refactor: get balance config primarily from appConfig

This commit is contained in:
Danny Avila 2025-08-18 03:45:02 -04:00
parent d853c10920
commit 1d2be247cf
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
11 changed files with 89 additions and 79 deletions

View file

@ -113,13 +113,15 @@ class BaseClient {
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
* Should only be used if `recordCollectedUsage` was not used instead.
* @param {string} [model]
* @param {AppConfig['balance']} [balance]
* @param {number} promptTokens
* @param {number} completionTokens
* @returns {Promise<void>}
*/
async recordTokenUsage({ model, promptTokens, completionTokens }) {
async recordTokenUsage({ model, balance, promptTokens, completionTokens }) {
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
model,
balance,
promptTokens,
completionTokens,
});
@ -754,6 +756,7 @@ class BaseClient {
completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({
usage,
balance,
promptTokens,
completionTokens,
model: responseMessage.model,

View file

@ -1,5 +1,4 @@
const { logger } = require('@librechat/data-schemas');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { Transaction, Balance } = require('~/db/models');
@ -187,9 +186,10 @@ async function createAutoRefillTransaction(txData) {
/**
* Static method to create a transaction and update the balance
* @param {txData} txData - Transaction data.
* @param {txData} _txData - Transaction data.
*/
async function createTransaction(txData) {
async function createTransaction(_txData) {
const { balance, ...txData } = _txData;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}
@ -199,8 +199,6 @@ async function createTransaction(txData) {
calculateTokenValue(transaction);
await transaction.save();
const balance = await getBalanceConfig();
if (!balance?.enabled) {
return;
}
@ -221,9 +219,10 @@ async function createTransaction(txData) {
/**
* Static method to create a structured transaction and update the balance
* @param {txData} txData - Transaction data.
* @param {txData} _txData - Transaction data.
*/
async function createStructuredTransaction(txData) {
async function createStructuredTransaction(_txData) {
const { balance, ...txData } = _txData;
const transaction = new Transaction({
...txData,
endpointTokenConfig: txData.endpointTokenConfig,
@ -233,7 +232,6 @@ async function createStructuredTransaction(txData) {
await transaction.save();
const balance = await getBalanceConfig();
if (!balance?.enabled) {
return;
}

View file

@ -1,14 +1,11 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { createTransaction } = require('./Transaction');
const { Balance } = require('~/db/models');
// Mock the custom config module so we can control the balance flag.
jest.mock('~/server/services/Config');
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
@ -23,8 +20,6 @@ afterAll(async () => {
beforeEach(async () => {
await mongoose.connection.dropDatabase();
// Default: enable balance updates in tests.
getBalanceConfig.mockResolvedValue({ enabled: true });
});
describe('Regular Token Spending Tests', () => {
@ -41,6 +36,7 @@ describe('Regular Token Spending Tests', () => {
model,
context: 'test',
endpointTokenConfig: null,
balance: { enabled: true },
};
const tokenUsage = {
@ -74,6 +70,7 @@ describe('Regular Token Spending Tests', () => {
model,
context: 'test',
endpointTokenConfig: null,
balance: { enabled: true },
};
const tokenUsage = {
@ -104,6 +101,7 @@ describe('Regular Token Spending Tests', () => {
model,
context: 'test',
endpointTokenConfig: null,
balance: { enabled: true },
};
const tokenUsage = {};
@ -128,6 +126,7 @@ describe('Regular Token Spending Tests', () => {
model,
context: 'test',
endpointTokenConfig: null,
balance: { enabled: true },
};
const tokenUsage = { promptTokens: 100 };
@ -143,8 +142,7 @@ describe('Regular Token Spending Tests', () => {
});
test('spendTokens should not update balance when balance feature is disabled', async () => {
// Arrange: Override the config to disable balance updates.
getBalanceConfig.mockResolvedValue({ balance: { enabled: false } });
// Arrange: Balance config is now passed directly in txData
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
@ -156,6 +154,7 @@ describe('Regular Token Spending Tests', () => {
model,
context: 'test',
endpointTokenConfig: null,
balance: { enabled: false },
};
const tokenUsage = {
@ -186,6 +185,7 @@ describe('Structured Token Spending Tests', () => {
model,
context: 'message',
endpointTokenConfig: null,
balance: { enabled: true },
};
const tokenUsage = {
@ -239,6 +239,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo',
model,
context: 'message',
balance: { enabled: true },
};
const tokenUsage = {
@ -271,6 +272,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo',
model,
context: 'message',
balance: { enabled: true },
};
const tokenUsage = {
@ -302,6 +304,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo',
model,
context: 'message',
balance: { enabled: true },
};
const tokenUsage = {};
@ -328,6 +331,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo',
model,
context: 'incomplete',
balance: { enabled: true },
};
const tokenUsage = {
@ -364,6 +368,7 @@ describe('NaN Handling Tests', () => {
endpointTokenConfig: null,
rawAmount: NaN,
tokenType: 'prompt',
balance: { enabled: true },
};
// Act

View file

@ -5,13 +5,7 @@ const { createTransaction, createStructuredTransaction } = require('./Transactio
*
* @function
* @async
* @param {Object} txData - Transaction data.
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
* @param {String} txData.conversationId - The ID of the conversation.
* @param {String} txData.model - The model name.
* @param {String} txData.context - The context in which the transaction is made.
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
* @param {String} [txData.valueKey] - The value key (optional).
* @param {txData} txData - Transaction data.
* @param {Object} tokenUsage - The number of tokens used.
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
@ -69,13 +63,7 @@ const spendTokens = async (txData, tokenUsage) => {
*
* @function
* @async
* @param {Object} txData - Transaction data.
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
* @param {String} txData.conversationId - The ID of the conversation.
* @param {String} txData.model - The model name.
* @param {String} txData.context - The context in which the transaction is made.
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
* @param {String} [txData.valueKey] - The value key (optional).
* @param {txData} txData - Transaction data.
* @param {Object} tokenUsage - The number of tokens used.
* @param {Object} tokenUsage.promptTokens - The number of prompt tokens used.
* @param {Number} tokenUsage.promptTokens.input - The number of input tokens.

View file

@ -5,7 +5,6 @@ const { createTransaction, createAutoRefillTransaction } = require('./Transactio
require('~/db/models');
// Mock the logger to prevent console output during tests
jest.mock('~/config', () => ({
logger: {
debug: jest.fn(),
@ -13,10 +12,6 @@ jest.mock('~/config', () => ({
},
}));
// Mock the Config service
const { getBalanceConfig } = require('~/server/services/Config');
jest.mock('~/server/services/Config');
describe('spendTokens', () => {
let mongoServer;
let userId;
@ -44,8 +39,7 @@ describe('spendTokens', () => {
// Create a new user ID for each test
userId = new mongoose.Types.ObjectId();
// Mock the balance config to be enabled by default
getBalanceConfig.mockResolvedValue({ enabled: true });
// Balance config is now passed directly in txData
});
it('should create transactions for both prompt and completion tokens', async () => {
@ -60,6 +54,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
balance: { enabled: true },
};
const tokenUsage = {
promptTokens: 100,
@ -98,6 +93,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
balance: { enabled: true },
};
const tokenUsage = {
promptTokens: 100,
@ -127,6 +123,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
balance: { enabled: true },
};
const tokenUsage = {};
@ -138,8 +135,7 @@ describe('spendTokens', () => {
});
it('should not update balance when the balance feature is disabled', async () => {
// Override configuration: disable balance updates
getBalanceConfig.mockResolvedValue({ enabled: false });
// Balance is now passed directly in txData
// Create a balance for the user
await Balance.create({
user: userId,
@ -151,6 +147,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
balance: { enabled: false },
};
const tokenUsage = {
promptTokens: 100,
@ -180,6 +177,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo',
model: 'gpt-4', // Using a more expensive model
context: 'test',
balance: { enabled: true },
};
// Spending more tokens than the user has balance for
@ -233,6 +231,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-1',
model: 'gpt-4',
context: 'test',
balance: { enabled: true },
};
const tokenUsage1 = {
@ -252,6 +251,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-2',
model: 'gpt-4',
context: 'test',
balance: { enabled: true },
};
const tokenUsage2 = {
@ -292,6 +292,7 @@ describe('spendTokens', () => {
tokenType: 'completion',
rawAmount: -100,
context: 'test',
balance: { enabled: true },
});
console.log('Direct Transaction.create result:', directResult);
@ -316,6 +317,7 @@ describe('spendTokens', () => {
conversationId: `test-convo-${model}`,
model,
context: 'test',
balance: { enabled: true },
};
const tokenUsage = {
@ -352,6 +354,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-1',
model: 'claude-3-5-sonnet',
context: 'test',
balance: { enabled: true },
};
const tokenUsage1 = {
@ -375,6 +378,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-2',
model: 'claude-3-5-sonnet',
context: 'test',
balance: { enabled: true },
};
const tokenUsage2 = {
@ -426,6 +430,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo',
model: 'claude-3-5-sonnet', // Using a model that supports structured tokens
context: 'test',
balance: { enabled: true },
};
// Spending more tokens than the user has balance for
@ -505,6 +510,7 @@ describe('spendTokens', () => {
conversationId,
user: userId,
model: usage.model,
balance: { enabled: true },
};
// Calculate expected spend for this transaction
@ -617,6 +623,7 @@ describe('spendTokens', () => {
tokenType: 'credits',
context: 'concurrent-refill-test',
rawAmount: refillAmount,
balance: { enabled: true },
}),
);
}
@ -683,6 +690,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo',
model: 'claude-3-5-sonnet',
context: 'test',
balance: { enabled: true },
};
const tokenUsage = {
promptTokens: {

View file

@ -627,9 +627,15 @@ class AgentClient extends BaseClient {
* @param {Object} params
* @param {string} [params.model]
* @param {string} [params.context='message']
* @param {AppConfig['balance']} [params.balance]
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
*/
async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) {
async recordCollectedUsage({
model,
balance,
context = 'message',
collectedUsage = this.collectedUsage,
}) {
if (!collectedUsage || !collectedUsage.length) {
return;
}
@ -651,6 +657,7 @@ class AgentClient extends BaseClient {
const txMetadata = {
context,
balance,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
@ -1044,7 +1051,7 @@ class AgentClient extends BaseClient {
this.artifactPromises.push(...attachments);
}
await this.recordCollectedUsage({ context: 'message' });
await this.recordCollectedUsage({ context: 'message', balance: appConfig?.balance });
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
@ -1219,9 +1226,10 @@ class AgentClient extends BaseClient {
});
await this.recordCollectedUsage({
model: clientOptions.model,
context: 'title',
collectedUsage,
context: 'title',
model: clientOptions.model,
balance: appConfig?.balance,
}).catch((err) => {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
@ -1240,17 +1248,26 @@ class AgentClient extends BaseClient {
* @param {object} params
* @param {number} params.promptTokens
* @param {number} params.completionTokens
* @param {OpenAIUsageMetadata} [params.usage]
* @param {string} [params.model]
* @param {OpenAIUsageMetadata} [params.usage]
* @param {AppConfig['balance']} [params.balance]
* @param {string} [params.context='message']
* @returns {Promise<void>}
*/
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) {
async recordTokenUsage({
model,
usage,
balance,
promptTokens,
completionTokens,
context = 'message',
}) {
try {
await spendTokens(
{
model,
context,
balance,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
@ -1267,6 +1284,7 @@ class AgentClient extends BaseClient {
await spendTokens(
{
model,
balance,
context: 'reasoning',
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,

View file

@ -20,10 +20,10 @@ const {
deleteUserById,
generateRefreshToken,
} = require('~/models');
const { getBalanceConfig, getAppConfig } = require('~/server/services/Config');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { checkEmailConfig, sendEmail } = require('~/server/utils');
const { registerSchema } = require('~/strategies/validators');
const { getAppConfig } = require('~/server/services/Config');
const domains = {
client: process.env.DOMAIN_CLIENT,
@ -220,9 +220,8 @@ const registerUser = async (user, additionalData = {}) => {
const emailEnabled = checkEmailConfig();
const disableTTL = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN);
const balanceConfig = await getBalanceConfig();
const newUser = await createUser(newUserData, balanceConfig, disableTTL, true);
const newUser = await createUser(newUserData, appConfig.balance, disableTTL, true);
newUserId = newUser._id;
if (emailEnabled && !newUser.emailVerified) {
await sendVerificationEmail({

View file

@ -1,25 +1,16 @@
const { logger } = require('@librechat/data-schemas');
const { EModelEndpoint } = require('librechat-data-provider');
const { isEnabled, getUserMCPAuthMap, normalizeEndpointName } = require('@librechat/api');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
/**
* Retrieves the configuration object
* @function getCustomConfig
* @returns {Promise<TCustomConfig | null>}
* */
async function getCustomConfig() {
const cache = getLogStores(CacheKeys.STATIC_CONFIG);
return (await cache.get(CacheKeys.LIBRECHAT_YAML_CONFIG)) || (await loadCustomConfig());
}
const { getAppConfig } = require('./app');
/**
* Retrieves the configuration object
* @function getBalanceConfig
* @param {Object} params
* @param {string} [params.role]
* @returns {Promise<TCustomConfig['balance'] | null>}
* */
async function getBalanceConfig() {
async function getBalanceConfig({ role }) {
const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE);
const startBalance = process.env.START_BALANCE;
/** @type {TCustomConfig['balance']} */
@ -27,11 +18,11 @@ async function getBalanceConfig() {
enabled: isLegacyEnabled,
startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined,
};
const customConfig = await getCustomConfig();
if (!customConfig) {
const appConfig = await getAppConfig({ role });
if (!appConfig) {
return config;
}
return { ...config, ...(customConfig?.['balance'] ?? {}) };
return { ...config, ...(appConfig?.['balance'] ?? {}) };
}
/**
@ -40,13 +31,12 @@ async function getBalanceConfig() {
* @returns {Promise<TEndpoint | undefined>}
*/
const getCustomEndpointConfig = async (endpoint) => {
const customConfig = await getCustomConfig();
if (!customConfig) {
const appConfig = await getAppConfig();
if (!appConfig) {
throw new Error(`Config not found for the ${endpoint} custom endpoint.`);
}
const { endpoints = {} } = customConfig;
const customEndpoints = endpoints[EModelEndpoint.custom] ?? [];
const customEndpoints = appConfig[EModelEndpoint.custom] ?? [];
return customEndpoints.find(
(endpointConfig) => normalizeEndpointName(endpointConfig.name) === endpoint,
);
@ -81,14 +71,13 @@ async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
* @returns {Promise<boolean>}
*/
async function hasCustomUserVars() {
const customConfig = await getCustomConfig();
const mcpServers = customConfig?.mcpServers;
const customConfig = await getAppConfig();
const mcpServers = customConfig?.mcpConfig;
return Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
}
module.exports = {
getMCPAuthMap,
getCustomConfig,
getBalanceConfig,
hasCustomUserVars,
getCustomEndpointConfig,

View file

@ -123,6 +123,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
if (!user) {
const isFirstRegisteredUser = (await countUsers()) === 0;
const role = isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER;
user = {
provider: 'ldap',
ldapId,
@ -130,9 +131,9 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
email: mail,
emailVerified: true, // The ldap server administrator should verify the email
name: fullName,
role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER,
role,
};
const balanceConfig = await getBalanceConfig();
const balanceConfig = await getBalanceConfig({ role });
const userId = await createUser(user, balanceConfig);
user._id = userId;
} else {

View file

@ -1783,8 +1783,8 @@
* @property {String} conversationId - The ID of the conversation.
* @property {String} model - The model name.
* @property {String} context - The context in which the transaction is made.
* @property {AppConfig['balance']} [balance] - The balance config
* @property {EndpointTokenConfig} [endpointTokenConfig] - The current endpoint token config.
* @property {object} [cacheUsage] - Cache usage, if any.
* @property {String} [valueKey] - The value key (optional).
* @memberof typedefs
*/
@ -1829,6 +1829,7 @@
* @callback sendCompletion
* @param {Array<ChatCompletionMessage> | string} payload - The messages or prompt to send to the model
* @param {object} opts - Options for the completion
* @param {AppConfig} opts.appConfig - Callback function to handle token progress
* @param {onTokenProgress} opts.onProgress - Callback function to handle token progress
* @param {AbortController} opts.abortController - AbortController instance
* @returns {Promise<string>}

View file

@ -5,7 +5,7 @@ import type { Model } from 'mongoose';
import type { BalanceUpdateFields } from '~/types';
export interface BalanceMiddlewareOptions {
getBalanceConfig: () => Promise<BalanceConfig | null>;
getBalanceConfig: ({ role }?: { role?: string }) => Promise<BalanceConfig | null>;
Balance: Model<IBalance>;
}
@ -82,7 +82,8 @@ export function createSetBalanceConfig({
) => Promise<void> {
return async (req: ServerRequest, res: ServerResponse, next: NextFunction): Promise<void> => {
try {
const balanceConfig = await getBalanceConfig();
const user = req.user as IUser & { _id: string | ObjectId };
const balanceConfig = await getBalanceConfig({ role: user?.role });
if (!balanceConfig?.enabled) {
return next();
}
@ -90,7 +91,6 @@ export function createSetBalanceConfig({
return next();
}
const user = req.user as IUser & { _id: string | ObjectId };
if (!user || !user._id) {
return next();
}