This commit is contained in:
murillocamargo 2026-04-05 00:10:58 -03:00 committed by GitHub
commit e63ab2877a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 512 additions and 22 deletions

View file

@ -585,6 +585,8 @@ class BaseClient {
responseMessage.text = completion.join('');
}
/** @type {(() => Promise<void>) | null} */
let reconcileTokenCount = null;
if (tokenCountMap && this.recordTokenUsage && this.getTokenCountForResponse) {
let completionTokens;
@ -599,25 +601,38 @@ class BaseClient {
responseMessage.tokenCount = usage[this.outputTokensKey];
completionTokens = responseMessage.tokenCount;
} else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens,
balance: balanceConfig,
/** Note: When using agents, responseMessage.model is the agent ID, not the model */
model: this.model,
messageId: this.responseMessageId,
});
reconcileTokenCount = async () => {
const tokenCount = this.getTokenCountForResponse(responseMessage);
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens: tokenCount,
balance: balanceConfig,
/** Note: When using agents, responseMessage.model is the agent ID, not the model */
model: this.model,
messageId: this.responseMessageId,
});
await this.updateMessageInDatabase({
messageId: this.responseMessageId,
tokenCount,
});
logger.debug('[BaseClient] Async response token reconciliation complete', {
messageId: responseMessage.messageId,
model: responseMessage.model,
promptTokens,
completionTokens: tokenCount,
});
};
}
logger.debug('[BaseClient] Response token usage', {
messageId: responseMessage.messageId,
model: responseMessage.model,
promptTokens,
completionTokens,
});
if (completionTokens != null) {
logger.debug('[BaseClient] Response token usage', {
messageId: responseMessage.messageId,
model: responseMessage.model,
promptTokens,
completionTokens,
});
}
}
if (userMessagePromise) {
@ -666,6 +681,13 @@ class BaseClient {
saveOptions,
user,
);
if (reconcileTokenCount != null) {
void responseMessage.databasePromise
.then(() => reconcileTokenCount())
.catch((error) => {
logger.error('[BaseClient] Async token reconciliation failed', error);
});
}
this.savedMessageIds.add(responseMessage.messageId);
delete responseMessage.tokenCount;
return responseMessage;

View file

@ -32,6 +32,7 @@ import CollapseChat from './CollapseChat';
import StreamAudio from './StreamAudio';
import StopButton from './StopButton';
import SendButton from './SendButton';
import ContextTracker from './ContextTracker';
import EditBadges from './EditBadges';
import BadgeRow from './BadgeRow';
import Mention from './Mention';
@ -318,7 +319,7 @@ const ChatForm = memo(function ChatForm({
onBlur={handleTextareaBlur}
aria-label={localize('com_ui_message_input')}
onClick={handleFocusOrClick}
style={{ height: 44, overflowY: 'auto' }}
style={{ height: 44 }}
className={cn(
baseClasses,
removeFocusRings,
@ -372,7 +373,8 @@ const ChatForm = memo(function ChatForm({
isSubmitting={isSubmitting}
/>
)}
<div className={`${isRTL ? 'ml-2' : 'mr-2'}`}>
<div className={cn('flex items-center gap-2', isRTL ? 'ml-2' : 'mr-2')}>
{endpoint && <ContextTracker conversation={conversation} />}
{isSubmitting && showStopButton ? (
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
) : (

View file

@ -0,0 +1,408 @@
import { HoverCard, HoverCardContent, HoverCardPortal, HoverCardTrigger } from '@librechat/client';
import { useQueryClient } from '@tanstack/react-query';
import type { TConversation, TMessage, TModelSpec, TStartupConfig } from 'librechat-data-provider';
import { Constants, QueryKeys } from 'librechat-data-provider';
import { memo, useCallback, useMemo, useSyncExternalStore } from 'react';
import { useRecoilValue } from 'recoil';
import { useGetStartupConfig } from '~/data-provider';
import { useLocalize } from '~/hooks';
import store from '~/store';
import { cn } from '~/utils';
type ContextTrackerProps = {
conversation: TConversation | null;
};
type MessageWithTokenCount = TMessage & { tokenCount?: number };
type TokenTotals = { inputTokens: number; outputTokens: number; totalUsed: number };
const TRACKER_SIZE = 28;
const TRACKER_STROKE = 3.5;
const formatTokenCount = (count: number): string => {
const formatted = new Intl.NumberFormat(undefined, {
notation: 'compact',
maximumFractionDigits: 1,
}).format(count);
return formatted.replace(/\.0(?=[A-Za-z]|$)/, '');
};
const getTokenTotals = (messages: TMessage[] | undefined): TokenTotals => {
if (!messages?.length) {
return { inputTokens: 0, outputTokens: 0, totalUsed: 0 };
}
const totals = messages.reduce(
(accumulator: Omit<TokenTotals, 'totalUsed'>, message) => {
const tokenCount = (message as MessageWithTokenCount).tokenCount;
if (typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0) {
return accumulator;
}
if (message.isCreatedByUser) {
accumulator.inputTokens += tokenCount;
} else {
accumulator.outputTokens += tokenCount;
}
return accumulator;
},
{ inputTokens: 0, outputTokens: 0 },
);
return {
inputTokens: totals.inputTokens,
outputTokens: totals.outputTokens,
totalUsed: totals.inputTokens + totals.outputTokens,
};
};
const getSpecMaxContextTokens = (
startupConfig: TStartupConfig | undefined,
specName: string | null | undefined,
): number | null => {
if (!specName) {
return null;
}
const modelSpec = startupConfig?.modelSpecs?.list?.find(
(spec: TModelSpec) => spec.name === specName,
);
const maxContextTokens = modelSpec?.preset?.maxContextTokens;
if (
typeof maxContextTokens !== 'number' ||
!Number.isFinite(maxContextTokens) ||
maxContextTokens <= 0
) {
return null;
}
return maxContextTokens;
};
type ProgressBarProps = {
value: number;
max: number;
colorClass: string;
label: string;
showPercentage?: boolean;
indeterminate?: boolean;
};
function ProgressBar({
value,
max,
colorClass,
label,
showPercentage = false,
indeterminate = false,
}: ProgressBarProps) {
const percentage = max > 0 ? Math.min((value / max) * 100, 100) : 0;
return (
<div className="flex items-center gap-2">
<div
role="progressbar"
aria-valuenow={indeterminate ? undefined : Math.round(percentage)}
aria-valuemin={0}
aria-valuemax={100}
aria-label={label}
className="h-2 flex-1 overflow-hidden rounded-full bg-surface-secondary"
>
{indeterminate ? (
<div
className="h-full w-full rounded-full"
style={{
background:
'repeating-linear-gradient(-45deg, var(--border-medium), var(--border-medium) 4px, var(--surface-tertiary) 4px, var(--surface-tertiary) 8px)',
}}
/>
) : (
<div className="flex h-full rounded-full">
<div
className={cn('rounded-full transition-all duration-300', colorClass)}
style={{ width: `${percentage}%` }}
/>
<div className="flex-1 bg-surface-hover" />
</div>
)}
</div>
{showPercentage && !indeterminate ? (
<span className="min-w-[3rem] text-right text-xs text-text-secondary" aria-hidden="true">
{Math.round(percentage)}%
</span>
) : null}
</div>
);
}
type TokenRowProps = {
label: string;
value: number;
max: number | null;
colorClass: string;
ariaLabel: string;
};
function TokenRow({ label, value, max, colorClass, ariaLabel }: TokenRowProps) {
const hasMax = max != null && max > 0;
const percentage = hasMax ? Math.round(Math.min((value / max) * 100, 100)) : 0;
return (
<div className="space-y-1">
<div className="flex items-center justify-between text-sm">
<span className="text-text-secondary">{label}</span>
<span className="font-medium text-text-primary">
{formatTokenCount(value)}
{hasMax ? (
<span className="ml-1 text-xs text-text-secondary" aria-hidden="true">
({percentage}%)
</span>
) : null}
</span>
</div>
<ProgressBar
value={value}
max={hasMax ? max : 0}
colorClass={colorClass}
label={ariaLabel}
indeterminate={!hasMax}
/>
</div>
);
}
function ContextTracker({ conversation }: ContextTrackerProps) {
const localize = useLocalize();
const queryClient = useQueryClient();
const { data: startupConfig } = useGetStartupConfig();
const showContextTracker = useRecoilValue(store.showContextTracker);
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
const subscribeToMessages = useCallback(
(onStoreChange: () => void) =>
queryClient.getQueryCache().subscribe((event) => {
const queryKey = event?.query?.queryKey;
if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) {
return;
}
if (queryKey[1] !== conversationId) {
return;
}
onStoreChange();
}),
[conversationId, queryClient],
);
const getMessagesSnapshot = useCallback(
() => queryClient.getQueryData<TMessage[]>([QueryKeys.messages, conversationId]),
[conversationId, queryClient],
);
const messages = useSyncExternalStore(
subscribeToMessages,
getMessagesSnapshot,
getMessagesSnapshot,
);
const { inputTokens, outputTokens, totalUsed } = useMemo(
() => getTokenTotals(messages),
[messages],
);
const maxContextTokens =
typeof conversation?.maxContextTokens === 'number' &&
Number.isFinite(conversation.maxContextTokens) &&
conversation.maxContextTokens > 0
? conversation.maxContextTokens
: getSpecMaxContextTokens(startupConfig, conversation?.spec);
const hasMaxContext = maxContextTokens != null && maxContextTokens > 0;
const usageRatio = useMemo(() => {
if (!hasMaxContext || maxContextTokens == null) {
return 0;
}
return Math.min(totalUsed / maxContextTokens, 1);
}, [hasMaxContext, maxContextTokens, totalUsed]);
const percentage = Math.round(usageRatio * 100);
const inputPercentage =
hasMaxContext && maxContextTokens != null
? Math.round(Math.min((inputTokens / maxContextTokens) * 100, 100))
: 0;
const outputPercentage =
hasMaxContext && maxContextTokens != null
? Math.round(Math.min((outputTokens / maxContextTokens) * 100, 100))
: 0;
const trackerRadius = useMemo(() => (TRACKER_SIZE - TRACKER_STROKE) / 2, []);
const circumference = useMemo(() => 2 * Math.PI * trackerRadius, [trackerRadius]);
const dashOffset = useMemo(
() => circumference - (percentage / 100) * circumference,
[circumference, percentage],
);
const getRingColorClass = () => {
if (!hasMaxContext) {
return 'stroke-text-secondary';
}
if (percentage > 90) {
return 'stroke-red-500';
}
if (percentage > 75) {
return 'stroke-yellow-500';
}
return 'stroke-green-500';
};
const getMainProgressColorClass = () => {
if (!hasMaxContext) {
return 'bg-text-secondary';
}
if (percentage > 90) {
return 'bg-red-500';
}
if (percentage > 75) {
return 'bg-yellow-500';
}
return 'bg-green-500';
};
const ariaLabel = hasMaxContext
? localize('com_ui_token_usage_aria_full', {
0: formatTokenCount(inputTokens),
1: formatTokenCount(outputTokens),
2: formatTokenCount(maxContextTokens ?? 0),
3: percentage.toString(),
})
: localize('com_ui_token_usage_aria_no_max', {
0: formatTokenCount(inputTokens),
1: formatTokenCount(outputTokens),
2: formatTokenCount(totalUsed),
});
if (!showContextTracker) {
return null;
}
return (
<HoverCard openDelay={200} closeDelay={100}>
<HoverCardTrigger asChild>
<button
type="button"
className="flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring"
aria-label={ariaLabel}
aria-haspopup="dialog"
data-testid="context-tracker"
>
<svg
width={TRACKER_SIZE}
height={TRACKER_SIZE}
viewBox={`0 0 ${TRACKER_SIZE} ${TRACKER_SIZE}`}
className="rotate-[-90deg]"
aria-hidden="true"
focusable="false"
>
<circle
cx={TRACKER_SIZE / 2}
cy={TRACKER_SIZE / 2}
r={trackerRadius}
fill="transparent"
strokeWidth={TRACKER_STROKE}
className="stroke-border-heavy"
/>
<circle
cx={TRACKER_SIZE / 2}
cy={TRACKER_SIZE / 2}
r={trackerRadius}
fill="transparent"
strokeWidth={TRACKER_STROKE}
strokeDasharray={circumference}
strokeDashoffset={hasMaxContext ? dashOffset : circumference}
strokeLinecap="round"
className={cn('transition-all duration-300', getRingColorClass())}
/>
</svg>
</button>
</HoverCardTrigger>
<HoverCardPortal>
<HoverCardContent side="top" align="end" className="p-3">
<div
className="w-full space-y-3"
role="region"
aria-label={localize('com_ui_context_usage')}
>
<div className="flex items-center justify-between">
<span className="text-sm font-medium text-text-primary">
{localize('com_ui_context_usage')}
</span>
{hasMaxContext ? (
<span
className={cn('text-xs font-medium', {
'text-red-500': percentage > 90,
'text-yellow-500': percentage > 75 && percentage <= 90,
'text-green-500': percentage <= 75,
})}
>
{localize('com_ui_token_usage_percent', { 0: percentage.toString() })}
</span>
) : null}
</div>
<div className="space-y-1">
<ProgressBar
value={totalUsed}
max={hasMaxContext ? (maxContextTokens ?? 0) : 0}
colorClass={getMainProgressColorClass()}
label={
hasMaxContext
? localize('com_ui_context_usage_with_max', {
0: `${percentage}%`,
1: formatTokenCount(totalUsed),
2: formatTokenCount(maxContextTokens ?? 0),
})
: localize('com_ui_context_usage_unknown_max', {
0: formatTokenCount(totalUsed),
})
}
indeterminate={!hasMaxContext}
/>
<div className="flex justify-between text-xs text-text-secondary" aria-hidden="true">
<span>{formatTokenCount(totalUsed)}</span>
<span>{hasMaxContext ? formatTokenCount(maxContextTokens ?? 0) : '--'}</span>
</div>
</div>
<div className="border-t border-border-light" role="separator" />
<div className="space-y-3">
<TokenRow
label={localize('com_ui_token_usage_input')}
value={inputTokens}
max={maxContextTokens}
colorClass="bg-blue-500"
ariaLabel={localize('com_ui_token_usage_input_aria', {
0: formatTokenCount(inputTokens),
1: formatTokenCount(maxContextTokens ?? 0),
2: inputPercentage.toString(),
})}
/>
<TokenRow
label={localize('com_ui_token_usage_output')}
value={outputTokens}
max={maxContextTokens}
colorClass="bg-green-500"
ariaLabel={localize('com_ui_token_usage_output_aria', {
0: formatTokenCount(outputTokens),
1: formatTokenCount(maxContextTokens ?? 0),
2: outputPercentage.toString(),
})}
/>
</div>
</div>
</HoverCardContent>
</HoverCardPortal>
</HoverCard>
);
}
export default memo(ContextTracker);

View file

@ -84,6 +84,13 @@ const toggleSwitchConfigs = [
hoverCardText: 'com_nav_info_save_badges_state' as const,
key: 'showBadges',
},
{
stateAtom: store.showContextTracker,
localizationKey: 'com_nav_show_context_tracker' as const,
switchId: 'showContextTracker',
hoverCardText: undefined,
key: 'showContextTracker',
},
{
stateAtom: store.modularChat,
localizationKey: 'com_nav_modular_chat' as const,

View file

@ -491,6 +491,27 @@ export default function useEventHandlers({
setMessages(_messages);
queryClient.setQueryData<TMessage[]>([QueryKeys.messages, id], _messages);
};
const shouldRefreshTokenCounts = (targetMessages: TMessage[]) =>
targetMessages.some((message) => {
if (message.isCreatedByUser) {
return false;
}
const tokenCount = message.tokenCount;
return (
typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0
);
});
const scheduleTokenCountRefresh = (targetConversationId: string, targetMessages: TMessage[]) => {
if (!shouldRefreshTokenCounts(targetMessages)) {
return;
}
setTimeout(() => {
void queryClient.invalidateQueries({
queryKey: [QueryKeys.messages, targetConversationId],
exact: true,
});
}, 0);
};
const hasNoResponse =
responseMessage?.content?.[0]?.['text']?.value ===
@ -545,6 +566,9 @@ export default function useEventHandlers({
if (finalMessages.length > 0) {
setFinalMessages(conversation.conversationId, finalMessages);
if (conversation.conversationId) {
scheduleTokenCountRefresh(conversation.conversationId, finalMessages);
}
} else if (
isAssistantsEndpoint(submissionConvo.endpoint) &&
(!submissionConvo.conversationId ||

View file

@ -570,6 +570,7 @@
"com_nav_plus_command_description": "Toggle command \"+\" for adding a multi-response setting",
"com_nav_profile_picture": "Profile Picture",
"com_nav_save_badges_state": "Save badges state",
"com_nav_show_context_tracker": "Show context tracker",
"com_nav_save_drafts": "Save drafts locally",
"com_nav_scroll_button": "Scroll to the end button",
"com_nav_search_placeholder": "Search messages",
@ -1522,6 +1523,16 @@
"com_ui_upload_provider": "Upload to Provider",
"com_ui_upload_success": "Successfully uploaded file",
"com_ui_upload_type": "Select Upload Type",
"com_ui_context_usage": "Context usage",
"com_ui_context_usage_unknown_max": "{{0}} tokens used (max unavailable)",
"com_ui_context_usage_with_max": "{{0}} · {{1}} / {{2}} context used",
"com_ui_token_usage_input": "Input",
"com_ui_token_usage_output": "Output",
"com_ui_token_usage_percent": "{{0}}% used",
"com_ui_token_usage_input_aria": "Input usage: {{0}} of {{1}} max context, {{2}}% used",
"com_ui_token_usage_output_aria": "Output usage: {{0}} of {{1}} max context, {{2}}% used",
"com_ui_token_usage_aria_full": "Token usage: {{0}} input, {{1}} output, {{2}} max context, {{3}}% used",
"com_ui_token_usage_aria_no_max": "Token usage: {{0}} input, {{1}} output, {{2}} total tokens used",
"com_ui_usage": "Usage",
"com_ui_use_2fa_code": "Use 2FA Code Instead",
"com_ui_use_backup_code": "Use Backup Code Instead",

View file

@ -39,6 +39,7 @@ const localStorageAtoms = {
rememberDefaultFork: atomWithLocalStorage(LocalStorageKeys.REMEMBER_FORK_OPTION, false),
showThinking: atomWithLocalStorage('showThinking', false),
saveBadgesState: atomWithLocalStorage('saveBadgesState', false),
showContextTracker: atomWithLocalStorage('showContextTracker', true),
// Beta features settings
modularChat: atomWithLocalStorage('modularChat', true),

View file

@ -47,14 +47,29 @@ class Tokenizer {
const TokenizerSingleton = new Tokenizer();
export function resolveEncodingFromModel(model?: string): EncodingName {
if (typeof model === 'string' && model.toLowerCase().includes('claude')) {
return 'claude';
}
return 'o200k_base';
}
/**
* Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding.
* @param text - The text to count tokens in. Defaults to an empty string.
* @param modelOrEncoding - Optional model id or explicit encoding name.
* @returns The number of tokens in the provided text.
*/
export async function countTokens(text = ''): Promise<number> {
await TokenizerSingleton.initEncoding('o200k_base');
return TokenizerSingleton.getTokenCount(text, 'o200k_base');
export async function countTokens(
text = '',
modelOrEncoding?: string | EncodingName,
): Promise<number> {
const encoding =
modelOrEncoding === 'claude' || modelOrEncoding === 'o200k_base'
? modelOrEncoding
: resolveEncodingFromModel(modelOrEncoding);
await TokenizerSingleton.initEncoding(encoding);
return TokenizerSingleton.getTokenCount(text, encoding);
}
export default TokenizerSingleton;