diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 00cc7c92c0..cdce2edf4d 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -523,6 +523,8 @@ class BaseClient { getMultiplier: db.getMultiplier, findBalanceByUser: db.findBalanceByUser, createAutoRefillTransaction: db.createAutoRefillTransaction, + balanceConfig, + upsertBalanceFields: db.upsertBalanceFields, }, ); } diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index e4a20c2a5e..631831e617 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -40,6 +40,7 @@ const { sendResponse } = require('~/server/middleware/error'); const { createAutoRefillTransaction, findBalanceByUser, + upsertBalanceFields, getTransactions, getMultiplier, getConvo, @@ -296,7 +297,14 @@ const chatV1 = async (req, res) => { amount: promptTokens, }, }, - { findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation }, + { + findBalanceByUser, + getMultiplier, + createAutoRefillTransaction, + logViolation, + balanceConfig, + upsertBalanceFields, + }, ); }; diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 559d9d8953..237af1b11a 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -37,6 +37,7 @@ const { getMultiplier, getTransactions, findBalanceByUser, + upsertBalanceFields, createAutoRefillTransaction, } = require('~/models'); const { logViolation, getLogStores } = require('~/cache'); @@ -169,7 +170,14 @@ const chatV2 = async (req, res) => { amount: promptTokens, }, }, - { findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation }, + { + findBalanceByUser, + getMultiplier, + createAutoRefillTransaction, + logViolation, + balanceConfig, + upsertBalanceFields, + }, ); }; diff --git a/packages/api/src/middleware/checkBalance.spec.ts b/packages/api/src/middleware/checkBalance.spec.ts new file mode 100644 index 0000000000..8d272d2e60 --- /dev/null +++ b/packages/api/src/middleware/checkBalance.spec.ts @@ -0,0 +1,266 @@ +import { ViolationTypes } from 'librechat-data-provider'; +import type { Response } from 'express'; +import type { CheckBalanceDeps } from './checkBalance'; +import type { ServerRequest } from '~/types/http'; +import { checkBalance } from './checkBalance'; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + error: jest.fn(), + }, +})); + +describe('checkBalance', () => { + const createMockDeps = (overrides: Partial = {}): CheckBalanceDeps => ({ + findBalanceByUser: jest.fn().mockResolvedValue({ tokenCredits: 1000 }), + getMultiplier: jest.fn().mockReturnValue(1), + createAutoRefillTransaction: jest.fn(), + logViolation: jest.fn().mockResolvedValue(undefined), + ...overrides, + }); + + const req = { user: { id: 'user-1' } } as ServerRequest; + const res = {} as Response; + + const baseTxData = { + user: 'user-1', + tokenType: 'prompt', + amount: 100, + endpoint: 'openAI', + model: 'gpt-4', + }; + + it('should return true when user has sufficient balance', async () => { + const deps = createMockDeps(); + + const result = await checkBalance({ req, res, txData: baseTxData }, deps); + expect(result).toBe(true); + }); + + it('should throw when user has insufficient balance', async () => { + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue({ tokenCredits: 10 }), + getMultiplier: jest.fn().mockReturnValue(1), + }); + + await expect( + checkBalance({ req, res, txData: { ...baseTxData, amount: 100 } }, deps), + ).rejects.toThrow(); + + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 10, tokenCost: 100 }), + 0, + ); + }); + + describe('lazy balance initialization', () => { + it('should create balance record when no record exists and startBalance is configured', async () => { + const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 5000 }); + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + balanceConfig: { startBalance: 5000 }, + upsertBalanceFields, + }); + + const result = await checkBalance({ req, res, txData: baseTxData }, deps); + + expect(result).toBe(true); + expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', { + user: 'user-1', + tokenCredits: 5000, + }); + }); + + it('should include auto-refill fields when configured', async () => { + const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 5000 }); + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + balanceConfig: { + startBalance: 5000, + autoRefillEnabled: true, + refillIntervalValue: 1, + refillIntervalUnit: 'days', + refillAmount: 1000, + }, + upsertBalanceFields, + }); + + await checkBalance({ req, res, txData: baseTxData }, deps); + + expect(upsertBalanceFields).toHaveBeenCalledWith( + 'user-1', + expect.objectContaining({ + user: 'user-1', + tokenCredits: 5000, + autoRefillEnabled: true, + refillIntervalValue: 1, + refillIntervalUnit: 'days', + refillAmount: 1000, + lastRefill: expect.any(Date), + }), + ); + }); + + it('should not include auto-refill fields when config is partial', async () => { + const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 5000 }); + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + balanceConfig: { startBalance: 5000, autoRefillEnabled: true }, + upsertBalanceFields, + }); + + await checkBalance({ req, res, txData: baseTxData }, deps); + + expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', { + user: 'user-1', + tokenCredits: 5000, + }); + }); + + it('should throw a TOKEN_BALANCE violation when lazy-initialized balance is less than token cost', async () => { + const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 50 }); + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + getMultiplier: jest.fn().mockReturnValue(1), + balanceConfig: { startBalance: 50 }, + upsertBalanceFields, + }); + + await expect( + checkBalance({ req, res, txData: { ...baseTxData, amount: 100 } }, deps), + ).rejects.toThrow(); + + expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', { + user: 'user-1', + tokenCredits: 50, + }); + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 50, tokenCost: 100 }), + 0, + ); + }); + + it('should use DB-returned tokenCredits over raw startBalance config constant', async () => { + const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 3000 }); + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + getMultiplier: jest.fn().mockReturnValue(1), + balanceConfig: { startBalance: 5000 }, + upsertBalanceFields, + }); + + await expect( + checkBalance({ req, res, txData: { ...baseTxData, amount: 4000 } }, deps), + ).rejects.toThrow(); + + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 3000, tokenCost: 4000 }), + 0, + ); + }); + + it('should throw a TOKEN_BALANCE violation when no record and no balanceConfig', async () => { + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + }); + + await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow(); + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 0 }), + 0, + ); + }); + + it('should throw a TOKEN_BALANCE violation when no record and startBalance is undefined', async () => { + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + balanceConfig: {}, + upsertBalanceFields: jest.fn(), + }); + + await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow(); + expect(deps.upsertBalanceFields).not.toHaveBeenCalled(); + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 0 }), + 0, + ); + }); + + it('should throw a TOKEN_BALANCE violation when upsertBalanceFields is not provided', async () => { + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + balanceConfig: { startBalance: 5000 }, + }); + + await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow(); + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 0 }), + 0, + ); + }); + + it('should handle startBalance of 0', async () => { + const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 0 }); + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + getMultiplier: jest.fn().mockReturnValue(1), + balanceConfig: { startBalance: 0 }, + upsertBalanceFields, + }); + + await expect( + checkBalance({ req, res, txData: { ...baseTxData, amount: 100 } }, deps), + ).rejects.toThrow(); + + expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', { + user: 'user-1', + tokenCredits: 0, + }); + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 0, tokenCost: 100 }), + 0, + ); + }); + + it('should fall back to balance: 0 when upsertBalanceFields rejects', async () => { + const upsertBalanceFields = jest.fn().mockRejectedValue(new Error('DB unavailable')); + const deps = createMockDeps({ + findBalanceByUser: jest.fn().mockResolvedValue(null), + balanceConfig: { startBalance: 5000 }, + upsertBalanceFields, + }); + + await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow(); + expect(deps.logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.TOKEN_BALANCE, + expect.objectContaining({ balance: 0 }), + 0, + ); + }); + }); +}); diff --git a/packages/api/src/middleware/checkBalance.ts b/packages/api/src/middleware/checkBalance.ts index d99874dc07..d285614826 100644 --- a/packages/api/src/middleware/checkBalance.ts +++ b/packages/api/src/middleware/checkBalance.ts @@ -1,7 +1,8 @@ import { logger } from '@librechat/data-schemas'; import { ViolationTypes } from 'librechat-data-provider'; -import type { ServerRequest } from '~/types/http'; +import type { BalanceConfig, IBalanceUpdate } from '@librechat/data-schemas'; import type { Response } from 'express'; +import type { ServerRequest } from '~/types/http'; type TimeUnit = 'seconds' | 'minutes' | 'hours' | 'days' | 'weeks' | 'months'; @@ -38,6 +39,10 @@ export interface CheckBalanceDeps { errorMessage: Record, score: number, ) => Promise; + /** Balance config for lazy initialization when no record exists */ + balanceConfig?: BalanceConfig; + /** Upsert function for lazy initialization when no record exists */ + upsertBalanceFields?: (userId: string, fields: IBalanceUpdate) => Promise; } function addIntervalToDate(date: Date, value: number, unit: TimeUnit): Date { @@ -84,6 +89,37 @@ async function checkBalanceRecord( const record = await deps.findBalanceByUser(user); if (!record) { + if (deps.balanceConfig?.startBalance != null && deps.upsertBalanceFields) { + logger.debug('[Balance.check] Lazy-initializing balance record for user', { + user, + startBalance: deps.balanceConfig.startBalance, + }); + try { + const fields: IBalanceUpdate = { + user, + tokenCredits: deps.balanceConfig.startBalance, + }; + const config = deps.balanceConfig; + if ( + config.autoRefillEnabled && + config.refillIntervalValue != null && + config.refillIntervalUnit != null && + config.refillAmount != null + ) { + fields.autoRefillEnabled = config.autoRefillEnabled; + fields.refillIntervalValue = config.refillIntervalValue; + fields.refillIntervalUnit = config.refillIntervalUnit; + fields.refillAmount = config.refillAmount; + fields.lastRefill = new Date(); + } + const created = await deps.upsertBalanceFields(user, fields); + const balance = created?.tokenCredits ?? deps.balanceConfig.startBalance; + return { canSpend: balance >= tokenCost, balance, tokenCost }; + } catch (error) { + logger.error('[Balance.check] Failed to lazy-initialize balance record', { user, error }); + return { canSpend: false, balance: 0, tokenCost }; + } + } logger.debug('[Balance.check] No balance record found for user', { user }); return { canSpend: false, balance: 0, tokenCost }; }