fix: stops stream upon conversation

This commit is contained in:
Daniel Avila 2023-03-11 21:42:08 -05:00
parent 79f050bac7
commit 23de688bf3
6 changed files with 218 additions and 120 deletions

View file

@ -2,7 +2,7 @@ const { ModelOperations } = require('@vscode/vscode-languagedetection');
const languages = require('../utils/languages.js'); const languages = require('../utils/languages.js');
const codeRegex = /(```[\s\S]*?```)/g; const codeRegex = /(```[\s\S]*?```)/g;
// const languageMatch = /```(\w+)/; // const languageMatch = /```(\w+)/;
const replaceRegex = /```\w+/g; const replaceRegex = /```\w+\n/g;
const detectCode = async (input) => { const detectCode = async (input) => {
try { try {
@ -22,7 +22,7 @@ const detectCode = async (input) => {
} }
console.log('[detectCode.js] replacing', match, 'with', '```shell'); console.log('[detectCode.js] replacing', match, 'with', '```shell');
text = text.replace(match, '```shell'); text = text.replace(match, '```shell\n');
}); });
return text; return text;

View file

@ -0,0 +1,11 @@
module.exports = (req, res, next) => {
let { stopStream } = req.body;
if (stopStream) {
console.log('stopStream');
res.write('event: stop\ndata:\n\n');
res.end();
return;
} else {
next();
}
};

View file

