🔄 fix: Ensure lastRefill Date for Existing Users & Refactor Balance Middleware (#9086)

- Deleted setBalanceConfig middleware and its associated file.
- Introduced createSetBalanceConfig factory function to create middleware for synchronizing user balance settings.
- Updated auth and oauth routes to use the new balance configuration middleware.
- Added comprehensive tests for the new balance middleware functionality.
- Updated package versions and dependencies in package.json and package-lock.json.
- Added balance types and updated middleware index to export new balance middleware.
This commit is contained in:
Danny Avila 2025-08-15 17:02:49 -04:00 committed by GitHub
parent 81186312ef
commit 50b7bd6643
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 835 additions and 280 deletions

View file

@ -9,7 +9,6 @@ const validateEndpoint = require('./validateEndpoint');
const requireLocalAuth = require('./requireLocalAuth');
const canDeleteAccount = require('./canDeleteAccount');
const accessResources = require('./accessResources');
const setBalanceConfig = require('./setBalanceConfig');
const requireLdapAuth = require('./requireLdapAuth');
const abortMiddleware = require('./abortMiddleware');
const checkInviteUser = require('./checkInviteUser');
@ -44,7 +43,6 @@ module.exports = {
requireLocalAuth,
canDeleteAccount,
validateEndpoint,
setBalanceConfig,
concurrentLimiter,
checkDomainAllowed,
validateMessageReq,

View file

@ -1,91 +0,0 @@
const { logger } = require('@librechat/data-schemas');
const { getBalanceConfig } = require('~/server/services/Config');
const { Balance } = require('~/db/models');
/**
* Middleware to synchronize user balance settings with current balance configuration.
* @function
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {import('express').NextFunction} next - Next middleware function.
*/
const setBalanceConfig = async (req, res, next) => {
try {
const balanceConfig = await getBalanceConfig();
if (!balanceConfig?.enabled) {
return next();
}
if (balanceConfig.startBalance == null) {
return next();
}
const userId = req.user._id;
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
if (Object.keys(updateFields).length === 0) {
return next();
}
await Balance.findOneAndUpdate(
{ user: userId },
{ $set: updateFields },
{ upsert: true, new: true },
);
next();
} catch (error) {
logger.error('Error setting user balance:', error);
next(error);
}
};
/**
* Build an object containing fields that need updating
* @param {Object} config - The balance configuration
* @param {Object|null} userRecord - The user's current balance record, if any
* @returns {Object} Fields that need updating
*/
function buildUpdateFields(config, userRecord) {
const updateFields = {};
// Ensure user record has the required fields
if (!userRecord) {
updateFields.user = userRecord?.user;
updateFields.tokenCredits = config.startBalance;
}
if (userRecord?.tokenCredits == null && config.startBalance != null) {
updateFields.tokenCredits = config.startBalance;
}
const isAutoRefillConfigValid =
config.autoRefillEnabled &&
config.refillIntervalValue != null &&
config.refillIntervalUnit != null &&
config.refillAmount != null;
if (!isAutoRefillConfigValid) {
return updateFields;
}
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
updateFields.autoRefillEnabled = config.autoRefillEnabled;
}
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
updateFields.refillIntervalValue = config.refillIntervalValue;
}
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
updateFields.refillIntervalUnit = config.refillIntervalUnit;
}
if (userRecord?.refillAmount !== config.refillAmount) {
updateFields.refillAmount = config.refillAmount;
}
return updateFields;
}
module.exports = setBalanceConfig;

View file

