🛡️ feat: Model Validation Middleware (#1841)

* refactor: add ViolationTypes enum and add new violation for illegal model requests

* feat: validateModel middleware to protect the backend against illicit requests for unlisted models
This commit is contained in:
Danny Avila 2024-02-19 22:47:39 -05:00 committed by GitHub
parent d8038e3b19
commit a8a19c6caa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 539 additions and 377 deletions

View file

@ -238,6 +238,8 @@ LIMIT_MESSAGE_USER=false
MESSAGE_USER_MAX=40 MESSAGE_USER_MAX=40
MESSAGE_USER_WINDOW=1 MESSAGE_USER_WINDOW=1
ILLEGAL_MODEL_REQ_SCORE=5
#========================# #========================#
# Balance # # Balance #
#========================# #========================#

View file

@ -1,5 +1,5 @@
const Keyv = require('keyv'); const Keyv = require('keyv');
const { CacheKeys } = require('librechat-data-provider'); const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { logFile, violationFile } = require('./keyvFiles'); const { logFile, violationFile } = require('./keyvFiles');
const { math, isEnabled } = require('~/server/utils'); const { math, isEnabled } = require('~/server/utils');
const keyvRedis = require('./keyvRedis'); const keyvRedis = require('./keyvRedis');
@ -49,7 +49,10 @@ const namespaces = {
message_limit: createViolationInstance('message_limit'), message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance('token_balance'), token_balance: createViolationInstance('token_balance'),
registrations: createViolationInstance('registrations'), registrations: createViolationInstance('registrations'),
[CacheKeys.FILE_UPLOAD_LIMIT]: createViolationInstance(CacheKeys.FILE_UPLOAD_LIMIT), [ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
logins: createViolationInstance('logins'), logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys, [CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig, [CacheKeys.TOKEN_CONFIG]: tokenConfig,

View file

@ -3,6 +3,7 @@ const checkBan = require('./checkBan');
const uaParser = require('./uaParser'); const uaParser = require('./uaParser');
const setHeaders = require('./setHeaders'); const setHeaders = require('./setHeaders');
const loginLimiter = require('./loginLimiter'); const loginLimiter = require('./loginLimiter');
const validateModel = require('./validateModel');
const requireJwtAuth = require('./requireJwtAuth'); const requireJwtAuth = require('./requireJwtAuth');
const uploadLimiters = require('./uploadLimiters'); const uploadLimiters = require('./uploadLimiters');
const registerLimiter = require('./registerLimiter'); const registerLimiter = require('./registerLimiter');
@ -32,6 +33,7 @@ module.exports = {
validateMessageReq, validateMessageReq,
buildEndpointOption, buildEndpointOption,
validateRegistration, validateRegistration,
validateModel,
moderateText, moderateText,
noIndex, noIndex,
}; };

View file

@ -1,5 +1,5 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { CacheKeys } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
@ -35,7 +35,7 @@ const createFileUploadHandler = (ip = true) => {
} = getEnvironmentVariables(); } = getEnvironmentVariables();
return async (req, res) => { return async (req, res) => {
const type = CacheKeys.FILE_UPLOAD_LIMIT; const type = ViolationTypes.FILE_UPLOAD_LIMIT;
const errorMessage = { const errorMessage = {
type, type,
max: ip ? fileUploadIpMax : fileUploadUserMax, max: ip ? fileUploadIpMax : fileUploadUserMax,

View file

@ -0,0 +1,50 @@
const { EModelEndpoint, CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { logViolation, getLogStores } = require('~/cache');
const { handleError } = require('~/server/utils');
/**
* Validates the model of the request.
*
* @async
* @param {Express.Request} req - The Express request object.
* @param {Express.Response} res - The Express response object.
* @param {Function} next - The Express next function.
*/
const validateModel = async (req, res, next) => {
const { model, endpoint } = req.body;
if (!model) {
return handleError(res, { text: 'Model not provided' });
}
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
if (!modelsConfig) {
return handleError(res, { text: 'Models not loaded' });
}
const availableModels = modelsConfig[endpoint];
if (!availableModels) {
return handleError(res, { text: 'Endpoint models not loaded' });
}
let validModel = !!availableModels.find((availableModel) => availableModel === model);
if (endpoint === EModelEndpoint.gptPlugins) {
validModel = validModel && availableModels.includes(req.body.agentOptions?.model);
}
if (validModel) {
return next();
}
const { ILLEGAL_MODEL_REQ_SCORE: score = 5 } = process.env ?? {};
const type = ViolationTypes.ILLEGAL_MODEL_REQUEST;
const errorMessage = {
type,
};
await logViolation(req, res, type, errorMessage, score);
return handleError(res, { text: 'Illegal model request' });
};
module.exports = validateModel;

View file

@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const { const {
setHeaders, setHeaders,
handleAbort, handleAbort,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('~/server/middleware'); } = require('~/server/middleware');
@ -12,8 +13,15 @@ const router = express.Router();
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await AskController(req, res, next, initializeClient); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient);
},
);
module.exports = router; module.exports = router;

View file

@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI');
const { const {
handleAbort, handleAbort,
setHeaders, setHeaders,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('~/server/middleware'); } = require('~/server/middleware');
@ -13,8 +14,15 @@ const router = express.Router();
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await AskController(req, res, next, initializeClient, addTitle); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router; module.exports = router;

View file

@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google');
const { const {
setHeaders, setHeaders,
handleAbort, handleAbort,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('~/server/middleware'); } = require('~/server/middleware');
@ -12,8 +13,15 @@ const router = express.Router();
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await AskController(req, res, next, initializeClient); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient);
},
);
module.exports = router; module.exports = router;

View file

@ -11,6 +11,7 @@ const {
createAbortController, createAbortController,
handleAbortError, handleAbortError,
setHeaders, setHeaders,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
moderateText, moderateText,
@ -20,207 +21,217 @@ const { logger } = require('~/config');
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { router.post(
let { '/',
text, validateEndpoint,
endpointOption, validateModel,
conversationId, buildEndpointOption,
parentMessageId = null, setHeaders,
overrideParentMessageId = null, async (req, res) => {
} = req.body; let {
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); text,
let metadata; endpointOption,
let userMessage; conversationId,
let promptTokens; parentMessageId = null,
let userMessageId; overrideParentMessageId = null,
let responseMessageId; } = req.body;
let lastSavedTimestamp = 0; logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
let saveDelay = 100; let metadata;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); let userMessage;
const newConvo = !conversationId; let promptTokens;
const user = req.user.id; let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
});
const newConvo = !conversationId;
const user = req.user.id;
const plugins = []; const plugins = [];
const addMetadata = (data) => (metadata = data); const addMetadata = (data) => (metadata = data);
const getReqData = (data = {}) => { const getReqData = (data = {}) => {
for (let key in data) { for (let key in data) {
if (key === 'userMessage') { if (key === 'userMessage') {
userMessage = data[key]; userMessage = data[key];
userMessageId = data[key].messageId; userMessageId = data[key].messageId;
} else if (key === 'responseMessageId') { } else if (key === 'responseMessageId') {
responseMessageId = data[key]; responseMessageId = data[key];
} else if (key === 'promptTokens') { } else if (key === 'promptTokens') {
promptTokens = data[key]; promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') { } else if (!conversationId && key === 'conversationId') {
conversationId = data[key]; conversationId = data[key];
}
} }
}
};
let streaming = null;
let timer = null;
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (timer) {
clearTimeout(timer);
}
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
error: false,
plugins,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
streaming = new Promise((resolve) => {
timer = setTimeout(() => {
resolve();
}, 250);
});
},
});
const pluginMap = new Map();
const onAgentAction = async (action, runId) => {
pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, { plugins });
};
const onToolStart = async (tool, input, runId, parentRunId) => {
const pluginName = pluginMap.get(parentRunId);
const latestPlugin = {
runId,
loading: true,
inputs: [input],
latest: pluginName,
outputs: null,
}; };
if (streaming) { let streaming = null;
await streaming; let timer = null;
}
const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin);
sendIntermediateMessage(res, { plugins }, extraTokens);
};
const onToolEnd = async (output, runId) => { const {
if (streaming) { onProgress: progressCallback,
await streaming; sendIntermediateMessage,
}
const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
if (pluginIndex !== -1) {
plugins[pluginIndex].loading = false;
plugins[pluginIndex].outputs = output;
}
};
const onChainEnd = () => {
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, { plugins });
};
const getAbortData = () => ({
sender,
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugins: plugins.map((p) => ({ ...p, loading: false })),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
user,
conversationId,
parentMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
onStart,
addMetadata,
getPartialText, getPartialText,
...endpointOption, } = createOnProgress({
onProgress: progressCallback.call(null, { onProgress: ({ text: partialText }) => {
res, const currentTimestamp = Date.now();
text,
parentMessageId: overrideParentMessageId || userMessageId, if (timer) {
plugins, clearTimeout(timer);
}), }
abortController,
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
error: false,
plugins,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
streaming = new Promise((resolve) => {
timer = setTimeout(() => {
resolve();
}, 250);
});
},
}); });
if (overrideParentMessageId) { const pluginMap = new Map();
response.parentMessageId = overrideParentMessageId; const onAgentAction = async (action, runId) => {
} pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, { plugins });
};
if (metadata) { const onToolStart = async (tool, input, runId, parentRunId) => {
response = { ...response, ...metadata }; const pluginName = pluginMap.get(parentRunId);
} const latestPlugin = {
runId,
loading: true,
inputs: [input],
latest: pluginName,
outputs: null,
};
logger.debug('[/ask/gptPlugins]', response); if (streaming) {
await streaming;
}
const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin);
sendIntermediateMessage(res, { plugins }, extraTokens);
};
response.plugins = plugins.map((p) => ({ ...p, loading: false })); const onToolEnd = async (output, runId) => {
await saveMessage({ ...response, user }); if (streaming) {
await streaming;
}
sendMessage(res, { const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
title: await getConvoTitle(user, conversationId),
final: true, if (pluginIndex !== -1) {
conversation: await getConvo(user, conversationId), plugins[pluginIndex].loading = false;
requestMessage: userMessage, plugins[pluginIndex].outputs = output;
responseMessage: response, }
};
const onChainEnd = () => {
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, { plugins });
};
const getAbortData = () => ({
sender,
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugins: plugins.map((p) => ({ ...p, loading: false })),
userMessage,
promptTokens,
}); });
res.end(); const { abortController, onStart } = createAbortController(req, res, getAbortData);
if (parentMessageId === Constants.NO_PARENT && newConvo) { try {
addTitle(req, { endpointOption.tools = await validateTools(user, endpointOption.tools);
text, const { client } = await initializeClient({ req, res, endpointOption });
response,
client, let response = await client.sendMessage(text, {
user,
conversationId,
parentMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
onStart,
addMetadata,
getPartialText,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
plugins,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
logger.debug('[/ask/gptPlugins]', response);
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response,
client,
});
}
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
}); });
} }
} catch (error) { },
const partialText = getPartialText(); );
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
});
module.exports = router; module.exports = router;

View file

@ -4,6 +4,7 @@ const { addTitle, initializeClient } = require('~/server/services/Endpoints/open
const { const {
handleAbort, handleAbort,
setHeaders, setHeaders,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
moderateText, moderateText,
@ -13,8 +14,15 @@ const router = express.Router();
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await AskController(req, res, next, initializeClient, addTitle); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router; module.exports = router;

View file

@ -21,6 +21,7 @@ const router = express.Router();
const { const {
setHeaders, setHeaders,
handleAbort, handleAbort,
validateModel,
handleAbortError, handleAbortError,
// validateEndpoint, // validateEndpoint,
buildEndpointOption, buildEndpointOption,
@ -36,7 +37,7 @@ router.post('/abort', handleAbort());
* @param {express.Response} res - The response object, used to send back a response. * @param {express.Response} res - The response object, used to send back a response.
* @returns {void} * @returns {void}
*/ */
router.post('/', buildEndpointOption, setHeaders, async (req, res) => { router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res) => {
logger.debug('[/assistants/chat/] req.body', req.body); logger.debug('[/assistants/chat/] req.body', req.body);
const { const {
text, text,

View file

@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const { const {
setHeaders, setHeaders,
handleAbort, handleAbort,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('~/server/middleware'); } = require('~/server/middleware');
@ -12,8 +13,15 @@ const router = express.Router();
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await EditController(req, res, next, initializeClient); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await EditController(req, res, next, initializeClient);
},
);
module.exports = router; module.exports = router;

View file

@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI');
const { const {
handleAbort, handleAbort,
setHeaders, setHeaders,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('~/server/middleware'); } = require('~/server/middleware');
@ -13,8 +14,15 @@ const router = express.Router();
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await EditController(req, res, next, initializeClient, addTitle); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await EditController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router; module.exports = router;

View file

@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google');
const { const {
setHeaders, setHeaders,
handleAbort, handleAbort,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
} = require('~/server/middleware'); } = require('~/server/middleware');
@ -12,8 +13,15 @@ const router = express.Router();
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await EditController(req, res, next, initializeClient); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await EditController(req, res, next, initializeClient);
},
);
module.exports = router; module.exports = router;

View file

@ -10,6 +10,7 @@ const {
createAbortController, createAbortController,
handleAbortError, handleAbortError,
setHeaders, setHeaders,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
moderateText, moderateText,
@ -19,179 +20,189 @@ const { logger } = require('~/config');
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { router.post(
let { '/',
text, validateEndpoint,
generation, validateModel,
endpointOption, buildEndpointOption,
conversationId, setHeaders,
responseMessageId, async (req, res) => {
isContinued = false, let {
parentMessageId = null, text,
overrideParentMessageId = null, generation,
} = req.body; endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
logger.debug('[/edit/gptPlugins]', { logger.debug('[/edit/gptPlugins]', {
text, text,
generation,
isContinued,
conversationId,
...endpointOption,
});
let metadata;
let userMessage;
let promptTokens;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model });
const userMessageId = parentMessageId;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const addMetadata = (data) => (metadata = data);
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (plugin.loading === true) {
plugin.loading = false;
}
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
isEdited: true,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, { plugin });
// logger.debug('PLUGIN ACTION', formattedAction);
};
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, { plugin });
// logger.debug('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender,
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
user,
generation, generation,
isContinued, isContinued,
isEdited: true,
conversationId, conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onStart,
addMetadata,
...endpointOption, ...endpointOption,
onProgress: progressCallback.call(null, { });
res, let metadata;
text, let userMessage;
plugin, let promptTokens;
parentMessageId: overrideParentMessageId || userMessageId, let lastSavedTimestamp = 0;
}), let saveDelay = 100;
abortController, const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
});
const userMessageId = parentMessageId;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const addMetadata = (data) => (metadata = data);
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (plugin.loading === true) {
plugin.loading = false;
}
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
isEdited: true,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
}); });
if (overrideParentMessageId) { const onAgentAction = (action, start = false) => {
response.parentMessageId = overrideParentMessageId; const formattedAction = formatAction(action);
} plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, { plugin });
// logger.debug('PLUGIN ACTION', formattedAction);
};
if (metadata) { const onChainEnd = (data) => {
response = { ...response, ...metadata }; let { intermediateSteps: steps } = data;
} plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, { plugin });
// logger.debug('CHAIN END', plugin.outputs);
};
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); const getAbortData = () => ({
response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender, sender,
conversationId,
messageId: responseMessageId, messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId, parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
promptTokens,
}); });
} const { abortController, onStart } = createAbortController(req, res, getAbortData);
});
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onStart,
addMetadata,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
plugin,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
},
);
module.exports = router; module.exports = router;

