diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 905cadfd23..60c3d40cb5 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -585,6 +585,8 @@ class BaseClient { responseMessage.text = completion.join(''); } + /** @type {(() => Promise) | null} */ + let reconcileTokenCount = null; if (tokenCountMap && this.recordTokenUsage && this.getTokenCountForResponse) { let completionTokens; @@ -599,25 +601,38 @@ class BaseClient { responseMessage.tokenCount = usage[this.outputTokensKey]; completionTokens = responseMessage.tokenCount; } else { - responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); - completionTokens = responseMessage.tokenCount; - await this.recordTokenUsage({ - usage, - promptTokens, - completionTokens, - balance: balanceConfig, - /** Note: When using agents, responseMessage.model is the agent ID, not the model */ - model: this.model, - messageId: this.responseMessageId, - }); + reconcileTokenCount = async () => { + const tokenCount = this.getTokenCountForResponse(responseMessage); + await this.recordTokenUsage({ + usage, + promptTokens, + completionTokens: tokenCount, + balance: balanceConfig, + /** Note: When using agents, responseMessage.model is the agent ID, not the model */ + model: this.model, + messageId: this.responseMessageId, + }); + await this.updateMessageInDatabase({ + messageId: this.responseMessageId, + tokenCount, + }); + logger.debug('[BaseClient] Async response token reconciliation complete', { + messageId: responseMessage.messageId, + model: responseMessage.model, + promptTokens, + completionTokens: tokenCount, + }); + }; } - logger.debug('[BaseClient] Response token usage', { - messageId: responseMessage.messageId, - model: responseMessage.model, - promptTokens, - completionTokens, - }); + if (completionTokens != null) { + logger.debug('[BaseClient] Response token usage', { + messageId: responseMessage.messageId, + model: responseMessage.model, + promptTokens, + completionTokens, + }); + } } if (userMessagePromise) { @@ -666,6 +681,13 @@ class BaseClient { saveOptions, user, ); + if (reconcileTokenCount != null) { + void responseMessage.databasePromise + .then(() => reconcileTokenCount()) + .catch((error) => { + logger.error('[BaseClient] Async token reconciliation failed', error); + }); + } this.savedMessageIds.add(responseMessage.messageId); delete responseMessage.tokenCount; return responseMessage; diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index 7030d8fe35..08ccf2a45f 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -3,7 +3,7 @@ import { useWatch } from 'react-hook-form'; import { TextareaAutosize } from '@librechat/client'; import { useRecoilState, useRecoilValue } from 'recoil'; import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider'; -import type { TConversation, TMessage } from 'librechat-data-provider'; +import type { TConversation } from 'librechat-data-provider'; import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common'; import { useChatContext, @@ -49,7 +49,6 @@ interface ChatFormProps { setFilesLoading: React.Dispatch>; newConversation: ConvoGenerator; handleStopGenerating: (e: React.MouseEvent) => void; - getMessages: () => TMessage[] | undefined; } const ChatForm = memo(function ChatForm({ @@ -62,7 +61,6 @@ const ChatForm = memo(function ChatForm({ setFilesLoading, newConversation, handleStopGenerating, - getMessages, }: ChatFormProps) { const submitButtonRef = useRef(null); const textAreaRef = useRef(null); @@ -321,7 +319,7 @@ const ChatForm = memo(function ChatForm({ onBlur={handleTextareaBlur} aria-label={localize('com_ui_message_input')} onClick={handleFocusOrClick} - style={{ height: 44, overflowY: 'auto' }} + style={{ height: 44 }} className={cn( baseClasses, removeFocusRings, @@ -376,9 +374,7 @@ const ChatForm = memo(function ChatForm({ /> )}
- {endpoint && ( - - )} + {endpoint && } {isSubmitting && showStopButton ? ( ) : ( @@ -416,7 +412,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) { setFilesLoading, newConversation, handleStopGenerating, - getMessages, } = useChatContext(); /** @@ -456,10 +451,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) { [], ); - const getMessagesRef = useRef(getMessages); - getMessagesRef.current = getMessages; - const stableGetMessages = useCallback(() => getMessagesRef.current(), []); - return ( ); } diff --git a/client/src/components/Chat/Input/ContextTracker.tsx b/client/src/components/Chat/Input/ContextTracker.tsx index 02cbf16c72..a664a8091b 100644 --- a/client/src/components/Chat/Input/ContextTracker.tsx +++ b/client/src/components/Chat/Input/ContextTracker.tsx @@ -1,4 +1,4 @@ -import { useEffect, useMemo, useRef, useState } from 'react'; +import { useCallback, useMemo, useSyncExternalStore } from 'react'; import { useQueryClient } from '@tanstack/react-query'; import { QueryKeys, Constants } from 'librechat-data-provider'; import type { TConversation, TMessage, TModelSpec, TStartupConfig } from 'librechat-data-provider'; @@ -11,8 +11,6 @@ import { cn } from '~/utils'; type ContextTrackerProps = { conversation: TConversation | null; - getMessages: () => TMessage[] | undefined; - isSubmitting: boolean; }; type MessageWithTokenCount = TMessage & { tokenCount?: number }; @@ -68,44 +66,38 @@ const getSpecMaxContextTokens = ( return maxContextTokens; }; -export default function ContextTracker({ - conversation, - getMessages, - isSubmitting, -}: ContextTrackerProps) { +export default function ContextTracker({ conversation }: ContextTrackerProps) { const localize = useLocalize(); const queryClient = useQueryClient(); const { data: startupConfig } = useGetStartupConfig(); const showContextTracker = useRecoilValue(store.showContextTracker); const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO; - const [usedTokens, setUsedTokens] = useState(() => getUsedTokens(getMessages())); + const subscribeToMessages = useCallback( + (onStoreChange: () => void) => + queryClient.getQueryCache().subscribe((event) => { + const queryKey = event?.query?.queryKey; + if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) { + return; + } + if (queryKey[1] !== conversationId) { + return; + } + onStoreChange(); + }), + [conversationId, queryClient], + ); - useEffect(() => { - setUsedTokens(getUsedTokens(getMessages())); - }, [conversationId, getMessages]); + const getMessagesSnapshot = useCallback( + () => queryClient.getQueryData([QueryKeys.messages, conversationId]), + [conversationId, queryClient], + ); - useEffect(() => { - const unsubscribe = queryClient.getQueryCache().subscribe((event) => { - const queryKey = event?.query?.queryKey; - if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) { - return; - } - - setUsedTokens(getUsedTokens(getMessages())); - }); - - return unsubscribe; - }, [getMessages, queryClient]); - - const prevIsSubmitting = useRef(isSubmitting); - useEffect(() => { - if (prevIsSubmitting.current && !isSubmitting) { - // Messages from SSE don't include tokenCount (server strips it). - // Invalidate to refetch from API which includes tokenCount from DB. - queryClient.invalidateQueries({ queryKey: [QueryKeys.messages] }); - } - prevIsSubmitting.current = isSubmitting; - }, [isSubmitting, queryClient]); + const messages = useSyncExternalStore( + subscribeToMessages, + getMessagesSnapshot, + getMessagesSnapshot, + ); + const usedTokens = useMemo(() => getUsedTokens(messages), [messages]); const maxContextTokens = typeof conversation?.maxContextTokens === 'number' && diff --git a/client/src/hooks/SSE/useEventHandlers.ts b/client/src/hooks/SSE/useEventHandlers.ts index 366775c4c1..b05e805fab 100644 --- a/client/src/hooks/SSE/useEventHandlers.ts +++ b/client/src/hooks/SSE/useEventHandlers.ts @@ -491,6 +491,27 @@ export default function useEventHandlers({ setMessages(_messages); queryClient.setQueryData([QueryKeys.messages, id], _messages); }; + const shouldRefreshTokenCounts = (targetMessages: TMessage[]) => + targetMessages.some((message) => { + if (message.isCreatedByUser) { + return false; + } + const tokenCount = message.tokenCount; + return ( + typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0 + ); + }); + const scheduleTokenCountRefresh = (targetConversationId: string, targetMessages: TMessage[]) => { + if (!shouldRefreshTokenCounts(targetMessages)) { + return; + } + setTimeout(() => { + void queryClient.invalidateQueries({ + queryKey: [QueryKeys.messages, targetConversationId], + exact: true, + }); + }, 0); + }; const hasNoResponse = responseMessage?.content?.[0]?.['text']?.value === @@ -545,6 +566,9 @@ export default function useEventHandlers({ if (finalMessages.length > 0) { setFinalMessages(conversation.conversationId, finalMessages); + if (conversation.conversationId) { + scheduleTokenCountRefresh(conversation.conversationId, finalMessages); + } } else if ( isAssistantsEndpoint(submissionConvo.endpoint) && (!submissionConvo.conversationId || diff --git a/packages/api/src/utils/tokenizer.ts b/packages/api/src/utils/tokenizer.ts index 4c638c948e..17baedcdcb 100644 --- a/packages/api/src/utils/tokenizer.ts +++ b/packages/api/src/utils/tokenizer.ts @@ -47,14 +47,29 @@ class Tokenizer { const TokenizerSingleton = new Tokenizer(); +export function resolveEncodingFromModel(model?: string): EncodingName { + if (typeof model === 'string' && model.toLowerCase().includes('claude')) { + return 'claude'; + } + return 'o200k_base'; +} + /** * Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding. * @param text - The text to count tokens in. Defaults to an empty string. + * @param modelOrEncoding - Optional model id or explicit encoding name. * @returns The number of tokens in the provided text. */ -export async function countTokens(text = ''): Promise { - await TokenizerSingleton.initEncoding('o200k_base'); - return TokenizerSingleton.getTokenCount(text, 'o200k_base'); +export async function countTokens( + text = '', + modelOrEncoding?: string | EncodingName, +): Promise { + const encoding = + modelOrEncoding === 'claude' || modelOrEncoding === 'o200k_base' + ? modelOrEncoding + : resolveEncodingFromModel(modelOrEncoding); + await TokenizerSingleton.initEncoding(encoding); + return TokenizerSingleton.getTokenCount(text, encoding); } export default TokenizerSingleton;