mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-03 14:27:20 +02:00
💰 fix: Lazy-Initialize Balance Record at Check Time for Overrides (#12474)
* fix: Lazy-initialize balance record when missing at check time When balance is configured via admin panel DB overrides, users with existing sessions never pass through the login middleware that creates their balance record. This causes checkBalanceRecord to find no record and return balance: 0, blocking the user. Add optional balanceConfig and upsertBalanceFields deps to CheckBalanceDeps. When no balance record exists but startBalance is configured, lazily create the record instead of returning canSpend: false. Pass the new deps from BaseClient, chatV1, and chatV2 callers. * test: Add checkBalance lazy initialization tests Cover lazy balance init scenarios: successful init with startBalance, insufficient startBalance, missing config fallback, undefined startBalance, missing upsertBalanceFields dep, and startBalance of 0. * fix: Address review findings for lazy balance initialization - Use canonical BalanceConfig and IBalanceUpdate types from @librechat/data-schemas instead of inline type definitions - Include auto-refill fields (autoRefillEnabled, refillIntervalValue, refillIntervalUnit, refillAmount, lastRefill) during lazy init, mirroring the login middleware's buildUpdateFields logic - Add try/catch around upsertBalanceFields with graceful fallback to canSpend: false on DB errors - Read balance from DB return value instead of raw startBalance constant - Fix misleading test names to describe observable throw behavior - Add tests: upsertBalanceFields rejection, auto-refill field inclusion, DB-returned balance value, and logViolation assertions * fix: Address second review pass findings - Fix import ordering: package type imports before local type imports - Remove misleading comment on DB-fallback test, rename for clarity - Add logViolation assertion to insufficient-balance lazy-init test - Add test for partial auto-refill config (autoRefillEnabled without required dependent fields) * refactor: Replace createMockReqRes factory with describe-scoped consts Replace zero-argument factory with plain const declarations using direct type casts instead of double-cast through unknown. * fix: Sort local type imports longest-first, add missing logViolation assertion - Reorder local type imports in spec file per AGENTS.md (longest to shortest within sub-group) - Add logViolation assertion to startBalance: 0 test for consistent violation payload coverage across all throw paths
This commit is contained in:
parent
4f37e8adb9
commit
fd01dfc083
5 changed files with 323 additions and 3 deletions
266
packages/api/src/middleware/checkBalance.spec.ts
Normal file
266
packages/api/src/middleware/checkBalance.spec.ts
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
import { ViolationTypes } from 'librechat-data-provider';
|
||||
import type { Response } from 'express';
|
||||
import type { CheckBalanceDeps } from './checkBalance';
|
||||
import type { ServerRequest } from '~/types/http';
|
||||
import { checkBalance } from './checkBalance';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('checkBalance', () => {
|
||||
const createMockDeps = (overrides: Partial<CheckBalanceDeps> = {}): CheckBalanceDeps => ({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue({ tokenCredits: 1000 }),
|
||||
getMultiplier: jest.fn().mockReturnValue(1),
|
||||
createAutoRefillTransaction: jest.fn(),
|
||||
logViolation: jest.fn().mockResolvedValue(undefined),
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const req = { user: { id: 'user-1' } } as ServerRequest;
|
||||
const res = {} as Response;
|
||||
|
||||
const baseTxData = {
|
||||
user: 'user-1',
|
||||
tokenType: 'prompt',
|
||||
amount: 100,
|
||||
endpoint: 'openAI',
|
||||
model: 'gpt-4',
|
||||
};
|
||||
|
||||
it('should return true when user has sufficient balance', async () => {
|
||||
const deps = createMockDeps();
|
||||
|
||||
const result = await checkBalance({ req, res, txData: baseTxData }, deps);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should throw when user has insufficient balance', async () => {
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue({ tokenCredits: 10 }),
|
||||
getMultiplier: jest.fn().mockReturnValue(1),
|
||||
});
|
||||
|
||||
await expect(
|
||||
checkBalance({ req, res, txData: { ...baseTxData, amount: 100 } }, deps),
|
||||
).rejects.toThrow();
|
||||
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 10, tokenCost: 100 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
describe('lazy balance initialization', () => {
|
||||
it('should create balance record when no record exists and startBalance is configured', async () => {
|
||||
const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 5000 });
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
balanceConfig: { startBalance: 5000 },
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
const result = await checkBalance({ req, res, txData: baseTxData }, deps);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', {
|
||||
user: 'user-1',
|
||||
tokenCredits: 5000,
|
||||
});
|
||||
});
|
||||
|
||||
it('should include auto-refill fields when configured', async () => {
|
||||
const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 5000 });
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
balanceConfig: {
|
||||
startBalance: 5000,
|
||||
autoRefillEnabled: true,
|
||||
refillIntervalValue: 1,
|
||||
refillIntervalUnit: 'days',
|
||||
refillAmount: 1000,
|
||||
},
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
await checkBalance({ req, res, txData: baseTxData }, deps);
|
||||
|
||||
expect(upsertBalanceFields).toHaveBeenCalledWith(
|
||||
'user-1',
|
||||
expect.objectContaining({
|
||||
user: 'user-1',
|
||||
tokenCredits: 5000,
|
||||
autoRefillEnabled: true,
|
||||
refillIntervalValue: 1,
|
||||
refillIntervalUnit: 'days',
|
||||
refillAmount: 1000,
|
||||
lastRefill: expect.any(Date),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not include auto-refill fields when config is partial', async () => {
|
||||
const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 5000 });
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
balanceConfig: { startBalance: 5000, autoRefillEnabled: true },
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
await checkBalance({ req, res, txData: baseTxData }, deps);
|
||||
|
||||
expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', {
|
||||
user: 'user-1',
|
||||
tokenCredits: 5000,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw a TOKEN_BALANCE violation when lazy-initialized balance is less than token cost', async () => {
|
||||
const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 50 });
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
getMultiplier: jest.fn().mockReturnValue(1),
|
||||
balanceConfig: { startBalance: 50 },
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
await expect(
|
||||
checkBalance({ req, res, txData: { ...baseTxData, amount: 100 } }, deps),
|
||||
).rejects.toThrow();
|
||||
|
||||
expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', {
|
||||
user: 'user-1',
|
||||
tokenCredits: 50,
|
||||
});
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 50, tokenCost: 100 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use DB-returned tokenCredits over raw startBalance config constant', async () => {
|
||||
const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 3000 });
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
getMultiplier: jest.fn().mockReturnValue(1),
|
||||
balanceConfig: { startBalance: 5000 },
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
await expect(
|
||||
checkBalance({ req, res, txData: { ...baseTxData, amount: 4000 } }, deps),
|
||||
).rejects.toThrow();
|
||||
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 3000, tokenCost: 4000 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw a TOKEN_BALANCE violation when no record and no balanceConfig', async () => {
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
});
|
||||
|
||||
await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow();
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 0 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw a TOKEN_BALANCE violation when no record and startBalance is undefined', async () => {
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
balanceConfig: {},
|
||||
upsertBalanceFields: jest.fn(),
|
||||
});
|
||||
|
||||
await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow();
|
||||
expect(deps.upsertBalanceFields).not.toHaveBeenCalled();
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 0 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw a TOKEN_BALANCE violation when upsertBalanceFields is not provided', async () => {
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
balanceConfig: { startBalance: 5000 },
|
||||
});
|
||||
|
||||
await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow();
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 0 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle startBalance of 0', async () => {
|
||||
const upsertBalanceFields = jest.fn().mockResolvedValue({ tokenCredits: 0 });
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
getMultiplier: jest.fn().mockReturnValue(1),
|
||||
balanceConfig: { startBalance: 0 },
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
await expect(
|
||||
checkBalance({ req, res, txData: { ...baseTxData, amount: 100 } }, deps),
|
||||
).rejects.toThrow();
|
||||
|
||||
expect(upsertBalanceFields).toHaveBeenCalledWith('user-1', {
|
||||
user: 'user-1',
|
||||
tokenCredits: 0,
|
||||
});
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 0, tokenCost: 100 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to balance: 0 when upsertBalanceFields rejects', async () => {
|
||||
const upsertBalanceFields = jest.fn().mockRejectedValue(new Error('DB unavailable'));
|
||||
const deps = createMockDeps({
|
||||
findBalanceByUser: jest.fn().mockResolvedValue(null),
|
||||
balanceConfig: { startBalance: 5000 },
|
||||
upsertBalanceFields,
|
||||
});
|
||||
|
||||
await expect(checkBalance({ req, res, txData: baseTxData }, deps)).rejects.toThrow();
|
||||
expect(deps.logViolation).toHaveBeenCalledWith(
|
||||
req,
|
||||
res,
|
||||
ViolationTypes.TOKEN_BALANCE,
|
||||
expect.objectContaining({ balance: 0 }),
|
||||
0,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { ViolationTypes } from 'librechat-data-provider';
|
||||
import type { ServerRequest } from '~/types/http';
|
||||
import type { BalanceConfig, IBalanceUpdate } from '@librechat/data-schemas';
|
||||
import type { Response } from 'express';
|
||||
import type { ServerRequest } from '~/types/http';
|
||||
|
||||
type TimeUnit = 'seconds' | 'minutes' | 'hours' | 'days' | 'weeks' | 'months';
|
||||
|
||||
|
|
@ -38,6 +39,10 @@ export interface CheckBalanceDeps {
|
|||
errorMessage: Record<string, unknown>,
|
||||
score: number,
|
||||
) => Promise<void>;
|
||||
/** Balance config for lazy initialization when no record exists */
|
||||
balanceConfig?: BalanceConfig;
|
||||
/** Upsert function for lazy initialization when no record exists */
|
||||
upsertBalanceFields?: (userId: string, fields: IBalanceUpdate) => Promise<BalanceRecord | null>;
|
||||
}
|
||||
|
||||
function addIntervalToDate(date: Date, value: number, unit: TimeUnit): Date {
|
||||
|
|
@ -84,6 +89,37 @@ async function checkBalanceRecord(
|
|||
|
||||
const record = await deps.findBalanceByUser(user);
|
||||
if (!record) {
|
||||
if (deps.balanceConfig?.startBalance != null && deps.upsertBalanceFields) {
|
||||
logger.debug('[Balance.check] Lazy-initializing balance record for user', {
|
||||
user,
|
||||
startBalance: deps.balanceConfig.startBalance,
|
||||
});
|
||||
try {
|
||||
const fields: IBalanceUpdate = {
|
||||
user,
|
||||
tokenCredits: deps.balanceConfig.startBalance,
|
||||
};
|
||||
const config = deps.balanceConfig;
|
||||
if (
|
||||
config.autoRefillEnabled &&
|
||||
config.refillIntervalValue != null &&
|
||||
config.refillIntervalUnit != null &&
|
||||
config.refillAmount != null
|
||||
) {
|
||||
fields.autoRefillEnabled = config.autoRefillEnabled;
|
||||
fields.refillIntervalValue = config.refillIntervalValue;
|
||||
fields.refillIntervalUnit = config.refillIntervalUnit;
|
||||
fields.refillAmount = config.refillAmount;
|
||||
fields.lastRefill = new Date();
|
||||
}
|
||||
const created = await deps.upsertBalanceFields(user, fields);
|
||||
const balance = created?.tokenCredits ?? deps.balanceConfig.startBalance;
|
||||
return { canSpend: balance >= tokenCost, balance, tokenCost };
|
||||
} catch (error) {
|
||||
logger.error('[Balance.check] Failed to lazy-initialize balance record', { user, error });
|
||||
return { canSpend: false, balance: 0, tokenCost };
|
||||
}
|
||||
}
|
||||
logger.debug('[Balance.check] No balance record found for user', { user });
|
||||
return { canSpend: false, balance: 0, tokenCost };
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue