diff --git a/api/models/Transaction.js b/api/models/Transaction.js index e553e2bb3b..7f018e1c30 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,140 +1,7 @@ -const { logger } = require('@librechat/data-schemas'); +const { logger, CANCEL_RATE } = require('@librechat/data-schemas'); const { getMultiplier, getCacheMultiplier } = require('./tx'); -const { Transaction, Balance } = require('~/db/models'); - -const cancelRate = 1.15; - -/** - * Updates a user's token balance based on a transaction using optimistic concurrency control - * without schema changes. Compatible with DocumentDB. - * @async - * @function - * @param {Object} params - The function parameters. - * @param {string|mongoose.Types.ObjectId} params.user - The user ID. - * @param {number} params.incrementValue - The value to increment the balance by (can be negative). - * @param {import('mongoose').UpdateQuery['$set']} [params.setValues] - Optional additional fields to set. - * @returns {Promise} Returns the updated balance document (lean). - * @throws {Error} Throws an error if the update fails after multiple retries. - */ -const updateBalance = async ({ user, incrementValue, setValues }) => { - let maxRetries = 10; // Number of times to retry on conflict - let delay = 50; // Initial retry delay in ms - let lastError = null; - - for (let attempt = 1; attempt <= maxRetries; attempt++) { - let currentBalanceDoc; - try { - // 1. Read the current document state - currentBalanceDoc = await Balance.findOne({ user }).lean(); - const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0; - - // 2. Calculate the desired new state - const potentialNewCredits = currentCredits + incrementValue; - const newCredits = Math.max(0, potentialNewCredits); // Ensure balance doesn't go below zero - - // 3. Prepare the update payload - const updatePayload = { - $set: { - tokenCredits: newCredits, - ...(setValues || {}), // Merge other values to set - }, - }; - - // 4. Attempt the conditional update or upsert - let updatedBalance = null; - if (currentBalanceDoc) { - // --- Document Exists: Perform Conditional Update --- - // Try to update only if the tokenCredits match the value we read (currentCredits) - updatedBalance = await Balance.findOneAndUpdate( - { - user: user, - tokenCredits: currentCredits, // Optimistic lock: condition based on the read value - }, - updatePayload, - { - new: true, // Return the modified document - // lean: true, // .lean() is applied after query execution in Mongoose >= 6 - }, - ).lean(); // Use lean() for plain JS object - - if (updatedBalance) { - // Success! The update was applied based on the expected current state. - return updatedBalance; - } - // If updatedBalance is null, it means tokenCredits changed between read and write (conflict). - lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`); - // Proceed to retry logic below. - } else { - // --- Document Does Not Exist: Perform Conditional Upsert --- - // Try to insert the document, but only if it still doesn't exist. - // Using tokenCredits: {$exists: false} helps prevent race conditions where - // another process creates the doc between our findOne and findOneAndUpdate. - try { - updatedBalance = await Balance.findOneAndUpdate( - { - user: user, - // Attempt to match only if the document doesn't exist OR was just created - // without tokenCredits (less likely but possible). A simple { user } filter - // might also work, relying on the retry for conflicts. - // Let's use a simpler filter and rely on retry for races. - // tokenCredits: { $exists: false } // This condition might be too strict if doc exists with 0 credits - }, - updatePayload, - { - upsert: true, // Create if doesn't exist - new: true, // Return the created/updated document - // setDefaultsOnInsert: true, // Ensure schema defaults are applied on insert - // lean: true, - }, - ).lean(); - - if (updatedBalance) { - // Upsert succeeded (likely created the document) - return updatedBalance; - } - // If null, potentially a rare race condition during upsert. Retry should handle it. - lastError = new Error( - `Upsert race condition suspected for user ${user} on attempt ${attempt}.`, - ); - } catch (error) { - if (error.code === 11000) { - // E11000 duplicate key error on index - // This means another process created the document *just* before our upsert. - // It's a concurrency conflict during creation. We should retry. - lastError = error; // Store the error - // Proceed to retry logic below. - } else { - // Different error, rethrow - throw error; - } - } - } // End if/else (document exists?) - } catch (error) { - // Catch errors from findOne or unexpected findOneAndUpdate errors - logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error); - lastError = error; // Store the error - // Consider stopping retries for non-transient errors, but for now, we retry. - } - - // If we reached here, it means the update failed (conflict or error), wait and retry - if (attempt < maxRetries) { - const jitter = Math.random() * delay * 0.5; // Add jitter to delay - await new Promise((resolve) => setTimeout(resolve, delay + jitter)); - delay = Math.min(delay * 2, 2000); // Exponential backoff with cap - } - } // End for loop (retries) - - // If loop finishes without success, throw the last encountered error or a generic one - logger.error( - `[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`, - ); - throw ( - lastError || - new Error( - `Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`, - ) - ); -}; +const { Transaction } = require('~/db/models'); +const { updateBalance } = require('~/models'); /** Method to calculate and set the tokenValue for a transaction */ function calculateTokenValue(txn) { @@ -145,8 +12,8 @@ function calculateTokenValue(txn) { txn.rate = multiplier; txn.tokenValue = txn.rawAmount * multiplier; if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { - txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); - txn.rate *= cancelRate; + txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE); + txn.rate *= CANCEL_RATE; } } @@ -321,11 +188,11 @@ function calculateStructuredTokenValue(txn) { } if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { - txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); - txn.rate *= cancelRate; + txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE); + txn.rate *= CANCEL_RATE; if (txn.rateDetail) { txn.rateDetail = Object.fromEntries( - Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]), + Object.entries(txn.rateDetail).map(([k, v]) => [k, v * CANCEL_RATE]), ); } } diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index 545c7b2755..f363c472e1 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -1,8 +1,10 @@ const mongoose = require('mongoose'); +const { recordCollectedUsage } = require('@librechat/api'); +const { createMethods } = require('@librechat/data-schemas'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { getMultiplier, getCacheMultiplier, premiumTokenValues, tokenValues } = require('./tx'); const { createTransaction, createStructuredTransaction } = require('./Transaction'); +const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { Balance, Transaction } = require('~/db/models'); let mongoServer; @@ -985,3 +987,339 @@ describe('Premium Token Pricing Integration Tests', () => { expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); }); + +describe('Bulk path parity', () => { + /** + * Each test here mirrors an existing legacy test above, replacing spendTokens/ + * spendStructuredTokens with recordCollectedUsage + bulk deps. + * The balance deduction and transaction document fields must be numerically identical. + */ + let bulkDeps; + let methods; + + beforeEach(() => { + methods = createMethods(mongoose); + bulkDeps = { + spendTokens: () => Promise.resolve(), + spendStructuredTokens: () => Promise.resolve(), + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { + insertMany: methods.bulkInsertTransactions, + updateBalance: methods.updateBalance, + }, + }; + }); + + test('balance should decrease when spending tokens via bulk path', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const promptTokens = 100; + const completionTokens = 50; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model, + context: 'test', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], + }); + + const updatedBalance = await Balance.findOne({ user: userId }); + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: promptTokens, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + const expectedTotalCost = + promptTokens * promptMultiplier + completionTokens * completionMultiplier; + const expectedBalance = initialBalance - expectedTotalCost; + + expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + }); + + test('bulk path should not update balance when balance.enabled is false', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model, + context: 'test', + balance: { enabled: false }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: 100, output_tokens: 50, model }], + }); + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBe(initialBalance); + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); // transactions still recorded + }); + + test('bulk path should not insert when transactions.enabled is false', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'gpt-3.5-turbo', + context: 'test', + balance: { enabled: true }, + transactions: { enabled: false }, + collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(0); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('bulk path handles incomplete context for completion tokens — same CANCEL_RATE as legacy', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 17613154.55; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const promptTokens = 10; + const completionTokens = 50; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-convo', + model, + context: 'incomplete', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + const completionTx = txns.find((t) => t.tokenType === 'completion'); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + expect(completionTx.tokenValue).toBeCloseTo(-completionTokens * completionMultiplier * 1.15, 0); + }); + + test('bulk path structured tokens — balance deduction matches legacy spendStructuredTokens', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 17613154.55; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const promptInput = 11; + const promptWrite = 140522; + const promptRead = 0; + const completionTokens = 5; + const totalInput = promptInput + promptWrite + promptRead; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-convo', + model, + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [ + { + input_tokens: promptInput, + output_tokens: completionTokens, + model, + input_token_details: { cache_creation: promptWrite, cache_read: promptRead }, + }, + ], + }); + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; + + const expectedPromptCost = + promptInput * promptMultiplier + promptWrite * writeMultiplier + promptRead * readMultiplier; + const expectedCompletionCost = completionTokens * completionMultiplier; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + const expectedBalance = initialBalance - expectedTotalCost; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(Math.abs(updatedBalance.tokenCredits - expectedBalance)).toBeLessThan(100); + }); + + test('premium pricing above threshold via bulk path — same balance as legacy', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 250000; + const completionTokens = 500; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-premium', + model, + context: 'test', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], + }); + + const premiumPromptRate = premiumTokenValues[model].prompt; + const premiumCompletionRate = premiumTokenValues[model].completion; + const expectedCost = + promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('real-world multi-entry batch: 5 sequential tool calls — same total deduction as 5 legacy spendTokens calls', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-5-20251101'; + const calls = [ + { input_tokens: 31596, output_tokens: 151 }, + { input_tokens: 35368, output_tokens: 150 }, + { input_tokens: 58362, output_tokens: 295 }, + { input_tokens: 112604, output_tokens: 193 }, + { input_tokens: 257440, output_tokens: 2217 }, + ]; + + let expectedTotalCost = 0; + for (const { input_tokens, output_tokens } of calls) { + const pm = getMultiplier({ model, tokenType: 'prompt', inputTokenCount: input_tokens }); + const cm = getMultiplier({ model, tokenType: 'completion', inputTokenCount: input_tokens }); + expectedTotalCost += input_tokens * pm + output_tokens * cm; + } + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-sequential', + model, + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: calls.map((c) => ({ ...c, model })), + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(10); // 5 calls × 2 docs (prompt + completion) + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); + + test('bulk path should save transaction but not update balance when balance disabled, transactions enabled', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'gpt-3.5-turbo', + context: 'test', + balance: { enabled: false }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + expect(txns[0].rawAmount).toBeDefined(); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('bulk path structured tokens should not save when transactions.enabled is false', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'claude-3-5-sonnet', + context: 'message', + balance: { enabled: true }, + transactions: { enabled: false }, + collectedUsage: [ + { + input_tokens: 10, + output_tokens: 5, + model: 'claude-3-5-sonnet', + input_token_details: { cache_creation: 100, cache_read: 5 }, + }, + ], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(0); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('bulk path structured tokens should save but not update balance when balance disabled', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'claude-3-5-sonnet', + context: 'message', + balance: { enabled: false }, + transactions: { enabled: true }, + collectedUsage: [ + { + input_tokens: 10, + output_tokens: 5, + model: 'claude-3-5-sonnet', + input_token_details: { cache_creation: 100, cache_read: 5 }, + }, + ], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx.inputTokens).toBe(-10); + expect(promptTx.writeTokens).toBe(-100); + expect(promptTx.readTokens).toBe(-5); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); +}); diff --git a/api/package.json b/api/package.json index f9c9601a37..3e9350ac34 100644 --- a/api/package.json +++ b/api/package.json @@ -44,7 +44,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", diff --git a/api/server/controllers/agents/__tests__/openai.spec.js b/api/server/controllers/agents/__tests__/openai.spec.js index 8592c79a2d..835343e798 100644 --- a/api/server/controllers/agents/__tests__/openai.spec.js +++ b/api/server/controllers/agents/__tests__/openai.spec.js @@ -82,6 +82,13 @@ jest.mock('~/models/spendTokens', () => ({ spendStructuredTokens: mockSpendStructuredTokens, })); +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + jest.mock('~/server/controllers/agents/callbacks', () => ({ createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), })); @@ -103,6 +110,8 @@ jest.mock('~/models/Agent', () => ({ getAgents: jest.fn().mockResolvedValue([]), })); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); jest.mock('~/models', () => ({ getFiles: jest.fn(), getUserKey: jest.fn(), @@ -112,6 +121,8 @@ jest.mock('~/models', () => ({ getUserCodeFiles: jest.fn(), getToolFilesByIds: jest.fn(), getCodeGeneratedFiles: jest.fn(), + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, })); describe('OpenAIChatCompletionController', () => { @@ -155,7 +166,15 @@ describe('OpenAIChatCompletionController', () => { expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(mockRecordCollectedUsage).toHaveBeenCalledWith( - { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, expect.objectContaining({ user: 'user-123', conversationId: expect.any(String), @@ -182,12 +201,18 @@ describe('OpenAIChatCompletionController', () => { ); }); - it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + it('should pass spendTokens, spendStructuredTokens, pricing, and bulkWriteOps as dependencies', async () => { await OpenAIChatCompletionController(req, res); const [deps] = mockRecordCollectedUsage.mock.calls[0]; expect(deps).toHaveProperty('spendTokens', mockSpendTokens); expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + expect(deps).toHaveProperty('pricing'); + expect(deps.pricing).toHaveProperty('getMultiplier', mockGetMultiplier); + expect(deps.pricing).toHaveProperty('getCacheMultiplier', mockGetCacheMultiplier); + expect(deps).toHaveProperty('bulkWriteOps'); + expect(deps.bulkWriteOps).toHaveProperty('insertMany', mockBulkInsertTransactions); + expect(deps.bulkWriteOps).toHaveProperty('updateBalance', mockUpdateBalance); }); it('should include model from primaryConfig in recordCollectedUsage params', async () => { diff --git a/api/server/controllers/agents/__tests__/responses.unit.spec.js b/api/server/controllers/agents/__tests__/responses.unit.spec.js index e16ca394b2..45ec31fc68 100644 --- a/api/server/controllers/agents/__tests__/responses.unit.spec.js +++ b/api/server/controllers/agents/__tests__/responses.unit.spec.js @@ -106,6 +106,13 @@ jest.mock('~/models/spendTokens', () => ({ spendStructuredTokens: mockSpendStructuredTokens, })); +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + jest.mock('~/server/controllers/agents/callbacks', () => ({ createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()), @@ -131,6 +138,8 @@ jest.mock('~/models/Agent', () => ({ getAgents: jest.fn().mockResolvedValue([]), })); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); jest.mock('~/models', () => ({ getFiles: jest.fn(), getUserKey: jest.fn(), @@ -141,6 +150,8 @@ jest.mock('~/models', () => ({ getUserCodeFiles: jest.fn(), getToolFilesByIds: jest.fn(), getCodeGeneratedFiles: jest.fn(), + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, })); describe('createResponse controller', () => { @@ -184,7 +195,15 @@ describe('createResponse controller', () => { expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(mockRecordCollectedUsage).toHaveBeenCalledWith( - { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, expect.objectContaining({ user: 'user-123', conversationId: expect.any(String), @@ -209,12 +228,18 @@ describe('createResponse controller', () => { ); }); - it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + it('should pass spendTokens, spendStructuredTokens, pricing, and bulkWriteOps as dependencies', async () => { await createResponse(req, res); const [deps] = mockRecordCollectedUsage.mock.calls[0]; expect(deps).toHaveProperty('spendTokens', mockSpendTokens); expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + expect(deps).toHaveProperty('pricing'); + expect(deps.pricing).toHaveProperty('getMultiplier', mockGetMultiplier); + expect(deps.pricing).toHaveProperty('getCacheMultiplier', mockGetCacheMultiplier); + expect(deps).toHaveProperty('bulkWriteOps'); + expect(deps.bulkWriteOps).toHaveProperty('insertMany', mockBulkInsertTransactions); + expect(deps.bulkWriteOps).toHaveProperty('updateBalance', mockUpdateBalance); }); it('should include model from primaryConfig in recordCollectedUsage params', async () => { @@ -244,7 +269,15 @@ describe('createResponse controller', () => { expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(mockRecordCollectedUsage).toHaveBeenCalledWith( - { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, expect.objectContaining({ user: 'user-123', context: 'message', diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index d69281d49c..5f99a0762b 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -13,11 +13,12 @@ const { createSafeUser, initializeAgent, getBalanceConfig, - getProviderConfig, omitTitleOptions, + getProviderConfig, memoryInstructions, - applyContextToAgent, createTokenCounter, + applyContextToAgent, + recordCollectedUsage, GenerationJobManager, getTransactionsConfig, createMemoryProcessor, @@ -45,6 +46,8 @@ const { } = require('librechat-data-provider'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { updateBalance, bulkInsertTransactions } = require('~/models'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const { createContextHandlers } = require('~/app/clients/prompts'); const { getConvoFiles } = require('~/models/Conversation'); const BaseClient = require('~/app/clients/BaseClient'); @@ -624,83 +627,29 @@ class AgentClient extends BaseClient { context = 'message', collectedUsage = this.collectedUsage, }) { - if (!collectedUsage || !collectedUsage.length) { - return; - } - // Use first entry's input_tokens as the base input (represents initial user message context) - // Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens) - const firstUsage = collectedUsage[0]; - const input_tokens = - (firstUsage?.input_tokens || 0) + - (Number(firstUsage?.input_token_details?.cache_creation) || - Number(firstUsage?.cache_creation_input_tokens) || - 0) + - (Number(firstUsage?.input_token_details?.cache_read) || - Number(firstUsage?.cache_read_input_tokens) || - 0); - - // Sum output_tokens directly from all entries - works for both sequential and parallel execution - // This avoids the incremental calculation that produced negative values for parallel agents - let total_output_tokens = 0; - - for (const usage of collectedUsage) { - if (!usage) { - continue; - } - - // Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens) - const cache_creation = - Number(usage.input_token_details?.cache_creation) || - Number(usage.cache_creation_input_tokens) || - 0; - const cache_read = - Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0; - - // Accumulate output tokens for the usage summary - total_output_tokens += Number(usage.output_tokens) || 0; - - const txMetadata = { + const result = await recordCollectedUsage( + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance }, + }, + { + user: this.user ?? this.options.req.user?.id, + conversationId: this.conversationId, + collectedUsage, + model: model ?? this.model ?? this.options.agent.model_parameters.model, context, + messageId: this.responseMessageId, balance, transactions, - messageId: this.responseMessageId, - conversationId: this.conversationId, - user: this.user ?? this.options.req.user?.id, endpointTokenConfig: this.options.endpointTokenConfig, - model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, - }; + }, + ); - if (cache_creation > 0 || cache_read > 0) { - spendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error( - '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens', - err, - ); - }); - continue; - } - spendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error( - '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens', - err, - ); - }); + if (result) { + this.usage = result; } - - this.usage = { - input_tokens, - output_tokens: total_output_tokens, - }; } /** diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js index a083bd9291..e8561f15fe 100644 --- a/api/server/controllers/agents/openai.js +++ b/api/server/controllers/agents/openai.js @@ -25,6 +25,7 @@ const { loadAgentTools, loadToolsForExecution } = require('~/server/services/Too const { createToolEndCallback } = require('~/server/controllers/agents/callbacks'); const { findAccessibleResources } = require('~/server/services/PermissionService'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const { getConvoFiles } = require('~/models/Conversation'); const { getAgent, getAgents } = require('~/models/Agent'); const db = require('~/models'); @@ -493,7 +494,12 @@ const OpenAIChatCompletionController = async (req, res) => { const balanceConfig = getBalanceConfig(appConfig); const transactionsConfig = getTransactionsConfig(appConfig); recordCollectedUsage( - { spendTokens, spendStructuredTokens }, + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, + }, { user: userId, conversationId, diff --git a/api/server/controllers/agents/recordCollectedUsage.spec.js b/api/server/controllers/agents/recordCollectedUsage.spec.js index 6904f2ed39..21720023ca 100644 --- a/api/server/controllers/agents/recordCollectedUsage.spec.js +++ b/api/server/controllers/agents/recordCollectedUsage.spec.js @@ -2,23 +2,37 @@ * Tests for AgentClient.recordCollectedUsage * * This is a critical function that handles token spending for agent LLM calls. - * It must correctly handle: - * - Sequential execution (single agent with tool calls) - * - Parallel execution (multiple agents with independent inputs) - * - Cache token handling (OpenAI and Anthropic formats) + * The client now delegates to the TS recordCollectedUsage from @librechat/api, + * passing pricing and bulkWriteOps deps. */ const { EModelEndpoint } = require('librechat-data-provider'); -// Mock dependencies before requiring the module const mockSpendTokens = jest.fn().mockResolvedValue(); const mockSpendStructuredTokens = jest.fn().mockResolvedValue(); +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); jest.mock('~/models/spendTokens', () => ({ spendTokens: (...args) => mockSpendTokens(...args), spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args), })); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + +jest.mock('~/models', () => ({ + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, +})); + jest.mock('~/config', () => ({ logger: { debug: jest.fn(), @@ -39,6 +53,14 @@ jest.mock('@librechat/agents', () => ({ }), })); +jest.mock('@librechat/api', () => { + const actual = jest.requireActual('@librechat/api'); + return { + ...actual, + recordCollectedUsage: (...args) => mockRecordCollectedUsage(...args), + }; +}); + const AgentClient = require('./client'); describe('AgentClient - recordCollectedUsage', () => { @@ -74,31 +96,66 @@ describe('AgentClient - recordCollectedUsage', () => { }); describe('basic functionality', () => { - it('should return early if collectedUsage is empty', async () => { + it('should delegate to recordCollectedUsage with full deps', async () => { + const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; + + await client.recordCollectedUsage({ + collectedUsage, + balance: { enabled: true }, + transactions: { enabled: true }, + }); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + const [deps, params] = mockRecordCollectedUsage.mock.calls[0]; + + expect(deps).toHaveProperty('spendTokens'); + expect(deps).toHaveProperty('spendStructuredTokens'); + expect(deps).toHaveProperty('pricing'); + expect(deps.pricing).toHaveProperty('getMultiplier'); + expect(deps.pricing).toHaveProperty('getCacheMultiplier'); + expect(deps).toHaveProperty('bulkWriteOps'); + expect(deps.bulkWriteOps).toHaveProperty('insertMany'); + expect(deps.bulkWriteOps).toHaveProperty('updateBalance'); + + expect(params).toEqual( + expect.objectContaining({ + user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + }), + ); + }); + + it('should not set this.usage if collectedUsage is empty (returns undefined)', async () => { + mockRecordCollectedUsage.mockResolvedValue(undefined); + await client.recordCollectedUsage({ collectedUsage: [], balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); expect(client.usage).toBeUndefined(); }); - it('should return early if collectedUsage is null', async () => { + it('should not set this.usage if collectedUsage is null (returns undefined)', async () => { + mockRecordCollectedUsage.mockResolvedValue(undefined); + await client.recordCollectedUsage({ collectedUsage: null, balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).not.toHaveBeenCalled(); expect(client.usage).toBeUndefined(); }); - it('should handle single usage entry correctly', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; + it('should set this.usage from recordCollectedUsage result', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 75 }); + const collectedUsage = [{ input_tokens: 200, output_tokens: 75, model: 'gpt-4' }]; await client.recordCollectedUsage({ collectedUsage, @@ -106,521 +163,122 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ - conversationId: 'convo-123', - user: 'user-123', - model: 'gpt-4', - }), - { promptTokens: 100, completionTokens: 50 }, - ); - expect(client.usage.input_tokens).toBe(100); - expect(client.usage.output_tokens).toBe(50); - }); - - it('should skip null entries in collectedUsage', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - null, - { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); + expect(client.usage).toEqual({ input_tokens: 200, output_tokens: 75 }); }); }); describe('sequential execution (single agent with tool calls)', () => { - it('should calculate tokens correctly for sequential tool calls', async () => { - // Sequential flow: output of call N becomes part of input for call N+1 - // Call 1: input=100, output=50 - // Call 2: input=150 (100+50), output=30 - // Call 3: input=180 (150+30), output=20 + it('should pass all usage entries to recordCollectedUsage', async () => { const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 150, output_tokens: 30, model: 'gpt-4' }, { input_tokens: 180, output_tokens: 20, model: 'gpt-4' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 100 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledTimes(3); - // Total output should be sum of all output_tokens: 50 + 30 + 20 = 100 + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.collectedUsage).toHaveLength(3); expect(client.usage.output_tokens).toBe(100); - expect(client.usage.input_tokens).toBe(100); // First entry's input + expect(client.usage.input_tokens).toBe(100); }); }); describe('parallel execution (multiple agents)', () => { - it('should handle parallel agents with independent input tokens', async () => { - // Parallel agents have INDEPENDENT input tokens (not cumulative) - // Agent A: input=100, output=50 - // Agent B: input=80, output=40 (different context, not 100+50) + it('should pass parallel agent usage to recordCollectedUsage', async () => { const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 80, output_tokens: 40, model: 'gpt-4' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 90 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - // Expected total output: 50 + 40 = 90 - // output_tokens must be positive and should reflect total output + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(client.usage.output_tokens).toBe(90); expect(client.usage.output_tokens).toBeGreaterThan(0); }); - it('should NOT produce negative output_tokens for parallel execution', async () => { - // Critical bug scenario: parallel agents where second agent has LOWER input tokens + /** Bug regression: parallel agents where second agent has LOWER input tokens produced negative output via incremental calculation. */ + it('should NOT produce negative output_tokens', async () => { const collectedUsage = [ { input_tokens: 200, output_tokens: 100, model: 'gpt-4' }, { input_tokens: 50, output_tokens: 30, model: 'gpt-4' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 130 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - // output_tokens MUST be positive for proper token tracking expect(client.usage.output_tokens).toBeGreaterThan(0); - // Correct value should be 100 + 30 = 130 - }); - - it('should calculate correct total output for parallel agents', async () => { - // Three parallel agents with independent contexts - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { input_tokens: 120, output_tokens: 60, model: 'gpt-4-turbo' }, - { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(3); - // Total output should be 50 + 60 + 40 = 150 - expect(client.usage.output_tokens).toBe(150); - }); - - it('should handle worst-case parallel scenario without negative tokens', async () => { - // Extreme case: first agent has very high input, subsequent have low - const collectedUsage = [ - { input_tokens: 1000, output_tokens: 500, model: 'gpt-4' }, - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { input_tokens: 50, output_tokens: 25, model: 'gpt-4' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - // Must be positive, should be 500 + 50 + 25 = 575 - expect(client.usage.output_tokens).toBeGreaterThan(0); - expect(client.usage.output_tokens).toBe(575); + expect(client.usage.output_tokens).toBe(130); }); }); describe('real-world scenarios', () => { - it('should correctly sum output tokens for sequential tool calls with growing context', async () => { - // Real production data: Claude Opus with multiple tool calls - // Context grows as tool results are added, but output_tokens should only count model generations + it('should correctly handle sequential tool calls with growing context', async () => { const collectedUsage = [ - { - input_tokens: 31596, - output_tokens: 151, - total_tokens: 31747, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 35368, - output_tokens: 150, - total_tokens: 35518, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 58362, - output_tokens: 295, - total_tokens: 58657, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 112604, - output_tokens: 193, - total_tokens: 112797, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 257440, - output_tokens: 2217, - total_tokens: 259657, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, + { input_tokens: 31596, output_tokens: 151, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 35368, output_tokens: 150, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 58362, output_tokens: 295, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 112604, output_tokens: 193, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 257440, output_tokens: 2217, model: 'claude-opus-4-5-20251101' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 31596, output_tokens: 3006 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - // input_tokens should be first entry's input (initial context) expect(client.usage.input_tokens).toBe(31596); - - // output_tokens should be sum of all model outputs: 151 + 150 + 295 + 193 + 2217 = 3006 - // NOT the inflated value from incremental calculation (338,559) expect(client.usage.output_tokens).toBe(3006); - - // Verify spendTokens was called for each entry with correct values - expect(mockSpendTokens).toHaveBeenCalledTimes(5); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { promptTokens: 31596, completionTokens: 151 }, - ); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 5, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { promptTokens: 257440, completionTokens: 2217 }, - ); }); - it('should handle single followup message correctly', async () => { - // Real production data: followup to the above conversation - const collectedUsage = [ - { - input_tokens: 263406, - output_tokens: 257, - total_tokens: 263663, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(client.usage.input_tokens).toBe(263406); - expect(client.usage.output_tokens).toBe(257); - - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { promptTokens: 263406, completionTokens: 257 }, - ); - }); - - it('should ensure output_tokens > 0 check passes for BaseClient.sendMessage', async () => { - // This verifies the fix for the duplicate token spending bug - // BaseClient.sendMessage checks: if (usage != null && Number(usage[this.outputTokensKey]) > 0) - const collectedUsage = [ - { - input_tokens: 31596, - output_tokens: 151, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 35368, - output_tokens: 150, - model: 'claude-opus-4-5-20251101', - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - const usage = client.getStreamUsage(); - - // The check that was failing before the fix - expect(usage).not.toBeNull(); - expect(Number(usage.output_tokens)).toBeGreaterThan(0); - - // Verify correct value - expect(usage.output_tokens).toBe(301); // 151 + 150 - }); - - it('should correctly handle cache tokens with multiple tool calls', async () => { - // Real production data: Claude Opus with cache tokens (prompt caching) - // First entry has cache_creation, subsequent entries have cache_read + it('should correctly handle cache tokens', async () => { const collectedUsage = [ { input_tokens: 788, output_tokens: 163, - total_tokens: 951, input_token_details: { cache_read: 0, cache_creation: 30808 }, model: 'claude-opus-4-5-20251101', }, - { - input_tokens: 3802, - output_tokens: 149, - total_tokens: 3951, - input_token_details: { cache_read: 30808, cache_creation: 768 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 26808, - output_tokens: 225, - total_tokens: 27033, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 80912, - output_tokens: 204, - total_tokens: 81116, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 136454, - output_tokens: 206, - total_tokens: 136660, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 146316, - output_tokens: 224, - total_tokens: 146540, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 150402, - output_tokens: 1248, - total_tokens: 151650, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 156268, - output_tokens: 139, - total_tokens: 156407, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 167126, - output_tokens: 2961, - total_tokens: 170087, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 31596, output_tokens: 163 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - // input_tokens = first entry's input + cache_creation + cache_read - // = 788 + 30808 + 0 = 31596 expect(client.usage.input_tokens).toBe(31596); - - // output_tokens = sum of all output_tokens - // = 163 + 149 + 225 + 204 + 206 + 224 + 1248 + 139 + 2961 = 5519 - expect(client.usage.output_tokens).toBe(5519); - - // First 2 entries have cache tokens, should use spendStructuredTokens - // Remaining 7 entries have cache_read but no cache_creation, still structured - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(9); - expect(mockSpendTokens).toHaveBeenCalledTimes(0); - - // Verify first entry uses structured tokens with cache_creation - expect(mockSpendStructuredTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { - promptTokens: { input: 788, write: 30808, read: 0 }, - completionTokens: 163, - }, - ); - - // Verify second entry uses structured tokens with both cache_creation and cache_read - expect(mockSpendStructuredTokens).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { - promptTokens: { input: 3802, write: 768, read: 30808 }, - completionTokens: 149, - }, - ); - }); - }); - - describe('cache token handling', () => { - it('should handle OpenAI format cache tokens (input_token_details)', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'gpt-4', - input_token_details: { - cache_creation: 20, - cache_read: 10, - }, - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4' }), - { - promptTokens: { - input: 100, - write: 20, - read: 10, - }, - completionTokens: 50, - }, - ); - }); - - it('should handle Anthropic format cache tokens (cache_*_input_tokens)', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'claude-3', - cache_creation_input_tokens: 25, - cache_read_input_tokens: 15, - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'claude-3' }), - { - promptTokens: { - input: 100, - write: 25, - read: 15, - }, - completionTokens: 50, - }, - ); - }); - - it('should use spendTokens for entries without cache tokens', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); - }); - - it('should handle mixed cache and non-cache entries', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { - input_tokens: 150, - output_tokens: 30, - model: 'gpt-4', - input_token_details: { cache_creation: 10, cache_read: 5 }, - }, - { input_tokens: 200, output_tokens: 20, model: 'gpt-4' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - }); - - it('should include cache tokens in total input calculation', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'gpt-4', - input_token_details: { - cache_creation: 20, - cache_read: 10, - }, - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - // Total input should include cache tokens: 100 + 20 + 10 = 130 - expect(client.usage.input_tokens).toBe(130); + expect(client.usage.output_tokens).toBe(163); }); }); describe('model fallback', () => { - it('should use usage.model when available', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4-turbo' }]; - - await client.recordCollectedUsage({ - model: 'fallback-model', - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4-turbo' }), - expect.any(Object), - ); - }); - - it('should fallback to param model when usage.model is missing', async () => { + it('should use param model when available', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; await client.recordCollectedUsage({ @@ -630,14 +288,13 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'param-model' }), - expect.any(Object), - ); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.model).toBe('param-model'); }); it('should fallback to client.model when param model is missing', async () => { client.model = 'client-model'; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; await client.recordCollectedUsage({ @@ -646,13 +303,12 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'client-model' }), - expect.any(Object), - ); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.model).toBe('client-model'); }); it('should fallback to agent model_parameters.model as last resort', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; await client.recordCollectedUsage({ @@ -661,15 +317,14 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4' }), - expect.any(Object), - ); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.model).toBe('gpt-4'); }); }); describe('getStreamUsage integration', () => { it('should return the usage object set by recordCollectedUsage', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; await client.recordCollectedUsage({ @@ -679,10 +334,7 @@ describe('AgentClient - recordCollectedUsage', () => { }); const usage = client.getStreamUsage(); - expect(usage).toEqual({ - input_tokens: 100, - output_tokens: 50, - }); + expect(usage).toEqual({ input_tokens: 100, output_tokens: 50 }); }); it('should return undefined before recordCollectedUsage is called', () => { @@ -690,9 +342,9 @@ describe('AgentClient - recordCollectedUsage', () => { expect(usage).toBeUndefined(); }); + /** Verifies usage passes the check in BaseClient.sendMessage: if (usage != null && Number(usage[this.outputTokensKey]) > 0) */ it('should have output_tokens > 0 for BaseClient.sendMessage check', async () => { - // This test verifies the usage will pass the check in BaseClient.sendMessage: - // if (usage != null && Number(usage[this.outputTokensKey]) > 0) + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 130 }); const collectedUsage = [ { input_tokens: 200, output_tokens: 100, model: 'gpt-4' }, { input_tokens: 50, output_tokens: 30, model: 'gpt-4' }, diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js index 8ce15766c7..83e6ad6efd 100644 --- a/api/server/controllers/agents/responses.js +++ b/api/server/controllers/agents/responses.js @@ -38,6 +38,7 @@ const { loadAgentTools, loadToolsForExecution } = require('~/server/services/Too const { findAccessibleResources } = require('~/server/services/PermissionService'); const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const { getAgent, getAgents } = require('~/models/Agent'); const db = require('~/models'); @@ -509,7 +510,12 @@ const createResponse = async (req, res) => { const balanceConfig = getBalanceConfig(req.config); const transactionsConfig = getTransactionsConfig(req.config); recordCollectedUsage( - { spendTokens, spendStructuredTokens }, + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, + }, { user: userId, conversationId, @@ -658,7 +664,12 @@ const createResponse = async (req, res) => { const balanceConfig = getBalanceConfig(req.config); const transactionsConfig = getTransactionsConfig(req.config); recordCollectedUsage( - { spendTokens, spendStructuredTokens }, + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, + }, { user: userId, conversationId, diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index acc9299b04..d39b0104a8 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,17 +1,19 @@ const { logger } = require('@librechat/data-schemas'); const { - countTokens, isEnabled, sendEvent, + countTokens, GenerationJobManager, + recordCollectedUsage, sanitizeMessageForTransmit, } = require('@librechat/api'); const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); +const { saveMessage, getConvo, updateBalance, bulkInsertTransactions } = require('~/models'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const clearPendingReq = require('~/cache/clearPendingReq'); const { sendError } = require('~/server/middleware/error'); -const { saveMessage, getConvo } = require('~/models'); const { abortRun } = require('./abortRun'); /** @@ -40,57 +42,22 @@ async function spendCollectedUsage({ return; } - const spendPromises = []; - - for (const usage of collectedUsage) { - if (!usage) { - continue; - } - - // Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens) - const cache_creation = - Number(usage.input_token_details?.cache_creation) || - Number(usage.cache_creation_input_tokens) || - 0; - const cache_read = - Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0; - - const txMetadata = { + await recordCollectedUsage( + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance }, + }, + { + user: userId, + conversationId, + collectedUsage, context: 'abort', messageId, - conversationId, - user: userId, - model: usage.model ?? fallbackModel, - }; - - if (cache_creation > 0 || cache_read > 0) { - spendPromises.push( - spendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[abortMiddleware] Error spending structured tokens for abort', err); - }), - ); - continue; - } - - spendPromises.push( - spendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[abortMiddleware] Error spending tokens for abort', err); - }), - ); - } - - // Wait for all token spending to complete - await Promise.all(spendPromises); + model: fallbackModel, + }, + ); // Clear the array to prevent double-spending from the AgentClient finally block. // The collectedUsage array is shared by reference with AgentClient.collectedUsage, @@ -301,4 +268,5 @@ const handleAbortError = async (res, req, error, data) => { module.exports = { handleAbort, handleAbortError, + spendCollectedUsage, }; diff --git a/api/server/middleware/abortMiddleware.spec.js b/api/server/middleware/abortMiddleware.spec.js index 93f2ce558b..795814a928 100644 --- a/api/server/middleware/abortMiddleware.spec.js +++ b/api/server/middleware/abortMiddleware.spec.js @@ -4,16 +4,32 @@ * This tests the token spending logic for abort scenarios, * particularly for parallel agents (addedConvo) where multiple * models need their tokens spent. + * + * spendCollectedUsage delegates to recordCollectedUsage from @librechat/api, + * passing pricing + bulkWriteOps deps, with context: 'abort'. + * After spending, it clears the collectedUsage array to prevent double-spending + * from the AgentClient finally block (which shares the same array reference). */ const mockSpendTokens = jest.fn().mockResolvedValue(); const mockSpendStructuredTokens = jest.fn().mockResolvedValue(); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); + +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); jest.mock('~/models/spendTokens', () => ({ spendTokens: (...args) => mockSpendTokens(...args), spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args), })); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + jest.mock('@librechat/data-schemas', () => ({ logger: { debug: jest.fn(), @@ -30,6 +46,7 @@ jest.mock('@librechat/api', () => ({ GenerationJobManager: { abortJob: jest.fn(), }, + recordCollectedUsage: mockRecordCollectedUsage, sanitizeMessageForTransmit: jest.fn((msg) => msg), })); @@ -49,94 +66,27 @@ jest.mock('~/server/middleware/error', () => ({ sendError: jest.fn(), })); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); jest.mock('~/models', () => ({ saveMessage: jest.fn().mockResolvedValue(), getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }), + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, })); jest.mock('./abortRun', () => ({ abortRun: jest.fn(), })); -// Import the module after mocks are set up -// We need to extract the spendCollectedUsage function for testing -// Since it's not exported, we'll test it through the handleAbort flow +const { spendCollectedUsage } = require('./abortMiddleware'); describe('abortMiddleware - spendCollectedUsage', () => { beforeEach(() => { jest.clearAllMocks(); }); - describe('spendCollectedUsage logic', () => { - // Since spendCollectedUsage is not exported, we test the logic directly - // by replicating the function here for unit testing - - const spendCollectedUsage = async ({ - userId, - conversationId, - collectedUsage, - fallbackModel, - }) => { - if (!collectedUsage || collectedUsage.length === 0) { - return; - } - - const spendPromises = []; - - for (const usage of collectedUsage) { - if (!usage) { - continue; - } - - const cache_creation = - Number(usage.input_token_details?.cache_creation) || - Number(usage.cache_creation_input_tokens) || - 0; - const cache_read = - Number(usage.input_token_details?.cache_read) || - Number(usage.cache_read_input_tokens) || - 0; - - const txMetadata = { - context: 'abort', - conversationId, - user: userId, - model: usage.model ?? fallbackModel, - }; - - if (cache_creation > 0 || cache_read > 0) { - spendPromises.push( - mockSpendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch(() => { - // Log error but don't throw - }), - ); - continue; - } - - spendPromises.push( - mockSpendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch(() => { - // Log error but don't throw - }), - ); - } - - // Wait for all token spending to complete - await Promise.all(spendPromises); - - // Clear the array to prevent double-spending - collectedUsage.length = 0; - }; - + describe('spendCollectedUsage delegation', () => { it('should return early if collectedUsage is empty', async () => { await spendCollectedUsage({ userId: 'user-123', @@ -145,8 +95,7 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockRecordCollectedUsage).not.toHaveBeenCalled(); }); it('should return early if collectedUsage is null', async () => { @@ -157,28 +106,10 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockRecordCollectedUsage).not.toHaveBeenCalled(); }); - it('should skip null entries in collectedUsage', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - null, - { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'gpt-4', - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - }); - - it('should spend tokens for single model', async () => { + it('should call recordCollectedUsage with abort context and full deps', async () => { const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; await spendCollectedUsage({ @@ -186,21 +117,35 @@ describe('abortMiddleware - spendCollectedUsage', () => { conversationId: 'convo-123', collectedUsage, fallbackModel: 'gpt-4', + messageId: 'msg-123', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ - context: 'abort', - conversationId: 'convo-123', + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { + spendTokens: expect.any(Function), + spendStructuredTokens: expect.any(Function), + pricing: { + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, + }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, + { user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + context: 'abort', + messageId: 'msg-123', model: 'gpt-4', - }), - { promptTokens: 100, completionTokens: 50 }, + }, ); }); - it('should spend tokens for multiple models (parallel agents)', async () => { + it('should pass context abort for multiple models (parallel agents)', async () => { const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, @@ -214,136 +159,17 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(3); - - // Verify each model was called - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'gpt-4' }), - { promptTokens: 100, completionTokens: 50 }, - ); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ model: 'claude-3' }), - { promptTokens: 80, completionTokens: 40 }, - ); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 3, - expect.objectContaining({ model: 'gemini-pro' }), - { promptTokens: 120, completionTokens: 60 }, - ); - }); - - it('should use fallbackModel when usage.model is missing', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'fallback-model', - }); - - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'fallback-model' }), + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( expect.any(Object), + expect.objectContaining({ + context: 'abort', + collectedUsage, + }), ); }); - it('should use spendStructuredTokens for OpenAI format cache tokens', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'gpt-4', - input_token_details: { - cache_creation: 20, - cache_read: 10, - }, - }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'gpt-4', - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4', context: 'abort' }), - { - promptTokens: { - input: 100, - write: 20, - read: 10, - }, - completionTokens: 50, - }, - ); - }); - - it('should use spendStructuredTokens for Anthropic format cache tokens', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'claude-3', - cache_creation_input_tokens: 25, - cache_read_input_tokens: 15, - }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'claude-3', - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'claude-3' }), - { - promptTokens: { - input: 100, - write: 25, - read: 15, - }, - completionTokens: 50, - }, - ); - }); - - it('should handle mixed cache and non-cache entries', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { - input_tokens: 150, - output_tokens: 30, - model: 'claude-3', - cache_creation_input_tokens: 20, - cache_read_input_tokens: 10, - }, - { input_tokens: 200, output_tokens: 20, model: 'gemini-pro' }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'gpt-4', - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - }); - it('should handle real-world parallel agent abort scenario', async () => { - // Simulates: Primary agent (gemini) + addedConvo agent (gpt-5) aborted mid-stream const collectedUsage = [ { input_tokens: 31596, output_tokens: 151, model: 'gemini-3-flash-preview' }, { input_tokens: 28000, output_tokens: 120, model: 'gpt-5.2' }, @@ -356,27 +182,24 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gemini-3-flash-preview', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - - // Primary model - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'gemini-3-flash-preview' }), - { promptTokens: 31596, completionTokens: 151 }, - ); - - // Parallel model (addedConvo) - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ model: 'gpt-5.2' }), - { promptTokens: 28000, completionTokens: 120 }, + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + user: 'user-123', + conversationId: 'convo-123', + context: 'abort', + model: 'gemini-3-flash-preview', + }), ); }); + /** + * Race condition prevention: after abort middleware spends tokens, + * the collectedUsage array is cleared so AgentClient.recordCollectedUsage() + * (which shares the same array reference) sees an empty array and returns early. + */ it('should clear collectedUsage array after spending to prevent double-spending', async () => { - // This tests the race condition fix: after abort middleware spends tokens, - // the collectedUsage array is cleared so AgentClient.recordCollectedUsage() - // (which shares the same array reference) sees an empty array and returns early. const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, @@ -391,19 +214,16 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - - // The array should be cleared after spending + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(collectedUsage.length).toBe(0); }); - it('should await all token spending operations before clearing array', async () => { - // Ensure we don't clear the array before spending completes - let spendCallCount = 0; - mockSpendTokens.mockImplementation(async () => { - spendCallCount++; - // Simulate async delay + it('should await recordCollectedUsage before clearing array', async () => { + let resolved = false; + mockRecordCollectedUsage.mockImplementation(async () => { await new Promise((resolve) => setTimeout(resolve, 10)); + resolved = true; + return { input_tokens: 100, output_tokens: 50 }; }); const collectedUsage = [ @@ -418,10 +238,7 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - // Both spend calls should have completed - expect(spendCallCount).toBe(2); - - // Array should be cleared after awaiting + expect(resolved).toBe(true); expect(collectedUsage.length).toBe(0); }); }); diff --git a/api/server/utils/import/importers.spec.js b/api/server/utils/import/importers.spec.js index a695a31555..2ddfa76658 100644 --- a/api/server/utils/import/importers.spec.js +++ b/api/server/utils/import/importers.spec.js @@ -1277,12 +1277,9 @@ describe('processAssistantMessage', () => { results.push(duration); }); - // Check if processing time increases exponentially - // In a ReDoS vulnerability, time would roughly double with each size increase - for (let i = 1; i < results.length; i++) { - const ratio = results[i] / results[i - 1]; - expect(ratio).toBeLessThan(3); // Allow for CI environment variability while still catching ReDoS - console.log(`Size ${sizes[i]} processing time ratio: ${ratio}`); + // Each size should complete well under 100ms; a ReDoS would cause exponential blowup + for (let i = 0; i < results.length; i++) { + expect(results[i]).toBeLessThan(100); } // Also test with the exact payload from the security report diff --git a/package-lock.json b/package-lock.json index c03ef33c8d..2b90bbec3e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -59,7 +59,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", @@ -11836,9 +11836,9 @@ } }, "node_modules/@librechat/agents": { - "version": "3.1.53", - "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.53.tgz", - "integrity": "sha512-jK9JHIhQYgr+Ha2FhknEYQmS6Ft3/TGdYIlL6L6EtIq20SIA59r1DvQx/x9sd3wHoHkk6AZumMgqAUTTCaWBIA==", + "version": "3.1.54", + "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.54.tgz", + "integrity": "sha512-OdsE8kDgtIhxs0sR0rG7I5WynbZKAH/j/50OCZEjLdv//jR8Lj6fpL9RCEzRGnu44MjUZgRkSgf2JV3LHsCJiQ==", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "^0.73.0", @@ -43797,7 +43797,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/package.json b/packages/api/package.json index 3ceaeb7a12..f2529ecea5 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -90,7 +90,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/src/agents/index.ts b/packages/api/src/agents/index.ts index 9d13b3dd8e..47e15b8c28 100644 --- a/packages/api/src/agents/index.ts +++ b/packages/api/src/agents/index.ts @@ -9,6 +9,7 @@ export * from './legacy'; export * from './memory'; export * from './migration'; export * from './openai'; +export * from './transactions'; export * from './usage'; export * from './resources'; export * from './responses'; diff --git a/packages/api/src/agents/transactions.bulk-parity.spec.ts b/packages/api/src/agents/transactions.bulk-parity.spec.ts new file mode 100644 index 0000000000..bf89682d6f --- /dev/null +++ b/packages/api/src/agents/transactions.bulk-parity.spec.ts @@ -0,0 +1,559 @@ +/** + * Real-DB parity tests for the bulk transaction path. + * + * Each test uses the actual getMultiplier/getCacheMultiplier pricing functions + * (the same ones the legacy createTransaction path uses) and runs the bulk path + * against a real MongoMemoryServer instance. + * + * The assertion pattern: compute the expected tokenValue/rate/rawAmount from the + * pricing functions directly, then verify the DB state matches exactly. Since both + * legacy (createTransaction) and bulk (prepareTokenSpend + bulkWriteTransactions) + * call the same pricing functions with the same inputs, their outputs must be + * numerically identical. + */ +import mongoose from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { + CANCEL_RATE, + createMethods, + balanceSchema, + transactionSchema, +} from '@librechat/data-schemas'; +import type { PricingFns, TxMetadata } from './transactions'; +import { + prepareStructuredTokenSpend, + bulkWriteTransactions, + prepareTokenSpend, +} from './transactions'; + +jest.mock('@librechat/data-schemas', () => { + const actual = jest.requireActual('@librechat/data-schemas'); + return { + ...actual, + logger: { debug: jest.fn(), error: jest.fn(), warn: jest.fn(), info: jest.fn() }, + }; +}); + +// Real pricing functions from api/models/tx.js — same ones the legacy path uses +/* eslint-disable @typescript-eslint/no-require-imports */ +const { + getMultiplier, + getCacheMultiplier, + tokenValues, + premiumTokenValues, +} = require('../../../../api/models/tx.js'); +/* eslint-enable @typescript-eslint/no-require-imports */ + +const pricing: PricingFns = { getMultiplier, getCacheMultiplier }; + +let mongoServer: MongoMemoryServer; +let Transaction: mongoose.Model; +let Balance: mongoose.Model; +let dbMethods: ReturnType; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + Transaction = mongoose.models.Transaction || mongoose.model('Transaction', transactionSchema); + Balance = mongoose.models.Balance || mongoose.model('Balance', balanceSchema); + dbMethods = createMethods(mongoose); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.connection.dropDatabase(); +}); + +const dbOps = () => ({ + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, +}); + +function txMeta(user: string, extra: Partial = {}): TxMetadata { + return { + user, + conversationId: 'test-convo', + context: 'test', + balance: { enabled: true }, + transactions: { enabled: true }, + ...extra, + }; +} + +describe('Standard token parity', () => { + test('balance should decrease by promptCost + completionCost — identical to legacy path', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const promptTokens = 100; + const completionTokens = 50; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: promptTokens, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + const expectedCost = promptTokens * promptMultiplier + completionTokens * completionMultiplier; + const expectedBalance = initialBalance - expectedCost; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(expectedBalance, 0); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + const completionTx = txns.find((t) => t.tokenType === 'completion'); + expect(promptTx!.rawAmount).toBe(-promptTokens); + expect(promptTx!.rate).toBe(promptMultiplier); + expect(promptTx!.tokenValue).toBe(-promptTokens * promptMultiplier); + expect(completionTx!.rawAmount).toBe(-completionTokens); + expect(completionTx!.rate).toBe(completionMultiplier); + expect(completionTx!.tokenValue).toBe(-completionTokens * completionMultiplier); + }); + + test('balance unchanged when balance.enabled is false — identical to legacy path', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo', balance: { enabled: false } }), + { promptTokens: 100, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBe(initialBalance); + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); // transactions still inserted + }); + + test('no docs when transactions.enabled is false — identical to legacy path', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo', transactions: { enabled: false } }), + { promptTokens: 100, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(0); + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('abort context — transactions inserted, no balance update when balance not passed', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const entries = prepareTokenSpend( + txMeta(userId, { model, context: 'abort', balance: undefined }), + { promptTokens: 100, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('NaN promptTokens — only completion doc inserted, identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo' }), + { promptTokens: NaN, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(1); + expect(txns[0].tokenType).toBe('completion'); + }); + + test('zero tokens produce docs with rawAmount=0, tokenValue=0', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 10000 }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo' }), + { promptTokens: 0, completionTokens: 0 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + expect(txns.every((t) => t.rawAmount === 0)).toBe(true); + expect(txns.every((t) => t.tokenValue === 0)).toBe(true); + }); +}); + +describe('CANCEL_RATE parity (incomplete context)', () => { + test('CANCEL_RATE applied to completion token — same tokenValue as legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 10000000 }); + + const model = 'claude-3-5-sonnet'; + const completionTokens = 50; + const promptTokens = 10; + + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + const expectedCompletionTokenValue = Math.ceil( + -completionTokens * completionMultiplier * CANCEL_RATE, + ); + + const entries = prepareTokenSpend( + txMeta(userId, { model, context: 'incomplete' }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const completionTx = txns.find((t) => t.tokenType === 'completion'); + expect(completionTx!.tokenValue).toBe(expectedCompletionTokenValue); + expect(completionTx!.rate).toBeCloseTo(completionMultiplier * CANCEL_RATE, 5); + }); + + test('CANCEL_RATE NOT applied to prompt tokens in incomplete context', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 10000000 }); + + const model = 'claude-3-5-sonnet'; + const promptTokens = 100; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: promptTokens, + }); + + const entries = prepareTokenSpend( + txMeta(userId, { model, context: 'incomplete' }), + { promptTokens, completionTokens: 0 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx!.rate).toBe(promptMultiplier); // no CANCEL_RATE + }); +}); + +describe('Structured token parity', () => { + test('balance deduction identical to legacy spendStructuredTokens', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 17613154.55; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const tokenUsage = { + promptTokens: { input: 11, write: 140522, read: 0 }, + completionTokens: 5, + }; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: 11 + 140522, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: 11 + 140522, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; + + const expectedPromptCost = + tokenUsage.promptTokens.input * promptMultiplier + + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; + const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + const expectedBalance = initialBalance - expectedTotalCost; + + const entries = prepareStructuredTokenSpend(txMeta(userId, { model }), tokenUsage, pricing); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(Math.abs((balance.tokenCredits as number) - expectedBalance)).toBeLessThan(100); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx!.inputTokens).toBe(-11); + expect(promptTx!.writeTokens).toBe(-140522); + expect(Math.abs(Number(promptTx!.readTokens ?? 0))).toBe(0); + }); + + test('structured tokens with both cache_creation and cache_read', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const tokenUsage = { + promptTokens: { input: 100, write: 50, read: 30 }, + completionTokens: 80, + }; + const totalInput = 180; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: totalInput, + }); + + const expectedPromptCost = 100 * promptMultiplier + 50 * writeMultiplier + 30 * readMultiplier; + const expectedCost = expectedPromptCost + 80 * completionMultiplier; + + const entries = prepareStructuredTokenSpend(txMeta(userId, { model }), tokenUsage, pricing); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx!.inputTokens).toBe(-100); + expect(promptTx!.writeTokens).toBe(-50); + expect(promptTx!.readTokens).toBe(-30); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect( + Math.abs((balance.tokenCredits as number) - (initialBalance - expectedCost)), + ).toBeLessThan(1); + }); + + test('CANCEL_RATE applied to completion in structured incomplete context', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 17613154.55 }); + + const model = 'claude-3-5-sonnet'; + const tokenUsage = { + promptTokens: { input: 10, write: 100, read: 5 }, + completionTokens: 50, + }; + + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: 115, + }); + const expectedCompletionTokenValue = Math.ceil(-50 * completionMultiplier * CANCEL_RATE); + + const entries = prepareStructuredTokenSpend( + txMeta(userId, { model, context: 'incomplete' }), + tokenUsage, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const completionTx = txns.find((t) => t.tokenType === 'completion'); + expect(completionTx!.tokenValue).toBeCloseTo(expectedCompletionTokenValue, 0); + }); +}); + +describe('Premium pricing parity', () => { + test('standard pricing below threshold — identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 100000; + const completionTokens = 500; + + const standardPromptRate = (tokenValues as Record>)[model] + .prompt; + const standardCompletionRate = (tokenValues as Record>)[model] + .completion; + const expectedCost = + promptTokens * standardPromptRate + completionTokens * standardCompletionRate; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('premium pricing above threshold — identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 250000; + const completionTokens = 500; + + const premiumPromptRate = (premiumTokenValues as Record>)[model] + .prompt; + const premiumCompletionRate = (premiumTokenValues as Record>)[ + model + ].completion; + const expectedCost = + promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('standard pricing at exactly the threshold — identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = (premiumTokenValues as Record>)[model] + .threshold; + const completionTokens = 500; + + const standardPromptRate = (tokenValues as Record>)[model] + .prompt; + const standardCompletionRate = (tokenValues as Record>)[model] + .completion; + const expectedCost = + promptTokens * standardPromptRate + completionTokens * standardCompletionRate; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); +}); + +describe('Multi-entry batch parity', () => { + test('real-world sequential tool calls — total balance deduction identical to N individual legacy calls', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-5-20251101'; + const calls = [ + { promptTokens: 31596, completionTokens: 151 }, + { promptTokens: 35368, completionTokens: 150 }, + { promptTokens: 58362, completionTokens: 295 }, + { promptTokens: 112604, completionTokens: 193 }, + { promptTokens: 257440, completionTokens: 2217 }, + ]; + + let expectedTotalCost = 0; + const allEntries = []; + for (const { promptTokens, completionTokens } of calls) { + const pm = getMultiplier({ model, tokenType: 'prompt', inputTokenCount: promptTokens }); + const cm = getMultiplier({ model, tokenType: 'completion', inputTokenCount: promptTokens }); + expectedTotalCost += promptTokens * pm + completionTokens * cm; + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + allEntries.push(...entries); + } + + await bulkWriteTransactions({ user: userId, docs: allEntries }, dbOps()); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(10); // 5 calls × 2 docs + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); + + test('structured premium above threshold — batch vs individual produce same balance deduction', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const tokenUsage = { + promptTokens: { input: 200000, write: 10000, read: 5000 }, + completionTokens: 1000, + }; + const totalInput = 215000; + + const premiumPromptRate = (premiumTokenValues as Record>)[model] + .prompt; + const premiumCompletionRate = (premiumTokenValues as Record>)[ + model + ].completion; + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + + const expectedPromptCost = + tokenUsage.promptTokens.input * premiumPromptRate + + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; + const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + + expect(totalInput).toBeGreaterThan( + (premiumTokenValues as Record>)[model].threshold, + ); + + const entries = prepareStructuredTokenSpend(txMeta(userId, { model }), tokenUsage, pricing); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); +}); diff --git a/packages/api/src/agents/transactions.spec.ts b/packages/api/src/agents/transactions.spec.ts new file mode 100644 index 0000000000..99fb7cdd85 --- /dev/null +++ b/packages/api/src/agents/transactions.spec.ts @@ -0,0 +1,474 @@ +import mongoose from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { + CANCEL_RATE, + createMethods, + balanceSchema, + transactionSchema, +} from '@librechat/data-schemas'; +import type { PricingFns, TxMetadata, PreparedEntry } from './transactions'; +import { + prepareStructuredTokenSpend, + bulkWriteTransactions, + prepareTokenSpend, +} from './transactions'; + +jest.mock('@librechat/data-schemas', () => { + const actual = jest.requireActual('@librechat/data-schemas'); + return { + ...actual, + logger: { + debug: jest.fn(), + error: jest.fn(), + warn: jest.fn(), + info: jest.fn(), + }, + }; +}); + +let mongoServer: MongoMemoryServer; +let Transaction: mongoose.Model; +let Balance: mongoose.Model; +let dbMethods: ReturnType; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + Transaction = mongoose.models.Transaction || mongoose.model('Transaction', transactionSchema); + Balance = mongoose.models.Balance || mongoose.model('Balance', balanceSchema); + dbMethods = createMethods(mongoose); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.connection.dropDatabase(); +}); + +const testUserId = new mongoose.Types.ObjectId().toString(); + +const baseTxData: TxMetadata = { + user: testUserId, + context: 'message', + conversationId: 'convo-123', + model: 'gpt-4', + messageId: 'msg-123', + balance: { enabled: true }, + transactions: { enabled: true }, +}; + +const mockPricing: PricingFns = { + getMultiplier: jest.fn().mockReturnValue(2), + getCacheMultiplier: jest.fn().mockReturnValue(null), +}; + +describe('prepareTokenSpend', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should prepare prompt + completion entries', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(2); + expect(entries[0].doc.tokenType).toBe('prompt'); + expect(entries[1].doc.tokenType).toBe('completion'); + }); + + it('should return empty array when transactions disabled', () => { + const txData = { ...baseTxData, transactions: { enabled: false } }; + const entries = prepareTokenSpend( + txData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(0); + }); + + it('should filter out NaN rawAmount entries', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: NaN, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(1); + expect(entries[0].doc.tokenType).toBe('completion'); + }); + + it('should handle promptTokens only', () => { + const entries = prepareTokenSpend(baseTxData, { promptTokens: 100 }, mockPricing); + expect(entries).toHaveLength(1); + expect(entries[0].doc.tokenType).toBe('prompt'); + }); + + it('should handle completionTokens only', () => { + const entries = prepareTokenSpend(baseTxData, { completionTokens: 50 }, mockPricing); + expect(entries).toHaveLength(1); + expect(entries[0].doc.tokenType).toBe('completion'); + }); + + it('should handle zero tokens', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 0, completionTokens: 0 }, + mockPricing, + ); + expect(entries).toHaveLength(2); + expect(entries[0].doc.rawAmount).toBe(0); + expect(entries[1].doc.rawAmount).toBe(0); + }); + + it('should calculate tokenValue using pricing multiplier', () => { + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(3); + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + expect(entries[0].doc.rate).toBe(3); + expect(entries[0].doc.tokenValue).toBe(-100 * 3); + expect(entries[1].doc.rate).toBe(3); + expect(entries[1].doc.tokenValue).toBe(-50 * 3); + }); + + it('should pass valueKey to getMultiplier', () => { + prepareTokenSpend(baseTxData, { promptTokens: 100 }, mockPricing); + expect(mockPricing.getMultiplier).toHaveBeenCalledWith( + expect.objectContaining({ tokenType: 'prompt', model: 'gpt-4' }), + ); + }); + + it('should carry balance config on each entry', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + for (const entry of entries) { + expect(entry.balance).toEqual({ enabled: true }); + } + }); +}); + +describe('prepareTokenSpend — CANCEL_RATE', () => { + beforeEach(() => { + jest.clearAllMocks(); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(2); + }); + + it('should apply CANCEL_RATE to completion tokens with incomplete context', () => { + const txData: TxMetadata = { ...baseTxData, context: 'incomplete' }; + const entries = prepareTokenSpend( + txData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + const completion = entries.find((e) => e.doc.tokenType === 'completion'); + expect(completion).toBeDefined(); + expect(completion!.doc.rate).toBe(2 * CANCEL_RATE); + expect(completion!.doc.tokenValue).toBe(Math.ceil(-50 * 2 * CANCEL_RATE)); + }); + + it('should NOT apply CANCEL_RATE to prompt tokens with incomplete context', () => { + const txData: TxMetadata = { ...baseTxData, context: 'incomplete' }; + const entries = prepareTokenSpend( + txData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + const prompt = entries.find((e) => e.doc.tokenType === 'prompt'); + expect(prompt!.doc.rate).toBe(2); + }); + + it('should NOT apply CANCEL_RATE for abort context', () => { + const txData: TxMetadata = { ...baseTxData, context: 'abort' }; + const entries = prepareTokenSpend(txData, { completionTokens: 50 }, mockPricing); + expect(entries[0].doc.rate).toBe(2); + }); +}); + +describe('prepareStructuredTokenSpend', () => { + beforeEach(() => { + jest.clearAllMocks(); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(2); + (mockPricing.getCacheMultiplier as jest.Mock).mockReturnValue(null); + }); + + it('should prepare prompt + completion for structured tokens', () => { + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 100, write: 50, read: 30 }, completionTokens: 80 }, + mockPricing, + ); + expect(entries).toHaveLength(2); + expect(entries[0].doc.tokenType).toBe('prompt'); + expect(entries[0].doc.inputTokens).toBe(-100); + expect(entries[0].doc.writeTokens).toBe(-50); + expect(entries[0].doc.readTokens).toBe(-30); + expect(entries[1].doc.tokenType).toBe('completion'); + }); + + it('should use cache multipliers when available', () => { + (mockPricing.getCacheMultiplier as jest.Mock).mockImplementation(({ cacheType }) => { + if (cacheType === 'write') { + return 5; + } + if (cacheType === 'read') { + return 0.5; + } + return null; + }); + + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 100, write: 50, read: 30 }, completionTokens: 0 }, + mockPricing, + ); + const prompt = entries.find((e) => e.doc.tokenType === 'prompt'); + expect(prompt).toBeDefined(); + expect(prompt!.doc.rateDetail).toEqual({ input: 2, write: 5, read: 0.5 }); + }); + + it('should return empty when transactions disabled', () => { + const txData = { ...baseTxData, transactions: { enabled: false } }; + const entries = prepareStructuredTokenSpend( + txData, + { promptTokens: { input: 100 }, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(0); + }); + + it('should handle zero totalPromptTokens (fallback rate)', () => { + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 0, write: 0, read: 0 }, completionTokens: 50 }, + mockPricing, + ); + const prompt = entries.find((e) => e.doc.tokenType === 'prompt'); + expect(prompt).toBeDefined(); + expect(prompt!.doc.rate).toBe(2); + }); +}); + +describe('bulkWriteTransactions (real DB)', () => { + it('should return early for empty docs without DB writes', async () => { + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs: [] }, dbOps); + const txCount = await Transaction.countDocuments(); + expect(txCount).toBe(0); + }); + + it('should insert transaction documents into MongoDB', async () => { + const docs: PreparedEntry[] = [ + { + doc: { + user: testUserId, + conversationId: 'c1', + tokenType: 'prompt', + tokenValue: -200, + rate: 2, + rawAmount: -100, + }, + tokenValue: -200, + balance: { enabled: true }, + }, + { + doc: { + user: testUserId, + conversationId: 'c1', + tokenType: 'completion', + tokenValue: -100, + rate: 2, + rawAmount: -50, + }, + tokenValue: -100, + balance: { enabled: true }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const saved = await Transaction.find({ user: testUserId }).lean(); + expect(saved).toHaveLength(2); + expect(saved.map((t: Record) => t.tokenType).sort()).toEqual([ + 'completion', + 'prompt', + ]); + }); + + it('should create balance document and update credits', async () => { + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -300 }, + tokenValue: -300, + balance: { enabled: true }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal).toBeDefined(); + expect(bal!.tokenCredits).toBe(0); + }); + + it('should NOT update balance when no docs have balance enabled', async () => { + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -100 }, + tokenValue: -100, + balance: { enabled: false }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const txCount = await Transaction.countDocuments({ user: testUserId }); + expect(txCount).toBe(1); + const bal = await Balance.findOne({ user: testUserId }).lean(); + expect(bal).toBeNull(); + }); + + it('should only sum tokenValue from balance-enabled docs', async () => { + await Balance.create({ user: testUserId, tokenCredits: 1000 }); + + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -100 }, + tokenValue: -100, + balance: { enabled: true }, + }, + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'completion', tokenValue: -50 }, + tokenValue: -50, + balance: { enabled: false }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal!.tokenCredits).toBe(900); + }); + + it('should handle null balance gracefully', async () => { + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -100 }, + tokenValue: -100, + balance: null, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const txCount = await Transaction.countDocuments({ user: testUserId }); + expect(txCount).toBe(1); + const bal = await Balance.findOne({ user: testUserId }).lean(); + expect(bal).toBeNull(); + }); +}); + +describe('end-to-end: prepare → bulk write → verify', () => { + it('should prepare, write, and correctly update balance for standard tokens', async () => { + await Balance.create({ user: testUserId, tokenCredits: 10000 }); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(2); + + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs: entries }, dbOps); + + const txns = (await Transaction.find({ user: testUserId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + + const prompt = txns.find((t) => t.tokenType === 'prompt'); + const completion = txns.find((t) => t.tokenType === 'completion'); + expect(prompt!.tokenValue).toBe(-200); + expect(prompt!.rate).toBe(2); + expect(completion!.tokenValue).toBe(-100); + expect(completion!.rate).toBe(2); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal!.tokenCredits).toBe(10000 + -200 + -100); + }); + + it('should prepare and write structured tokens with cache pricing', async () => { + await Balance.create({ user: testUserId, tokenCredits: 5000 }); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(1); + (mockPricing.getCacheMultiplier as jest.Mock).mockImplementation(({ cacheType }) => { + if (cacheType === 'write') { + return 3; + } + if (cacheType === 'read') { + return 0.1; + } + return null; + }); + + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 100, write: 50, read: 200 }, completionTokens: 80 }, + mockPricing, + ); + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs: entries }, dbOps); + + const txns = (await Transaction.find({ user: testUserId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + + const prompt = txns.find((t) => t.tokenType === 'prompt'); + expect(prompt!.inputTokens).toBe(-100); + expect(prompt!.writeTokens).toBe(-50); + expect(prompt!.readTokens).toBe(-200); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal!.tokenCredits).toBeLessThan(5000); + }); +}); diff --git a/packages/api/src/agents/transactions.ts b/packages/api/src/agents/transactions.ts new file mode 100644 index 0000000000..b746392b44 --- /dev/null +++ b/packages/api/src/agents/transactions.ts @@ -0,0 +1,345 @@ +import { CANCEL_RATE } from '@librechat/data-schemas'; +import type { TCustomConfig, TTransactionsConfig } from 'librechat-data-provider'; +import type { TransactionData } from '@librechat/data-schemas'; +import type { EndpointTokenConfig } from '~/types/tokens'; + +interface GetMultiplierParams { + valueKey?: string; + tokenType?: string; + model?: string; + endpointTokenConfig?: EndpointTokenConfig; + inputTokenCount?: number; +} + +interface GetCacheMultiplierParams { + cacheType: 'write' | 'read'; + model?: string; + endpointTokenConfig?: EndpointTokenConfig; +} + +export interface PricingFns { + getMultiplier: (params: GetMultiplierParams) => number; + getCacheMultiplier: (params: GetCacheMultiplierParams) => number | null; +} + +interface BaseTxData { + user: string; + model?: string; + context: string; + messageId?: string; + conversationId: string; + endpointTokenConfig?: EndpointTokenConfig; + balance?: Partial | null; + transactions?: Partial; +} + +interface StandardTxData extends BaseTxData { + tokenType: string; + rawAmount: number; + inputTokenCount?: number; + valueKey?: string; +} + +interface StructuredTxData extends BaseTxData { + tokenType: string; + inputTokens?: number; + writeTokens?: number; + readTokens?: number; + inputTokenCount?: number; + rawAmount?: number; +} + +export interface PreparedEntry { + doc: TransactionData; + tokenValue: number; + balance?: Partial | null; +} + +export interface TokenUsage { + promptTokens?: number; + completionTokens?: number; +} + +export interface StructuredPromptTokens { + input?: number; + write?: number; + read?: number; +} + +export interface StructuredTokenUsage { + promptTokens?: StructuredPromptTokens; + completionTokens?: number; +} + +export interface TxMetadata { + user: string; + model?: string; + context: string; + messageId?: string; + conversationId: string; + balance?: Partial | null; + transactions?: Partial; + endpointTokenConfig?: EndpointTokenConfig; +} + +export interface BulkWriteDeps { + insertMany: (docs: TransactionData[]) => Promise; + updateBalance: (params: { user: string; incrementValue: number }) => Promise; +} + +function calculateTokenValue( + txData: StandardTxData, + pricing: PricingFns, +): { tokenValue: number; rate: number } { + const { tokenType, model, endpointTokenConfig, inputTokenCount, rawAmount, valueKey } = txData; + const multiplier = Math.abs( + pricing.getMultiplier({ valueKey, tokenType, model, endpointTokenConfig, inputTokenCount }), + ); + let rate = multiplier; + let tokenValue = rawAmount * multiplier; + if (txData.context === 'incomplete' && tokenType === 'completion') { + tokenValue = Math.ceil(tokenValue * CANCEL_RATE); + rate *= CANCEL_RATE; + } + return { tokenValue, rate }; +} + +function calculateStructuredTokenValue( + txData: StructuredTxData, + pricing: PricingFns, +): { tokenValue: number; rate: number; rawAmount: number; rateDetail?: Record } { + const { tokenType, model, endpointTokenConfig, inputTokenCount } = txData; + + if (!tokenType) { + return { tokenValue: txData.rawAmount ?? 0, rate: 0, rawAmount: txData.rawAmount ?? 0 }; + } + + if (tokenType === 'prompt') { + const inputMultiplier = pricing.getMultiplier({ + tokenType: 'prompt', + model, + endpointTokenConfig, + inputTokenCount, + }); + const writeMultiplier = + pricing.getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? + inputMultiplier; + const readMultiplier = + pricing.getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? + inputMultiplier; + + const inputAbs = Math.abs(txData.inputTokens ?? 0); + const writeAbs = Math.abs(txData.writeTokens ?? 0); + const readAbs = Math.abs(txData.readTokens ?? 0); + const totalPromptTokens = inputAbs + writeAbs + readAbs; + + const rate = + totalPromptTokens > 0 + ? (Math.abs(inputMultiplier * (txData.inputTokens ?? 0)) + + Math.abs(writeMultiplier * (txData.writeTokens ?? 0)) + + Math.abs(readMultiplier * (txData.readTokens ?? 0))) / + totalPromptTokens + : Math.abs(inputMultiplier); + + const tokenValue = -( + inputAbs * inputMultiplier + + writeAbs * writeMultiplier + + readAbs * readMultiplier + ); + + return { + tokenValue, + rate, + rawAmount: -totalPromptTokens, + rateDetail: { input: inputMultiplier, write: writeMultiplier, read: readMultiplier }, + }; + } + + const multiplier = pricing.getMultiplier({ + tokenType, + model, + endpointTokenConfig, + inputTokenCount, + }); + const rawAmount = -Math.abs(txData.rawAmount ?? 0); + let rate = Math.abs(multiplier); + let tokenValue = rawAmount * multiplier; + + if (txData.context === 'incomplete' && tokenType === 'completion') { + tokenValue = Math.ceil(tokenValue * CANCEL_RATE); + rate *= CANCEL_RATE; + } + + return { tokenValue, rate, rawAmount }; +} + +function prepareStandardTx( + _txData: StandardTxData & { + balance?: Partial | null; + transactions?: Partial; + }, + pricing: PricingFns, +): PreparedEntry | null { + const { balance, transactions, ...txData } = _txData; + if (txData.rawAmount != null && isNaN(txData.rawAmount)) { + return null; + } + if (transactions?.enabled === false) { + return null; + } + + const { tokenValue, rate } = calculateTokenValue(txData, pricing); + return { + doc: { ...txData, tokenValue, rate }, + tokenValue, + balance, + }; +} + +function prepareStructuredTx( + _txData: StructuredTxData & { + balance?: Partial | null; + transactions?: Partial; + }, + pricing: PricingFns, +): PreparedEntry | null { + const { balance, transactions, ...txData } = _txData; + if (transactions?.enabled === false) { + return null; + } + + const { tokenValue, rate, rawAmount, rateDetail } = calculateStructuredTokenValue( + txData, + pricing, + ); + return { + doc: { + ...txData, + tokenValue, + rate, + rawAmount, + ...(rateDetail && { rateDetail }), + }, + tokenValue, + balance, + }; +} + +export function prepareTokenSpend( + txData: TxMetadata, + tokenUsage: TokenUsage, + pricing: PricingFns, +): PreparedEntry[] { + const { promptTokens, completionTokens } = tokenUsage; + const results: PreparedEntry[] = []; + const normalizedPromptTokens = Math.max(promptTokens ?? 0, 0); + + if (promptTokens !== undefined) { + const entry = prepareStandardTx( + { + ...txData, + tokenType: 'prompt', + rawAmount: promptTokens === 0 ? 0 : -normalizedPromptTokens, + inputTokenCount: normalizedPromptTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + if (completionTokens !== undefined) { + const entry = prepareStandardTx( + { + ...txData, + tokenType: 'completion', + rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0), + inputTokenCount: normalizedPromptTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + return results; +} + +export function prepareStructuredTokenSpend( + txData: TxMetadata, + tokenUsage: StructuredTokenUsage, + pricing: PricingFns, +): PreparedEntry[] { + const { promptTokens, completionTokens } = tokenUsage; + const results: PreparedEntry[] = []; + + if (promptTokens) { + const input = Math.max(promptTokens.input ?? 0, 0); + const write = Math.max(promptTokens.write ?? 0, 0); + const read = Math.max(promptTokens.read ?? 0, 0); + const totalInputTokens = input + write + read; + const entry = prepareStructuredTx( + { + ...txData, + tokenType: 'prompt', + inputTokens: -input, + writeTokens: -write, + readTokens: -read, + inputTokenCount: totalInputTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + if (completionTokens) { + const totalInputTokens = promptTokens + ? Math.max(promptTokens.input ?? 0, 0) + + Math.max(promptTokens.write ?? 0, 0) + + Math.max(promptTokens.read ?? 0, 0) + : undefined; + const entry = prepareStandardTx( + { + ...txData, + tokenType: 'completion', + rawAmount: -Math.max(completionTokens, 0), + inputTokenCount: totalInputTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + return results; +} + +export async function bulkWriteTransactions( + { user, docs }: { user: string; docs: PreparedEntry[] }, + dbOps: BulkWriteDeps, +): Promise { + if (!docs.length) { + return; + } + + let totalTokenValue = 0; + let balanceEnabled = false; + const plainDocs = docs.map(({ doc, tokenValue, balance }) => { + if (balance?.enabled) { + balanceEnabled = true; + totalTokenValue += tokenValue; + } + return doc; + }); + + if (balanceEnabled) { + await dbOps.updateBalance({ user, incrementValue: totalTokenValue }); + } + + await dbOps.insertMany(plainDocs); +} diff --git a/packages/api/src/agents/usage.bulk-parity.spec.ts b/packages/api/src/agents/usage.bulk-parity.spec.ts new file mode 100644 index 0000000000..79dd50b2e3 --- /dev/null +++ b/packages/api/src/agents/usage.bulk-parity.spec.ts @@ -0,0 +1,533 @@ +/** + * Bulk path parity tests for recordCollectedUsage. + * + * Every test here mirrors a corresponding legacy-path test in usage.spec.ts. + * The return values (input_tokens, output_tokens) must be identical between paths. + * The docs written to insertMany must carry the same metadata as the args that + * would have been passed to spendTokens/spendStructuredTokens. + */ +import type { UsageMetadata } from '../stream/interfaces/IJobStore'; +import type { RecordUsageDeps, RecordUsageParams } from './usage'; +import type { BulkWriteDeps, PricingFns } from './transactions'; +import { recordCollectedUsage } from './usage'; + +describe('recordCollectedUsage — bulk path parity', () => { + let mockSpendTokens: jest.Mock; + let mockSpendStructuredTokens: jest.Mock; + let mockInsertMany: jest.Mock; + let mockUpdateBalance: jest.Mock; + let mockPricing: PricingFns; + let mockBulkWriteOps: BulkWriteDeps; + let deps: RecordUsageDeps; + + const baseParams: Omit = { + user: 'user-123', + conversationId: 'convo-123', + model: 'gpt-4', + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + }; + + beforeEach(() => { + jest.clearAllMocks(); + mockSpendTokens = jest.fn().mockResolvedValue(undefined); + mockSpendStructuredTokens = jest.fn().mockResolvedValue(undefined); + mockInsertMany = jest.fn().mockResolvedValue(undefined); + mockUpdateBalance = jest.fn().mockResolvedValue({}); + mockPricing = { + getMultiplier: jest.fn().mockReturnValue(1), + getCacheMultiplier: jest.fn().mockReturnValue(null), + }; + mockBulkWriteOps = { + insertMany: mockInsertMany, + updateBalance: mockUpdateBalance, + }; + deps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: mockPricing, + bulkWriteOps: mockBulkWriteOps, + }; + }); + + describe('basic functionality', () => { + it('should return undefined if collectedUsage is empty', async () => { + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage: [] }); + expect(result).toBeUndefined(); + expect(mockInsertMany).not.toHaveBeenCalled(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + + it('should return undefined if collectedUsage is null-ish', async () => { + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage: null as unknown as UsageMetadata[], + }); + expect(result).toBeUndefined(); + expect(mockInsertMany).not.toHaveBeenCalled(); + }); + + it('should handle single usage entry — same return value as legacy path', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockInsertMany).toHaveBeenCalledTimes(1); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(2); + const promptDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'prompt'); + const completionDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'completion'); + expect(promptDoc.user).toBe('user-123'); + expect(promptDoc.conversationId).toBe('convo-123'); + expect(promptDoc.model).toBe('gpt-4'); + expect(promptDoc.context).toBe('message'); + expect(promptDoc.rawAmount).toBe(-100); + expect(completionDoc.rawAmount).toBe(-50); + }); + + it('should skip null entries — same return value as legacy path', async () => { + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + null, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + ] as UsageMetadata[]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 110 }); + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(4); // 2 non-null entries × 2 docs each + }); + }); + + describe('sequential execution (tool calls)', () => { + it('should calculate tokens correctly for sequential tool calls — same totals as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 150, output_tokens: 30, model: 'gpt-4' }, + { input_tokens: 180, output_tokens: 20, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(100); // 50 + 30 + 20 + expect(result?.input_tokens).toBe(100); // first entry's input + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(6); // 3 entries × 2 docs + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + }); + + describe('parallel execution (multiple agents)', () => { + it('should handle parallel agents — same output_tokens total as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 80, output_tokens: 40, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(90); // 50 + 40 + expect(result?.output_tokens).toBeGreaterThan(0); + expect(mockInsertMany).toHaveBeenCalledTimes(1); + }); + + /** Bug regression: parallel agents where second agent has LOWER input tokens produced negative output via incremental calculation. */ + it('should NOT produce negative output_tokens — same positive result as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 200, output_tokens: 100, model: 'gpt-4' }, + { input_tokens: 50, output_tokens: 30, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBeGreaterThan(0); + expect(result?.output_tokens).toBe(130); // 100 + 30 + }); + + it('should calculate correct total output for 3 parallel agents', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 120, output_tokens: 60, model: 'gpt-4-turbo' }, + { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(150); // 50 + 60 + 40 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(6); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + }); + + describe('cache token handling - OpenAI format', () => { + it('should route cache entries to structured path — same input_tokens as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'gpt-4', + input_token_details: { cache_creation: 20, cache_read: 10 }, + }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(130); // 100 + 20 + 10 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + + const docs = mockInsertMany.mock.calls[0][0]; + const promptDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'prompt'); + expect(promptDoc.inputTokens).toBe(-100); + expect(promptDoc.writeTokens).toBe(-20); + expect(promptDoc.readTokens).toBe(-10); + expect(promptDoc.model).toBe('gpt-4'); + }); + }); + + describe('cache token handling - Anthropic format', () => { + it('should route Anthropic cache entries to structured path — same input_tokens as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'claude-3', + cache_creation_input_tokens: 25, + cache_read_input_tokens: 15, + }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(140); // 100 + 25 + 15 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + + const docs = mockInsertMany.mock.calls[0][0]; + const promptDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'prompt'); + expect(promptDoc.inputTokens).toBe(-100); + expect(promptDoc.writeTokens).toBe(-25); + expect(promptDoc.readTokens).toBe(-15); + expect(promptDoc.model).toBe('claude-3'); + }); + }); + + describe('mixed cache and non-cache entries', () => { + it('should handle mixed entries — same output_tokens as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'gpt-4', + input_token_details: { cache_creation: 10, cache_read: 5 }, + }, + { input_tokens: 200, output_tokens: 20, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(100); // 50 + 30 + 20 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(6); // 3 entries × 2 docs each + }); + }); + + describe('model fallback', () => { + it('should use usage.model when available — model lands in doc', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4-turbo' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: 'fallback-model', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].model).toBe('gpt-4-turbo'); + }); + + it('should fallback to param model when usage.model is missing — model lands in doc', async () => { + const collectedUsage: UsageMetadata[] = [{ input_tokens: 100, output_tokens: 50 }]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: 'param-model', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].model).toBe('param-model'); + }); + + it('should fallback to undefined model when both usage.model and param model are missing', async () => { + const collectedUsage: UsageMetadata[] = [{ input_tokens: 100, output_tokens: 50 }]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: undefined, + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].model).toBeUndefined(); + }); + }); + + describe('real-world scenarios', () => { + it('should correctly sum output tokens for sequential tool calls with growing context', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 31596, output_tokens: 151, model: 'claude-opus' }, + { input_tokens: 35368, output_tokens: 150, model: 'claude-opus' }, + { input_tokens: 58362, output_tokens: 295, model: 'claude-opus' }, + { input_tokens: 112604, output_tokens: 193, model: 'claude-opus' }, + { input_tokens: 257440, output_tokens: 2217, model: 'claude-opus' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(31596); + expect(result?.output_tokens).toBe(3006); // 151+150+295+193+2217 + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(10); // 5 entries × 2 docs + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + + it('should handle cache tokens with multiple tool calls — same totals as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 788, + output_tokens: 163, + model: 'claude-opus', + input_token_details: { cache_read: 0, cache_creation: 30808 }, + }, + { + input_tokens: 3802, + output_tokens: 149, + model: 'claude-opus', + input_token_details: { cache_read: 30808, cache_creation: 768 }, + }, + { + input_tokens: 26808, + output_tokens: 225, + model: 'claude-opus', + input_token_details: { cache_read: 31576, cache_creation: 0 }, + }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(31596); // 788 + 30808 + 0 + expect(result?.output_tokens).toBe(537); // 163 + 149 + 225 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + }); + + describe('error handling', () => { + it('should catch bulk write errors — still returns correct result', async () => { + mockInsertMany.mockRejectedValue(new Error('DB error')); + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + }); + + describe('transaction metadata — doc fields match what legacy would pass to spendTokens', () => { + it('should pass all metadata fields to docs', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + const endpointTokenConfig = { 'gpt-4': { prompt: 0.01, completion: 0.03, context: 8192 } }; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-123', + endpointTokenConfig, + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.user).toBe('user-123'); + expect(doc.conversationId).toBe('convo-123'); + expect(doc.model).toBe('gpt-4'); + expect(doc.context).toBe('message'); + expect(doc.messageId).toBe('msg-123'); + } + }); + + it('should use default context "message" when not provided', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].context).toBe('message'); + }); + + it('should allow custom context like "title"', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + context: 'title', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].context).toBe('title'); + }); + }); + + describe('messageId propagation — messageId on every doc', () => { + it('should propagate messageId to all docs', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 10, output_tokens: 5, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-1', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.messageId).toBe('msg-1'); + } + }); + + it('should propagate messageId to structured cache docs', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'claude-3', + cache_creation_input_tokens: 25, + cache_read_input_tokens: 15, + }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-cache-1', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.messageId).toBe('msg-cache-1'); + } + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + }); + + it('should pass undefined messageId when not provided', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 10, output_tokens: 5, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].messageId).toBeUndefined(); + }); + + it('should propagate messageId across all entries in a multi-entry batch', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'gpt-4', + input_token_details: { cache_creation: 10, cache_read: 5 }, + }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-multi', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.messageId).toBe('msg-multi'); + } + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + }); + }); + + describe('balance behavior parity', () => { + it('should not call updateBalance when balance is disabled — same as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + balance: { enabled: false }, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockUpdateBalance).not.toHaveBeenCalled(); + }); + + it('should not insert docs when transactions are disabled — same as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + transactions: { enabled: false }, + collectedUsage, + }); + + expect(mockInsertMany).not.toHaveBeenCalled(); + expect(mockUpdateBalance).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/api/src/agents/usage.spec.ts b/packages/api/src/agents/usage.spec.ts index 1937af5011..d0b065b8ff 100644 --- a/packages/api/src/agents/usage.spec.ts +++ b/packages/api/src/agents/usage.spec.ts @@ -1,6 +1,7 @@ -import { recordCollectedUsage } from './usage'; -import type { RecordUsageDeps, RecordUsageParams } from './usage'; import type { UsageMetadata } from '../stream/interfaces/IJobStore'; +import type { RecordUsageDeps, RecordUsageParams } from './usage'; +import type { BulkWriteDeps, PricingFns } from './transactions'; +import { recordCollectedUsage } from './usage'; describe('recordCollectedUsage', () => { let mockSpendTokens: jest.Mock; @@ -522,4 +523,199 @@ describe('recordCollectedUsage', () => { ); }); }); + + describe('bulk write path', () => { + let mockInsertMany: jest.Mock; + let mockUpdateBalance: jest.Mock; + let mockPricing: PricingFns; + let mockBulkWriteOps: BulkWriteDeps; + let bulkDeps: RecordUsageDeps; + + beforeEach(() => { + mockInsertMany = jest.fn().mockResolvedValue(undefined); + mockUpdateBalance = jest.fn().mockResolvedValue({}); + mockPricing = { + getMultiplier: jest.fn().mockReturnValue(1), + getCacheMultiplier: jest.fn().mockReturnValue(null), + }; + mockBulkWriteOps = { + insertMany: mockInsertMany, + updateBalance: mockUpdateBalance, + }; + bulkDeps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: mockPricing, + bulkWriteOps: mockBulkWriteOps, + }; + }); + + it('should use bulk path when pricing and bulkWriteOps are provided', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + + it('should batch all entries into a single insertMany call', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + { input_tokens: 300, output_tokens: 70, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const insertedDocs = mockInsertMany.mock.calls[0][0]; + expect(insertedDocs.length).toBe(6); // 2 per entry (prompt + completion) + }); + + it('should call updateBalance once when balance is enabled', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(bulkDeps, { + ...baseParams, + balance: { enabled: true }, + collectedUsage, + }); + + expect(mockUpdateBalance).toHaveBeenCalledTimes(1); + expect(mockUpdateBalance).toHaveBeenCalledWith( + expect.objectContaining({ + user: 'user-123', + incrementValue: expect.any(Number), + }), + ); + }); + + it('should not call updateBalance when balance is disabled', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(bulkDeps, { + ...baseParams, + balance: { enabled: false }, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockUpdateBalance).not.toHaveBeenCalled(); + }); + + it('should handle cache tokens via bulk path', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'gpt-4', + input_token_details: { cache_creation: 20, cache_read: 10 }, + }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(result).toBeDefined(); + }); + + it('should handle mixed cache and non-cache entries in bulk', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'gpt-4', + input_token_details: { cache_creation: 10, cache_read: 5 }, + }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(result?.output_tokens).toBe(80); + }); + + it('should fall back to legacy path when pricing is missing', async () => { + const legacyDeps: RecordUsageDeps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + bulkWriteOps: mockBulkWriteOps, + // no pricing + }; + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(legacyDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(1); + expect(mockInsertMany).not.toHaveBeenCalled(); + }); + + it('should fall back to legacy path when bulkWriteOps is missing', async () => { + const legacyDeps: RecordUsageDeps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: mockPricing, + // no bulkWriteOps + }; + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(legacyDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(1); + expect(mockInsertMany).not.toHaveBeenCalled(); + }); + + it('should handle errors in bulk write gracefully', async () => { + mockInsertMany.mockRejectedValue(new Error('DB error')); + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + }); }); diff --git a/packages/api/src/agents/usage.ts b/packages/api/src/agents/usage.ts index 351452d698..c092702730 100644 --- a/packages/api/src/agents/usage.ts +++ b/packages/api/src/agents/usage.ts @@ -1,34 +1,20 @@ import { logger } from '@librechat/data-schemas'; import type { TCustomConfig, TTransactionsConfig } from 'librechat-data-provider'; -import type { UsageMetadata } from '../stream/interfaces/IJobStore'; -import type { EndpointTokenConfig } from '../types/tokens'; - -interface TokenUsage { - promptTokens?: number; - completionTokens?: number; -} - -interface StructuredPromptTokens { - input?: number; - write?: number; - read?: number; -} - -interface StructuredTokenUsage { - promptTokens?: StructuredPromptTokens; - completionTokens?: number; -} - -interface TxMetadata { - user: string; - model?: string; - context: string; - messageId?: string; - conversationId: string; - balance?: Partial | null; - transactions?: Partial; - endpointTokenConfig?: EndpointTokenConfig; -} +import type { + StructuredTokenUsage, + BulkWriteDeps, + PreparedEntry, + TxMetadata, + TokenUsage, + PricingFns, +} from './transactions'; +import type { UsageMetadata } from '~/stream/interfaces/IJobStore'; +import type { EndpointTokenConfig } from '~/types/tokens'; +import { + prepareStructuredTokenSpend, + bulkWriteTransactions, + prepareTokenSpend, +} from './transactions'; type SpendTokensFn = (txData: TxMetadata, tokenUsage: TokenUsage) => Promise; type SpendStructuredTokensFn = ( @@ -39,6 +25,8 @@ type SpendStructuredTokensFn = ( export interface RecordUsageDeps { spendTokens: SpendTokensFn; spendStructuredTokens: SpendStructuredTokensFn; + pricing?: PricingFns; + bulkWriteOps?: BulkWriteDeps; } export interface RecordUsageParams { @@ -61,6 +49,9 @@ export interface RecordUsageResult { /** * Records token usage for collected LLM calls and spends tokens against balance. * This handles both sequential execution (tool calls) and parallel execution (multiple agents). + * + * When `pricing` and `bulkWriteOps` deps are provided, prepares all transaction documents + * in-memory first, then writes them in a single `insertMany` + one `updateBalance` call. */ export async function recordCollectedUsage( deps: RecordUsageDeps, @@ -78,8 +69,6 @@ export async function recordCollectedUsage( context = 'message', } = params; - const { spendTokens, spendStructuredTokens } = deps; - if (!collectedUsage || !collectedUsage.length) { return; } @@ -96,6 +85,11 @@ export async function recordCollectedUsage( let total_output_tokens = 0; + const { pricing, bulkWriteOps } = deps; + const useBulk = pricing && bulkWriteOps; + + const allDocs: PreparedEntry[] = []; + for (const usage of collectedUsage) { if (!usage) { continue; @@ -121,26 +115,68 @@ export async function recordCollectedUsage( model: usage.model ?? model, }; - if (cache_creation > 0 || cache_read > 0) { - spendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[packages/api #recordCollectedUsage] Error spending structured tokens', err); - }); + if (useBulk) { + const entries = + cache_creation > 0 || cache_read > 0 + ? prepareStructuredTokenSpend( + txMetadata, + { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }, + pricing, + ) + : prepareTokenSpend( + txMetadata, + { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }, + pricing, + ); + allDocs.push(...entries); continue; } - spendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[packages/api #recordCollectedUsage] Error spending tokens', err); - }); + if (cache_creation > 0 || cache_read > 0) { + deps + .spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }) + .catch((err) => { + logger.error( + '[packages/api #recordCollectedUsage] Error spending structured tokens', + err, + ); + }); + continue; + } + + deps + .spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }) + .catch((err) => { + logger.error('[packages/api #recordCollectedUsage] Error spending tokens', err); + }); + } + + if (useBulk && allDocs.length > 0) { + try { + await bulkWriteTransactions({ user, docs: allDocs }, bulkWriteOps); + } catch (err) { + logger.error('[packages/api #recordCollectedUsage] Error in bulk write', err); + } } return { diff --git a/packages/api/src/endpoints/openai/config.spec.ts b/packages/api/src/endpoints/openai/config.spec.ts index 2bbe123e63..58b471aa66 100644 --- a/packages/api/src/endpoints/openai/config.spec.ts +++ b/packages/api/src/endpoints/openai/config.spec.ts @@ -872,9 +872,8 @@ describe('getOpenAIConfig', () => { modelOptions, }); - // OpenRouter reasoning object should only include effort, not summary - expect(result.llmConfig.reasoning).toEqual({ - effort: ReasoningEffort.high, + expect(result.llmConfig.modelKwargs).toMatchObject({ + reasoning: { effort: ReasoningEffort.high }, }); expect(result.llmConfig.include_reasoning).toBeUndefined(); expect(result.provider).toBe('openrouter'); @@ -1206,13 +1205,13 @@ describe('getOpenAIConfig', () => { model: 'gpt-4-turbo', temperature: 0.8, streaming: false, - reasoning: { effort: ReasoningEffort.high }, // OpenRouter reasoning object }); expect(result.llmConfig.include_reasoning).toBeUndefined(); // Should NOT have useResponsesApi for OpenRouter expect(result.llmConfig.useResponsesApi).toBeUndefined(); expect(result.llmConfig.maxTokens).toBe(2000); expect(result.llmConfig.modelKwargs).toEqual({ + reasoning: { effort: ReasoningEffort.high }, verbosity: Verbosity.medium, customParam: 'custom-value', plugins: [{ id: 'web' }], // OpenRouter web search format @@ -1482,13 +1481,11 @@ describe('getOpenAIConfig', () => { user: 'openrouter-user', temperature: 0.7, maxTokens: 4000, - reasoning: { - effort: ReasoningEffort.high, - }, apiKey: apiKey, }); expect(result.llmConfig.include_reasoning).toBeUndefined(); expect(result.llmConfig.modelKwargs).toMatchObject({ + reasoning: { effort: ReasoningEffort.high }, top_k: 50, repetition_penalty: 1.1, }); diff --git a/packages/api/src/endpoints/openai/llm.spec.ts b/packages/api/src/endpoints/openai/llm.spec.ts index 3c10179737..a78cc4b87d 100644 --- a/packages/api/src/endpoints/openai/llm.spec.ts +++ b/packages/api/src/endpoints/openai/llm.spec.ts @@ -393,7 +393,9 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort: ReasoningEffort.high }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { + effort: ReasoningEffort.high, + }); expect(result.llmConfig).not.toHaveProperty('include_reasoning'); expect(result.llmConfig.modelKwargs).toHaveProperty('plugins', [{ id: 'web' }]); }); @@ -617,7 +619,9 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort: ReasoningEffort.high }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { + effort: ReasoningEffort.high, + }); expect(result.llmConfig).not.toHaveProperty('include_reasoning'); expect(result.llmConfig).not.toHaveProperty('reasoning_effort'); }); @@ -634,7 +638,9 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort: ReasoningEffort.high }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { + effort: ReasoningEffort.high, + }); }); it.each([ReasoningEffort.xhigh, ReasoningEffort.minimal, ReasoningEffort.none])( @@ -650,7 +656,7 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { effort }); expect(result.llmConfig).not.toHaveProperty('include_reasoning'); }, ); diff --git a/packages/api/src/endpoints/openai/llm.ts b/packages/api/src/endpoints/openai/llm.ts index c659645958..a89f6fce44 100644 --- a/packages/api/src/endpoints/openai/llm.ts +++ b/packages/api/src/endpoints/openai/llm.ts @@ -1,7 +1,6 @@ import { EModelEndpoint, removeNullishValues } from 'librechat-data-provider'; import type { BindToolsInput } from '@langchain/core/language_models/chat_models'; import type { SettingDefinition } from 'librechat-data-provider'; -import type { OpenRouterReasoning } from '@librechat/agents'; import type { AzureOpenAIInput } from '@langchain/openai'; import type { OpenAI } from 'openai'; import type * as t from '~/types'; @@ -231,7 +230,8 @@ export function getOpenAILLMConfig({ * `include_reasoning` is legacy compat that maps to `{ enabled: true }` only when * no `reasoning` object is present, so we intentionally omit it here. */ - llmConfig.reasoning = { effort: reasoning_effort } as OpenRouterReasoning; + modelKwargs.reasoning = { effort: reasoning_effort }; + hasModelKwargs = true; } else { /** No explicit effort; fall back to legacy `include_reasoning` for reasoning token inclusion */ llmConfig.include_reasoning = true; diff --git a/packages/data-schemas/src/methods/index.ts b/packages/data-schemas/src/methods/index.ts index 2f20b67fec..07e7cefc24 100644 --- a/packages/data-schemas/src/methods/index.ts +++ b/packages/data-schemas/src/methods/index.ts @@ -21,6 +21,7 @@ import { createAccessRoleMethods, type AccessRoleMethods } from './accessRole'; import { createUserGroupMethods, type UserGroupMethods } from './userGroup'; import { createAclEntryMethods, type AclEntryMethods } from './aclEntry'; import { createShareMethods, type ShareMethods } from './share'; +import { createTransactionMethods, type TransactionMethods } from './transaction'; export type AllMethods = UserMethods & SessionMethods & @@ -36,7 +37,8 @@ export type AllMethods = UserMethods & AclEntryMethods & ShareMethods & AccessRoleMethods & - PluginAuthMethods; + PluginAuthMethods & + TransactionMethods; /** * Creates all database methods for all collections @@ -59,6 +61,7 @@ export function createMethods(mongoose: typeof import('mongoose')): AllMethods { ...createAclEntryMethods(mongoose), ...createShareMethods(mongoose), ...createPluginAuthMethods(mongoose), + ...createTransactionMethods(mongoose), }; } @@ -78,4 +81,5 @@ export type { ShareMethods, AccessRoleMethods, PluginAuthMethods, + TransactionMethods, }; diff --git a/packages/data-schemas/src/methods/transaction.ts b/packages/data-schemas/src/methods/transaction.ts new file mode 100644 index 0000000000..d521b9e85e --- /dev/null +++ b/packages/data-schemas/src/methods/transaction.ts @@ -0,0 +1,100 @@ +import type { IBalance, TransactionData } from '~/types'; +import logger from '~/config/winston'; + +interface UpdateBalanceParams { + user: string; + incrementValue: number; + setValues?: Partial>; +} + +export function createTransactionMethods(mongoose: typeof import('mongoose')) { + async function updateBalance({ user, incrementValue, setValues }: UpdateBalanceParams) { + const maxRetries = 10; + let delay = 50; + let lastError: Error | null = null; + const Balance = mongoose.models.Balance; + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + const currentBalanceDoc = await Balance.findOne({ user }).lean(); + const currentCredits = currentBalanceDoc?.tokenCredits ?? 0; + const newCredits = Math.max(0, currentCredits + incrementValue); + + const updatePayload = { + $set: { + tokenCredits: newCredits, + ...(setValues ?? {}), + }, + }; + + if (currentBalanceDoc) { + const updatedBalance = await Balance.findOneAndUpdate( + { user, tokenCredits: currentCredits }, + updatePayload, + { new: true }, + ).lean(); + + if (updatedBalance) { + return updatedBalance; + } + lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`); + } else { + try { + const updatedBalance = await Balance.findOneAndUpdate({ user }, updatePayload, { + upsert: true, + new: true, + }).lean(); + + if (updatedBalance) { + return updatedBalance; + } + lastError = new Error( + `Upsert race condition suspected for user ${user} on attempt ${attempt}.`, + ); + } catch (error: unknown) { + if ( + error instanceof Error && + 'code' in error && + (error as { code: number }).code === 11000 + ) { + lastError = error; + } else { + throw error; + } + } + } + } catch (error) { + logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error); + lastError = error instanceof Error ? error : new Error(String(error)); + } + + if (attempt < maxRetries) { + const jitter = Math.random() * delay * 0.5; + await new Promise((resolve) => setTimeout(resolve, delay + jitter)); + delay = Math.min(delay * 2, 2000); + } + } + + logger.error( + `[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`, + ); + throw ( + lastError ?? + new Error( + `Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`, + ) + ); + } + + /** Bypasses document middleware; all computed fields must be pre-calculated before calling. */ + async function bulkInsertTransactions(docs: TransactionData[]): Promise { + const Transaction = mongoose.models.Transaction; + if (docs.length) { + await Transaction.insertMany(docs); + } + } + + return { updateBalance, bulkInsertTransactions }; +} + +export type TransactionMethods = ReturnType; diff --git a/packages/data-schemas/src/types/index.ts b/packages/data-schemas/src/types/index.ts index 38f9f22b50..d467d99d21 100644 --- a/packages/data-schemas/src/types/index.ts +++ b/packages/data-schemas/src/types/index.ts @@ -8,6 +8,7 @@ export * from './convo'; export * from './session'; export * from './balance'; export * from './banner'; +export * from './transaction'; export * from './message'; export * from './agent'; export * from './agentApiKey'; diff --git a/packages/data-schemas/src/types/transaction.ts b/packages/data-schemas/src/types/transaction.ts new file mode 100644 index 0000000000..978d7fd62b --- /dev/null +++ b/packages/data-schemas/src/types/transaction.ts @@ -0,0 +1,17 @@ +export interface TransactionData { + user: string; + conversationId: string; + tokenType: string; + model?: string; + context?: string; + valueKey?: string; + rate?: number; + rawAmount?: number; + tokenValue?: number; + inputTokens?: number; + writeTokens?: number; + readTokens?: number; + messageId?: string; + inputTokenCount?: number; + rateDetail?: Record; +} diff --git a/packages/data-schemas/src/utils/transactions.ts b/packages/data-schemas/src/utils/transactions.ts index 09bbb040c1..26f1f77e7e 100644 --- a/packages/data-schemas/src/utils/transactions.ts +++ b/packages/data-schemas/src/utils/transactions.ts @@ -1,5 +1,7 @@ import logger from '~/config/winston'; +export const CANCEL_RATE = 1.15; + /** * Checks if the connected MongoDB deployment supports transactions * This requires a MongoDB replica set configuration