mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-27 12:54:09 +01:00
📌 fix: Populate userMessage.files Before First DB Save (#11939)
* fix: populate userMessage.files before first DB save * fix: ESLint error fixed * fix: deduplicate file-population logic and add test coverage Extract `buildMessageFiles` helper into `packages/api/src/utils/message` to replace three near-identical loops in BaseClient and both agent controllers. Fixes set poisoning from undefined file_id entries, moves file population inside the skipSaveUserMessage guard to avoid wasted work, and adds full unit test coverage for the new behavior. * chore: reorder import statements in openIdJwtStrategy.js for consistency --------- Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
parent
13df8ed67c
commit
3a079b980a
6 changed files with 258 additions and 48 deletions
|
|
@ -4,6 +4,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
countTokens,
|
countTokens,
|
||||||
getBalanceConfig,
|
getBalanceConfig,
|
||||||
|
buildMessageFiles,
|
||||||
extractFileContext,
|
extractFileContext,
|
||||||
encodeAndFormatAudios,
|
encodeAndFormatAudios,
|
||||||
encodeAndFormatVideos,
|
encodeAndFormatVideos,
|
||||||
|
|
@ -670,6 +671,14 @@ class BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isEdited && !this.skipSaveUserMessage) {
|
if (!isEdited && !this.skipSaveUserMessage) {
|
||||||
|
const reqFiles = this.options.req?.body?.files;
|
||||||
|
if (reqFiles && Array.isArray(this.options.attachments)) {
|
||||||
|
const files = buildMessageFiles(reqFiles, this.options.attachments);
|
||||||
|
if (files.length > 0) {
|
||||||
|
userMessage.files = files;
|
||||||
|
}
|
||||||
|
delete userMessage.image_urls;
|
||||||
|
}
|
||||||
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||||
this.savedMessageIds.add(userMessage.messageId);
|
this.savedMessageIds.add(userMessage.messageId);
|
||||||
if (typeof opts?.getReqData === 'function') {
|
if (typeof opts?.getReqData === 'function') {
|
||||||
|
|
|
||||||
|
|
@ -928,4 +928,123 @@ describe('BaseClient', () => {
|
||||||
expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label)
|
expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label)
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('sendMessage file population', () => {
|
||||||
|
const attachment = {
|
||||||
|
file_id: 'file-abc',
|
||||||
|
filename: 'image.png',
|
||||||
|
filepath: '/uploads/image.png',
|
||||||
|
type: 'image/png',
|
||||||
|
bytes: 1024,
|
||||||
|
object: 'file',
|
||||||
|
user: 'user-1',
|
||||||
|
embedded: false,
|
||||||
|
usage: 0,
|
||||||
|
text: 'large ocr blob that should be stripped',
|
||||||
|
_id: 'mongo-id-1',
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
TestClient.options.req = { body: { files: [{ file_id: 'file-abc' }] } };
|
||||||
|
TestClient.options.attachments = [attachment];
|
||||||
|
});
|
||||||
|
|
||||||
|
test('populates userMessage.files before saveMessageToDatabase is called', async () => {
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockImplementation((msg) => {
|
||||||
|
return Promise.resolve({ message: msg });
|
||||||
|
});
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave).toBeDefined();
|
||||||
|
expect(userSave[0].files).toBeDefined();
|
||||||
|
expect(userSave[0].files).toHaveLength(1);
|
||||||
|
expect(userSave[0].files[0].file_id).toBe('file-abc');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('strips text and _id from files before saving', async () => {
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files[0].text).toBeUndefined();
|
||||||
|
expect(userSave[0].files[0]._id).toBeUndefined();
|
||||||
|
expect(userSave[0].files[0].filename).toBe('image.png');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('deletes image_urls from userMessage when files are present', async () => {
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
TestClient.options.attachments = [
|
||||||
|
{ ...attachment, image_urls: ['data:image/png;base64,...'] },
|
||||||
|
];
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].image_urls).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('does not set files when no attachments match request file IDs', async () => {
|
||||||
|
TestClient.options.req = { body: { files: [{ file_id: 'file-nomatch' }] } };
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('skips file population when attachments is not an array (Promise case)', async () => {
|
||||||
|
TestClient.options.attachments = Promise.resolve([attachment]);
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('skips file population when skipSaveUserMessage is true', async () => {
|
||||||
|
TestClient.skipSaveUserMessage = true;
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg?.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('ignores file_id: undefined entries in req.body.files (no set poisoning)', async () => {
|
||||||
|
TestClient.options.req = {
|
||||||
|
body: { files: [{ file_id: undefined }, { file_id: 'file-abc' }] },
|
||||||
|
};
|
||||||
|
TestClient.options.attachments = [
|
||||||
|
{ ...attachment, file_id: undefined },
|
||||||
|
{ ...attachment, file_id: 'file-abc' },
|
||||||
|
];
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files).toHaveLength(1);
|
||||||
|
expect(userSave[0].files[0].file_id).toBe('file-abc');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ const { Constants, ViolationTypes } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
sendEvent,
|
sendEvent,
|
||||||
getViolationInfo,
|
getViolationInfo,
|
||||||
|
buildMessageFiles,
|
||||||
GenerationJobManager,
|
GenerationJobManager,
|
||||||
decrementPendingRequest,
|
decrementPendingRequest,
|
||||||
sanitizeFileForTransmit,
|
|
||||||
sanitizeMessageForTransmit,
|
sanitizeMessageForTransmit,
|
||||||
checkAndIncrementPendingRequest,
|
checkAndIncrementPendingRequest,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
|
|
@ -252,13 +252,10 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
conversation.title =
|
conversation.title =
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
if (req.body.files && client.options?.attachments) {
|
if (req.body.files && Array.isArray(client.options.attachments)) {
|
||||||
userMessage.files = [];
|
const files = buildMessageFiles(req.body.files, client.options.attachments);
|
||||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
if (files.length > 0) {
|
||||||
for (const attachment of client.options.attachments) {
|
userMessage.files = files;
|
||||||
if (messageFiles.has(attachment.file_id)) {
|
|
||||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete userMessage.image_urls;
|
delete userMessage.image_urls;
|
||||||
}
|
}
|
||||||
|
|
@ -639,14 +636,10 @@ const _LegacyAgentController = async (req, res, next, initializeClient, addTitle
|
||||||
conversation.title =
|
conversation.title =
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
// Process files if needed (sanitize to remove large text fields before transmission)
|
if (req.body.files && Array.isArray(client.options.attachments)) {
|
||||||
if (req.body.files && client.options?.attachments) {
|
const files = buildMessageFiles(req.body.files, client.options.attachments);
|
||||||
userMessage.files = [];
|
if (files.length > 0) {
|
||||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
userMessage.files = files;
|
||||||
for (const attachment of client.options.attachments) {
|
|
||||||
if (messageFiles.has(attachment.file_id)) {
|
|
||||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete userMessage.image_urls;
|
delete userMessage.image_urls;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ const { logger } = require('@librechat/data-schemas');
|
||||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||||
const { SystemRoles } = require('librechat-data-provider');
|
const { SystemRoles } = require('librechat-data-provider');
|
||||||
const { isEnabled, findOpenIDUser, math } = require('@librechat/api');
|
const { isEnabled, findOpenIDUser, math } = require('@librechat/api');
|
||||||
const { getOpenIdEmail } = require('./openidStrategy');
|
|
||||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||||
|
const { getOpenIdEmail } = require('./openidStrategy');
|
||||||
const { updateUser, findUser } = require('~/models');
|
const { updateUser, findUser } = require('~/models');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
import { Constants } from 'librechat-data-provider';
|
import { Constants } from 'librechat-data-provider';
|
||||||
import { sanitizeFileForTransmit, sanitizeMessageForTransmit, getThreadData } from './message';
|
import {
|
||||||
|
sanitizeMessageForTransmit,
|
||||||
|
sanitizeFileForTransmit,
|
||||||
|
buildMessageFiles,
|
||||||
|
getThreadData,
|
||||||
|
} from './message';
|
||||||
|
|
||||||
/** Cast to string for type compatibility with ThreadMessage */
|
/** Cast to string for type compatibility with ThreadMessage */
|
||||||
const NO_PARENT = Constants.NO_PARENT as string;
|
const NO_PARENT = Constants.NO_PARENT as string;
|
||||||
|
|
@ -125,47 +130,107 @@ describe('sanitizeMessageForTransmit', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('buildMessageFiles', () => {
|
||||||
|
const baseAttachment = {
|
||||||
|
file_id: 'file-1',
|
||||||
|
filename: 'test.png',
|
||||||
|
filepath: '/uploads/test.png',
|
||||||
|
type: 'image/png',
|
||||||
|
bytes: 512,
|
||||||
|
object: 'file' as const,
|
||||||
|
user: 'user-1',
|
||||||
|
embedded: false,
|
||||||
|
usage: 0,
|
||||||
|
text: 'big ocr text',
|
||||||
|
_id: 'mongo-id',
|
||||||
|
};
|
||||||
|
|
||||||
|
it('returns sanitized files matching request file IDs', () => {
|
||||||
|
const result = buildMessageFiles([{ file_id: 'file-1' }], [baseAttachment]);
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result?.[0].file_id).toBe('file-1');
|
||||||
|
expect(result?.[0]).not.toHaveProperty('text');
|
||||||
|
expect(result?.[0]).not.toHaveProperty('_id');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns undefined when no attachments match request IDs', () => {
|
||||||
|
const result = buildMessageFiles([{ file_id: 'file-nomatch' }], [baseAttachment]);
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns undefined for empty attachments array', () => {
|
||||||
|
const result = buildMessageFiles([{ file_id: 'file-1' }], []);
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns undefined for empty request files array', () => {
|
||||||
|
const result = buildMessageFiles([], [baseAttachment]);
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('filters out undefined file_id entries in request files (no set poisoning)', () => {
|
||||||
|
const undefinedAttachment = { ...baseAttachment, file_id: undefined as unknown as string };
|
||||||
|
const result = buildMessageFiles(
|
||||||
|
[{ file_id: undefined }, { file_id: 'file-1' }],
|
||||||
|
[undefinedAttachment, baseAttachment],
|
||||||
|
);
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result?.[0].file_id).toBe('file-1');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns only attachments whose file_id is in the request set', () => {
|
||||||
|
const attachment2 = { ...baseAttachment, file_id: 'file-2', filename: 'b.png' };
|
||||||
|
const result = buildMessageFiles([{ file_id: 'file-1' }], [baseAttachment, attachment2]);
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result?.[0].file_id).toBe('file-1');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not mutate original attachment objects', () => {
|
||||||
|
buildMessageFiles([{ file_id: 'file-1' }], [baseAttachment]);
|
||||||
|
expect(baseAttachment.text).toBe('big ocr text');
|
||||||
|
expect(baseAttachment._id).toBe('mongo-id');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getThreadData', () => {
|
describe('getThreadData', () => {
|
||||||
describe('edge cases - empty and null inputs', () => {
|
it('should return empty result for empty messages array', () => {
|
||||||
it('should return empty result for empty messages array', () => {
|
const result = getThreadData([], 'parent-123');
|
||||||
const result = getThreadData([], 'parent-123');
|
|
||||||
|
|
||||||
expect(result.messageIds).toEqual([]);
|
expect(result.messageIds).toEqual([]);
|
||||||
expect(result.fileIds).toEqual([]);
|
expect(result.fileIds).toEqual([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return empty result for null parentMessageId', () => {
|
it('should return empty result for null parentMessageId', () => {
|
||||||
const messages = [
|
const messages = [
|
||||||
{ messageId: 'msg-1', parentMessageId: null },
|
{ messageId: 'msg-1', parentMessageId: null },
|
||||||
{ messageId: 'msg-2', parentMessageId: 'msg-1' },
|
{ messageId: 'msg-2', parentMessageId: 'msg-1' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const result = getThreadData(messages, null);
|
const result = getThreadData(messages, null);
|
||||||
|
|
||||||
expect(result.messageIds).toEqual([]);
|
expect(result.messageIds).toEqual([]);
|
||||||
expect(result.fileIds).toEqual([]);
|
expect(result.fileIds).toEqual([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return empty result for undefined parentMessageId', () => {
|
it('should return empty result for undefined parentMessageId', () => {
|
||||||
const messages = [{ messageId: 'msg-1', parentMessageId: null }];
|
const messages = [{ messageId: 'msg-1', parentMessageId: null }];
|
||||||
|
|
||||||
const result = getThreadData(messages, undefined);
|
const result = getThreadData(messages, undefined);
|
||||||
|
|
||||||
expect(result.messageIds).toEqual([]);
|
expect(result.messageIds).toEqual([]);
|
||||||
expect(result.fileIds).toEqual([]);
|
expect(result.fileIds).toEqual([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return empty result when parentMessageId not found in messages', () => {
|
it('should return empty result when parentMessageId not found in messages', () => {
|
||||||
const messages = [
|
const messages = [
|
||||||
{ messageId: 'msg-1', parentMessageId: null },
|
{ messageId: 'msg-1', parentMessageId: null },
|
||||||
{ messageId: 'msg-2', parentMessageId: 'msg-1' },
|
{ messageId: 'msg-2', parentMessageId: 'msg-1' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const result = getThreadData(messages, 'non-existent');
|
const result = getThreadData(messages, 'non-existent');
|
||||||
|
|
||||||
expect(result.messageIds).toEqual([]);
|
expect(result.messageIds).toEqual([]);
|
||||||
expect(result.fileIds).toEqual([]);
|
expect(result.fileIds).toEqual([]);
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('thread traversal', () => {
|
describe('thread traversal', () => {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
import { Constants } from 'librechat-data-provider';
|
import { Constants } from 'librechat-data-provider';
|
||||||
import type { TFile, TMessage } from 'librechat-data-provider';
|
import type { TFile, TMessage } from 'librechat-data-provider';
|
||||||
|
|
||||||
|
/** Minimal shape for request file entries (from `req.body.files`) */
|
||||||
|
type RequestFile = { file_id?: string };
|
||||||
|
|
||||||
/** Fields to strip from files before client transmission */
|
/** Fields to strip from files before client transmission */
|
||||||
const FILE_STRIP_FIELDS = ['text', '_id', '__v'] as const;
|
const FILE_STRIP_FIELDS = ['text', '_id', '__v'] as const;
|
||||||
|
|
||||||
|
|
@ -32,6 +35,27 @@ export function sanitizeFileForTransmit<T extends Partial<TFile>>(
|
||||||
return sanitized;
|
return sanitized;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Filters attachments to those whose `file_id` appears in `requestFiles`, then sanitizes each. */
|
||||||
|
export function buildMessageFiles<T extends Partial<TFile>>(
|
||||||
|
requestFiles: RequestFile[],
|
||||||
|
attachments: T[],
|
||||||
|
): Omit<T, (typeof FILE_STRIP_FIELDS)[number]>[] {
|
||||||
|
const requestFileIds = new Set<string>();
|
||||||
|
for (const f of requestFiles) {
|
||||||
|
if (f.file_id) {
|
||||||
|
requestFileIds.add(f.file_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const files: Omit<T, (typeof FILE_STRIP_FIELDS)[number]>[] = [];
|
||||||
|
for (const attachment of attachments) {
|
||||||
|
if (attachment.file_id != null && requestFileIds.has(attachment.file_id)) {
|
||||||
|
files.push(sanitizeFileForTransmit(attachment));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sanitizes a message object before transmitting to client.
|
* Sanitizes a message object before transmitting to client.
|
||||||
* Removes large fields like `fileContext` and strips `text` from embedded files.
|
* Removes large fields like `fileContext` and strips `text` from embedded files.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue