mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
🛡️ feat: Model Validation Middleware (#1841)
* refactor: add ViolationTypes enum and add new violation for illegal model requests * feat: validateModel middleware to protect the backend against illicit requests for unlisted models
This commit is contained in:
parent
d8038e3b19
commit
a8a19c6caa
19 changed files with 539 additions and 377 deletions
|
|
@ -238,6 +238,8 @@ LIMIT_MESSAGE_USER=false
|
|||
MESSAGE_USER_MAX=40
|
||||
MESSAGE_USER_WINDOW=1
|
||||
|
||||
ILLEGAL_MODEL_REQ_SCORE=5
|
||||
|
||||
#========================#
|
||||
# Balance #
|
||||
#========================#
|
||||
|
|
|
|||
7
api/cache/getLogStores.js
vendored
7
api/cache/getLogStores.js
vendored
|
|
@ -1,5 +1,5 @@
|
|||
const Keyv = require('keyv');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { logFile, violationFile } = require('./keyvFiles');
|
||||
const { math, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
|
|
@ -49,7 +49,10 @@ const namespaces = {
|
|||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance('token_balance'),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
[CacheKeys.FILE_UPLOAD_LIMIT]: createViolationInstance(CacheKeys.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||
),
|
||||
logins: createViolationInstance('logins'),
|
||||
[CacheKeys.ABORT_KEYS]: abortKeys,
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ const checkBan = require('./checkBan');
|
|||
const uaParser = require('./uaParser');
|
||||
const setHeaders = require('./setHeaders');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const validateModel = require('./validateModel');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
|
|
@ -32,6 +33,7 @@ module.exports = {
|
|||
validateMessageReq,
|
||||
buildEndpointOption,
|
||||
validateRegistration,
|
||||
validateModel,
|
||||
moderateText,
|
||||
noIndex,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const rateLimit = require('express-rate-limit');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
|
|
@ -35,7 +35,7 @@ const createFileUploadHandler = (ip = true) => {
|
|||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = CacheKeys.FILE_UPLOAD_LIMIT;
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? fileUploadIpMax : fileUploadUserMax,
|
||||
|
|
|
|||
50
api/server/middleware/validateModel.js
Normal file
50
api/server/middleware/validateModel.js
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
const { EModelEndpoint, CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
const { handleError } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* Validates the model of the request.
|
||||
*
|
||||
* @async
|
||||
* @param {Express.Request} req - The Express request object.
|
||||
* @param {Express.Response} res - The Express response object.
|
||||
* @param {Function} next - The Express next function.
|
||||
*/
|
||||
const validateModel = async (req, res, next) => {
|
||||
const { model, endpoint } = req.body;
|
||||
if (!model) {
|
||||
return handleError(res, { text: 'Model not provided' });
|
||||
}
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
|
||||
if (!modelsConfig) {
|
||||
return handleError(res, { text: 'Models not loaded' });
|
||||
}
|
||||
|
||||
const availableModels = modelsConfig[endpoint];
|
||||
if (!availableModels) {
|
||||
return handleError(res, { text: 'Endpoint models not loaded' });
|
||||
}
|
||||
|
||||
let validModel = !!availableModels.find((availableModel) => availableModel === model);
|
||||
if (endpoint === EModelEndpoint.gptPlugins) {
|
||||
validModel = validModel && availableModels.includes(req.body.agentOptions?.model);
|
||||
}
|
||||
|
||||
if (validModel) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const { ILLEGAL_MODEL_REQ_SCORE: score = 5 } = process.env ?? {};
|
||||
|
||||
const type = ViolationTypes.ILLEGAL_MODEL_REQUEST;
|
||||
const errorMessage = {
|
||||
type,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return handleError(res, { text: 'Illegal model request' });
|
||||
};
|
||||
|
||||
module.exports = validateModel;
|
||||
|
|
@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic');
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -12,8 +13,15 @@ const router = express.Router();
|
|||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI');
|
|||
const {
|
||||
handleAbort,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -13,8 +14,15 @@ const router = express.Router();
|
|||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google');
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -12,8 +13,15 @@ const router = express.Router();
|
|||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ const {
|
|||
createAbortController,
|
||||
handleAbortError,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
moderateText,
|
||||
|
|
@ -20,7 +21,13 @@ const { logger } = require('~/config');
|
|||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res) => {
|
||||
let {
|
||||
text,
|
||||
endpointOption,
|
||||
|
|
@ -36,7 +43,10 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
|||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
model: endpointOption.modelOptions.model,
|
||||
});
|
||||
const newConvo = !conversationId;
|
||||
const user = req.user.id;
|
||||
|
||||
|
|
@ -221,6 +231,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
|||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const { addTitle, initializeClient } = require('~/server/services/Endpoints/open
|
|||
const {
|
||||
handleAbort,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
moderateText,
|
||||
|
|
@ -13,8 +14,15 @@ const router = express.Router();
|
|||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const router = express.Router();
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
handleAbortError,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
|
|
@ -36,7 +37,7 @@ router.post('/abort', handleAbort());
|
|||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/', buildEndpointOption, setHeaders, async (req, res) => {
|
||||
router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res) => {
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
const {
|
||||
text,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic');
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -12,8 +13,15 @@ const router = express.Router();
|
|||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await EditController(req, res, next, initializeClient);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI');
|
|||
const {
|
||||
handleAbort,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -13,8 +14,15 @@ const router = express.Router();
|
|||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await EditController(req, res, next, initializeClient, addTitle);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google');
|
|||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -12,8 +13,15 @@ const router = express.Router();
|
|||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await EditController(req, res, next, initializeClient);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const {
|
|||
createAbortController,
|
||||
handleAbortError,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
moderateText,
|
||||
|
|
@ -19,7 +20,13 @@ const { logger } = require('~/config');
|
|||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res) => {
|
||||
let {
|
||||
text,
|
||||
generation,
|
||||
|
|
@ -43,7 +50,10 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
|||
let promptTokens;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
model: endpointOption.modelOptions.model,
|
||||
});
|
||||
const userMessageId = parentMessageId;
|
||||
const user = req.user.id;
|
||||
|
||||
|
|
@ -192,6 +202,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
|||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/openAI');
|
|||
const {
|
||||
handleAbort,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
moderateText,
|
||||
|
|
@ -13,8 +14,15 @@ const router = express.Router();
|
|||
router.use(moderateText);
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await EditController(req, res, next, initializeClient);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
|
|
|||
|
|
@ -69,8 +69,13 @@ MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_M
|
|||
LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW
|
||||
MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW
|
||||
MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages
|
||||
|
||||
ILLEGAL_MODEL_REQ_SCORE=5 #Violation score to accrue if a user attempts to use an unlisted model.
|
||||
|
||||
```
|
||||
|
||||
> Note: Illegal model requests are almost always nefarious as it means a 3rd party is attempting to access the server through an automated script. For this, I recommend a relatively high score, no less than 5.
|
||||
|
||||
## OpenAI moderation text
|
||||
|
||||
### OPENAI_MODERATION
|
||||
|
|
|
|||
|
|
@ -602,8 +602,11 @@ REGISTRATION_VIOLATION_SCORE=1
|
|||
CONCURRENT_VIOLATION_SCORE=1
|
||||
MESSAGE_VIOLATION_SCORE=1
|
||||
NON_BROWSER_VIOLATION_SCORE=20
|
||||
ILLEGAL_MODEL_REQ_SCORE=5
|
||||
```
|
||||
|
||||
> Note: Non-browser access and Illegal model requests are almost always nefarious as it means a 3rd party is attempting to access the server through an automated script.
|
||||
|
||||
#### Login and registration rate limiting.
|
||||
- `LOGIN_MAX`: The max amount of logins allowed per IP per `LOGIN_WINDOW`
|
||||
- `LOGIN_WINDOW`: In minutes, determines the window of time for `LOGIN_MAX` logins
|
||||
|
|
|
|||
|
|
@ -284,10 +284,20 @@ export enum CacheKeys {
|
|||
* Key for the override config cache.
|
||||
*/
|
||||
OVERRIDE_CONFIG = 'overrideConfig',
|
||||
}
|
||||
|
||||
/**
|
||||
* Key for accessing File Upload Violations (exceeding limit).
|
||||
* Enum for violation types, used to identify, log, and cache violations.
|
||||
*/
|
||||
export enum ViolationTypes {
|
||||
/**
|
||||
* File Upload Violations (exceeding limit).
|
||||
*/
|
||||
FILE_UPLOAD_LIMIT = 'file_upload_limit',
|
||||
/**
|
||||
* Illegal Model Request (not available).
|
||||
*/
|
||||
ILLEGAL_MODEL_REQUEST = 'illegal_model_request',
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue