From 9f8e9cb091022a1e11bceb86ec0fd93e1fe58f26 Mon Sep 17 00:00:00 2001 From: Wentao Lyu <35-wentao.lyu@users.noreply.git.stereye.tech> Date: Mon, 13 Mar 2023 14:04:47 +0800 Subject: [PATCH] feat: gen title by sperate api call feat: fix: rename of convo should based on real request. --- api/server/routes/ask.js | 3 ++- api/server/routes/askBing.js | 5 ++-- api/server/routes/askSydney.js | 5 ++-- api/server/routes/convos.js | 26 ++++++++++++++++++ .../components/Conversations/Conversation.jsx | 11 +++++--- client/src/components/Main/TextChat.jsx | 27 ++++++++++--------- client/src/components/Nav/index.jsx | 5 ++-- client/src/store/convoSlice.js | 11 ++++---- 8 files changed, 65 insertions(+), 28 deletions(-) diff --git a/api/server/routes/ask.js b/api/server/routes/ask.js index 9be3fe7014..d21f8be263 100644 --- a/api/server/routes/ask.js +++ b/api/server/routes/ask.js @@ -10,7 +10,7 @@ const { customClient, detectCode } = require('../../app/'); -const { getConvo, saveMessage, deleteMessagesSince, deleteMessages, saveConvo } = require('../../models'); +const { getConvo, saveMessage, getConvoTitle, saveConvo } = require('../../models'); const { handleError, sendMessage } = require('./handlers'); router.use('/bing', askBing); @@ -156,6 +156,7 @@ router.post('/', async (req, res) => { await saveMessage(gptResponse); await saveConvo(gptResponse); sendMessage(res, { + title: await getConvoTitle(conversationId), final: true, requestMessage: userMessage, responseMessage: gptResponse diff --git a/api/server/routes/askBing.js b/api/server/routes/askBing.js index cffb669ba8..f2722bb248 100644 --- a/api/server/routes/askBing.js +++ b/api/server/routes/askBing.js @@ -2,7 +2,7 @@ const express = require('express'); const crypto = require('crypto'); const router = express.Router(); const { titleConvo, getCitations, citeText, askBing } = require('../../app/'); -const { saveMessage, deleteMessages, deleteMessagesSince, saveConvo } = require('../../models'); +const { saveMessage, getConvoTitle, saveConvo } = require('../../models'); const { handleError, sendMessage } = require('./handlers'); const citationRegex = /\[\^\d+?\^]/g; @@ -11,7 +11,7 @@ router.post('/', async (req, res) => { if (text.length === 0) { return handleError(res, 'Prompt empty or too short'); } - + const conversationId = oldConversationId || crypto.randomUUID(); const userMessageId = messageId; @@ -98,6 +98,7 @@ router.post('/', async (req, res) => { await saveMessage(response); await saveConvo(response); sendMessage(res, { + title: await getConvoTitle(conversationId), final: true, requestMessage: userMessage, responseMessage: gptResponse diff --git a/api/server/routes/askSydney.js b/api/server/routes/askSydney.js index a4f1b2b866..a90dd39555 100644 --- a/api/server/routes/askSydney.js +++ b/api/server/routes/askSydney.js @@ -2,7 +2,7 @@ const express = require('express'); const crypto = require('crypto'); const router = express.Router(); const { titleConvo, getCitations, citeText, askSydney } = require('../../app/'); -const { saveMessage, deleteMessages, saveConvo, deleteMessagesSince, getConvoTitle } = require('../../models'); +const { saveMessage, saveConvo, getConvoTitle } = require('../../models'); const { handleError, sendMessage } = require('./handlers'); const citationRegex = /\[\^\d+?\^]/g; @@ -11,7 +11,7 @@ router.post('/', async (req, res) => { if (text.length === 0) { return handleError(res, 'Prompt empty or too short'); } - + const conversationId = oldConversationId || crypto.randomUUID(); const userMessageId = messageId; @@ -108,6 +108,7 @@ router.post('/', async (req, res) => { await saveMessage(response); await saveConvo(response); sendMessage(res, { + title: await getConvoTitle(conversationId), final: true, requestMessage: userMessage, responseMessage: gptResponse diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 4b9320873f..a82242db4a 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -1,12 +1,38 @@ const express = require('express'); const router = express.Router(); +const { titleConvo } = require('../../app/'); +const { getConvo, saveConvo, getConvoTitle } = require('../../models'); const { getConvos, deleteConvos, updateConvo } = require('../../models/Conversation'); +const { getMessages } = require('../../models/Message'); router.get('/', async (req, res) => { const pageNumber = req.query.pageNumber || 1; res.status(200).send(await getConvos(pageNumber)); }); +router.post('/gen_title', async (req, res) => { + const { conversationId } = req.body.arg; + + const convo = await getConvo(conversationId) + const firstMessage = (await getMessages({ conversationId }))[0] + const secondMessage = (await getMessages({ conversationId }))[1] + + const title = convo.jailbreakConversationId + ? await getConvoTitle(conversationId) + : await titleConvo({ + model: convo?.model, + message: firstMessage?.text, + response: JSON.stringify(secondMessage?.text || '') + }); + + await saveConvo({ + conversationId, + title + }) + + res.status(200).send(title); +}); + router.post('/clear', async (req, res) => { let filter = {}; const { conversationId } = req.body.arg; diff --git a/client/src/components/Conversations/Conversation.jsx b/client/src/components/Conversations/Conversation.jsx index 30df587da9..4d90ecb2cc 100644 --- a/client/src/components/Conversations/Conversation.jsx +++ b/client/src/components/Conversations/Conversation.jsx @@ -8,13 +8,14 @@ import { setMessages, setEmptyMessage } from '~/store/messageSlice'; import { setText } from '~/store/textSlice'; import manualSWR from '~/utils/fetchers'; import ConvoIcon from '../svg/ConvoIcon'; +import { refreshConversation } from '../../store/convoSlice'; export default function Conversation({ id, model, parentMessageId, conversationId, - title = 'New conversation', + title, chatGptLabel = null, promptPrefix = null, bingData, @@ -95,6 +96,7 @@ export default function Conversation({ const renameHandler = (e) => { e.preventDefault(); + setTitleInput(title); setRenaming(true); setTimeout(() => { inputRef.current.focus(); @@ -112,7 +114,10 @@ export default function Conversation({ if (titleInput === title) { return; } - rename.trigger({ conversationId, title: titleInput }); + rename.trigger({ conversationId, title: titleInput }) + .then(() => { + dispatch(refreshConversation()) + }); }; const handleKeyDown = (e) => { @@ -149,7 +154,7 @@ export default function Conversation({ onKeyDown={handleKeyDown} /> ) : ( - titleInput + title )} {conversationId === id ? ( diff --git a/client/src/components/Main/TextChat.jsx b/client/src/components/Main/TextChat.jsx index 69f4b65898..b541c9a51f 100644 --- a/client/src/components/Main/TextChat.jsx +++ b/client/src/components/Main/TextChat.jsx @@ -1,5 +1,6 @@ import React, { useEffect, useRef, useState } from 'react'; import { SSE } from '~/utils/sse'; +import axios from 'axios'; import SubmitButton from './SubmitButton'; import Regenerate from './Regenerate'; import ModelMenu from '../Models/ModelMenu'; @@ -7,10 +8,11 @@ import Footer from './Footer'; import TextareaAutosize from 'react-textarea-autosize'; import handleSubmit from '~/utils/handleSubmit'; import { useSelector, useDispatch } from 'react-redux'; -import { setConversation, setError } from '~/store/convoSlice'; +import { setConversation, setError, refreshConversation } from '~/store/convoSlice'; import { setMessages } from '~/store/messageSlice'; import { setSubmitState, setSubmission } from '~/store/submitSlice'; import { setText } from '~/store/textSlice'; +import manualSWR from '~/utils/fetchers'; export default function TextChat({ messages }) { const [errorMessage, setErrorMessage] = useState(''); @@ -22,6 +24,7 @@ export default function TextChat({ messages }) { useSelector((state) => state.submit); const { text } = useSelector((state) => state.text); const { error } = convo; + const genTitle = manualSWR(`/api/convos/gen_title`, 'post'); // auto focus to input, when enter a conversation. useEffect(() => { @@ -45,26 +48,24 @@ export default function TextChat({ messages }) { const convoHandler = (data, currentState, currentMsg) => { const { requestMessage, responseMessage } = data; + const { conversationId } = currentMsg; const { messages, _currentMsg, message, isCustomModel, sender } = currentState; const { model, chatGptLabel, promptPrefix } = message; dispatch( - setMessages([...messages, - { - ...requestMessage, - // messageId: data?.parentMessageId, - }, - { - ...responseMessage, - // sender, - // text: data.text || data.response, - } - ]) + setMessages([...messages, requestMessage, responseMessage,]) ); const isBing = model === 'bingai' || model === 'sydney'; - // if (!message.messageId) + if (requestMessage.parentMessageId == '00000000-0000-0000-0000-000000000000') { + genTitle.trigger({ conversationId }).then((ret) => { + const title = ret?.data + + if (title) + dispatch(refreshConversation()); + }) + } if (!isBing && convo.conversationId === null && convo.parentMessageId === null) { const { title } = data; diff --git a/client/src/components/Nav/index.jsx b/client/src/components/Nav/index.jsx index 7d548429d9..3029686cf8 100644 --- a/client/src/components/Nav/index.jsx +++ b/client/src/components/Nav/index.jsx @@ -11,7 +11,7 @@ import { incrementPage, setConvos } from '~/store/convoSlice'; export default function Nav({ navVisible, setNavVisible }) { const dispatch = useDispatch(); const [isHovering, setIsHovering] = useState(false); - const { conversationId, convos, pageNumber } = useSelector((state) => state.convo); + const { conversationId, convos, pageNumber, refreshConvoHint } = useSelector((state) => state.convo); const onSuccess = (data) => { dispatch(setConvos(data)); }; @@ -20,6 +20,7 @@ export default function Nav({ navVisible, setNavVisible }) { `/api/convos?pageNumber=${pageNumber}`, onSuccess ); + const containerRef = useRef(null); const scrollPositionRef = useRef(null); @@ -35,7 +36,7 @@ export default function Nav({ navVisible, setNavVisible }) { } }; - useDidMountEffect(() => mutate(), [conversationId]); + useDidMountEffect(() => mutate(), [conversationId, refreshConvoHint]); useEffect(() => { const container = containerRef.current; diff --git a/client/src/store/convoSlice.js b/client/src/store/convoSlice.js index 4e55fa2b94..258de84e70 100644 --- a/client/src/store/convoSlice.js +++ b/client/src/store/convoSlice.js @@ -13,6 +13,7 @@ const initialState = { promptPrefix: null, convosLoading: false, pageNumber: 1, + refreshConvoHint: 0, convos: [] }; @@ -20,6 +21,9 @@ const currentSlice = createSlice({ name: 'convo', initialState, reducers: { + refreshConversation: (state, action) => { + state.refreshConvoHint = state.refreshConvoHint + 1; + }, setConversation: (state, action) => { return { ...state, ...action.payload }; }, @@ -44,10 +48,7 @@ const currentSlice = createSlice({ state.pageNumber = 1; }, setConvos: (state, action) => { - const newConvos = action.payload.filter((convo) => { - return !state.convos.some((c) => c.conversationId === convo.conversationId); - }); - state.convos = [...state.convos, ...newConvos].sort( + state.convos = action.payload.sort( (a, b) => new Date(b.createdAt) - new Date(a.createdAt) ); }, @@ -60,7 +61,7 @@ const currentSlice = createSlice({ } }); -export const { setConversation, setConvos, setNewConvo, setError, incrementPage, removeConvo, removeAll } = +export const { refreshConversation, setConversation, setConvos, setNewConvo, setError, incrementPage, removeConvo, removeAll } = currentSlice.actions; export default currentSlice.reducer;