mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-19 09:50:15 +01:00
- Updated import paths for connectDb across various files to use the new centralized connect module. - Removed the old connectDb file to streamline the database connection logic. - Ensured all tests and models reference the new connection method for consistency.
177 lines
5.3 KiB
JavaScript
177 lines
5.3 KiB
JavaScript
const banViolation = require('./banViolation');
|
|
|
|
const mockModels = {
|
|
Session: {
|
|
deleteAllUserSessions: jest.fn(),
|
|
},
|
|
};
|
|
|
|
jest.mock('~/db/connect', () => {
|
|
return {
|
|
connectDb: jest.fn(),
|
|
get models() {
|
|
return mockModels;
|
|
},
|
|
};
|
|
});
|
|
|
|
jest.mock('~/server/utils', () => ({
|
|
isEnabled: jest.fn(() => true), // default to false, override per test if needed
|
|
math: jest.fn(() => 20), // default to false, override per test if needed
|
|
removePorts: jest.fn(),
|
|
}));
|
|
|
|
jest.mock('keyv');
|
|
// jest.mock('../models/Session');
|
|
// Mocking the getLogStores function
|
|
jest.mock('./getLogStores', () => {
|
|
return jest.fn().mockImplementation(() => {
|
|
const EventEmitter = require('events');
|
|
const { CacheKeys } = require('librechat-data-provider');
|
|
const math = require('../server/utils/math');
|
|
const mockGet = jest.fn();
|
|
const mockSet = jest.fn();
|
|
class KeyvMongo extends EventEmitter {
|
|
constructor(url = 'mongodb://127.0.0.1:27017', options) {
|
|
super();
|
|
this.ttlSupport = false;
|
|
url = url ?? {};
|
|
if (typeof url === 'string') {
|
|
url = { url };
|
|
}
|
|
if (url.uri) {
|
|
url = { url: url.uri, ...url };
|
|
}
|
|
this.opts = {
|
|
url,
|
|
collection: 'keyv',
|
|
...url,
|
|
...options,
|
|
};
|
|
}
|
|
|
|
get = mockGet;
|
|
set = mockSet;
|
|
}
|
|
|
|
return new KeyvMongo('', {
|
|
namespace: CacheKeys.BANS,
|
|
ttl: math(process.env.BAN_DURATION, 7200000),
|
|
});
|
|
});
|
|
});
|
|
|
|
describe('banViolation', () => {
|
|
let req, res, errorMessage;
|
|
|
|
beforeEach(() => {
|
|
req = {
|
|
ip: '127.0.0.1',
|
|
cookies: {
|
|
refreshToken: 'someToken',
|
|
},
|
|
};
|
|
res = {
|
|
clearCookie: jest.fn(),
|
|
};
|
|
errorMessage = {
|
|
type: 'someViolation',
|
|
user_id: '12345',
|
|
prev_count: 0,
|
|
violation_count: 0,
|
|
};
|
|
process.env.BAN_VIOLATIONS = 'true';
|
|
process.env.BAN_DURATION = '7200000'; // 2 hours in ms
|
|
process.env.BAN_INTERVAL = '20';
|
|
});
|
|
|
|
afterEach(() => {
|
|
jest.clearAllMocks();
|
|
});
|
|
|
|
it('should not ban if BAN_VIOLATIONS are not enabled', async () => {
|
|
process.env.BAN_VIOLATIONS = 'false';
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeFalsy();
|
|
});
|
|
|
|
it('should not ban if errorMessage is not provided', async () => {
|
|
await banViolation(req, res, null);
|
|
expect(errorMessage.ban).toBeFalsy();
|
|
});
|
|
|
|
it('[1/3] should ban if violation_count crosses the interval threshold: 19 -> 39', async () => {
|
|
errorMessage.prev_count = 19;
|
|
errorMessage.violation_count = 39;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeTruthy();
|
|
});
|
|
|
|
it('[2/3] should ban if violation_count crosses the interval threshold: 19 -> 20', async () => {
|
|
errorMessage.prev_count = 19;
|
|
errorMessage.violation_count = 20;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeTruthy();
|
|
});
|
|
|
|
const randomValueAbove = Math.floor(20 + Math.random() * 100);
|
|
it(`[3/3] should ban if violation_count crosses the interval threshold: 19 -> ${randomValueAbove}`, async () => {
|
|
errorMessage.prev_count = 19;
|
|
errorMessage.violation_count = randomValueAbove;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeTruthy();
|
|
});
|
|
|
|
it('should handle invalid BAN_INTERVAL and default to 20', async () => {
|
|
process.env.BAN_INTERVAL = 'invalid';
|
|
errorMessage.prev_count = 19;
|
|
errorMessage.violation_count = 39;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeTruthy();
|
|
});
|
|
|
|
it('should ban if BAN_DURATION is invalid as default is 2 hours', async () => {
|
|
process.env.BAN_DURATION = 'invalid';
|
|
errorMessage.prev_count = 19;
|
|
errorMessage.violation_count = 39;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeTruthy();
|
|
});
|
|
|
|
it('should not ban if BAN_DURATION is 0 but should clear cookies', async () => {
|
|
process.env.BAN_DURATION = '0';
|
|
errorMessage.prev_count = 19;
|
|
errorMessage.violation_count = 39;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(res.clearCookie).toHaveBeenCalledWith('refreshToken');
|
|
});
|
|
|
|
it('should not ban if violation_count does not change', async () => {
|
|
errorMessage.prev_count = 0;
|
|
errorMessage.violation_count = 0;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeFalsy();
|
|
});
|
|
|
|
it('[1/2] should not ban if violation_count does not cross the interval threshold: 0 -> 19', async () => {
|
|
errorMessage.prev_count = 0;
|
|
errorMessage.violation_count = 19;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeFalsy();
|
|
});
|
|
|
|
const randomValueUnder = Math.floor(1 + Math.random() * 19);
|
|
it(`[2/2] should not ban if violation_count does not cross the interval threshold: 0 -> ${randomValueUnder}`, async () => {
|
|
errorMessage.prev_count = 0;
|
|
errorMessage.violation_count = randomValueUnder;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeFalsy();
|
|
});
|
|
|
|
it('[EDGE CASE] should not ban if violation_count is lower', async () => {
|
|
errorMessage.prev_count = 0;
|
|
errorMessage.violation_count = -10;
|
|
await banViolation(req, res, errorMessage);
|
|
expect(errorMessage.ban).toBeFalsy();
|
|
});
|
|
});
|