View file

@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/openAI');
const { const {
handleAbort, handleAbort,
setHeaders, setHeaders,
validateModel,
validateEndpoint, validateEndpoint,
buildEndpointOption, buildEndpointOption,
moderateText, moderateText,
@ -13,8 +14,15 @@ const router = express.Router();
router.use(moderateText); router.use(moderateText);
router.post('/abort', handleAbort()); router.post('/abort', handleAbort());
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { router.post(
await EditController(req, res, next, initializeClient); '/',
}); validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await EditController(req, res, next, initializeClient);
},
);
module.exports = router; module.exports = router;

View file

@ -69,8 +69,13 @@ MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_M
LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW
MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW
MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages
ILLEGAL_MODEL_REQ_SCORE=5 #Violation score to accrue if a user attempts to use an unlisted model.
``` ```
> Note: Illegal model requests are almost always nefarious as it means a 3rd party is attempting to access the server through an automated script. For this, I recommend a relatively high score, no less than 5.
## OpenAI moderation text ## OpenAI moderation text
### OPENAI_MODERATION ### OPENAI_MODERATION

View file

@ -602,8 +602,11 @@ REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1 CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1 MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20 NON_BROWSER_VIOLATION_SCORE=20
ILLEGAL_MODEL_REQ_SCORE=5
``` ```
> Note: Non-browser access and Illegal model requests are almost always nefarious as it means a 3rd party is attempting to access the server through an automated script.
#### Login and registration rate limiting. #### Login and registration rate limiting.
- `LOGIN_MAX`: The max amount of logins allowed per IP per `LOGIN_WINDOW` - `LOGIN_MAX`: The max amount of logins allowed per IP per `LOGIN_WINDOW`
- `LOGIN_WINDOW`: In minutes, determines the window of time for `LOGIN_MAX` logins - `LOGIN_WINDOW`: In minutes, determines the window of time for `LOGIN_MAX` logins

View file

@ -284,10 +284,20 @@ export enum CacheKeys {
* Key for the override config cache. * Key for the override config cache.
*/ */
OVERRIDE_CONFIG = 'overrideConfig', OVERRIDE_CONFIG = 'overrideConfig',
}
/**
* Enum for violation types, used to identify, log, and cache violations.
*/
export enum ViolationTypes {
/** /**
* Key for accessing File Upload Violations (exceeding limit). * File Upload Violations (exceeding limit).
*/ */
FILE_UPLOAD_LIMIT = 'file_upload_limit', FILE_UPLOAD_LIMIT = 'file_upload_limit',
/**
* Illegal Model Request (not available).
*/
ILLEGAL_MODEL_REQUEST = 'illegal_model_request',
} }
/** /**