diff --git a/client/src/components/Chat/ChatView.tsx b/client/src/components/Chat/ChatView.tsx index ae8d35231b..2a7e3a76a1 100644 --- a/client/src/components/Chat/ChatView.tsx +++ b/client/src/components/Chat/ChatView.tsx @@ -7,9 +7,9 @@ import { Constants, buildTree } from 'librechat-data-provider'; import type { TMessage } from 'librechat-data-provider'; import type { ChatFormValues } from '~/common'; import { ChatContext, AddedChatContext, useFileMapContext, ChatFormProvider } from '~/Providers'; -import { useGetMessagesByConvoId, useGetConversationCosts } from '~/data-provider'; import { useChatHelpers, useAddedResponse, useSSE } from '~/hooks'; import ConversationStarters from './Input/ConversationStarters'; +import { useGetMessagesByConvoId } from '~/data-provider'; import MessagesView from './Messages/MessagesView'; import Presentation from './Presentation'; import ChatForm from './Input/ChatForm'; @@ -30,7 +30,13 @@ function LoadingSpinner() { ); } -function ChatView({ index = 0 }: { index?: number }) { +function ChatView({ + index = 0, + modelCosts, +}: { + index?: number; + modelCosts?: { modelCostTable: Record }; +}) { const { conversationId } = useParams(); const rootSubmission = useRecoilValue(store.submissionByIndex(index)); const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1)); @@ -52,13 +58,6 @@ function ChatView({ index = 0 }: { index?: number }) { enabled: !!fileMap, }); - const { data: conversationCosts } = useGetConversationCosts( - conversationId && conversationId !== Constants.NEW_CONVO ? conversationId : '', - { - enabled: !!conversationId && conversationId !== Constants.NEW_CONVO && conversationId !== '', - }, - ); - const chatHelpers = useChatHelpers(index, conversationId); const addedChatHelpers = useAddedResponse({ rootIndex: index }); @@ -138,15 +137,14 @@ function ChatView({ index = 0 }: { index?: number }) { messagesTree={messagesTree} costBar={ !isLandingPage && - conversationCosts && - conversationCosts.totals && ( + modelCosts && ( ) } - costs={conversationCosts} /> ); } else { diff --git a/client/src/components/Chat/CostBar.tsx b/client/src/components/Chat/CostBar.tsx index 7b2f9dbc97..21265b532c 100644 --- a/client/src/components/Chat/CostBar.tsx +++ b/client/src/components/Chat/CostBar.tsx @@ -1,19 +1,73 @@ +import { useMemo } from 'react'; import { useRecoilValue } from 'recoil'; import { ArrowIcon } from '@librechat/client'; -import type { TConversationCosts } from 'librechat-data-provider'; +import { TModelCosts, TMessage } from 'librechat-data-provider'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; import store from '~/store'; interface CostBarProps { - conversationCosts: TConversationCosts; + messagesTree: TMessage[]; + modelCosts: TModelCosts; showCostBar: boolean; } -export default function CostBar({ conversationCosts, showCostBar }: CostBarProps) { +export default function CostBar({ messagesTree, modelCosts, showCostBar }: CostBarProps) { const localize = useLocalize(); const showCostTracking = useRecoilValue(store.showCostTracking); + const conversationCosts = useMemo(() => { + if (!modelCosts?.modelCostTable || !messagesTree) { + return null; + } + + let totalPromptTokens = 0; + let totalCompletionTokens = 0; + let totalPromptUSD = 0; + let totalCompletionUSD = 0; + + const flattenMessages = (messages: TMessage[]) => { + const flattened: TMessage[] = []; + messages.forEach((message: TMessage) => { + flattened.push(message); + if (message.children && message.children.length > 0) { + flattened.push(...flattenMessages(message.children)); + } + }); + return flattened; + }; + + const allMessages = flattenMessages(messagesTree); + + allMessages.forEach((message) => { + if (!message.tokenCount) { + return null; + } + + const modelToUse = message.isCreatedByUser ? message.targetModel : message.model; + + const modelPricing = modelCosts.modelCostTable[modelToUse]; + if (message.isCreatedByUser) { + totalPromptTokens += message.tokenCount; + totalPromptUSD += (message.tokenCount / 1000000) * modelPricing.prompt; + } else { + totalCompletionTokens += message.tokenCount; + totalCompletionUSD += (message.tokenCount / 1000000) * modelPricing.completion; + } + }); + + const totalTokens = totalPromptTokens + totalCompletionTokens; + const totalUSD = totalPromptUSD + totalCompletionUSD; + + return { + totals: { + prompt: { tokenCount: totalPromptTokens, usd: totalPromptUSD }, + completion: { tokenCount: totalCompletionTokens, usd: totalCompletionUSD }, + total: { tokenCount: totalTokens, usd: totalUSD }, + }, + }; + }, [modelCosts, messagesTree]); + if (!showCostTracking || !conversationCosts || !conversationCosts.totals) { return null; } diff --git a/client/src/routes/ChatRoute.tsx b/client/src/routes/ChatRoute.tsx index d81cbc075c..eb1a2220f8 100644 --- a/client/src/routes/ChatRoute.tsx +++ b/client/src/routes/ChatRoute.tsx @@ -6,6 +6,7 @@ import { useGetModelsQuery } from 'librechat-data-provider/react-query'; import type { TPreset } from 'librechat-data-provider'; import { useGetConvoIdQuery, useGetStartupConfig, useGetEndpointsQuery } from '~/data-provider'; import { useNewConvo, useAppStartup, useAssistantListMap, useIdChangeEffect } from '~/hooks'; +import { useGetModelCostsQuery } from 'librechat-data-provider/react-query'; import { getDefaultModelSpec, getModelSpecPreset, logger } from '~/utils'; import { ToolCallsMapProvider } from '~/Providers'; import ChatView from '~/components/Chat/ChatView'; @@ -44,6 +45,10 @@ export default function ChatRoute() { const endpointsQuery = useGetEndpointsQuery({ enabled: isAuthenticated }); const assistantListMap = useAssistantListMap(); + const modelCostsQuery = useGetModelCostsQuery(initialConvoQuery.data?.modelHistory || [], { + enabled: !!initialConvoQuery.data?.modelHistory?.length, + }); + const isTemporaryChat = conversation && conversation.expiredAt ? true : false; useEffect(() => { @@ -148,7 +153,7 @@ export default function ChatRoute() { return ( - + ); }