diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 9e5b1a8b47..dcdfbc5815 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -113,13 +113,15 @@ class BaseClient { * If a correction to the token usage is needed, the method should return an object with the corrected token counts. * Should only be used if `recordCollectedUsage` was not used instead. * @param {string} [model] + * @param {AppConfig['balance']} [balance] * @param {number} promptTokens * @param {number} completionTokens * @returns {Promise} */ - async recordTokenUsage({ model, promptTokens, completionTokens }) { + async recordTokenUsage({ model, balance, promptTokens, completionTokens }) { logger.debug('[BaseClient] `recordTokenUsage` not implemented.', { model, + balance, promptTokens, completionTokens, }); @@ -754,6 +756,7 @@ class BaseClient { completionTokens = responseMessage.tokenCount; await this.recordTokenUsage({ usage, + balance, promptTokens, completionTokens, model: responseMessage.model, diff --git a/api/models/Transaction.js b/api/models/Transaction.js index 0e0e327857..e1cff15c3b 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,5 +1,4 @@ const { logger } = require('@librechat/data-schemas'); -const { getBalanceConfig } = require('~/server/services/Config'); const { getMultiplier, getCacheMultiplier } = require('./tx'); const { Transaction, Balance } = require('~/db/models'); @@ -187,9 +186,10 @@ async function createAutoRefillTransaction(txData) { /** * Static method to create a transaction and update the balance - * @param {txData} txData - Transaction data. + * @param {txData} _txData - Transaction data. */ -async function createTransaction(txData) { +async function createTransaction(_txData) { + const { balance, ...txData } = _txData; if (txData.rawAmount != null && isNaN(txData.rawAmount)) { return; } @@ -199,8 +199,6 @@ async function createTransaction(txData) { calculateTokenValue(transaction); await transaction.save(); - - const balance = await getBalanceConfig(); if (!balance?.enabled) { return; } @@ -221,9 +219,10 @@ async function createTransaction(txData) { /** * Static method to create a structured transaction and update the balance - * @param {txData} txData - Transaction data. + * @param {txData} _txData - Transaction data. */ -async function createStructuredTransaction(txData) { +async function createStructuredTransaction(_txData) { + const { balance, ...txData } = _txData; const transaction = new Transaction({ ...txData, endpointTokenConfig: txData.endpointTokenConfig, @@ -233,7 +232,6 @@ async function createStructuredTransaction(txData) { await transaction.save(); - const balance = await getBalanceConfig(); if (!balance?.enabled) { return; } diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index 3a1303edec..891d9ca7db 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -1,14 +1,11 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); const { spendTokens, spendStructuredTokens } = require('./spendTokens'); -const { getBalanceConfig } = require('~/server/services/Config'); + const { getMultiplier, getCacheMultiplier } = require('./tx'); const { createTransaction } = require('./Transaction'); const { Balance } = require('~/db/models'); -// Mock the custom config module so we can control the balance flag. -jest.mock('~/server/services/Config'); - let mongoServer; beforeAll(async () => { mongoServer = await MongoMemoryServer.create(); @@ -23,8 +20,6 @@ afterAll(async () => { beforeEach(async () => { await mongoose.connection.dropDatabase(); - // Default: enable balance updates in tests. - getBalanceConfig.mockResolvedValue({ enabled: true }); }); describe('Regular Token Spending Tests', () => { @@ -41,6 +36,7 @@ describe('Regular Token Spending Tests', () => { model, context: 'test', endpointTokenConfig: null, + balance: { enabled: true }, }; const tokenUsage = { @@ -74,6 +70,7 @@ describe('Regular Token Spending Tests', () => { model, context: 'test', endpointTokenConfig: null, + balance: { enabled: true }, }; const tokenUsage = { @@ -104,6 +101,7 @@ describe('Regular Token Spending Tests', () => { model, context: 'test', endpointTokenConfig: null, + balance: { enabled: true }, }; const tokenUsage = {}; @@ -128,6 +126,7 @@ describe('Regular Token Spending Tests', () => { model, context: 'test', endpointTokenConfig: null, + balance: { enabled: true }, }; const tokenUsage = { promptTokens: 100 }; @@ -143,8 +142,7 @@ describe('Regular Token Spending Tests', () => { }); test('spendTokens should not update balance when balance feature is disabled', async () => { - // Arrange: Override the config to disable balance updates. - getBalanceConfig.mockResolvedValue({ balance: { enabled: false } }); + // Arrange: Balance config is now passed directly in txData const userId = new mongoose.Types.ObjectId(); const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -156,6 +154,7 @@ describe('Regular Token Spending Tests', () => { model, context: 'test', endpointTokenConfig: null, + balance: { enabled: false }, }; const tokenUsage = { @@ -186,6 +185,7 @@ describe('Structured Token Spending Tests', () => { model, context: 'message', endpointTokenConfig: null, + balance: { enabled: true }, }; const tokenUsage = { @@ -239,6 +239,7 @@ describe('Structured Token Spending Tests', () => { conversationId: 'test-convo', model, context: 'message', + balance: { enabled: true }, }; const tokenUsage = { @@ -271,6 +272,7 @@ describe('Structured Token Spending Tests', () => { conversationId: 'test-convo', model, context: 'message', + balance: { enabled: true }, }; const tokenUsage = { @@ -302,6 +304,7 @@ describe('Structured Token Spending Tests', () => { conversationId: 'test-convo', model, context: 'message', + balance: { enabled: true }, }; const tokenUsage = {}; @@ -328,6 +331,7 @@ describe('Structured Token Spending Tests', () => { conversationId: 'test-convo', model, context: 'incomplete', + balance: { enabled: true }, }; const tokenUsage = { @@ -364,6 +368,7 @@ describe('NaN Handling Tests', () => { endpointTokenConfig: null, rawAmount: NaN, tokenType: 'prompt', + balance: { enabled: true }, }; // Act diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index 834f740926..65fadd7896 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -5,13 +5,7 @@ const { createTransaction, createStructuredTransaction } = require('./Transactio * * @function * @async - * @param {Object} txData - Transaction data. - * @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID. - * @param {String} txData.conversationId - The ID of the conversation. - * @param {String} txData.model - The model name. - * @param {String} txData.context - The context in which the transaction is made. - * @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config. - * @param {String} [txData.valueKey] - The value key (optional). + * @param {txData} txData - Transaction data. * @param {Object} tokenUsage - The number of tokens used. * @param {Number} tokenUsage.promptTokens - The number of prompt tokens used. * @param {Number} tokenUsage.completionTokens - The number of completion tokens used. @@ -69,13 +63,7 @@ const spendTokens = async (txData, tokenUsage) => { * * @function * @async - * @param {Object} txData - Transaction data. - * @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID. - * @param {String} txData.conversationId - The ID of the conversation. - * @param {String} txData.model - The model name. - * @param {String} txData.context - The context in which the transaction is made. - * @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config. - * @param {String} [txData.valueKey] - The value key (optional). + * @param {txData} txData - Transaction data. * @param {Object} tokenUsage - The number of tokens used. * @param {Object} tokenUsage.promptTokens - The number of prompt tokens used. * @param {Number} tokenUsage.promptTokens.input - The number of input tokens. diff --git a/api/models/spendTokens.spec.js b/api/models/spendTokens.spec.js index 7ee067e589..eee6572736 100644 --- a/api/models/spendTokens.spec.js +++ b/api/models/spendTokens.spec.js @@ -5,7 +5,6 @@ const { createTransaction, createAutoRefillTransaction } = require('./Transactio require('~/db/models'); -// Mock the logger to prevent console output during tests jest.mock('~/config', () => ({ logger: { debug: jest.fn(), @@ -13,10 +12,6 @@ jest.mock('~/config', () => ({ }, })); -// Mock the Config service -const { getBalanceConfig } = require('~/server/services/Config'); -jest.mock('~/server/services/Config'); - describe('spendTokens', () => { let mongoServer; let userId; @@ -44,8 +39,7 @@ describe('spendTokens', () => { // Create a new user ID for each test userId = new mongoose.Types.ObjectId(); - // Mock the balance config to be enabled by default - getBalanceConfig.mockResolvedValue({ enabled: true }); + // Balance config is now passed directly in txData }); it('should create transactions for both prompt and completion tokens', async () => { @@ -60,6 +54,7 @@ describe('spendTokens', () => { conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', + balance: { enabled: true }, }; const tokenUsage = { promptTokens: 100, @@ -98,6 +93,7 @@ describe('spendTokens', () => { conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', + balance: { enabled: true }, }; const tokenUsage = { promptTokens: 100, @@ -127,6 +123,7 @@ describe('spendTokens', () => { conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', + balance: { enabled: true }, }; const tokenUsage = {}; @@ -138,8 +135,7 @@ describe('spendTokens', () => { }); it('should not update balance when the balance feature is disabled', async () => { - // Override configuration: disable balance updates - getBalanceConfig.mockResolvedValue({ enabled: false }); + // Balance is now passed directly in txData // Create a balance for the user await Balance.create({ user: userId, @@ -151,6 +147,7 @@ describe('spendTokens', () => { conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', + balance: { enabled: false }, }; const tokenUsage = { promptTokens: 100, @@ -180,6 +177,7 @@ describe('spendTokens', () => { conversationId: 'test-convo', model: 'gpt-4', // Using a more expensive model context: 'test', + balance: { enabled: true }, }; // Spending more tokens than the user has balance for @@ -233,6 +231,7 @@ describe('spendTokens', () => { conversationId: 'test-convo-1', model: 'gpt-4', context: 'test', + balance: { enabled: true }, }; const tokenUsage1 = { @@ -252,6 +251,7 @@ describe('spendTokens', () => { conversationId: 'test-convo-2', model: 'gpt-4', context: 'test', + balance: { enabled: true }, }; const tokenUsage2 = { @@ -292,6 +292,7 @@ describe('spendTokens', () => { tokenType: 'completion', rawAmount: -100, context: 'test', + balance: { enabled: true }, }); console.log('Direct Transaction.create result:', directResult); @@ -316,6 +317,7 @@ describe('spendTokens', () => { conversationId: `test-convo-${model}`, model, context: 'test', + balance: { enabled: true }, }; const tokenUsage = { @@ -352,6 +354,7 @@ describe('spendTokens', () => { conversationId: 'test-convo-1', model: 'claude-3-5-sonnet', context: 'test', + balance: { enabled: true }, }; const tokenUsage1 = { @@ -375,6 +378,7 @@ describe('spendTokens', () => { conversationId: 'test-convo-2', model: 'claude-3-5-sonnet', context: 'test', + balance: { enabled: true }, }; const tokenUsage2 = { @@ -426,6 +430,7 @@ describe('spendTokens', () => { conversationId: 'test-convo', model: 'claude-3-5-sonnet', // Using a model that supports structured tokens context: 'test', + balance: { enabled: true }, }; // Spending more tokens than the user has balance for @@ -505,6 +510,7 @@ describe('spendTokens', () => { conversationId, user: userId, model: usage.model, + balance: { enabled: true }, }; // Calculate expected spend for this transaction @@ -617,6 +623,7 @@ describe('spendTokens', () => { tokenType: 'credits', context: 'concurrent-refill-test', rawAmount: refillAmount, + balance: { enabled: true }, }), ); } @@ -683,6 +690,7 @@ describe('spendTokens', () => { conversationId: 'test-convo', model: 'claude-3-5-sonnet', context: 'test', + balance: { enabled: true }, }; const tokenUsage = { promptTokens: { diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 774c22f128..7a598c69d2 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -627,9 +627,15 @@ class AgentClient extends BaseClient { * @param {Object} params * @param {string} [params.model] * @param {string} [params.context='message'] + * @param {AppConfig['balance']} [params.balance] * @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage] */ - async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) { + async recordCollectedUsage({ + model, + balance, + context = 'message', + collectedUsage = this.collectedUsage, + }) { if (!collectedUsage || !collectedUsage.length) { return; } @@ -651,6 +657,7 @@ class AgentClient extends BaseClient { const txMetadata = { context, + balance, conversationId: this.conversationId, user: this.user ?? this.options.req.user?.id, endpointTokenConfig: this.options.endpointTokenConfig, @@ -1044,7 +1051,7 @@ class AgentClient extends BaseClient { this.artifactPromises.push(...attachments); } - await this.recordCollectedUsage({ context: 'message' }); + await this.recordCollectedUsage({ context: 'message', balance: appConfig?.balance }); } catch (err) { logger.error( '[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage', @@ -1219,9 +1226,10 @@ class AgentClient extends BaseClient { }); await this.recordCollectedUsage({ - model: clientOptions.model, - context: 'title', collectedUsage, + context: 'title', + model: clientOptions.model, + balance: appConfig?.balance, }).catch((err) => { logger.error( '[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage', @@ -1240,17 +1248,26 @@ class AgentClient extends BaseClient { * @param {object} params * @param {number} params.promptTokens * @param {number} params.completionTokens - * @param {OpenAIUsageMetadata} [params.usage] * @param {string} [params.model] + * @param {OpenAIUsageMetadata} [params.usage] + * @param {AppConfig['balance']} [params.balance] * @param {string} [params.context='message'] * @returns {Promise} */ - async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) { + async recordTokenUsage({ + model, + usage, + balance, + promptTokens, + completionTokens, + context = 'message', + }) { try { await spendTokens( { model, context, + balance, conversationId: this.conversationId, user: this.user ?? this.options.req.user?.id, endpointTokenConfig: this.options.endpointTokenConfig, @@ -1267,6 +1284,7 @@ class AgentClient extends BaseClient { await spendTokens( { model, + balance, context: 'reasoning', conversationId: this.conversationId, user: this.user ?? this.options.req.user?.id, diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index c35bb9bd54..1ac5f76928 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -20,10 +20,10 @@ const { deleteUserById, generateRefreshToken, } = require('~/models'); -const { getBalanceConfig, getAppConfig } = require('~/server/services/Config'); const { isEmailDomainAllowed } = require('~/server/services/domains'); const { checkEmailConfig, sendEmail } = require('~/server/utils'); const { registerSchema } = require('~/strategies/validators'); +const { getAppConfig } = require('~/server/services/Config'); const domains = { client: process.env.DOMAIN_CLIENT, @@ -220,9 +220,8 @@ const registerUser = async (user, additionalData = {}) => { const emailEnabled = checkEmailConfig(); const disableTTL = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN); - const balanceConfig = await getBalanceConfig(); - const newUser = await createUser(newUserData, balanceConfig, disableTTL, true); + const newUser = await createUser(newUserData, appConfig.balance, disableTTL, true); newUserId = newUser._id; if (emailEnabled && !newUser.emailVerified) { await sendVerificationEmail({ diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index a66cf664b2..a25ca8b76a 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -1,25 +1,16 @@ const { logger } = require('@librechat/data-schemas'); +const { EModelEndpoint } = require('librechat-data-provider'); const { isEnabled, getUserMCPAuthMap, normalizeEndpointName } = require('@librechat/api'); -const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); -const loadCustomConfig = require('./loadCustomConfig'); -const getLogStores = require('~/cache/getLogStores'); - -/** - * Retrieves the configuration object - * @function getCustomConfig - * @returns {Promise} - * */ -async function getCustomConfig() { - const cache = getLogStores(CacheKeys.STATIC_CONFIG); - return (await cache.get(CacheKeys.LIBRECHAT_YAML_CONFIG)) || (await loadCustomConfig()); -} +const { getAppConfig } = require('./app'); /** * Retrieves the configuration object * @function getBalanceConfig + * @param {Object} params + * @param {string} [params.role] * @returns {Promise} * */ -async function getBalanceConfig() { +async function getBalanceConfig({ role }) { const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE); const startBalance = process.env.START_BALANCE; /** @type {TCustomConfig['balance']} */ @@ -27,11 +18,11 @@ async function getBalanceConfig() { enabled: isLegacyEnabled, startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined, }; - const customConfig = await getCustomConfig(); - if (!customConfig) { + const appConfig = await getAppConfig({ role }); + if (!appConfig) { return config; } - return { ...config, ...(customConfig?.['balance'] ?? {}) }; + return { ...config, ...(appConfig?.['balance'] ?? {}) }; } /** @@ -40,13 +31,12 @@ async function getBalanceConfig() { * @returns {Promise} */ const getCustomEndpointConfig = async (endpoint) => { - const customConfig = await getCustomConfig(); - if (!customConfig) { + const appConfig = await getAppConfig(); + if (!appConfig) { throw new Error(`Config not found for the ${endpoint} custom endpoint.`); } - const { endpoints = {} } = customConfig; - const customEndpoints = endpoints[EModelEndpoint.custom] ?? []; + const customEndpoints = appConfig[EModelEndpoint.custom] ?? []; return customEndpoints.find( (endpointConfig) => normalizeEndpointName(endpointConfig.name) === endpoint, ); @@ -81,14 +71,13 @@ async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) { * @returns {Promise} */ async function hasCustomUserVars() { - const customConfig = await getCustomConfig(); - const mcpServers = customConfig?.mcpServers; + const customConfig = await getAppConfig(); + const mcpServers = customConfig?.mcpConfig; return Object.values(mcpServers ?? {}).some((server) => server.customUserVars); } module.exports = { getMCPAuthMap, - getCustomConfig, getBalanceConfig, hasCustomUserVars, getCustomEndpointConfig, diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index a9c00fa596..712b2e46cf 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -123,6 +123,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { if (!user) { const isFirstRegisteredUser = (await countUsers()) === 0; + const role = isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER; user = { provider: 'ldap', ldapId, @@ -130,9 +131,9 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { email: mail, emailVerified: true, // The ldap server administrator should verify the email name: fullName, - role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER, + role, }; - const balanceConfig = await getBalanceConfig(); + const balanceConfig = await getBalanceConfig({ role }); const userId = await createUser(user, balanceConfig); user._id = userId; } else { diff --git a/api/typedefs.js b/api/typedefs.js index 275379538b..a4a1592f73 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -1783,8 +1783,8 @@ * @property {String} conversationId - The ID of the conversation. * @property {String} model - The model name. * @property {String} context - The context in which the transaction is made. + * @property {AppConfig['balance']} [balance] - The balance config * @property {EndpointTokenConfig} [endpointTokenConfig] - The current endpoint token config. - * @property {object} [cacheUsage] - Cache usage, if any. * @property {String} [valueKey] - The value key (optional). * @memberof typedefs */ @@ -1829,6 +1829,7 @@ * @callback sendCompletion * @param {Array | string} payload - The messages or prompt to send to the model * @param {object} opts - Options for the completion + * @param {AppConfig} opts.appConfig - Callback function to handle token progress * @param {onTokenProgress} opts.onProgress - Callback function to handle token progress * @param {AbortController} opts.abortController - AbortController instance * @returns {Promise} diff --git a/packages/api/src/middleware/balance.ts b/packages/api/src/middleware/balance.ts index 3aaa20da68..2a84c886cc 100644 --- a/packages/api/src/middleware/balance.ts +++ b/packages/api/src/middleware/balance.ts @@ -5,7 +5,7 @@ import type { Model } from 'mongoose'; import type { BalanceUpdateFields } from '~/types'; export interface BalanceMiddlewareOptions { - getBalanceConfig: () => Promise; + getBalanceConfig: ({ role }?: { role?: string }) => Promise; Balance: Model; } @@ -82,7 +82,8 @@ export function createSetBalanceConfig({ ) => Promise { return async (req: ServerRequest, res: ServerResponse, next: NextFunction): Promise => { try { - const balanceConfig = await getBalanceConfig(); + const user = req.user as IUser & { _id: string | ObjectId }; + const balanceConfig = await getBalanceConfig({ role: user?.role }); if (!balanceConfig?.enabled) { return next(); } @@ -90,7 +91,6 @@ export function createSetBalanceConfig({ return next(); } - const user = req.user as IUser & { _id: string | ObjectId }; if (!user || !user._id) { return next(); }