@ -3,7 +3,7 @@ import RenameButton from './RenameButton';
import DeleteButton from './DeleteButton'; import DeleteButton from './DeleteButton';
import { useSelector, useDispatch } from 'react-redux'; import { useSelector, useDispatch } from 'react-redux';
import { setConversation } from '~/store/convoSlice'; import { setConversation } from '~/store/convoSlice';
import { setStopStream, setCustomGpt, setModel, setCustomModel } from '~/store/submitSlice'; import { setSubmitState, setSubmission, setStopStream, setCustomGpt, setModel, setCustomModel } from '~/store/submitSlice';
import { setMessages, setEmptyMessage } from '~/store/messageSlice'; import { setMessages, setEmptyMessage } from '~/store/messageSlice';
import { setText } from '~/store/textSlice'; import { setText } from '~/store/textSlice';
import manualSWR from '~/utils/fetchers'; import manualSWR from '~/utils/fetchers';
@ -21,6 +21,7 @@ export default function Conversation({
const [renaming, setRenaming] = useState(false); const [renaming, setRenaming] = useState(false);
const [titleInput, setTitleInput] = useState(title); const [titleInput, setTitleInput] = useState(title);
const { modelMap } = useSelector((state) => state.models); const { modelMap } = useSelector((state) => state.models);
const { stopStream } = useSelector((state) => state.submit);
const inputRef = useRef(null); const inputRef = useRef(null);
const dispatch = useDispatch(); const dispatch = useDispatch();
const { trigger, isMutating } = manualSWR(`/api/messages/${id}`, 'get'); const { trigger, isMutating } = manualSWR(`/api/messages/${id}`, 'get');
@ -31,7 +32,11 @@ export default function Conversation({
return; return;
} }
if (!stopStream) {
dispatch(setStopStream(true)); dispatch(setStopStream(true));
dispatch(setSubmission({}));
dispatch(setSubmitState(false));
}
dispatch(setEmptyMessage()); dispatch(setEmptyMessage());
const convo = { title, error: false, conversationId: id, chatGptLabel, promptPrefix }; const convo = { title, error: false, conversationId: id, chatGptLabel, promptPrefix };
@ -67,12 +72,6 @@ export default function Conversation({
); );
} }
const data = await trigger(); const data = await trigger();
// while (isMutating) {
// await new Promise((resolve) => setTimeout(() => {
// dispatch(setMessages([]));
// resolve();
// }, 50));
// }
if (chatGptLabel) { if (chatGptLabel) {
dispatch(setModel('chatgptCustom')); dispatch(setModel('chatgptCustom'));

View file

@ -1,4 +1,5 @@
import React, { useState } from 'react'; import React, { useEffect, useState } from 'react';
import { SSE } from '~/utils/sse';
import SubmitButton from './SubmitButton'; import SubmitButton from './SubmitButton';
import Regenerate from './Regenerate'; import Regenerate from './Regenerate';
import ModelMenu from '../Models/ModelMenu'; import ModelMenu from '../Models/ModelMenu';
@ -8,7 +9,7 @@ import handleSubmit from '~/utils/handleSubmit';
import { useSelector, useDispatch } from 'react-redux'; import { useSelector, useDispatch } from 'react-redux';
import { setConversation, setError } from '~/store/convoSlice'; import { setConversation, setError } from '~/store/convoSlice';
import { setMessages } from '~/store/messageSlice'; import { setMessages } from '~/store/messageSlice';
import { setSubmitState } from '~/store/submitSlice'; import { setSubmitState, setSubmission } from '~/store/submitSlice';
import { setText } from '~/store/textSlice'; import { setText } from '~/store/textSlice';
export default function TextChat({ messages }) { export default function TextChat({ messages }) {
@ -16,48 +17,26 @@ export default function TextChat({ messages }) {
const dispatch = useDispatch(); const dispatch = useDispatch();
const convo = useSelector((state) => state.convo); const convo = useSelector((state) => state.convo);
const { initial } = useSelector((state) => state.models); const { initial } = useSelector((state) => state.models);
const { isSubmitting, stopStream, disabled, model, chatGptLabel, promptPrefix } = useSelector( const { isSubmitting, stopStream, submission, disabled, model, chatGptLabel, promptPrefix } =
(state) => state.submit useSelector((state) => state.submit);
);
const { text } = useSelector((state) => state.text); const { text } = useSelector((state) => state.text);
const { error } = convo; const { error } = convo;
const isCustomModel = model === 'chatgptCustom' || !initial[model];
const submitMessage = () => { const messageHandler = (data, currentState) => {
if (error) { const { messages, currentMsg, sender } = currentState;
dispatch(setError(false));
}
if (!!isSubmitting || text.trim() === '') {
return;
}
dispatch(setSubmitState(true));
const message = text.trim();
const currentMsg = { sender: 'User', text: message, current: true };
const sender = model === 'chatgptCustom' ? chatGptLabel : model;
const initialResponse = { sender, text: '' };
dispatch(setMessages([...messages, currentMsg, initialResponse]));
dispatch(setText(''));
const messageHandler = (data, events) => {
if (stopStream) {
console.log('Stopping stream');
events.close();
return;
}
dispatch(setMessages([...messages, currentMsg, { sender, text: data }])); dispatch(setMessages([...messages, currentMsg, { sender, text: data }]));
}; };
const convoHandler = (data) => {
const convoHandler = (data, currentState) => {
const { messages, currentMsg, sender, isCustomModel, model, chatGptLabel, promptPrefix } =
currentState;
dispatch( dispatch(
setMessages([...messages, currentMsg, { sender, text: data.text || data.response }]) setMessages([...messages, currentMsg, { sender, text: data.text || data.response }])
); );
const isBing = model === 'bingai' || model === 'sydney'; const isBing = model === 'bingai' || model === 'sydney';
if ( if (!isBing && convo.conversationId === null && convo.parentMessageId === null) {
!isBing &&
convo.conversationId === null &&
convo.parentMessageId === null
) {
const { title, conversationId, id } = data; const { title, conversationId, id } = data;
dispatch( dispatch(
setConversation({ setConversation({
@ -77,14 +56,8 @@ export default function TextChat({ messages }) {
convo.conversationId === null && convo.conversationId === null &&
convo.invocationId === null convo.invocationId === null
) { ) {
console.log('Bing data:', data) console.log('Bing data:', data);
const { const { title, conversationSignature, clientId, conversationId, invocationId } = data;
title,
conversationSignature,
clientId,
conversationId,
invocationId
} = data;
dispatch( dispatch(
setConversation({ setConversation({
title, title,
@ -92,7 +65,7 @@ export default function TextChat({ messages }) {
conversationSignature, conversationSignature,
clientId, clientId,
conversationId, conversationId,
invocationId, invocationId
}) })
); );
} else if (model === 'sydney') { } else if (model === 'sydney') {
@ -113,7 +86,7 @@ export default function TextChat({ messages }) {
conversationSignature, conversationSignature,
clientId, clientId,
conversationId, conversationId,
invocationId, invocationId
}) })
); );
} }
@ -121,7 +94,8 @@ export default function TextChat({ messages }) {
dispatch(setSubmitState(false)); dispatch(setSubmitState(false));
}; };
const errorHandler = (event) => { const errorHandler = (event, currentState) => {
const { initialResponse, messages, currentMsg, message } = currentState;
console.log('Error:', event); console.log('Error:', event);
const errorResponse = { const errorResponse = {
...initialResponse, ...initialResponse,
@ -135,20 +109,124 @@ export default function TextChat({ messages }) {
dispatch(setError(true)); dispatch(setError(true));
return; return;
}; };
const submitMessage = () => {
if (error) {
dispatch(setError(false));
}
if (!!isSubmitting || text.trim() === '') {
return;
}
const isCustomModel = model === 'chatgptCustom' || !initial[model];
const message = text.trim();
const currentMsg = { sender: 'User', text: message, current: true };
const sender = model === 'chatgptCustom' ? chatGptLabel : model;
const initialResponse = { sender, text: '' };
dispatch(setSubmitState(true));
dispatch(setMessages([...messages, currentMsg, initialResponse]));
dispatch(setText(''));
const submission = { const submission = {
model, model,
text: message, text: message,
convo, convo,
messageHandler,
convoHandler,
errorHandler,
chatGptLabel, chatGptLabel,
promptPrefix promptPrefix,
isCustomModel,
message,
messages,
currentMsg,
sender,
initialResponse
}; };
console.log('User Input:', message); console.log('User Input:', message);
handleSubmit(submission); // handleSubmit(submission);
dispatch(setSubmission(submission));
}; };
const createPayload = ({ model, text, convo, chatGptLabel, promptPrefix }) => {
const endpoint = `/api/ask`;
let payload = { model, text, chatGptLabel, promptPrefix };
if (convo.conversationId && convo.parentMessageId) {
payload = {
...payload,
conversationId: convo.conversationId,
parentMessageId: convo.parentMessageId
};
}
const isBing = model === 'bingai' || model === 'sydney';
if (isBing && convo.conversationId) {
payload = {
...payload,
jailbreakConversationId: convo.jailbreakConversationId,
conversationId: convo.conversationId,
conversationSignature: convo.conversationSignature,
clientId: convo.clientId,
invocationId: convo.invocationId
};
}
let server = endpoint;
server = model === 'bingai' ? server + '/bing' : server;
server = model === 'sydney' ? server + '/sydney' : server;
return { server, payload };
};
useEffect(() => {
if (Object.keys(submission).length === 0) {
return;
}
const currentState = submission;
const { server, payload } = createPayload(submission);
const onMessage = (e) => {
if (stopStream) {
return;
}
const data = JSON.parse(e.data);
let text = data.text || data.response;
if (data.message) {
messageHandler(text, currentState);
}
if (data.final) {
convoHandler(data, currentState);
console.log('final', data);
} else {
// console.log('dataStream', data);
}
};
const events = new SSE(server, {
payload: JSON.stringify(payload),
headers: { 'Content-Type': 'application/json' }
});
events.onopen = function () {
console.log('connection is opened');
};
events.onmessage = onMessage;
events.onerror = function (e) {
console.log('error in opening conn.');
events.close();
errorHandler(e, currentState);
};
events.stream();
return () => {
events.removeEventListener('message', onMessage);
events.close();
};
}, [submission]);
const handleKeyDown = (e) => { const handleKeyDown = (e) => {
if (e.key === 'Enter' && !e.shiftKey) { if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault(); e.preventDefault();

View file

@ -2,12 +2,13 @@ import { createSlice } from '@reduxjs/toolkit';
const initialState = { const initialState = {
isSubmitting: false, isSubmitting: false,
submission: {},
stopStream: false, stopStream: false,
disabled: false, disabled: false,
model: 'chatgpt', model: 'chatgpt',
promptPrefix: '', promptPrefix: '',
chatGptLabel: '', chatGptLabel: '',
customModel: null customModel: null,
}; };
const currentSlice = createSlice({ const currentSlice = createSlice({
@ -17,6 +18,9 @@ const currentSlice = createSlice({
setSubmitState: (state, action) => { setSubmitState: (state, action) => {
state.isSubmitting = action.payload; state.isSubmitting = action.payload;
}, },
setSubmission: (state, action) => {
state.submission = action.payload;
},
setStopStream: (state, action) => { setStopStream: (state, action) => {
state.stopStream = action.payload; state.stopStream = action.payload;
}, },
@ -36,7 +40,7 @@ const currentSlice = createSlice({
} }
}); });
export const { setSubmitState, setStopStream, setDisabled, setModel, setCustomGpt, setCustomModel } = export const { setSubmitState, setSubmission, setStopStream, setDisabled, setModel, setCustomGpt, setCustomModel } =
currentSlice.actions; currentSlice.actions;
export default currentSlice.reducer; export default currentSlice.reducer;

View file

@ -68,5 +68,11 @@ export default function handleSubmit({
errorHandler(e); errorHandler(e);
}; };
events.addEventListener('stop', () => {
// Close the SSE stream
console.log('stop event received');
events.close();
});
events.stream(); events.stream();
} }