From 3a62a2633d48b9995be4762d38828bc45ceb2f2f Mon Sep 17 00:00:00 2001 From: Ruben Talstra Date: Fri, 21 Mar 2025 22:48:11 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=B5=20feat:=20Add=20Automatic=20Balanc?= =?UTF-8?q?e=20Refill=20(#6452)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🚀 feat: Add automatic refill settings to balance schema * 🚀 feat: Refactor balance feature to use global interface configuration * 🚀 feat: Implement auto-refill functionality for balance management * 🚀 feat: Enhance auto-refill logic and configuration for balance management * 🚀 chore: Bump version to 0.7.74 in package.json and package-lock.json * 🚀 chore: Bump version to 0.0.5 in package.json and package-lock.json * 🚀 docs: Update comment for balance settings in librechat.example.yaml * chore: space in `.env.example` * 🚀 feat: Implement balance configuration loading and refactor related components * 🚀 test: Refactor tests to use custom config for balance feature * 🚀 fix: Update balance response handling in Transaction.js to use Balance model * 🚀 test: Update AppService tests to include balance configuration in mock setup * 🚀 test: Enhance AppService tests with complete balance configuration scenarios * 🚀 refactor: Rename balanceConfig to balance and update related tests for clarity * 🚀 refactor: Remove loadDefaultBalance and update balance handling in AppService * 🚀 test: Update AppService tests to reflect new balance structure and defaults * 🚀 test: Mock getCustomConfig in BaseClient tests to control balance configuration * 🚀 test: Add get method to mockCache in OpenAIClient tests for improved cache handling * 🚀 test: Mock getCustomConfig in OpenAIClient tests to control balance configuration * 🚀 test: Remove mock for getCustomConfig in OpenAIClient tests to streamline configuration handling * 🚀 fix: Update balance configuration reference in config.js for consistency * refactor: Add getBalanceConfig function to retrieve balance configuration * chore: Comment out example balance settings in librechat.example.yaml * refactor: Replace getCustomConfig with getBalanceConfig for balance handling * fix: tests * refactor: Replace getBalanceConfig call with balance from request locals * refactor: Update balance handling to use environment variables for configuration * refactor: Replace getBalanceConfig calls with balance from request locals * refactor: Simplify balance configuration logic in getBalanceConfig --------- Co-authored-by: Danny Avila --- .env.example | 2 +- api/app/clients/BaseClient.js | 5 +- api/app/clients/PluginsClient.js | 4 +- .../clients/callbacks/createStartHandler.js | 6 +- api/app/clients/specs/OpenAIClient.test.js | 7 +- api/models/Balance.js | 79 ++++++++-- api/models/Transaction.js | 28 ++-- api/models/Transaction.spec.js | 138 +++++++++--------- api/models/spendTokens.spec.js | 15 +- api/models/userMethods.js | 42 ++++-- api/server/controllers/assistants/chatV1.js | 5 +- api/server/controllers/assistants/chatV2.js | 5 +- api/server/routes/config.js | 2 +- api/server/services/AppService.js | 13 +- api/server/services/AppService.spec.js | 29 +++- api/server/services/Config/getCustomConfig.js | 27 +++- client/src/components/Nav/AccountSettings.tsx | 4 +- client/src/hooks/SSE/useSSE.ts | 7 +- librechat.example.yaml | 10 ++ package-lock.json | 4 +- packages/data-provider/package.json | 2 +- packages/data-provider/src/config.ts | 16 +- packages/data-schemas/package.json | 2 +- packages/data-schemas/src/schema/balance.ts | 29 ++++ 24 files changed, 334 insertions(+), 147 deletions(-) diff --git a/.env.example b/.env.example index d0e189acd..57af60354 100644 --- a/.env.example +++ b/.env.example @@ -364,7 +364,7 @@ ILLEGAL_MODEL_REQ_SCORE=5 # Balance # #========================# -CHECK_BALANCE=false +# CHECK_BALANCE=false # START_BALANCE=20000 # note: the number of tokens that will be credited after registration. #========================# diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index d3077e68f..54e88e595 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -11,8 +11,8 @@ const { Constants, } = require('librechat-data-provider'); const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); -const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { truncateToolCallOutputs } = require('./prompts'); +const { addSpaceIfNeeded } = require('~/server/utils'); const checkBalance = require('~/models/checkBalance'); const { getFiles } = require('~/models/File'); const TextStream = require('./TextStream'); @@ -634,8 +634,9 @@ class BaseClient { } } + const balance = this.options.req?.app?.locals?.balance; if ( - isEnabled(process.env.CHECK_BALANCE) && + balance?.enabled && supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint] ) { await checkBalance({ diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index bfe222e24..9fd906ae4 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -7,7 +7,6 @@ const { processFileURL } = require('~/server/services/Files/process'); const { EModelEndpoint } = require('librechat-data-provider'); const { formatLangChainMessages } = require('./prompts'); const checkBalance = require('~/models/checkBalance'); -const { isEnabled } = require('~/server/utils'); const { extractBaseURL } = require('~/utils'); const { loadTools } = require('./tools/util'); const { logger } = require('~/config'); @@ -336,7 +335,8 @@ class PluginsClient extends OpenAIClient { } } - if (isEnabled(process.env.CHECK_BALANCE)) { + const balance = this.options.req?.app?.locals?.balance; + if (balance?.enabled) { await checkBalance({ req: this.options.req, res: this.options.res, diff --git a/api/app/clients/callbacks/createStartHandler.js b/api/app/clients/callbacks/createStartHandler.js index 4bc32bc0c..d5d806e6c 100644 --- a/api/app/clients/callbacks/createStartHandler.js +++ b/api/app/clients/callbacks/createStartHandler.js @@ -1,8 +1,8 @@ const { promptTokensEstimate } = require('openai-chat-tokens'); const { EModelEndpoint, supportsBalanceCheck } = require('librechat-data-provider'); const { formatFromLangChain } = require('~/app/clients/prompts'); +const { getBalanceConfig } = require('~/server/services/Config'); const checkBalance = require('~/models/checkBalance'); -const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); const createStartHandler = ({ @@ -49,8 +49,8 @@ const createStartHandler = ({ prelimPromptTokens += tokenBuffer; try { - // TODO: if plugins extends to non-OpenAI models, this will need to be updated - if (isEnabled(process.env.CHECK_BALANCE) && supportsBalanceCheck[EModelEndpoint.openAI]) { + const balance = await getBalanceConfig(); + if (balance?.enabled && supportsBalanceCheck[EModelEndpoint.openAI]) { const generations = initialMessageCount && messages.length > initialMessageCount ? messages.slice(initialMessageCount) diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 0e811cf38..adc290486 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -136,10 +136,11 @@ OpenAI.mockImplementation(() => ({ })); describe('OpenAIClient', () => { - const mockSet = jest.fn(); - const mockCache = { set: mockSet }; - beforeEach(() => { + const mockCache = { + get: jest.fn().mockResolvedValue({}), + set: jest.fn(), + }; getLogStores.mockReturnValue(mockCache); }); let client; diff --git a/api/models/Balance.js b/api/models/Balance.js index f7978d804..c26ed90bc 100644 --- a/api/models/Balance.js +++ b/api/models/Balance.js @@ -3,6 +3,40 @@ const { balanceSchema } = require('@librechat/data-schemas'); const { getMultiplier } = require('./tx'); const { logger } = require('~/config'); +/** + * Adds a time interval to a given date. + * @param {Date} date - The starting date. + * @param {number} value - The numeric value of the interval. + * @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time. + * @returns {Date} A new Date representing the starting date plus the interval. + */ +const addIntervalToDate = (date, value, unit) => { + const result = new Date(date); + switch (unit) { + case 'seconds': + result.setSeconds(result.getSeconds() + value); + break; + case 'minutes': + result.setMinutes(result.getMinutes() + value); + break; + case 'hours': + result.setHours(result.getHours() + value); + break; + case 'days': + result.setDate(result.getDate() + value); + break; + case 'weeks': + result.setDate(result.getDate() + value * 7); + break; + case 'months': + result.setMonth(result.getMonth() + value); + break; + default: + break; + } + return result; +}; + balanceSchema.statics.check = async function ({ user, model, @@ -14,9 +48,20 @@ balanceSchema.statics.check = async function ({ }) { const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig }); const tokenCost = amount * multiplier; - const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {}; - logger.debug('[Balance.check]', { + // Retrieve the complete balance record + let record = await this.findOne({ user }).lean(); + if (!record) { + logger.debug('[Balance.check] No balance record found for user', { user }); + return { + canSpend: false, + balance: 0, + tokenCost, + }; + } + let balance = record.tokenCredits; + + logger.debug('[Balance.check] Initial state', { user, model, endpoint, @@ -28,15 +73,31 @@ balanceSchema.statics.check = async function ({ endpointTokenConfig: !!endpointTokenConfig, }); - if (!balance) { - return { - canSpend: false, - balance: 0, - tokenCost, - }; + // Only perform auto-refill if spending would bring the balance to 0 or below + if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) { + const lastRefillDate = new Date(record.lastRefill); + const nextRefillDate = addIntervalToDate( + lastRefillDate, + record.refillIntervalValue, + record.refillIntervalUnit, + ); + const now = new Date(); + + if (now >= nextRefillDate) { + record = await this.findOneAndUpdate( + { user }, + { + $inc: { tokenCredits: record.refillAmount }, + $set: { lastRefill: new Date() }, + }, + { new: true }, + ).lean(); + balance = record.tokenCredits; + logger.debug('[Balance.check] Auto-refill performed', { balance }); + } } - logger.debug('[Balance.check]', { tokenCost }); + logger.debug('[Balance.check] Token cost', { tokenCost }); return { canSpend: balance >= tokenCost, balance, tokenCost }; }; diff --git a/api/models/Transaction.js b/api/models/Transaction.js index b1c4c6571..b9f444d98 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,6 +1,6 @@ const mongoose = require('mongoose'); -const { isEnabled } = require('~/server/utils/handleText'); const { transactionSchema } = require('@librechat/data-schemas'); +const { getBalanceConfig } = require('~/server/services/Config'); const { getMultiplier, getCacheMultiplier } = require('./tx'); const { logger } = require('~/config'); const Balance = require('./Balance'); @@ -37,18 +37,19 @@ transactionSchema.statics.create = async function (txData) { await transaction.save(); - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = await getBalanceConfig(); + if (!balance?.enabled) { return; } - let balance = await Balance.findOne({ user: transaction.user }).lean(); + let balanceResponse = await Balance.findOne({ user: transaction.user }).lean(); let incrementValue = transaction.tokenValue; - if (balance && balance?.tokenCredits + incrementValue < 0) { - incrementValue = -balance.tokenCredits; + if (balanceResponse && balanceResponse.tokenCredits + incrementValue < 0) { + incrementValue = -balanceResponse.tokenCredits; } - balance = await Balance.findOneAndUpdate( + balanceResponse = await Balance.findOneAndUpdate( { user: transaction.user }, { $inc: { tokenCredits: incrementValue } }, { upsert: true, new: true }, @@ -57,7 +58,7 @@ transactionSchema.statics.create = async function (txData) { return { rate: transaction.rate, user: transaction.user.toString(), - balance: balance.tokenCredits, + balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; }; @@ -78,18 +79,19 @@ transactionSchema.statics.createStructured = async function (txData) { await transaction.save(); - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = await getBalanceConfig(); + if (!balance?.enabled) { return; } - let balance = await Balance.findOne({ user: transaction.user }).lean(); + let balanceResponse = await Balance.findOne({ user: transaction.user }).lean(); let incrementValue = transaction.tokenValue; - if (balance && balance?.tokenCredits + incrementValue < 0) { - incrementValue = -balance.tokenCredits; + if (balanceResponse && balanceResponse.tokenCredits + incrementValue < 0) { + incrementValue = -balanceResponse.tokenCredits; } - balance = await Balance.findOneAndUpdate( + balanceResponse = await Balance.findOneAndUpdate( { user: transaction.user }, { $inc: { tokenCredits: incrementValue } }, { upsert: true, new: true }, @@ -98,7 +100,7 @@ transactionSchema.statics.createStructured = async function (txData) { return { rate: transaction.rate, user: transaction.user.toString(), - balance: balance.tokenCredits, + balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; }; diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index b8c69e13f..43f3c004b 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -1,9 +1,13 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); +const { spendTokens, spendStructuredTokens } = require('./spendTokens'); +const { getBalanceConfig } = require('~/server/services/Config'); +const { getMultiplier, getCacheMultiplier } = require('./tx'); const { Transaction } = require('./Transaction'); const Balance = require('./Balance'); -const { spendTokens, spendStructuredTokens } = require('./spendTokens'); -const { getMultiplier, getCacheMultiplier } = require('./tx'); + +// Mock the custom config module so we can control the balance flag. +jest.mock('~/server/services/Config'); let mongoServer; @@ -20,6 +24,8 @@ afterAll(async () => { beforeEach(async () => { await mongoose.connection.dropDatabase(); + // Default: enable balance updates in tests. + getBalanceConfig.mockResolvedValue({ enabled: true }); }); describe('Regular Token Spending Tests', () => { @@ -44,34 +50,22 @@ describe('Regular Token Spending Tests', () => { }; // Act - process.env.CHECK_BALANCE = 'true'; await spendTokens(txData, tokenUsage); // Assert - console.log('Initial Balance:', initialBalance); - const updatedBalance = await Balance.findOne({ user: userId }); - console.log('Updated Balance:', updatedBalance.tokenCredits); - const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const completionMultiplier = getMultiplier({ model, tokenType: 'completion' }); - - const expectedPromptCost = tokenUsage.promptTokens * promptMultiplier; - const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier; - const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + const expectedTotalCost = 100 * promptMultiplier + 50 * completionMultiplier; const expectedBalance = initialBalance - expectedTotalCost; - expect(updatedBalance.tokenCredits).toBeLessThan(initialBalance); expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0); - - console.log('Expected Total Cost:', expectedTotalCost); - console.log('Actual Balance Decrease:', initialBalance - updatedBalance.tokenCredits); }); test('spendTokens should handle zero completion tokens', async () => { // Arrange const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; // $10.00 + const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); const model = 'gpt-3.5-turbo'; @@ -89,24 +83,19 @@ describe('Regular Token Spending Tests', () => { }; // Act - process.env.CHECK_BALANCE = 'true'; await spendTokens(txData, tokenUsage); // Assert const updatedBalance = await Balance.findOne({ user: userId }); - const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); - const expectedCost = tokenUsage.promptTokens * promptMultiplier; + const expectedCost = 100 * promptMultiplier; expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); - - console.log('Initial Balance:', initialBalance); - console.log('Updated Balance:', updatedBalance.tokenCredits); - console.log('Expected Cost:', expectedCost); }); test('spendTokens should handle undefined token counts', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; // $10.00 + const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); const model = 'gpt-3.5-turbo'; @@ -120,14 +109,17 @@ describe('Regular Token Spending Tests', () => { const tokenUsage = {}; + // Act const result = await spendTokens(txData, tokenUsage); + // Assert: No transaction should be created expect(result).toBeUndefined(); }); test('spendTokens should handle only prompt tokens', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; // $10.00 + const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); const model = 'gpt-3.5-turbo'; @@ -141,14 +133,44 @@ describe('Regular Token Spending Tests', () => { const tokenUsage = { promptTokens: 100 }; + // Act await spendTokens(txData, tokenUsage); + // Assert const updatedBalance = await Balance.findOne({ user: userId }); - const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const expectedCost = 100 * promptMultiplier; expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); + + test('spendTokens should not update balance when balance feature is disabled', async () => { + // Arrange: Override the config to disable balance updates. + getBalanceConfig.mockResolvedValue({ balance: { enabled: false } }); + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const txData = { + user: userId, + conversationId: 'test-conversation-id', + model, + context: 'test', + endpointTokenConfig: null, + }; + + const tokenUsage = { + promptTokens: 100, + completionTokens: 50, + }; + + // Act + await spendTokens(txData, tokenUsage); + + // Assert: Balance should remain unchanged. + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBe(initialBalance); + }); }); describe('Structured Token Spending Tests', () => { @@ -164,7 +186,7 @@ describe('Structured Token Spending Tests', () => { conversationId: 'c23a18da-706c-470a-ac28-ec87ed065199', model, context: 'message', - endpointTokenConfig: null, // We'll use the default rates + endpointTokenConfig: null, }; const tokenUsage = { @@ -176,28 +198,15 @@ describe('Structured Token Spending Tests', () => { completionTokens: 5, }; - // Get the actual multipliers const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const completionMultiplier = getMultiplier({ model, tokenType: 'completion' }); const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); - console.log('Multipliers:', { - promptMultiplier, - completionMultiplier, - writeMultiplier, - readMultiplier, - }); - // Act - process.env.CHECK_BALANCE = 'true'; const result = await spendStructuredTokens(txData, tokenUsage); - // Assert - console.log('Initial Balance:', initialBalance); - console.log('Updated Balance:', result.completion.balance); - console.log('Transaction Result:', result); - + // Calculate expected costs. const expectedPromptCost = tokenUsage.promptTokens.input * promptMultiplier + tokenUsage.promptTokens.write * writeMultiplier + @@ -206,37 +215,21 @@ describe('Structured Token Spending Tests', () => { const expectedTotalCost = expectedPromptCost + expectedCompletionCost; const expectedBalance = initialBalance - expectedTotalCost; - console.log('Expected Cost:', expectedTotalCost); - console.log('Expected Balance:', expectedBalance); - + // Assert expect(result.completion.balance).toBeLessThan(initialBalance); - - // Allow for a small difference (e.g., 100 token credits, which is $0.0001) const allowedDifference = 100; expect(Math.abs(result.completion.balance - expectedBalance)).toBeLessThan(allowedDifference); - - // Check if the decrease is approximately as expected const balanceDecrease = initialBalance - result.completion.balance; expect(balanceDecrease).toBeCloseTo(expectedTotalCost, 0); - // Check token values - const expectedPromptTokenValue = -( - tokenUsage.promptTokens.input * promptMultiplier + - tokenUsage.promptTokens.write * writeMultiplier + - tokenUsage.promptTokens.read * readMultiplier - ); - const expectedCompletionTokenValue = -tokenUsage.completionTokens * completionMultiplier; - + const expectedPromptTokenValue = -expectedPromptCost; + const expectedCompletionTokenValue = -expectedCompletionCost; expect(result.prompt.prompt).toBeCloseTo(expectedPromptTokenValue, 1); expect(result.completion.completion).toBe(expectedCompletionTokenValue); - - console.log('Expected prompt tokenValue:', expectedPromptTokenValue); - console.log('Actual prompt tokenValue:', result.prompt.prompt); - console.log('Expected completion tokenValue:', expectedCompletionTokenValue); - console.log('Actual completion tokenValue:', result.completion.completion); }); test('should handle zero completion tokens in structured spending', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -258,15 +251,17 @@ describe('Structured Token Spending Tests', () => { completionTokens: 0, }; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); + // Assert expect(result.prompt).toBeDefined(); expect(result.completion).toBeUndefined(); expect(result.prompt.prompt).toBeLessThan(0); }); test('should handle only prompt tokens in structured spending', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -287,15 +282,17 @@ describe('Structured Token Spending Tests', () => { }, }; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); + // Assert expect(result.prompt).toBeDefined(); expect(result.completion).toBeUndefined(); expect(result.prompt.prompt).toBeLessThan(0); }); test('should handle undefined token counts in structured spending', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -310,9 +307,10 @@ describe('Structured Token Spending Tests', () => { const tokenUsage = {}; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); + // Assert expect(result).toEqual({ prompt: undefined, completion: undefined, @@ -320,6 +318,7 @@ describe('Structured Token Spending Tests', () => { }); test('should handle incomplete context for completion tokens', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -341,15 +340,18 @@ describe('Structured Token Spending Tests', () => { completionTokens: 50, }; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); - expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); // Assuming multiplier is 15 and cancelRate is 1.15 + // Assert: + // (Assuming a multiplier for completion of 15 and a cancel rate of 1.15 as noted in the original test.) + expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); }); }); describe('NaN Handling Tests', () => { test('should skip transaction creation when rawAmount is NaN', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -365,9 +367,11 @@ describe('NaN Handling Tests', () => { tokenType: 'prompt', }; + // Act const result = await Transaction.create(txData); - expect(result).toBeUndefined(); + // Assert: No transaction should be created and balance remains unchanged. + expect(result).toBeUndefined(); const balance = await Balance.findOne({ user: userId }); expect(balance.tokenCredits).toBe(initialBalance); }); diff --git a/api/models/spendTokens.spec.js b/api/models/spendTokens.spec.js index 91056bb54..f855c3119 100644 --- a/api/models/spendTokens.spec.js +++ b/api/models/spendTokens.spec.js @@ -19,14 +19,19 @@ jest.mock('~/config', () => ({ }, })); +// New config module +const { getBalanceConfig } = require('~/server/services/Config'); +jest.mock('~/server/services/Config'); + // Import after mocking const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { Transaction } = require('./Transaction'); const Balance = require('./Balance'); + describe('spendTokens', () => { beforeEach(() => { jest.clearAllMocks(); - process.env.CHECK_BALANCE = 'true'; + getBalanceConfig.mockResolvedValue({ enabled: true }); }); it('should create transactions for both prompt and completion tokens', async () => { @@ -92,7 +97,7 @@ describe('spendTokens', () => { expect(Transaction.create).toHaveBeenCalledWith( expect.objectContaining({ tokenType: 'completion', - rawAmount: -0, // Changed from 0 to -0 + rawAmount: -0, }), ); }); @@ -111,8 +116,9 @@ describe('spendTokens', () => { expect(Transaction.create).not.toHaveBeenCalled(); }); - it('should not update balance when CHECK_BALANCE is false', async () => { - process.env.CHECK_BALANCE = 'false'; + it('should not update balance when the balance feature is disabled', async () => { + // Override configuration: disable balance updates. + getBalanceConfig.mockResolvedValue({ enabled: false }); const txData = { user: new mongoose.Types.ObjectId(), conversationId: 'test-convo', @@ -130,6 +136,7 @@ describe('spendTokens', () => { await spendTokens(txData, tokenUsage); expect(Transaction.create).toHaveBeenCalledTimes(2); + // When balance updates are disabled, Balance methods should not be called. expect(Balance.findOne).not.toHaveBeenCalled(); expect(Balance.findOneAndUpdate).not.toHaveBeenCalled(); }); diff --git a/api/models/userMethods.js b/api/models/userMethods.js index 63b25edd3..fbcd33aba 100644 --- a/api/models/userMethods.js +++ b/api/models/userMethods.js @@ -1,6 +1,6 @@ const bcrypt = require('bcryptjs'); +const { getBalanceConfig } = require('~/server/services/Config'); const signPayload = require('~/server/services/signPayload'); -const { isEnabled } = require('~/server/utils/handleText'); const Balance = require('./Balance'); const User = require('./User'); @@ -13,11 +13,9 @@ const User = require('./User'); */ const getUserById = async function (userId, fieldsToSelect = null) { const query = User.findById(userId); - if (fieldsToSelect) { query.select(fieldsToSelect); } - return await query.lean(); }; @@ -32,7 +30,6 @@ const findUser = async function (searchCriteria, fieldsToSelect = null) { if (fieldsToSelect) { query.select(fieldsToSelect); } - return await query.lean(); }; @@ -58,11 +55,12 @@ const updateUser = async function (userId, updateData) { * Creates a new user, optionally with a TTL of 1 week. * @param {MongoUser} data - The user data to be created, must contain user_id. * @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`. - * @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`. - * @returns {Promise} A promise that resolves to the created user document ID. + * @param {boolean} [returnUser=false] - Whether to return the created user object. + * @returns {Promise} A promise that resolves to the created user document ID or user object. * @throws {Error} If a user with the same user_id already exists. */ const createUser = async (data, disableTTL = true, returnUser = false) => { + const balance = await getBalanceConfig(); const userData = { ...data, expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds @@ -74,13 +72,27 @@ const createUser = async (data, disableTTL = true, returnUser = false) => { const user = await User.create(userData); - if (isEnabled(process.env.CHECK_BALANCE) && process.env.START_BALANCE) { - let incrementValue = parseInt(process.env.START_BALANCE); - await Balance.findOneAndUpdate( - { user: user._id }, - { $inc: { tokenCredits: incrementValue } }, - { upsert: true, new: true }, - ).lean(); + // If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance + if (balance?.enabled && balance?.startBalance) { + const update = { + $inc: { tokenCredits: balance.startBalance }, + }; + + if ( + balance.autoRefillEnabled && + balance.refillIntervalValue != null && + balance.refillIntervalUnit != null && + balance.refillAmount != null + ) { + update.$set = { + autoRefillEnabled: true, + refillIntervalValue: balance.refillIntervalValue, + refillIntervalUnit: balance.refillIntervalUnit, + refillAmount: balance.refillAmount, + }; + } + + await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean(); } if (returnUser) { @@ -123,7 +135,7 @@ const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15; /** * Generates a JWT token for a given user. * - * @param {MongoUser} user - ID of the user for whom the token is being generated. + * @param {MongoUser} user - The user for whom the token is being generated. * @returns {Promise} A promise that resolves to a JWT token. */ const generateToken = async (user) => { @@ -146,7 +158,7 @@ const generateToken = async (user) => { /** * Compares the provided password with the user's password. * - * @param {MongoUser} user - the user to compare password for. + * @param {MongoUser} user - The user to compare the password for. * @param {string} candidatePassword - The password to test against the user's password. * @returns {Promise} A promise that resolves to a boolean indicating if the password matches. */ diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index 8461941e0..cb30277e9 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -19,7 +19,7 @@ const { addThreadMetadata, saveAssistantMessage, } = require('~/server/services/Threads'); -const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); +const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts'); @@ -248,7 +248,8 @@ const chatV1 = async (req, res) => { } const checkBalanceBeforeRun = async () => { - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = req.app?.locals?.balance; + if (!balance?.enabled) { return; } const transactions = diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 24a8e38fa..cf130e4ee 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -18,11 +18,11 @@ const { saveAssistantMessage, } = require('~/server/services/Threads'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); -const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); const { createErrorHandler } = require('~/server/controllers/assistants/errors'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); +const { sendMessage, sleep, countTokens } = require('~/server/utils'); const { createRunBody } = require('~/server/services/createRunBody'); const { getTransactions } = require('~/models/Transaction'); const checkBalance = require('~/models/checkBalance'); @@ -124,7 +124,8 @@ const chatV2 = async (req, res) => { } const checkBalanceBeforeRun = async () => { - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = req.app?.locals?.balance; + if (!balance?.enabled) { return; } const transactions = diff --git a/api/server/routes/config.js b/api/server/routes/config.js index e8d2fe57a..e1e8ba763 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -69,7 +69,6 @@ router.get('/', async function (req, res) { !!process.env.EMAIL_PASSWORD && !!process.env.EMAIL_FROM, passwordResetEnabled, - checkBalance: isEnabled(process.env.CHECK_BALANCE), showBirthdayIcon: isBirthday() || isEnabled(process.env.SHOW_BIRTHDAY_ICON) || @@ -77,6 +76,7 @@ router.get('/', async function (req, res) { helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai', interface: req.app.locals.interfaceConfig, modelSpecs: req.app.locals.modelSpecs, + balance: req.app.locals.balance, sharedLinksEnabled, publicSharedLinksEnabled, analyticsGtmId: process.env.ANALYTICS_GTM_ID, diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index 3fdae6ac1..baead9744 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -9,15 +9,16 @@ const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = requir const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants'); const { initializeAzureBlobService } = require('./Files/Azure/initialize'); const { initializeFirebase } = require('./Files/Firebase/initialize'); -const { initializeS3 } = require('./Files/S3/initialize'); const loadCustomConfig = require('./Config/loadCustomConfig'); const handleRateLimits = require('./Config/handleRateLimits'); const { loadDefaultInterface } = require('./start/interface'); const { azureConfigSetup } = require('./start/azureOpenAI'); const { processModelSpecs } = require('./start/modelSpecs'); +const { initializeS3 } = require('./Files/S3/initialize'); const { loadAndFormatTools } = require('./ToolService'); const { agentsConfigSetup } = require('./start/agents'); const { initializeRoles } = require('~/models/Role'); +const { isEnabled } = require('~/server/utils'); const { getMCPManager } = require('~/config'); const paths = require('~/config/paths'); @@ -29,7 +30,7 @@ const paths = require('~/config/paths'); */ const AppService = async (app) => { await initializeRoles(); - /** @type {TCustomConfig}*/ + /** @type {TCustomConfig} */ const config = (await loadCustomConfig()) ?? {}; const configDefaults = getConfigDefaults(); @@ -37,6 +38,11 @@ const AppService = async (app) => { const filteredTools = config.filteredTools; const includedTools = config.includedTools; const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy; + const startBalance = process.env.START_BALANCE; + const balance = config.balance ?? { + enabled: isEnabled(process.env.CHECK_BALANCE), + startBalance: startBalance ? parseInt(startBalance, 10) : undefined, + }; const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType; process.env.CDN_PROVIDER = fileStrategy; @@ -52,7 +58,7 @@ const AppService = async (app) => { initializeS3(); } - /** @type {Record} */ const availableTools = loadAndFormatTools({ adminFilter: filteredTools, adminIncluded: includedTools, @@ -79,6 +85,7 @@ const AppService = async (app) => { availableTools, imageOutputType, interfaceConfig, + balance, }; if (!Object.keys(config).length) { diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index e47bfe7d5..465ec9fdd 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -15,6 +15,9 @@ jest.mock('./Config/loadCustomConfig', () => { Promise.resolve({ registration: { socialLogins: ['testLogin'] }, fileStrategy: 'testStrategy', + balance: { + enabled: true, + }, }), ); }); @@ -124,6 +127,9 @@ describe('AppService', () => { imageOutputType: expect.any(String), fileConfig: undefined, secureImageLinks: undefined, + balance: { enabled: true }, + filteredTools: undefined, + includedTools: undefined, }); }); @@ -341,9 +347,6 @@ describe('AppService', () => { process.env.FILE_UPLOAD_USER_MAX = 'initialUserMax'; process.env.FILE_UPLOAD_USER_WINDOW = 'initialUserWindow'; - // Mock a custom configuration without specific rate limits - require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); - await AppService(app); // Verify that process.env falls back to the initial values @@ -404,9 +407,6 @@ describe('AppService', () => { process.env.IMPORT_USER_MAX = 'initialUserMax'; process.env.IMPORT_USER_WINDOW = 'initialUserWindow'; - // Mock a custom configuration without specific rate limits - require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); - await AppService(app); // Verify that process.env falls back to the initial values @@ -445,13 +445,27 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(FileSources.local); expect(app.locals.socialLogins).toEqual(defaultSocialLogins); + expect(app.locals.balance).toEqual( + expect.objectContaining({ + enabled: false, + startBalance: undefined, + }), + ); }); it('should update app.locals with values from loadCustomConfig', async () => { - // Mock loadCustomConfig to return a specific config object + // Mock loadCustomConfig to return a specific config object with a complete balance config const customConfig = { fileStrategy: 'firebase', registration: { socialLogins: ['testLogin'] }, + balance: { + enabled: false, + startBalance: 5000, + autoRefillEnabled: true, + refillIntervalValue: 15, + refillIntervalUnit: 'hours', + refillAmount: 5000, + }, }; require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(customConfig), @@ -464,6 +478,7 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy); expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins); + expect(app.locals.balance).toEqual(customConfig.balance); }); it('should apply the assistants endpoint configuration correctly to app.locals', async () => { diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index 5b9b2dd18..2a154421b 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -1,5 +1,5 @@ const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); -const { normalizeEndpointName } = require('~/server/utils'); +const { normalizeEndpointName, isEnabled } = require('~/server/utils'); const loadCustomConfig = require('./loadCustomConfig'); const getLogStores = require('~/cache/getLogStores'); @@ -23,6 +23,29 @@ async function getCustomConfig() { return customConfig; } +/** + * Retrieves the configuration object + * @function getBalanceConfig + * @returns {Promise} + * */ +async function getBalanceConfig() { + const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE); + const startBalance = process.env.START_BALANCE; + if (isLegacyEnabled || (startBalance != null && startBalance)) { + /** @type {TCustomConfig['balance']} */ + const config = { + enabled: isLegacyEnabled, + startBalance: startBalance ? parseInt(startBalance, 10) : undefined, + }; + return config; + } + const customConfig = await getCustomConfig(); + if (!customConfig) { + return null; + } + return customConfig?.['balance'] ?? null; +} + /** * * @param {string | EModelEndpoint} endpoint @@ -40,4 +63,4 @@ const getCustomEndpointConfig = async (endpoint) => { ); }; -module.exports = { getCustomConfig, getCustomEndpointConfig }; +module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig }; diff --git a/client/src/components/Nav/AccountSettings.tsx b/client/src/components/Nav/AccountSettings.tsx index 77ed49ce4..ec6b52fd5 100644 --- a/client/src/components/Nav/AccountSettings.tsx +++ b/client/src/components/Nav/AccountSettings.tsx @@ -17,7 +17,7 @@ function AccountSettings() { const { user, isAuthenticated, logout } = useAuthContext(); const { data: startupConfig } = useGetStartupConfig(); const balanceQuery = useGetUserBalance({ - enabled: !!isAuthenticated && startupConfig?.checkBalance, + enabled: !!isAuthenticated && startupConfig?.balance?.enabled, }); const [showSettings, setShowSettings] = useState(false); const [showFiles, setShowFiles] = useRecoilState(store.showFiles); @@ -75,7 +75,7 @@ function AccountSettings() { {user?.email ?? localize('com_nav_user')} - {startupConfig?.checkBalance === true && + {startupConfig?.balance?.enabled === true && balanceQuery.data != null && !isNaN(parseFloat(balanceQuery.data)) && ( <> diff --git a/client/src/hooks/SSE/useSSE.ts b/client/src/hooks/SSE/useSSE.ts index a52928caa..92f03090e 100644 --- a/client/src/hooks/SSE/useSSE.ts +++ b/client/src/hooks/SSE/useSSE.ts @@ -76,7 +76,7 @@ export default function useSSE( const { data: startupConfig } = useGetStartupConfig(); const balanceQuery = useGetUserBalance({ - enabled: !!isAuthenticated && startupConfig?.checkBalance, + enabled: !!isAuthenticated && startupConfig?.balance?.enabled, }); useEffect(() => { @@ -114,7 +114,7 @@ export default function useSSE( if (data.final != null) { const { plugins } = data; finalHandler(data, { ...submission, plugins } as EventSubmission); - (startupConfig?.checkBalance ?? false) && balanceQuery.refetch(); + (startupConfig?.balance?.enabled ?? false) && balanceQuery.refetch(); console.log('final', data); return; } else if (data.created != null) { @@ -208,7 +208,7 @@ export default function useSSE( } console.log('error in server stream.'); - (startupConfig?.checkBalance ?? false) && balanceQuery.refetch(); + (startupConfig?.balance?.enabled ?? false) && balanceQuery.refetch(); let data: TResData | undefined = undefined; try { @@ -234,6 +234,5 @@ export default function useSSE( sse.dispatchEvent(e); } }; - // eslint-disable-next-line react-hooks/exhaustive-deps }, [submission]); } diff --git a/librechat.example.yaml b/librechat.example.yaml index c8c914ced..0b4963cb2 100644 --- a/librechat.example.yaml +++ b/librechat.example.yaml @@ -77,6 +77,16 @@ registration: # allowedDomains: # - "gmail.com" + +# Example Balance settings +# balance: +# enabled: false +# startBalance: 20000 +# autoRefillEnabled: false +# refillIntervalValue: 30 +# refillIntervalUnit: 'days' +# refillAmount: 10000 + # speech: # tts: # openai: diff --git a/package-lock.json b/package-lock.json index 38c4ec052..d7bc3c39a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -43995,7 +43995,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.7.73", + "version": "0.7.74", "license": "ISC", "dependencies": { "axios": "^1.8.2", @@ -44132,7 +44132,7 @@ }, "packages/data-schemas": { "name": "@librechat/data-schemas", - "version": "0.0.4", + "version": "0.0.5", "license": "MIT", "dependencies": { "mongoose": "^8.12.1" diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 6f495a9ae..ff7d4cb93 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.7.73", + "version": "0.7.74", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 679d600eb..eb4fb89fe 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -501,11 +501,13 @@ export const intefaceSchema = z }); export type TInterfaceConfig = z.infer; +export type TBalanceConfig = z.infer; export type TStartupConfig = { appTitle: string; socialLogins?: string[]; interface?: TInterfaceConfig; + balance?: TBalanceConfig; discordLoginEnabled: boolean; facebookLoginEnabled: boolean; githubLoginEnabled: boolean; @@ -528,7 +530,6 @@ export type TStartupConfig = { socialLoginEnabled: boolean; passwordResetEnabled: boolean; emailEnabled: boolean; - checkBalance: boolean; showBirthdayIcon: boolean; helpAndFaqURL: string; customFooter?: string; @@ -552,6 +553,18 @@ export const ocrSchema = z.object({ strategy: z.nativeEnum(OCRStrategy).default(OCRStrategy.MISTRAL_OCR), }); +export const balanceSchema = z.object({ + enabled: z.boolean().optional().default(false), + startBalance: z.number().optional().default(20000), + autoRefillEnabled: z.boolean().optional().default(false), + refillIntervalValue: z.number().optional().default(30), + refillIntervalUnit: z + .enum(['seconds', 'minutes', 'hours', 'days', 'weeks', 'months']) + .optional() + .default('days'), + refillAmount: z.number().optional().default(10000), +}); + export const configSchema = z.object({ version: z.string(), cache: z.boolean().default(true), @@ -574,6 +587,7 @@ export const configSchema = z.object({ allowedDomains: z.array(z.string()).optional(), }) .default({ socialLogins: defaultSocialLogins }), + balance: balanceSchema.optional(), speech: z .object({ tts: ttsSchema.optional(), diff --git a/packages/data-schemas/package.json b/packages/data-schemas/package.json index 5d1ab2cf4..3add216e4 100644 --- a/packages/data-schemas/package.json +++ b/packages/data-schemas/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/data-schemas", - "version": "0.0.4", + "version": "0.0.5", "description": "Mongoose schemas and models for LibreChat", "type": "module", "main": "dist/index.cjs", diff --git a/packages/data-schemas/src/schema/balance.ts b/packages/data-schemas/src/schema/balance.ts index 0878a5760..c02871dff 100644 --- a/packages/data-schemas/src/schema/balance.ts +++ b/packages/data-schemas/src/schema/balance.ts @@ -3,6 +3,12 @@ import { Schema, Document, Types } from 'mongoose'; export interface IBalance extends Document { user: Types.ObjectId; tokenCredits: number; + // Automatic refill settings + autoRefillEnabled: boolean; + refillIntervalValue: number; + refillIntervalUnit: 'seconds' | 'minutes' | 'hours' | 'days' | 'weeks' | 'months'; + lastRefill: Date; + refillAmount: number; } const balanceSchema = new Schema({ @@ -17,6 +23,29 @@ const balanceSchema = new Schema({ type: Number, default: 0, }, + // Automatic refill settings + autoRefillEnabled: { + type: Boolean, + default: false, + }, + refillIntervalValue: { + type: Number, + default: 30, + }, + refillIntervalUnit: { + type: String, + enum: ['seconds', 'minutes', 'hours', 'days', 'weeks', 'months'], + default: 'days', + }, + lastRefill: { + type: Date, + default: Date.now, + }, + // amount to add on each refill + refillAmount: { + type: Number, + default: 0, + }, }); export default balanceSchema;