diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 905cadfd23..60c3d40cb5 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -585,6 +585,8 @@ class BaseClient { responseMessage.text = completion.join(''); } + /** @type {(() => Promise) | null} */ + let reconcileTokenCount = null; if (tokenCountMap && this.recordTokenUsage && this.getTokenCountForResponse) { let completionTokens; @@ -599,25 +601,38 @@ class BaseClient { responseMessage.tokenCount = usage[this.outputTokensKey]; completionTokens = responseMessage.tokenCount; } else { - responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); - completionTokens = responseMessage.tokenCount; - await this.recordTokenUsage({ - usage, - promptTokens, - completionTokens, - balance: balanceConfig, - /** Note: When using agents, responseMessage.model is the agent ID, not the model */ - model: this.model, - messageId: this.responseMessageId, - }); + reconcileTokenCount = async () => { + const tokenCount = this.getTokenCountForResponse(responseMessage); + await this.recordTokenUsage({ + usage, + promptTokens, + completionTokens: tokenCount, + balance: balanceConfig, + /** Note: When using agents, responseMessage.model is the agent ID, not the model */ + model: this.model, + messageId: this.responseMessageId, + }); + await this.updateMessageInDatabase({ + messageId: this.responseMessageId, + tokenCount, + }); + logger.debug('[BaseClient] Async response token reconciliation complete', { + messageId: responseMessage.messageId, + model: responseMessage.model, + promptTokens, + completionTokens: tokenCount, + }); + }; } - logger.debug('[BaseClient] Response token usage', { - messageId: responseMessage.messageId, - model: responseMessage.model, - promptTokens, - completionTokens, - }); + if (completionTokens != null) { + logger.debug('[BaseClient] Response token usage', { + messageId: responseMessage.messageId, + model: responseMessage.model, + promptTokens, + completionTokens, + }); + } } if (userMessagePromise) { @@ -666,6 +681,13 @@ class BaseClient { saveOptions, user, ); + if (reconcileTokenCount != null) { + void responseMessage.databasePromise + .then(() => reconcileTokenCount()) + .catch((error) => { + logger.error('[BaseClient] Async token reconciliation failed', error); + }); + } this.savedMessageIds.add(responseMessage.messageId); delete responseMessage.tokenCount; return responseMessage; diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index 9e0ad7f382..08ccf2a45f 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -32,6 +32,7 @@ import CollapseChat from './CollapseChat'; import StreamAudio from './StreamAudio'; import StopButton from './StopButton'; import SendButton from './SendButton'; +import ContextTracker from './ContextTracker'; import EditBadges from './EditBadges'; import BadgeRow from './BadgeRow'; import Mention from './Mention'; @@ -318,7 +319,7 @@ const ChatForm = memo(function ChatForm({ onBlur={handleTextareaBlur} aria-label={localize('com_ui_message_input')} onClick={handleFocusOrClick} - style={{ height: 44, overflowY: 'auto' }} + style={{ height: 44 }} className={cn( baseClasses, removeFocusRings, @@ -372,7 +373,8 @@ const ChatForm = memo(function ChatForm({ isSubmitting={isSubmitting} /> )} -
+
+ {endpoint && } {isSubmitting && showStopButton ? ( ) : ( diff --git a/client/src/components/Chat/Input/ContextTracker.tsx b/client/src/components/Chat/Input/ContextTracker.tsx new file mode 100644 index 0000000000..2f4485b81f --- /dev/null +++ b/client/src/components/Chat/Input/ContextTracker.tsx @@ -0,0 +1,408 @@ +import { HoverCard, HoverCardContent, HoverCardPortal, HoverCardTrigger } from '@librechat/client'; +import { useQueryClient } from '@tanstack/react-query'; +import type { TConversation, TMessage, TModelSpec, TStartupConfig } from 'librechat-data-provider'; +import { Constants, QueryKeys } from 'librechat-data-provider'; +import { memo, useCallback, useMemo, useSyncExternalStore } from 'react'; +import { useRecoilValue } from 'recoil'; +import { useGetStartupConfig } from '~/data-provider'; +import { useLocalize } from '~/hooks'; +import store from '~/store'; +import { cn } from '~/utils'; + +type ContextTrackerProps = { + conversation: TConversation | null; +}; + +type MessageWithTokenCount = TMessage & { tokenCount?: number }; +type TokenTotals = { inputTokens: number; outputTokens: number; totalUsed: number }; + +const TRACKER_SIZE = 28; +const TRACKER_STROKE = 3.5; + +const formatTokenCount = (count: number): string => { + const formatted = new Intl.NumberFormat(undefined, { + notation: 'compact', + maximumFractionDigits: 1, + }).format(count); + return formatted.replace(/\.0(?=[A-Za-z]|$)/, ''); +}; + +const getTokenTotals = (messages: TMessage[] | undefined): TokenTotals => { + if (!messages?.length) { + return { inputTokens: 0, outputTokens: 0, totalUsed: 0 }; + } + + const totals = messages.reduce( + (accumulator: Omit, message) => { + const tokenCount = (message as MessageWithTokenCount).tokenCount; + if (typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0) { + return accumulator; + } + + if (message.isCreatedByUser) { + accumulator.inputTokens += tokenCount; + } else { + accumulator.outputTokens += tokenCount; + } + + return accumulator; + }, + { inputTokens: 0, outputTokens: 0 }, + ); + + return { + inputTokens: totals.inputTokens, + outputTokens: totals.outputTokens, + totalUsed: totals.inputTokens + totals.outputTokens, + }; +}; + +const getSpecMaxContextTokens = ( + startupConfig: TStartupConfig | undefined, + specName: string | null | undefined, +): number | null => { + if (!specName) { + return null; + } + + const modelSpec = startupConfig?.modelSpecs?.list?.find( + (spec: TModelSpec) => spec.name === specName, + ); + const maxContextTokens = modelSpec?.preset?.maxContextTokens; + + if ( + typeof maxContextTokens !== 'number' || + !Number.isFinite(maxContextTokens) || + maxContextTokens <= 0 + ) { + return null; + } + + return maxContextTokens; +}; + +type ProgressBarProps = { + value: number; + max: number; + colorClass: string; + label: string; + showPercentage?: boolean; + indeterminate?: boolean; +}; + +function ProgressBar({ + value, + max, + colorClass, + label, + showPercentage = false, + indeterminate = false, +}: ProgressBarProps) { + const percentage = max > 0 ? Math.min((value / max) * 100, 100) : 0; + + return ( +
+
+ {indeterminate ? ( +
+ ) : ( +
+
+
+
+ )} +
+ {showPercentage && !indeterminate ? ( + + ) : null} +
+ ); +} + +type TokenRowProps = { + label: string; + value: number; + max: number | null; + colorClass: string; + ariaLabel: string; +}; + +function TokenRow({ label, value, max, colorClass, ariaLabel }: TokenRowProps) { + const hasMax = max != null && max > 0; + const percentage = hasMax ? Math.round(Math.min((value / max) * 100, 100)) : 0; + + return ( +
+
+ {label} + + {formatTokenCount(value)} + {hasMax ? ( + + ) : null} + +
+ +
+ ); +} + +function ContextTracker({ conversation }: ContextTrackerProps) { + const localize = useLocalize(); + const queryClient = useQueryClient(); + const { data: startupConfig } = useGetStartupConfig(); + const showContextTracker = useRecoilValue(store.showContextTracker); + const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO; + const subscribeToMessages = useCallback( + (onStoreChange: () => void) => + queryClient.getQueryCache().subscribe((event) => { + const queryKey = event?.query?.queryKey; + if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) { + return; + } + if (queryKey[1] !== conversationId) { + return; + } + onStoreChange(); + }), + [conversationId, queryClient], + ); + + const getMessagesSnapshot = useCallback( + () => queryClient.getQueryData([QueryKeys.messages, conversationId]), + [conversationId, queryClient], + ); + + const messages = useSyncExternalStore( + subscribeToMessages, + getMessagesSnapshot, + getMessagesSnapshot, + ); + const { inputTokens, outputTokens, totalUsed } = useMemo( + () => getTokenTotals(messages), + [messages], + ); + + const maxContextTokens = + typeof conversation?.maxContextTokens === 'number' && + Number.isFinite(conversation.maxContextTokens) && + conversation.maxContextTokens > 0 + ? conversation.maxContextTokens + : getSpecMaxContextTokens(startupConfig, conversation?.spec); + + const hasMaxContext = maxContextTokens != null && maxContextTokens > 0; + const usageRatio = useMemo(() => { + if (!hasMaxContext || maxContextTokens == null) { + return 0; + } + + return Math.min(totalUsed / maxContextTokens, 1); + }, [hasMaxContext, maxContextTokens, totalUsed]); + const percentage = Math.round(usageRatio * 100); + const inputPercentage = + hasMaxContext && maxContextTokens != null + ? Math.round(Math.min((inputTokens / maxContextTokens) * 100, 100)) + : 0; + const outputPercentage = + hasMaxContext && maxContextTokens != null + ? Math.round(Math.min((outputTokens / maxContextTokens) * 100, 100)) + : 0; + + const trackerRadius = useMemo(() => (TRACKER_SIZE - TRACKER_STROKE) / 2, []); + const circumference = useMemo(() => 2 * Math.PI * trackerRadius, [trackerRadius]); + const dashOffset = useMemo( + () => circumference - (percentage / 100) * circumference, + [circumference, percentage], + ); + + const getRingColorClass = () => { + if (!hasMaxContext) { + return 'stroke-text-secondary'; + } + if (percentage > 90) { + return 'stroke-red-500'; + } + if (percentage > 75) { + return 'stroke-yellow-500'; + } + return 'stroke-green-500'; + }; + + const getMainProgressColorClass = () => { + if (!hasMaxContext) { + return 'bg-text-secondary'; + } + if (percentage > 90) { + return 'bg-red-500'; + } + if (percentage > 75) { + return 'bg-yellow-500'; + } + return 'bg-green-500'; + }; + + const ariaLabel = hasMaxContext + ? localize('com_ui_token_usage_aria_full', { + 0: formatTokenCount(inputTokens), + 1: formatTokenCount(outputTokens), + 2: formatTokenCount(maxContextTokens ?? 0), + 3: percentage.toString(), + }) + : localize('com_ui_token_usage_aria_no_max', { + 0: formatTokenCount(inputTokens), + 1: formatTokenCount(outputTokens), + 2: formatTokenCount(totalUsed), + }); + + if (!showContextTracker) { + return null; + } + + return ( + + + + + + +
+
+ + {localize('com_ui_context_usage')} + + {hasMaxContext ? ( + 90, + 'text-yellow-500': percentage > 75 && percentage <= 90, + 'text-green-500': percentage <= 75, + })} + > + {localize('com_ui_token_usage_percent', { 0: percentage.toString() })} + + ) : null} +
+ +
+ + +
+ +
+ +
+ + +
+
+ + + + ); +} + +export default memo(ContextTracker); diff --git a/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx b/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx index 8d4f1817f3..00769c8ff0 100644 --- a/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx +++ b/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx @@ -84,6 +84,13 @@ const toggleSwitchConfigs = [ hoverCardText: 'com_nav_info_save_badges_state' as const, key: 'showBadges', }, + { + stateAtom: store.showContextTracker, + localizationKey: 'com_nav_show_context_tracker' as const, + switchId: 'showContextTracker', + hoverCardText: undefined, + key: 'showContextTracker', + }, { stateAtom: store.modularChat, localizationKey: 'com_nav_modular_chat' as const, diff --git a/client/src/hooks/SSE/useEventHandlers.ts b/client/src/hooks/SSE/useEventHandlers.ts index 366775c4c1..b05e805fab 100644 --- a/client/src/hooks/SSE/useEventHandlers.ts +++ b/client/src/hooks/SSE/useEventHandlers.ts @@ -491,6 +491,27 @@ export default function useEventHandlers({ setMessages(_messages); queryClient.setQueryData([QueryKeys.messages, id], _messages); }; + const shouldRefreshTokenCounts = (targetMessages: TMessage[]) => + targetMessages.some((message) => { + if (message.isCreatedByUser) { + return false; + } + const tokenCount = message.tokenCount; + return ( + typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0 + ); + }); + const scheduleTokenCountRefresh = (targetConversationId: string, targetMessages: TMessage[]) => { + if (!shouldRefreshTokenCounts(targetMessages)) { + return; + } + setTimeout(() => { + void queryClient.invalidateQueries({ + queryKey: [QueryKeys.messages, targetConversationId], + exact: true, + }); + }, 0); + }; const hasNoResponse = responseMessage?.content?.[0]?.['text']?.value === @@ -545,6 +566,9 @@ export default function useEventHandlers({ if (finalMessages.length > 0) { setFinalMessages(conversation.conversationId, finalMessages); + if (conversation.conversationId) { + scheduleTokenCountRefresh(conversation.conversationId, finalMessages); + } } else if ( isAssistantsEndpoint(submissionConvo.endpoint) && (!submissionConvo.conversationId || diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 3d19f65ad6..e6cacb2a2d 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -570,6 +570,7 @@ "com_nav_plus_command_description": "Toggle command \"+\" for adding a multi-response setting", "com_nav_profile_picture": "Profile Picture", "com_nav_save_badges_state": "Save badges state", + "com_nav_show_context_tracker": "Show context tracker", "com_nav_save_drafts": "Save drafts locally", "com_nav_scroll_button": "Scroll to the end button", "com_nav_search_placeholder": "Search messages", @@ -1522,6 +1523,16 @@ "com_ui_upload_provider": "Upload to Provider", "com_ui_upload_success": "Successfully uploaded file", "com_ui_upload_type": "Select Upload Type", + "com_ui_context_usage": "Context usage", + "com_ui_context_usage_unknown_max": "{{0}} tokens used (max unavailable)", + "com_ui_context_usage_with_max": "{{0}} ยท {{1}} / {{2}} context used", + "com_ui_token_usage_input": "Input", + "com_ui_token_usage_output": "Output", + "com_ui_token_usage_percent": "{{0}}% used", + "com_ui_token_usage_input_aria": "Input usage: {{0}} of {{1}} max context, {{2}}% used", + "com_ui_token_usage_output_aria": "Output usage: {{0}} of {{1}} max context, {{2}}% used", + "com_ui_token_usage_aria_full": "Token usage: {{0}} input, {{1}} output, {{2}} max context, {{3}}% used", + "com_ui_token_usage_aria_no_max": "Token usage: {{0}} input, {{1}} output, {{2}} total tokens used", "com_ui_usage": "Usage", "com_ui_use_2fa_code": "Use 2FA Code Instead", "com_ui_use_backup_code": "Use Backup Code Instead", diff --git a/client/src/store/settings.ts b/client/src/store/settings.ts index 2a2796ad59..4d163a9621 100644 --- a/client/src/store/settings.ts +++ b/client/src/store/settings.ts @@ -39,6 +39,7 @@ const localStorageAtoms = { rememberDefaultFork: atomWithLocalStorage(LocalStorageKeys.REMEMBER_FORK_OPTION, false), showThinking: atomWithLocalStorage('showThinking', false), saveBadgesState: atomWithLocalStorage('saveBadgesState', false), + showContextTracker: atomWithLocalStorage('showContextTracker', true), // Beta features settings modularChat: atomWithLocalStorage('modularChat', true), diff --git a/packages/api/src/utils/tokenizer.ts b/packages/api/src/utils/tokenizer.ts index 4c638c948e..17baedcdcb 100644 --- a/packages/api/src/utils/tokenizer.ts +++ b/packages/api/src/utils/tokenizer.ts @@ -47,14 +47,29 @@ class Tokenizer { const TokenizerSingleton = new Tokenizer(); +export function resolveEncodingFromModel(model?: string): EncodingName { + if (typeof model === 'string' && model.toLowerCase().includes('claude')) { + return 'claude'; + } + return 'o200k_base'; +} + /** * Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding. * @param text - The text to count tokens in. Defaults to an empty string. + * @param modelOrEncoding - Optional model id or explicit encoding name. * @returns The number of tokens in the provided text. */ -export async function countTokens(text = ''): Promise { - await TokenizerSingleton.initEncoding('o200k_base'); - return TokenizerSingleton.getTokenCount(text, 'o200k_base'); +export async function countTokens( + text = '', + modelOrEncoding?: string | EncodingName, +): Promise { + const encoding = + modelOrEncoding === 'claude' || modelOrEncoding === 'o200k_base' + ? modelOrEncoding + : resolveEncodingFromModel(modelOrEncoding); + await TokenizerSingleton.initEncoding(encoding); + return TokenizerSingleton.getTokenCount(text, encoding); } export default TokenizerSingleton;