diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 2c27817985..5b00d6db99 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -11,10 +11,11 @@ const { } = require('~/models'); const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update'); const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); +const { tokenValues, getValueKey, defaultRate } = require('~/models/tx'); const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); const { getConvosQueried } = require('~/models/Conversation'); -const { Message, Transaction } = require('~/db/models'); const { countTokens } = require('~/server/utils'); +const { Message } = require('~/db/models'); const router = express.Router(); router.use(requireJwtAuth); @@ -160,103 +161,37 @@ router.post('/artifact/:messageId', async (req, res) => { } }); -router.get('/:conversationId/costs', validateMessageReq, async (req, res) => { +/** + * POST /costs + * Get cost information for models in modelHistory array + */ +router.post('/costs', async (req, res) => { try { - const user = req.user.id; - const { conversationId } = req.params; + const { modelHistory } = req.body; - const [transactions, messages] = await Promise.all([ - Transaction.find({ - conversationId, - user, - tokenType: { $in: ['prompt', 'completion'] }, - }) - .select('tokenType tokenValue createdAt') - .sort({ createdAt: 1 }) - .lean(), - Message.find({ conversationId, user }) - .select('messageId isCreatedByUser tokenCount createdAt') - .sort({ createdAt: 1 }) - .lean(), - ]); - - const userMsgs = messages.filter((m) => m.isCreatedByUser); - const aiMsgs = messages.filter((m) => !m.isCreatedByUser); - - const perMessageMap = new Map(); - for (const msg of messages) { - perMessageMap.set(msg.messageId, { - messageId: msg.messageId, - tokenType: msg.isCreatedByUser ? 'prompt' : 'completion', - tokenCount: msg.tokenCount ?? 0, - tokenValue: 0, - usd: 0, - }); + if (!Array.isArray(modelHistory)) { + return res.status(400).json({ error: 'modelHistory must be an array' }); } - let currentPrompt = 0; - let currentCompletion = 0; + const modelCostTable = {}; - let promptTokenValue = 0; - let completionTokenValue = 0; + modelHistory.forEach((modelEntry) => { + if (modelEntry && typeof modelEntry === 'object' && modelEntry.model && modelEntry.endpoint) { + const { model, endpoint } = modelEntry; - for (const tx of transactions) { - const value = Math.abs(tx.tokenValue ?? 0); - if (tx.tokenType === 'prompt') { - promptTokenValue += value; - const target = userMsgs[currentPrompt] ?? userMsgs[userMsgs.length - 1]; - if (target) { - const entry = perMessageMap.get(target.messageId); - entry.tokenValue += value; - perMessageMap.set(target.messageId, entry); - if (currentPrompt < userMsgs.length - 1) { - currentPrompt++; - } - } - } else if (tx.tokenType === 'completion') { - completionTokenValue += value; - const target = aiMsgs[currentCompletion] ?? aiMsgs[aiMsgs.length - 1]; - if (target) { - const entry = perMessageMap.get(target.messageId); - entry.tokenValue += value; - perMessageMap.set(target.messageId, entry); - if (currentCompletion < aiMsgs.length - 1) { - currentCompletion++; - } - } + const valueKey = getValueKey(model, endpoint); + const pricing = tokenValues[valueKey]; + + modelCostTable[model] = { + prompt: pricing?.prompt ?? defaultRate, + completion: pricing?.completion ?? defaultRate, + }; } - } + }); - const perMessage = Array.from(perMessageMap.values()).map((entry) => ({ - messageId: entry.messageId, - tokenType: entry.tokenType, - tokenCount: entry.tokenCount, - usd: entry.tokenValue / 1_000_000, - })); - - const promptTokenCount = userMsgs.reduce((sum, m) => sum + (m.tokenCount ?? 0), 0); - const completionTokenCount = aiMsgs.reduce((sum, m) => sum + (m.tokenCount ?? 0), 0); - const totalTokenCount = promptTokenCount + completionTokenCount; - - const totals = { - prompt: { - usd: promptTokenValue / 1_000_000, - tokenCount: promptTokenCount, - }, - completion: { - usd: completionTokenValue / 1_000_000, - tokenCount: completionTokenCount, - }, - total: { - usd: (promptTokenValue + completionTokenValue) / 1_000_000, - tokenCount: totalTokenCount, - }, - }; - - const response = { conversationId, totals, perMessage }; - res.status(200).json(response); + res.status(200).json({ modelCostTable }); } catch (error) { - logger.error('Error fetching conversation costs:', error); + logger.error('Error fetching model costs:', error); res.status(500).json({ error: 'Internal server error' }); } }); diff --git a/packages/data-provider/src/api-endpoints.ts b/packages/data-provider/src/api-endpoints.ts index 3097b0891c..330027d056 100644 --- a/packages/data-provider/src/api-endpoints.ts +++ b/packages/data-provider/src/api-endpoints.ts @@ -66,8 +66,7 @@ export const messages = (params: q.MessagesListParams) => { export const messagesArtifacts = (messageId: string) => `${messagesRoot}/artifacts/${messageId}`; -export const conversationCosts = (conversationId: string) => - `/api/messages/${conversationId}/costs`; +export const costs = () => `/api/messages/costs`; const shareRoot = `${BASE_URL}/api/share`; export const shareMessages = (shareId: string) => `${shareRoot}/${shareId}`;