🔐 fix: Enhance Message & Image Access Security (#3363)

* chore: slight refactor

* fix: prevent message updates unless explicitly owned

* refactor: rethrow errors, update deleteMessagesSince (not used), add basic tests

* fix: Add path normalization and validation to image request middleware

* fix: image validation path security
This commit is contained in:
Danny Avila 2024-07-17 09:51:03 -04:00 committed by GitHub
parent 0a1d38e318
commit d5d188eebf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 595 additions and 229 deletions

View file

@ -598,7 +598,11 @@ class BaseClient {
* @param {string | null} user
*/
async saveMessageToDatabase(message, endpointOptions, user = null) {
const savedMessage = await saveMessage({
if (this.user && user !== this.user) {
throw new Error('User mismatch.');
}
const savedMessage = await saveMessage(this.options.req, {
...message,
endpoint: this.options.endpoint,
unfinished: false,
@ -619,7 +623,7 @@ class BaseClient {
}
async updateMessageInDatabase(message) {
await updateMessage(message);
await updateMessage(this.options.req, message);
}
/**

View file

@ -4,11 +4,37 @@ const logger = require('~/config/winston');
const idSchema = z.string().uuid();
module.exports = {
Message,
async saveMessage({
user,
/**
* Saves a message in the database.
*
* @async
* @function saveMessage
* @param {Express.Request} req - The request object containing user information.
* @param {Object} params - The message data object.
* @param {string} params.endpoint - The endpoint where the message originated.
* @param {string} params.iconURL - The URL of the sender's icon.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.newMessageId - The new unique identifier for the message (if applicable).
* @param {string} params.conversationId - The identifier of the conversation.
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
* @param {string} params.sender - The identifier of the sender.
* @param {string} params.text - The text content of the message.
* @param {boolean} params.isCreatedByUser - Indicates if the message was created by the user.
* @param {string} [params.error] - Any error associated with the message.
* @param {boolean} [params.unfinished] - Indicates if the message is unfinished.
* @param {Object[]} [params.files] - An array of files associated with the message.
* @param {boolean} [params.isEdited] - Indicates if the message was edited.
* @param {string} [params.finish_reason] - Reason for finishing the message.
* @param {number} [params.tokenCount] - The number of tokens in the message.
* @param {string} [params.plugin] - Plugin associated with the message.
* @param {Object[]} [params.plugins] - An array of plugins associated with the message.
* @param {string} [params.model] - The model used to generate the message.
* @returns {Promise<TMessage>} The updated or newly inserted message document.
* @throws {Error} If there is an error in saving the message.
*/
async function saveMessage(
req,
{
endpoint,
iconURL,
messageId,
@ -27,15 +53,20 @@ module.exports = {
plugin,
plugins,
model,
}) {
},
) {
try {
if (!req || !req.user || !req.user.id) {
throw new Error('User not authenticated');
}
const validConvoId = idSchema.safeParse(conversationId);
if (!validConvoId.success) {
return;
throw new Error('Invalid conversation ID');
}
const update = {
user,
user: req.user.id,
iconURL,
endpoint,
messageId: newMessageId || messageId,
@ -58,7 +89,7 @@ module.exports = {
update.files = files;
}
const message = await Message.findOneAndUpdate({ messageId }, update, {
const message = await Message.findOneAndUpdate({ messageId, user: req.user.id }, update, {
upsert: true,
new: true,
});
@ -66,11 +97,20 @@ module.exports = {
return message.toObject();
} catch (err) {
logger.error('Error saving message:', err);
throw new Error('Failed to save message.');
throw err;
}
}
},
async bulkSaveMessages(messages) {
/**
* Saves multiple messages in the database in bulk.
*
* @async
* @function bulkSaveMessages
* @param {Object[]} messages - An array of message objects to save.
* @returns {Promise<Object>} The result of the bulk write operation.
* @throws {Error} If there is an error in saving messages in bulk.
*/
async function bulkSaveMessages(messages) {
try {
const bulkOps = messages.map((message) => ({
updateOne: {
@ -84,9 +124,9 @@ module.exports = {
return result;
} catch (err) {
logger.error('Error saving messages in bulk:', err);
throw new Error('Failed to save messages in bulk.');
throw err;
}
}
},
/**
* Records a message in the database.
@ -103,7 +143,14 @@ module.exports = {
* @returns {Promise<Object>} The updated or newly inserted message document.
* @throws {Error} If there is an error in saving the message.
*/
async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) {
async function recordMessage({
user,
endpoint,
messageId,
conversationId,
parentMessageId,
...rest
}) {
try {
// No parsing of convoId as may use threadId
const message = {
@ -121,27 +168,61 @@ module.exports = {
});
} catch (err) {
logger.error('Error saving message:', err);
throw new Error('Failed to save message.');
throw err;
}
},
async updateMessageText({ messageId, text }) {
}
/**
* Updates the text of a message.
*
* @async
* @function updateMessageText
* @param {Object} params - The update data object.
* @param {Object} req - The request object.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.text - The new text content of the message.
* @returns {Promise<void>}
* @throws {Error} If there is an error in updating the message text.
*/
async function updateMessageText(req, { messageId, text }) {
try {
await Message.updateOne({ messageId }, { text });
await Message.updateOne({ messageId, user: req.user.id }, { text });
} catch (err) {
logger.error('Error updating message text:', err);
throw new Error('Failed to update message text.');
throw err;
}
},
async updateMessage(message) {
}
/**
* Updates a message.
*
* @async
* @function updateMessage
* @param {Object} message - The message object containing update data.
* @param {Object} req - The request object.
* @param {string} message.messageId - The unique identifier for the message.
* @param {string} [message.text] - The new text content of the message.
* @param {Object[]} [message.files] - The files associated with the message.
* @param {boolean} [message.isCreatedByUser] - Indicates if the message was created by the user.
* @param {string} [message.sender] - The identifier of the sender.
* @param {number} [message.tokenCount] - The number of tokens in the message.
* @returns {Promise<TMessage>} The updated message document.
* @throws {Error} If there is an error in updating the message or if the message is not found.
*/
async function updateMessage(req, message) {
try {
const { messageId, ...update } = message;
update.isEdited = true;
const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, {
const updatedMessage = await Message.findOneAndUpdate(
{ messageId, user: req.user.id },
update,
{
new: true,
});
},
);
if (!updatedMessage) {
throw new Error('Message not found.');
throw new Error('Message not found or user not authorized.');
}
return {
@ -156,31 +237,49 @@ module.exports = {
};
} catch (err) {
logger.error('Error updating message:', err);
throw new Error('Failed to update message.');
throw err;
}
},
async deleteMessagesSince({ messageId, conversationId }) {
}
/**
* Deletes messages in a conversation since a specific message.
*
* @async
* @function deleteMessagesSince
* @param {Object} params - The parameters object.
* @param {Object} req - The request object.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.conversationId - The identifier of the conversation.
* @returns {Promise<Number>} The number of deleted messages.
* @throws {Error} If there is an error in deleting messages.
*/
async function deleteMessagesSince(req, { messageId, conversationId }) {
try {
const message = await Message.findOne({ messageId }).lean();
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
if (message) {
return await Message.find({ conversationId }).deleteMany({
const query = Message.find({ conversationId, user: req.user.id });
return await query.deleteMany({
createdAt: { $gt: message.createdAt },
});
}
return undefined;
} catch (err) {
logger.error('Error deleting messages:', err);
throw new Error('Failed to delete messages.');
throw err;
}
}
},
/**
* Retrieves messages from the database.
* @param {Record<string, unknown>} filter
* @param {string | undefined} [select]
* @returns
* @async
* @function getMessages
* @param {Record<string, unknown>} filter - The filter criteria.
* @param {string | undefined} [select] - The fields to select.
* @returns {Promise<TMessage[]>} The messages that match the filter criteria.
* @throws {Error} If there is an error in retrieving messages.
*/
async getMessages(filter, select) {
async function getMessages(filter, select) {
try {
if (select) {
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
@ -189,16 +288,36 @@ module.exports = {
return await Message.find(filter).sort({ createdAt: 1 }).lean();
} catch (err) {
logger.error('Error getting messages:', err);
throw new Error('Failed to get messages.');
throw err;
}
}
},
async deleteMessages(filter) {
/**
* Deletes messages from the database.
*
* @async
* @function deleteMessages
* @param {Object} filter - The filter criteria to find messages to delete.
* @returns {Promise<Number>} The number of deleted messages.
* @throws {Error} If there is an error in deleting messages.
*/
async function deleteMessages(filter) {
try {
return await Message.deleteMany(filter);
} catch (err) {
logger.error('Error deleting messages:', err);
throw new Error('Failed to delete messages.');
throw err;
}
},
}
module.exports = {
Message,
saveMessage,
bulkSaveMessages,
recordMessage,
updateMessageText,
updateMessage,
deleteMessagesSince,
getMessages,
deleteMessages,
};

239
api/models/Message.spec.js Normal file
View file

@ -0,0 +1,239 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
jest.mock('mongoose');
const mockFindQuery = {
select: jest.fn().mockReturnThis(),
sort: jest.fn().mockReturnThis(),
lean: jest.fn().mockReturnThis(),
deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }),
};
const mockSchema = {
findOneAndUpdate: jest.fn(),
updateOne: jest.fn(),
findOne: jest.fn(() => ({
lean: jest.fn(),
})),
find: jest.fn(() => mockFindQuery),
deleteMany: jest.fn(),
};
mongoose.model.mockReturnValue(mockSchema);
jest.mock('~/models/schema/messageSchema', () => mockSchema);
jest.mock('~/config/winston', () => ({
error: jest.fn(),
}));
const {
saveMessage,
getMessages,
updateMessage,
deleteMessages,
updateMessageText,
deleteMessagesSince,
} = require('~/models/Message');
describe('Message Operations', () => {
let mockReq;
let mockMessage;
beforeEach(() => {
jest.clearAllMocks();
mockReq = {
user: { id: 'user123' },
};
mockMessage = {
messageId: 'msg123',
conversationId: uuidv4(),
text: 'Hello, world!',
user: 'user123',
};
mockSchema.findOneAndUpdate.mockResolvedValue({
toObject: () => mockMessage,
});
});
describe('saveMessage', () => {
it('should save a message for an authenticated user', async () => {
const result = await saveMessage(mockReq, mockMessage);
expect(result).toEqual(mockMessage);
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: 'msg123', user: 'user123' },
expect.objectContaining({ user: 'user123' }),
expect.any(Object),
);
});
it('should throw an error for unauthenticated user', async () => {
mockReq.user = null;
await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated');
});
it('should throw an error for invalid conversation ID', async () => {
mockMessage.conversationId = 'invalid-id';
await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('Invalid conversation ID');
});
});
describe('updateMessageText', () => {
it('should update message text for the authenticated user', async () => {
await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' });
expect(mockSchema.updateOne).toHaveBeenCalledWith(
{ messageId: 'msg123', user: 'user123' },
{ text: 'Updated text' },
);
});
});
describe('updateMessage', () => {
it('should update a message for the authenticated user', async () => {
mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage);
const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' });
expect(result).toEqual(
expect.objectContaining({
messageId: 'msg123',
text: 'Hello, world!',
isEdited: true,
}),
);
});
it('should throw an error if message is not found', async () => {
mockSchema.findOneAndUpdate.mockResolvedValue(null);
await expect(
updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }),
).rejects.toThrow('Message not found or user not authorized.');
});
});
describe('deleteMessagesSince', () => {
it('should delete messages only for the authenticated user', async () => {
mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() });
mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 });
const result = await deleteMessagesSince(mockReq, {
messageId: 'msg123',
conversationId: 'convo123',
});
expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' });
expect(mockSchema.find).not.toHaveBeenCalled();
expect(result).toBeUndefined();
});
it('should return undefined if no message is found', async () => {
mockSchema.findOne().lean.mockResolvedValueOnce(null);
const result = await deleteMessagesSince(mockReq, {
messageId: 'nonexistent',
conversationId: 'convo123',
});
expect(result).toBeUndefined();
});
});
describe('getMessages', () => {
it('should retrieve messages with the correct filter', async () => {
const filter = { conversationId: 'convo123' };
await getMessages(filter);
expect(mockSchema.find).toHaveBeenCalledWith(filter);
expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 });
expect(mockFindQuery.lean).toHaveBeenCalled();
});
});
describe('deleteMessages', () => {
it('should delete messages with the correct filter', async () => {
await deleteMessages({ user: 'user123' });
expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' });
});
});
describe('Conversation Hijacking Prevention', () => {
it('should not allow editing a message in another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
mockSchema.findOneAndUpdate.mockResolvedValue(null);
await expect(
updateMessage(attackerReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
text: 'Hacked message',
}),
).rejects.toThrow('Message not found or user not authorized.');
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: victimMessageId, user: 'attacker123' },
expect.anything(),
expect.anything(),
);
});
it('should not allow deleting messages from another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user
const result = await deleteMessagesSince(attackerReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
});
expect(result).toBeUndefined();
expect(mockSchema.findOne).toHaveBeenCalledWith({
messageId: victimMessageId,
user: 'attacker123',
});
});
it('should not allow inserting a new message into another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4(); // Use a valid UUID
await expect(
saveMessage(attackerReq, {
conversationId: victimConversationId,
text: 'Inserted malicious message',
messageId: 'new-msg-123',
}),
).resolves.not.toThrow(); // It should not throw an error
// Check that the message was saved with the attacker's user ID
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: 'new-msg-123', user: 'attacker123' },
expect.objectContaining({
user: 'attacker123',
conversationId: victimConversationId,
}),
expect.anything(),
);
});
it('should allow retrieving messages from any conversation', async () => {
const victimConversationId = 'victim-convo-123';
await getMessages({ conversationId: victimConversationId });
expect(mockSchema.find).toHaveBeenCalledWith({
conversationId: victimConversationId,
});
mockSchema.find.mockReturnValueOnce({
select: jest.fn().mockReturnThis(),
sort: jest.fn().mockReturnThis(),
lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]),
});
const result = await getMessages({ conversationId: victimConversationId });
expect(result).toEqual([{ text: 'Test message' }]);
});
});
});

