mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 17:00:15 +01:00
♻️ fix: Prevent Instructions from Removal when nearing Max Context (#5516)
* refactor: getMessagesWithinTokenLimit to accept params object * refactor: always include instructions in payload if provided * ci: remove obsolete test * refactor: update logoutUser to accept request object and handle session destruction * test: enhance getMessagesWithinTokenLimit tests for instruction handling
This commit is contained in:
parent
528ee62eb1
commit
4110209494
6 changed files with 185 additions and 83 deletions
|
|
@ -416,7 +416,7 @@ class AnthropicClient extends BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
let { context: messagesInWindow, remainingContextTokens } =
|
let { context: messagesInWindow, remainingContextTokens } =
|
||||||
await this.getMessagesWithinTokenLimit(formattedMessages);
|
await this.getMessagesWithinTokenLimit({ messages: formattedMessages });
|
||||||
|
|
||||||
const tokenCountMap = orderedMessages
|
const tokenCountMap = orderedMessages
|
||||||
.slice(orderedMessages.length - messagesInWindow.length)
|
.slice(orderedMessages.length - messagesInWindow.length)
|
||||||
|
|
|
||||||
|
|
@ -347,25 +347,38 @@ class BaseClient {
|
||||||
* If the token limit would be exceeded by adding a message, that message is not added to the context and remains in the original array.
|
* If the token limit would be exceeded by adding a message, that message is not added to the context and remains in the original array.
|
||||||
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the context array at the end to maintain the original order of the messages.
|
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the context array at the end to maintain the original order of the messages.
|
||||||
*
|
*
|
||||||
* @param {Array} _messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
|
* @param {Object} params
|
||||||
* @param {number} [maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
|
* @param {TMessage[]} params.messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
|
||||||
* @returns {Object} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
|
* @param {number} [params.maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
|
||||||
|
* @param {{ role: 'system', content: text, tokenCount: number }} [params.instructions] - Instructions already added to the context at index 0.
|
||||||
|
* @returns {Promise<{
|
||||||
|
* context: TMessage[],
|
||||||
|
* remainingContextTokens: number,
|
||||||
|
* messagesToRefine: TMessage[],
|
||||||
|
* summaryIndex: number,
|
||||||
|
* }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
|
||||||
* `context` is an array of messages that fit within the token limit.
|
* `context` is an array of messages that fit within the token limit.
|
||||||
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
|
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
|
||||||
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
|
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
|
||||||
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
|
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
|
||||||
*/
|
*/
|
||||||
async getMessagesWithinTokenLimit(_messages, maxContextTokens) {
|
async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) {
|
||||||
// Every reply is primed with <|start|>assistant<|message|>, so we
|
// Every reply is primed with <|start|>assistant<|message|>, so we
|
||||||
// start with 3 tokens for the label after all messages have been counted.
|
// start with 3 tokens for the label after all messages have been counted.
|
||||||
let currentTokenCount = 3;
|
|
||||||
let summaryIndex = -1;
|
let summaryIndex = -1;
|
||||||
let remainingContextTokens = maxContextTokens ?? this.maxContextTokens;
|
let currentTokenCount = 3;
|
||||||
|
const instructionsTokenCount = instructions?.tokenCount ?? 0;
|
||||||
|
let remainingContextTokens =
|
||||||
|
(maxContextTokens ?? this.maxContextTokens) - instructionsTokenCount;
|
||||||
const messages = [..._messages];
|
const messages = [..._messages];
|
||||||
|
|
||||||
const context = [];
|
const context = [];
|
||||||
|
|
||||||
if (currentTokenCount < remainingContextTokens) {
|
if (currentTokenCount < remainingContextTokens) {
|
||||||
while (messages.length > 0 && currentTokenCount < remainingContextTokens) {
|
while (messages.length > 0 && currentTokenCount < remainingContextTokens) {
|
||||||
|
if (messages.length === 1 && instructions) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
const poppedMessage = messages.pop();
|
const poppedMessage = messages.pop();
|
||||||
const { tokenCount } = poppedMessage;
|
const { tokenCount } = poppedMessage;
|
||||||
|
|
||||||
|
|
@ -379,6 +392,11 @@ class BaseClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (instructions) {
|
||||||
|
context.push(_messages[0]);
|
||||||
|
messages.shift();
|
||||||
|
}
|
||||||
|
|
||||||
const prunedMemory = messages;
|
const prunedMemory = messages;
|
||||||
summaryIndex = prunedMemory.length - 1;
|
summaryIndex = prunedMemory.length - 1;
|
||||||
remainingContextTokens -= currentTokenCount;
|
remainingContextTokens -= currentTokenCount;
|
||||||
|
|
@ -403,12 +421,18 @@ class BaseClient {
|
||||||
if (instructions) {
|
if (instructions) {
|
||||||
({ tokenCount, ..._instructions } = instructions);
|
({ tokenCount, ..._instructions } = instructions);
|
||||||
}
|
}
|
||||||
|
|
||||||
_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
|
_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
|
||||||
let payload = this.addInstructions(formattedMessages, _instructions);
|
if (tokenCount && tokenCount > this.maxContextTokens) {
|
||||||
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
|
const info = `${tokenCount} / ${this.maxContextTokens}`;
|
||||||
|
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
||||||
|
logger.warn(`Instructions token count exceeds max token count (${info}).`);
|
||||||
|
throw new Error(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
if (this.clientName === EModelEndpoint.agents) {
|
if (this.clientName === EModelEndpoint.agents) {
|
||||||
const { dbMessages, editedIndices } = truncateToolCallOutputs(
|
const { dbMessages, editedIndices } = truncateToolCallOutputs(
|
||||||
orderedWithInstructions,
|
orderedMessages,
|
||||||
this.maxContextTokens,
|
this.maxContextTokens,
|
||||||
this.getTokenCountForMessage.bind(this),
|
this.getTokenCountForMessage.bind(this),
|
||||||
);
|
);
|
||||||
|
|
@ -416,14 +440,19 @@ class BaseClient {
|
||||||
if (editedIndices.length > 0) {
|
if (editedIndices.length > 0) {
|
||||||
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
|
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
|
||||||
for (const index of editedIndices) {
|
for (const index of editedIndices) {
|
||||||
payload[index].content = dbMessages[index].content;
|
formattedMessages[index].content = dbMessages[index].content;
|
||||||
}
|
}
|
||||||
orderedWithInstructions = dbMessages;
|
orderedMessages = dbMessages;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
|
||||||
|
|
||||||
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
|
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
|
||||||
await this.getMessagesWithinTokenLimit(orderedWithInstructions);
|
await this.getMessagesWithinTokenLimit({
|
||||||
|
messages: orderedWithInstructions,
|
||||||
|
instructions,
|
||||||
|
});
|
||||||
|
|
||||||
logger.debug('[BaseClient] Context Count (1/2)', {
|
logger.debug('[BaseClient] Context Count (1/2)', {
|
||||||
remainingContextTokens,
|
remainingContextTokens,
|
||||||
|
|
@ -435,7 +464,9 @@ class BaseClient {
|
||||||
let { shouldSummarize } = this;
|
let { shouldSummarize } = this;
|
||||||
|
|
||||||
// Calculate the difference in length to determine how many messages were discarded if any
|
// Calculate the difference in length to determine how many messages were discarded if any
|
||||||
const { length } = payload;
|
let payload;
|
||||||
|
let { length } = formattedMessages;
|
||||||
|
length += instructions != null ? 1 : 0;
|
||||||
const diff = length - context.length;
|
const diff = length - context.length;
|
||||||
const firstMessage = orderedWithInstructions[0];
|
const firstMessage = orderedWithInstructions[0];
|
||||||
const usePrevSummary =
|
const usePrevSummary =
|
||||||
|
|
@ -445,18 +476,31 @@ class BaseClient {
|
||||||
this.previous_summary.messageId === firstMessage.messageId;
|
this.previous_summary.messageId === firstMessage.messageId;
|
||||||
|
|
||||||
if (diff > 0) {
|
if (diff > 0) {
|
||||||
payload = payload.slice(diff);
|
payload = formattedMessages.slice(diff);
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[BaseClient] Difference between original payload (${length}) and context (${context.length}): ${diff}`,
|
`[BaseClient] Difference between original payload (${length}) and context (${context.length}): ${diff}`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
payload = this.addInstructions(payload ?? formattedMessages, _instructions);
|
||||||
|
|
||||||
const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1];
|
const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1];
|
||||||
if (payload.length === 0 && !shouldSummarize && latestMessage) {
|
if (payload.length === 0 && !shouldSummarize && latestMessage) {
|
||||||
const info = `${latestMessage.tokenCount} / ${this.maxContextTokens}`;
|
const info = `${latestMessage.tokenCount} / ${this.maxContextTokens}`;
|
||||||
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
||||||
logger.warn(`Prompt token count exceeds max token count (${info}).`);
|
logger.warn(`Prompt token count exceeds max token count (${info}).`);
|
||||||
throw new Error(errorMessage);
|
throw new Error(errorMessage);
|
||||||
|
} else if (
|
||||||
|
_instructions &&
|
||||||
|
payload.length === 1 &&
|
||||||
|
payload[0].content === _instructions.content
|
||||||
|
) {
|
||||||
|
const info = `${tokenCount + 3} / ${this.maxContextTokens}`;
|
||||||
|
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
||||||
|
logger.warn(
|
||||||
|
`Including instructions, the prompt token count exceeds remaining max token count (${info}).`,
|
||||||
|
);
|
||||||
|
throw new Error(errorMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (usePrevSummary) {
|
if (usePrevSummary) {
|
||||||
|
|
|
||||||
|
|
@ -931,7 +931,10 @@ ${convo}
|
||||||
);
|
);
|
||||||
|
|
||||||
if (excessTokenCount > maxContextTokens) {
|
if (excessTokenCount > maxContextTokens) {
|
||||||
({ context } = await this.getMessagesWithinTokenLimit(context, maxContextTokens));
|
({ context } = await this.getMessagesWithinTokenLimit({
|
||||||
|
messages: context,
|
||||||
|
maxContextTokens,
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (context.length === 0) {
|
if (context.length === 0) {
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ describe('BaseClient', () => {
|
||||||
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
||||||
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
||||||
|
|
||||||
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
|
||||||
|
|
||||||
expect(result.context).toEqual(expectedContext);
|
expect(result.context).toEqual(expectedContext);
|
||||||
expect(result.summaryIndex).toEqual(expectedIndex);
|
expect(result.summaryIndex).toEqual(expectedIndex);
|
||||||
|
|
@ -195,7 +195,7 @@ describe('BaseClient', () => {
|
||||||
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
||||||
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
||||||
|
|
||||||
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
|
||||||
|
|
||||||
expect(result.context).toEqual(expectedContext);
|
expect(result.context).toEqual(expectedContext);
|
||||||
expect(result.summaryIndex).toEqual(expectedIndex);
|
expect(result.summaryIndex).toEqual(expectedIndex);
|
||||||
|
|
@ -203,66 +203,6 @@ describe('BaseClient', () => {
|
||||||
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('handles context strategy correctly in handleContextStrategy()', async () => {
|
|
||||||
TestClient.addInstructions = jest
|
|
||||||
.fn()
|
|
||||||
.mockReturnValue([
|
|
||||||
{ content: 'Hello' },
|
|
||||||
{ content: 'How can I help you?' },
|
|
||||||
{ content: 'Please provide more details.' },
|
|
||||||
{ content: 'I can assist you with that.' },
|
|
||||||
]);
|
|
||||||
TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({
|
|
||||||
context: [
|
|
||||||
{ content: 'How can I help you?' },
|
|
||||||
{ content: 'Please provide more details.' },
|
|
||||||
{ content: 'I can assist you with that.' },
|
|
||||||
],
|
|
||||||
remainingContextTokens: 80,
|
|
||||||
messagesToRefine: [{ content: 'Hello' }],
|
|
||||||
summaryIndex: 3,
|
|
||||||
});
|
|
||||||
|
|
||||||
TestClient.getTokenCount = jest.fn().mockReturnValue(40);
|
|
||||||
|
|
||||||
const instructions = { content: 'Please provide more details.' };
|
|
||||||
const orderedMessages = [
|
|
||||||
{ content: 'Hello' },
|
|
||||||
{ content: 'How can I help you?' },
|
|
||||||
{ content: 'Please provide more details.' },
|
|
||||||
{ content: 'I can assist you with that.' },
|
|
||||||
];
|
|
||||||
const formattedMessages = [
|
|
||||||
{ content: 'Hello' },
|
|
||||||
{ content: 'How can I help you?' },
|
|
||||||
{ content: 'Please provide more details.' },
|
|
||||||
{ content: 'I can assist you with that.' },
|
|
||||||
];
|
|
||||||
const expectedResult = {
|
|
||||||
payload: [
|
|
||||||
{
|
|
||||||
role: 'system',
|
|
||||||
content: 'Refined answer',
|
|
||||||
},
|
|
||||||
{ content: 'How can I help you?' },
|
|
||||||
{ content: 'Please provide more details.' },
|
|
||||||
{ content: 'I can assist you with that.' },
|
|
||||||
],
|
|
||||||
promptTokens: expect.any(Number),
|
|
||||||
tokenCountMap: {},
|
|
||||||
messages: expect.any(Array),
|
|
||||||
};
|
|
||||||
|
|
||||||
TestClient.shouldSummarize = true;
|
|
||||||
const result = await TestClient.handleContextStrategy({
|
|
||||||
instructions,
|
|
||||||
orderedMessages,
|
|
||||||
formattedMessages,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(result).toEqual(expectedResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getMessagesForConversation', () => {
|
describe('getMessagesForConversation', () => {
|
||||||
it('should return an empty array if the parentMessageId does not exist', () => {
|
it('should return an empty array if the parentMessageId does not exist', () => {
|
||||||
const result = TestClient.constructor.getMessagesForConversation({
|
const result = TestClient.constructor.getMessagesForConversation({
|
||||||
|
|
@ -674,4 +614,112 @@ describe('BaseClient', () => {
|
||||||
expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message
|
expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('getMessagesWithinTokenLimit with instructions', () => {
|
||||||
|
test('should always include instructions when present', async () => {
|
||||||
|
TestClient.maxContextTokens = 50;
|
||||||
|
const instructions = {
|
||||||
|
role: 'system',
|
||||||
|
content: 'System instructions',
|
||||||
|
tokenCount: 20,
|
||||||
|
};
|
||||||
|
|
||||||
|
const messages = [
|
||||||
|
instructions,
|
||||||
|
{ role: 'user', content: 'Hello', tokenCount: 10 },
|
||||||
|
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
|
||||||
|
];
|
||||||
|
|
||||||
|
const result = await TestClient.getMessagesWithinTokenLimit({
|
||||||
|
messages,
|
||||||
|
instructions,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.context[0]).toBe(instructions);
|
||||||
|
expect(result.remainingContextTokens).toBe(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle case when messages exceed limit but instructions must be preserved', async () => {
|
||||||
|
TestClient.maxContextTokens = 30;
|
||||||
|
const instructions = {
|
||||||
|
role: 'system',
|
||||||
|
content: 'System instructions',
|
||||||
|
tokenCount: 20,
|
||||||
|
};
|
||||||
|
|
||||||
|
const messages = [
|
||||||
|
instructions,
|
||||||
|
{ role: 'user', content: 'Hello', tokenCount: 10 },
|
||||||
|
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
|
||||||
|
];
|
||||||
|
|
||||||
|
const result = await TestClient.getMessagesWithinTokenLimit({
|
||||||
|
messages,
|
||||||
|
instructions,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should only include instructions and the last message that fits
|
||||||
|
expect(result.context).toHaveLength(1);
|
||||||
|
expect(result.context[0].content).toBe(instructions.content);
|
||||||
|
expect(result.messagesToRefine).toHaveLength(2);
|
||||||
|
expect(result.remainingContextTokens).toBe(7); // 30 - 20 - 3 (assistant label)
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should work correctly without instructions (1/2)', async () => {
|
||||||
|
TestClient.maxContextTokens = 50;
|
||||||
|
const messages = [
|
||||||
|
{ role: 'user', content: 'Hello', tokenCount: 10 },
|
||||||
|
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
|
||||||
|
];
|
||||||
|
|
||||||
|
const result = await TestClient.getMessagesWithinTokenLimit({
|
||||||
|
messages,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.context).toHaveLength(2);
|
||||||
|
expect(result.remainingContextTokens).toBe(22); // 50 - 10 - 15 - 3(assistant label)
|
||||||
|
expect(result.messagesToRefine).toHaveLength(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should work correctly without instructions (2/2)', async () => {
|
||||||
|
TestClient.maxContextTokens = 30;
|
||||||
|
const messages = [
|
||||||
|
{ role: 'user', content: 'Hello', tokenCount: 10 },
|
||||||
|
{ role: 'assistant', content: 'Hi there', tokenCount: 20 },
|
||||||
|
];
|
||||||
|
|
||||||
|
const result = await TestClient.getMessagesWithinTokenLimit({
|
||||||
|
messages,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.context).toHaveLength(1);
|
||||||
|
expect(result.remainingContextTokens).toBe(7);
|
||||||
|
expect(result.messagesToRefine).toHaveLength(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle case when only instructions fit within limit', async () => {
|
||||||
|
TestClient.maxContextTokens = 25;
|
||||||
|
const instructions = {
|
||||||
|
role: 'system',
|
||||||
|
content: 'System instructions',
|
||||||
|
tokenCount: 20,
|
||||||
|
};
|
||||||
|
|
||||||
|
const messages = [
|
||||||
|
instructions,
|
||||||
|
{ role: 'user', content: 'Hello', tokenCount: 10 },
|
||||||
|
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
|
||||||
|
];
|
||||||
|
|
||||||
|
const result = await TestClient.getMessagesWithinTokenLimit({
|
||||||
|
messages,
|
||||||
|
instructions,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.context).toHaveLength(1);
|
||||||
|
expect(result.context[0]).toBe(instructions);
|
||||||
|
expect(result.messagesToRefine).toHaveLength(2);
|
||||||
|
expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label)
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ const { logger } = require('~/config');
|
||||||
const logoutController = async (req, res) => {
|
const logoutController = async (req, res) => {
|
||||||
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
|
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
|
||||||
try {
|
try {
|
||||||
const logout = await logoutUser(req.user._id, refreshToken);
|
const logout = await logoutUser(req, refreshToken);
|
||||||
const { status, message } = logout;
|
const { status, message } = logout;
|
||||||
res.clearCookie('refreshToken');
|
res.clearCookie('refreshToken');
|
||||||
return res.status(status).send({ message });
|
return res.status(status).send({ message });
|
||||||
|
|
|
||||||
|
|
@ -35,13 +35,14 @@ const genericVerificationMessage = 'Please check your email to verify your email
|
||||||
/**
|
/**
|
||||||
* Logout user
|
* Logout user
|
||||||
*
|
*
|
||||||
* @param {String} userId
|
* @param {ServerRequest} req
|
||||||
* @param {*} refreshToken
|
* @param {string} refreshToken
|
||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
const logoutUser = async (userId, refreshToken) => {
|
const logoutUser = async (req, refreshToken) => {
|
||||||
try {
|
try {
|
||||||
const session = await findSession({ userId: userId, refreshToken: refreshToken });
|
const userId = req.user._id;
|
||||||
|
const session = await findSession({ userId: userId, refreshToken });
|
||||||
|
|
||||||
if (session) {
|
if (session) {
|
||||||
try {
|
try {
|
||||||
|
|
@ -52,6 +53,12 @@ const logoutUser = async (userId, refreshToken) => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
req.session.destroy();
|
||||||
|
} catch (destroyErr) {
|
||||||
|
logger.error('[logoutUser] Failed to destroy session.', destroyErr);
|
||||||
|
}
|
||||||
|
|
||||||
return { status: 200, message: 'Logout successful' };
|
return { status: 200, message: 'Logout successful' };
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
return { status: 500, message: err.message };
|
return { status: 500, message: err.message };
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue