diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index fab82db93b..e5771aac55 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -4,6 +4,7 @@ const { logger } = require('@librechat/data-schemas'); const { countTokens, getBalanceConfig, + buildMessageFiles, extractFileContext, encodeAndFormatAudios, encodeAndFormatVideos, @@ -670,6 +671,14 @@ class BaseClient { } if (!isEdited && !this.skipSaveUserMessage) { + const reqFiles = this.options.req?.body?.files; + if (reqFiles && Array.isArray(this.options.attachments)) { + const files = buildMessageFiles(reqFiles, this.options.attachments); + if (files.length > 0) { + userMessage.files = files; + } + delete userMessage.image_urls; + } userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); this.savedMessageIds.add(userMessage.messageId); if (typeof opts?.getReqData === 'function') { diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index fed80de28c..15328af644 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -928,4 +928,123 @@ describe('BaseClient', () => { expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label) }); }); + + describe('sendMessage file population', () => { + const attachment = { + file_id: 'file-abc', + filename: 'image.png', + filepath: '/uploads/image.png', + type: 'image/png', + bytes: 1024, + object: 'file', + user: 'user-1', + embedded: false, + usage: 0, + text: 'large ocr blob that should be stripped', + _id: 'mongo-id-1', + }; + + beforeEach(() => { + TestClient.options.req = { body: { files: [{ file_id: 'file-abc' }] } }; + TestClient.options.attachments = [attachment]; + }); + + test('populates userMessage.files before saveMessageToDatabase is called', async () => { + TestClient.saveMessageToDatabase = jest.fn().mockImplementation((msg) => { + return Promise.resolve({ message: msg }); + }); + + await TestClient.sendMessage('Hello'); + + const userSave = TestClient.saveMessageToDatabase.mock.calls.find( + ([msg]) => msg.isCreatedByUser, + ); + expect(userSave).toBeDefined(); + expect(userSave[0].files).toBeDefined(); + expect(userSave[0].files).toHaveLength(1); + expect(userSave[0].files[0].file_id).toBe('file-abc'); + }); + + test('strips text and _id from files before saving', async () => { + TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} }); + + await TestClient.sendMessage('Hello'); + + const userSave = TestClient.saveMessageToDatabase.mock.calls.find( + ([msg]) => msg.isCreatedByUser, + ); + expect(userSave[0].files[0].text).toBeUndefined(); + expect(userSave[0].files[0]._id).toBeUndefined(); + expect(userSave[0].files[0].filename).toBe('image.png'); + }); + + test('deletes image_urls from userMessage when files are present', async () => { + TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} }); + TestClient.options.attachments = [ + { ...attachment, image_urls: ['data:image/png;base64,...'] }, + ]; + + await TestClient.sendMessage('Hello'); + + const userSave = TestClient.saveMessageToDatabase.mock.calls.find( + ([msg]) => msg.isCreatedByUser, + ); + expect(userSave[0].image_urls).toBeUndefined(); + }); + + test('does not set files when no attachments match request file IDs', async () => { + TestClient.options.req = { body: { files: [{ file_id: 'file-nomatch' }] } }; + TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} }); + + await TestClient.sendMessage('Hello'); + + const userSave = TestClient.saveMessageToDatabase.mock.calls.find( + ([msg]) => msg.isCreatedByUser, + ); + expect(userSave[0].files).toBeUndefined(); + }); + + test('skips file population when attachments is not an array (Promise case)', async () => { + TestClient.options.attachments = Promise.resolve([attachment]); + TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} }); + + await TestClient.sendMessage('Hello'); + + const userSave = TestClient.saveMessageToDatabase.mock.calls.find( + ([msg]) => msg.isCreatedByUser, + ); + expect(userSave[0].files).toBeUndefined(); + }); + + test('skips file population when skipSaveUserMessage is true', async () => { + TestClient.skipSaveUserMessage = true; + TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} }); + + await TestClient.sendMessage('Hello'); + + const userSave = TestClient.saveMessageToDatabase.mock.calls.find( + ([msg]) => msg?.isCreatedByUser, + ); + expect(userSave).toBeUndefined(); + }); + + test('ignores file_id: undefined entries in req.body.files (no set poisoning)', async () => { + TestClient.options.req = { + body: { files: [{ file_id: undefined }, { file_id: 'file-abc' }] }, + }; + TestClient.options.attachments = [ + { ...attachment, file_id: undefined }, + { ...attachment, file_id: 'file-abc' }, + ]; + TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} }); + + await TestClient.sendMessage('Hello'); + + const userSave = TestClient.saveMessageToDatabase.mock.calls.find( + ([msg]) => msg.isCreatedByUser, + ); + expect(userSave[0].files).toHaveLength(1); + expect(userSave[0].files[0].file_id).toBe('file-abc'); + }); + }); }); diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 79387b6e89..dea5400036 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -3,9 +3,9 @@ const { Constants, ViolationTypes } = require('librechat-data-provider'); const { sendEvent, getViolationInfo, + buildMessageFiles, GenerationJobManager, decrementPendingRequest, - sanitizeFileForTransmit, sanitizeMessageForTransmit, checkAndIncrementPendingRequest, } = require('@librechat/api'); @@ -252,13 +252,10 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - if (req.body.files && client.options?.attachments) { - userMessage.files = []; - const messageFiles = new Set(req.body.files.map((file) => file.file_id)); - for (const attachment of client.options.attachments) { - if (messageFiles.has(attachment.file_id)) { - userMessage.files.push(sanitizeFileForTransmit(attachment)); - } + if (req.body.files && Array.isArray(client.options.attachments)) { + const files = buildMessageFiles(req.body.files, client.options.attachments); + if (files.length > 0) { + userMessage.files = files; } delete userMessage.image_urls; } @@ -639,14 +636,10 @@ const _LegacyAgentController = async (req, res, next, initializeClient, addTitle conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - // Process files if needed (sanitize to remove large text fields before transmission) - if (req.body.files && client.options?.attachments) { - userMessage.files = []; - const messageFiles = new Set(req.body.files.map((file) => file.file_id)); - for (const attachment of client.options.attachments) { - if (messageFiles.has(attachment.file_id)) { - userMessage.files.push(sanitizeFileForTransmit(attachment)); - } + if (req.body.files && Array.isArray(client.options.attachments)) { + const files = buildMessageFiles(req.body.files, client.options.attachments); + if (files.length > 0) { + userMessage.files = files; } delete userMessage.image_urls; } diff --git a/api/strategies/openIdJwtStrategy.js b/api/strategies/openIdJwtStrategy.js index ececf8df54..83a40bf948 100644 --- a/api/strategies/openIdJwtStrategy.js +++ b/api/strategies/openIdJwtStrategy.js @@ -4,8 +4,8 @@ const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { SystemRoles } = require('librechat-data-provider'); const { isEnabled, findOpenIDUser, math } = require('@librechat/api'); -const { getOpenIdEmail } = require('./openidStrategy'); const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt'); +const { getOpenIdEmail } = require('./openidStrategy'); const { updateUser, findUser } = require('~/models'); /** diff --git a/packages/api/src/utils/message.spec.ts b/packages/api/src/utils/message.spec.ts index ba626c83fd..7fe6cf5239 100644 --- a/packages/api/src/utils/message.spec.ts +++ b/packages/api/src/utils/message.spec.ts @@ -1,5 +1,10 @@ import { Constants } from 'librechat-data-provider'; -import { sanitizeFileForTransmit, sanitizeMessageForTransmit, getThreadData } from './message'; +import { + sanitizeMessageForTransmit, + sanitizeFileForTransmit, + buildMessageFiles, + getThreadData, +} from './message'; /** Cast to string for type compatibility with ThreadMessage */ const NO_PARENT = Constants.NO_PARENT as string; @@ -125,47 +130,107 @@ describe('sanitizeMessageForTransmit', () => { }); }); +describe('buildMessageFiles', () => { + const baseAttachment = { + file_id: 'file-1', + filename: 'test.png', + filepath: '/uploads/test.png', + type: 'image/png', + bytes: 512, + object: 'file' as const, + user: 'user-1', + embedded: false, + usage: 0, + text: 'big ocr text', + _id: 'mongo-id', + }; + + it('returns sanitized files matching request file IDs', () => { + const result = buildMessageFiles([{ file_id: 'file-1' }], [baseAttachment]); + expect(result).toHaveLength(1); + expect(result?.[0].file_id).toBe('file-1'); + expect(result?.[0]).not.toHaveProperty('text'); + expect(result?.[0]).not.toHaveProperty('_id'); + }); + + it('returns undefined when no attachments match request IDs', () => { + const result = buildMessageFiles([{ file_id: 'file-nomatch' }], [baseAttachment]); + expect(result).toEqual([]); + }); + + it('returns undefined for empty attachments array', () => { + const result = buildMessageFiles([{ file_id: 'file-1' }], []); + expect(result).toEqual([]); + }); + + it('returns undefined for empty request files array', () => { + const result = buildMessageFiles([], [baseAttachment]); + expect(result).toEqual([]); + }); + + it('filters out undefined file_id entries in request files (no set poisoning)', () => { + const undefinedAttachment = { ...baseAttachment, file_id: undefined as unknown as string }; + const result = buildMessageFiles( + [{ file_id: undefined }, { file_id: 'file-1' }], + [undefinedAttachment, baseAttachment], + ); + expect(result).toHaveLength(1); + expect(result?.[0].file_id).toBe('file-1'); + }); + + it('returns only attachments whose file_id is in the request set', () => { + const attachment2 = { ...baseAttachment, file_id: 'file-2', filename: 'b.png' }; + const result = buildMessageFiles([{ file_id: 'file-1' }], [baseAttachment, attachment2]); + expect(result).toHaveLength(1); + expect(result?.[0].file_id).toBe('file-1'); + }); + + it('does not mutate original attachment objects', () => { + buildMessageFiles([{ file_id: 'file-1' }], [baseAttachment]); + expect(baseAttachment.text).toBe('big ocr text'); + expect(baseAttachment._id).toBe('mongo-id'); + }); +}); + describe('getThreadData', () => { - describe('edge cases - empty and null inputs', () => { - it('should return empty result for empty messages array', () => { - const result = getThreadData([], 'parent-123'); + it('should return empty result for empty messages array', () => { + const result = getThreadData([], 'parent-123'); - expect(result.messageIds).toEqual([]); - expect(result.fileIds).toEqual([]); - }); + expect(result.messageIds).toEqual([]); + expect(result.fileIds).toEqual([]); + }); - it('should return empty result for null parentMessageId', () => { - const messages = [ - { messageId: 'msg-1', parentMessageId: null }, - { messageId: 'msg-2', parentMessageId: 'msg-1' }, - ]; + it('should return empty result for null parentMessageId', () => { + const messages = [ + { messageId: 'msg-1', parentMessageId: null }, + { messageId: 'msg-2', parentMessageId: 'msg-1' }, + ]; - const result = getThreadData(messages, null); + const result = getThreadData(messages, null); - expect(result.messageIds).toEqual([]); - expect(result.fileIds).toEqual([]); - }); + expect(result.messageIds).toEqual([]); + expect(result.fileIds).toEqual([]); + }); - it('should return empty result for undefined parentMessageId', () => { - const messages = [{ messageId: 'msg-1', parentMessageId: null }]; + it('should return empty result for undefined parentMessageId', () => { + const messages = [{ messageId: 'msg-1', parentMessageId: null }]; - const result = getThreadData(messages, undefined); + const result = getThreadData(messages, undefined); - expect(result.messageIds).toEqual([]); - expect(result.fileIds).toEqual([]); - }); + expect(result.messageIds).toEqual([]); + expect(result.fileIds).toEqual([]); + }); - it('should return empty result when parentMessageId not found in messages', () => { - const messages = [ - { messageId: 'msg-1', parentMessageId: null }, - { messageId: 'msg-2', parentMessageId: 'msg-1' }, - ]; + it('should return empty result when parentMessageId not found in messages', () => { + const messages = [ + { messageId: 'msg-1', parentMessageId: null }, + { messageId: 'msg-2', parentMessageId: 'msg-1' }, + ]; - const result = getThreadData(messages, 'non-existent'); + const result = getThreadData(messages, 'non-existent'); - expect(result.messageIds).toEqual([]); - expect(result.fileIds).toEqual([]); - }); + expect(result.messageIds).toEqual([]); + expect(result.fileIds).toEqual([]); }); describe('thread traversal', () => { diff --git a/packages/api/src/utils/message.ts b/packages/api/src/utils/message.ts index b1e939c6d7..719d04b838 100644 --- a/packages/api/src/utils/message.ts +++ b/packages/api/src/utils/message.ts @@ -1,6 +1,9 @@ import { Constants } from 'librechat-data-provider'; import type { TFile, TMessage } from 'librechat-data-provider'; +/** Minimal shape for request file entries (from `req.body.files`) */ +type RequestFile = { file_id?: string }; + /** Fields to strip from files before client transmission */ const FILE_STRIP_FIELDS = ['text', '_id', '__v'] as const; @@ -32,6 +35,27 @@ export function sanitizeFileForTransmit>( return sanitized; } +/** Filters attachments to those whose `file_id` appears in `requestFiles`, then sanitizes each. */ +export function buildMessageFiles>( + requestFiles: RequestFile[], + attachments: T[], +): Omit[] { + const requestFileIds = new Set(); + for (const f of requestFiles) { + if (f.file_id) { + requestFileIds.add(f.file_id); + } + } + + const files: Omit[] = []; + for (const attachment of attachments) { + if (attachment.file_id != null && requestFileIds.has(attachment.file_id)) { + files.push(sanitizeFileForTransmit(attachment)); + } + } + return files; +} + /** * Sanitizes a message object before transmitting to client. * Removes large fields like `fileContext` and strips `text` from embedded files.