import { HoverCard, HoverCardContent, HoverCardPortal, HoverCardTrigger } from '@librechat/client'; import { useQueryClient } from '@tanstack/react-query'; import type { TConversation, TMessage, TModelSpec, TStartupConfig } from 'librechat-data-provider'; import { Constants, QueryKeys } from 'librechat-data-provider'; import { memo, useCallback, useMemo, useSyncExternalStore } from 'react'; import { useRecoilValue } from 'recoil'; import { useGetStartupConfig } from '~/data-provider'; import { useLocalize } from '~/hooks'; import store from '~/store'; import { cn } from '~/utils'; type ContextTrackerProps = { conversation: TConversation | null; }; type MessageWithTokenCount = TMessage & { tokenCount?: number }; type TokenTotals = { inputTokens: number; outputTokens: number; totalUsed: number }; const TRACKER_SIZE = 28; const TRACKER_STROKE = 3.5; const formatTokenCount = (count: number): string => { const formatted = new Intl.NumberFormat(undefined, { notation: 'compact', maximumFractionDigits: 1, }).format(count); return formatted.replace(/\.0(?=[A-Za-z]|$)/, ''); }; const getTokenTotals = (messages: TMessage[] | undefined): TokenTotals => { if (!messages?.length) { return { inputTokens: 0, outputTokens: 0, totalUsed: 0 }; } const totals = messages.reduce( (accumulator: Omit, message) => { const tokenCount = (message as MessageWithTokenCount).tokenCount; if (typeof tokenCount !== 'number' || !Number.isFinite(tokenCount) || tokenCount <= 0) { return accumulator; } if (message.isCreatedByUser) { accumulator.inputTokens += tokenCount; } else { accumulator.outputTokens += tokenCount; } return accumulator; }, { inputTokens: 0, outputTokens: 0 }, ); return { inputTokens: totals.inputTokens, outputTokens: totals.outputTokens, totalUsed: totals.inputTokens + totals.outputTokens, }; }; const getSpecMaxContextTokens = ( startupConfig: TStartupConfig | undefined, specName: string | null | undefined, ): number | null => { if (!specName) { return null; } const modelSpec = startupConfig?.modelSpecs?.list?.find( (spec: TModelSpec) => spec.name === specName, ); const maxContextTokens = modelSpec?.preset?.maxContextTokens; if ( typeof maxContextTokens !== 'number' || !Number.isFinite(maxContextTokens) || maxContextTokens <= 0 ) { return null; } return maxContextTokens; }; type ProgressBarProps = { value: number; max: number; colorClass: string; label: string; showPercentage?: boolean; indeterminate?: boolean; }; function ProgressBar({ value, max, colorClass, label, showPercentage = false, indeterminate = false, }: ProgressBarProps) { const percentage = max > 0 ? Math.min((value / max) * 100, 100) : 0; return (
{indeterminate ? (
) : (
)}
{showPercentage && !indeterminate ? ( ) : null}
); } type TokenRowProps = { label: string; value: number; max: number | null; colorClass: string; ariaLabel: string; }; function TokenRow({ label, value, max, colorClass, ariaLabel }: TokenRowProps) { const hasMax = max != null && max > 0; const percentage = hasMax ? Math.round(Math.min((value / max) * 100, 100)) : 0; return (
{label} {formatTokenCount(value)} {hasMax ? ( ) : null}
); } function ContextTracker({ conversation }: ContextTrackerProps) { const localize = useLocalize(); const queryClient = useQueryClient(); const { data: startupConfig } = useGetStartupConfig(); const showContextTracker = useRecoilValue(store.showContextTracker); const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO; const subscribeToMessages = useCallback( (onStoreChange: () => void) => queryClient.getQueryCache().subscribe((event) => { const queryKey = event?.query?.queryKey; if (!Array.isArray(queryKey) || queryKey[0] !== QueryKeys.messages) { return; } if (queryKey[1] !== conversationId) { return; } onStoreChange(); }), [conversationId, queryClient], ); const getMessagesSnapshot = useCallback( () => queryClient.getQueryData([QueryKeys.messages, conversationId]), [conversationId, queryClient], ); const messages = useSyncExternalStore( subscribeToMessages, getMessagesSnapshot, getMessagesSnapshot, ); const { inputTokens, outputTokens, totalUsed } = useMemo( () => getTokenTotals(messages), [messages], ); const maxContextTokens = typeof conversation?.maxContextTokens === 'number' && Number.isFinite(conversation.maxContextTokens) && conversation.maxContextTokens > 0 ? conversation.maxContextTokens : getSpecMaxContextTokens(startupConfig, conversation?.spec); const hasMaxContext = maxContextTokens != null && maxContextTokens > 0; const usageRatio = useMemo(() => { if (!hasMaxContext || maxContextTokens == null) { return 0; } return Math.min(totalUsed / maxContextTokens, 1); }, [hasMaxContext, maxContextTokens, totalUsed]); const percentage = Math.round(usageRatio * 100); const inputPercentage = hasMaxContext && maxContextTokens != null ? Math.round(Math.min((inputTokens / maxContextTokens) * 100, 100)) : 0; const outputPercentage = hasMaxContext && maxContextTokens != null ? Math.round(Math.min((outputTokens / maxContextTokens) * 100, 100)) : 0; const trackerRadius = useMemo(() => (TRACKER_SIZE - TRACKER_STROKE) / 2, []); const circumference = useMemo(() => 2 * Math.PI * trackerRadius, [trackerRadius]); const dashOffset = useMemo( () => circumference - (percentage / 100) * circumference, [circumference, percentage], ); const getRingColorClass = () => { if (!hasMaxContext) { return 'stroke-text-secondary'; } if (percentage > 90) { return 'stroke-red-500'; } if (percentage > 75) { return 'stroke-yellow-500'; } return 'stroke-green-500'; }; const getMainProgressColorClass = () => { if (!hasMaxContext) { return 'bg-text-secondary'; } if (percentage > 90) { return 'bg-red-500'; } if (percentage > 75) { return 'bg-yellow-500'; } return 'bg-green-500'; }; const ariaLabel = hasMaxContext ? localize('com_ui_token_usage_aria_full', { 0: formatTokenCount(inputTokens), 1: formatTokenCount(outputTokens), 2: formatTokenCount(maxContextTokens ?? 0), 3: percentage.toString(), }) : localize('com_ui_token_usage_aria_no_max', { 0: formatTokenCount(inputTokens), 1: formatTokenCount(outputTokens), 2: formatTokenCount(totalUsed), }); if (!showContextTracker) { return null; } return (
{localize('com_ui_context_usage')} {hasMaxContext ? ( 90, 'text-yellow-500': percentage > 75 && percentage <= 90, 'text-green-500': percentage <= 75, })} > {localize('com_ui_token_usage_percent', { 0: percentage.toString() })} ) : null}
); } export default memo(ContextTracker);