View file

@ -55,7 +55,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
saveMessage({
saveMessage(req, {
messageId: responseMessageId,
sender,
conversationId,
@ -144,11 +144,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(req, { ...response, user });
}
if (!client.skipSaveUserMessage) {
await saveMessage(userMessage);
await saveMessage(req, userMessage);
}
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {

View file

@ -56,7 +56,7 @@ const EditController = async (req, res, next, initializeClient) => {
generation,
onProgress: throttle(
({ text: partialText }) => {
saveMessage({
saveMessage(req, {
messageId: responseMessageId,
sender,
conversationId,
@ -141,7 +141,7 @@ const EditController = async (req, res, next, initializeClient) => {
});
res.end();
await saveMessage({ ...response, user });
await saveMessage(req, { ...response, user });
}
} catch (error) {
const partialText = getPartialText();

View file

@ -120,21 +120,22 @@ const chatV1 = async (req, res) => {
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(res, messageData, errorMessage);
return sendResponse(req, res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(res, messageData, error.message);
return sendResponse(req, res, messageData, error.message);
} else {
logger.error('[/assistants/chat/]', error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(res, messageData, defaultErrorMessage);
return sendResponse(req, res, messageData, defaultErrorMessage);
}
await sleep(2000);
@ -221,10 +222,10 @@ const chatV1 = async (req, res) => {
};
} catch (error) {
logger.error('[/assistants/chat/] Error finalizing error process', error);
return sendResponse(res, messageData, 'The Assistant run failed');
return sendResponse(req, res, messageData, 'The Assistant run failed');
}
return sendResponse(res, finalEvent);
return sendResponse(req, res, finalEvent);
};
try {

View file

@ -117,21 +117,22 @@ const chatV2 = async (req, res) => {
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(res, messageData, errorMessage);
return sendResponse(req, res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(res, messageData, error.message);
return sendResponse(req, res, messageData, error.message);
} else {
logger.error('[/assistants/chat/]', error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(res, messageData, defaultErrorMessage);
return sendResponse(req, res, messageData, defaultErrorMessage);
}
await sleep(2000);
@ -218,10 +219,10 @@ const chatV2 = async (req, res) => {
};
} catch (error) {
logger.error('[/assistants/chat/] Error finalizing error process', error);
return sendResponse(res, messageData, 'The Assistant run failed');
return sendResponse(req, res, messageData, 'The Assistant run failed');
}
return sendResponse(res, finalEvent);
return sendResponse(req, res, finalEvent);
};
try {

View file

@ -116,7 +116,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
{ promptTokens, completionTokens },
);
saveMessage({ ...responseMessage, user });
saveMessage(req, { ...responseMessage, user });
let conversation;
if (userMessagePromise) {
@ -190,7 +190,7 @@ const handleAbortError = async (res, req, error, data) => {
}
};
await sendError(res, options, callback);
await sendError(req, res, options, callback);
};
if (partialText && partialText.length > 5) {

View file

@ -41,10 +41,10 @@ const denyRequest = async (req, res, errorMessage) => {
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
if (shouldSaveMessage) {
await saveMessage({ ...userMessage, user: req.user.id });
await saveMessage(req, { ...userMessage, user: req.user.id });
}
return await sendError(res, {
return await sendError(req, res, {
sender: getResponseSender(req.body),
messageId: crypto.randomUUID(),
conversationId,

View file

@ -31,10 +31,14 @@ function validateImageRequest(req, res, next) {
return res.status(403).send('Access Denied');
}
if (req.path.includes(payload.id)) {
const fullPath = decodeURIComponent(req.originalUrl);
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
if (pathPattern.test(fullPath)) {
logger.debug('[validateImageRequest] Image request validated');
next();
} else {
logger.warn('[validateImageRequest] Invalid image path');
res.status(403).send('Access Denied');
}
}

View file

@ -51,7 +51,7 @@ router.post('/', setHeaders, async (req, res) => {
});
if (!overrideParentMessageId) {
await saveMessage({ ...userMessage, user: req.user.id });
await saveMessage(req, { ...userMessage, user: req.user.id });
await saveConvo(req.user.id, {
...userMessage,
...endpointOption,
@ -93,7 +93,7 @@ const ask = async ({
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
saveMessage(req, {
messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId,
@ -159,7 +159,7 @@ const ask = async ({
isCreatedByUser: false,
};
await saveMessage({ ...responseMessage, user });
await saveMessage(req, { ...responseMessage, user });
responseMessage.messageId = newResponseMessageId;
// STEP2 update the conversation
@ -192,7 +192,7 @@ const ask = async ({
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
if (!overrideParentMessageId) {
await saveMessage({
await saveMessage(req, {
...userMessage,
user,
messageId: userMessageId,
@ -229,7 +229,7 @@ const ask = async ({
isCreatedByUser: false,
text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`,
};
await saveMessage({ ...errorMessage, user });
await saveMessage(req, { ...errorMessage, user });
handleError(res, errorMessage);
}
};

View file

@ -70,7 +70,7 @@ router.post('/', setHeaders, async (req, res) => {
});
if (!overrideParentMessageId) {
await saveMessage({ ...userMessage, user: req.user.id });
await saveMessage(req, { ...userMessage, user: req.user.id });
await saveConvo(req.user.id, {
...userMessage,
...endpointOption,
@ -118,7 +118,7 @@ const ask = async ({
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
saveMessage(req, {
messageId: responseMessageId,
sender: model,
conversationId,
@ -197,7 +197,7 @@ const ask = async ({
isCreatedByUser: false,
};
await saveMessage({ ...responseMessage, user });
await saveMessage(req, { ...responseMessage, user });
responseMessage.messageId = newResponseMessageId;
let conversationUpdate = {
@ -221,7 +221,7 @@ const ask = async ({
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
if (!overrideParentMessageId) {
await saveMessage({
await saveMessage(req, {
...userMessage,
user,
messageId: userMessageId,
@ -266,7 +266,7 @@ const ask = async ({
isCreatedByUser: false,
};
saveMessage({ ...responseMessage, user });
saveMessage(req, { ...responseMessage, user });
return {
title: await getConvoTitle(user, conversationId),
@ -288,7 +288,7 @@ const ask = async ({
model,
isCreatedByUser: false,
};
await saveMessage({ ...errorMessage, user });
await saveMessage(req, { ...errorMessage, user });
handleError(res, errorMessage);
}
}

View file

@ -85,7 +85,7 @@ router.post(
clearTimeout(timer);
}
throttledSaveMessage({
throttledSaveMessage(req, {
messageId: responseMessageId,
sender,
conversationId,
@ -170,7 +170,7 @@ router.post(
const onChainEnd = () => {
if (!client.skipSaveUserMessage) {
saveMessage({ ...userMessage, user });
saveMessage(req, { ...userMessage, user });
}
sendIntermediateMessage(res, {
plugins,
@ -208,7 +208,7 @@ router.post(
logger.debug('[/ask/gptPlugins]', response);
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage({ ...response, user });
await saveMessage(req, { ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =

View file

@ -91,7 +91,7 @@ router.post(
plugin.loading = false;
}
throttledSaveMessage({
throttledSaveMessage(req, {
messageId: responseMessageId,
sender,
conversationId,
@ -110,7 +110,7 @@ router.post(
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage({ ...userMessage, user });
saveMessage(req, { ...userMessage, user });
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
@ -141,7 +141,7 @@ router.post(
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start && !client.skipSaveUserMessage) {
saveMessage({ ...userMessage, user });
saveMessage(req, { ...userMessage, user });
}
sendIntermediateMessage(res, {
plugin,
@ -180,7 +180,7 @@ router.post(
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user });
await saveMessage(req, { ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =

View file

@ -1,46 +1,42 @@
const express = require('express');
const router = express.Router();
const {
getMessages,
updateMessage,
saveConvo,
saveMessage,
deleteMessages,
} = require('../../models');
const { countTokens } = require('../utils');
const { requireJwtAuth, validateMessageReq } = require('../middleware/');
const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models');
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
const { countTokens } = require('~/server/utils');
router.use(requireJwtAuth);
router.use(validateMessageReq);
router.get('/:conversationId', validateMessageReq, async (req, res) => {
router.get('/:conversationId', async (req, res) => {
const { conversationId } = req.params;
res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user'));
});
// CREATE
router.post('/:conversationId', validateMessageReq, async (req, res) => {
router.post('/:conversationId', async (req, res) => {
const message = req.body;
const savedMessage = await saveMessage({ ...message, user: req.user.id });
const savedMessage = await saveMessage(req, { ...message, user: req.user.id });
await saveConvo(req.user.id, savedMessage);
res.status(201).send(savedMessage);
});
// READ
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
router.get('/:conversationId/:messageId', async (req, res) => {
const { conversationId, messageId } = req.params;
res.status(200).send(await getMessages({ conversationId, messageId }, '-_id -__v -user'));
});
// UPDATE
router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
router.put('/:conversationId/:messageId', async (req, res) => {
const { messageId, model } = req.params;
const { text } = req.body;
const tokenCount = await countTokens(text, model);
res.status(201).json(await updateMessage({ messageId, text, tokenCount }));
const result = await updateMessage(req, { messageId, text, tokenCount });
res.status(201).json(result);
});
// DELETE
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
router.delete('/:conversationId/:messageId', async (req, res) => {
const { messageId } = req.params;
await deleteMessages({ messageId });
res.status(204).send();

View file

@ -143,7 +143,7 @@ class StreamRunManager {
* @returns {Promise<void>}
*/
async saveInitialMessage() {
return saveMessage({
return saveMessage(this.req, {
conversationId: this.finalMessage.conversationId,
messageId: this.finalMessage.messageId,
parentMessageId: this.parentMessageId,

View file

@ -30,7 +30,8 @@ const sendMessage = (res, message, event = 'message') => {
/**
* Processes an error with provided options, saves the error message and sends a corresponding SSE response
* @async
* @param {object} res - The server response.
* @param {object} req - The request.
* @param {object} res - The response.
* @param {object} options - The options for handling the error containing message properties.
* @param {object} options.user - The user ID.
* @param {string} options.sender - The sender of the message.
@ -41,7 +42,7 @@ const sendMessage = (res, message, event = 'message') => {
* @param {boolean} options.shouldSaveMessage - [Optional] Whether the message should be saved. Default is true.
* @param {function} callback - [Optional] The callback function to be executed.
*/
const sendError = async (res, options, callback) => {
const sendError = async (req, res, options, callback) => {
const {
user,
sender,
@ -69,7 +70,7 @@ const sendError = async (res, options, callback) => {
}
if (shouldSaveMessage) {
await saveMessage({ ...errorMessage, user });
await saveMessage(req, { ...errorMessage, user });
}
if (!errorMessage.error) {
@ -97,11 +98,12 @@ const sendError = async (res, options, callback) => {
/**
* Sends the response based on whether headers have been sent or not.
* @param {Express.Request} req - The server response.
* @param {Express.Response} res - The server response.
* @param {Object} data - The data to be sent.
* @param {string} [errorMessage] - The error message, if any.
*/
const sendResponse = (res, data, errorMessage) => {
const sendResponse = (req, res, data, errorMessage) => {
if (!res.headersSent) {
if (errorMessage) {
return res.status(500).json({ error: errorMessage });
@ -110,7 +112,7 @@ const sendResponse = (res, data, errorMessage) => {
}
if (errorMessage) {
return sendError(res, { ...data, text: errorMessage });
return sendError(req, res, { ...data, text: errorMessage });
}
return sendMessage(res, data);
};