diff --git a/api/server/controllers/agents/__tests__/openai.spec.js b/api/server/controllers/agents/__tests__/openai.spec.js index 28af80314c..d702907c6a 100644 --- a/api/server/controllers/agents/__tests__/openai.spec.js +++ b/api/server/controllers/agents/__tests__/openai.spec.js @@ -79,11 +79,6 @@ jest.mock('~/server/services/ToolService', () => ({ const mockGetMultiplier = jest.fn().mockReturnValue(1); const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); -jest.mock('~/models/tx', () => ({ - getMultiplier: mockGetMultiplier, - getCacheMultiplier: mockGetCacheMultiplier, -})); - jest.mock('~/server/controllers/agents/callbacks', () => ({ createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), @@ -110,6 +105,8 @@ jest.mock('~/models', () => ({ bulkInsertTransactions: mockBulkInsertTransactions, spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens, + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, getConvoFiles: jest.fn().mockResolvedValue([]), })); diff --git a/api/server/controllers/agents/__tests__/responses.unit.spec.js b/api/server/controllers/agents/__tests__/responses.unit.spec.js index 2ce22a2060..2b38a00771 100644 --- a/api/server/controllers/agents/__tests__/responses.unit.spec.js +++ b/api/server/controllers/agents/__tests__/responses.unit.spec.js @@ -103,11 +103,6 @@ jest.mock('~/server/services/ToolService', () => ({ const mockGetMultiplier = jest.fn().mockReturnValue(1); const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); -jest.mock('~/models/tx', () => ({ - getMultiplier: mockGetMultiplier, - getCacheMultiplier: mockGetCacheMultiplier, -})); - jest.mock('~/server/controllers/agents/callbacks', () => ({ createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), @@ -136,6 +131,8 @@ jest.mock('~/models', () => ({ bulkInsertTransactions: mockBulkInsertTransactions, spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens, + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, getConvoFiles: jest.fn().mockResolvedValue([]), saveConvo: jest.fn().mockResolvedValue({}), getConvo: jest.fn().mockResolvedValue(null), diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 598dddbd09..43c429fb34 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -46,8 +46,6 @@ const { removeNullishValues, } = require('librechat-data-provider'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { updateBalance, bulkInsertTransactions } = require('~/models'); -const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const { createContextHandlers } = require('~/app/clients/prompts'); const { getMCPServerTools } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); @@ -631,8 +629,8 @@ class AgentClient extends BaseClient { { spendTokens: db.spendTokens, spendStructuredTokens: db.spendStructuredTokens, - pricing: { getMultiplier, getCacheMultiplier }, - bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance }, + pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, }, { user: this.user ?? this.options.req.user?.id, diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js index 11e100cc9b..4a38f6f097 100644 --- a/api/server/controllers/agents/openai.js +++ b/api/server/controllers/agents/openai.js @@ -24,7 +24,6 @@ const { const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { createToolEndCallback } = require('~/server/controllers/agents/callbacks'); const { findAccessibleResources } = require('~/server/services/PermissionService'); -const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const db = require('~/models'); /** @@ -494,7 +493,7 @@ const OpenAIChatCompletionController = async (req, res) => { { spendTokens: db.spendTokens, spendStructuredTokens: db.spendStructuredTokens, - pricing: { getMultiplier, getCacheMultiplier }, + pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier }, bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, }, { diff --git a/api/server/controllers/agents/recordCollectedUsage.spec.js b/api/server/controllers/agents/recordCollectedUsage.spec.js index 2d4730c603..009c5b262c 100644 --- a/api/server/controllers/agents/recordCollectedUsage.spec.js +++ b/api/server/controllers/agents/recordCollectedUsage.spec.js @@ -21,14 +21,8 @@ const mockRecordCollectedUsage = jest jest.mock('~/models', () => ({ spendTokens: (...args) => mockSpendTokens(...args), spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args), -})); - -jest.mock('~/models/tx', () => ({ getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier, -})); - -jest.mock('~/models', () => ({ updateBalance: mockUpdateBalance, bulkInsertTransactions: mockBulkInsertTransactions, })); diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js index 90e16b3128..6f629d2a81 100644 --- a/api/server/controllers/agents/responses.js +++ b/api/server/controllers/agents/responses.js @@ -36,7 +36,6 @@ const { } = require('~/server/controllers/agents/callbacks'); const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { findAccessibleResources } = require('~/server/services/PermissionService'); -const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const db = require('~/models'); /** @type {import('@librechat/api').AppConfig | null} */ @@ -514,7 +513,7 @@ const createResponse = async (req, res) => { { spendTokens: db.spendTokens, spendStructuredTokens: db.spendStructuredTokens, - pricing: { getMultiplier, getCacheMultiplier }, + pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier }, bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, }, { @@ -668,7 +667,7 @@ const createResponse = async (req, res) => { { spendTokens: db.spendTokens, spendStructuredTokens: db.spendStructuredTokens, - pricing: { getMultiplier, getCacheMultiplier }, + pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier }, bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, }, { diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 624ace7f9f..e0c5ae0ff0 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -8,13 +8,11 @@ const { recordCollectedUsage, sanitizeMessageForTransmit, } = require('@librechat/api'); -const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); -const { saveMessage, getConvo, updateBalance, bulkInsertTransactions } = require('~/models'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); -const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); const clearPendingReq = require('~/cache/clearPendingReq'); const { sendError } = require('~/server/middleware/error'); const { abortRun } = require('./abortRun'); +const db = require('~/models'); /** * Spend tokens for all models from collected usage. @@ -44,10 +42,10 @@ async function spendCollectedUsage({ await recordCollectedUsage( { - spendTokens, - spendStructuredTokens, - pricing: { getMultiplier, getCacheMultiplier }, - bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance }, + spendTokens: db.spendTokens, + spendStructuredTokens: db.spendStructuredTokens, + pricing: { getMultiplier: db.getMultiplier, getCacheMultiplier: db.getCacheMultiplier }, + bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance }, }, { user: userId, @@ -123,13 +121,13 @@ async function abortMessage(req, res) { }); } else { // Fallback: no collected usage, use text-based token counting for primary model only - await spendTokens( + await db.spendTokens( { ...responseMessage, context: 'incomplete', user: userId }, { promptTokens, completionTokens }, ); } - await saveMessage( + await db.saveMessage( { userId: req?.user?.id, isTemporary: req?.body?.isTemporary, @@ -140,7 +138,7 @@ async function abortMessage(req, res) { ); // Get conversation for title - const conversation = await getConvo(userId, conversationId); + const conversation = await db.getConvo(userId, conversationId); const finalEvent = { title: conversation && !conversation.title ? null : conversation?.title || 'New Chat', diff --git a/api/server/middleware/abortMiddleware.spec.js b/api/server/middleware/abortMiddleware.spec.js index c9c0d5cc60..a4ce85674b 100644 --- a/api/server/middleware/abortMiddleware.spec.js +++ b/api/server/middleware/abortMiddleware.spec.js @@ -20,8 +20,6 @@ const mockRecordCollectedUsage = jest const mockGetMultiplier = jest.fn().mockReturnValue(1); const mockGetCacheMultiplier = jest.fn().mockReturnValue(null); - - jest.mock('@librechat/data-schemas', () => ({ logger: { debug: jest.fn(), @@ -65,6 +63,10 @@ jest.mock('~/models', () => ({ getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }), updateBalance: mockUpdateBalance, bulkInsertTransactions: mockBulkInsertTransactions, + spendTokens: (...args) => mockSpendTokens(...args), + spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args), + getMultiplier: mockGetMultiplier, + getCacheMultiplier: mockGetCacheMultiplier, })); jest.mock('./abortRun', () => ({ diff --git a/api/server/services/Files/process.spec.js b/api/server/services/Files/process.spec.js index 7737255a52..39300161a8 100644 --- a/api/server/services/Files/process.spec.js +++ b/api/server/services/Files/process.spec.js @@ -30,11 +30,6 @@ jest.mock('~/server/controllers/assistants/v2', () => ({ deleteResourceFileId: jest.fn(), })); -jest.mock('~/models/Agent', () => ({ - addAgentResourceFile: jest.fn().mockResolvedValue({}), - removeAgentResourceFiles: jest.fn(), -})); - jest.mock('~/server/controllers/assistants/helpers', () => ({ getOpenAIClient: jest.fn(), })); @@ -47,6 +42,8 @@ jest.mock('~/models', () => ({ createFile: jest.fn().mockResolvedValue({ file_id: 'created-file-id' }), updateFileUsage: jest.fn(), deleteFiles: jest.fn(), + addAgentResourceFile: jest.fn().mockResolvedValue({}), + removeAgentResourceFiles: jest.fn(), })); jest.mock('~/server/utils/getFileStrategy', () => ({ diff --git a/packages/api/src/agents/transactions.bulk-parity.spec.ts b/packages/api/src/agents/transactions.bulk-parity.spec.ts index bf89682d6f..327856d18b 100644 --- a/packages/api/src/agents/transactions.bulk-parity.spec.ts +++ b/packages/api/src/agents/transactions.bulk-parity.spec.ts @@ -14,10 +14,12 @@ import mongoose from 'mongoose'; import { MongoMemoryServer } from 'mongodb-memory-server'; import { + tokenValues, CANCEL_RATE, createMethods, balanceSchema, transactionSchema, + premiumTokenValues, } from '@librechat/data-schemas'; import type { PricingFns, TxMetadata } from './transactions'; import { @@ -26,6 +28,26 @@ import { prepareTokenSpend, } from './transactions'; +/** Inlined from packages/data-schemas/src/methods/test-helpers.ts — keep in sync */ +function findMatchingPattern( + modelName: string, + tokensMap: Record>, +): string | undefined { + const keys = Object.keys(tokensMap); + const lowerModelName = modelName.toLowerCase(); + for (let i = keys.length - 1; i >= 0; i--) { + if (lowerModelName.includes(keys[i])) { + return keys[i]; + } + } + return undefined; +} + +/** Inlined from packages/data-schemas/src/methods/test-helpers.ts — keep in sync */ +function matchModelName(modelName: string, _endpoint?: string): string | undefined { + return typeof modelName === 'string' ? modelName : undefined; +} + jest.mock('@librechat/data-schemas', () => { const actual = jest.requireActual('@librechat/data-schemas'); return { @@ -34,29 +56,23 @@ jest.mock('@librechat/data-schemas', () => { }; }); -// Real pricing functions from api/models/tx.js — same ones the legacy path uses -/* eslint-disable @typescript-eslint/no-require-imports */ -const { - getMultiplier, - getCacheMultiplier, - tokenValues, - premiumTokenValues, -} = require('../../../../api/models/tx.js'); -/* eslint-enable @typescript-eslint/no-require-imports */ - -const pricing: PricingFns = { getMultiplier, getCacheMultiplier }; - let mongoServer: MongoMemoryServer; let Transaction: mongoose.Model; let Balance: mongoose.Model; let dbMethods: ReturnType; +let pricing: PricingFns; +let getMultiplier: ReturnType['getMultiplier']; +let getCacheMultiplier: ReturnType['getCacheMultiplier']; beforeAll(async () => { mongoServer = await MongoMemoryServer.create(); await mongoose.connect(mongoServer.getUri()); Transaction = mongoose.models.Transaction || mongoose.model('Transaction', transactionSchema); Balance = mongoose.models.Balance || mongoose.model('Balance', balanceSchema); - dbMethods = createMethods(mongoose); + dbMethods = createMethods(mongoose, { matchModelName, findMatchingPattern }); + getMultiplier = dbMethods.getMultiplier; + getCacheMultiplier = dbMethods.getCacheMultiplier; + pricing = { getMultiplier, getCacheMultiplier }; }); afterAll(async () => { @@ -536,8 +552,13 @@ describe('Multi-entry batch parity', () => { const premiumCompletionRate = (premiumTokenValues as Record>)[ model ].completion; - const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); - const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; const expectedPromptCost = tokenUsage.promptTokens.input * premiumPromptRate + diff --git a/packages/api/src/agents/transactions.ts b/packages/api/src/agents/transactions.ts index b746392b44..a9eeda1973 100644 --- a/packages/api/src/agents/transactions.ts +++ b/packages/api/src/agents/transactions.ts @@ -3,9 +3,11 @@ import type { TCustomConfig, TTransactionsConfig } from 'librechat-data-provider import type { TransactionData } from '@librechat/data-schemas'; import type { EndpointTokenConfig } from '~/types/tokens'; +type TokenType = 'prompt' | 'completion'; + interface GetMultiplierParams { valueKey?: string; - tokenType?: string; + tokenType?: TokenType; model?: string; endpointTokenConfig?: EndpointTokenConfig; inputTokenCount?: number; @@ -34,14 +36,14 @@ interface BaseTxData { } interface StandardTxData extends BaseTxData { - tokenType: string; + tokenType: TokenType; rawAmount: number; inputTokenCount?: number; valueKey?: string; } interface StructuredTxData extends BaseTxData { - tokenType: string; + tokenType: TokenType; inputTokens?: number; writeTokens?: number; readTokens?: number; diff --git a/packages/api/src/types/tokens.ts b/packages/api/src/types/tokens.ts index f6e03d2e8d..b555031049 100644 --- a/packages/api/src/types/tokens.ts +++ b/packages/api/src/types/tokens.ts @@ -1,16 +1,8 @@ -/** Configuration object mapping model keys to their respective prompt, completion rates, and context limit - * - * Note: the [key: string]: unknown is not in the original JSDoc typedef in /api/typedefs.js, but I've included it since - * getModelMaxOutputTokens calls getModelTokenValue with a key of 'output', which was not in the original JSDoc typedef, - * but would be referenced in a TokenConfig in the if(matchedPattern) portion of getModelTokenValue. - * So in order to preserve functionality for that case and any others which might reference an additional key I'm unaware of, - * I've included it here until the interface can be typed more tightly. - */ export interface TokenConfig { + [key: string]: number; prompt: number; completion: number; context: number; - [key: string]: unknown; } /** An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit */ diff --git a/packages/data-schemas/src/methods/index.ts b/packages/data-schemas/src/methods/index.ts index 4663f94622..ee4b1cdf9d 100644 --- a/packages/data-schemas/src/methods/index.ts +++ b/packages/data-schemas/src/methods/index.ts @@ -85,7 +85,10 @@ export interface CreateMethodsDeps { /** Matches a model name to a canonical key. From @librechat/api. */ matchModelName?: (model: string, endpoint?: string) => string | undefined; /** Finds the first key in values whose key is a substring of model. From @librechat/api. */ - findMatchingPattern?: (model: string, values: Record) => string | undefined; + findMatchingPattern?: ( + model: string, + values: Record>, + ) => string | undefined; /** Removes all ACL permissions for a resource. From PermissionService. */ removeAllPermissions?: (params: { resourceType: string; resourceId: unknown }) => Promise; /** Returns a cache store for the given key. From getLogStores. */ diff --git a/packages/data-schemas/src/methods/test-helpers.ts b/packages/data-schemas/src/methods/test-helpers.ts index 26b5038dd6..bd64e0268a 100644 --- a/packages/data-schemas/src/methods/test-helpers.ts +++ b/packages/data-schemas/src/methods/test-helpers.ts @@ -11,7 +11,7 @@ */ export function findMatchingPattern( modelName: string, - tokensMap: Record, + tokensMap: Record>, ): string | undefined { const keys = Object.keys(tokensMap); const lowerModelName = modelName.toLowerCase(); diff --git a/packages/data-schemas/src/methods/transaction.spec.ts b/packages/data-schemas/src/methods/transaction.spec.ts index feaf9b758f..ee7df36c57 100644 --- a/packages/data-schemas/src/methods/transaction.spec.ts +++ b/packages/data-schemas/src/methods/transaction.spec.ts @@ -247,8 +247,8 @@ describe('Structured Token Spending Tests', () => { const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const completionMultiplier = getMultiplier({ model, tokenType: 'completion' }); - const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); - const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; // Act const result = await spendStructuredTokens(txData, tokenUsage); @@ -256,8 +256,8 @@ describe('Structured Token Spending Tests', () => { // Calculate expected costs. const expectedPromptCost = tokenUsage.promptTokens.input * promptMultiplier + - tokenUsage.promptTokens.write * (writeMultiplier ?? 0) + - tokenUsage.promptTokens.read * (readMultiplier ?? 0); + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier; const expectedTotalCost = expectedPromptCost + expectedCompletionCost; const expectedBalance = initialBalance - expectedTotalCost; @@ -813,13 +813,18 @@ describe('Premium Token Pricing Integration Tests', () => { const premiumPromptRate = premiumTokenValues[model].prompt; const premiumCompletionRate = premiumTokenValues[model].completion; - const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); - const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; const expectedPromptCost = tokenUsage.promptTokens.input * premiumPromptRate + - tokenUsage.promptTokens.write * (writeMultiplier ?? 0) + - tokenUsage.promptTokens.read * (readMultiplier ?? 0); + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; const expectedTotalCost = expectedPromptCost + expectedCompletionCost; @@ -859,13 +864,18 @@ describe('Premium Token Pricing Integration Tests', () => { const standardPromptRate = tokenValues[model].prompt; const standardCompletionRate = tokenValues[model].completion; - const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); - const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; const expectedPromptCost = tokenUsage.promptTokens.input * standardPromptRate + - tokenUsage.promptTokens.write * (writeMultiplier ?? 0) + - tokenUsage.promptTokens.read * (readMultiplier ?? 0); + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; const expectedCompletionCost = tokenUsage.completionTokens * standardCompletionRate; const expectedTotalCost = expectedPromptCost + expectedCompletionCost; @@ -900,7 +910,7 @@ describe('Premium Token Pricing Integration Tests', () => { promptTokens * standardPromptRate + completionTokens * standardCompletionRate; const updatedBalance = await Balance.findOne({ user: userId }); - expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); test('spendTokens should apply premium pricing for gemini-3.1-pro-preview above threshold', async () => { @@ -929,7 +939,7 @@ describe('Premium Token Pricing Integration Tests', () => { promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate; const updatedBalance = await Balance.findOne({ user: userId }); - expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); test('spendTokens should apply standard pricing for gemini-3.1-pro-preview at exactly the threshold', async () => { @@ -958,7 +968,7 @@ describe('Premium Token Pricing Integration Tests', () => { promptTokens * standardPromptRate + completionTokens * standardCompletionRate; const updatedBalance = await Balance.findOne({ user: userId }); - expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); test('spendStructuredTokens should apply premium pricing for gemini-3.1 when total input exceeds threshold', async () => { @@ -992,8 +1002,13 @@ describe('Premium Token Pricing Integration Tests', () => { const premiumPromptRate = premiumTokenValues['gemini-3.1'].prompt; const premiumCompletionRate = premiumTokenValues['gemini-3.1'].completion; - const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); - const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + const promptMultiplier = getMultiplier({ + model, + tokenType: 'prompt', + inputTokenCount: totalInput, + }); + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; const expectedPromptCost = tokenUsage.promptTokens.input * premiumPromptRate + @@ -1004,7 +1019,7 @@ describe('Premium Token Pricing Integration Tests', () => { const updatedBalance = await Balance.findOne({ user: userId }); expect(totalInput).toBeGreaterThan(premiumTokenValues['gemini-3.1'].threshold); - expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); }); test('non-premium models should not be affected by inputTokenCount regardless of prompt size', async () => { @@ -1036,339 +1051,3 @@ describe('Premium Token Pricing Integration Tests', () => { expect(updatedBalance?.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); }); - -describe('Bulk path parity', () => { - /** - * Each test here mirrors an existing legacy test above, replacing spendTokens/ - * spendStructuredTokens with recordCollectedUsage + bulk deps. - * The balance deduction and transaction document fields must be numerically identical. - */ - let bulkDeps; - let methods; - - beforeEach(() => { - methods = createMethods(mongoose); - bulkDeps = { - spendTokens: () => Promise.resolve(), - spendStructuredTokens: () => Promise.resolve(), - pricing: { getMultiplier, getCacheMultiplier }, - bulkWriteOps: { - insertMany: methods.bulkInsertTransactions, - updateBalance: methods.updateBalance, - }, - }; - }); - - test('balance should decrease when spending tokens via bulk path', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - const model = 'gpt-3.5-turbo'; - const promptTokens = 100; - const completionTokens = 50; - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-conversation-id', - model, - context: 'test', - balance: { enabled: true }, - transactions: { enabled: true }, - collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], - }); - - const updatedBalance = await Balance.findOne({ user: userId }); - const promptMultiplier = getMultiplier({ - model, - tokenType: 'prompt', - inputTokenCount: promptTokens, - }); - const completionMultiplier = getMultiplier({ - model, - tokenType: 'completion', - inputTokenCount: promptTokens, - }); - const expectedTotalCost = - promptTokens * promptMultiplier + completionTokens * completionMultiplier; - const expectedBalance = initialBalance - expectedTotalCost; - - expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0); - - const txns = await Transaction.find({ user: userId }).lean(); - expect(txns).toHaveLength(2); - }); - - test('bulk path should not update balance when balance.enabled is false', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - const model = 'gpt-3.5-turbo'; - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-conversation-id', - model, - context: 'test', - balance: { enabled: false }, - transactions: { enabled: true }, - collectedUsage: [{ input_tokens: 100, output_tokens: 50, model }], - }); - - const updatedBalance = await Balance.findOne({ user: userId }); - expect(updatedBalance.tokenCredits).toBe(initialBalance); - const txns = await Transaction.find({ user: userId }).lean(); - expect(txns).toHaveLength(2); // transactions still recorded - }); - - test('bulk path should not insert when transactions.enabled is false', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-conversation-id', - model: 'gpt-3.5-turbo', - context: 'test', - balance: { enabled: true }, - transactions: { enabled: false }, - collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }], - }); - - const txns = await Transaction.find({ user: userId }).lean(); - expect(txns).toHaveLength(0); - const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBe(initialBalance); - }); - - test('bulk path handles incomplete context for completion tokens — same CANCEL_RATE as legacy', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 17613154.55; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - const model = 'claude-3-5-sonnet'; - const promptTokens = 10; - const completionTokens = 50; - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-convo', - model, - context: 'incomplete', - balance: { enabled: true }, - transactions: { enabled: true }, - collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], - }); - - const txns = await Transaction.find({ user: userId }).lean(); - const completionTx = txns.find((t) => t.tokenType === 'completion'); - const completionMultiplier = getMultiplier({ - model, - tokenType: 'completion', - inputTokenCount: promptTokens, - }); - expect(completionTx.tokenValue).toBeCloseTo(-completionTokens * completionMultiplier * 1.15, 0); - }); - - test('bulk path structured tokens — balance deduction matches legacy spendStructuredTokens', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 17613154.55; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - const model = 'claude-3-5-sonnet'; - const promptInput = 11; - const promptWrite = 140522; - const promptRead = 0; - const completionTokens = 5; - const totalInput = promptInput + promptWrite + promptRead; - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-convo', - model, - context: 'message', - balance: { enabled: true }, - transactions: { enabled: true }, - collectedUsage: [ - { - input_tokens: promptInput, - output_tokens: completionTokens, - model, - input_token_details: { cache_creation: promptWrite, cache_read: promptRead }, - }, - ], - }); - - const promptMultiplier = getMultiplier({ - model, - tokenType: 'prompt', - inputTokenCount: totalInput, - }); - const completionMultiplier = getMultiplier({ - model, - tokenType: 'completion', - inputTokenCount: totalInput, - }); - const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier; - const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier; - - const expectedPromptCost = - promptInput * promptMultiplier + promptWrite * writeMultiplier + promptRead * readMultiplier; - const expectedCompletionCost = completionTokens * completionMultiplier; - const expectedTotalCost = expectedPromptCost + expectedCompletionCost; - const expectedBalance = initialBalance - expectedTotalCost; - - const updatedBalance = await Balance.findOne({ user: userId }); - expect(Math.abs(updatedBalance.tokenCredits - expectedBalance)).toBeLessThan(100); - }); - - test('premium pricing above threshold via bulk path — same balance as legacy', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 100000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - const model = 'claude-opus-4-6'; - const promptTokens = 250000; - const completionTokens = 500; - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-premium', - model, - context: 'test', - balance: { enabled: true }, - transactions: { enabled: true }, - collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }], - }); - - const premiumPromptRate = premiumTokenValues[model].prompt; - const premiumCompletionRate = premiumTokenValues[model].completion; - const expectedCost = - promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate; - - const updatedBalance = await Balance.findOne({ user: userId }); - expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); - }); - - test('real-world multi-entry batch: 5 sequential tool calls — same total deduction as 5 legacy spendTokens calls', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 100000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - const model = 'claude-opus-4-5-20251101'; - const calls = [ - { input_tokens: 31596, output_tokens: 151 }, - { input_tokens: 35368, output_tokens: 150 }, - { input_tokens: 58362, output_tokens: 295 }, - { input_tokens: 112604, output_tokens: 193 }, - { input_tokens: 257440, output_tokens: 2217 }, - ]; - - let expectedTotalCost = 0; - for (const { input_tokens, output_tokens } of calls) { - const pm = getMultiplier({ model, tokenType: 'prompt', inputTokenCount: input_tokens }); - const cm = getMultiplier({ model, tokenType: 'completion', inputTokenCount: input_tokens }); - expectedTotalCost += input_tokens * pm + output_tokens * cm; - } - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-sequential', - model, - context: 'message', - balance: { enabled: true }, - transactions: { enabled: true }, - collectedUsage: calls.map((c) => ({ ...c, model })), - }); - - const txns = await Transaction.find({ user: userId }).lean(); - expect(txns).toHaveLength(10); // 5 calls × 2 docs (prompt + completion) - - const updatedBalance = await Balance.findOne({ user: userId }); - expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); - }); - - test('bulk path should save transaction but not update balance when balance disabled, transactions enabled', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-conversation-id', - model: 'gpt-3.5-turbo', - context: 'test', - balance: { enabled: false }, - transactions: { enabled: true }, - collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }], - }); - - const txns = await Transaction.find({ user: userId }).lean(); - expect(txns).toHaveLength(2); - expect(txns[0].rawAmount).toBeDefined(); - const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBe(initialBalance); - }); - - test('bulk path structured tokens should not save when transactions.enabled is false', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-conversation-id', - model: 'claude-3-5-sonnet', - context: 'message', - balance: { enabled: true }, - transactions: { enabled: false }, - collectedUsage: [ - { - input_tokens: 10, - output_tokens: 5, - model: 'claude-3-5-sonnet', - input_token_details: { cache_creation: 100, cache_read: 5 }, - }, - ], - }); - - const txns = await Transaction.find({ user: userId }).lean(); - expect(txns).toHaveLength(0); - const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBe(initialBalance); - }); - - test('bulk path structured tokens should save but not update balance when balance disabled', async () => { - const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; - await Balance.create({ user: userId, tokenCredits: initialBalance }); - - await recordCollectedUsage(bulkDeps, { - user: userId.toString(), - conversationId: 'test-conversation-id', - model: 'claude-3-5-sonnet', - context: 'message', - balance: { enabled: false }, - transactions: { enabled: true }, - collectedUsage: [ - { - input_tokens: 10, - output_tokens: 5, - model: 'claude-3-5-sonnet', - input_token_details: { cache_creation: 100, cache_read: 5 }, - }, - ], - }); - - const txns = await Transaction.find({ user: userId }).lean(); - expect(txns).toHaveLength(2); - const promptTx = txns.find((t) => t.tokenType === 'prompt'); - expect(promptTx.inputTokens).toBe(-10); - expect(promptTx.writeTokens).toBe(-100); - expect(promptTx.readTokens).toBe(-5); - const balance = await Balance.findOne({ user: userId }); - expect(balance.tokenCredits).toBe(initialBalance); - }); -}); diff --git a/packages/data-schemas/src/methods/transaction.ts b/packages/data-schemas/src/methods/transaction.ts index 3f019defa2..66c34b7e00 100644 --- a/packages/data-schemas/src/methods/transaction.ts +++ b/packages/data-schemas/src/methods/transaction.ts @@ -1,7 +1,7 @@ import logger from '~/config/winston'; import type { FilterQuery, Model, Types } from 'mongoose'; +import type { IBalance, IBalanceUpdate, TransactionData } from '~/types'; import type { ITransaction } from '~/schema/transaction'; -import type { IBalance, IBalanceUpdate } from '~/types'; const cancelRate = 1.15; @@ -408,7 +408,22 @@ export function createTransactionMethods( return Balance.deleteMany(filter); } + async function bulkInsertTransactions(docs: TransactionData[]): Promise { + if (!docs.length) { + return; + } + try { + const Transaction = mongoose.models.Transaction; + await Transaction.insertMany(docs); + } catch (error) { + logger.error('[bulkInsertTransactions] Error inserting transaction docs:', error); + throw error; + } + } + return { + updateBalance, + bulkInsertTransactions, findBalanceByUser, upsertBalanceFields, getTransactions, diff --git a/packages/data-schemas/src/methods/tx.ts b/packages/data-schemas/src/methods/tx.ts index df436f4fec..a7254b34a4 100644 --- a/packages/data-schemas/src/methods/tx.ts +++ b/packages/data-schemas/src/methods/tx.ts @@ -20,7 +20,10 @@ export interface TxDeps { /** From @librechat/api — matches a model name to a canonical key. */ matchModelName: (model: string, endpoint?: string) => string | undefined; /** From @librechat/api — finds the longest key in `values` whose key is a substring of `model`. */ - findMatchingPattern: (model: string, values: Record) => string | undefined; + findMatchingPattern: ( + model: string, + values: Record>, + ) => string | undefined; } export const defaultRate = 6;