refactor: cost bar uses new cost table

This commit is contained in:
Dustin Healy 2025-09-13 13:30:00 -07:00
parent adff605c50
commit b034624690
3 changed files with 74 additions and 17 deletions

View file

@ -7,9 +7,9 @@ import { Constants, buildTree } from 'librechat-data-provider';
import type { TMessage } from 'librechat-data-provider';
import type { ChatFormValues } from '~/common';
import { ChatContext, AddedChatContext, useFileMapContext, ChatFormProvider } from '~/Providers';
import { useGetMessagesByConvoId, useGetConversationCosts } from '~/data-provider';
import { useChatHelpers, useAddedResponse, useSSE } from '~/hooks';
import ConversationStarters from './Input/ConversationStarters';
import { useGetMessagesByConvoId } from '~/data-provider';
import MessagesView from './Messages/MessagesView';
import Presentation from './Presentation';
import ChatForm from './Input/ChatForm';
@ -30,7 +30,13 @@ function LoadingSpinner() {
);
}
function ChatView({ index = 0 }: { index?: number }) {
function ChatView({
index = 0,
modelCosts,
}: {
index?: number;
modelCosts?: { modelCostTable: Record<string, { prompt: number; completion: number }> };
}) {
const { conversationId } = useParams();
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
@ -52,13 +58,6 @@ function ChatView({ index = 0 }: { index?: number }) {
enabled: !!fileMap,
});
const { data: conversationCosts } = useGetConversationCosts(
conversationId && conversationId !== Constants.NEW_CONVO ? conversationId : '',
{
enabled: !!conversationId && conversationId !== Constants.NEW_CONVO && conversationId !== '',
},
);
const chatHelpers = useChatHelpers(index, conversationId);
const addedChatHelpers = useAddedResponse({ rootIndex: index });
@ -138,15 +137,14 @@ function ChatView({ index = 0 }: { index?: number }) {
messagesTree={messagesTree}
costBar={
!isLandingPage &&
conversationCosts &&
conversationCosts.totals && (
modelCosts && (
<CostBar
conversationCosts={conversationCosts}
messagesTree={messagesTree}
modelCosts={modelCosts}
showCostBar={showCostBar && !isStreaming}
/>
)
}
costs={conversationCosts}
/>
);
} else {

View file

@ -1,19 +1,73 @@
import { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import { ArrowIcon } from '@librechat/client';
import type { TConversationCosts } from 'librechat-data-provider';
import { TModelCosts, TMessage } from 'librechat-data-provider';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
import store from '~/store';
interface CostBarProps {
conversationCosts: TConversationCosts;
messagesTree: TMessage[];
modelCosts: TModelCosts;
showCostBar: boolean;
}
export default function CostBar({ conversationCosts, showCostBar }: CostBarProps) {
export default function CostBar({ messagesTree, modelCosts, showCostBar }: CostBarProps) {
const localize = useLocalize();
const showCostTracking = useRecoilValue(store.showCostTracking);
const conversationCosts = useMemo(() => {
if (!modelCosts?.modelCostTable || !messagesTree) {
return null;
}
let totalPromptTokens = 0;
let totalCompletionTokens = 0;
let totalPromptUSD = 0;
let totalCompletionUSD = 0;
const flattenMessages = (messages: TMessage[]) => {
const flattened: TMessage[] = [];
messages.forEach((message: TMessage) => {
flattened.push(message);
if (message.children && message.children.length > 0) {
flattened.push(...flattenMessages(message.children));
}
});
return flattened;
};
const allMessages = flattenMessages(messagesTree);
allMessages.forEach((message) => {
if (!message.tokenCount) {
return null;
}
const modelToUse = message.isCreatedByUser ? message.targetModel : message.model;
const modelPricing = modelCosts.modelCostTable[modelToUse];
if (message.isCreatedByUser) {
totalPromptTokens += message.tokenCount;
totalPromptUSD += (message.tokenCount / 1000000) * modelPricing.prompt;
} else {
totalCompletionTokens += message.tokenCount;
totalCompletionUSD += (message.tokenCount / 1000000) * modelPricing.completion;
}
});
const totalTokens = totalPromptTokens + totalCompletionTokens;
const totalUSD = totalPromptUSD + totalCompletionUSD;
return {
totals: {
prompt: { tokenCount: totalPromptTokens, usd: totalPromptUSD },
completion: { tokenCount: totalCompletionTokens, usd: totalCompletionUSD },
total: { tokenCount: totalTokens, usd: totalUSD },
},
};
}, [modelCosts, messagesTree]);
if (!showCostTracking || !conversationCosts || !conversationCosts.totals) {
return null;
}

View file

@ -6,6 +6,7 @@ import { useGetModelsQuery } from 'librechat-data-provider/react-query';
import type { TPreset } from 'librechat-data-provider';
import { useGetConvoIdQuery, useGetStartupConfig, useGetEndpointsQuery } from '~/data-provider';
import { useNewConvo, useAppStartup, useAssistantListMap, useIdChangeEffect } from '~/hooks';
import { useGetModelCostsQuery } from 'librechat-data-provider/react-query';
import { getDefaultModelSpec, getModelSpecPreset, logger } from '~/utils';
import { ToolCallsMapProvider } from '~/Providers';
import ChatView from '~/components/Chat/ChatView';
@ -44,6 +45,10 @@ export default function ChatRoute() {
const endpointsQuery = useGetEndpointsQuery({ enabled: isAuthenticated });
const assistantListMap = useAssistantListMap();
const modelCostsQuery = useGetModelCostsQuery(initialConvoQuery.data?.modelHistory || [], {
enabled: !!initialConvoQuery.data?.modelHistory?.length,
});
const isTemporaryChat = conversation && conversation.expiredAt ? true : false;
useEffect(() => {
@ -148,7 +153,7 @@ export default function ChatRoute() {
return (
<ToolCallsMapProvider conversationId={conversation.conversationId ?? ''}>
<ChatView index={index} />
<ChatView index={index} modelCosts={modelCostsQuery.data} />
</ToolCallsMapProvider>
);
}