diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 698014cbe0..e32ca6bc44 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -448,7 +448,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} } if (!availableTools) { try { - availableTools = await getMCPServerTools(serverName); + availableTools = await getMCPServerTools(safeUser.id, serverName); } catch (error) { logger.error(`Error fetching available tools for MCP server ${serverName}:`, error); } diff --git a/api/models/Agent.js b/api/models/Agent.js index f5f740ba7b..b802ca187b 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -79,6 +79,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet /** @type {TEphemeralAgent | null} */ const ephemeralAgent = req.body.ephemeralAgent; const mcpServers = new Set(ephemeralAgent?.mcp); + const userId = req.user?.id; // note: userId cannot be undefined at runtime if (modelSpec?.mcpServers) { for (const mcpServer of modelSpec.mcpServers) { mcpServers.add(mcpServer); @@ -102,7 +103,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet if (addedServers.has(mcpServer)) { continue; } - const serverTools = await getMCPServerTools(mcpServer); + const serverTools = await getMCPServerTools(userId, mcpServer); if (!serverTools) { tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`); addedServers.add(mcpServer); diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js index f95db65013..6c7db6121e 100644 --- a/api/models/Agent.spec.js +++ b/api/models/Agent.spec.js @@ -1931,7 +1931,7 @@ describe('models/Agent', () => { }); // Mock getMCPServerTools to return tools for each server - getMCPServerTools.mockImplementation(async (server) => { + getMCPServerTools.mockImplementation(async (_userId, server) => { if (server === 'server1') { return { tool1_mcp_server1: {} }; } else if (server === 'server2') { @@ -2125,7 +2125,7 @@ describe('models/Agent', () => { getCachedTools.mockResolvedValue(availableTools); // Mock getMCPServerTools to return all tools for server1 - getMCPServerTools.mockImplementation(async (server) => { + getMCPServerTools.mockImplementation(async (_userId, server) => { if (server === 'server1') { return availableTools; // All 100 tools belong to server1 } @@ -2674,7 +2674,7 @@ describe('models/Agent', () => { }); // Mock getMCPServerTools to return only tools matching the server - getMCPServerTools.mockImplementation(async (server) => { + getMCPServerTools.mockImplementation(async (_userId, server) => { if (server === 'server1') { // Only return tool that correctly matches server1 format return { tool_mcp_server1: {} }; diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 839d9bd17b..9e520d392e 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -32,7 +32,7 @@ const getMCPTools = async (req, res) => { const mcpServers = {}; const cachePromises = configuredServers.map((serverName) => - getMCPServerTools(serverName).then((tools) => ({ serverName, tools })), + getMCPServerTools(userId, serverName).then((tools) => ({ serverName, tools })), ); const cacheResults = await Promise.all(cachePromises); @@ -52,7 +52,7 @@ const getMCPTools = async (req, res) => { if (Object.keys(serverTools).length > 0) { // Cache asynchronously without blocking - cacheMCPServerTools({ serverName, serverTools }).catch((err) => + cacheMCPServerTools({ userId, serverName, serverTools }).catch((err) => logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err), ); } diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 64c95c58ee..8ae92cdd3d 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -47,6 +47,7 @@ jest.mock('~/models', () => ({ jest.mock('~/server/services/Config', () => ({ setCachedTools: jest.fn(), getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), loadCustomConfig: jest.fn(), })); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index e8415fd801..9b66b10e52 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -205,6 +205,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => { const tools = await userConnection.fetchTools(); await updateMCPServerTools({ + userId: flowState.userId, serverName, tools, }); diff --git a/api/server/services/Config/__tests__/getCachedTools.spec.js b/api/server/services/Config/__tests__/getCachedTools.spec.js new file mode 100644 index 0000000000..48ab6e0737 --- /dev/null +++ b/api/server/services/Config/__tests__/getCachedTools.spec.js @@ -0,0 +1,10 @@ +const { ToolCacheKeys } = require('../getCachedTools'); + +describe('getCachedTools - Cache Isolation Security', () => { + describe('ToolCacheKeys.MCP_SERVER', () => { + it('should generate cache keys that include userId', () => { + const key = ToolCacheKeys.MCP_SERVER('user123', 'github'); + expect(key).toBe('tools:mcp:user123:github'); + }); + }); +}); diff --git a/api/server/services/Config/getCachedTools.js b/api/server/services/Config/getCachedTools.js index 59a0c8cc5d..841ca04c94 100644 --- a/api/server/services/Config/getCachedTools.js +++ b/api/server/services/Config/getCachedTools.js @@ -7,24 +7,25 @@ const getLogStores = require('~/cache/getLogStores'); const ToolCacheKeys = { /** Global tools available to all users */ GLOBAL: 'tools:global', - /** MCP tools cached by server name */ - MCP_SERVER: (serverName) => `tools:mcp:${serverName}`, + /** MCP tools cached by user ID and server name */ + MCP_SERVER: (userId, serverName) => `tools:mcp:${userId}:${serverName}`, }; /** * Retrieves available tools from cache * @function getCachedTools * @param {Object} options - Options for retrieving tools + * @param {string} [options.userId] - User ID for user-specific MCP tools * @param {string} [options.serverName] - MCP server name to get cached tools for * @returns {Promise} The available tools object or null if not cached */ async function getCachedTools(options = {}) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const { serverName } = options; + const { userId, serverName } = options; // Return MCP server-specific tools if requested - if (serverName) { - return await cache.get(ToolCacheKeys.MCP_SERVER(serverName)); + if (serverName && userId) { + return await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName)); } // Default to global tools @@ -36,17 +37,18 @@ async function getCachedTools(options = {}) { * @function setCachedTools * @param {Object} tools - The tools object to cache * @param {Object} options - Options for caching tools + * @param {string} [options.userId] - User ID for user-specific MCP tools * @param {string} [options.serverName] - MCP server name for server-specific tools * @param {number} [options.ttl] - Time to live in milliseconds * @returns {Promise} Whether the operation was successful */ async function setCachedTools(tools, options = {}) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const { serverName, ttl } = options; + const { userId, serverName, ttl } = options; - // Cache by MCP server if specified - if (serverName) { - return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl); + // Cache by MCP server if specified (requires userId) + if (serverName && userId) { + return await cache.set(ToolCacheKeys.MCP_SERVER(userId, serverName), tools, ttl); } // Default to global cache @@ -57,13 +59,14 @@ async function setCachedTools(tools, options = {}) { * Invalidates cached tools * @function invalidateCachedTools * @param {Object} options - Options for invalidating tools + * @param {string} [options.userId] - User ID for user-specific MCP tools * @param {string} [options.serverName] - MCP server name to invalidate * @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools * @returns {Promise} */ async function invalidateCachedTools(options = {}) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const { serverName, invalidateGlobal = false } = options; + const { userId, serverName, invalidateGlobal = false } = options; const keysToDelete = []; @@ -71,22 +74,23 @@ async function invalidateCachedTools(options = {}) { keysToDelete.push(ToolCacheKeys.GLOBAL); } - if (serverName) { - keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName)); + if (serverName && userId) { + keysToDelete.push(ToolCacheKeys.MCP_SERVER(userId, serverName)); } await Promise.all(keysToDelete.map((key) => cache.delete(key))); } /** - * Gets MCP tools for a specific server from cache or merges with global tools + * Gets MCP tools for a specific server from cache * @function getMCPServerTools + * @param {string} userId - The user ID * @param {string} serverName - The MCP server name * @returns {Promise} The available tools for the server */ -async function getMCPServerTools(serverName) { +async function getMCPServerTools(userId, serverName) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName)); + const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName)); if (serverTools) { return serverTools; diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index 75824d1b30..7f4210f8c9 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -6,11 +6,12 @@ const { getLogStores } = require('~/cache'); /** * Updates MCP tools in the cache for a specific server * @param {Object} params - Parameters for updating MCP tools + * @param {string} params.userId - User ID for user-specific caching * @param {string} params.serverName - MCP server name * @param {Array} params.tools - Array of tool objects from MCP server * @returns {Promise} */ -async function updateMCPServerTools({ serverName, tools }) { +async function updateMCPServerTools({ userId, serverName, tools }) { try { const serverTools = {}; const mcpDelimiter = Constants.mcp_delimiter; @@ -27,14 +28,16 @@ async function updateMCPServerTools({ serverName, tools }) { }; } - await setCachedTools(serverTools, { serverName }); + await setCachedTools(serverTools, { userId, serverName }); const cache = getLogStores(CacheKeys.CONFIG_STORE); await cache.delete(CacheKeys.TOOLS); - logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`); + logger.debug( + `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, + ); return serverTools; } catch (error) { - logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error); + logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error); throw error; } } @@ -65,21 +68,22 @@ async function mergeAppTools(appTools) { /** * Caches MCP server tools (no longer merges with global) * @param {object} params + * @param {string} params.userId - User ID for user-specific caching * @param {string} params.serverName * @param {import('@librechat/api').LCAvailableTools} params.serverTools * @returns {Promise} */ -async function cacheMCPServerTools({ serverName, serverTools }) { +async function cacheMCPServerTools({ userId, serverName, serverTools }) { try { const count = Object.keys(serverTools).length; if (!count) { return; } // Only cache server-specific tools, no merging with global - await setCachedTools(serverTools, { serverName }); - logger.debug(`Cached ${count} MCP server tools for ${serverName}`); + await setCachedTools(serverTools, { userId, serverName }); + logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`); } catch (error) { - logger.error(`Failed to cache MCP server tools for ${serverName}:`, error); + logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error); throw error; } } diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index e6d293800d..521560aad4 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -98,6 +98,7 @@ async function reinitMCPServer({ if (connection && !oauthRequired) { tools = await connection.fetchTools(); availableTools = await updateMCPServerTools({ + userId: user.id, serverName, tools, });