@ -1,75 +1,75 @@
const express = require('express');
const { createSetBalanceConfig } = require('@librechat/api');
const {
refreshController,
registrationController,
resetPasswordController,
resetPasswordRequestController,
resetPasswordController,
registrationController,
graphTokenController,
refreshController,
} = require('~/server/controllers/AuthController');
const { loginController } = require('~/server/controllers/auth/LoginController');
const { logoutController } = require('~/server/controllers/auth/LogoutController');
const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController');
const {
regenerateBackupCodes,
disable2FA,
confirm2FA,
enable2FA,
verify2FA,
disable2FA,
regenerateBackupCodes,
confirm2FA,
} = require('~/server/controllers/TwoFactorController');
const {
checkBan,
logHeaders,
loginLimiter,
requireJwtAuth,
checkInviteUser,
registerLimiter,
requireLdapAuth,
setBalanceConfig,
requireLocalAuth,
resetPasswordLimiter,
validateRegistration,
validatePasswordReset,
} = require('~/server/middleware');
const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController');
const { logoutController } = require('~/server/controllers/auth/LogoutController');
const { loginController } = require('~/server/controllers/auth/LoginController');
const { getBalanceConfig } = require('~/server/services/Config');
const middleware = require('~/server/middleware');
const { Balance } = require('~/db/models');
const setBalanceConfig = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const router = express.Router();
const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
//Local
router.post('/logout', requireJwtAuth, logoutController);
router.post('/logout', middleware.requireJwtAuth, logoutController);
router.post(
'/login',
logHeaders,
loginLimiter,
checkBan,
ldapAuth ? requireLdapAuth : requireLocalAuth,
middleware.logHeaders,
middleware.loginLimiter,
middleware.checkBan,
ldapAuth ? middleware.requireLdapAuth : middleware.requireLocalAuth,
setBalanceConfig,
loginController,
);
router.post('/refresh', refreshController);
router.post(
'/register',
registerLimiter,
checkBan,
checkInviteUser,
validateRegistration,
middleware.registerLimiter,
middleware.checkBan,
middleware.checkInviteUser,
middleware.validateRegistration,
registrationController,
);
router.post(
'/requestPasswordReset',
resetPasswordLimiter,
checkBan,
validatePasswordReset,
middleware.resetPasswordLimiter,
middleware.checkBan,
middleware.validatePasswordReset,
resetPasswordRequestController,
);
router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController);
router.post(
'/resetPassword',
middleware.checkBan,
middleware.validatePasswordReset,
resetPasswordController,
);
router.get('/2fa/enable', requireJwtAuth, enable2FA);
router.post('/2fa/verify', requireJwtAuth, verify2FA);
router.post('/2fa/verify-temp', checkBan, verify2FAWithTempToken);
router.post('/2fa/confirm', requireJwtAuth, confirm2FA);
router.post('/2fa/disable', requireJwtAuth, disable2FA);
router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodes);
router.get('/2fa/enable', middleware.requireJwtAuth, enable2FA);
router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA);
router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken);
router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA);
router.post('/2fa/disable', middleware.requireJwtAuth, disable2FA);
router.post('/2fa/backup/regenerate', middleware.requireJwtAuth, regenerateBackupCodes);
router.get('/graph-token', requireJwtAuth, graphTokenController);
router.get('/graph-token', middleware.requireJwtAuth, graphTokenController);
module.exports = router;

View file

@ -1,19 +1,20 @@
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
const express = require('express');
const passport = require('passport');
const { isEnabled } = require('@librechat/api');
const { randomState } = require('openid-client');
const { logger } = require('@librechat/data-schemas');
const { ErrorTypes } = require('librechat-data-provider');
const {
checkBan,
logHeaders,
loginLimiter,
setBalanceConfig,
checkDomainAllowed,
} = require('~/server/middleware');
const { isEnabled, createSetBalanceConfig } = require('@librechat/api');
const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware');
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
const { getBalanceConfig } = require('~/server/services/Config');
const { Balance } = require('~/db/models');
const setBalanceConfig = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const router = express.Router();

205
package-lock.json generated
View file

