mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
Better end of stream state
This commit is contained in:
parent
c944cd6074
commit
8e9f926e42
5 changed files with 110 additions and 67 deletions
|
|
@ -585,6 +585,8 @@ class BaseClient {
|
|||
responseMessage.text = completion.join('');
|
||||
}
|
||||
|
||||
/** @type {(() => Promise<void>) | null} */
|
||||
let reconcileTokenCount = null;
|
||||
if (tokenCountMap && this.recordTokenUsage && this.getTokenCountForResponse) {
|
||||
let completionTokens;
|
||||
|
||||
|
|
@ -599,25 +601,38 @@ class BaseClient {
|
|||
responseMessage.tokenCount = usage[this.outputTokensKey];
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
} else {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
await this.recordTokenUsage({
|
||||
usage,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
balance: balanceConfig,
|
||||
/** Note: When using agents, responseMessage.model is the agent ID, not the model */
|
||||
model: this.model,
|
||||
messageId: this.responseMessageId,
|
||||
});
|
||||
reconcileTokenCount = async () => {
|
||||
const tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
await this.recordTokenUsage({
|
||||
usage,
|
||||
promptTokens,
|
||||
completionTokens: tokenCount,
|
||||
balance: balanceConfig,
|
||||
/** Note: When using agents, responseMessage.model is the agent ID, not the model */
|
||||
model: this.model,
|
||||
messageId: this.responseMessageId,
|
||||
});
|
||||
await this.updateMessageInDatabase({
|
||||
messageId: this.responseMessageId,
|
||||
tokenCount,
|
||||
});
|
||||
logger.debug('[BaseClient] Async response token reconciliation complete', {
|
||||
messageId: responseMessage.messageId,
|
||||
model: responseMessage.model,
|
||||
promptTokens,
|
||||
completionTokens: tokenCount,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug('[BaseClient] Response token usage', {
|
||||
messageId: responseMessage.messageId,
|
||||
model: responseMessage.model,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
});
|
||||
if (completionTokens != null) {
|
||||
logger.debug('[BaseClient] Response token usage', {
|
||||
messageId: responseMessage.messageId,
|
||||
model: responseMessage.model,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (userMessagePromise) {
|
||||
|
|
@ -666,6 +681,13 @@ class BaseClient {
|
|||
saveOptions,
|
||||
user,
|
||||
);
|
||||
if (reconcileTokenCount != null) {
|
||||
void responseMessage.databasePromise
|
||||
.then(() => reconcileTokenCount())
|
||||
.catch((error) => {
|
||||
logger.error('[BaseClient] Async token reconciliation failed', error);
|
||||
});
|
||||
}
|
||||
this.savedMessageIds.add(responseMessage.messageId);
|
||||
delete responseMessage.tokenCount;
|
||||
return responseMessage;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import { useWatch } from 'react-hook-form';
|
|||
import { TextareaAutosize } from '@librechat/client';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import type { TConversation, TMessage } from 'librechat-data-provider';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common';
|
||||
import {
|
||||
useChatContext,
|
||||
|
|
@ -49,7 +49,6 @@ interface ChatFormProps {
|
|||
setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
newConversation: ConvoGenerator;
|
||||
handleStopGenerating: (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
getMessages: () => TMessage[] | undefined;
|
||||
}
|
||||
|
||||
const ChatForm = memo(function ChatForm({
|
||||
|
|
@ -62,7 +61,6 @@ const ChatForm = memo(function ChatForm({
|
|||
setFilesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
getMessages,
|
||||
}: ChatFormProps) {
|
||||
const submitButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
|
@ -321,7 +319,7 @@ const ChatForm = memo(function ChatForm({
|
|||
onBlur={handleTextareaBlur}
|
||||
aria-label={localize('com_ui_message_input')}
|
||||
onClick={handleFocusOrClick}
|
||||
style={{ height: 44, overflowY: 'auto' }}
|
||||
style={{ height: 44 }}
|
||||
className={cn(
|
||||
baseClasses,
|
||||
removeFocusRings,
|
||||
|
|
@ -376,9 +374,7 @@ const ChatForm = memo(function ChatForm({
|
|||
/>
|
||||
)}
|
||||
<div className={cn('flex items-center gap-2', isRTL ? 'ml-2' : 'mr-2')}>
|
||||
{endpoint && (
|
||||
<ContextTracker conversation={conversation} getMessages={getMessages} isSubmitting={isSubmitting} />
|
||||
)}
|
||||
{endpoint && <ContextTracker conversation={conversation} />}
|
||||
{isSubmitting && showStopButton ? (
|
||||
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
|
||||
) : (
|
||||
|
|
@ -416,7 +412,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) {
|
|||
setFilesLoading,
|
||||
newConversation,
|
||||
handleStopGenerating,
|
||||
getMessages,
|
||||
} = useChatContext();
|
||||
|
||||
/**
|
||||
|
|
@ -456,10 +451,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) {
|
|||
[],
|
||||
);
|
||||
|
||||
const getMessagesRef = useRef(getMessages);
|
||||
getMessagesRef.current = getMessages;
|
||||
const stableGetMessages = useCallback(() => getMessagesRef.current(), []);
|
||||
|
||||
return (
|
||||
<ChatForm
|
||||
index={index}
|
||||
|
|
@ -471,7 +462,6 @@ function ChatFormWrapper({ index = 0 }: { index?: number }) {
|
|||
setFilesLoading={setFilesLoading}
|
||||
newConversation={stableNewConversation}
|
||||
handleStopGenerating={stableHandleStop}
|
||||
getMessages={stableGetMessages}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useCallback, useMemo, useSyncExternalStore } from 'react';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { QueryKeys, Constants } from 'librechat-data-provider';
|
||||
import type { TConversation, TMessage, TModelSpec, TStartupConfig } from 'librechat-data-provider';
|
||||
|
|
@ -11,8 +11,6 @@ import { cn } from '~/utils';
|
|||
|
||||
type ContextTrackerProps = {
|
||||
conversation: TConversation | null;
|
||||
getMessages: () => TMessage[] | undefined;
|
||||
isSubmitting: boolean;
|
||||
};
|
||||
|
||||
type MessageWithTokenCount = TMessage & { tokenCount?: number };
|
||||
|
|
@ -68,44 +66,38 @@ const getSpecMaxContextTokens = (
|
|||
return maxContextTokens;
|
||||
};
|
||||
|
||||
export default function ContextTracker({
|
||||
conversation,
|
||||
getMessages,
|
||||
isSubmitting,
|
||||
}: ContextTrackerProps) {
|
||||
export default function ContextTracker({ conversation }: ContextTrackerProps) {
|
||||
const localize = useLocalize();
|
||||
const queryClient = useQueryClient();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const showContextTracker = useRecoilValue(store.showContextTracker);
|
||||
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
|
||||
const [usedTokens, setUsedTokens] = useState<number>(() => getUsedTokens(getMessages()));
|
||||
const subscribeToMessages = useCallback(
|
||||
(onStoreChange: () => void) =>
|
||||
queryClient.getQueryCache().subscribe((event) => {
|
||||
const queryKey = event?.query?.queryKey;
|
||||
if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) {
|
||||
return;
|
||||
}
|
||||
if (queryKey[1] !== conversationId) {
|
||||
return;
|
||||
}
|
||||
onStoreChange();
|
||||
}),
|
||||
[conversationId, queryClient],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setUsedTokens(getUsedTokens(getMessages()));
|
||||
}, [conversationId, getMessages]);
|
||||
const getMessagesSnapshot = useCallback(
|
||||
() => queryClient.getQueryData<TMessage[]>([QueryKeys.messages, conversationId]),
|
||||
[conversationId, queryClient],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribe = queryClient.getQueryCache().subscribe((event) => {
|
||||
const queryKey = event?.query?.queryKey;
|
||||
if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) {
|
||||
return;
|
||||
}
|
||||
|
||||
setUsedTokens(getUsedTokens(getMessages()));
|
||||
});
|
||||
|
||||
return unsubscribe;
|
||||
}, [getMessages, queryClient]);
|
||||
|
||||
const prevIsSubmitting = useRef(isSubmitting);
|
||||
useEffect(() => {
|
||||
if (prevIsSubmitting.current && !isSubmitting) {
|
||||
// Messages from SSE don't include tokenCount (server strips it).
|
||||
// Invalidate to refetch from API which includes tokenCount from DB.
|
||||
queryClient.invalidateQueries({ queryKey: [QueryKeys.messages] });
|
||||
}
|
||||
prevIsSubmitting.current = isSubmitting;
|
||||
}, [isSubmitting, queryClient]);
|
||||
const messages = useSyncExternalStore(
|
||||
subscribeToMessages,
|
||||
getMessagesSnapshot,
|
||||
getMessagesSnapshot,
|
||||
);
|
||||
const usedTokens = useMemo(() => getUsedTokens(messages), [messages]);
|
||||
|
||||
const maxContextTokens =
|
||||
typeof conversation?.maxContextTokens === 'number' &&
|
||||
|
|
|
|||
|
|
@ -491,6 +491,27 @@ export default function useEventHandlers({
|
|||
setMessages(_messages);
|
||||
queryClient.setQueryData<TMessage[]>([QueryKeys.messages, id], _messages);
|
||||
};
|
||||
const shouldRefreshTokenCounts = (targetMessages: TMessage[]) =>
|
||||
targetMessages.some((message) => {
|
||||
if (message.isCreatedByUser) {
|
||||
return false;
|
||||
}
|
||||
const tokenCount = message.tokenCount;
|
||||
return (
|
||||
typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0
|
||||
);
|
||||
});
|
||||
const scheduleTokenCountRefresh = (targetConversationId: string, targetMessages: TMessage[]) => {
|
||||
if (!shouldRefreshTokenCounts(targetMessages)) {
|
||||
return;
|
||||
}
|
||||
setTimeout(() => {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: [QueryKeys.messages, targetConversationId],
|
||||
exact: true,
|
||||
});
|
||||
}, 0);
|
||||
};
|
||||
|
||||
const hasNoResponse =
|
||||
responseMessage?.content?.[0]?.['text']?.value ===
|
||||
|
|
@ -545,6 +566,9 @@ export default function useEventHandlers({
|
|||
|
||||
if (finalMessages.length > 0) {
|
||||
setFinalMessages(conversation.conversationId, finalMessages);
|
||||
if (conversation.conversationId) {
|
||||
scheduleTokenCountRefresh(conversation.conversationId, finalMessages);
|
||||
}
|
||||
} else if (
|
||||
isAssistantsEndpoint(submissionConvo.endpoint) &&
|
||||
(!submissionConvo.conversationId ||
|
||||
|
|
|
|||
|
|
@ -47,14 +47,29 @@ class Tokenizer {
|
|||
|
||||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
export function resolveEncodingFromModel(model?: string): EncodingName {
|
||||
if (typeof model === 'string' && model.toLowerCase().includes('claude')) {
|
||||
return 'claude';
|
||||
}
|
||||
return 'o200k_base';
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding.
|
||||
* @param text - The text to count tokens in. Defaults to an empty string.
|
||||
* @param modelOrEncoding - Optional model id or explicit encoding name.
|
||||
* @returns The number of tokens in the provided text.
|
||||
*/
|
||||
export async function countTokens(text = ''): Promise<number> {
|
||||
await TokenizerSingleton.initEncoding('o200k_base');
|
||||
return TokenizerSingleton.getTokenCount(text, 'o200k_base');
|
||||
export async function countTokens(
|
||||
text = '',
|
||||
modelOrEncoding?: string | EncodingName,
|
||||
): Promise<number> {
|
||||
const encoding =
|
||||
modelOrEncoding === 'claude' || modelOrEncoding === 'o200k_base'
|
||||
? modelOrEncoding
|
||||
: resolveEncodingFromModel(modelOrEncoding);
|
||||
await TokenizerSingleton.initEncoding(encoding);
|
||||
return TokenizerSingleton.getTokenCount(text, encoding);
|
||||
}
|
||||
|
||||
export default TokenizerSingleton;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue