Better end of stream state

This commit is contained in:
Murillo Camargo 2026-03-30 14:30:59 -07:00
parent c944cd6074
commit 8e9f926e42
5 changed files with 110 additions and 67 deletions

View file

@ -585,6 +585,8 @@ class BaseClient {
responseMessage.text = completion.join('');
}
/** @type {(() => Promise<void>) | null} */
let reconcileTokenCount = null;
if (tokenCountMap && this.recordTokenUsage && this.getTokenCountForResponse) {
let completionTokens;
@ -599,25 +601,38 @@ class BaseClient {
responseMessage.tokenCount = usage[this.outputTokensKey];
completionTokens = responseMessage.tokenCount;
} else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens,
balance: balanceConfig,
/** Note: When using agents, responseMessage.model is the agent ID, not the model */
model: this.model,
messageId: this.responseMessageId,
});
reconcileTokenCount = async () => {
const tokenCount = this.getTokenCountForResponse(responseMessage);
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens: tokenCount,
balance: balanceConfig,
/** Note: When using agents, responseMessage.model is the agent ID, not the model */
model: this.model,
messageId: this.responseMessageId,
});
await this.updateMessageInDatabase({
messageId: this.responseMessageId,
tokenCount,
});
logger.debug('[BaseClient] Async response token reconciliation complete', {
messageId: responseMessage.messageId,
model: responseMessage.model,
promptTokens,
completionTokens: tokenCount,
});
};
}
logger.debug('[BaseClient] Response token usage', {
messageId: responseMessage.messageId,
model: responseMessage.model,
promptTokens,
completionTokens,
});
if (completionTokens != null) {
logger.debug('[BaseClient] Response token usage', {
messageId: responseMessage.messageId,
model: responseMessage.model,
promptTokens,
completionTokens,
});
}
}
if (userMessagePromise) {
@ -666,6 +681,13 @@ class BaseClient {
saveOptions,
user,
);
if (reconcileTokenCount != null) {
void responseMessage.databasePromise
.then(() => reconcileTokenCount())
.catch((error) => {
logger.error('[BaseClient] Async token reconciliation failed', error);
});
}
this.savedMessageIds.add(responseMessage.messageId);
delete responseMessage.tokenCount;
return responseMessage;

View file

@ -3,7 +3,7 @@ import { useWatch } from 'react-hook-form';
import { TextareaAutosize } from '@librechat/client';
import { useRecoilState, useRecoilValue } from 'recoil';
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
import type { TConversation, TMessage } from 'librechat-data-provider';
import type { TConversation } from 'librechat-data-provider';
import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common';
import {
useChatContext,
@ -49,7 +49,6 @@ interface ChatFormProps {
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
newConversation: ConvoGenerator;
handleStopGenerating: (e: React.MouseEvent<HTMLButtonElement>) => void;
getMessages: () => TMessage[] | undefined;
}
const ChatForm = memo(function ChatForm({
@ -62,7 +61,6 @@ const ChatForm = memo(function ChatForm({
setFilesLoading,
newConversation,
handleStopGenerating,
getMessages,
}: ChatFormProps) {
const submitButtonRef = useRef<HTMLButtonElement>(null);
const textAreaRef = useRef<HTMLTextAreaElement>(null);
@ -321,7 +319,7 @@ const ChatForm = memo(function ChatForm({
onBlur={handleTextareaBlur}
aria-label={localize('com_ui_message_input')}
onClick={handleFocusOrClick}
style={{ height: 44, overflowY: 'auto' }}
style={{ height: 44 }}
className={cn(
baseClasses,
removeFocusRings,
@ -376,9 +374,7 @@ const ChatForm = memo(function ChatForm({
/>
)}
<div className={cn('flex items-center gap-2', isRTL ? 'ml-2' : 'mr-2')}>
{endpoint && (
<ContextTracker conversation={conversation} getMessages={getMessages} isSubmitting={isSubmitting} />
)}
{endpoint && <ContextTracker conversation={conversation} />}
{isSubmitting && showStopButton ? (
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
) : (
@ -416,7 +412,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) {
setFilesLoading,
newConversation,
handleStopGenerating,
getMessages,
} = useChatContext();
/**
@ -456,10 +451,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) {
[],
);
const getMessagesRef = useRef(getMessages);
getMessagesRef.current = getMessages;
const stableGetMessages = useCallback(() => getMessagesRef.current(), []);
return (
<ChatForm
index={index}
@ -471,7 +462,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) {
setFilesLoading={setFilesLoading}
newConversation={stableNewConversation}
handleStopGenerating={stableHandleStop}
getMessages={stableGetMessages}
/>
);
}

View file

@ -1,4 +1,4 @@
import { useEffect, useMemo, useRef, useState } from 'react';
import { useCallback, useMemo, useSyncExternalStore } from 'react';
import { useQueryClient } from '@tanstack/react-query';
import { QueryKeys, Constants } from 'librechat-data-provider';
import type { TConversation, TMessage, TModelSpec, TStartupConfig } from 'librechat-data-provider';
@ -11,8 +11,6 @@ import { cn } from '~/utils';
type ContextTrackerProps = {
conversation: TConversation | null;
getMessages: () => TMessage[] | undefined;
isSubmitting: boolean;
};
type MessageWithTokenCount = TMessage & { tokenCount?: number };
@ -68,44 +66,38 @@ const getSpecMaxContextTokens = (
return maxContextTokens;
};
export default function ContextTracker({
conversation,
getMessages,
isSubmitting,
}: ContextTrackerProps) {
export default function ContextTracker({ conversation }: ContextTrackerProps) {
const localize = useLocalize();
const queryClient = useQueryClient();
const { data: startupConfig } = useGetStartupConfig();
const showContextTracker = useRecoilValue(store.showContextTracker);
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
const [usedTokens, setUsedTokens] = useState<number>(() => getUsedTokens(getMessages()));
const subscribeToMessages = useCallback(
(onStoreChange: () => void) =>
queryClient.getQueryCache().subscribe((event) => {
const queryKey = event?.query?.queryKey;
if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) {
return;
}
if (queryKey[1] !== conversationId) {
return;
}
onStoreChange();
}),
[conversationId, queryClient],
);
useEffect(() => {
setUsedTokens(getUsedTokens(getMessages()));
}, [conversationId, getMessages]);
const getMessagesSnapshot = useCallback(
() => queryClient.getQueryData<TMessage[]>([QueryKeys.messages, conversationId]),
[conversationId, queryClient],
);
useEffect(() => {
const unsubscribe = queryClient.getQueryCache().subscribe((event) => {
const queryKey = event?.query?.queryKey;
if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) {
return;
}
setUsedTokens(getUsedTokens(getMessages()));
});
return unsubscribe;
}, [getMessages, queryClient]);
const prevIsSubmitting = useRef(isSubmitting);
useEffect(() => {
if (prevIsSubmitting.current && !isSubmitting) {
// Messages from SSE don't include tokenCount (server strips it).
// Invalidate to refetch from API which includes tokenCount from DB.
queryClient.invalidateQueries({ queryKey: [QueryKeys.messages] });
}
prevIsSubmitting.current = isSubmitting;
}, [isSubmitting, queryClient]);
const messages = useSyncExternalStore(
subscribeToMessages,
getMessagesSnapshot,
getMessagesSnapshot,
);
const usedTokens = useMemo(() => getUsedTokens(messages), [messages]);
const maxContextTokens =
typeof conversation?.maxContextTokens === 'number' &&

View file

@ -491,6 +491,27 @@ export default function useEventHandlers({
setMessages(_messages);
queryClient.setQueryData<TMessage[]>([QueryKeys.messages, id], _messages);
};
const shouldRefreshTokenCounts = (targetMessages: TMessage[]) =>
targetMessages.some((message) => {
if (message.isCreatedByUser) {
return false;
}
const tokenCount = message.tokenCount;
return (
typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0
);
});
const scheduleTokenCountRefresh = (targetConversationId: string, targetMessages: TMessage[]) => {
if (!shouldRefreshTokenCounts(targetMessages)) {
return;
}
setTimeout(() => {
void queryClient.invalidateQueries({
queryKey: [QueryKeys.messages, targetConversationId],
exact: true,
});
}, 0);
};
const hasNoResponse =
responseMessage?.content?.[0]?.['text']?.value ===
@ -545,6 +566,9 @@ export default function useEventHandlers({
if (finalMessages.length > 0) {
setFinalMessages(conversation.conversationId, finalMessages);
if (conversation.conversationId) {
scheduleTokenCountRefresh(conversation.conversationId, finalMessages);
}
} else if (
isAssistantsEndpoint(submissionConvo.endpoint) &&
(!submissionConvo.conversationId ||

View file

@ -47,14 +47,29 @@ class Tokenizer {
const TokenizerSingleton = new Tokenizer();
export function resolveEncodingFromModel(model?: string): EncodingName {
if (typeof model === 'string' && model.toLowerCase().includes('claude')) {
return 'claude';
}
return 'o200k_base';
}
/**
* Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding.
* @param text - The text to count tokens in. Defaults to an empty string.
* @param modelOrEncoding - Optional model id or explicit encoding name.
* @returns The number of tokens in the provided text.
*/
export async function countTokens(text = ''): Promise<number> {
await TokenizerSingleton.initEncoding('o200k_base');
return TokenizerSingleton.getTokenCount(text, 'o200k_base');
export async function countTokens(
text = '',
modelOrEncoding?: string | EncodingName,
): Promise<number> {
const encoding =
modelOrEncoding === 'claude' || modelOrEncoding === 'o200k_base'
? modelOrEncoding
: resolveEncodingFromModel(modelOrEncoding);
await TokenizerSingleton.initEncoding(encoding);
return TokenizerSingleton.getTokenCount(text, encoding);
}
export default TokenizerSingleton;