refactor: get balance config primarily from appConfig

This commit is contained in:
Danny Avila 2025-08-18 03:45:02 -04:00
parent d853c10920
commit 1d2be247cf
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
11 changed files with 89 additions and 79 deletions

View file

@ -113,13 +113,15 @@ class BaseClient {
* If a correction to the token usage is needed, the method should return an object with the corrected token counts. * If a correction to the token usage is needed, the method should return an object with the corrected token counts.
* Should only be used if `recordCollectedUsage` was not used instead. * Should only be used if `recordCollectedUsage` was not used instead.
* @param {string} [model] * @param {string} [model]
* @param {AppConfig['balance']} [balance]
* @param {number} promptTokens * @param {number} promptTokens
* @param {number} completionTokens * @param {number} completionTokens
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
async recordTokenUsage({ model, promptTokens, completionTokens }) { async recordTokenUsage({ model, balance, promptTokens, completionTokens }) {
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', { logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
model, model,
balance,
promptTokens, promptTokens,
completionTokens, completionTokens,
}); });
@ -754,6 +756,7 @@ class BaseClient {
completionTokens = responseMessage.tokenCount; completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({ await this.recordTokenUsage({
usage, usage,
balance,
promptTokens, promptTokens,
completionTokens, completionTokens,
model: responseMessage.model, model: responseMessage.model,

View file

@ -1,5 +1,4 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx'); const { getMultiplier, getCacheMultiplier } = require('./tx');
const { Transaction, Balance } = require('~/db/models'); const { Transaction, Balance } = require('~/db/models');
@ -187,9 +186,10 @@ async function createAutoRefillTransaction(txData) {
/** /**
* Static method to create a transaction and update the balance * Static method to create a transaction and update the balance
* @param {txData} txData - Transaction data. * @param {txData} _txData - Transaction data.
*/ */
async function createTransaction(txData) { async function createTransaction(_txData) {
const { balance, ...txData } = _txData;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) { if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return; return;
} }
@ -199,8 +199,6 @@ async function createTransaction(txData) {
calculateTokenValue(transaction); calculateTokenValue(transaction);
await transaction.save(); await transaction.save();
const balance = await getBalanceConfig();
if (!balance?.enabled) { if (!balance?.enabled) {
return; return;
} }
@ -221,9 +219,10 @@ async function createTransaction(txData) {
/** /**
* Static method to create a structured transaction and update the balance * Static method to create a structured transaction and update the balance
* @param {txData} txData - Transaction data. * @param {txData} _txData - Transaction data.
*/ */
async function createStructuredTransaction(txData) { async function createStructuredTransaction(_txData) {
const { balance, ...txData } = _txData;
const transaction = new Transaction({ const transaction = new Transaction({
...txData, ...txData,
endpointTokenConfig: txData.endpointTokenConfig, endpointTokenConfig: txData.endpointTokenConfig,
@ -233,7 +232,6 @@ async function createStructuredTransaction(txData) {
await transaction.save(); await transaction.save();
const balance = await getBalanceConfig();
if (!balance?.enabled) { if (!balance?.enabled) {
return; return;
} }

View file

@ -1,14 +1,11 @@
const mongoose = require('mongoose'); const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server'); const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx'); const { getMultiplier, getCacheMultiplier } = require('./tx');
const { createTransaction } = require('./Transaction'); const { createTransaction } = require('./Transaction');
const { Balance } = require('~/db/models'); const { Balance } = require('~/db/models');
// Mock the custom config module so we can control the balance flag.
jest.mock('~/server/services/Config');
let mongoServer; let mongoServer;
beforeAll(async () => { beforeAll(async () => {
mongoServer = await MongoMemoryServer.create(); mongoServer = await MongoMemoryServer.create();
@ -23,8 +20,6 @@ afterAll(async () => {
beforeEach(async () => { beforeEach(async () => {
await mongoose.connection.dropDatabase(); await mongoose.connection.dropDatabase();
// Default: enable balance updates in tests.
getBalanceConfig.mockResolvedValue({ enabled: true });
}); });
describe('Regular Token Spending Tests', () => { describe('Regular Token Spending Tests', () => {
@ -41,6 +36,7 @@ describe('Regular Token Spending Tests', () => {
model, model,
context: 'test', context: 'test',
endpointTokenConfig: null, endpointTokenConfig: null,
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
@ -74,6 +70,7 @@ describe('Regular Token Spending Tests', () => {
model, model,
context: 'test', context: 'test',
endpointTokenConfig: null, endpointTokenConfig: null,
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
@ -104,6 +101,7 @@ describe('Regular Token Spending Tests', () => {
model, model,
context: 'test', context: 'test',
endpointTokenConfig: null, endpointTokenConfig: null,
balance: { enabled: true },
}; };
const tokenUsage = {}; const tokenUsage = {};
@ -128,6 +126,7 @@ describe('Regular Token Spending Tests', () => {
model, model,
context: 'test', context: 'test',
endpointTokenConfig: null, endpointTokenConfig: null,
balance: { enabled: true },
}; };
const tokenUsage = { promptTokens: 100 }; const tokenUsage = { promptTokens: 100 };
@ -143,8 +142,7 @@ describe('Regular Token Spending Tests', () => {
}); });
test('spendTokens should not update balance when balance feature is disabled', async () => { test('spendTokens should not update balance when balance feature is disabled', async () => {
// Arrange: Override the config to disable balance updates. // Arrange: Balance config is now passed directly in txData
getBalanceConfig.mockResolvedValue({ balance: { enabled: false } });
const userId = new mongoose.Types.ObjectId(); const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000; const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance }); await Balance.create({ user: userId, tokenCredits: initialBalance });
@ -156,6 +154,7 @@ describe('Regular Token Spending Tests', () => {
model, model,
context: 'test', context: 'test',
endpointTokenConfig: null, endpointTokenConfig: null,
balance: { enabled: false },
}; };
const tokenUsage = { const tokenUsage = {
@ -186,6 +185,7 @@ describe('Structured Token Spending Tests', () => {
model, model,
context: 'message', context: 'message',
endpointTokenConfig: null, endpointTokenConfig: null,
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
@ -239,6 +239,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model, model,
context: 'message', context: 'message',
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
@ -271,6 +272,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model, model,
context: 'message', context: 'message',
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
@ -302,6 +304,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model, model,
context: 'message', context: 'message',
balance: { enabled: true },
}; };
const tokenUsage = {}; const tokenUsage = {};
@ -328,6 +331,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model, model,
context: 'incomplete', context: 'incomplete',
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
@ -364,6 +368,7 @@ describe('NaN Handling Tests', () => {
endpointTokenConfig: null, endpointTokenConfig: null,
rawAmount: NaN, rawAmount: NaN,
tokenType: 'prompt', tokenType: 'prompt',
balance: { enabled: true },
}; };
// Act // Act

View file

@ -5,13 +5,7 @@ const { createTransaction, createStructuredTransaction } = require('./Transactio
* *
* @function * @function
* @async * @async
* @param {Object} txData - Transaction data. * @param {txData} txData - Transaction data.
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
* @param {String} txData.conversationId - The ID of the conversation.
* @param {String} txData.model - The model name.
* @param {String} txData.context - The context in which the transaction is made.
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
* @param {String} [txData.valueKey] - The value key (optional).
* @param {Object} tokenUsage - The number of tokens used. * @param {Object} tokenUsage - The number of tokens used.
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used. * @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used. * @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
@ -69,13 +63,7 @@ const spendTokens = async (txData, tokenUsage) => {
* *
* @function * @function
* @async * @async
* @param {Object} txData - Transaction data. * @param {txData} txData - Transaction data.
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
* @param {String} txData.conversationId - The ID of the conversation.
* @param {String} txData.model - The model name.
* @param {String} txData.context - The context in which the transaction is made.
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
* @param {String} [txData.valueKey] - The value key (optional).
* @param {Object} tokenUsage - The number of tokens used. * @param {Object} tokenUsage - The number of tokens used.
* @param {Object} tokenUsage.promptTokens - The number of prompt tokens used. * @param {Object} tokenUsage.promptTokens - The number of prompt tokens used.
* @param {Number} tokenUsage.promptTokens.input - The number of input tokens. * @param {Number} tokenUsage.promptTokens.input - The number of input tokens.

View file

@ -5,7 +5,6 @@ const { createTransaction, createAutoRefillTransaction } = require('./Transactio
require('~/db/models'); require('~/db/models');
// Mock the logger to prevent console output during tests
jest.mock('~/config', () => ({ jest.mock('~/config', () => ({
logger: { logger: {
debug: jest.fn(), debug: jest.fn(),
@ -13,10 +12,6 @@ jest.mock('~/config', () => ({
}, },
})); }));
// Mock the Config service
const { getBalanceConfig } = require('~/server/services/Config');
jest.mock('~/server/services/Config');
describe('spendTokens', () => { describe('spendTokens', () => {
let mongoServer; let mongoServer;
let userId; let userId;
@ -44,8 +39,7 @@ describe('spendTokens', () => {
// Create a new user ID for each test // Create a new user ID for each test
userId = new mongoose.Types.ObjectId(); userId = new mongoose.Types.ObjectId();
// Mock the balance config to be enabled by default // Balance config is now passed directly in txData
getBalanceConfig.mockResolvedValue({ enabled: true });
}); });
it('should create transactions for both prompt and completion tokens', async () => { it('should create transactions for both prompt and completion tokens', async () => {
@ -60,6 +54,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
promptTokens: 100, promptTokens: 100,
@ -98,6 +93,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
promptTokens: 100, promptTokens: 100,
@ -127,6 +123,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage = {}; const tokenUsage = {};
@ -138,8 +135,7 @@ describe('spendTokens', () => {
}); });
it('should not update balance when the balance feature is disabled', async () => { it('should not update balance when the balance feature is disabled', async () => {
// Override configuration: disable balance updates // Balance is now passed directly in txData
getBalanceConfig.mockResolvedValue({ enabled: false });
// Create a balance for the user // Create a balance for the user
await Balance.create({ await Balance.create({
user: userId, user: userId,
@ -151,6 +147,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
context: 'test', context: 'test',
balance: { enabled: false },
}; };
const tokenUsage = { const tokenUsage = {
promptTokens: 100, promptTokens: 100,
@ -180,6 +177,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model: 'gpt-4', // Using a more expensive model model: 'gpt-4', // Using a more expensive model
context: 'test', context: 'test',
balance: { enabled: true },
}; };
// Spending more tokens than the user has balance for // Spending more tokens than the user has balance for
@ -233,6 +231,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-1', conversationId: 'test-convo-1',
model: 'gpt-4', model: 'gpt-4',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage1 = { const tokenUsage1 = {
@ -252,6 +251,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-2', conversationId: 'test-convo-2',
model: 'gpt-4', model: 'gpt-4',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage2 = { const tokenUsage2 = {
@ -292,6 +292,7 @@ describe('spendTokens', () => {
tokenType: 'completion', tokenType: 'completion',
rawAmount: -100, rawAmount: -100,
context: 'test', context: 'test',
balance: { enabled: true },
}); });
console.log('Direct Transaction.create result:', directResult); console.log('Direct Transaction.create result:', directResult);
@ -316,6 +317,7 @@ describe('spendTokens', () => {
conversationId: `test-convo-${model}`, conversationId: `test-convo-${model}`,
model, model,
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
@ -352,6 +354,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-1', conversationId: 'test-convo-1',
model: 'claude-3-5-sonnet', model: 'claude-3-5-sonnet',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage1 = { const tokenUsage1 = {
@ -375,6 +378,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo-2', conversationId: 'test-convo-2',
model: 'claude-3-5-sonnet', model: 'claude-3-5-sonnet',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage2 = { const tokenUsage2 = {
@ -426,6 +430,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model: 'claude-3-5-sonnet', // Using a model that supports structured tokens model: 'claude-3-5-sonnet', // Using a model that supports structured tokens
context: 'test', context: 'test',
balance: { enabled: true },
}; };
// Spending more tokens than the user has balance for // Spending more tokens than the user has balance for
@ -505,6 +510,7 @@ describe('spendTokens', () => {
conversationId, conversationId,
user: userId, user: userId,
model: usage.model, model: usage.model,
balance: { enabled: true },
}; };
// Calculate expected spend for this transaction // Calculate expected spend for this transaction
@ -617,6 +623,7 @@ describe('spendTokens', () => {
tokenType: 'credits', tokenType: 'credits',
context: 'concurrent-refill-test', context: 'concurrent-refill-test',
rawAmount: refillAmount, rawAmount: refillAmount,
balance: { enabled: true },
}), }),
); );
} }
@ -683,6 +690,7 @@ describe('spendTokens', () => {
conversationId: 'test-convo', conversationId: 'test-convo',
model: 'claude-3-5-sonnet', model: 'claude-3-5-sonnet',
context: 'test', context: 'test',
balance: { enabled: true },
}; };
const tokenUsage = { const tokenUsage = {
promptTokens: { promptTokens: {

View file

@ -627,9 +627,15 @@ class AgentClient extends BaseClient {
* @param {Object} params * @param {Object} params
* @param {string} [params.model] * @param {string} [params.model]
* @param {string} [params.context='message'] * @param {string} [params.context='message']
* @param {AppConfig['balance']} [params.balance]
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage] * @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
*/ */
async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) { async recordCollectedUsage({
model,
balance,
context = 'message',
collectedUsage = this.collectedUsage,
}) {
if (!collectedUsage || !collectedUsage.length) { if (!collectedUsage || !collectedUsage.length) {
return; return;
} }
@ -651,6 +657,7 @@ class AgentClient extends BaseClient {
const txMetadata = { const txMetadata = {
context, context,
balance,
conversationId: this.conversationId, conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id, user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig, endpointTokenConfig: this.options.endpointTokenConfig,
@ -1044,7 +1051,7 @@ class AgentClient extends BaseClient {
this.artifactPromises.push(...attachments); this.artifactPromises.push(...attachments);
} }
await this.recordCollectedUsage({ context: 'message' }); await this.recordCollectedUsage({ context: 'message', balance: appConfig?.balance });
} catch (err) { } catch (err) {
logger.error( logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage', '[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
@ -1219,9 +1226,10 @@ class AgentClient extends BaseClient {
}); });
await this.recordCollectedUsage({ await this.recordCollectedUsage({
model: clientOptions.model,
context: 'title',
collectedUsage, collectedUsage,
context: 'title',
model: clientOptions.model,
balance: appConfig?.balance,
}).catch((err) => { }).catch((err) => {
logger.error( logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage', '[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
@ -1240,17 +1248,26 @@ class AgentClient extends BaseClient {
* @param {object} params * @param {object} params
* @param {number} params.promptTokens * @param {number} params.promptTokens
* @param {number} params.completionTokens * @param {number} params.completionTokens
* @param {OpenAIUsageMetadata} [params.usage]
* @param {string} [params.model] * @param {string} [params.model]
* @param {OpenAIUsageMetadata} [params.usage]
* @param {AppConfig['balance']} [params.balance]
* @param {string} [params.context='message'] * @param {string} [params.context='message']
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) { async recordTokenUsage({
model,
usage,
balance,
promptTokens,
completionTokens,
context = 'message',
}) {
try { try {
await spendTokens( await spendTokens(
{ {
model, model,
context, context,
balance,
conversationId: this.conversationId, conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id, user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig, endpointTokenConfig: this.options.endpointTokenConfig,
@ -1267,6 +1284,7 @@ class AgentClient extends BaseClient {
await spendTokens( await spendTokens(
{ {
model, model,
balance,
context: 'reasoning', context: 'reasoning',
conversationId: this.conversationId, conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id, user: this.user ?? this.options.req.user?.id,

View file

@ -20,10 +20,10 @@ const {
deleteUserById, deleteUserById,
generateRefreshToken, generateRefreshToken,
} = require('~/models'); } = require('~/models');
const { getBalanceConfig, getAppConfig } = require('~/server/services/Config');
const { isEmailDomainAllowed } = require('~/server/services/domains'); const { isEmailDomainAllowed } = require('~/server/services/domains');
const { checkEmailConfig, sendEmail } = require('~/server/utils'); const { checkEmailConfig, sendEmail } = require('~/server/utils');
const { registerSchema } = require('~/strategies/validators'); const { registerSchema } = require('~/strategies/validators');
const { getAppConfig } = require('~/server/services/Config');
const domains = { const domains = {
client: process.env.DOMAIN_CLIENT, client: process.env.DOMAIN_CLIENT,
@ -220,9 +220,8 @@ const registerUser = async (user, additionalData = {}) => {
const emailEnabled = checkEmailConfig(); const emailEnabled = checkEmailConfig();
const disableTTL = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN); const disableTTL = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN);
const balanceConfig = await getBalanceConfig();
const newUser = await createUser(newUserData, balanceConfig, disableTTL, true); const newUser = await createUser(newUserData, appConfig.balance, disableTTL, true);
newUserId = newUser._id; newUserId = newUser._id;
if (emailEnabled && !newUser.emailVerified) { if (emailEnabled && !newUser.emailVerified) {
await sendVerificationEmail({ await sendVerificationEmail({

View file

@ -1,25 +1,16 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { EModelEndpoint } = require('librechat-data-provider');
const { isEnabled, getUserMCPAuthMap, normalizeEndpointName } = require('@librechat/api'); const { isEnabled, getUserMCPAuthMap, normalizeEndpointName } = require('@librechat/api');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { getAppConfig } = require('./app');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
/**
* Retrieves the configuration object
* @function getCustomConfig
* @returns {Promise<TCustomConfig | null>}
* */
async function getCustomConfig() {
const cache = getLogStores(CacheKeys.STATIC_CONFIG);
return (await cache.get(CacheKeys.LIBRECHAT_YAML_CONFIG)) || (await loadCustomConfig());
}
/** /**
* Retrieves the configuration object * Retrieves the configuration object
* @function getBalanceConfig * @function getBalanceConfig
* @param {Object} params
* @param {string} [params.role]
* @returns {Promise<TCustomConfig['balance'] | null>} * @returns {Promise<TCustomConfig['balance'] | null>}
* */ * */
async function getBalanceConfig() { async function getBalanceConfig({ role }) {
const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE); const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE);
const startBalance = process.env.START_BALANCE; const startBalance = process.env.START_BALANCE;
/** @type {TCustomConfig['balance']} */ /** @type {TCustomConfig['balance']} */
@ -27,11 +18,11 @@ async function getBalanceConfig() {
enabled: isLegacyEnabled, enabled: isLegacyEnabled,
startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined, startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined,
}; };
const customConfig = await getCustomConfig(); const appConfig = await getAppConfig({ role });
if (!customConfig) { if (!appConfig) {
return config; return config;
} }
return { ...config, ...(customConfig?.['balance'] ?? {}) }; return { ...config, ...(appConfig?.['balance'] ?? {}) };
} }
/** /**
@ -40,13 +31,12 @@ async function getBalanceConfig() {
* @returns {Promise<TEndpoint | undefined>} * @returns {Promise<TEndpoint | undefined>}
*/ */
const getCustomEndpointConfig = async (endpoint) => { const getCustomEndpointConfig = async (endpoint) => {
const customConfig = await getCustomConfig(); const appConfig = await getAppConfig();
if (!customConfig) { if (!appConfig) {
throw new Error(`Config not found for the ${endpoint} custom endpoint.`); throw new Error(`Config not found for the ${endpoint} custom endpoint.`);
} }
const { endpoints = {} } = customConfig; const customEndpoints = appConfig[EModelEndpoint.custom] ?? [];
const customEndpoints = endpoints[EModelEndpoint.custom] ?? [];
return customEndpoints.find( return customEndpoints.find(
(endpointConfig) => normalizeEndpointName(endpointConfig.name) === endpoint, (endpointConfig) => normalizeEndpointName(endpointConfig.name) === endpoint,
); );
@ -81,14 +71,13 @@ async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
async function hasCustomUserVars() { async function hasCustomUserVars() {
const customConfig = await getCustomConfig(); const customConfig = await getAppConfig();
const mcpServers = customConfig?.mcpServers; const mcpServers = customConfig?.mcpConfig;
return Object.values(mcpServers ?? {}).some((server) => server.customUserVars); return Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
} }
module.exports = { module.exports = {
getMCPAuthMap, getMCPAuthMap,
getCustomConfig,
getBalanceConfig, getBalanceConfig,
hasCustomUserVars, hasCustomUserVars,
getCustomEndpointConfig, getCustomEndpointConfig,

View file

@ -123,6 +123,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
if (!user) { if (!user) {
const isFirstRegisteredUser = (await countUsers()) === 0; const isFirstRegisteredUser = (await countUsers()) === 0;
const role = isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER;
user = { user = {
provider: 'ldap', provider: 'ldap',
ldapId, ldapId,
@ -130,9 +131,9 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
email: mail, email: mail,
emailVerified: true, // The ldap server administrator should verify the email emailVerified: true, // The ldap server administrator should verify the email
name: fullName, name: fullName,
role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER, role,
}; };
const balanceConfig = await getBalanceConfig(); const balanceConfig = await getBalanceConfig({ role });
const userId = await createUser(user, balanceConfig); const userId = await createUser(user, balanceConfig);
user._id = userId; user._id = userId;
} else { } else {

View file

@ -1783,8 +1783,8 @@
* @property {String} conversationId - The ID of the conversation. * @property {String} conversationId - The ID of the conversation.
* @property {String} model - The model name. * @property {String} model - The model name.
* @property {String} context - The context in which the transaction is made. * @property {String} context - The context in which the transaction is made.
* @property {AppConfig['balance']} [balance] - The balance config
* @property {EndpointTokenConfig} [endpointTokenConfig] - The current endpoint token config. * @property {EndpointTokenConfig} [endpointTokenConfig] - The current endpoint token config.
* @property {object} [cacheUsage] - Cache usage, if any.
* @property {String} [valueKey] - The value key (optional). * @property {String} [valueKey] - The value key (optional).
* @memberof typedefs * @memberof typedefs
*/ */
@ -1829,6 +1829,7 @@
* @callback sendCompletion * @callback sendCompletion
* @param {Array<ChatCompletionMessage> | string} payload - The messages or prompt to send to the model * @param {Array<ChatCompletionMessage> | string} payload - The messages or prompt to send to the model
* @param {object} opts - Options for the completion * @param {object} opts - Options for the completion
* @param {AppConfig} opts.appConfig - Callback function to handle token progress
* @param {onTokenProgress} opts.onProgress - Callback function to handle token progress * @param {onTokenProgress} opts.onProgress - Callback function to handle token progress
* @param {AbortController} opts.abortController - AbortController instance * @param {AbortController} opts.abortController - AbortController instance
* @returns {Promise<string>} * @returns {Promise<string>}

View file

@ -5,7 +5,7 @@ import type { Model } from 'mongoose';
import type { BalanceUpdateFields } from '~/types'; import type { BalanceUpdateFields } from '~/types';
export interface BalanceMiddlewareOptions { export interface BalanceMiddlewareOptions {
getBalanceConfig: () => Promise<BalanceConfig | null>; getBalanceConfig: ({ role }?: { role?: string }) => Promise<BalanceConfig | null>;
Balance: Model<IBalance>; Balance: Model<IBalance>;
} }
@ -82,7 +82,8 @@ export function createSetBalanceConfig({
) => Promise<void> { ) => Promise<void> {
return async (req: ServerRequest, res: ServerResponse, next: NextFunction): Promise<void> => { return async (req: ServerRequest, res: ServerResponse, next: NextFunction): Promise<void> => {
try { try {
const balanceConfig = await getBalanceConfig(); const user = req.user as IUser & { _id: string | ObjectId };
const balanceConfig = await getBalanceConfig({ role: user?.role });
if (!balanceConfig?.enabled) { if (!balanceConfig?.enabled) {
return next(); return next();
} }
@ -90,7 +91,6 @@ export function createSetBalanceConfig({
return next(); return next();
} }
const user = req.user as IUser & { _id: string | ObjectId };
if (!user || !user._id) { if (!user || !user._id) {
return next(); return next();
} }