@ -2451,72 +2451,6 @@
"mkdirp": "bin/cmd.js"
}
},
"api/node_modules/mongodb": {
"version": "6.14.2",
"resolved": "https://registry.npmjs.org/mongodb/-/mongodb-6.14.2.tgz",
"integrity": "sha512-kMEHNo0F3P6QKDq17zcDuPeaywK/YaJVCEQRzPF3TOM/Bl9MFg64YE5Tu7ifj37qZJMhwU1tl2Ioivws5gRG5Q==",
"dependencies": {
"@mongodb-js/saslprep": "^1.1.9",
"bson": "^6.10.3",
"mongodb-connection-string-url": "^3.0.0"
},
"engines": {
"node": ">=16.20.1"
},
"peerDependencies": {
"@aws-sdk/credential-providers": "^3.188.0",
"@mongodb-js/zstd": "^1.1.0 || ^2.0.0",
"gcp-metadata": "^5.2.0",
"kerberos": "^2.0.1",
"mongodb-client-encryption": ">=6.0.0 <7",
"snappy": "^7.2.2",
"socks": "^2.7.1"
},
"peerDependenciesMeta": {
"@aws-sdk/credential-providers": {
"optional": true
},
"@mongodb-js/zstd": {
"optional": true
},
"gcp-metadata": {
"optional": true
},
"kerberos": {
"optional": true
},
"mongodb-client-encryption": {
"optional": true
},
"snappy": {
"optional": true
},
"socks": {
"optional": true
}
}
},
"api/node_modules/mongoose": {
"version": "8.12.1",
"resolved": "https://registry.npmjs.org/mongoose/-/mongoose-8.12.1.tgz",
"integrity": "sha512-UW22y8QFVYmrb36hm8cGncfn4ARc/XsYWQwRTaj0gxtQk1rDuhzDO1eBantS+hTTatfAIS96LlRCJrcNHvW5+Q==",
"dependencies": {
"bson": "^6.10.3",
"kareem": "2.6.3",
"mongodb": "~6.14.0",
"mpath": "0.9.0",
"mquery": "5.0.0",
"ms": "2.1.3",
"sift": "17.1.3"
},
"engines": {
"node": ">=16.20.1"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/mongoose"
}
},
"api/node_modules/multer": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/multer/-/multer-2.0.2.tgz",
@ -41372,6 +41306,74 @@
"node": ">= 14"
}
},
"node_modules/mongoose": {
"version": "8.12.1",
"resolved": "https://registry.npmjs.org/mongoose/-/mongoose-8.12.1.tgz",
"integrity": "sha512-UW22y8QFVYmrb36hm8cGncfn4ARc/XsYWQwRTaj0gxtQk1rDuhzDO1eBantS+hTTatfAIS96LlRCJrcNHvW5+Q==",
"license": "MIT",
"dependencies": {
"bson": "^6.10.3",
"kareem": "2.6.3",
"mongodb": "~6.14.0",
"mpath": "0.9.0",
"mquery": "5.0.0",
"ms": "2.1.3",
"sift": "17.1.3"
},
"engines": {
"node": ">=16.20.1"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/mongoose"
}
},
"node_modules/mongoose/node_modules/mongodb": {
"version": "6.14.2",
"resolved": "https://registry.npmjs.org/mongodb/-/mongodb-6.14.2.tgz",
"integrity": "sha512-kMEHNo0F3P6QKDq17zcDuPeaywK/YaJVCEQRzPF3TOM/Bl9MFg64YE5Tu7ifj37qZJMhwU1tl2Ioivws5gRG5Q==",
"license": "Apache-2.0",
"dependencies": {
"@mongodb-js/saslprep": "^1.1.9",
"bson": "^6.10.3",
"mongodb-connection-string-url": "^3.0.0"
},
"engines": {
"node": ">=16.20.1"
},
"peerDependencies": {
"@aws-sdk/credential-providers": "^3.188.0",
"@mongodb-js/zstd": "^1.1.0 || ^2.0.0",
"gcp-metadata": "^5.2.0",
"kerberos": "^2.0.1",
"mongodb-client-encryption": ">=6.0.0 <7",
"snappy": "^7.2.2",
"socks": "^2.7.1"
},
"peerDependenciesMeta": {
"@aws-sdk/credential-providers": {
"optional": true
},
"@mongodb-js/zstd": {
"optional": true
},
"gcp-metadata": {
"optional": true
},
"kerberos": {
"optional": true
},
"mongodb-client-encryption": {
"optional": true
},
"snappy": {
"optional": true
},
"socks": {
"optional": true
}
}
},
"node_modules/moo-color": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/moo-color/-/moo-color-1.0.3.tgz",
@ -51344,7 +51346,7 @@
},
"packages/api": {
"name": "@librechat/api",
"version": "1.3.1",
"version": "1.3.2",
"license": "ISC",
"devDependencies": {
"@babel/preset-env": "^7.21.5",
@ -51368,6 +51370,7 @@
"jest": "^29.5.0",
"jest-junit": "^16.0.0",
"librechat-data-provider": "*",
"mongoose": "^8.12.1",
"rimraf": "^5.0.1",
"rollup": "^4.22.4",
"rollup-plugin-generate-package-json": "^3.2.0",
@ -51991,74 +51994,6 @@
"url": "https://github.com/sponsors/isaacs"
}
},
"packages/data-schemas/node_modules/mongodb": {
"version": "6.14.2",
"resolved": "https://registry.npmjs.org/mongodb/-/mongodb-6.14.2.tgz",
"integrity": "sha512-kMEHNo0F3P6QKDq17zcDuPeaywK/YaJVCEQRzPF3TOM/Bl9MFg64YE5Tu7ifj37qZJMhwU1tl2Ioivws5gRG5Q==",
"peer": true,
"dependencies": {
"@mongodb-js/saslprep": "^1.1.9",
"bson": "^6.10.3",
"mongodb-connection-string-url": "^3.0.0"
},
"engines": {
"node": ">=16.20.1"
},
"peerDependencies": {
"@aws-sdk/credential-providers": "^3.188.0",
"@mongodb-js/zstd": "^1.1.0 || ^2.0.0",
"gcp-metadata": "^5.2.0",
"kerberos": "^2.0.1",
"mongodb-client-encryption": ">=6.0.0 <7",
"snappy": "^7.2.2",
"socks": "^2.7.1"
},
"peerDependenciesMeta": {
"@aws-sdk/credential-providers": {
"optional": true
},
"@mongodb-js/zstd": {
"optional": true
},
"gcp-metadata": {
"optional": true
},
"kerberos": {
"optional": true
},
"mongodb-client-encryption": {
"optional": true
},
"snappy": {
"optional": true
},
"socks": {
"optional": true
}
}
},
"packages/data-schemas/node_modules/mongoose": {
"version": "8.12.1",
"resolved": "https://registry.npmjs.org/mongoose/-/mongoose-8.12.1.tgz",
"integrity": "sha512-UW22y8QFVYmrb36hm8cGncfn4ARc/XsYWQwRTaj0gxtQk1rDuhzDO1eBantS+hTTatfAIS96LlRCJrcNHvW5+Q==",
"peer": true,
"dependencies": {
"bson": "^6.10.3",
"kareem": "2.6.3",
"mongodb": "~6.14.0",
"mpath": "0.9.0",
"mquery": "5.0.0",
"ms": "2.1.3",
"sift": "17.1.3"
},
"engines": {
"node": ">=16.20.1"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/mongoose"
}
},
"packages/data-schemas/node_modules/object-hash": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/object-hash/-/object-hash-3.0.0.tgz",

