mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-31 23:58:50 +01:00
🔐 fix: Enhance Message & Image Access Security (#3363)
* chore: slight refactor * fix: prevent message updates unless explicitly owned * refactor: rethrow errors, update deleteMessagesSince (not used), add basic tests * fix: Add path normalization and validation to image request middleware * fix: image validation path security
This commit is contained in:
parent
0a1d38e318
commit
d5d188eebf
17 changed files with 595 additions and 229 deletions
|
|
@ -55,7 +55,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
onProgress: throttle(
|
||||
({ text: partialText }) => {
|
||||
saveMessage({
|
||||
saveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
|
|
@ -144,11 +144,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(req, { ...response, user });
|
||||
}
|
||||
|
||||
if (!client.skipSaveUserMessage) {
|
||||
await saveMessage(userMessage);
|
||||
await saveMessage(req, userMessage);
|
||||
}
|
||||
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
generation,
|
||||
onProgress: throttle(
|
||||
({ text: partialText }) => {
|
||||
saveMessage({
|
||||
saveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
|
|
@ -141,7 +141,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(req, { ...response, user });
|
||||
}
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
|
|
|
|||
|
|
@ -120,21 +120,22 @@ const chatV1 = async (req, res) => {
|
|||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(res, messageData, errorMessage);
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
req,
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(res, messageData, error.message);
|
||||
return sendResponse(req, res, messageData, error.message);
|
||||
} else {
|
||||
logger.error('[/assistants/chat/]', error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(res, messageData, defaultErrorMessage);
|
||||
return sendResponse(req, res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
|
|
@ -221,10 +222,10 @@ const chatV1 = async (req, res) => {
|
|||
};
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error finalizing error process', error);
|
||||
return sendResponse(res, messageData, 'The Assistant run failed');
|
||||
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(res, finalEvent);
|
||||
return sendResponse(req, res, finalEvent);
|
||||
};
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -117,21 +117,22 @@ const chatV2 = async (req, res) => {
|
|||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(res, messageData, errorMessage);
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
req,
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(res, messageData, error.message);
|
||||
return sendResponse(req, res, messageData, error.message);
|
||||
} else {
|
||||
logger.error('[/assistants/chat/]', error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(res, messageData, defaultErrorMessage);
|
||||
return sendResponse(req, res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
|
|
@ -218,10 +219,10 @@ const chatV2 = async (req, res) => {
|
|||
};
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error finalizing error process', error);
|
||||
return sendResponse(res, messageData, 'The Assistant run failed');
|
||||
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(res, finalEvent);
|
||||
return sendResponse(req, res, finalEvent);
|
||||
};
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
|||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
saveMessage({ ...responseMessage, user });
|
||||
saveMessage(req, { ...responseMessage, user });
|
||||
|
||||
let conversation;
|
||||
if (userMessagePromise) {
|
||||
|
|
@ -190,7 +190,7 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
}
|
||||
};
|
||||
|
||||
await sendError(res, options, callback);
|
||||
await sendError(req, res, options, callback);
|
||||
};
|
||||
|
||||
if (partialText && partialText.length > 5) {
|
||||
|
|
|
|||
|
|
@ -41,10 +41,10 @@ const denyRequest = async (req, res, errorMessage) => {
|
|||
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
|
||||
|
||||
if (shouldSaveMessage) {
|
||||
await saveMessage({ ...userMessage, user: req.user.id });
|
||||
await saveMessage(req, { ...userMessage, user: req.user.id });
|
||||
}
|
||||
|
||||
return await sendError(res, {
|
||||
return await sendError(req, res, {
|
||||
sender: getResponseSender(req.body),
|
||||
messageId: crypto.randomUUID(),
|
||||
conversationId,
|
||||
|
|
|
|||
|
|
@ -31,10 +31,14 @@ function validateImageRequest(req, res, next) {
|
|||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (req.path.includes(payload.id)) {
|
||||
const fullPath = decodeURIComponent(req.originalUrl);
|
||||
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ router.post('/', setHeaders, async (req, res) => {
|
|||
});
|
||||
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({ ...userMessage, user: req.user.id });
|
||||
await saveMessage(req, { ...userMessage, user: req.user.id });
|
||||
await saveConvo(req.user.id, {
|
||||
...userMessage,
|
||||
...endpointOption,
|
||||
|
|
@ -93,7 +93,7 @@ const ask = async ({
|
|||
const currentTimestamp = Date.now();
|
||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
saveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
|
||||
conversationId,
|
||||
|
|
@ -159,7 +159,7 @@ const ask = async ({
|
|||
isCreatedByUser: false,
|
||||
};
|
||||
|
||||
await saveMessage({ ...responseMessage, user });
|
||||
await saveMessage(req, { ...responseMessage, user });
|
||||
responseMessage.messageId = newResponseMessageId;
|
||||
|
||||
// STEP2 update the conversation
|
||||
|
|
@ -192,7 +192,7 @@ const ask = async ({
|
|||
|
||||
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({
|
||||
await saveMessage(req, {
|
||||
...userMessage,
|
||||
user,
|
||||
messageId: userMessageId,
|
||||
|
|
@ -229,7 +229,7 @@ const ask = async ({
|
|||
isCreatedByUser: false,
|
||||
text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`,
|
||||
};
|
||||
await saveMessage({ ...errorMessage, user });
|
||||
await saveMessage(req, { ...errorMessage, user });
|
||||
handleError(res, errorMessage);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ router.post('/', setHeaders, async (req, res) => {
|
|||
});
|
||||
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({ ...userMessage, user: req.user.id });
|
||||
await saveMessage(req, { ...userMessage, user: req.user.id });
|
||||
await saveConvo(req.user.id, {
|
||||
...userMessage,
|
||||
...endpointOption,
|
||||
|
|
@ -118,7 +118,7 @@ const ask = async ({
|
|||
const currentTimestamp = Date.now();
|
||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
saveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender: model,
|
||||
conversationId,
|
||||
|
|
@ -197,7 +197,7 @@ const ask = async ({
|
|||
isCreatedByUser: false,
|
||||
};
|
||||
|
||||
await saveMessage({ ...responseMessage, user });
|
||||
await saveMessage(req, { ...responseMessage, user });
|
||||
responseMessage.messageId = newResponseMessageId;
|
||||
|
||||
let conversationUpdate = {
|
||||
|
|
@ -221,7 +221,7 @@ const ask = async ({
|
|||
|
||||
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({
|
||||
await saveMessage(req, {
|
||||
...userMessage,
|
||||
user,
|
||||
messageId: userMessageId,
|
||||
|
|
@ -266,7 +266,7 @@ const ask = async ({
|
|||
isCreatedByUser: false,
|
||||
};
|
||||
|
||||
saveMessage({ ...responseMessage, user });
|
||||
saveMessage(req, { ...responseMessage, user });
|
||||
|
||||
return {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
|
|
@ -288,7 +288,7 @@ const ask = async ({
|
|||
model,
|
||||
isCreatedByUser: false,
|
||||
};
|
||||
await saveMessage({ ...errorMessage, user });
|
||||
await saveMessage(req, { ...errorMessage, user });
|
||||
handleError(res, errorMessage);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ router.post(
|
|||
clearTimeout(timer);
|
||||
}
|
||||
|
||||
throttledSaveMessage({
|
||||
throttledSaveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
|
|
@ -170,7 +170,7 @@ router.post(
|
|||
|
||||
const onChainEnd = () => {
|
||||
if (!client.skipSaveUserMessage) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
saveMessage(req, { ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
|
|
@ -208,7 +208,7 @@ router.post(
|
|||
logger.debug('[/ask/gptPlugins]', response);
|
||||
|
||||
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(req, { ...response, user });
|
||||
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ router.post(
|
|||
plugin.loading = false;
|
||||
}
|
||||
|
||||
throttledSaveMessage({
|
||||
throttledSaveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
|
|
@ -110,7 +110,7 @@ router.post(
|
|||
let { intermediateSteps: steps } = data;
|
||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||
plugin.loading = false;
|
||||
saveMessage({ ...userMessage, user });
|
||||
saveMessage(req, { ...userMessage, user });
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
|
|
@ -141,7 +141,7 @@ router.post(
|
|||
plugin.inputs.push(formattedAction);
|
||||
plugin.latest = formattedAction.plugin;
|
||||
if (!start && !client.skipSaveUserMessage) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
saveMessage(req, { ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
|
|
@ -180,7 +180,7 @@ router.post(
|
|||
|
||||
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
|
||||
response.plugin = { ...plugin, loading: false };
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(req, { ...response, user });
|
||||
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
|
|
|
|||
|
|
@ -1,46 +1,42 @@
|
|||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const {
|
||||
getMessages,
|
||||
updateMessage,
|
||||
saveConvo,
|
||||
saveMessage,
|
||||
deleteMessages,
|
||||
} = require('../../models');
|
||||
const { countTokens } = require('../utils');
|
||||
const { requireJwtAuth, validateMessageReq } = require('../middleware/');
|
||||
const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models');
|
||||
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(validateMessageReq);
|
||||
|
||||
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
router.get('/:conversationId', async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user'));
|
||||
});
|
||||
|
||||
// CREATE
|
||||
router.post('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
router.post('/:conversationId', async (req, res) => {
|
||||
const message = req.body;
|
||||
const savedMessage = await saveMessage({ ...message, user: req.user.id });
|
||||
const savedMessage = await saveMessage(req, { ...message, user: req.user.id });
|
||||
await saveConvo(req.user.id, savedMessage);
|
||||
res.status(201).send(savedMessage);
|
||||
});
|
||||
|
||||
// READ
|
||||
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
router.get('/:conversationId/:messageId', async (req, res) => {
|
||||
const { conversationId, messageId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId, messageId }, '-_id -__v -user'));
|
||||
});
|
||||
|
||||
// UPDATE
|
||||
router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
router.put('/:conversationId/:messageId', async (req, res) => {
|
||||
const { messageId, model } = req.params;
|
||||
const { text } = req.body;
|
||||
const tokenCount = await countTokens(text, model);
|
||||
res.status(201).json(await updateMessage({ messageId, text, tokenCount }));
|
||||
const result = await updateMessage(req, { messageId, text, tokenCount });
|
||||
res.status(201).json(result);
|
||||
});
|
||||
|
||||
// DELETE
|
||||
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
router.delete('/:conversationId/:messageId', async (req, res) => {
|
||||
const { messageId } = req.params;
|
||||
await deleteMessages({ messageId });
|
||||
res.status(204).send();
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class StreamRunManager {
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
async saveInitialMessage() {
|
||||
return saveMessage({
|
||||
return saveMessage(this.req, {
|
||||
conversationId: this.finalMessage.conversationId,
|
||||
messageId: this.finalMessage.messageId,
|
||||
parentMessageId: this.parentMessageId,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ const sendMessage = (res, message, event = 'message') => {
|
|||
/**
|
||||
* Processes an error with provided options, saves the error message and sends a corresponding SSE response
|
||||
* @async
|
||||
* @param {object} res - The server response.
|
||||
* @param {object} req - The request.
|
||||
* @param {object} res - The response.
|
||||
* @param {object} options - The options for handling the error containing message properties.
|
||||
* @param {object} options.user - The user ID.
|
||||
* @param {string} options.sender - The sender of the message.
|
||||
|
|
@ -41,7 +42,7 @@ const sendMessage = (res, message, event = 'message') => {
|
|||
* @param {boolean} options.shouldSaveMessage - [Optional] Whether the message should be saved. Default is true.
|
||||
* @param {function} callback - [Optional] The callback function to be executed.
|
||||
*/
|
||||
const sendError = async (res, options, callback) => {
|
||||
const sendError = async (req, res, options, callback) => {
|
||||
const {
|
||||
user,
|
||||
sender,
|
||||
|
|
@ -69,7 +70,7 @@ const sendError = async (res, options, callback) => {
|
|||
}
|
||||
|
||||
if (shouldSaveMessage) {
|
||||
await saveMessage({ ...errorMessage, user });
|
||||
await saveMessage(req, { ...errorMessage, user });
|
||||
}
|
||||
|
||||
if (!errorMessage.error) {
|
||||
|
|
@ -97,11 +98,12 @@ const sendError = async (res, options, callback) => {
|
|||
|
||||
/**
|
||||
* Sends the response based on whether headers have been sent or not.
|
||||
* @param {Express.Request} req - The server response.
|
||||
* @param {Express.Response} res - The server response.
|
||||
* @param {Object} data - The data to be sent.
|
||||
* @param {string} [errorMessage] - The error message, if any.
|
||||
*/
|
||||
const sendResponse = (res, data, errorMessage) => {
|
||||
const sendResponse = (req, res, data, errorMessage) => {
|
||||
if (!res.headersSent) {
|
||||
if (errorMessage) {
|
||||
return res.status(500).json({ error: errorMessage });
|
||||
|
|
@ -110,7 +112,7 @@ const sendResponse = (res, data, errorMessage) => {
|
|||
}
|
||||
|
||||
if (errorMessage) {
|
||||
return sendError(res, { ...data, text: errorMessage });
|
||||
return sendError(req, res, { ...data, text: errorMessage });
|
||||
}
|
||||
return sendMessage(res, data);
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue