diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 5d55d33fd..ea1f9c0da 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -52,15 +52,23 @@ class BaseClient { if (opts && typeof opts === 'object') { this.setOptions(opts); } + + const { isEdited, isContinued } = opts; const user = opts.user ?? null; + const saveOptions = this.getSaveOptions(); + this.abortController = opts.abortController ?? new AbortController(); const conversationId = opts.conversationId ?? crypto.randomUUID(); const parentMessageId = opts.parentMessageId ?? '00000000-0000-0000-0000-000000000000'; const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID(); - const responseMessageId = opts.responseMessageId ?? crypto.randomUUID(); - const saveOptions = this.getSaveOptions(); - const head = opts.isEdited ? responseMessageId : parentMessageId; + let responseMessageId = opts.responseMessageId ?? crypto.randomUUID(); + let head = isEdited ? responseMessageId : parentMessageId; this.currentMessages = (await this.loadHistory(conversationId, head)) ?? []; - this.abortController = opts.abortController ?? new AbortController(); + + if (isEdited && !isContinued) { + responseMessageId = crypto.randomUUID(); + head = responseMessageId; + this.currentMessages[this.currentMessages.length - 1].messageId = head; + } return { ...opts, @@ -397,11 +405,16 @@ class BaseClient { const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } = await this.handleStartMethods(message, opts); + const { generation = '' } = opts; + this.user = user; // It's not necessary to push to currentMessages // depending on subclass implementation of handling messages // When this is an edit, all messages are already in currentMessages, both user and response - if (!isEdited) { + if (isEdited) { + /* TODO: edge case where latest message doesn't exist */ + this.currentMessages[this.currentMessages.length - 1].text = generation; + } else { this.currentMessages.push(userMessage); } @@ -419,7 +432,7 @@ class BaseClient { if (this.options.debug) { console.debug('payload'); - // console.debug(payload); + console.debug(payload); } if (tokenCountMap) { @@ -442,7 +455,6 @@ class BaseClient { await this.saveMessageToDatabase(userMessage, saveOptions, user); } - const generation = isEdited ? this.currentMessages[this.currentMessages.length - 1].text : ''; const responseMessage = { messageId: responseMessageId, conversationId, diff --git a/api/app/titleConvo.js b/api/app/titleConvo.js index ebdde7e5c..65ef44d28 100644 --- a/api/app/titleConvo.js +++ b/api/app/titleConvo.js @@ -1,4 +1,4 @@ -const _ = require('lodash'); +const throttle = require('lodash/throttle'); const { genAzureChatCompletion, getAzureCredentials } = require('../utils/'); const titleConvo = async ({ text, response, openAIApiKey, azure = false }) => { @@ -52,6 +52,6 @@ const titleConvo = async ({ text, response, openAIApiKey, azure = false }) => { return title; }; -const throttledTitleConvo = _.throttle(titleConvo, 1000); +const throttledTitleConvo = throttle(titleConvo, 1000); module.exports = throttledTitleConvo; diff --git a/api/app/titleConvoBing.js b/api/app/titleConvoBing.js index 8454517d8..cb75bd859 100644 --- a/api/app/titleConvoBing.js +++ b/api/app/titleConvoBing.js @@ -1,4 +1,4 @@ -const _ = require('lodash'); +const throttle = require('lodash/throttle'); const titleConvo = async ({ text, response }) => { let title = 'New Chat'; @@ -32,6 +32,6 @@ const titleConvo = async ({ text, response }) => { return title; }; -const throttledTitleConvo = _.throttle(titleConvo, 3000); +const throttledTitleConvo = throttle(titleConvo, 3000); module.exports = throttledTitleConvo; diff --git a/api/package.json b/api/package.json index bd2c76b97..514f53d25 100644 --- a/api/package.json +++ b/api/package.json @@ -22,7 +22,7 @@ "dependencies": { "@anthropic-ai/sdk": "^0.5.4", "@azure/search-documents": "^11.3.2", - "@dqbd/tiktoken": "^1.0.2", + "@dqbd/tiktoken": "^1.0.7", "@fortaine/fetch-event-source": "^3.0.6", "@keyv/mongo": "^2.1.8", "@waylaidwanderer/chatgpt-api": "^1.37.2", diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 648975efd..a3d355055 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -79,6 +79,7 @@ const handleAbortError = async (res, req, error, data) => { cancelled: false, error: true, text: error.message, + isCreatedByUser: false, }; if (abortControllers.has(conversationId)) { const { abortController } = abortControllers.get(conversationId); @@ -89,10 +90,11 @@ const handleAbortError = async (res, req, error, data) => { handleError(res, errorMessage); }; - if (partialText?.length > 2) { + if (partialText && partialText.length > 5) { try { return await abortMessage(req, res); } catch (err) { + console.error(err); return respondWithError(); } } else { diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 142626008..a5c7f3c96 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -3,6 +3,7 @@ const setHeaders = require('./setHeaders'); const requireJwtAuth = require('./requireJwtAuth'); const requireLocalAuth = require('./requireLocalAuth'); const validateEndpoint = require('./validateEndpoint'); +const validateMessageReq = require('./validateMessageReq'); const buildEndpointOption = require('./buildEndpointOption'); const validateRegistration = require('./validateRegistration'); @@ -12,6 +13,7 @@ module.exports = { requireJwtAuth, requireLocalAuth, validateEndpoint, + validateMessageReq, buildEndpointOption, validateRegistration, }; diff --git a/api/server/middleware/validateMessageReq.js b/api/server/middleware/validateMessageReq.js new file mode 100644 index 000000000..7492c8fd4 --- /dev/null +++ b/api/server/middleware/validateMessageReq.js @@ -0,0 +1,28 @@ +const { getConvo } = require('../../models'); + +// Middleware to validate conversationId and user relationship +const validateMessageReq = async (req, res, next) => { + let conversationId = req.params.conversationId || req.body.conversationId; + + if (conversationId === 'new') { + return res.status(200).send([]); + } + + if (!conversationId && req.body.message) { + conversationId = req.body.message.conversationId; + } + + const conversation = await getConvo(req.user.id, conversationId); + + if (!conversation) { + return res.status(404).json({ error: 'Conversation not found' }); + } + + if (conversation.user !== req.user.id) { + return res.status(403).json({ error: 'User not authorized for this conversation' }); + } + + next(); +}; + +module.exports = validateMessageReq; diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index 7e141cbdf..10a291ea7 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -29,11 +29,12 @@ router.post( endpointOption, conversationId, responseMessageId, + isContinued = false, parentMessageId = null, overrideParentMessageId = null, } = req.body; console.log('edit log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); let metadata; let userMessage; let lastSavedTimestamp = 0; @@ -41,7 +42,10 @@ router.post( const userMessageId = parentMessageId; const addMetadata = (data) => (metadata = data); - const getIds = (data) => (userMessage = data.userMessage); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; const { onProgress: progressCallback, getPartialText } = createOnProgress({ generation, @@ -87,6 +91,8 @@ router.post( let response = await client.sendMessage(text, { user: req.user.id, + generation, + isContinued, isEdited: true, conversationId, parentMessageId, diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index a0c81d46c..89886ea82 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -30,11 +30,12 @@ router.post( endpointOption, conversationId, responseMessageId, + isContinued = false, parentMessageId = null, overrideParentMessageId = null, } = req.body; console.log('edit log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); let metadata; let userMessage; let lastSavedTimestamp = 0; @@ -50,7 +51,10 @@ router.post( }; const addMetadata = (data) => (metadata = data); - const getIds = (data) => (userMessage = data.userMessage); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; const { onProgress: progressCallback, @@ -128,6 +132,8 @@ router.post( let response = await client.sendMessage(text, { user, + generation, + isContinued, isEdited: true, conversationId, parentMessageId, diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index 1ad0bcbeb..1f15d25d0 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -29,11 +29,12 @@ router.post( endpointOption, conversationId, responseMessageId, + isContinued = false, parentMessageId = null, overrideParentMessageId = null, } = req.body; console.log('edit log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); let metadata; let userMessage; let lastSavedTimestamp = 0; @@ -41,7 +42,10 @@ router.post( const userMessageId = parentMessageId; const addMetadata = (data) => (metadata = data); - const getIds = (data) => (userMessage = data.userMessage); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; const { onProgress: progressCallback, getPartialText } = createOnProgress({ generation, @@ -90,6 +94,8 @@ router.post( let response = await client.sendMessage(text, { user: req.user.id, + generation, + isContinued, isEdited: true, conversationId, parentMessageId, diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 0530ebc26..7dd72fad1 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,11 +1,50 @@ const express = require('express'); const router = express.Router(); -const { getMessages } = require('../../models/Message'); -const requireJwtAuth = require('../middleware/requireJwtAuth'); +const { + getMessages, + updateMessage, + saveConvo, + saveMessage, + deleteMessages, +} = require('../../models'); +const { requireJwtAuth, validateMessageReq } = require('../middleware/'); -router.get('/:conversationId', requireJwtAuth, async (req, res) => { +router.get('/:conversationId', requireJwtAuth, validateMessageReq, async (req, res) => { const { conversationId } = req.params; res.status(200).send(await getMessages({ conversationId })); }); +// CREATE +router.post('/:conversationId', requireJwtAuth, validateMessageReq, async (req, res) => { + const message = req.body; + const savedMessage = await saveMessage(message); + await saveConvo(req.user.id, savedMessage); + res.status(201).send(savedMessage); +}); + +// READ +router.get('/:conversationId/:messageId', requireJwtAuth, validateMessageReq, async (req, res) => { + const { conversationId, messageId } = req.params; + res.status(200).send(await getMessages({ conversationId, messageId })); +}); + +// UPDATE +router.put('/:conversationId/:messageId', requireJwtAuth, validateMessageReq, async (req, res) => { + const { messageId } = req.params; + const { text } = req.body; + res.status(201).send(await updateMessage({ messageId, text })); +}); + +// DELETE +router.delete( + '/:conversationId/:messageId', + requireJwtAuth, + validateMessageReq, + async (req, res) => { + const { messageId } = req.params; + await deleteMessages({ messageId }); + res.status(204).send(); + }, +); + module.exports = router; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index b5efa08d8..d29a56cfe 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,4 +1,4 @@ -const _ = require('lodash'); +const partialRight = require('lodash/partialRight'); const citationRegex = /\[\^\d+?\^]/g; const { getCitations, citeText } = require('./citations'); const cursor = ''; @@ -73,7 +73,7 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { }; const onProgress = (opts) => { - return _.partialRight(progressCallback, opts); + return partialRight(progressCallback, opts); }; const getPartialText = () => { diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 950778235..d9a6a2257 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -1,4 +1,4 @@ -import { TConversation, TPreset } from 'librechat-data-provider'; +import { TConversation, TMessage, TPreset } from 'librechat-data-provider'; export type TSetOption = (param: number | string) => (newValue: number | string | boolean) => void; export type TSetExample = ( @@ -62,3 +62,34 @@ export type TOnClick = (e: React.MouseEvent) => void; export type TGenButtonProps = { onClick: TOnClick; }; + +export type TAskProps = { + text: string; + parentMessageId?: string | null; + conversationId?: string | null; + messageId?: string | null; +}; + +export type TOptions = { + editedMessageId?: string | null; + editedText?: string | null; + isRegenerate?: boolean; + isContinued?: boolean; + isEdited?: boolean; +}; + +export type TAskFunction = (props: TAskProps, options?: TOptions) => void; + +export type TMessageProps = { + conversation?: TConversation | null; + messageId?: string | null; + message?: TMessage; + messagesTree?: TMessage[]; + currentEditId: string | number | null; + isSearchView?: boolean; + siblingIdx?: number; + siblingCount?: number; + scrollToBottom?: () => void; + setCurrentEditId?: React.Dispatch> | null; + setSiblingIdx?: ((value: number) => void | React.Dispatch>) | null; +}; diff --git a/client/src/components/Input/Generations/GenerationButtons.tsx b/client/src/components/Input/Generations/GenerationButtons.tsx index 7c51f59e0..ad9ef15c4 100644 --- a/client/src/components/Input/Generations/GenerationButtons.tsx +++ b/client/src/components/Input/Generations/GenerationButtons.tsx @@ -1,3 +1,4 @@ +import { useEffect, useState } from 'react'; import type { TMessage } from 'librechat-data-provider'; import { useMessageHandler, useMediaQuery, useGenerations } from '~/hooks'; import { cn } from '~/utils'; @@ -31,6 +32,27 @@ export default function GenerationButtons({ isSubmitting, }); + const [userStopped, setUserStopped] = useState(false); + + const handleStop = (e: React.MouseEvent) => { + setUserStopped(true); + handleStopGenerating(e); + }; + + useEffect(() => { + let timer: NodeJS.Timeout; + + if (userStopped) { + timer = setTimeout(() => { + setUserStopped(false); + }, 200); + } + + return () => { + clearTimeout(timer); + }; + }, [userStopped]); + if (isSmallScreen) { return null; } @@ -38,8 +60,8 @@ export default function GenerationButtons({ let button: React.ReactNode = null; if (isSubmitting) { - button = ; - } else if (continueSupported) { + button = ; + } else if (userStopped || continueSupported) { button = ; } else if (messages && messages.length > 0 && regenerateEnabled) { button = ; diff --git a/client/src/components/Messages/Content/CodeBlock.tsx b/client/src/components/Messages/Content/CodeBlock.tsx index c61fc65a2..9329662e8 100644 --- a/client/src/components/Messages/Content/CodeBlock.tsx +++ b/client/src/components/Messages/Content/CodeBlock.tsx @@ -51,7 +51,7 @@ const CodeBar: React.FC = React.memo(({ lang, codeRef, plugin = nu interface CodeBlockProps { lang: string; - codeChildren: string; + codeChildren: React.ReactNode; classProp?: string; plugin?: boolean; } diff --git a/client/src/components/Messages/Content/Content.jsx b/client/src/components/Messages/Content/Content.tsx similarity index 76% rename from client/src/components/Messages/Content/Content.jsx rename to client/src/components/Messages/Content/Content.tsx index cbf9beccc..158dfd293 100644 --- a/client/src/components/Messages/Content/Content.jsx +++ b/client/src/components/Messages/Content/Content.tsx @@ -1,6 +1,8 @@ import React, { useState, useEffect } from 'react'; +import type { TMessage } from 'librechat-data-provider'; import { useRecoilValue } from 'recoil'; import ReactMarkdown from 'react-markdown'; +import type { PluggableList } from 'unified'; import rehypeKatex from 'rehype-katex'; import rehypeHighlight from 'rehype-highlight'; import remarkMath from 'remark-math'; @@ -11,8 +13,18 @@ import CodeBlock from './CodeBlock'; import store from '~/store'; import { langSubset } from '~/utils'; -const code = React.memo((props) => { - const { inline, className, children } = props; +type TCodeProps = { + inline: boolean; + className: string; + children: React.ReactNode; +}; + +type TContentProps = { + content: string; + message: TMessage; +}; + +const code = React.memo(({ inline, className, children }: TCodeProps) => { const match = /language-(\w+)/.exec(className || ''); const lang = match && match[1]; @@ -23,11 +35,11 @@ const code = React.memo((props) => { } }); -const p = React.memo((props) => { - return

{props?.children}

; +const p = React.memo(({ children }: { children: React.ReactNode }) => { + return

{children}

; }); -const Content = React.memo(({ content, message }) => { +const Content = React.memo(({ content, message }: TContentProps) => { const [cursor, setCursor] = useState('█'); const isSubmitting = useRecoilValue(store.isSubmitting); const latestMessage = useRecoilValue(store.latestMessage); @@ -57,7 +69,7 @@ const Content = React.memo(({ content, message }) => { }; }, [isSubmitting, isLatestMessage]); - let rehypePlugins = [ + const rehypePlugins: PluggableList = [ [rehypeKatex, { output: 'mathml' }], [ rehypeHighlight, @@ -79,10 +91,14 @@ const Content = React.memo(({ content, message }) => { remarkPlugins={[supersub, remarkGfm, [remarkMath, { singleDollarTextMath: true }]]} rehypePlugins={rehypePlugins} linkTarget="_new" - components={{ - code, - p, - }} + components={ + { + code, + p, + } as { + [nodeType: string]: React.ElementType; + } + } > {isLatestMessage && isSubmitting && !isInitializing ? currentContent + cursor diff --git a/client/src/components/Messages/Content/MessageContent.tsx b/client/src/components/Messages/Content/MessageContent.tsx new file mode 100644 index 000000000..e8910523d --- /dev/null +++ b/client/src/components/Messages/Content/MessageContent.tsx @@ -0,0 +1,194 @@ +import { useRef } from 'react'; +import { useRecoilState } from 'recoil'; +import { useUpdateMessageMutation } from 'librechat-data-provider'; +import type { TMessage } from 'librechat-data-provider'; +import type { TAskFunction } from '~/common'; +import { cn, getError } from '~/utils'; +import store from '~/store'; +import Content from './Content'; + +type TInitialProps = { + text: string; + edit: boolean; + error: boolean; + unfinished: boolean; + isSubmitting: boolean; +}; +type TAdditionalProps = { + ask: TAskFunction; + message: TMessage; + isCreatedByUser: boolean; + siblingIdx: number; + enterEdit: (cancel: boolean) => void; + setSiblingIdx: (value: number) => void; +}; + +type TMessageContent = TInitialProps & TAdditionalProps; + +type TText = Pick; +type TEditProps = Pick & + Omit; +type TDisplayProps = TText & Pick; + +// Container Component +const Container = ({ children }: { children: React.ReactNode }) => ( +
{children}
+); + +// Error Message Component +const ErrorMessage = ({ text }: TText) => ( + +
+ {getError(text)} +
+
+); + +// Edit Message Component +const EditMessage = ({ + text, + message, + isSubmitting, + ask, + enterEdit, + siblingIdx, + setSiblingIdx, +}: TEditProps) => { + const [messages, setMessages] = useRecoilState(store.messages); + const textEditor = useRef(null); + const { conversationId, parentMessageId, messageId } = message; + const updateMessageMutation = useUpdateMessageMutation(conversationId ?? ''); + + const resubmitMessage = () => { + const text = textEditor?.current?.innerText ?? ''; + console.log('siblingIdx:', siblingIdx); + if (message.isCreatedByUser) { + ask({ + text, + parentMessageId, + conversationId, + }); + + setSiblingIdx((siblingIdx ?? 0) - 1); + } else { + const parentMessage = messages?.find((msg) => msg.messageId === parentMessageId); + + if (!parentMessage) { + return; + } + ask( + { ...parentMessage }, + { + editedText: text, + editedMessageId: messageId, + isRegenerate: true, + isEdited: true, + }, + ); + + setSiblingIdx((siblingIdx ?? 0) - 1); + } + + enterEdit(true); + }; + + const updateMessage = () => { + if (!messages) { + return; + } + const text = textEditor?.current?.innerText ?? ''; + updateMessageMutation.mutate({ + conversationId: conversationId ?? '', + messageId, + text, + }); + setMessages(() => + messages.map((msg) => + msg.messageId === messageId + ? { + ...msg, + text, + } + : msg, + ), + ); + enterEdit(true); + }; + + return ( + +
+ {text} +
+
+ + + +
+
+ ); +}; + +// Display Message Component +const DisplayMessage = ({ text, isCreatedByUser, message }: TDisplayProps) => ( + +
+ {!isCreatedByUser ? : <>{text}} +
+
+); + +// Unfinished Message Component +const UnfinishedMessage = () => ( + +); + +// Content Component +const MessageContent = ({ + text, + edit, + error, + unfinished, + isSubmitting, + ...props +}: TMessageContent) => { + if (error) { + return ; + } else if (edit) { + return ; + } else { + return ( + <> + + {!isSubmitting && unfinished && } + + ); + } +}; + +export default MessageContent; diff --git a/client/src/components/Messages/Plugin.tsx b/client/src/components/Messages/Content/Plugin.tsx similarity index 89% rename from client/src/components/Messages/Plugin.tsx rename to client/src/components/Messages/Content/Plugin.tsx index cd1bae1c2..afe6b2236 100644 --- a/client/src/components/Messages/Plugin.tsx +++ b/client/src/components/Messages/Content/Plugin.tsx @@ -1,28 +1,13 @@ -import React, { useState, useCallback, memo, ReactNode } from 'react'; -import { Spinner } from '~/components'; -import { useRecoilValue } from 'recoil'; -import CodeBlock from './Content/CodeBlock.jsx'; -import { Disclosure } from '@headlessui/react'; +import { useState, useCallback, memo, ReactNode } from 'react'; +import type { TResPlugin, TInput } from 'librechat-data-provider'; import { ChevronDownIcon, LucideProps } from 'lucide-react'; +import { Disclosure } from '@headlessui/react'; +import { useRecoilValue } from 'recoil'; +import { Spinner } from '~/components'; +import CodeBlock from './CodeBlock'; import { cn } from '~/utils/'; import store from '~/store'; -interface Input { - inputStr: string; -} - -interface PluginProps { - plugin: { - plugin: string; - input: string; - thought: string; - loading?: boolean; - outputs?: string; - latest?: string; - inputs?: Input[]; - }; -} - type PluginsMap = { [pluginKey: string]: string; }; @@ -31,7 +16,7 @@ type PluginIconProps = LucideProps & { className?: string; }; -function formatInputs(inputs: Input[]) { +function formatInputs(inputs: TInput[]) { let output = ''; for (let i = 0; i < inputs.length; i++) { @@ -45,6 +30,10 @@ function formatInputs(inputs: Input[]) { return output; } +type PluginProps = { + plugin: TResPlugin; +}; + const Plugin: React.FC = ({ plugin }) => { const [loading, setLoading] = useState(plugin.loading); const finished = plugin.outputs && plugin.outputs.length > 0; diff --git a/client/src/components/Messages/Content/SubRow.jsx b/client/src/components/Messages/Content/SubRow.tsx similarity index 66% rename from client/src/components/Messages/Content/SubRow.jsx rename to client/src/components/Messages/Content/SubRow.tsx index be1e9d72e..9041cb50c 100644 --- a/client/src/components/Messages/Content/SubRow.jsx +++ b/client/src/components/Messages/Content/SubRow.tsx @@ -1,6 +1,11 @@ -import React from 'react'; +type TSubRowProps = { + children: React.ReactNode; + classes?: string; + subclasses?: string; + onClick?: () => void; +}; -export default function SubRow({ children, classes = '', subclasses = '', onClick }) { +export default function SubRow({ children, classes = '', subclasses = '', onClick }: TSubRowProps) { return (
void; - copyToClipboard: (setIsCopied: (isCopied: boolean) => void) => void; - conversation: TConversation; + enterEdit: (cancel?: boolean) => void; + copyToClipboard: (setIsCopied: React.Dispatch>) => void; + conversation: TConversation | null; isSubmitting: boolean; message: TMessage; regenerate: () => void; @@ -25,33 +25,47 @@ export default function HoverButtons({ regenerate, handleContinue, }: THoverButtons) { - const { endpoint } = conversation; + const { endpoint } = conversation ?? {}; const [isCopied, setIsCopied] = useState(false); - const { editEnabled, regenerateEnabled, continueSupported } = useGenerations({ + const { hideEditButton, regenerateEnabled, continueSupported } = useGenerations({ isEditing, isSubmitting, message, endpoint: endpoint ?? '', }); + if (!conversation) { + return null; + } + + const { isCreatedByUser } = message; + + const onEdit = () => { + if (isEditing) { + return enterEdit(true); + } + enterEdit(); + }; return (
- -
-
- ) : ( - <> -
- {/*
*/} -
- {!isCreatedByUser ? ( - <> - - - ) : ( - <>{text} - )} -
-
- {/* {!isSubmitting && cancelled ? ( -
-
- {`This is a cancelled message.`} -
-
- ) : null} */} - {!isSubmitting && unfinished ? ( -
-
- { - 'This is an unfinished message. The AI may still be generating a response, it was aborted, or a censor was triggered. Refresh or visit later to see more updates.' - } -
-
- ) : null} - - )} -
- enterEdit()} - regenerate={() => regenerateMessage()} - handleContinue={handleContinue} - copyToClipboard={copyToClipboard} - /> - - - -
- - - - - ); -} diff --git a/client/src/components/Messages/Message.tsx b/client/src/components/Messages/Message.tsx new file mode 100644 index 000000000..963ee32e8 --- /dev/null +++ b/client/src/components/Messages/Message.tsx @@ -0,0 +1,212 @@ +/* eslint-disable react-hooks/exhaustive-deps */ +import { useGetConversationByIdQuery } from 'librechat-data-provider'; +import { useState, useEffect } from 'react'; +import { useSetRecoilState } from 'recoil'; +import copy from 'copy-to-clipboard'; +import { Plugin, SubRow, MessageContent } from './Content'; +// eslint-disable-next-line import/no-cycle +import MultiMessage from './MultiMessage'; +import HoverButtons from './HoverButtons'; +import SiblingSwitch from './SiblingSwitch'; +import { getIcon } from '~/components/Endpoints'; +import { useMessageHandler } from '~/hooks'; +import type { TMessageProps } from '~/common'; +import store from '~/store'; + +export default function Message({ + conversation, + message, + scrollToBottom, + currentEditId, + setCurrentEditId, + siblingIdx, + siblingCount, + setSiblingIdx, +}: TMessageProps) { + const setLatestMessage = useSetRecoilState(store.latestMessage); + const [abortScroll, setAbort] = useState(false); + const { isSubmitting, ask, regenerate, handleContinue } = useMessageHandler(); + const { switchToConversation } = store.useConversation(); + const { + text, + children, + messageId = null, + searchResult, + isCreatedByUser, + error, + unfinished, + } = message ?? {}; + const last = !children?.length; + const edit = messageId == currentEditId; + const getConversationQuery = useGetConversationByIdQuery(message?.conversationId ?? '', { + enabled: false, + }); + const blinker = message?.submitting && isSubmitting; + + // debugging + // useEffect(() => { + // console.log('isSubmitting:', isSubmitting); + // console.log('unfinished:', unfinished); + // }, [isSubmitting, unfinished]); + + useEffect(() => { + if (blinker && scrollToBottom && !abortScroll) { + scrollToBottom(); + } + }, [isSubmitting, blinker, text, scrollToBottom]); + + useEffect(() => { + if (!message) { + return; + } else if (last) { + setLatestMessage({ ...message }); + } + }, [last, message]); + + if (!message) { + return null; + } + + const enterEdit = (cancel?: boolean) => + setCurrentEditId && setCurrentEditId(cancel ? -1 : messageId); + + const handleWheel = () => { + if (blinker) { + setAbort(true); + } else { + setAbort(false); + } + }; + + const props = { + className: + 'w-full border-b border-black/10 dark:border-gray-900/50 text-gray-800 bg-white dark:text-gray-100 group dark:bg-gray-800', + titleclass: '', + }; + + const icon = getIcon({ + ...conversation, + ...message, + model: message?.model ?? conversation?.model, + }); + + if (!isCreatedByUser) { + props.className = + 'w-full border-b border-black/10 bg-gray-50 dark:border-gray-900/50 text-gray-800 dark:text-gray-100 group bg-gray-100 dark:bg-gray-1000'; + } + + if (message?.bg && searchResult) { + props.className = message?.bg?.split('hover')[0]; + props.titleclass = message?.bg?.split(props.className)[1] + ' cursor-pointer'; + } + + const regenerateMessage = () => { + if (!isSubmitting && !isCreatedByUser) { + regenerate(message); + } + }; + + const copyToClipboard = (setIsCopied: React.Dispatch>) => { + setIsCopied(true); + copy(text ?? ''); + + setTimeout(() => { + setIsCopied(false); + }, 3000); + }; + + const clickSearchResult = async () => { + if (!searchResult) { + return; + } + if (!message) { + return; + } + getConversationQuery.refetch({ queryKey: [message?.conversationId] }).then((response) => { + console.log('getConversationQuery response.data:', response.data); + if (response.data) { + switchToConversation(response.data); + } + }); + }; + + return ( + <> +
+
+
+ {typeof icon === 'string' && /[^\\x00-\\x7F]+/.test(icon as string) ? ( + {icon} + ) : ( + icon + )} +
+ +
+
+
+ {searchResult && ( + + {`${message?.title} | ${message?.sender}`} + + )} +
+ {message?.plugin && } + { + return; + }) + } + /> +
+ regenerateMessage()} + handleContinue={handleContinue} + copyToClipboard={copyToClipboard} + /> + + + +
+
+
+ + + ); +} diff --git a/client/src/components/Messages/MessageHeader.jsx b/client/src/components/Messages/MessageHeader.tsx similarity index 94% rename from client/src/components/Messages/MessageHeader.jsx rename to client/src/components/Messages/MessageHeader.tsx index 37dd94884..34099afd0 100644 --- a/client/src/components/Messages/MessageHeader.jsx +++ b/client/src/components/Messages/MessageHeader.tsx @@ -1,5 +1,6 @@ import { useState } from 'react'; import { useRecoilValue } from 'recoil'; +import type { TPreset } from 'librechat-data-provider'; import { Plugin } from '~/components/svg'; import EndpointOptionsDialog from '../Endpoints/EndpointOptionsDialog'; import { cn, alternateName } from '~/utils/'; @@ -10,7 +11,17 @@ const MessageHeader = ({ isSearchView = false }) => { const [saveAsDialogShow, setSaveAsDialogShow] = useState(false); const conversation = useRecoilValue(store.conversation); const searchQuery = useRecoilValue(store.searchQuery); + + if (!conversation) { + return null; + } + const { endpoint, model } = conversation; + + if (!endpoint) { + return null; + } + const isNotClickable = endpoint === 'chatGPTBrowser'; const plugins = ( @@ -89,7 +100,7 @@ const MessageHeader = ({ isSearchView = false }) => { ); diff --git a/client/src/components/Messages/index.jsx b/client/src/components/Messages/Messages.tsx similarity index 75% rename from client/src/components/Messages/index.jsx rename to client/src/components/Messages/Messages.tsx index 667e67e81..bf454ea47 100644 --- a/client/src/components/Messages/index.jsx +++ b/client/src/components/Messages/Messages.tsx @@ -1,20 +1,20 @@ -import React, { useEffect, useState, useRef, useCallback } from 'react'; -import { useRecoilValue } from 'recoil'; -import { Spinner } from '~/components'; -import throttle from 'lodash/throttle'; +import { useEffect, useState, useRef } from 'react'; import { CSSTransition } from 'react-transition-group'; +import { useRecoilValue } from 'recoil'; + import ScrollToBottom from './ScrollToBottom'; -import MultiMessage from './MultiMessage'; import MessageHeader from './MessageHeader'; -import { useScreenshot } from '~/hooks'; +import MultiMessage from './MultiMessage'; +import { Spinner } from '~/components'; +import { useScreenshot, useScrollToRef } from '~/hooks'; import store from '~/store'; export default function Messages({ isSearchView = false }) { - const [currentEditId, setCurrentEditId] = useState(-1); + const [currentEditId, setCurrentEditId] = useState(-1); const [showScrollButton, setShowScrollButton] = useState(false); - const scrollableRef = useRef(null); - const messagesEndRef = useRef(null); + const scrollableRef = useRef(null); + const messagesEndRef = useRef(null); const messagesTree = useRecoilValue(store.messagesTree); const showPopover = useRecoilValue(store.showPopover); @@ -22,8 +22,8 @@ export default function Messages({ isSearchView = false }) { const _messagesTree = isSearchView ? searchResultMessagesTree : messagesTree; - const conversation = useRecoilValue(store.conversation) || {}; - const { conversationId } = conversation; + const conversation = useRecoilValue(store.conversation); + const { conversationId } = conversation ?? {}; const { screenshotTargetRef } = useScreenshot(); @@ -62,42 +62,15 @@ export default function Messages({ isSearchView = false }) { }; }, [_messagesTree]); - // eslint-disable-next-line react-hooks/exhaustive-deps - const scrollToBottom = useCallback( - throttle( - () => { - messagesEndRef.current?.scrollIntoView({ behavior: 'instant' }); - setShowScrollButton(false); - }, - 450, - { leading: true }, - ), - [messagesEndRef], - ); - - // eslint-disable-next-line react-hooks/exhaustive-deps - const scrollToBottomSmooth = useCallback( - throttle( - () => { - messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); - setShowScrollButton(false); - }, - 750, - { leading: true }, - ), - [messagesEndRef], - ); - - let timeoutId = null; + let timeoutId: ReturnType | undefined; const debouncedHandleScroll = () => { clearTimeout(timeoutId); timeoutId = setTimeout(handleScroll, 100); }; - const scrollHandler = (e) => { - e.preventDefault(); - scrollToBottomSmooth(); - }; + const { scrollToRef: scrollToBottom, handleSmoothToRef } = useScrollToRef(messagesEndRef, () => + setShowScrollButton(false), + ); return (
@@ -137,7 +110,7 @@ export default function Messages({ isSearchView = false }) { > {() => showScrollButton && - !showPopover && + !showPopover && } diff --git a/client/src/components/Messages/MultiMessage.jsx b/client/src/components/Messages/MultiMessage.tsx similarity index 82% rename from client/src/components/Messages/MultiMessage.jsx rename to client/src/components/Messages/MultiMessage.tsx index ce49f56f5..08a17c33f 100644 --- a/client/src/components/Messages/MultiMessage.jsx +++ b/client/src/components/Messages/MultiMessage.tsx @@ -1,5 +1,7 @@ import { useEffect } from 'react'; import { useRecoilState } from 'recoil'; +import type { TMessageProps } from '~/common'; +// eslint-disable-next-line import/no-cycle import Message from './Message'; import store from '~/store'; @@ -11,23 +13,21 @@ export default function MultiMessage({ currentEditId, setCurrentEditId, isSearchView, -}) { - // const [siblingIdx, setSiblingIdx] = useState(0); - +}: TMessageProps) { const [siblingIdx, setSiblingIdx] = useRecoilState(store.messagesSiblingIdxFamily(messageId)); - const setSiblingIdxRev = (value) => { - setSiblingIdx(messagesTree?.length - value - 1); + const setSiblingIdxRev = (value: number) => { + setSiblingIdx((messagesTree?.length ?? 0) - value - 1); }; useEffect(() => { - // reset siblingIdx when changes, mostly a new message is submitting. + // reset siblingIdx when the tree changes, mostly when a new message is submitting. setSiblingIdx(0); // eslint-disable-next-line react-hooks/exhaustive-deps }, [messagesTree?.length]); // if (!messageList?.length) return null; - if (!(messagesTree && messagesTree.length)) { + if (!(messagesTree && messagesTree?.length)) { return null; } diff --git a/client/src/components/Messages/SiblingSwitch.jsx b/client/src/components/Messages/SiblingSwitch.tsx similarity index 69% rename from client/src/components/Messages/SiblingSwitch.jsx rename to client/src/components/Messages/SiblingSwitch.tsx index e04b6c31a..0f55076ef 100644 --- a/client/src/components/Messages/SiblingSwitch.jsx +++ b/client/src/components/Messages/SiblingSwitch.tsx @@ -1,13 +1,26 @@ -import React from 'react'; +import type { TMessageProps } from '~/common'; + +type TSiblingSwitchProps = Pick; + +export default function SiblingSwitch({ + siblingIdx, + siblingCount, + setSiblingIdx, +}: TSiblingSwitchProps) { + if (siblingIdx === undefined) { + return null; + } else if (siblingCount === undefined) { + return null; + } -export default function SiblingSwitch({ siblingIdx, siblingCount, setSiblingIdx }) { const previous = () => { - setSiblingIdx(siblingIdx - 1); + setSiblingIdx && setSiblingIdx(siblingIdx - 1); }; const next = () => { - setSiblingIdx(siblingIdx + 1); + setSiblingIdx && setSiblingIdx(siblingIdx + 1); }; + return siblingCount > 1 ? ( <> @@ -50,7 +63,7 @@ export default function SiblingSwitch({ siblingIdx, siblingCount, setSiblingIdx width="1em" xmlns="http://www.w3.org/2000/svg" > - + diff --git a/client/src/components/svg/Plugin.tsx b/client/src/components/svg/Plugin.tsx index 05c53d1a0..4d6c25ffa 100644 --- a/client/src/components/svg/Plugin.tsx +++ b/client/src/components/svg/Plugin.tsx @@ -1,6 +1,6 @@ import { cn } from '~/utils/'; -export default function Plugin({ className, ...props }) { +export default function Plugin({ className = '', ...props }) { return ( ; + ref?: RefObject; }; const ScreenshotContext = createContext({}); diff --git a/client/src/hooks/index.ts b/client/src/hooks/index.ts index 552208df0..55f19eebe 100644 --- a/client/src/hooks/index.ts +++ b/client/src/hooks/index.ts @@ -7,5 +7,6 @@ export { default as useLocalize } from './useLocalize'; export { default as useMediaQuery } from './useMediaQuery'; export { default as useSetOptions } from './useSetOptions'; export { default as useGenerations } from './useGenerations'; +export { default as useScrollToRef } from './useScrollToRef'; export { default as useServerStream } from './useServerStream'; export { default as useMessageHandler } from './useMessageHandler'; diff --git a/client/src/hooks/useGenerations.ts b/client/src/hooks/useGenerations.ts index 8040ec6b8..549282d38 100644 --- a/client/src/hooks/useGenerations.ts +++ b/client/src/hooks/useGenerations.ts @@ -18,12 +18,17 @@ export default function useGenerations({ const latestMessage = useRecoilValue(store.latestMessage); const { error, messageId, searchResult, finish_reason, isCreatedByUser } = message ?? {}; + const isEditableEndpoint = !!['azureOpenAI', 'openAI', 'gptPlugins', 'anthropic'].find( + (e) => e === endpoint, + ); const continueSupported = latestMessage?.messageId === messageId && finish_reason && finish_reason !== 'stop' && - !!['azureOpenAI', 'openAI', 'gptPlugins', 'anthropic'].find((e) => e === endpoint); + !isEditing && + !searchResult && + isEditableEndpoint; const branchingSupported = // 5/21/23: Bing is allowing editing and Message regenerating @@ -37,19 +42,15 @@ export default function useGenerations({ 'anthropic', ].find((e) => e === endpoint); - const editEnabled = - !error && - isCreatedByUser && // TODO: allow AI editing - !searchResult && - !isEditing && - branchingSupported; - const regenerateEnabled = !isCreatedByUser && !searchResult && !isEditing && !isSubmitting && branchingSupported; + const hideEditButton = + error || searchResult || !branchingSupported || (!isEditableEndpoint && !isCreatedByUser); + return { continueSupported, - editEnabled, regenerateEnabled, + hideEditButton, }; } diff --git a/client/src/hooks/useMessageHandler.ts b/client/src/hooks/useMessageHandler.ts index 47c9a44ee..3b019a012 100644 --- a/client/src/hooks/useMessageHandler.ts +++ b/client/src/hooks/useMessageHandler.ts @@ -2,28 +2,31 @@ import { v4 } from 'uuid'; import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'; import { parseConvo, getResponseSender } from 'librechat-data-provider'; import type { TMessage, TSubmission } from 'librechat-data-provider'; +import type { TAskFunction } from '~/common'; import store from '~/store'; -type TAskProps = { - text: string; - parentMessageId?: string | null; - conversationId?: string | null; - messageId?: string | null; -}; - const useMessageHandler = () => { + const latestMessage = useRecoilValue(store.latestMessage); + const setSiblingIdx = useSetRecoilState( + store.messagesSiblingIdxFamily(latestMessage?.parentMessageId), + ); const currentConversation = useRecoilValue(store.conversation) || { endpoint: null }; const setSubmission = useSetRecoilState(store.submission); const isSubmitting = useRecoilValue(store.isSubmitting); const endpointsConfig = useRecoilValue(store.endpointsConfig); - const latestMessage = useRecoilValue(store.latestMessage); const [messages, setMessages] = useRecoilState(store.messages); const { endpoint } = currentConversation; const { getToken } = store.useToken(endpoint ?? ''); - const ask = ( - { text, parentMessageId = null, conversationId = null, messageId = null }: TAskProps, - { isRegenerate = false, isEdited = false } = {}, + const ask: TAskFunction = ( + { text, parentMessageId = null, conversationId = null, messageId = null }, + { + editedText = null, + editedMessageId = null, + isRegenerate = false, + isContinued = false, + isEdited = false, + } = {}, ) => { if (!!isSubmitting || text === '') { return; @@ -40,11 +43,12 @@ const useMessageHandler = () => { return; } - if (isEdited && !latestMessage) { - console.error('cannot edit AI message without latestMessage!'); + if (isContinued && !latestMessage) { + console.error('cannot continue AI message without latestMessage!'); return; } + const isEditOrContinue = isEdited || isContinued; const { userProvide } = endpointsConfig[endpoint] ?? {}; // set the endpoint option @@ -77,15 +81,17 @@ const useMessageHandler = () => { isCreatedByUser: true, parentMessageId, conversationId, - messageId: isEdited && messageId ? messageId : fakeMessageId, + messageId: isContinued && messageId ? messageId : fakeMessageId, error: false, }; // construct the placeholder response message - const generation = latestMessage?.text ?? ''; - const responseText = isEdited ? generation : ''; + const generation = editedText ?? latestMessage?.text ?? ''; + const responseText = isEditOrContinue + ? generation + : ''; - const responseMessageId = isEdited ? latestMessage?.messageId : null; + const responseMessageId = editedMessageId ?? latestMessage?.messageId ?? null; const initialResponse: TMessage = { sender: responseSender, text: responseText, @@ -98,6 +104,10 @@ const useMessageHandler = () => { error: false, }; + if (isContinued) { + currentMessages = currentMessages.filter((msg) => msg.messageId !== responseMessageId); + } + const submission: TSubmission = { conversation: { ...currentConversation, @@ -111,7 +121,8 @@ const useMessageHandler = () => { overrideParentMessageId: isRegenerate ? messageId : null, }, messages: currentMessages, - isEdited, + isEdited: isEditOrContinue, + isContinued, isRegenerate, initialResponse, }; @@ -119,12 +130,9 @@ const useMessageHandler = () => { console.log('User Input:', text, submission); if (isRegenerate) { - setMessages([ - ...(isEdited ? currentMessages.slice(0, -1) : currentMessages), - initialResponse, - ]); + setMessages([...submission.messages, initialResponse]); } else { - setMessages([...currentMessages, currentMsg, initialResponse]); + setMessages([...submission.messages, currentMsg, initialResponse]); } setSubmission(submission); }; @@ -152,7 +160,7 @@ const useMessageHandler = () => { ); if (parentMessage && parentMessage.isCreatedByUser) { - ask({ ...parentMessage }, { isRegenerate: true, isEdited: true }); + ask({ ...parentMessage }, { isContinued: true, isRegenerate: true, isEdited: true }); } else { console.error( 'Failed to regenerate the message: parentMessage not found, or not created by user.', @@ -182,6 +190,7 @@ const useMessageHandler = () => { const handleContinue = (e: React.MouseEvent) => { e.preventDefault(); continueGeneration(); + setSiblingIdx(0); }; return { diff --git a/client/src/hooks/useScrollToRef.ts b/client/src/hooks/useScrollToRef.ts new file mode 100644 index 000000000..ab70424de --- /dev/null +++ b/client/src/hooks/useScrollToRef.ts @@ -0,0 +1,40 @@ +import { RefObject, useCallback } from 'react'; +import throttle from 'lodash/throttle'; + +export default function useScrollToRef(targetRef: RefObject, callback: () => void) { + // eslint-disable-next-line react-hooks/exhaustive-deps + const scrollToRef = useCallback( + throttle( + () => { + targetRef.current?.scrollIntoView({ behavior: 'instant' }); + callback(); + }, + 450, + { leading: true }, + ), + [targetRef], + ); + + // eslint-disable-next-line react-hooks/exhaustive-deps + const scrollToRefSmooth = useCallback( + throttle( + () => { + targetRef.current?.scrollIntoView({ behavior: 'smooth' }); + callback(); + }, + 750, + { leading: true }, + ), + [targetRef], + ); + + const handleSmoothToRef: React.MouseEventHandler = (e) => { + e.preventDefault(); + scrollToRefSmooth(); + }; + + return { + scrollToRef, + handleSmoothToRef, + }; +} diff --git a/client/src/hooks/useServerStream.ts b/client/src/hooks/useServerStream.ts index f8ecbcf72..b7f92bb06 100644 --- a/client/src/hooks/useServerStream.ts +++ b/client/src/hooks/useServerStream.ts @@ -1,12 +1,12 @@ import { useEffect } from 'react'; import { useResetRecoilState, useSetRecoilState } from 'recoil'; import { SSE, createPayload, tMessageSchema, tConversationSchema } from 'librechat-data-provider'; -import type { TPlugin, TMessage, TConversation, TSubmission } from 'librechat-data-provider'; +import type { TResPlugin, TMessage, TConversation, TSubmission } from 'librechat-data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import store from '~/store'; type TResData = { - plugin: TPlugin; + plugin: TResPlugin; final?: boolean; initial?: boolean; requestMessage: TMessage; @@ -24,18 +24,11 @@ export default function useServerStream(submission: TSubmission | null) { const { refreshConversations } = store.useConversations(); const messageHandler = (data: string, submission: TSubmission) => { - const { - messages, - message, - plugin, - initialResponse, - isRegenerate = false, - isEdited = false, - } = submission; + const { messages, message, plugin, initialResponse, isRegenerate = false } = submission; if (isRegenerate) { setMessages([ - ...(isEdited ? messages.slice(0, -1) : messages), + ...messages, { ...initialResponse, text: data, @@ -65,11 +58,11 @@ export default function useServerStream(submission: TSubmission | null) { const cancelHandler = (data: TResData, submission: TSubmission) => { const { requestMessage, responseMessage, conversation } = data; - const { messages, isRegenerate = false, isEdited = false } = submission; + const { messages, isRegenerate = false } = submission; // update the messages if (isRegenerate) { - setMessages([...(isEdited ? messages.slice(0, -1) : messages), responseMessage]); + setMessages([...messages, responseMessage]); } else { setMessages([...messages, requestMessage, responseMessage]); } @@ -94,17 +87,11 @@ export default function useServerStream(submission: TSubmission | null) { }; const createdHandler = (data: TResData, submission: TSubmission) => { - const { - messages, - message, - initialResponse, - isRegenerate = false, - isEdited = false, - } = submission; + const { messages, message, initialResponse, isRegenerate = false } = submission; if (isRegenerate) { setMessages([ - ...(isEdited ? messages.slice(0, -1) : messages), + ...messages, { ...initialResponse, parentMessageId: message?.overrideParentMessageId ?? null, @@ -137,11 +124,11 @@ export default function useServerStream(submission: TSubmission | null) { const finalHandler = (data: TResData, submission: TSubmission) => { const { requestMessage, responseMessage, conversation } = data; - const { messages, isRegenerate = false, isEdited = false } = submission; + const { messages, isRegenerate = false } = submission; // update the messages if (isRegenerate) { - setMessages([...(isEdited ? messages.slice(0, -1) : messages), responseMessage]); + setMessages([...messages, responseMessage]); } else { setMessages([...messages, requestMessage, responseMessage]); } diff --git a/client/src/routes/Chat.tsx b/client/src/routes/Chat.tsx index 8453fbb17..cea4357ab 100644 --- a/client/src/routes/Chat.tsx +++ b/client/src/routes/Chat.tsx @@ -4,7 +4,7 @@ import { useNavigate, useParams } from 'react-router-dom'; import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'; import Landing from '~/components/ui/Landing'; -import Messages from '~/components/Messages'; +import Messages from '~/components/Messages/Messages'; import TextChat from '~/components/Input/TextChat'; import store from '~/store'; diff --git a/client/src/routes/Search.tsx b/client/src/routes/Search.tsx index 882d5a5ec..95d7cc861 100644 --- a/client/src/routes/Search.tsx +++ b/client/src/routes/Search.tsx @@ -2,7 +2,7 @@ import React, { useEffect } from 'react'; import { useNavigate, useParams } from 'react-router-dom'; import { useRecoilState, useRecoilValue } from 'recoil'; -import Messages from '~/components/Messages'; +import Messages from '~/components/Messages/Messages'; import TextChat from '~/components/Input/TextChat'; import store from '~/store'; diff --git a/e2e/specs/messages.spec.ts b/e2e/specs/messages.spec.ts index d826600d4..a81ff9cd3 100644 --- a/e2e/specs/messages.spec.ts +++ b/e2e/specs/messages.spec.ts @@ -12,7 +12,10 @@ function isUUID(uuid: string) { } const waitForServerStream = async (response: Response) => { - return response.url().includes(`/api/ask/${endpoint}`) && response.status() === 200; + const endpointCheck = + response.url().includes(`/api/ask/${endpoint}`) || + response.url().includes(`/api/edit/${endpoint}`); + return endpointCheck && response.status() === 200; }; async function clearConvos(page: Page) { @@ -52,7 +55,7 @@ test.afterEach(async ({ page }) => { }); test.describe('Messaging suite', () => { - test('textbox should be focused after receiving message & test expected navigation', async ({ + test('textbox should be focused after generation, test expected navigation, & test editing messages', async ({ page, }) => { test.setTimeout(120000); @@ -91,6 +94,33 @@ test.describe('Messaging suite', () => { const finalUrl = page.url(); const conversationId = finalUrl.split(basePath).pop() ?? ''; expect(isUUID(conversationId)).toBeTruthy(); + + // Check if editing works + const editText = 'All work and no play makes Johnny a poor boy'; + await page.getByRole('button', { name: 'edit' }).click(); + const textEditor = page.getByTestId('message-text-editor'); + await textEditor.click(); + await textEditor.fill(editText); + await page.getByRole('button', { name: 'Save', exact: true }).click(); + + const updatedTextElement = page.getByText(editText); + expect(updatedTextElement).toBeTruthy(); + + // Check edit response + await page.getByRole('button', { name: 'edit' }).click(); + const editResponsePromise = [ + page.waitForResponse(waitForServerStream), + await page.getByRole('button', { name: 'Save & Submit' }).click(), + ]; + + const [editResponse] = (await Promise.all(editResponsePromise)) as [Response]; + const editResponseBody = await editResponse.body(); + const editSuccess = editResponseBody.includes('"final":true'); + expect(editSuccess).toBe(true); + + // The generated message should include the edited text + const currentTextContent = await updatedTextElement.innerText(); + expect(currentTextContent.includes(editText)).toBeTruthy(); }); test('message should stop and continue', async ({ page }) => { @@ -124,6 +154,9 @@ test.describe('Messaging suite', () => { const regenerateButton = page.getByRole('button', { name: 'Regenerate' }); expect(regenerateButton).toBeTruthy(); + + // Clear conversation since it seems to persist despite other tests clearing it + await page.getByTestId('convo-item').getByRole('button').nth(1).click(); }); // in this spec as we are testing post-message navigation, we are not testing the message response diff --git a/package-lock.json b/package-lock.json index 7d2ef2892..5760a96a5 100644 --- a/package-lock.json +++ b/package-lock.json @@ -48,7 +48,7 @@ "dependencies": { "@anthropic-ai/sdk": "^0.5.4", "@azure/search-documents": "^11.3.2", - "@dqbd/tiktoken": "^1.0.2", + "@dqbd/tiktoken": "^1.0.7", "@fortaine/fetch-event-source": "^3.0.6", "@keyv/mongo": "^2.1.8", "@waylaidwanderer/chatgpt-api": "^1.37.2", @@ -26431,7 +26431,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.1.5", + "version": "0.1.6", "license": "ISC", "dependencies": { "@tanstack/react-query": "^4.28.0", diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 324b44226..a678cce12 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.1.5", + "version": "0.1.6", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/src/api-endpoints.ts b/packages/data-provider/src/api-endpoints.ts index fe1579cd5..a84a57e17 100644 --- a/packages/data-provider/src/api-endpoints.ts +++ b/packages/data-provider/src/api-endpoints.ts @@ -6,8 +6,8 @@ export const userPlugins = () => { return '/api/user/plugins'; }; -export const messages = (id: string) => { - return `/api/messages/${id}`; +export const messages = (conversationId: string, messageId?: string) => { + return `/api/messages/${conversationId}${messageId ? `/${messageId}` : ''}`; }; export const abortRequest = (endpoint: string) => { diff --git a/packages/data-provider/src/createPayload.ts b/packages/data-provider/src/createPayload.ts index b2d2f0e4e..eab38cfba 100644 --- a/packages/data-provider/src/createPayload.ts +++ b/packages/data-provider/src/createPayload.ts @@ -2,7 +2,7 @@ import { tConversationSchema } from './schemas'; import { TSubmission, EModelEndpoint } from './types'; export default function createPayload(submission: TSubmission) { - const { conversation, message, endpointOption, isEdited } = submission; + const { conversation, message, endpointOption, isEdited, isContinued } = submission; const { conversationId } = tConversationSchema.parse(conversation); const { endpoint } = endpointOption as { endpoint: EModelEndpoint }; @@ -26,6 +26,7 @@ export default function createPayload(submission: TSubmission) { const payload = { ...message, ...endpointOption, + isContinued: isEdited && isContinued, conversationId, }; diff --git a/packages/data-provider/src/data-service.ts b/packages/data-provider/src/data-service.ts index 44d3dea41..24d7822d1 100644 --- a/packages/data-provider/src/data-service.ts +++ b/packages/data-provider/src/data-service.ts @@ -24,8 +24,8 @@ export function clearAllConversations(): Promise { return request.post(endpoints.deleteConversation(), { arg: {} }); } -export function getMessagesByConvoId(id: string): Promise { - return request.get(endpoints.messages(id)); +export function getMessagesByConvoId(conversationId: string): Promise { + return request.get(endpoints.messages(conversationId)); } export function getConversationById(id: string): Promise { @@ -38,6 +38,15 @@ export function updateConversation( return request.post(endpoints.updateConversation(), { arg: payload }); } +export function updateMessage(payload: t.TUpdateMessageRequest): Promise { + const { conversationId, messageId, text } = payload; + if (!conversationId) { + throw new Error('conversationId is required'); + } + + return request.put(endpoints.messages(conversationId, messageId), { text }); +} + export function getPresets(): Promise { return request.get(endpoints.presets()); } diff --git a/packages/data-provider/src/react-query-service.ts b/packages/data-provider/src/react-query-service.ts index 3bcbffc7d..d75849c47 100644 --- a/packages/data-provider/src/react-query-service.ts +++ b/packages/data-provider/src/react-query-service.ts @@ -110,6 +110,17 @@ export const useUpdateConversationMutation = ( ); }; +export const useUpdateMessageMutation = ( + id: string, +): UseMutationResult => { + const queryClient = useQueryClient(); + return useMutation((payload: t.TUpdateMessageRequest) => dataService.updateMessage(payload), { + onSuccess: () => { + queryClient.invalidateQueries([QueryKeys.messages, id]); + }, + }); +}; + export const useDeleteConversationMutation = ( id?: string, ): UseMutationResult< diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index a66860390..996c4ddb4 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -32,6 +32,20 @@ export const tPluginSchema = z.object({ export type TPlugin = z.infer; +export type TInput = { + inputStr: string; +}; + +export type TResPlugin = { + plugin: string; + input: string; + thought: string; + loading?: boolean; + outputs?: string; + latest?: string; + inputs?: TInput[]; +}; + export const tExampleSchema = z.object({ input: z.object({ content: z.string(), @@ -57,7 +71,9 @@ export const tMessageSchema = z.object({ parentMessageId: z.string().nullable(), responseMessageId: z.string().nullable().optional(), overrideParentMessageId: z.string().nullable().optional(), - plugin: tPluginSchema.nullable().optional(), + bg: z.string().nullable().optional(), + model: z.string().nullable().optional(), + title: z.string().nullable().optional(), sender: z.string(), text: z.string(), generation: z.string().nullable().optional(), @@ -78,7 +94,10 @@ export const tMessageSchema = z.object({ finish_reason: z.string().optional(), }); -export type TMessage = z.input; +export type TMessage = z.input & { + children?: TMessage[]; + plugin?: TResPlugin | null; +}; export const tConversationSchema = z.object({ conversationId: z.string().nullable(), diff --git a/packages/data-provider/src/types.ts b/packages/data-provider/src/types.ts index 40bb5ad1e..ecf8adbd8 100644 --- a/packages/data-provider/src/types.ts +++ b/packages/data-provider/src/types.ts @@ -1,4 +1,4 @@ -import type { TPlugin, TMessage, TConversation, TEndpointOption } from './schemas'; +import type { TResPlugin, TMessage, TConversation, TEndpointOption } from './schemas'; export * from './schemas'; @@ -7,9 +7,10 @@ export type TMessages = TMessage[]; export type TMessagesAtom = TMessages | null; export type TSubmission = { - plugin?: TPlugin; + plugin?: TResPlugin; message: TMessage; isEdited?: boolean; + isContinued?: boolean; messages: TMessage[]; isRegenerate?: boolean; conversationId?: string; @@ -37,6 +38,7 @@ export type TError = { data?: { message?: string; }; + status?: number; }; }; @@ -60,6 +62,12 @@ export type TGetConversationsResponse = { pages: string | number; }; +export type TUpdateMessageRequest = { + conversationId: string; + messageId: string; + text: string; +}; + export type TUpdateConversationRequest = { conversationId: string; title: string;