View file

@ -1,6 +1,6 @@
{
"name": "@librechat/api",
"version": "1.3.1",
"version": "1.3.2",
"type": "commonjs",
"description": "MCP services for LibreChat",
"main": "dist/index.js",
@ -58,6 +58,7 @@
"jest": "^29.5.0",
"jest-junit": "^16.0.0",
"librechat-data-provider": "*",
"mongoose": "^8.12.1",
"rimraf": "^5.0.1",
"rollup": "^4.22.4",
"rollup-plugin-generate-package-json": "^3.2.0",

View file

@ -0,0 +1,583 @@
import mongoose from 'mongoose';
import { MongoMemoryServer } from 'mongodb-memory-server';
import { logger, balanceSchema } from '@librechat/data-schemas';
import type { NextFunction, Request as ServerRequest, Response as ServerResponse } from 'express';
import type { IBalance, BalanceConfig } from '@librechat/data-schemas';
import { createSetBalanceConfig } from './balance';
jest.mock('@librechat/data-schemas', () => ({
...jest.requireActual('@librechat/data-schemas'),
logger: {
error: jest.fn(),
},
}));
let mongoServer: MongoMemoryServer;
let Balance: mongoose.Model<IBalance>;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
Balance = mongoose.models.Balance || mongoose.model('Balance', balanceSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.connection.dropDatabase();
jest.clearAllMocks();
jest.restoreAllMocks();
});
describe('createSetBalanceConfig', () => {
const createMockRequest = (userId: string | mongoose.Types.ObjectId): Partial<ServerRequest> => ({
user: {
_id: userId,
id: userId.toString(),
email: 'test@example.com',
},
});
const createMockResponse = (): Partial<ServerResponse> => ({
status: jest.fn().mockReturnThis(),
json: jest.fn().mockReturnThis(),
});
const mockNext: NextFunction = jest.fn();
const defaultBalanceConfig: BalanceConfig = {
enabled: true,
startBalance: 1000,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
};
describe('Basic Functionality', () => {
test('should create balance record for new user with start balance', async () => {
const userId = new mongoose.Types.ObjectId();
const getBalanceConfig = jest.fn().mockResolvedValue(defaultBalanceConfig);
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(getBalanceConfig).toHaveBeenCalled();
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeTruthy();
expect(balanceRecord?.tokenCredits).toBe(1000);
expect(balanceRecord?.autoRefillEnabled).toBe(true);
expect(balanceRecord?.refillIntervalValue).toBe(30);
expect(balanceRecord?.refillIntervalUnit).toBe('days');
expect(balanceRecord?.refillAmount).toBe(500);
expect(balanceRecord?.lastRefill).toBeInstanceOf(Date);
});
test('should skip if balance config is not enabled', async () => {
const userId = new mongoose.Types.ObjectId();
const getBalanceConfig = jest.fn().mockResolvedValue({ enabled: false });
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeNull();
});
test('should skip if startBalance is null', async () => {
const userId = new mongoose.Types.ObjectId();
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: null,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeNull();
});
test('should handle user._id as string', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const getBalanceConfig = jest.fn().mockResolvedValue(defaultBalanceConfig);
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeTruthy();
expect(balanceRecord?.tokenCredits).toBe(1000);
});
test('should skip if user is not present in request', async () => {
const getBalanceConfig = jest.fn().mockResolvedValue(defaultBalanceConfig);
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = {} as ServerRequest;
const res = createMockResponse();
await middleware(req, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
expect(getBalanceConfig).toHaveBeenCalled();
});
});
describe('Edge Case: Auto-refill without lastRefill', () => {
test('should initialize lastRefill when enabling auto-refill for existing user without lastRefill', async () => {
const userId = new mongoose.Types.ObjectId();
// Create existing balance record without lastRefill
// Note: We need to unset lastRefill after creation since the schema has a default
const doc = await Balance.create({
user: userId,
tokenCredits: 500,
autoRefillEnabled: false,
});
// Remove lastRefill to simulate existing user without it
await Balance.updateOne({ _id: doc._id }, { $unset: { lastRefill: 1 } });
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 1000,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
const beforeTime = new Date();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
const afterTime = new Date();
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeTruthy();
expect(balanceRecord?.tokenCredits).toBe(500); // Should not change existing credits
expect(balanceRecord?.autoRefillEnabled).toBe(true);
expect(balanceRecord?.lastRefill).toBeInstanceOf(Date);
// Verify lastRefill was set to current time
const lastRefillTime = balanceRecord?.lastRefill?.getTime() || 0;
expect(lastRefillTime).toBeGreaterThanOrEqual(beforeTime.getTime());
expect(lastRefillTime).toBeLessThanOrEqual(afterTime.getTime());
});
test('should not update lastRefill if it already exists', async () => {
const userId = new mongoose.Types.ObjectId();
const existingLastRefill = new Date('2024-01-01');
// Create existing balance record with lastRefill
await Balance.create({
user: userId,
tokenCredits: 500,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
lastRefill: existingLastRefill,
});
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 1000,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord?.lastRefill?.getTime()).toBe(existingLastRefill.getTime());
});
test('should handle existing user with auto-refill enabled but missing lastRefill', async () => {
const userId = new mongoose.Types.ObjectId();
// Create a balance record with auto-refill enabled but missing lastRefill
// This simulates the exact edge case reported by the user
const doc = await Balance.create({
user: userId,
tokenCredits: 500,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
});
// Remove lastRefill to simulate the edge case
await Balance.updateOne({ _id: doc._id }, { $unset: { lastRefill: 1 } });
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 1000,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeTruthy();
expect(balanceRecord?.autoRefillEnabled).toBe(true);
expect(balanceRecord?.lastRefill).toBeInstanceOf(Date);
// This should have fixed the issue - user should no longer get the error
});
test('should not set lastRefill when auto-refill is disabled', async () => {
const userId = new mongoose.Types.ObjectId();
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 1000,
autoRefillEnabled: false,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeTruthy();
expect(balanceRecord?.tokenCredits).toBe(1000);
expect(balanceRecord?.autoRefillEnabled).toBe(false);
// lastRefill should have default value from schema
expect(balanceRecord?.lastRefill).toBeInstanceOf(Date);
});
});
describe('Update Scenarios', () => {
test('should update auto-refill settings for existing user', async () => {
const userId = new mongoose.Types.ObjectId();
// Create existing balance record
await Balance.create({
user: userId,
tokenCredits: 500,
autoRefillEnabled: false,
refillIntervalValue: 7,
refillIntervalUnit: 'days',
refillAmount: 100,
});
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 1000,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord?.tokenCredits).toBe(500); // Should not change
expect(balanceRecord?.autoRefillEnabled).toBe(true);
expect(balanceRecord?.refillIntervalValue).toBe(30);
expect(balanceRecord?.refillIntervalUnit).toBe('days');
expect(balanceRecord?.refillAmount).toBe(500);
});
test('should not update if values are already the same', async () => {
const userId = new mongoose.Types.ObjectId();
const lastRefillTime = new Date();
// Create existing balance record with same values
await Balance.create({
user: userId,
tokenCredits: 1000,
autoRefillEnabled: true,
refillIntervalValue: 30,
refillIntervalUnit: 'days',
refillAmount: 500,
lastRefill: lastRefillTime,
});
const getBalanceConfig = jest.fn().mockResolvedValue(defaultBalanceConfig);
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
// Spy on Balance.findOneAndUpdate to verify it's not called
const updateSpy = jest.spyOn(Balance, 'findOneAndUpdate');
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
expect(updateSpy).not.toHaveBeenCalled();
});
test('should set tokenCredits for user with null tokenCredits', async () => {
const userId = new mongoose.Types.ObjectId();
// Create balance record with null tokenCredits
await Balance.create({
user: userId,
tokenCredits: null,
});
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 2000,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord?.tokenCredits).toBe(2000);
});
});
describe('Error Handling', () => {
test('should handle database errors gracefully', async () => {
const userId = new mongoose.Types.ObjectId();
const getBalanceConfig = jest.fn().mockResolvedValue(defaultBalanceConfig);
const dbError = new Error('Database error');
// Mock Balance.findOne to throw an error
jest.spyOn(Balance, 'findOne').mockImplementationOnce((() => {
return {
lean: jest.fn().mockRejectedValue(dbError),
};
}) as unknown as mongoose.Model<IBalance>['findOne']);
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(logger.error).toHaveBeenCalledWith('Error setting user balance:', dbError);
expect(mockNext).toHaveBeenCalledWith(dbError);
});
test('should handle getBalanceConfig errors', async () => {
const userId = new mongoose.Types.ObjectId();
const configError = new Error('Config error');
const getBalanceConfig = jest.fn().mockRejectedValue(configError);
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(logger.error).toHaveBeenCalledWith('Error setting user balance:', configError);
expect(mockNext).toHaveBeenCalledWith(configError);
});
test('should handle invalid auto-refill configuration', async () => {
const userId = new mongoose.Types.ObjectId();
// Missing required auto-refill fields
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 1000,
autoRefillEnabled: true,
refillIntervalValue: null, // Invalid
refillIntervalUnit: 'days',
refillAmount: 500,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
expect(mockNext).toHaveBeenCalled();
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord).toBeTruthy();
expect(balanceRecord?.tokenCredits).toBe(1000);
// Auto-refill fields should not be updated due to invalid config
expect(balanceRecord?.autoRefillEnabled).toBe(false);
});
});
describe('Concurrent Updates', () => {
test('should handle concurrent middleware calls for same user', async () => {
const userId = new mongoose.Types.ObjectId();
const getBalanceConfig = jest.fn().mockResolvedValue(defaultBalanceConfig);
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res1 = createMockResponse();
const res2 = createMockResponse();
const mockNext1 = jest.fn();
const mockNext2 = jest.fn();
// Run middleware concurrently
await Promise.all([
middleware(req as ServerRequest, res1 as ServerResponse, mockNext1),
middleware(req as ServerRequest, res2 as ServerResponse, mockNext2),
]);
expect(mockNext1).toHaveBeenCalled();
expect(mockNext2).toHaveBeenCalled();
// Should only have one balance record
const balanceRecords = await Balance.find({ user: userId });
expect(balanceRecords).toHaveLength(1);
expect(balanceRecords[0].tokenCredits).toBe(1000);
});
});
describe('Integration with Different refillIntervalUnits', () => {
test.each(['seconds', 'minutes', 'hours', 'days', 'weeks', 'months'])(
'should handle refillIntervalUnit: %s',
async (unit) => {
const userId = new mongoose.Types.ObjectId();
const getBalanceConfig = jest.fn().mockResolvedValue({
enabled: true,
startBalance: 1000,
autoRefillEnabled: true,
refillIntervalValue: 10,
refillIntervalUnit: unit,
refillAmount: 100,
});
const middleware = createSetBalanceConfig({
getBalanceConfig,
Balance,
});
const req = createMockRequest(userId);
const res = createMockResponse();
await middleware(req as ServerRequest, res as ServerResponse, mockNext);
const balanceRecord = await Balance.findOne({ user: userId });
expect(balanceRecord?.refillIntervalUnit).toBe(unit);
expect(balanceRecord?.refillIntervalValue).toBe(10);
expect(balanceRecord?.lastRefill).toBeInstanceOf(Date);
},
);
});
});

