From e1e204d6cffca9b174ef5511746972b990ef8206 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 1 Mar 2026 12:26:36 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AE=20refactor:=20Bulk=20Transactions?= =?UTF-8?q?=20&=20Balance=20Updates=20for=20Token=20Spending=20(#11996)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: transaction handling by integrating pricing and bulk write operations - Updated `recordCollectedUsage` to accept pricing functions and bulk write operations, improving transaction management. - Refactored `AgentClient` and related controllers to utilize the new transaction handling capabilities, ensuring better performance and accuracy in token spending. - Added tests to validate the new functionality, ensuring correct behavior for both standard and bulk transaction paths. - Introduced a new `transactions.ts` file to encapsulate transaction-related logic and types, enhancing code organization and maintainability. * chore: reorganize imports in agents client controller - Moved `getMultiplier` and `getCacheMultiplier` imports to maintain consistency and clarity in the import structure. - Removed duplicate import of `updateBalance` and `bulkInsertTransactions`, streamlining the code for better readability. * refactor: add TransactionData type and CANCEL_RATE constant to data-schemas Establishes a single source of truth for the transaction document shape and the incomplete-context billing rate constant, both consumed by packages/api and api/. * refactor: use proper types in data-schemas transaction methods - Replace `as unknown as { tokenCredits }` with `lean()` - Use `TransactionData[]` instead of `Record[]` for bulkInsertTransactions parameter - Add JSDoc noting insertMany bypasses document middleware - Remove orphan section comment in methods/index.ts * refactor: use shared types in transactions.ts, fix bulk write logic - Import CANCEL_RATE from data-schemas instead of local duplicate - Import TransactionData from data-schemas for PreparedEntry/BulkWriteDeps - Use tilde alias for EndpointTokenConfig import - Pass valueKey through to getMultiplier - Only sum tokenValue for balance-enabled docs in bulkWriteTransactions - Consolidate two loops into single-pass map * refactor: remove duplicate updateBalance from Transaction.js Import updateBalance from ~/models (sourced from data-schemas) instead of maintaining a second copy. Also import CANCEL_RATE from data-schemas and remove the Balance model import (no longer needed directly). * fix: test real spendCollectedUsage instead of IIFE replica Export spendCollectedUsage from abortMiddleware.js and rewrite the test file to import and test the actual function. Previously the tests ran against a hand-written replica that could silently diverge from the real implementation. * test: add transactions.spec.ts and restore regression comments Add 22 direct unit tests for transactions.ts financial logic covering prepareTokenSpend, prepareStructuredTokenSpend, bulkWriteTransactions, CANCEL_RATE paths, NaN guards, disabled transactions, zero tokens, cache multipliers, and balance-enabled filtering. Restore critical regression documentation comments in recordCollectedUsage.spec.js explaining which production bugs the tests guard against. * fix: widen setValues type to include lastRefill The UpdateBalanceParams.setValues type was Partial> which excluded lastRefill — used by createAutoRefillTransaction. Widen to also pick 'lastRefill'. * test: use real MongoDB for bulkWriteTransactions tests Replace mock-based bulkWriteTransactions tests with real DB tests using MongoMemoryServer. Pure function tests (prepareTokenSpend, prepareStructuredTokenSpend) remain mock-based since they don't touch DB. Add end-to-end integration tests that verify the full prepare → bulk write → DB state pipeline with real Transaction and Balance models. * chore: update @librechat/agents dependency to version 3.1.54 in package-lock.json and related package.json files * test: add bulk path parity tests proving identical DB outcomes Three test suites proving the bulk path (prepareTokenSpend/ prepareStructuredTokenSpend + bulkWriteTransactions) produces numerically identical results to the legacy path for all scenarios: - usage.bulk-parity.spec.ts: mirrors all legacy recordCollectedUsage tests; asserts same return values and verifies metadata fields on the insertMany docs match what spendTokens args would carry - transactions.bulk-parity.spec.ts: real-DB tests using actual getMultiplier/getCacheMultiplier pricing functions; asserts exact tokenValue, rate, rawAmount and balance deductions for standard tokens, structured/cache tokens, CANCEL_RATE, premium pricing, multi-entry batches, and edge cases (NaN, zero, disabled) - Transaction.spec.js: adds describe('Bulk path parity') that mirrors 7 key legacy tests via recordCollectedUsage + bulk deps against real MongoDB, asserting same balance deductions and doc counts * refactor: update llmConfig structure to use modelKwargs for reasoning effort Refactor the llmConfig in getOpenAILLMConfig to store reasoning effort within modelKwargs instead of directly on llmConfig. This change ensures consistency in the configuration structure and improves clarity in the handling of reasoning properties in the tests. * test: update performance checks in processAssistantMessage tests Revise the performance assertions in the processAssistantMessage tests to ensure that each message processing time remains under 100ms, addressing potential ReDoS vulnerabilities. This change enhances the reliability of the tests by focusing on maximum processing time rather than relative ratios. * test: fill parity test gaps — model fallback, abort context, structured edge cases - usage.bulk-parity: add undefined model fallback test - transactions.bulk-parity: add abort context test (txns inserted, balance unchanged when balance not passed), fix readTokens type cast - Transaction.spec: add 3 missing mirrors — balance disabled with transactions enabled, structured transactions disabled, structured balance disabled * fix: deduct balance before inserting transactions to prevent orphaned docs Swap the order in bulkWriteTransactions: updateBalance runs before insertMany. If updateBalance fails (after exhausting retries), no transaction documents are written — avoiding the inconsistent state where transactions exist in MongoDB with no corresponding balance deduction. * chore: import order * test: update config.spec.ts for OpenRouter reasoning in modelKwargs Same fix as llm.spec.ts — OpenRouter reasoning is now passed via modelKwargs instead of llmConfig.reasoning directly. --- api/models/Transaction.js | 149 +---- api/models/Transaction.spec.js | 340 ++++++++++- api/package.json | 2 +- .../agents/__tests__/openai.spec.js | 29 +- .../agents/__tests__/responses.unit.spec.js | 39 +- api/server/controllers/agents/client.js | 95 +-- api/server/controllers/agents/openai.js | 8 +- .../agents/recordCollectedUsage.spec.js | 574 ++++-------------- api/server/controllers/agents/responses.js | 15 +- api/server/middleware/abortMiddleware.js | 70 +-- api/server/middleware/abortMiddleware.spec.js | 333 +++------- api/server/utils/import/importers.spec.js | 9 +- package-lock.json | 10 +- packages/api/package.json | 2 +- packages/api/src/agents/index.ts | 1 + .../agents/transactions.bulk-parity.spec.ts | 559 +++++++++++++++++ packages/api/src/agents/transactions.spec.ts | 474 +++++++++++++++ packages/api/src/agents/transactions.ts | 345 +++++++++++ .../api/src/agents/usage.bulk-parity.spec.ts | 533 ++++++++++++++++ packages/api/src/agents/usage.spec.ts | 200 +++++- packages/api/src/agents/usage.ts | 132 ++-- .../api/src/endpoints/openai/config.spec.ts | 11 +- packages/api/src/endpoints/openai/llm.spec.ts | 14 +- packages/api/src/endpoints/openai/llm.ts | 4 +- packages/data-schemas/src/methods/index.ts | 6 +- .../data-schemas/src/methods/transaction.ts | 100 +++ packages/data-schemas/src/types/index.ts | 1 + .../data-schemas/src/types/transaction.ts | 17 + .../data-schemas/src/utils/transactions.ts | 2 + 29 files changed, 3004 insertions(+), 1070 deletions(-) create mode 100644 packages/api/src/agents/transactions.bulk-parity.spec.ts create mode 100644 packages/api/src/agents/transactions.spec.ts create mode 100644 packages/api/src/agents/transactions.ts create mode 100644 packages/api/src/agents/usage.bulk-parity.spec.ts create mode 100644 packages/data-schemas/src/methods/transaction.ts create mode 100644 packages/data-schemas/src/types/transaction.ts diff --git a/api/models/Transaction.js b/api/models/Transaction.js index e553e2bb3b..7f018e1c30 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,140 +1,7 @@ -const { logger } = require('@librechat/data-schemas'); +const { logger, CANCEL_RATE } = require('@librechat/data-schemas'); const { getMultiplier, getCacheMultiplier } = require('./tx'); -const { Transaction, Balance } = require('~/db/models'); - -const cancelRate = 1.15; - -/** - * Updates a user's token balance based on a transaction using optimistic concurrency control - * without schema changes. Compatible with DocumentDB. - * @async - * @function - * @param {Object} params - The function parameters. - * @param {string|mongoose.Types.ObjectId} params.user - The user ID. - * @param {number} params.incrementValue - The value to increment the balance by (can be negative). - * @param {import('mongoose').UpdateQuery['$set']} [params.setValues] - Optional additional fields to set. - * @returns {Promise} Returns the updated balance document (lean). - * @throws {Error} Throws an error if the update fails after multiple retries. - */ -const updateBalance = async ({ user, incrementValue, setValues }) => { - let maxRetries = 10; // Number of times to retry on conflict - let delay = 50; // Initial retry delay in ms - let lastError = null; - - for (let attempt = 1; attempt <= maxRetries; attempt++) { - let currentBalanceDoc; - try { - // 1. Read the current document state - currentBalanceDoc = await Balance.findOne({ user }).lean(); - const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0; - - // 2. Calculate the desired new state - const potentialNewCredits = currentCredits + incrementValue; - const newCredits = Math.max(0, potentialNewCredits); // Ensure balance doesn't go below zero - - // 3. Prepare the update payload - const updatePayload = { - $set: { - tokenCredits: newCredits, - ...(setValues || {}), // Merge other values to set - }, - }; - - // 4. Attempt the conditional update or upsert - let updatedBalance = null; - if (currentBalanceDoc) { - // --- Document Exists: Perform Conditional Update --- - // Try to update only if the tokenCredits match the value we read (currentCredits) - updatedBalance = await Balance.findOneAndUpdate( - { - user: user, - tokenCredits: currentCredits, // Optimistic lock: condition based on the read value - }, - updatePayload, - { - new: true, // Return the modified document - // lean: true, // .lean() is applied after query execution in Mongoose >= 6 - }, - ).lean(); // Use lean() for plain JS object - - if (updatedBalance) { - // Success! The update was applied based on the expected current state. - return updatedBalance; - } - // If updatedBalance is null, it means tokenCredits changed between read and write (conflict). - lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`); - // Proceed to retry logic below. - } else { - // --- Document Does Not Exist: Perform Conditional Upsert --- - // Try to insert the document, but only if it still doesn't exist. - // Using tokenCredits: {$exists: false} helps prevent race conditions where - // another process creates the doc between our findOne and findOneAndUpdate. - try { - updatedBalance = await Balance.findOneAndUpdate( - { - user: user, - // Attempt to match only if the document doesn't exist OR was just created - // without tokenCredits (less likely but possible). A simple { user } filter - // might also work, relying on the retry for conflicts. - // Let's use a simpler filter and rely on retry for races. - // tokenCredits: { $exists: false } // This condition might be too strict if doc exists with 0 credits - }, - updatePayload, - { - upsert: true, // Create if doesn't exist - new: true, // Return the created/updated document - // setDefaultsOnInsert: true, // Ensure schema defaults are applied on insert - // lean: true, - }, - ).lean(); - - if (updatedBalance) { - // Upsert succeeded (likely created the document) - return updatedBalance; - } - // If null, potentially a rare race condition during upsert. Retry should handle it. - lastError = new Error( - `Upsert race condition suspected for user ${user} on attempt ${attempt}.`, - ); - } catch (error) { - if (error.code === 11000) { - // E11000 duplicate key error on index - // This means another process created the document *just* before our upsert. - // It's a concurrency conflict during creation. We should retry. - lastError = error; // Store the error - // Proceed to retry logic below. - } else { - // Different error, rethrow - throw error; - } - } - } // End if/else (document exists?) - } catch (error) { - // Catch errors from findOne or unexpected findOneAndUpdate errors - logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error); - lastError = error; // Store the error - // Consider stopping retries for non-transient errors, but for now, we retry. - } - - // If we reached here, it means the update failed (conflict or error), wait and retry - if (attempt < maxRetries) { - const jitter = Math.random() * delay * 0.5; // Add jitter to delay - await new Promise((resolve) => setTimeout(resolve, delay + jitter)); - delay = Math.min(delay * 2, 2000); // Exponential backoff with cap - } - } // End for loop (retries) - - // If loop finishes without success, throw the last encountered error or a generic one - logger.error( - `[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`, - ); - throw ( - lastError || - new Error( - `Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`, - ) - ); -}; +const { Transaction } = require('~/db/models'); +const { updateBalance } = require('~/models'); /** Method to calculate and set the tokenValue for a transaction */ function calculateTokenValue(txn) { @@ -145,8 +12,8 @@ function calculateTokenValue(txn) { txn.rate = multiplier; txn.tokenValue = txn.rawAmount * multiplier; if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { - txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); - txn.rate *= cancelRate; + txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE); + txn.rate *= CANCEL_RATE; } } @@ -321,11 +188,11 @@ function calculateStructuredTokenValue(txn) { } if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { - txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); - txn.rate *= cancelRate; + txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE); + txn.rate *= CANCEL_RATE; if (txn.rateDetail) { txn.rateDetail = Object.fromEntries( - Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]), + Object.entries(txn.rateDetail).map(([k, v]) => [k, v * CANCEL_RATE]), ); } } diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index 545c7b2755..f363c472e1 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -1,8 +1,10 @@ const mongoose = require('mongoose'); +const { recordCollectedUsage } = require('@librechat/api'); +const { createMethods } = require('@librechat/data-schemas'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { getMultiplier, getCacheMultiplier, premiumTokenValues, tokenValues } = require('./tx'); const { createTransaction, createStructuredTransaction } = require('./Transaction'); +const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { Balance, Transaction } = require('~/db/models'); let mongoServer; @@ -985,3 +987,339 @@ describe('Premium Token Pricing Integration Tests', () => { expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); }); + +describe('Bulk path parity', () => { + /** + * Each test here mirrors an existing legacy test above, replacing spendTokens/ + * spendStructuredTokens with recordCollectedUsage + bulk deps. + * The balance deduction and transaction document fields must be numerically identical. + */ + let bulkDeps; + let methods; + + beforeEach(() => { + methods = createMethods(mongoose); + bulkDeps = { + spendTokens: () => Promise.resolve(), + spendStructuredTokens: () => Promise.resolve(), + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { + insertMany: methods.bulkInsertTransactions, + updateBalance: methods.updateBalance, + }, + }; + }); + + test('balance should decrease when spending tokens via bulk path', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const promptTokens = 100; + const completionTokens = 50; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model, + context: 'test', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], + }); + + const updatedBalance = await Balance.findOne({ user: userId }); + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: promptTokens, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + const expectedTotalCost = + promptTokens * promptMultiplier + completionTokens * completionMultiplier; + const expectedBalance = initialBalance - expectedTotalCost; + + expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + }); + + test('bulk path should not update balance when balance.enabled is false', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model, + context: 'test', + balance: { enabled: false }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: 100, output_tokens: 50, model }], + }); + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBe(initialBalance); + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); // transactions still recorded + }); + + test('bulk path should not insert when transactions.enabled is false', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'gpt-3.5-turbo', + context: 'test', + balance: { enabled: true }, + transactions: { enabled: false }, + collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(0); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('bulk path handles incomplete context for completion tokens — same CANCEL_RATE as legacy', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 17613154.55; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const promptTokens = 10; + const completionTokens = 50; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-convo', + model, + context: 'incomplete', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + const completionTx = txns.find((t) => t.tokenType === 'completion'); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + expect(completionTx.tokenValue).toBeCloseTo(-completionTokens * completionMultiplier * 1.15, 0); + }); + + test('bulk path structured tokens — balance deduction matches legacy spendStructuredTokens', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 17613154.55; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const promptInput = 11; + const promptWrite = 140522; + const promptRead = 0; + const completionTokens = 5; + const totalInput = promptInput + promptWrite + promptRead; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-convo', + model, + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [ + { + input_tokens: promptInput, + output_tokens: completionTokens, + model, + input_token_details: { cache_creation: promptWrite, cache_read: promptRead }, + }, + ], + }); + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; + + const expectedPromptCost = + promptInput * promptMultiplier + promptWrite * writeMultiplier + promptRead * readMultiplier; + const expectedCompletionCost = completionTokens * completionMultiplier; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + const expectedBalance = initialBalance - expectedTotalCost; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(Math.abs(updatedBalance.tokenCredits - expectedBalance)).toBeLessThan(100); + }); + + test('premium pricing above threshold via bulk path — same balance as legacy', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 250000; + const completionTokens = 500; + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-premium', + model, + context: 'test', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], + }); + + const premiumPromptRate = premiumTokenValues[model].prompt; + const premiumCompletionRate = premiumTokenValues[model].completion; + const expectedCost = + promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('real-world multi-entry batch: 5 sequential tool calls — same total deduction as 5 legacy spendTokens calls', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-5-20251101'; + const calls = [ + { input_tokens: 31596, output_tokens: 151 }, + { input_tokens: 35368, output_tokens: 150 }, + { input_tokens: 58362, output_tokens: 295 }, + { input_tokens: 112604, output_tokens: 193 }, + { input_tokens: 257440, output_tokens: 2217 }, + ]; + + let expectedTotalCost = 0; + for (const { input_tokens, output_tokens } of calls) { + const pm = getMultiplier({ model, tokenType: 'prompt', inputTokenCount: input_tokens }); + const cm = getMultiplier({ model, tokenType: 'completion', inputTokenCount: input_tokens }); + expectedTotalCost += input_tokens * pm + output_tokens * cm; + } + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-sequential', + model, + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + collectedUsage: calls.map((c) => ({ ...c, model })), + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(10); // 5 calls × 2 docs (prompt + completion) + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); + + test('bulk path should save transaction but not update balance when balance disabled, transactions enabled', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'gpt-3.5-turbo', + context: 'test', + balance: { enabled: false }, + transactions: { enabled: true }, + collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + expect(txns[0].rawAmount).toBeDefined(); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('bulk path structured tokens should not save when transactions.enabled is false', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'claude-3-5-sonnet', + context: 'message', + balance: { enabled: true }, + transactions: { enabled: false }, + collectedUsage: [ + { + input_tokens: 10, + output_tokens: 5, + model: 'claude-3-5-sonnet', + input_token_details: { cache_creation: 100, cache_read: 5 }, + }, + ], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(0); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('bulk path structured tokens should save but not update balance when balance disabled', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await recordCollectedUsage(bulkDeps, { + user: userId.toString(), + conversationId: 'test-conversation-id', + model: 'claude-3-5-sonnet', + context: 'message', + balance: { enabled: false }, + transactions: { enabled: true }, + collectedUsage: [ + { + input_tokens: 10, + output_tokens: 5, + model: 'claude-3-5-sonnet', + input_token_details: { cache_creation: 100, cache_read: 5 }, + }, + ], + }); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx.inputTokens).toBe(-10); + expect(promptTx.writeTokens).toBe(-100); + expect(promptTx.readTokens).toBe(-5); + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(initialBalance); + }); +}); diff --git a/api/package.json b/api/package.json index f9c9601a37..3e9350ac34 100644 --- a/api/package.json +++ b/api/package.json @@ -44,7 +44,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", diff --git a/api/server/controllers/agents/__tests__/openai.spec.js b/api/server/controllers/agents/__tests__/openai.spec.js index 8592c79a2d..835343e798 100644 --- a/api/server/controllers/agents/__tests__/openai.spec.js +++ b/api/server/controllers/agents/__tests__/openai.spec.js @@ -82,6 +82,13 @@ jest.mock('~/models/spendTokens', () => ({ spendStructuredTokens: mockSpendStructuredTokens, })); +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + jest.mock('~/server/controllers/agents/callbacks', () => ({ createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), })); @@ -103,6 +110,8 @@ jest.mock('~/models/Agent', () => ({ getAgents: jest.fn().mockResolvedValue([]), })); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); jest.mock('~/models', () => ({ getFiles: jest.fn(), getUserKey: jest.fn(), @@ -112,6 +121,8 @@ jest.mock('~/models', () => ({ getUserCodeFiles: jest.fn(), getToolFilesByIds: jest.fn(), getCodeGeneratedFiles: jest.fn(), + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, })); describe('OpenAIChatCompletionController', () => { @@ -155,7 +166,15 @@ describe('OpenAIChatCompletionController', () => { expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(mockRecordCollectedUsage).toHaveBeenCalledWith( - { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, expect.objectContaining({ user: 'user-123', conversationId: expect.any(String), @@ -182,12 +201,18 @@ describe('OpenAIChatCompletionController', () => { ); }); - it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + it('should pass spendTokens, spendStructuredTokens, pricing, and bulkWriteOps as dependencies', async () => { await OpenAIChatCompletionController(req, res); const [deps] = mockRecordCollectedUsage.mock.calls[0]; expect(deps).toHaveProperty('spendTokens', mockSpendTokens); expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + expect(deps).toHaveProperty('pricing'); + expect(deps.pricing).toHaveProperty('getMultiplier', mockGetMultiplier); + expect(deps.pricing).toHaveProperty('getCacheMultiplier', mockGetCacheMultiplier); + expect(deps).toHaveProperty('bulkWriteOps'); + expect(deps.bulkWriteOps).toHaveProperty('insertMany', mockBulkInsertTransactions); + expect(deps.bulkWriteOps).toHaveProperty('updateBalance', mockUpdateBalance); }); it('should include model from primaryConfig in recordCollectedUsage params', async () => { diff --git a/api/server/controllers/agents/__tests__/responses.unit.spec.js b/api/server/controllers/agents/__tests__/responses.unit.spec.js index e16ca394b2..45ec31fc68 100644 --- a/api/server/controllers/agents/__tests__/responses.unit.spec.js +++ b/api/server/controllers/agents/__tests__/responses.unit.spec.js @@ -106,6 +106,13 @@ jest.mock('~/models/spendTokens', () => ({ spendStructuredTokens: mockSpendStructuredTokens, })); +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + jest.mock('~/server/controllers/agents/callbacks', () => ({ createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()), @@ -131,6 +138,8 @@ jest.mock('~/models/Agent', () => ({ getAgents: jest.fn().mockResolvedValue([]), })); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); jest.mock('~/models', () => ({ getFiles: jest.fn(), getUserKey: jest.fn(), @@ -141,6 +150,8 @@ jest.mock('~/models', () => ({ getUserCodeFiles: jest.fn(), getToolFilesByIds: jest.fn(), getCodeGeneratedFiles: jest.fn(), + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, })); describe('createResponse controller', () => { @@ -184,7 +195,15 @@ describe('createResponse controller', () => { expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(mockRecordCollectedUsage).toHaveBeenCalledWith( - { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, expect.objectContaining({ user: 'user-123', conversationId: expect.any(String), @@ -209,12 +228,18 @@ describe('createResponse controller', () => { ); }); - it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + it('should pass spendTokens, spendStructuredTokens, pricing, and bulkWriteOps as dependencies', async () => { await createResponse(req, res); const [deps] = mockRecordCollectedUsage.mock.calls[0]; expect(deps).toHaveProperty('spendTokens', mockSpendTokens); expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + expect(deps).toHaveProperty('pricing'); + expect(deps.pricing).toHaveProperty('getMultiplier', mockGetMultiplier); + expect(deps.pricing).toHaveProperty('getCacheMultiplier', mockGetCacheMultiplier); + expect(deps).toHaveProperty('bulkWriteOps'); + expect(deps.bulkWriteOps).toHaveProperty('insertMany', mockBulkInsertTransactions); + expect(deps.bulkWriteOps).toHaveProperty('updateBalance', mockUpdateBalance); }); it('should include model from primaryConfig in recordCollectedUsage params', async () => { @@ -244,7 +269,15 @@ describe('createResponse controller', () => { expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(mockRecordCollectedUsage).toHaveBeenCalledWith( - { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, expect.objectContaining({ user: 'user-123', context: 'message', diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index d69281d49c..5f99a0762b 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -13,11 +13,12 @@ const { createSafeUser, initializeAgent, getBalanceConfig, - getProviderConfig, omitTitleOptions, + getProviderConfig, memoryInstructions, - applyContextToAgent, createTokenCounter, + applyContextToAgent, + recordCollectedUsage, GenerationJobManager, getTransactionsConfig, createMemoryProcessor, @@ -45,6 +46,8 @@ const { } = require('librechat-data-provider'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { updateBalance, bulkInsertTransactions } = require('~/models'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const { createContextHandlers } = require('~/app/clients/prompts'); const { getConvoFiles } = require('~/models/Conversation'); const BaseClient = require('~/app/clients/BaseClient'); @@ -624,83 +627,29 @@ class AgentClient extends BaseClient { context = 'message', collectedUsage = this.collectedUsage, }) { - if (!collectedUsage || !collectedUsage.length) { - return; - } - // Use first entry's input_tokens as the base input (represents initial user message context) - // Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens) - const firstUsage = collectedUsage[0]; - const input_tokens = - (firstUsage?.input_tokens || 0) + - (Number(firstUsage?.input_token_details?.cache_creation) || - Number(firstUsage?.cache_creation_input_tokens) || - 0) + - (Number(firstUsage?.input_token_details?.cache_read) || - Number(firstUsage?.cache_read_input_tokens) || - 0); - - // Sum output_tokens directly from all entries - works for both sequential and parallel execution - // This avoids the incremental calculation that produced negative values for parallel agents - let total_output_tokens = 0; - - for (const usage of collectedUsage) { - if (!usage) { - continue; - } - - // Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens) - const cache_creation = - Number(usage.input_token_details?.cache_creation) || - Number(usage.cache_creation_input_tokens) || - 0; - const cache_read = - Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0; - - // Accumulate output tokens for the usage summary - total_output_tokens += Number(usage.output_tokens) || 0; - - const txMetadata = { + const result = await recordCollectedUsage( + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance }, + }, + { + user: this.user ?? this.options.req.user?.id, + conversationId: this.conversationId, + collectedUsage, + model: model ?? this.model ?? this.options.agent.model_parameters.model, context, + messageId: this.responseMessageId, balance, transactions, - messageId: this.responseMessageId, - conversationId: this.conversationId, - user: this.user ?? this.options.req.user?.id, endpointTokenConfig: this.options.endpointTokenConfig, - model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, - }; + }, + ); - if (cache_creation > 0 || cache_read > 0) { - spendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error( - '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens', - err, - ); - }); - continue; - } - spendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error( - '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens', - err, - ); - }); + if (result) { + this.usage = result; } - - this.usage = { - input_tokens, - output_tokens: total_output_tokens, - }; } /** diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js index a083bd9291..e8561f15fe 100644 --- a/api/server/controllers/agents/openai.js +++ b/api/server/controllers/agents/openai.js @@ -25,6 +25,7 @@ const { loadAgentTools, loadToolsForExecution } = require('~/server/services/Too const { createToolEndCallback } = require('~/server/controllers/agents/callbacks'); const { findAccessibleResources } = require('~/server/services/PermissionService'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const { getConvoFiles } = require('~/models/Conversation'); const { getAgent, getAgents } = require('~/models/Agent'); const db = require('~/models'); @@ -493,7 +494,12 @@ const OpenAIChatCompletionController = async (req, res) => { const balanceConfig = getBalanceConfig(appConfig); const transactionsConfig = getTransactionsConfig(appConfig); recordCollectedUsage( - { spendTokens, spendStructuredTokens }, + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, + }, { user: userId, conversationId, diff --git a/api/server/controllers/agents/recordCollectedUsage.spec.js b/api/server/controllers/agents/recordCollectedUsage.spec.js index 6904f2ed39..21720023ca 100644 --- a/api/server/controllers/agents/recordCollectedUsage.spec.js +++ b/api/server/controllers/agents/recordCollectedUsage.spec.js @@ -2,23 +2,37 @@ * Tests for AgentClient.recordCollectedUsage * * This is a critical function that handles token spending for agent LLM calls. - * It must correctly handle: - * - Sequential execution (single agent with tool calls) - * - Parallel execution (multiple agents with independent inputs) - * - Cache token handling (OpenAI and Anthropic formats) + * The client now delegates to the TS recordCollectedUsage from @librechat/api, + * passing pricing and bulkWriteOps deps. */ const { EModelEndpoint } = require('librechat-data-provider'); -// Mock dependencies before requiring the module const mockSpendTokens = jest.fn().mockResolvedValue(); const mockSpendStructuredTokens = jest.fn().mockResolvedValue(); +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); jest.mock('~/models/spendTokens', () => ({ spendTokens: (...args) => mockSpendTokens(...args), spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args), })); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + +jest.mock('~/models', () => ({ + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, +})); + jest.mock('~/config', () => ({ logger: { debug: jest.fn(), @@ -39,6 +53,14 @@ jest.mock('@librechat/agents', () => ({ }), })); +jest.mock('@librechat/api', () => { + const actual = jest.requireActual('@librechat/api'); + return { + ...actual, + recordCollectedUsage: (...args) => mockRecordCollectedUsage(...args), + }; +}); + const AgentClient = require('./client'); describe('AgentClient - recordCollectedUsage', () => { @@ -74,31 +96,66 @@ describe('AgentClient - recordCollectedUsage', () => { }); describe('basic functionality', () => { - it('should return early if collectedUsage is empty', async () => { + it('should delegate to recordCollectedUsage with full deps', async () => { + const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; + + await client.recordCollectedUsage({ + collectedUsage, + balance: { enabled: true }, + transactions: { enabled: true }, + }); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + const [deps, params] = mockRecordCollectedUsage.mock.calls[0]; + + expect(deps).toHaveProperty('spendTokens'); + expect(deps).toHaveProperty('spendStructuredTokens'); + expect(deps).toHaveProperty('pricing'); + expect(deps.pricing).toHaveProperty('getMultiplier'); + expect(deps.pricing).toHaveProperty('getCacheMultiplier'); + expect(deps).toHaveProperty('bulkWriteOps'); + expect(deps.bulkWriteOps).toHaveProperty('insertMany'); + expect(deps.bulkWriteOps).toHaveProperty('updateBalance'); + + expect(params).toEqual( + expect.objectContaining({ + user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + }), + ); + }); + + it('should not set this.usage if collectedUsage is empty (returns undefined)', async () => { + mockRecordCollectedUsage.mockResolvedValue(undefined); + await client.recordCollectedUsage({ collectedUsage: [], balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); expect(client.usage).toBeUndefined(); }); - it('should return early if collectedUsage is null', async () => { + it('should not set this.usage if collectedUsage is null (returns undefined)', async () => { + mockRecordCollectedUsage.mockResolvedValue(undefined); + await client.recordCollectedUsage({ collectedUsage: null, balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).not.toHaveBeenCalled(); expect(client.usage).toBeUndefined(); }); - it('should handle single usage entry correctly', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; + it('should set this.usage from recordCollectedUsage result', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 75 }); + const collectedUsage = [{ input_tokens: 200, output_tokens: 75, model: 'gpt-4' }]; await client.recordCollectedUsage({ collectedUsage, @@ -106,521 +163,122 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ - conversationId: 'convo-123', - user: 'user-123', - model: 'gpt-4', - }), - { promptTokens: 100, completionTokens: 50 }, - ); - expect(client.usage.input_tokens).toBe(100); - expect(client.usage.output_tokens).toBe(50); - }); - - it('should skip null entries in collectedUsage', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - null, - { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); + expect(client.usage).toEqual({ input_tokens: 200, output_tokens: 75 }); }); }); describe('sequential execution (single agent with tool calls)', () => { - it('should calculate tokens correctly for sequential tool calls', async () => { - // Sequential flow: output of call N becomes part of input for call N+1 - // Call 1: input=100, output=50 - // Call 2: input=150 (100+50), output=30 - // Call 3: input=180 (150+30), output=20 + it('should pass all usage entries to recordCollectedUsage', async () => { const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 150, output_tokens: 30, model: 'gpt-4' }, { input_tokens: 180, output_tokens: 20, model: 'gpt-4' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 100 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledTimes(3); - // Total output should be sum of all output_tokens: 50 + 30 + 20 = 100 + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.collectedUsage).toHaveLength(3); expect(client.usage.output_tokens).toBe(100); - expect(client.usage.input_tokens).toBe(100); // First entry's input + expect(client.usage.input_tokens).toBe(100); }); }); describe('parallel execution (multiple agents)', () => { - it('should handle parallel agents with independent input tokens', async () => { - // Parallel agents have INDEPENDENT input tokens (not cumulative) - // Agent A: input=100, output=50 - // Agent B: input=80, output=40 (different context, not 100+50) + it('should pass parallel agent usage to recordCollectedUsage', async () => { const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 80, output_tokens: 40, model: 'gpt-4' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 90 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - // Expected total output: 50 + 40 = 90 - // output_tokens must be positive and should reflect total output + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(client.usage.output_tokens).toBe(90); expect(client.usage.output_tokens).toBeGreaterThan(0); }); - it('should NOT produce negative output_tokens for parallel execution', async () => { - // Critical bug scenario: parallel agents where second agent has LOWER input tokens + /** Bug regression: parallel agents where second agent has LOWER input tokens produced negative output via incremental calculation. */ + it('should NOT produce negative output_tokens', async () => { const collectedUsage = [ { input_tokens: 200, output_tokens: 100, model: 'gpt-4' }, { input_tokens: 50, output_tokens: 30, model: 'gpt-4' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 130 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - // output_tokens MUST be positive for proper token tracking expect(client.usage.output_tokens).toBeGreaterThan(0); - // Correct value should be 100 + 30 = 130 - }); - - it('should calculate correct total output for parallel agents', async () => { - // Three parallel agents with independent contexts - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { input_tokens: 120, output_tokens: 60, model: 'gpt-4-turbo' }, - { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(3); - // Total output should be 50 + 60 + 40 = 150 - expect(client.usage.output_tokens).toBe(150); - }); - - it('should handle worst-case parallel scenario without negative tokens', async () => { - // Extreme case: first agent has very high input, subsequent have low - const collectedUsage = [ - { input_tokens: 1000, output_tokens: 500, model: 'gpt-4' }, - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { input_tokens: 50, output_tokens: 25, model: 'gpt-4' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - // Must be positive, should be 500 + 50 + 25 = 575 - expect(client.usage.output_tokens).toBeGreaterThan(0); - expect(client.usage.output_tokens).toBe(575); + expect(client.usage.output_tokens).toBe(130); }); }); describe('real-world scenarios', () => { - it('should correctly sum output tokens for sequential tool calls with growing context', async () => { - // Real production data: Claude Opus with multiple tool calls - // Context grows as tool results are added, but output_tokens should only count model generations + it('should correctly handle sequential tool calls with growing context', async () => { const collectedUsage = [ - { - input_tokens: 31596, - output_tokens: 151, - total_tokens: 31747, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 35368, - output_tokens: 150, - total_tokens: 35518, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 58362, - output_tokens: 295, - total_tokens: 58657, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 112604, - output_tokens: 193, - total_tokens: 112797, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 257440, - output_tokens: 2217, - total_tokens: 259657, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, + { input_tokens: 31596, output_tokens: 151, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 35368, output_tokens: 150, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 58362, output_tokens: 295, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 112604, output_tokens: 193, model: 'claude-opus-4-5-20251101' }, + { input_tokens: 257440, output_tokens: 2217, model: 'claude-opus-4-5-20251101' }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 31596, output_tokens: 3006 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - // input_tokens should be first entry's input (initial context) expect(client.usage.input_tokens).toBe(31596); - - // output_tokens should be sum of all model outputs: 151 + 150 + 295 + 193 + 2217 = 3006 - // NOT the inflated value from incremental calculation (338,559) expect(client.usage.output_tokens).toBe(3006); - - // Verify spendTokens was called for each entry with correct values - expect(mockSpendTokens).toHaveBeenCalledTimes(5); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { promptTokens: 31596, completionTokens: 151 }, - ); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 5, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { promptTokens: 257440, completionTokens: 2217 }, - ); }); - it('should handle single followup message correctly', async () => { - // Real production data: followup to the above conversation - const collectedUsage = [ - { - input_tokens: 263406, - output_tokens: 257, - total_tokens: 263663, - input_token_details: { cache_read: 0, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(client.usage.input_tokens).toBe(263406); - expect(client.usage.output_tokens).toBe(257); - - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { promptTokens: 263406, completionTokens: 257 }, - ); - }); - - it('should ensure output_tokens > 0 check passes for BaseClient.sendMessage', async () => { - // This verifies the fix for the duplicate token spending bug - // BaseClient.sendMessage checks: if (usage != null && Number(usage[this.outputTokensKey]) > 0) - const collectedUsage = [ - { - input_tokens: 31596, - output_tokens: 151, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 35368, - output_tokens: 150, - model: 'claude-opus-4-5-20251101', - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - const usage = client.getStreamUsage(); - - // The check that was failing before the fix - expect(usage).not.toBeNull(); - expect(Number(usage.output_tokens)).toBeGreaterThan(0); - - // Verify correct value - expect(usage.output_tokens).toBe(301); // 151 + 150 - }); - - it('should correctly handle cache tokens with multiple tool calls', async () => { - // Real production data: Claude Opus with cache tokens (prompt caching) - // First entry has cache_creation, subsequent entries have cache_read + it('should correctly handle cache tokens', async () => { const collectedUsage = [ { input_tokens: 788, output_tokens: 163, - total_tokens: 951, input_token_details: { cache_read: 0, cache_creation: 30808 }, model: 'claude-opus-4-5-20251101', }, - { - input_tokens: 3802, - output_tokens: 149, - total_tokens: 3951, - input_token_details: { cache_read: 30808, cache_creation: 768 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 26808, - output_tokens: 225, - total_tokens: 27033, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 80912, - output_tokens: 204, - total_tokens: 81116, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 136454, - output_tokens: 206, - total_tokens: 136660, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 146316, - output_tokens: 224, - total_tokens: 146540, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 150402, - output_tokens: 1248, - total_tokens: 151650, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 156268, - output_tokens: 139, - total_tokens: 156407, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, - { - input_tokens: 167126, - output_tokens: 2961, - total_tokens: 170087, - input_token_details: { cache_read: 31576, cache_creation: 0 }, - model: 'claude-opus-4-5-20251101', - }, ]; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 31596, output_tokens: 163 }); + await client.recordCollectedUsage({ collectedUsage, balance: { enabled: true }, transactions: { enabled: true }, }); - // input_tokens = first entry's input + cache_creation + cache_read - // = 788 + 30808 + 0 = 31596 expect(client.usage.input_tokens).toBe(31596); - - // output_tokens = sum of all output_tokens - // = 163 + 149 + 225 + 204 + 206 + 224 + 1248 + 139 + 2961 = 5519 - expect(client.usage.output_tokens).toBe(5519); - - // First 2 entries have cache tokens, should use spendStructuredTokens - // Remaining 7 entries have cache_read but no cache_creation, still structured - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(9); - expect(mockSpendTokens).toHaveBeenCalledTimes(0); - - // Verify first entry uses structured tokens with cache_creation - expect(mockSpendStructuredTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { - promptTokens: { input: 788, write: 30808, read: 0 }, - completionTokens: 163, - }, - ); - - // Verify second entry uses structured tokens with both cache_creation and cache_read - expect(mockSpendStructuredTokens).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ model: 'claude-opus-4-5-20251101' }), - { - promptTokens: { input: 3802, write: 768, read: 30808 }, - completionTokens: 149, - }, - ); - }); - }); - - describe('cache token handling', () => { - it('should handle OpenAI format cache tokens (input_token_details)', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'gpt-4', - input_token_details: { - cache_creation: 20, - cache_read: 10, - }, - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4' }), - { - promptTokens: { - input: 100, - write: 20, - read: 10, - }, - completionTokens: 50, - }, - ); - }); - - it('should handle Anthropic format cache tokens (cache_*_input_tokens)', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'claude-3', - cache_creation_input_tokens: 25, - cache_read_input_tokens: 15, - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'claude-3' }), - { - promptTokens: { - input: 100, - write: 25, - read: 15, - }, - completionTokens: 50, - }, - ); - }); - - it('should use spendTokens for entries without cache tokens', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); - }); - - it('should handle mixed cache and non-cache entries', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { - input_tokens: 150, - output_tokens: 30, - model: 'gpt-4', - input_token_details: { cache_creation: 10, cache_read: 5 }, - }, - { input_tokens: 200, output_tokens: 20, model: 'gpt-4' }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - }); - - it('should include cache tokens in total input calculation', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'gpt-4', - input_token_details: { - cache_creation: 20, - cache_read: 10, - }, - }, - ]; - - await client.recordCollectedUsage({ - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - // Total input should include cache tokens: 100 + 20 + 10 = 130 - expect(client.usage.input_tokens).toBe(130); + expect(client.usage.output_tokens).toBe(163); }); }); describe('model fallback', () => { - it('should use usage.model when available', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4-turbo' }]; - - await client.recordCollectedUsage({ - model: 'fallback-model', - collectedUsage, - balance: { enabled: true }, - transactions: { enabled: true }, - }); - - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4-turbo' }), - expect.any(Object), - ); - }); - - it('should fallback to param model when usage.model is missing', async () => { + it('should use param model when available', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; await client.recordCollectedUsage({ @@ -630,14 +288,13 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'param-model' }), - expect.any(Object), - ); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.model).toBe('param-model'); }); it('should fallback to client.model when param model is missing', async () => { client.model = 'client-model'; + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; await client.recordCollectedUsage({ @@ -646,13 +303,12 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'client-model' }), - expect.any(Object), - ); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.model).toBe('client-model'); }); it('should fallback to agent model_parameters.model as last resort', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; await client.recordCollectedUsage({ @@ -661,15 +317,14 @@ describe('AgentClient - recordCollectedUsage', () => { transactions: { enabled: true }, }); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4' }), - expect.any(Object), - ); + const [, params] = mockRecordCollectedUsage.mock.calls[0]; + expect(params.model).toBe('gpt-4'); }); }); describe('getStreamUsage integration', () => { it('should return the usage object set by recordCollectedUsage', async () => { + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; await client.recordCollectedUsage({ @@ -679,10 +334,7 @@ describe('AgentClient - recordCollectedUsage', () => { }); const usage = client.getStreamUsage(); - expect(usage).toEqual({ - input_tokens: 100, - output_tokens: 50, - }); + expect(usage).toEqual({ input_tokens: 100, output_tokens: 50 }); }); it('should return undefined before recordCollectedUsage is called', () => { @@ -690,9 +342,9 @@ describe('AgentClient - recordCollectedUsage', () => { expect(usage).toBeUndefined(); }); + /** Verifies usage passes the check in BaseClient.sendMessage: if (usage != null && Number(usage[this.outputTokensKey]) > 0) */ it('should have output_tokens > 0 for BaseClient.sendMessage check', async () => { - // This test verifies the usage will pass the check in BaseClient.sendMessage: - // if (usage != null && Number(usage[this.outputTokensKey]) > 0) + mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 130 }); const collectedUsage = [ { input_tokens: 200, output_tokens: 100, model: 'gpt-4' }, { input_tokens: 50, output_tokens: 30, model: 'gpt-4' }, diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js index 8ce15766c7..83e6ad6efd 100644 --- a/api/server/controllers/agents/responses.js +++ b/api/server/controllers/agents/responses.js @@ -38,6 +38,7 @@ const { loadAgentTools, loadToolsForExecution } = require('~/server/services/Too const { findAccessibleResources } = require('~/server/services/PermissionService'); const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const { getAgent, getAgents } = require('~/models/Agent'); const db = require('~/models'); @@ -509,7 +510,12 @@ const createResponse = async (req, res) => { const balanceConfig = getBalanceConfig(req.config); const transactionsConfig = getTransactionsConfig(req.config); recordCollectedUsage( - { spendTokens, spendStructuredTokens }, + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, + }, { user: userId, conversationId, @@ -658,7 +664,12 @@ const createResponse = async (req, res) => { const balanceConfig = getBalanceConfig(req.config); const transactionsConfig = getTransactionsConfig(req.config); recordCollectedUsage( - { spendTokens, spendStructuredTokens }, + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, + }, { user: userId, conversationId, diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index acc9299b04..d39b0104a8 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,17 +1,19 @@ const { logger } = require('@librechat/data-schemas'); const { - countTokens, isEnabled, sendEvent, + countTokens, GenerationJobManager, + recordCollectedUsage, sanitizeMessageForTransmit, } = require('@librechat/api'); const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); +const { saveMessage, getConvo, updateBalance, bulkInsertTransactions } = require('~/models'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); +const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const clearPendingReq = require('~/cache/clearPendingReq'); const { sendError } = require('~/server/middleware/error'); -const { saveMessage, getConvo } = require('~/models'); const { abortRun } = require('./abortRun'); /** @@ -40,57 +42,22 @@ async function spendCollectedUsage({ return; } - const spendPromises = []; - - for (const usage of collectedUsage) { - if (!usage) { - continue; - } - - // Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens) - const cache_creation = - Number(usage.input_token_details?.cache_creation) || - Number(usage.cache_creation_input_tokens) || - 0; - const cache_read = - Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0; - - const txMetadata = { + await recordCollectedUsage( + { + spendTokens, + spendStructuredTokens, + pricing: { getMultiplier, getCacheMultiplier }, + bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance }, + }, + { + user: userId, + conversationId, + collectedUsage, context: 'abort', messageId, - conversationId, - user: userId, - model: usage.model ?? fallbackModel, - }; - - if (cache_creation > 0 || cache_read > 0) { - spendPromises.push( - spendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[abortMiddleware] Error spending structured tokens for abort', err); - }), - ); - continue; - } - - spendPromises.push( - spendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[abortMiddleware] Error spending tokens for abort', err); - }), - ); - } - - // Wait for all token spending to complete - await Promise.all(spendPromises); + model: fallbackModel, + }, + ); // Clear the array to prevent double-spending from the AgentClient finally block. // The collectedUsage array is shared by reference with AgentClient.collectedUsage, @@ -301,4 +268,5 @@ const handleAbortError = async (res, req, error, data) => { module.exports = { handleAbort, handleAbortError, + spendCollectedUsage, }; diff --git a/api/server/middleware/abortMiddleware.spec.js b/api/server/middleware/abortMiddleware.spec.js index 93f2ce558b..795814a928 100644 --- a/api/server/middleware/abortMiddleware.spec.js +++ b/api/server/middleware/abortMiddleware.spec.js @@ -4,16 +4,32 @@ * This tests the token spending logic for abort scenarios, * particularly for parallel agents (addedConvo) where multiple * models need their tokens spent. + * + * spendCollectedUsage delegates to recordCollectedUsage from @librechat/api, + * passing pricing + bulkWriteOps deps, with context: 'abort'. + * After spending, it clears the collectedUsage array to prevent double-spending + * from the AgentClient finally block (which shares the same array reference). */ const mockSpendTokens = jest.fn().mockResolvedValue(); const mockSpendStructuredTokens = jest.fn().mockResolvedValue(); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); + +const mockGetMultiplier = jest.fn().mockReturnValue(1); +const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); jest.mock('~/models/spendTokens', () => ({ spendTokens: (...args) => mockSpendTokens(...args), spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args), })); +jest.mock('~/models/tx', () => ({ + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, +})); + jest.mock('@librechat/data-schemas', () => ({ logger: { debug: jest.fn(), @@ -30,6 +46,7 @@ jest.mock('@librechat/api', () => ({ GenerationJobManager: { abortJob: jest.fn(), }, + recordCollectedUsage: mockRecordCollectedUsage, sanitizeMessageForTransmit: jest.fn((msg) => msg), })); @@ -49,94 +66,27 @@ jest.mock('~/server/middleware/error', () => ({ sendError: jest.fn(), })); +const mockUpdateBalance = jest.fn().mockResolvedValue({}); +const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined); jest.mock('~/models', () => ({ saveMessage: jest.fn().mockResolvedValue(), getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }), + updateBalance: mockUpdateBalance, + bulkInsertTransactions: mockBulkInsertTransactions, })); jest.mock('./abortRun', () => ({ abortRun: jest.fn(), })); -// Import the module after mocks are set up -// We need to extract the spendCollectedUsage function for testing -// Since it's not exported, we'll test it through the handleAbort flow +const { spendCollectedUsage } = require('./abortMiddleware'); describe('abortMiddleware - spendCollectedUsage', () => { beforeEach(() => { jest.clearAllMocks(); }); - describe('spendCollectedUsage logic', () => { - // Since spendCollectedUsage is not exported, we test the logic directly - // by replicating the function here for unit testing - - const spendCollectedUsage = async ({ - userId, - conversationId, - collectedUsage, - fallbackModel, - }) => { - if (!collectedUsage || collectedUsage.length === 0) { - return; - } - - const spendPromises = []; - - for (const usage of collectedUsage) { - if (!usage) { - continue; - } - - const cache_creation = - Number(usage.input_token_details?.cache_creation) || - Number(usage.cache_creation_input_tokens) || - 0; - const cache_read = - Number(usage.input_token_details?.cache_read) || - Number(usage.cache_read_input_tokens) || - 0; - - const txMetadata = { - context: 'abort', - conversationId, - user: userId, - model: usage.model ?? fallbackModel, - }; - - if (cache_creation > 0 || cache_read > 0) { - spendPromises.push( - mockSpendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch(() => { - // Log error but don't throw - }), - ); - continue; - } - - spendPromises.push( - mockSpendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch(() => { - // Log error but don't throw - }), - ); - } - - // Wait for all token spending to complete - await Promise.all(spendPromises); - - // Clear the array to prevent double-spending - collectedUsage.length = 0; - }; - + describe('spendCollectedUsage delegation', () => { it('should return early if collectedUsage is empty', async () => { await spendCollectedUsage({ userId: 'user-123', @@ -145,8 +95,7 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockRecordCollectedUsage).not.toHaveBeenCalled(); }); it('should return early if collectedUsage is null', async () => { @@ -157,28 +106,10 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockRecordCollectedUsage).not.toHaveBeenCalled(); }); - it('should skip null entries in collectedUsage', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - null, - { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'gpt-4', - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - }); - - it('should spend tokens for single model', async () => { + it('should call recordCollectedUsage with abort context and full deps', async () => { const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; await spendCollectedUsage({ @@ -186,21 +117,35 @@ describe('abortMiddleware - spendCollectedUsage', () => { conversationId: 'convo-123', collectedUsage, fallbackModel: 'gpt-4', + messageId: 'msg-123', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ - context: 'abort', - conversationId: 'convo-123', + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { + spendTokens: expect.any(Function), + spendStructuredTokens: expect.any(Function), + pricing: { + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, + }, + bulkWriteOps: { + insertMany: mockBulkInsertTransactions, + updateBalance: mockUpdateBalance, + }, + }, + { user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + context: 'abort', + messageId: 'msg-123', model: 'gpt-4', - }), - { promptTokens: 100, completionTokens: 50 }, + }, ); }); - it('should spend tokens for multiple models (parallel agents)', async () => { + it('should pass context abort for multiple models (parallel agents)', async () => { const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, @@ -214,136 +159,17 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(3); - - // Verify each model was called - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'gpt-4' }), - { promptTokens: 100, completionTokens: 50 }, - ); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ model: 'claude-3' }), - { promptTokens: 80, completionTokens: 40 }, - ); - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 3, - expect.objectContaining({ model: 'gemini-pro' }), - { promptTokens: 120, completionTokens: 60 }, - ); - }); - - it('should use fallbackModel when usage.model is missing', async () => { - const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'fallback-model', - }); - - expect(mockSpendTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'fallback-model' }), + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( expect.any(Object), + expect.objectContaining({ + context: 'abort', + collectedUsage, + }), ); }); - it('should use spendStructuredTokens for OpenAI format cache tokens', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'gpt-4', - input_token_details: { - cache_creation: 20, - cache_read: 10, - }, - }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'gpt-4', - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'gpt-4', context: 'abort' }), - { - promptTokens: { - input: 100, - write: 20, - read: 10, - }, - completionTokens: 50, - }, - ); - }); - - it('should use spendStructuredTokens for Anthropic format cache tokens', async () => { - const collectedUsage = [ - { - input_tokens: 100, - output_tokens: 50, - model: 'claude-3', - cache_creation_input_tokens: 25, - cache_read_input_tokens: 15, - }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'claude-3', - }); - - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - expect(mockSpendTokens).not.toHaveBeenCalled(); - expect(mockSpendStructuredTokens).toHaveBeenCalledWith( - expect.objectContaining({ model: 'claude-3' }), - { - promptTokens: { - input: 100, - write: 25, - read: 15, - }, - completionTokens: 50, - }, - ); - }); - - it('should handle mixed cache and non-cache entries', async () => { - const collectedUsage = [ - { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, - { - input_tokens: 150, - output_tokens: 30, - model: 'claude-3', - cache_creation_input_tokens: 20, - cache_read_input_tokens: 10, - }, - { input_tokens: 200, output_tokens: 20, model: 'gemini-pro' }, - ]; - - await spendCollectedUsage({ - userId: 'user-123', - conversationId: 'convo-123', - collectedUsage, - fallbackModel: 'gpt-4', - }); - - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); - }); - it('should handle real-world parallel agent abort scenario', async () => { - // Simulates: Primary agent (gemini) + addedConvo agent (gpt-5) aborted mid-stream const collectedUsage = [ { input_tokens: 31596, output_tokens: 151, model: 'gemini-3-flash-preview' }, { input_tokens: 28000, output_tokens: 120, model: 'gpt-5.2' }, @@ -356,27 +182,24 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gemini-3-flash-preview', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - - // Primary model - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ model: 'gemini-3-flash-preview' }), - { promptTokens: 31596, completionTokens: 151 }, - ); - - // Parallel model (addedConvo) - expect(mockSpendTokens).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ model: 'gpt-5.2' }), - { promptTokens: 28000, completionTokens: 120 }, + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + user: 'user-123', + conversationId: 'convo-123', + context: 'abort', + model: 'gemini-3-flash-preview', + }), ); }); + /** + * Race condition prevention: after abort middleware spends tokens, + * the collectedUsage array is cleared so AgentClient.recordCollectedUsage() + * (which shares the same array reference) sees an empty array and returns early. + */ it('should clear collectedUsage array after spending to prevent double-spending', async () => { - // This tests the race condition fix: after abort middleware spends tokens, - // the collectedUsage array is cleared so AgentClient.recordCollectedUsage() - // (which shares the same array reference) sees an empty array and returns early. const collectedUsage = [ { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, @@ -391,19 +214,16 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - expect(mockSpendTokens).toHaveBeenCalledTimes(2); - - // The array should be cleared after spending + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); expect(collectedUsage.length).toBe(0); }); - it('should await all token spending operations before clearing array', async () => { - // Ensure we don't clear the array before spending completes - let spendCallCount = 0; - mockSpendTokens.mockImplementation(async () => { - spendCallCount++; - // Simulate async delay + it('should await recordCollectedUsage before clearing array', async () => { + let resolved = false; + mockRecordCollectedUsage.mockImplementation(async () => { await new Promise((resolve) => setTimeout(resolve, 10)); + resolved = true; + return { input_tokens: 100, output_tokens: 50 }; }); const collectedUsage = [ @@ -418,10 +238,7 @@ describe('abortMiddleware - spendCollectedUsage', () => { fallbackModel: 'gpt-4', }); - // Both spend calls should have completed - expect(spendCallCount).toBe(2); - - // Array should be cleared after awaiting + expect(resolved).toBe(true); expect(collectedUsage.length).toBe(0); }); }); diff --git a/api/server/utils/import/importers.spec.js b/api/server/utils/import/importers.spec.js index a695a31555..2ddfa76658 100644 --- a/api/server/utils/import/importers.spec.js +++ b/api/server/utils/import/importers.spec.js @@ -1277,12 +1277,9 @@ describe('processAssistantMessage', () => { results.push(duration); }); - // Check if processing time increases exponentially - // In a ReDoS vulnerability, time would roughly double with each size increase - for (let i = 1; i < results.length; i++) { - const ratio = results[i] / results[i - 1]; - expect(ratio).toBeLessThan(3); // Allow for CI environment variability while still catching ReDoS - console.log(`Size ${sizes[i]} processing time ratio: ${ratio}`); + // Each size should complete well under 100ms; a ReDoS would cause exponential blowup + for (let i = 0; i < results.length; i++) { + expect(results[i]).toBeLessThan(100); } // Also test with the exact payload from the security report diff --git a/package-lock.json b/package-lock.json index c03ef33c8d..2b90bbec3e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -59,7 +59,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", @@ -11836,9 +11836,9 @@ } }, "node_modules/@librechat/agents": { - "version": "3.1.53", - "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.53.tgz", - "integrity": "sha512-jK9JHIhQYgr+Ha2FhknEYQmS6Ft3/TGdYIlL6L6EtIq20SIA59r1DvQx/x9sd3wHoHkk6AZumMgqAUTTCaWBIA==", + "version": "3.1.54", + "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.54.tgz", + "integrity": "sha512-OdsE8kDgtIhxs0sR0rG7I5WynbZKAH/j/50OCZEjLdv//jR8Lj6fpL9RCEzRGnu44MjUZgRkSgf2JV3LHsCJiQ==", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "^0.73.0", @@ -43797,7 +43797,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/package.json b/packages/api/package.json index 3ceaeb7a12..f2529ecea5 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -90,7 +90,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.53", + "@librechat/agents": "^3.1.54", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/src/agents/index.ts b/packages/api/src/agents/index.ts index 9d13b3dd8e..47e15b8c28 100644 --- a/packages/api/src/agents/index.ts +++ b/packages/api/src/agents/index.ts @@ -9,6 +9,7 @@ export * from './legacy'; export * from './memory'; export * from './migration'; export * from './openai'; +export * from './transactions'; export * from './usage'; export * from './resources'; export * from './responses'; diff --git a/packages/api/src/agents/transactions.bulk-parity.spec.ts b/packages/api/src/agents/transactions.bulk-parity.spec.ts new file mode 100644 index 0000000000..bf89682d6f --- /dev/null +++ b/packages/api/src/agents/transactions.bulk-parity.spec.ts @@ -0,0 +1,559 @@ +/** + * Real-DB parity tests for the bulk transaction path. + * + * Each test uses the actual getMultiplier/getCacheMultiplier pricing functions + * (the same ones the legacy createTransaction path uses) and runs the bulk path + * against a real MongoMemoryServer instance. + * + * The assertion pattern: compute the expected tokenValue/rate/rawAmount from the + * pricing functions directly, then verify the DB state matches exactly. Since both + * legacy (createTransaction) and bulk (prepareTokenSpend + bulkWriteTransactions) + * call the same pricing functions with the same inputs, their outputs must be + * numerically identical. + */ +import mongoose from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { + CANCEL_RATE, + createMethods, + balanceSchema, + transactionSchema, +} from '@librechat/data-schemas'; +import type { PricingFns, TxMetadata } from './transactions'; +import { + prepareStructuredTokenSpend, + bulkWriteTransactions, + prepareTokenSpend, +} from './transactions'; + +jest.mock('@librechat/data-schemas', () => { + const actual = jest.requireActual('@librechat/data-schemas'); + return { + ...actual, + logger: { debug: jest.fn(), error: jest.fn(), warn: jest.fn(), info: jest.fn() }, + }; +}); + +// Real pricing functions from api/models/tx.js — same ones the legacy path uses +/* eslint-disable @typescript-eslint/no-require-imports */ +const { + getMultiplier, + getCacheMultiplier, + tokenValues, + premiumTokenValues, +} = require('../../../../api/models/tx.js'); +/* eslint-enable @typescript-eslint/no-require-imports */ + +const pricing: PricingFns = { getMultiplier, getCacheMultiplier }; + +let mongoServer: MongoMemoryServer; +let Transaction: mongoose.Model; +let Balance: mongoose.Model; +let dbMethods: ReturnType; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + Transaction = mongoose.models.Transaction || mongoose.model('Transaction', transactionSchema); + Balance = mongoose.models.Balance || mongoose.model('Balance', balanceSchema); + dbMethods = createMethods(mongoose); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.connection.dropDatabase(); +}); + +const dbOps = () => ({ + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, +}); + +function txMeta(user: string, extra: Partial = {}): TxMetadata { + return { + user, + conversationId: 'test-convo', + context: 'test', + balance: { enabled: true }, + transactions: { enabled: true }, + ...extra, + }; +} + +describe('Standard token parity', () => { + test('balance should decrease by promptCost + completionCost — identical to legacy path', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const promptTokens = 100; + const completionTokens = 50; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: promptTokens, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + const expectedCost = promptTokens * promptMultiplier + completionTokens * completionMultiplier; + const expectedBalance = initialBalance - expectedCost; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(expectedBalance, 0); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + const completionTx = txns.find((t) => t.tokenType === 'completion'); + expect(promptTx!.rawAmount).toBe(-promptTokens); + expect(promptTx!.rate).toBe(promptMultiplier); + expect(promptTx!.tokenValue).toBe(-promptTokens * promptMultiplier); + expect(completionTx!.rawAmount).toBe(-completionTokens); + expect(completionTx!.rate).toBe(completionMultiplier); + expect(completionTx!.tokenValue).toBe(-completionTokens * completionMultiplier); + }); + + test('balance unchanged when balance.enabled is false — identical to legacy path', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo', balance: { enabled: false } }), + { promptTokens: 100, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBe(initialBalance); + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); // transactions still inserted + }); + + test('no docs when transactions.enabled is false — identical to legacy path', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo', transactions: { enabled: false } }), + { promptTokens: 100, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(0); + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('abort context — transactions inserted, no balance update when balance not passed', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const entries = prepareTokenSpend( + txMeta(userId, { model, context: 'abort', balance: undefined }), + { promptTokens: 100, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(2); + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBe(initialBalance); + }); + + test('NaN promptTokens — only completion doc inserted, identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo' }), + { promptTokens: NaN, completionTokens: 50 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(1); + expect(txns[0].tokenType).toBe('completion'); + }); + + test('zero tokens produce docs with rawAmount=0, tokenValue=0', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 10000 }); + + const entries = prepareTokenSpend( + txMeta(userId, { model: 'gpt-3.5-turbo' }), + { promptTokens: 0, completionTokens: 0 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + expect(txns.every((t) => t.rawAmount === 0)).toBe(true); + expect(txns.every((t) => t.tokenValue === 0)).toBe(true); + }); +}); + +describe('CANCEL_RATE parity (incomplete context)', () => { + test('CANCEL_RATE applied to completion token — same tokenValue as legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 10000000 }); + + const model = 'claude-3-5-sonnet'; + const completionTokens = 50; + const promptTokens = 10; + + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: promptTokens, + }); + const expectedCompletionTokenValue = Math.ceil( + -completionTokens * completionMultiplier * CANCEL_RATE, + ); + + const entries = prepareTokenSpend( + txMeta(userId, { model, context: 'incomplete' }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const completionTx = txns.find((t) => t.tokenType === 'completion'); + expect(completionTx!.tokenValue).toBe(expectedCompletionTokenValue); + expect(completionTx!.rate).toBeCloseTo(completionMultiplier * CANCEL_RATE, 5); + }); + + test('CANCEL_RATE NOT applied to prompt tokens in incomplete context', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 10000000 }); + + const model = 'claude-3-5-sonnet'; + const promptTokens = 100; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: promptTokens, + }); + + const entries = prepareTokenSpend( + txMeta(userId, { model, context: 'incomplete' }), + { promptTokens, completionTokens: 0 }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx!.rate).toBe(promptMultiplier); // no CANCEL_RATE + }); +}); + +describe('Structured token parity', () => { + test('balance deduction identical to legacy spendStructuredTokens', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 17613154.55; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const tokenUsage = { + promptTokens: { input: 11, write: 140522, read: 0 }, + completionTokens: 5, + }; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: 11 + 140522, + }); + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: 11 + 140522, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; + + const expectedPromptCost = + tokenUsage.promptTokens.input * promptMultiplier + + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; + const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + const expectedBalance = initialBalance - expectedTotalCost; + + const entries = prepareStructuredTokenSpend(txMeta(userId, { model }), tokenUsage, pricing); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(Math.abs((balance.tokenCredits as number) - expectedBalance)).toBeLessThan(100); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx!.inputTokens).toBe(-11); + expect(promptTx!.writeTokens).toBe(-140522); + expect(Math.abs(Number(promptTx!.readTokens ?? 0))).toBe(0); + }); + + test('structured tokens with both cache_creation and cache_read', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-3-5-sonnet'; + const tokenUsage = { + promptTokens: { input: 100, write: 50, read: 30 }, + completionTokens: 80, + }; + const totalInput = 180; + + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: totalInput, + }); + + const expectedPromptCost = 100 * promptMultiplier + 50 * writeMultiplier + 30 * readMultiplier; + const expectedCost = expectedPromptCost + 80 * completionMultiplier; + + const entries = prepareStructuredTokenSpend(txMeta(userId, { model }), tokenUsage, pricing); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + const promptTx = txns.find((t) => t.tokenType === 'prompt'); + expect(promptTx!.inputTokens).toBe(-100); + expect(promptTx!.writeTokens).toBe(-50); + expect(promptTx!.readTokens).toBe(-30); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect( + Math.abs((balance.tokenCredits as number) - (initialBalance - expectedCost)), + ).toBeLessThan(1); + }); + + test('CANCEL_RATE applied to completion in structured incomplete context', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + await Balance.create({ user: userId, tokenCredits: 17613154.55 }); + + const model = 'claude-3-5-sonnet'; + const tokenUsage = { + promptTokens: { input: 10, write: 100, read: 5 }, + completionTokens: 50, + }; + + const completionMultiplier = getMultiplier({ + model, + tokenType: 'completion', + inputTokenCount: 115, + }); + const expectedCompletionTokenValue = Math.ceil(-50 * completionMultiplier * CANCEL_RATE); + + const entries = prepareStructuredTokenSpend( + txMeta(userId, { model, context: 'incomplete' }), + tokenUsage, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const txns = (await Transaction.find({ user: userId }).lean()) as Record[]; + const completionTx = txns.find((t) => t.tokenType === 'completion'); + expect(completionTx!.tokenValue).toBeCloseTo(expectedCompletionTokenValue, 0); + }); +}); + +describe('Premium pricing parity', () => { + test('standard pricing below threshold — identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 100000; + const completionTokens = 500; + + const standardPromptRate = (tokenValues as Record>)[model] + .prompt; + const standardCompletionRate = (tokenValues as Record>)[model] + .completion; + const expectedCost = + promptTokens * standardPromptRate + completionTokens * standardCompletionRate; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('premium pricing above threshold — identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 250000; + const completionTokens = 500; + + const premiumPromptRate = (premiumTokenValues as Record>)[model] + .prompt; + const premiumCompletionRate = (premiumTokenValues as Record>)[ + model + ].completion; + const expectedCost = + promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('standard pricing at exactly the threshold — identical to legacy', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = (premiumTokenValues as Record>)[model] + .threshold; + const completionTokens = 500; + + const standardPromptRate = (tokenValues as Record>)[model] + .prompt; + const standardCompletionRate = (tokenValues as Record>)[model] + .completion; + const expectedCost = + promptTokens * standardPromptRate + completionTokens * standardCompletionRate; + + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); +}); + +describe('Multi-entry batch parity', () => { + test('real-world sequential tool calls — total balance deduction identical to N individual legacy calls', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-5-20251101'; + const calls = [ + { promptTokens: 31596, completionTokens: 151 }, + { promptTokens: 35368, completionTokens: 150 }, + { promptTokens: 58362, completionTokens: 295 }, + { promptTokens: 112604, completionTokens: 193 }, + { promptTokens: 257440, completionTokens: 2217 }, + ]; + + let expectedTotalCost = 0; + const allEntries = []; + for (const { promptTokens, completionTokens } of calls) { + const pm = getMultiplier({ model, tokenType: 'prompt', inputTokenCount: promptTokens }); + const cm = getMultiplier({ model, tokenType: 'completion', inputTokenCount: promptTokens }); + expectedTotalCost += promptTokens * pm + completionTokens * cm; + const entries = prepareTokenSpend( + txMeta(userId, { model }), + { promptTokens, completionTokens }, + pricing, + ); + allEntries.push(...entries); + } + + await bulkWriteTransactions({ user: userId, docs: allEntries }, dbOps()); + + const txns = await Transaction.find({ user: userId }).lean(); + expect(txns).toHaveLength(10); // 5 calls × 2 docs + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); + + test('structured premium above threshold — batch vs individual produce same balance deduction', async () => { + const userId = new mongoose.Types.ObjectId().toString(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const tokenUsage = { + promptTokens: { input: 200000, write: 10000, read: 5000 }, + completionTokens: 1000, + }; + const totalInput = 215000; + + const premiumPromptRate = (premiumTokenValues as Record>)[model] + .prompt; + const premiumCompletionRate = (premiumTokenValues as Record>)[ + model + ].completion; + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + + const expectedPromptCost = + tokenUsage.promptTokens.input * premiumPromptRate + + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; + const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + + expect(totalInput).toBeGreaterThan( + (premiumTokenValues as Record>)[model].threshold, + ); + + const entries = prepareStructuredTokenSpend(txMeta(userId, { model }), tokenUsage, pricing); + await bulkWriteTransactions({ user: userId, docs: entries }, dbOps()); + + const balance = (await Balance.findOne({ user: userId }).lean()) as Record; + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); +}); diff --git a/packages/api/src/agents/transactions.spec.ts b/packages/api/src/agents/transactions.spec.ts new file mode 100644 index 0000000000..99fb7cdd85 --- /dev/null +++ b/packages/api/src/agents/transactions.spec.ts @@ -0,0 +1,474 @@ +import mongoose from 'mongoose'; +import { MongoMemoryServer } from 'mongodb-memory-server'; +import { + CANCEL_RATE, + createMethods, + balanceSchema, + transactionSchema, +} from '@librechat/data-schemas'; +import type { PricingFns, TxMetadata, PreparedEntry } from './transactions'; +import { + prepareStructuredTokenSpend, + bulkWriteTransactions, + prepareTokenSpend, +} from './transactions'; + +jest.mock('@librechat/data-schemas', () => { + const actual = jest.requireActual('@librechat/data-schemas'); + return { + ...actual, + logger: { + debug: jest.fn(), + error: jest.fn(), + warn: jest.fn(), + info: jest.fn(), + }, + }; +}); + +let mongoServer: MongoMemoryServer; +let Transaction: mongoose.Model; +let Balance: mongoose.Model; +let dbMethods: ReturnType; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + Transaction = mongoose.models.Transaction || mongoose.model('Transaction', transactionSchema); + Balance = mongoose.models.Balance || mongoose.model('Balance', balanceSchema); + dbMethods = createMethods(mongoose); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.connection.dropDatabase(); +}); + +const testUserId = new mongoose.Types.ObjectId().toString(); + +const baseTxData: TxMetadata = { + user: testUserId, + context: 'message', + conversationId: 'convo-123', + model: 'gpt-4', + messageId: 'msg-123', + balance: { enabled: true }, + transactions: { enabled: true }, +}; + +const mockPricing: PricingFns = { + getMultiplier: jest.fn().mockReturnValue(2), + getCacheMultiplier: jest.fn().mockReturnValue(null), +}; + +describe('prepareTokenSpend', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should prepare prompt + completion entries', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(2); + expect(entries[0].doc.tokenType).toBe('prompt'); + expect(entries[1].doc.tokenType).toBe('completion'); + }); + + it('should return empty array when transactions disabled', () => { + const txData = { ...baseTxData, transactions: { enabled: false } }; + const entries = prepareTokenSpend( + txData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(0); + }); + + it('should filter out NaN rawAmount entries', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: NaN, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(1); + expect(entries[0].doc.tokenType).toBe('completion'); + }); + + it('should handle promptTokens only', () => { + const entries = prepareTokenSpend(baseTxData, { promptTokens: 100 }, mockPricing); + expect(entries).toHaveLength(1); + expect(entries[0].doc.tokenType).toBe('prompt'); + }); + + it('should handle completionTokens only', () => { + const entries = prepareTokenSpend(baseTxData, { completionTokens: 50 }, mockPricing); + expect(entries).toHaveLength(1); + expect(entries[0].doc.tokenType).toBe('completion'); + }); + + it('should handle zero tokens', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 0, completionTokens: 0 }, + mockPricing, + ); + expect(entries).toHaveLength(2); + expect(entries[0].doc.rawAmount).toBe(0); + expect(entries[1].doc.rawAmount).toBe(0); + }); + + it('should calculate tokenValue using pricing multiplier', () => { + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(3); + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + expect(entries[0].doc.rate).toBe(3); + expect(entries[0].doc.tokenValue).toBe(-100 * 3); + expect(entries[1].doc.rate).toBe(3); + expect(entries[1].doc.tokenValue).toBe(-50 * 3); + }); + + it('should pass valueKey to getMultiplier', () => { + prepareTokenSpend(baseTxData, { promptTokens: 100 }, mockPricing); + expect(mockPricing.getMultiplier).toHaveBeenCalledWith( + expect.objectContaining({ tokenType: 'prompt', model: 'gpt-4' }), + ); + }); + + it('should carry balance config on each entry', () => { + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + for (const entry of entries) { + expect(entry.balance).toEqual({ enabled: true }); + } + }); +}); + +describe('prepareTokenSpend — CANCEL_RATE', () => { + beforeEach(() => { + jest.clearAllMocks(); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(2); + }); + + it('should apply CANCEL_RATE to completion tokens with incomplete context', () => { + const txData: TxMetadata = { ...baseTxData, context: 'incomplete' }; + const entries = prepareTokenSpend( + txData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + const completion = entries.find((e) => e.doc.tokenType === 'completion'); + expect(completion).toBeDefined(); + expect(completion!.doc.rate).toBe(2 * CANCEL_RATE); + expect(completion!.doc.tokenValue).toBe(Math.ceil(-50 * 2 * CANCEL_RATE)); + }); + + it('should NOT apply CANCEL_RATE to prompt tokens with incomplete context', () => { + const txData: TxMetadata = { ...baseTxData, context: 'incomplete' }; + const entries = prepareTokenSpend( + txData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + const prompt = entries.find((e) => e.doc.tokenType === 'prompt'); + expect(prompt!.doc.rate).toBe(2); + }); + + it('should NOT apply CANCEL_RATE for abort context', () => { + const txData: TxMetadata = { ...baseTxData, context: 'abort' }; + const entries = prepareTokenSpend(txData, { completionTokens: 50 }, mockPricing); + expect(entries[0].doc.rate).toBe(2); + }); +}); + +describe('prepareStructuredTokenSpend', () => { + beforeEach(() => { + jest.clearAllMocks(); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(2); + (mockPricing.getCacheMultiplier as jest.Mock).mockReturnValue(null); + }); + + it('should prepare prompt + completion for structured tokens', () => { + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 100, write: 50, read: 30 }, completionTokens: 80 }, + mockPricing, + ); + expect(entries).toHaveLength(2); + expect(entries[0].doc.tokenType).toBe('prompt'); + expect(entries[0].doc.inputTokens).toBe(-100); + expect(entries[0].doc.writeTokens).toBe(-50); + expect(entries[0].doc.readTokens).toBe(-30); + expect(entries[1].doc.tokenType).toBe('completion'); + }); + + it('should use cache multipliers when available', () => { + (mockPricing.getCacheMultiplier as jest.Mock).mockImplementation(({ cacheType }) => { + if (cacheType === 'write') { + return 5; + } + if (cacheType === 'read') { + return 0.5; + } + return null; + }); + + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 100, write: 50, read: 30 }, completionTokens: 0 }, + mockPricing, + ); + const prompt = entries.find((e) => e.doc.tokenType === 'prompt'); + expect(prompt).toBeDefined(); + expect(prompt!.doc.rateDetail).toEqual({ input: 2, write: 5, read: 0.5 }); + }); + + it('should return empty when transactions disabled', () => { + const txData = { ...baseTxData, transactions: { enabled: false } }; + const entries = prepareStructuredTokenSpend( + txData, + { promptTokens: { input: 100 }, completionTokens: 50 }, + mockPricing, + ); + expect(entries).toHaveLength(0); + }); + + it('should handle zero totalPromptTokens (fallback rate)', () => { + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 0, write: 0, read: 0 }, completionTokens: 50 }, + mockPricing, + ); + const prompt = entries.find((e) => e.doc.tokenType === 'prompt'); + expect(prompt).toBeDefined(); + expect(prompt!.doc.rate).toBe(2); + }); +}); + +describe('bulkWriteTransactions (real DB)', () => { + it('should return early for empty docs without DB writes', async () => { + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs: [] }, dbOps); + const txCount = await Transaction.countDocuments(); + expect(txCount).toBe(0); + }); + + it('should insert transaction documents into MongoDB', async () => { + const docs: PreparedEntry[] = [ + { + doc: { + user: testUserId, + conversationId: 'c1', + tokenType: 'prompt', + tokenValue: -200, + rate: 2, + rawAmount: -100, + }, + tokenValue: -200, + balance: { enabled: true }, + }, + { + doc: { + user: testUserId, + conversationId: 'c1', + tokenType: 'completion', + tokenValue: -100, + rate: 2, + rawAmount: -50, + }, + tokenValue: -100, + balance: { enabled: true }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const saved = await Transaction.find({ user: testUserId }).lean(); + expect(saved).toHaveLength(2); + expect(saved.map((t: Record) => t.tokenType).sort()).toEqual([ + 'completion', + 'prompt', + ]); + }); + + it('should create balance document and update credits', async () => { + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -300 }, + tokenValue: -300, + balance: { enabled: true }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal).toBeDefined(); + expect(bal!.tokenCredits).toBe(0); + }); + + it('should NOT update balance when no docs have balance enabled', async () => { + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -100 }, + tokenValue: -100, + balance: { enabled: false }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const txCount = await Transaction.countDocuments({ user: testUserId }); + expect(txCount).toBe(1); + const bal = await Balance.findOne({ user: testUserId }).lean(); + expect(bal).toBeNull(); + }); + + it('should only sum tokenValue from balance-enabled docs', async () => { + await Balance.create({ user: testUserId, tokenCredits: 1000 }); + + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -100 }, + tokenValue: -100, + balance: { enabled: true }, + }, + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'completion', tokenValue: -50 }, + tokenValue: -50, + balance: { enabled: false }, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal!.tokenCredits).toBe(900); + }); + + it('should handle null balance gracefully', async () => { + const docs: PreparedEntry[] = [ + { + doc: { user: testUserId, conversationId: 'c1', tokenType: 'prompt', tokenValue: -100 }, + tokenValue: -100, + balance: null, + }, + ]; + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs }, dbOps); + + const txCount = await Transaction.countDocuments({ user: testUserId }); + expect(txCount).toBe(1); + const bal = await Balance.findOne({ user: testUserId }).lean(); + expect(bal).toBeNull(); + }); +}); + +describe('end-to-end: prepare → bulk write → verify', () => { + it('should prepare, write, and correctly update balance for standard tokens', async () => { + await Balance.create({ user: testUserId, tokenCredits: 10000 }); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(2); + + const entries = prepareTokenSpend( + baseTxData, + { promptTokens: 100, completionTokens: 50 }, + mockPricing, + ); + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs: entries }, dbOps); + + const txns = (await Transaction.find({ user: testUserId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + + const prompt = txns.find((t) => t.tokenType === 'prompt'); + const completion = txns.find((t) => t.tokenType === 'completion'); + expect(prompt!.tokenValue).toBe(-200); + expect(prompt!.rate).toBe(2); + expect(completion!.tokenValue).toBe(-100); + expect(completion!.rate).toBe(2); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal!.tokenCredits).toBe(10000 + -200 + -100); + }); + + it('should prepare and write structured tokens with cache pricing', async () => { + await Balance.create({ user: testUserId, tokenCredits: 5000 }); + (mockPricing.getMultiplier as jest.Mock).mockReturnValue(1); + (mockPricing.getCacheMultiplier as jest.Mock).mockImplementation(({ cacheType }) => { + if (cacheType === 'write') { + return 3; + } + if (cacheType === 'read') { + return 0.1; + } + return null; + }); + + const entries = prepareStructuredTokenSpend( + baseTxData, + { promptTokens: { input: 100, write: 50, read: 200 }, completionTokens: 80 }, + mockPricing, + ); + const dbOps = { + insertMany: dbMethods.bulkInsertTransactions, + updateBalance: dbMethods.updateBalance, + }; + await bulkWriteTransactions({ user: testUserId, docs: entries }, dbOps); + + const txns = (await Transaction.find({ user: testUserId }).lean()) as Record[]; + expect(txns).toHaveLength(2); + + const prompt = txns.find((t) => t.tokenType === 'prompt'); + expect(prompt!.inputTokens).toBe(-100); + expect(prompt!.writeTokens).toBe(-50); + expect(prompt!.readTokens).toBe(-200); + + const bal = (await Balance.findOne({ user: testUserId }).lean()) as Record< + string, + unknown + > | null; + expect(bal!.tokenCredits).toBeLessThan(5000); + }); +}); diff --git a/packages/api/src/agents/transactions.ts b/packages/api/src/agents/transactions.ts new file mode 100644 index 0000000000..b746392b44 --- /dev/null +++ b/packages/api/src/agents/transactions.ts @@ -0,0 +1,345 @@ +import { CANCEL_RATE } from '@librechat/data-schemas'; +import type { TCustomConfig, TTransactionsConfig } from 'librechat-data-provider'; +import type { TransactionData } from '@librechat/data-schemas'; +import type { EndpointTokenConfig } from '~/types/tokens'; + +interface GetMultiplierParams { + valueKey?: string; + tokenType?: string; + model?: string; + endpointTokenConfig?: EndpointTokenConfig; + inputTokenCount?: number; +} + +interface GetCacheMultiplierParams { + cacheType: 'write' | 'read'; + model?: string; + endpointTokenConfig?: EndpointTokenConfig; +} + +export interface PricingFns { + getMultiplier: (params: GetMultiplierParams) => number; + getCacheMultiplier: (params: GetCacheMultiplierParams) => number | null; +} + +interface BaseTxData { + user: string; + model?: string; + context: string; + messageId?: string; + conversationId: string; + endpointTokenConfig?: EndpointTokenConfig; + balance?: Partial | null; + transactions?: Partial; +} + +interface StandardTxData extends BaseTxData { + tokenType: string; + rawAmount: number; + inputTokenCount?: number; + valueKey?: string; +} + +interface StructuredTxData extends BaseTxData { + tokenType: string; + inputTokens?: number; + writeTokens?: number; + readTokens?: number; + inputTokenCount?: number; + rawAmount?: number; +} + +export interface PreparedEntry { + doc: TransactionData; + tokenValue: number; + balance?: Partial | null; +} + +export interface TokenUsage { + promptTokens?: number; + completionTokens?: number; +} + +export interface StructuredPromptTokens { + input?: number; + write?: number; + read?: number; +} + +export interface StructuredTokenUsage { + promptTokens?: StructuredPromptTokens; + completionTokens?: number; +} + +export interface TxMetadata { + user: string; + model?: string; + context: string; + messageId?: string; + conversationId: string; + balance?: Partial | null; + transactions?: Partial; + endpointTokenConfig?: EndpointTokenConfig; +} + +export interface BulkWriteDeps { + insertMany: (docs: TransactionData[]) => Promise; + updateBalance: (params: { user: string; incrementValue: number }) => Promise; +} + +function calculateTokenValue( + txData: StandardTxData, + pricing: PricingFns, +): { tokenValue: number; rate: number } { + const { tokenType, model, endpointTokenConfig, inputTokenCount, rawAmount, valueKey } = txData; + const multiplier = Math.abs( + pricing.getMultiplier({ valueKey, tokenType, model, endpointTokenConfig, inputTokenCount }), + ); + let rate = multiplier; + let tokenValue = rawAmount * multiplier; + if (txData.context === 'incomplete' && tokenType === 'completion') { + tokenValue = Math.ceil(tokenValue * CANCEL_RATE); + rate *= CANCEL_RATE; + } + return { tokenValue, rate }; +} + +function calculateStructuredTokenValue( + txData: StructuredTxData, + pricing: PricingFns, +): { tokenValue: number; rate: number; rawAmount: number; rateDetail?: Record } { + const { tokenType, model, endpointTokenConfig, inputTokenCount } = txData; + + if (!tokenType) { + return { tokenValue: txData.rawAmount ?? 0, rate: 0, rawAmount: txData.rawAmount ?? 0 }; + } + + if (tokenType === 'prompt') { + const inputMultiplier = pricing.getMultiplier({ + tokenType: 'prompt', + model, + endpointTokenConfig, + inputTokenCount, + }); + const writeMultiplier = + pricing.getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? + inputMultiplier; + const readMultiplier = + pricing.getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? + inputMultiplier; + + const inputAbs = Math.abs(txData.inputTokens ?? 0); + const writeAbs = Math.abs(txData.writeTokens ?? 0); + const readAbs = Math.abs(txData.readTokens ?? 0); + const totalPromptTokens = inputAbs + writeAbs + readAbs; + + const rate = + totalPromptTokens > 0 + ? (Math.abs(inputMultiplier * (txData.inputTokens ?? 0)) + + Math.abs(writeMultiplier * (txData.writeTokens ?? 0)) + + Math.abs(readMultiplier * (txData.readTokens ?? 0))) / + totalPromptTokens + : Math.abs(inputMultiplier); + + const tokenValue = -( + inputAbs * inputMultiplier + + writeAbs * writeMultiplier + + readAbs * readMultiplier + ); + + return { + tokenValue, + rate, + rawAmount: -totalPromptTokens, + rateDetail: { input: inputMultiplier, write: writeMultiplier, read: readMultiplier }, + }; + } + + const multiplier = pricing.getMultiplier({ + tokenType, + model, + endpointTokenConfig, + inputTokenCount, + }); + const rawAmount = -Math.abs(txData.rawAmount ?? 0); + let rate = Math.abs(multiplier); + let tokenValue = rawAmount * multiplier; + + if (txData.context === 'incomplete' && tokenType === 'completion') { + tokenValue = Math.ceil(tokenValue * CANCEL_RATE); + rate *= CANCEL_RATE; + } + + return { tokenValue, rate, rawAmount }; +} + +function prepareStandardTx( + _txData: StandardTxData & { + balance?: Partial | null; + transactions?: Partial; + }, + pricing: PricingFns, +): PreparedEntry | null { + const { balance, transactions, ...txData } = _txData; + if (txData.rawAmount != null && isNaN(txData.rawAmount)) { + return null; + } + if (transactions?.enabled === false) { + return null; + } + + const { tokenValue, rate } = calculateTokenValue(txData, pricing); + return { + doc: { ...txData, tokenValue, rate }, + tokenValue, + balance, + }; +} + +function prepareStructuredTx( + _txData: StructuredTxData & { + balance?: Partial | null; + transactions?: Partial; + }, + pricing: PricingFns, +): PreparedEntry | null { + const { balance, transactions, ...txData } = _txData; + if (transactions?.enabled === false) { + return null; + } + + const { tokenValue, rate, rawAmount, rateDetail } = calculateStructuredTokenValue( + txData, + pricing, + ); + return { + doc: { + ...txData, + tokenValue, + rate, + rawAmount, + ...(rateDetail && { rateDetail }), + }, + tokenValue, + balance, + }; +} + +export function prepareTokenSpend( + txData: TxMetadata, + tokenUsage: TokenUsage, + pricing: PricingFns, +): PreparedEntry[] { + const { promptTokens, completionTokens } = tokenUsage; + const results: PreparedEntry[] = []; + const normalizedPromptTokens = Math.max(promptTokens ?? 0, 0); + + if (promptTokens !== undefined) { + const entry = prepareStandardTx( + { + ...txData, + tokenType: 'prompt', + rawAmount: promptTokens === 0 ? 0 : -normalizedPromptTokens, + inputTokenCount: normalizedPromptTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + if (completionTokens !== undefined) { + const entry = prepareStandardTx( + { + ...txData, + tokenType: 'completion', + rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0), + inputTokenCount: normalizedPromptTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + return results; +} + +export function prepareStructuredTokenSpend( + txData: TxMetadata, + tokenUsage: StructuredTokenUsage, + pricing: PricingFns, +): PreparedEntry[] { + const { promptTokens, completionTokens } = tokenUsage; + const results: PreparedEntry[] = []; + + if (promptTokens) { + const input = Math.max(promptTokens.input ?? 0, 0); + const write = Math.max(promptTokens.write ?? 0, 0); + const read = Math.max(promptTokens.read ?? 0, 0); + const totalInputTokens = input + write + read; + const entry = prepareStructuredTx( + { + ...txData, + tokenType: 'prompt', + inputTokens: -input, + writeTokens: -write, + readTokens: -read, + inputTokenCount: totalInputTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + if (completionTokens) { + const totalInputTokens = promptTokens + ? Math.max(promptTokens.input ?? 0, 0) + + Math.max(promptTokens.write ?? 0, 0) + + Math.max(promptTokens.read ?? 0, 0) + : undefined; + const entry = prepareStandardTx( + { + ...txData, + tokenType: 'completion', + rawAmount: -Math.max(completionTokens, 0), + inputTokenCount: totalInputTokens, + }, + pricing, + ); + if (entry) { + results.push(entry); + } + } + + return results; +} + +export async function bulkWriteTransactions( + { user, docs }: { user: string; docs: PreparedEntry[] }, + dbOps: BulkWriteDeps, +): Promise { + if (!docs.length) { + return; + } + + let totalTokenValue = 0; + let balanceEnabled = false; + const plainDocs = docs.map(({ doc, tokenValue, balance }) => { + if (balance?.enabled) { + balanceEnabled = true; + totalTokenValue += tokenValue; + } + return doc; + }); + + if (balanceEnabled) { + await dbOps.updateBalance({ user, incrementValue: totalTokenValue }); + } + + await dbOps.insertMany(plainDocs); +} diff --git a/packages/api/src/agents/usage.bulk-parity.spec.ts b/packages/api/src/agents/usage.bulk-parity.spec.ts new file mode 100644 index 0000000000..79dd50b2e3 --- /dev/null +++ b/packages/api/src/agents/usage.bulk-parity.spec.ts @@ -0,0 +1,533 @@ +/** + * Bulk path parity tests for recordCollectedUsage. + * + * Every test here mirrors a corresponding legacy-path test in usage.spec.ts. + * The return values (input_tokens, output_tokens) must be identical between paths. + * The docs written to insertMany must carry the same metadata as the args that + * would have been passed to spendTokens/spendStructuredTokens. + */ +import type { UsageMetadata } from '../stream/interfaces/IJobStore'; +import type { RecordUsageDeps, RecordUsageParams } from './usage'; +import type { BulkWriteDeps, PricingFns } from './transactions'; +import { recordCollectedUsage } from './usage'; + +describe('recordCollectedUsage — bulk path parity', () => { + let mockSpendTokens: jest.Mock; + let mockSpendStructuredTokens: jest.Mock; + let mockInsertMany: jest.Mock; + let mockUpdateBalance: jest.Mock; + let mockPricing: PricingFns; + let mockBulkWriteOps: BulkWriteDeps; + let deps: RecordUsageDeps; + + const baseParams: Omit = { + user: 'user-123', + conversationId: 'convo-123', + model: 'gpt-4', + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + }; + + beforeEach(() => { + jest.clearAllMocks(); + mockSpendTokens = jest.fn().mockResolvedValue(undefined); + mockSpendStructuredTokens = jest.fn().mockResolvedValue(undefined); + mockInsertMany = jest.fn().mockResolvedValue(undefined); + mockUpdateBalance = jest.fn().mockResolvedValue({}); + mockPricing = { + getMultiplier: jest.fn().mockReturnValue(1), + getCacheMultiplier: jest.fn().mockReturnValue(null), + }; + mockBulkWriteOps = { + insertMany: mockInsertMany, + updateBalance: mockUpdateBalance, + }; + deps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: mockPricing, + bulkWriteOps: mockBulkWriteOps, + }; + }); + + describe('basic functionality', () => { + it('should return undefined if collectedUsage is empty', async () => { + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage: [] }); + expect(result).toBeUndefined(); + expect(mockInsertMany).not.toHaveBeenCalled(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + + it('should return undefined if collectedUsage is null-ish', async () => { + const result = await recordCollectedUsage(deps, { + ...baseParams, + collectedUsage: null as unknown as UsageMetadata[], + }); + expect(result).toBeUndefined(); + expect(mockInsertMany).not.toHaveBeenCalled(); + }); + + it('should handle single usage entry — same return value as legacy path', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockInsertMany).toHaveBeenCalledTimes(1); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(2); + const promptDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'prompt'); + const completionDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'completion'); + expect(promptDoc.user).toBe('user-123'); + expect(promptDoc.conversationId).toBe('convo-123'); + expect(promptDoc.model).toBe('gpt-4'); + expect(promptDoc.context).toBe('message'); + expect(promptDoc.rawAmount).toBe(-100); + expect(completionDoc.rawAmount).toBe(-50); + }); + + it('should skip null entries — same return value as legacy path', async () => { + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + null, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + ] as UsageMetadata[]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 110 }); + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(4); // 2 non-null entries × 2 docs each + }); + }); + + describe('sequential execution (tool calls)', () => { + it('should calculate tokens correctly for sequential tool calls — same totals as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 150, output_tokens: 30, model: 'gpt-4' }, + { input_tokens: 180, output_tokens: 20, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(100); // 50 + 30 + 20 + expect(result?.input_tokens).toBe(100); // first entry's input + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(6); // 3 entries × 2 docs + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + }); + + describe('parallel execution (multiple agents)', () => { + it('should handle parallel agents — same output_tokens total as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 80, output_tokens: 40, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(90); // 50 + 40 + expect(result?.output_tokens).toBeGreaterThan(0); + expect(mockInsertMany).toHaveBeenCalledTimes(1); + }); + + /** Bug regression: parallel agents where second agent has LOWER input tokens produced negative output via incremental calculation. */ + it('should NOT produce negative output_tokens — same positive result as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 200, output_tokens: 100, model: 'gpt-4' }, + { input_tokens: 50, output_tokens: 30, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBeGreaterThan(0); + expect(result?.output_tokens).toBe(130); // 100 + 30 + }); + + it('should calculate correct total output for 3 parallel agents', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 120, output_tokens: 60, model: 'gpt-4-turbo' }, + { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(150); // 50 + 60 + 40 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(6); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + }); + + describe('cache token handling - OpenAI format', () => { + it('should route cache entries to structured path — same input_tokens as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'gpt-4', + input_token_details: { cache_creation: 20, cache_read: 10 }, + }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(130); // 100 + 20 + 10 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + + const docs = mockInsertMany.mock.calls[0][0]; + const promptDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'prompt'); + expect(promptDoc.inputTokens).toBe(-100); + expect(promptDoc.writeTokens).toBe(-20); + expect(promptDoc.readTokens).toBe(-10); + expect(promptDoc.model).toBe('gpt-4'); + }); + }); + + describe('cache token handling - Anthropic format', () => { + it('should route Anthropic cache entries to structured path — same input_tokens as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'claude-3', + cache_creation_input_tokens: 25, + cache_read_input_tokens: 15, + }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(140); // 100 + 25 + 15 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + + const docs = mockInsertMany.mock.calls[0][0]; + const promptDoc = docs.find((d: { tokenType: string }) => d.tokenType === 'prompt'); + expect(promptDoc.inputTokens).toBe(-100); + expect(promptDoc.writeTokens).toBe(-25); + expect(promptDoc.readTokens).toBe(-15); + expect(promptDoc.model).toBe('claude-3'); + }); + }); + + describe('mixed cache and non-cache entries', () => { + it('should handle mixed entries — same output_tokens as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'gpt-4', + input_token_details: { cache_creation: 10, cache_read: 5 }, + }, + { input_tokens: 200, output_tokens: 20, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.output_tokens).toBe(100); // 50 + 30 + 20 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(6); // 3 entries × 2 docs each + }); + }); + + describe('model fallback', () => { + it('should use usage.model when available — model lands in doc', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4-turbo' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: 'fallback-model', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].model).toBe('gpt-4-turbo'); + }); + + it('should fallback to param model when usage.model is missing — model lands in doc', async () => { + const collectedUsage: UsageMetadata[] = [{ input_tokens: 100, output_tokens: 50 }]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: 'param-model', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].model).toBe('param-model'); + }); + + it('should fallback to undefined model when both usage.model and param model are missing', async () => { + const collectedUsage: UsageMetadata[] = [{ input_tokens: 100, output_tokens: 50 }]; + + await recordCollectedUsage(deps, { + ...baseParams, + model: undefined, + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].model).toBeUndefined(); + }); + }); + + describe('real-world scenarios', () => { + it('should correctly sum output tokens for sequential tool calls with growing context', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 31596, output_tokens: 151, model: 'claude-opus' }, + { input_tokens: 35368, output_tokens: 150, model: 'claude-opus' }, + { input_tokens: 58362, output_tokens: 295, model: 'claude-opus' }, + { input_tokens: 112604, output_tokens: 193, model: 'claude-opus' }, + { input_tokens: 257440, output_tokens: 2217, model: 'claude-opus' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(31596); + expect(result?.output_tokens).toBe(3006); // 151+150+295+193+2217 + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs).toHaveLength(10); // 5 entries × 2 docs + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + + it('should handle cache tokens with multiple tool calls — same totals as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 788, + output_tokens: 163, + model: 'claude-opus', + input_token_details: { cache_read: 0, cache_creation: 30808 }, + }, + { + input_tokens: 3802, + output_tokens: 149, + model: 'claude-opus', + input_token_details: { cache_read: 30808, cache_creation: 768 }, + }, + { + input_tokens: 26808, + output_tokens: 225, + model: 'claude-opus', + input_token_details: { cache_read: 31576, cache_creation: 0 }, + }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result?.input_tokens).toBe(31596); // 788 + 30808 + 0 + expect(result?.output_tokens).toBe(537); // 163 + 149 + 225 + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(mockSpendTokens).not.toHaveBeenCalled(); + }); + }); + + describe('error handling', () => { + it('should catch bulk write errors — still returns correct result', async () => { + mockInsertMany.mockRejectedValue(new Error('DB error')); + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(deps, { ...baseParams, collectedUsage }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + }); + + describe('transaction metadata — doc fields match what legacy would pass to spendTokens', () => { + it('should pass all metadata fields to docs', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + const endpointTokenConfig = { 'gpt-4': { prompt: 0.01, completion: 0.03, context: 8192 } }; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-123', + endpointTokenConfig, + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.user).toBe('user-123'); + expect(doc.conversationId).toBe('convo-123'); + expect(doc.model).toBe('gpt-4'); + expect(doc.context).toBe('message'); + expect(doc.messageId).toBe('msg-123'); + } + }); + + it('should use default context "message" when not provided', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].context).toBe('message'); + }); + + it('should allow custom context like "title"', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + context: 'title', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].context).toBe('title'); + }); + }); + + describe('messageId propagation — messageId on every doc', () => { + it('should propagate messageId to all docs', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 10, output_tokens: 5, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-1', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.messageId).toBe('msg-1'); + } + }); + + it('should propagate messageId to structured cache docs', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'claude-3', + cache_creation_input_tokens: 25, + cache_read_input_tokens: 15, + }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-cache-1', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.messageId).toBe('msg-cache-1'); + } + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + }); + + it('should pass undefined messageId when not provided', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 10, output_tokens: 5, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + user: 'user-123', + conversationId: 'convo-123', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + expect(docs[0].messageId).toBeUndefined(); + }); + + it('should propagate messageId across all entries in a multi-entry batch', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'gpt-4', + input_token_details: { cache_creation: 10, cache_read: 5 }, + }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + messageId: 'msg-multi', + collectedUsage, + }); + + const docs = mockInsertMany.mock.calls[0][0]; + for (const doc of docs) { + expect(doc.messageId).toBe('msg-multi'); + } + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + }); + }); + + describe('balance behavior parity', () => { + it('should not call updateBalance when balance is disabled — same as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + balance: { enabled: false }, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockUpdateBalance).not.toHaveBeenCalled(); + }); + + it('should not insert docs when transactions are disabled — same as legacy', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(deps, { + ...baseParams, + transactions: { enabled: false }, + collectedUsage, + }); + + expect(mockInsertMany).not.toHaveBeenCalled(); + expect(mockUpdateBalance).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/api/src/agents/usage.spec.ts b/packages/api/src/agents/usage.spec.ts index 1937af5011..d0b065b8ff 100644 --- a/packages/api/src/agents/usage.spec.ts +++ b/packages/api/src/agents/usage.spec.ts @@ -1,6 +1,7 @@ -import { recordCollectedUsage } from './usage'; -import type { RecordUsageDeps, RecordUsageParams } from './usage'; import type { UsageMetadata } from '../stream/interfaces/IJobStore'; +import type { RecordUsageDeps, RecordUsageParams } from './usage'; +import type { BulkWriteDeps, PricingFns } from './transactions'; +import { recordCollectedUsage } from './usage'; describe('recordCollectedUsage', () => { let mockSpendTokens: jest.Mock; @@ -522,4 +523,199 @@ describe('recordCollectedUsage', () => { ); }); }); + + describe('bulk write path', () => { + let mockInsertMany: jest.Mock; + let mockUpdateBalance: jest.Mock; + let mockPricing: PricingFns; + let mockBulkWriteOps: BulkWriteDeps; + let bulkDeps: RecordUsageDeps; + + beforeEach(() => { + mockInsertMany = jest.fn().mockResolvedValue(undefined); + mockUpdateBalance = jest.fn().mockResolvedValue({}); + mockPricing = { + getMultiplier: jest.fn().mockReturnValue(1), + getCacheMultiplier: jest.fn().mockReturnValue(null), + }; + mockBulkWriteOps = { + insertMany: mockInsertMany, + updateBalance: mockUpdateBalance, + }; + bulkDeps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: mockPricing, + bulkWriteOps: mockBulkWriteOps, + }; + }); + + it('should use bulk path when pricing and bulkWriteOps are provided', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + + it('should batch all entries into a single insertMany call', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + { input_tokens: 300, output_tokens: 70, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + const insertedDocs = mockInsertMany.mock.calls[0][0]; + expect(insertedDocs.length).toBe(6); // 2 per entry (prompt + completion) + }); + + it('should call updateBalance once when balance is enabled', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(bulkDeps, { + ...baseParams, + balance: { enabled: true }, + collectedUsage, + }); + + expect(mockUpdateBalance).toHaveBeenCalledTimes(1); + expect(mockUpdateBalance).toHaveBeenCalledWith( + expect.objectContaining({ + user: 'user-123', + incrementValue: expect.any(Number), + }), + ); + }); + + it('should not call updateBalance when balance is disabled', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(bulkDeps, { + ...baseParams, + balance: { enabled: false }, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockUpdateBalance).not.toHaveBeenCalled(); + }); + + it('should handle cache tokens via bulk path', async () => { + const collectedUsage: UsageMetadata[] = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'gpt-4', + input_token_details: { cache_creation: 20, cache_read: 10 }, + }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(result).toBeDefined(); + }); + + it('should handle mixed cache and non-cache entries in bulk', async () => { + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'gpt-4', + input_token_details: { cache_creation: 10, cache_read: 5 }, + }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockInsertMany).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + expect(result?.output_tokens).toBe(80); + }); + + it('should fall back to legacy path when pricing is missing', async () => { + const legacyDeps: RecordUsageDeps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + bulkWriteOps: mockBulkWriteOps, + // no pricing + }; + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(legacyDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(1); + expect(mockInsertMany).not.toHaveBeenCalled(); + }); + + it('should fall back to legacy path when bulkWriteOps is missing', async () => { + const legacyDeps: RecordUsageDeps = { + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, + pricing: mockPricing, + // no bulkWriteOps + }; + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + await recordCollectedUsage(legacyDeps, { + ...baseParams, + collectedUsage, + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(1); + expect(mockInsertMany).not.toHaveBeenCalled(); + }); + + it('should handle errors in bulk write gracefully', async () => { + mockInsertMany.mockRejectedValue(new Error('DB error')); + + const collectedUsage: UsageMetadata[] = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + ]; + + const result = await recordCollectedUsage(bulkDeps, { + ...baseParams, + collectedUsage, + }); + + expect(result).toEqual({ input_tokens: 100, output_tokens: 50 }); + }); + }); }); diff --git a/packages/api/src/agents/usage.ts b/packages/api/src/agents/usage.ts index 351452d698..c092702730 100644 --- a/packages/api/src/agents/usage.ts +++ b/packages/api/src/agents/usage.ts @@ -1,34 +1,20 @@ import { logger } from '@librechat/data-schemas'; import type { TCustomConfig, TTransactionsConfig } from 'librechat-data-provider'; -import type { UsageMetadata } from '../stream/interfaces/IJobStore'; -import type { EndpointTokenConfig } from '../types/tokens'; - -interface TokenUsage { - promptTokens?: number; - completionTokens?: number; -} - -interface StructuredPromptTokens { - input?: number; - write?: number; - read?: number; -} - -interface StructuredTokenUsage { - promptTokens?: StructuredPromptTokens; - completionTokens?: number; -} - -interface TxMetadata { - user: string; - model?: string; - context: string; - messageId?: string; - conversationId: string; - balance?: Partial | null; - transactions?: Partial; - endpointTokenConfig?: EndpointTokenConfig; -} +import type { + StructuredTokenUsage, + BulkWriteDeps, + PreparedEntry, + TxMetadata, + TokenUsage, + PricingFns, +} from './transactions'; +import type { UsageMetadata } from '~/stream/interfaces/IJobStore'; +import type { EndpointTokenConfig } from '~/types/tokens'; +import { + prepareStructuredTokenSpend, + bulkWriteTransactions, + prepareTokenSpend, +} from './transactions'; type SpendTokensFn = (txData: TxMetadata, tokenUsage: TokenUsage) => Promise; type SpendStructuredTokensFn = ( @@ -39,6 +25,8 @@ type SpendStructuredTokensFn = ( export interface RecordUsageDeps { spendTokens: SpendTokensFn; spendStructuredTokens: SpendStructuredTokensFn; + pricing?: PricingFns; + bulkWriteOps?: BulkWriteDeps; } export interface RecordUsageParams { @@ -61,6 +49,9 @@ export interface RecordUsageResult { /** * Records token usage for collected LLM calls and spends tokens against balance. * This handles both sequential execution (tool calls) and parallel execution (multiple agents). + * + * When `pricing` and `bulkWriteOps` deps are provided, prepares all transaction documents + * in-memory first, then writes them in a single `insertMany` + one `updateBalance` call. */ export async function recordCollectedUsage( deps: RecordUsageDeps, @@ -78,8 +69,6 @@ export async function recordCollectedUsage( context = 'message', } = params; - const { spendTokens, spendStructuredTokens } = deps; - if (!collectedUsage || !collectedUsage.length) { return; } @@ -96,6 +85,11 @@ export async function recordCollectedUsage( let total_output_tokens = 0; + const { pricing, bulkWriteOps } = deps; + const useBulk = pricing && bulkWriteOps; + + const allDocs: PreparedEntry[] = []; + for (const usage of collectedUsage) { if (!usage) { continue; @@ -121,26 +115,68 @@ export async function recordCollectedUsage( model: usage.model ?? model, }; - if (cache_creation > 0 || cache_read > 0) { - spendStructuredTokens(txMetadata, { - promptTokens: { - input: usage.input_tokens, - write: cache_creation, - read: cache_read, - }, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[packages/api #recordCollectedUsage] Error spending structured tokens', err); - }); + if (useBulk) { + const entries = + cache_creation > 0 || cache_read > 0 + ? prepareStructuredTokenSpend( + txMetadata, + { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }, + pricing, + ) + : prepareTokenSpend( + txMetadata, + { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }, + pricing, + ); + allDocs.push(...entries); continue; } - spendTokens(txMetadata, { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, - }).catch((err) => { - logger.error('[packages/api #recordCollectedUsage] Error spending tokens', err); - }); + if (cache_creation > 0 || cache_read > 0) { + deps + .spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }) + .catch((err) => { + logger.error( + '[packages/api #recordCollectedUsage] Error spending structured tokens', + err, + ); + }); + continue; + } + + deps + .spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }) + .catch((err) => { + logger.error('[packages/api #recordCollectedUsage] Error spending tokens', err); + }); + } + + if (useBulk && allDocs.length > 0) { + try { + await bulkWriteTransactions({ user, docs: allDocs }, bulkWriteOps); + } catch (err) { + logger.error('[packages/api #recordCollectedUsage] Error in bulk write', err); + } } return { diff --git a/packages/api/src/endpoints/openai/config.spec.ts b/packages/api/src/endpoints/openai/config.spec.ts index 2bbe123e63..58b471aa66 100644 --- a/packages/api/src/endpoints/openai/config.spec.ts +++ b/packages/api/src/endpoints/openai/config.spec.ts @@ -872,9 +872,8 @@ describe('getOpenAIConfig', () => { modelOptions, }); - // OpenRouter reasoning object should only include effort, not summary - expect(result.llmConfig.reasoning).toEqual({ - effort: ReasoningEffort.high, + expect(result.llmConfig.modelKwargs).toMatchObject({ + reasoning: { effort: ReasoningEffort.high }, }); expect(result.llmConfig.include_reasoning).toBeUndefined(); expect(result.provider).toBe('openrouter'); @@ -1206,13 +1205,13 @@ describe('getOpenAIConfig', () => { model: 'gpt-4-turbo', temperature: 0.8, streaming: false, - reasoning: { effort: ReasoningEffort.high }, // OpenRouter reasoning object }); expect(result.llmConfig.include_reasoning).toBeUndefined(); // Should NOT have useResponsesApi for OpenRouter expect(result.llmConfig.useResponsesApi).toBeUndefined(); expect(result.llmConfig.maxTokens).toBe(2000); expect(result.llmConfig.modelKwargs).toEqual({ + reasoning: { effort: ReasoningEffort.high }, verbosity: Verbosity.medium, customParam: 'custom-value', plugins: [{ id: 'web' }], // OpenRouter web search format @@ -1482,13 +1481,11 @@ describe('getOpenAIConfig', () => { user: 'openrouter-user', temperature: 0.7, maxTokens: 4000, - reasoning: { - effort: ReasoningEffort.high, - }, apiKey: apiKey, }); expect(result.llmConfig.include_reasoning).toBeUndefined(); expect(result.llmConfig.modelKwargs).toMatchObject({ + reasoning: { effort: ReasoningEffort.high }, top_k: 50, repetition_penalty: 1.1, }); diff --git a/packages/api/src/endpoints/openai/llm.spec.ts b/packages/api/src/endpoints/openai/llm.spec.ts index 3c10179737..a78cc4b87d 100644 --- a/packages/api/src/endpoints/openai/llm.spec.ts +++ b/packages/api/src/endpoints/openai/llm.spec.ts @@ -393,7 +393,9 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort: ReasoningEffort.high }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { + effort: ReasoningEffort.high, + }); expect(result.llmConfig).not.toHaveProperty('include_reasoning'); expect(result.llmConfig.modelKwargs).toHaveProperty('plugins', [{ id: 'web' }]); }); @@ -617,7 +619,9 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort: ReasoningEffort.high }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { + effort: ReasoningEffort.high, + }); expect(result.llmConfig).not.toHaveProperty('include_reasoning'); expect(result.llmConfig).not.toHaveProperty('reasoning_effort'); }); @@ -634,7 +638,9 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort: ReasoningEffort.high }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { + effort: ReasoningEffort.high, + }); }); it.each([ReasoningEffort.xhigh, ReasoningEffort.minimal, ReasoningEffort.none])( @@ -650,7 +656,7 @@ describe('getOpenAILLMConfig', () => { }, }); - expect(result.llmConfig).toHaveProperty('reasoning', { effort }); + expect(result.llmConfig.modelKwargs).toHaveProperty('reasoning', { effort }); expect(result.llmConfig).not.toHaveProperty('include_reasoning'); }, ); diff --git a/packages/api/src/endpoints/openai/llm.ts b/packages/api/src/endpoints/openai/llm.ts index c659645958..a89f6fce44 100644 --- a/packages/api/src/endpoints/openai/llm.ts +++ b/packages/api/src/endpoints/openai/llm.ts @@ -1,7 +1,6 @@ import { EModelEndpoint, removeNullishValues } from 'librechat-data-provider'; import type { BindToolsInput } from '@langchain/core/language_models/chat_models'; import type { SettingDefinition } from 'librechat-data-provider'; -import type { OpenRouterReasoning } from '@librechat/agents'; import type { AzureOpenAIInput } from '@langchain/openai'; import type { OpenAI } from 'openai'; import type * as t from '~/types'; @@ -231,7 +230,8 @@ export function getOpenAILLMConfig({ * `include_reasoning` is legacy compat that maps to `{ enabled: true }` only when * no `reasoning` object is present, so we intentionally omit it here. */ - llmConfig.reasoning = { effort: reasoning_effort } as OpenRouterReasoning; + modelKwargs.reasoning = { effort: reasoning_effort }; + hasModelKwargs = true; } else { /** No explicit effort; fall back to legacy `include_reasoning` for reasoning token inclusion */ llmConfig.include_reasoning = true; diff --git a/packages/data-schemas/src/methods/index.ts b/packages/data-schemas/src/methods/index.ts index 2f20b67fec..07e7cefc24 100644 --- a/packages/data-schemas/src/methods/index.ts +++ b/packages/data-schemas/src/methods/index.ts @@ -21,6 +21,7 @@ import { createAccessRoleMethods, type AccessRoleMethods } from './accessRole'; import { createUserGroupMethods, type UserGroupMethods } from './userGroup'; import { createAclEntryMethods, type AclEntryMethods } from './aclEntry'; import { createShareMethods, type ShareMethods } from './share'; +import { createTransactionMethods, type TransactionMethods } from './transaction'; export type AllMethods = UserMethods & SessionMethods & @@ -36,7 +37,8 @@ export type AllMethods = UserMethods & AclEntryMethods & ShareMethods & AccessRoleMethods & - PluginAuthMethods; + PluginAuthMethods & + TransactionMethods; /** * Creates all database methods for all collections @@ -59,6 +61,7 @@ export function createMethods(mongoose: typeof import('mongoose')): AllMethods { ...createAclEntryMethods(mongoose), ...createShareMethods(mongoose), ...createPluginAuthMethods(mongoose), + ...createTransactionMethods(mongoose), }; } @@ -78,4 +81,5 @@ export type { ShareMethods, AccessRoleMethods, PluginAuthMethods, + TransactionMethods, }; diff --git a/packages/data-schemas/src/methods/transaction.ts b/packages/data-schemas/src/methods/transaction.ts new file mode 100644 index 0000000000..d521b9e85e --- /dev/null +++ b/packages/data-schemas/src/methods/transaction.ts @@ -0,0 +1,100 @@ +import type { IBalance, TransactionData } from '~/types'; +import logger from '~/config/winston'; + +interface UpdateBalanceParams { + user: string; + incrementValue: number; + setValues?: Partial>; +} + +export function createTransactionMethods(mongoose: typeof import('mongoose')) { + async function updateBalance({ user, incrementValue, setValues }: UpdateBalanceParams) { + const maxRetries = 10; + let delay = 50; + let lastError: Error | null = null; + const Balance = mongoose.models.Balance; + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + const currentBalanceDoc = await Balance.findOne({ user }).lean(); + const currentCredits = currentBalanceDoc?.tokenCredits ?? 0; + const newCredits = Math.max(0, currentCredits + incrementValue); + + const updatePayload = { + $set: { + tokenCredits: newCredits, + ...(setValues ?? {}), + }, + }; + + if (currentBalanceDoc) { + const updatedBalance = await Balance.findOneAndUpdate( + { user, tokenCredits: currentCredits }, + updatePayload, + { new: true }, + ).lean(); + + if (updatedBalance) { + return updatedBalance; + } + lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`); + } else { + try { + const updatedBalance = await Balance.findOneAndUpdate({ user }, updatePayload, { + upsert: true, + new: true, + }).lean(); + + if (updatedBalance) { + return updatedBalance; + } + lastError = new Error( + `Upsert race condition suspected for user ${user} on attempt ${attempt}.`, + ); + } catch (error: unknown) { + if ( + error instanceof Error && + 'code' in error && + (error as { code: number }).code === 11000 + ) { + lastError = error; + } else { + throw error; + } + } + } + } catch (error) { + logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error); + lastError = error instanceof Error ? error : new Error(String(error)); + } + + if (attempt < maxRetries) { + const jitter = Math.random() * delay * 0.5; + await new Promise((resolve) => setTimeout(resolve, delay + jitter)); + delay = Math.min(delay * 2, 2000); + } + } + + logger.error( + `[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`, + ); + throw ( + lastError ?? + new Error( + `Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`, + ) + ); + } + + /** Bypasses document middleware; all computed fields must be pre-calculated before calling. */ + async function bulkInsertTransactions(docs: TransactionData[]): Promise { + const Transaction = mongoose.models.Transaction; + if (docs.length) { + await Transaction.insertMany(docs); + } + } + + return { updateBalance, bulkInsertTransactions }; +} + +export type TransactionMethods = ReturnType; diff --git a/packages/data-schemas/src/types/index.ts b/packages/data-schemas/src/types/index.ts index 38f9f22b50..d467d99d21 100644 --- a/packages/data-schemas/src/types/index.ts +++ b/packages/data-schemas/src/types/index.ts @@ -8,6 +8,7 @@ export * from './convo'; export * from './session'; export * from './balance'; export * from './banner'; +export * from './transaction'; export * from './message'; export * from './agent'; export * from './agentApiKey'; diff --git a/packages/data-schemas/src/types/transaction.ts b/packages/data-schemas/src/types/transaction.ts new file mode 100644 index 0000000000..978d7fd62b --- /dev/null +++ b/packages/data-schemas/src/types/transaction.ts @@ -0,0 +1,17 @@ +export interface TransactionData { + user: string; + conversationId: string; + tokenType: string; + model?: string; + context?: string; + valueKey?: string; + rate?: number; + rawAmount?: number; + tokenValue?: number; + inputTokens?: number; + writeTokens?: number; + readTokens?: number; + messageId?: string; + inputTokenCount?: number; + rateDetail?: Record; +} diff --git a/packages/data-schemas/src/utils/transactions.ts b/packages/data-schemas/src/utils/transactions.ts index 09bbb040c1..26f1f77e7e 100644 --- a/packages/data-schemas/src/utils/transactions.ts +++ b/packages/data-schemas/src/utils/transactions.ts @@ -1,5 +1,7 @@ import logger from '~/config/winston'; +export const CANCEL_RATE = 1.15; + /** * Checks if the connected MongoDB deployment supports transactions * This requires a MongoDB replica set configuration