View file

@ -0,0 +1,117 @@
import { logger } from '@librechat/data-schemas';
import type { NextFunction, Request as ServerRequest, Response as ServerResponse } from 'express';
import type { IBalance, IUser, BalanceConfig, ObjectId } from '@librechat/data-schemas';
import type { Model } from 'mongoose';
import type { BalanceUpdateFields } from '~/types';
export interface BalanceMiddlewareOptions {
getBalanceConfig: () => Promise<BalanceConfig | null>;
Balance: Model<IBalance>;
}
/**
* Build an object containing fields that need updating
* @param config - The balance configuration
* @param userRecord - The user's current balance record, if any
* @param userId - The user's ID
* @returns Fields that need updating
*/
function buildUpdateFields(
config: BalanceConfig,
userRecord: IBalance | null,
userId: string,
): BalanceUpdateFields {
const updateFields: BalanceUpdateFields = {};
// Ensure user record has the required fields
if (!userRecord) {
updateFields.user = userId;
updateFields.tokenCredits = config.startBalance;
}
if (userRecord?.tokenCredits == null && config.startBalance != null) {
updateFields.tokenCredits = config.startBalance;
}
const isAutoRefillConfigValid =
config.autoRefillEnabled &&
config.refillIntervalValue != null &&
config.refillIntervalUnit != null &&
config.refillAmount != null;
if (!isAutoRefillConfigValid) {
return updateFields;
}
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
updateFields.autoRefillEnabled = config.autoRefillEnabled;
}
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
updateFields.refillIntervalValue = config.refillIntervalValue;
}
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
updateFields.refillIntervalUnit = config.refillIntervalUnit;
}
if (userRecord?.refillAmount !== config.refillAmount) {
updateFields.refillAmount = config.refillAmount;
}
// Initialize lastRefill if it's missing when auto-refill is enabled
if (config.autoRefillEnabled && !userRecord?.lastRefill) {
updateFields.lastRefill = new Date();
}
return updateFields;
}
/**
* Factory function to create middleware that synchronizes user balance settings with current balance configuration.
* @param options - Options containing getBalanceConfig function and Balance model
* @returns Express middleware function
*/
export function createSetBalanceConfig({
getBalanceConfig,
Balance,
}: BalanceMiddlewareOptions): (
req: ServerRequest,
res: ServerResponse,
next: NextFunction,
) => Promise<void> {
return async (req: ServerRequest, res: ServerResponse, next: NextFunction): Promise<void> => {
try {
const balanceConfig = await getBalanceConfig();
if (!balanceConfig?.enabled) {
return next();
}
if (balanceConfig.startBalance == null) {
return next();
}
const user = req.user as IUser & { _id: string | ObjectId };
if (!user || !user._id) {
return next();
}
const userId = typeof user._id === 'string' ? user._id : user._id.toString();
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord, userId);
if (Object.keys(updateFields).length === 0) {
return next();
}
await Balance.findOneAndUpdate(
{ user: userId },
{ $set: updateFields },
{ upsert: true, new: true },
);
next();
} catch (error) {
logger.error('Error setting user balance:', error);
next(error);
}
};
}

View file

@ -1,2 +1,3 @@
export * from './access';
export * from './error';
export * from './balance';

View file

@ -0,0 +1,9 @@
export interface BalanceUpdateFields {
user?: string;
tokenCredits?: number;
autoRefillEnabled?: boolean;
refillIntervalValue?: number;
refillIntervalUnit?: string;
refillAmount?: number;
lastRefill?: Date;
}

View file

@ -1,4 +1,5 @@
export * from './azure';
export * from './balance';
export * from './events';
export * from './error';
export * from './google';