🏷️ fix: Add user ID to MCP tools cache keys (#10201)

* add user id to mcp tools cache key

* tests

* clean up redundant tests

* remove unused imports
This commit is contained in:
Federico Ruggi 2025-10-30 22:09:56 +01:00 committed by GitHub
parent 8f4705f683
commit ea45d0b9c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 52 additions and 30 deletions

View file

@ -448,7 +448,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
} }
if (!availableTools) { if (!availableTools) {
try { try {
availableTools = await getMCPServerTools(serverName); availableTools = await getMCPServerTools(safeUser.id, serverName);
} catch (error) { } catch (error) {
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error); logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
} }

View file

@ -79,6 +79,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
/** @type {TEphemeralAgent | null} */ /** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent; const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp); const mcpServers = new Set(ephemeralAgent?.mcp);
const userId = req.user?.id; // note: userId cannot be undefined at runtime
if (modelSpec?.mcpServers) { if (modelSpec?.mcpServers) {
for (const mcpServer of modelSpec.mcpServers) { for (const mcpServer of modelSpec.mcpServers) {
mcpServers.add(mcpServer); mcpServers.add(mcpServer);
@ -102,7 +103,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
if (addedServers.has(mcpServer)) { if (addedServers.has(mcpServer)) {
continue; continue;
} }
const serverTools = await getMCPServerTools(mcpServer); const serverTools = await getMCPServerTools(userId, mcpServer);
if (!serverTools) { if (!serverTools) {
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`); tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
addedServers.add(mcpServer); addedServers.add(mcpServer);

View file

@ -1931,7 +1931,7 @@ describe('models/Agent', () => {
}); });
// Mock getMCPServerTools to return tools for each server // Mock getMCPServerTools to return tools for each server
getMCPServerTools.mockImplementation(async (server) => { getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') { if (server === 'server1') {
return { tool1_mcp_server1: {} }; return { tool1_mcp_server1: {} };
} else if (server === 'server2') { } else if (server === 'server2') {
@ -2125,7 +2125,7 @@ describe('models/Agent', () => {
getCachedTools.mockResolvedValue(availableTools); getCachedTools.mockResolvedValue(availableTools);
// Mock getMCPServerTools to return all tools for server1 // Mock getMCPServerTools to return all tools for server1
getMCPServerTools.mockImplementation(async (server) => { getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') { if (server === 'server1') {
return availableTools; // All 100 tools belong to server1 return availableTools; // All 100 tools belong to server1
} }
@ -2674,7 +2674,7 @@ describe('models/Agent', () => {
}); });
// Mock getMCPServerTools to return only tools matching the server // Mock getMCPServerTools to return only tools matching the server
getMCPServerTools.mockImplementation(async (server) => { getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') { if (server === 'server1') {
// Only return tool that correctly matches server1 format // Only return tool that correctly matches server1 format
return { tool_mcp_server1: {} }; return { tool_mcp_server1: {} };

View file

@ -32,7 +32,7 @@ const getMCPTools = async (req, res) => {
const mcpServers = {}; const mcpServers = {};
const cachePromises = configuredServers.map((serverName) => const cachePromises = configuredServers.map((serverName) =>
getMCPServerTools(serverName).then((tools) => ({ serverName, tools })), getMCPServerTools(userId, serverName).then((tools) => ({ serverName, tools })),
); );
const cacheResults = await Promise.all(cachePromises); const cacheResults = await Promise.all(cachePromises);
@ -52,7 +52,7 @@ const getMCPTools = async (req, res) => {
if (Object.keys(serverTools).length > 0) { if (Object.keys(serverTools).length > 0) {
// Cache asynchronously without blocking // Cache asynchronously without blocking
cacheMCPServerTools({ serverName, serverTools }).catch((err) => cacheMCPServerTools({ userId, serverName, serverTools }).catch((err) =>
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err), logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
); );
} }

View file

@ -47,6 +47,7 @@ jest.mock('~/models', () => ({
jest.mock('~/server/services/Config', () => ({ jest.mock('~/server/services/Config', () => ({
setCachedTools: jest.fn(), setCachedTools: jest.fn(),
getCachedTools: jest.fn(), getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
loadCustomConfig: jest.fn(), loadCustomConfig: jest.fn(),
})); }));

View file

@ -205,6 +205,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const tools = await userConnection.fetchTools(); const tools = await userConnection.fetchTools();
await updateMCPServerTools({ await updateMCPServerTools({
userId: flowState.userId,
serverName, serverName,
tools, tools,
}); });

View file

@ -0,0 +1,10 @@
const { ToolCacheKeys } = require('../getCachedTools');
describe('getCachedTools - Cache Isolation Security', () => {
describe('ToolCacheKeys.MCP_SERVER', () => {
it('should generate cache keys that include userId', () => {
const key = ToolCacheKeys.MCP_SERVER('user123', 'github');
expect(key).toBe('tools:mcp:user123:github');
});
});
});

View file

@ -7,24 +7,25 @@ const getLogStores = require('~/cache/getLogStores');
const ToolCacheKeys = { const ToolCacheKeys = {
/** Global tools available to all users */ /** Global tools available to all users */
GLOBAL: 'tools:global', GLOBAL: 'tools:global',
/** MCP tools cached by server name */ /** MCP tools cached by user ID and server name */
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`, MCP_SERVER: (userId, serverName) => `tools:mcp:${userId}:${serverName}`,
}; };
/** /**
* Retrieves available tools from cache * Retrieves available tools from cache
* @function getCachedTools * @function getCachedTools
* @param {Object} options - Options for retrieving tools * @param {Object} options - Options for retrieving tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name to get cached tools for * @param {string} [options.serverName] - MCP server name to get cached tools for
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached * @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/ */
async function getCachedTools(options = {}) { async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName } = options; const { userId, serverName } = options;
// Return MCP server-specific tools if requested // Return MCP server-specific tools if requested
if (serverName) { if (serverName && userId) {
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName)); return await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
} }
// Default to global tools // Default to global tools
@ -36,17 +37,18 @@ async function getCachedTools(options = {}) {
* @function setCachedTools * @function setCachedTools
* @param {Object} tools - The tools object to cache * @param {Object} tools - The tools object to cache
* @param {Object} options - Options for caching tools * @param {Object} options - Options for caching tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name for server-specific tools * @param {string} [options.serverName] - MCP server name for server-specific tools
* @param {number} [options.ttl] - Time to live in milliseconds * @param {number} [options.ttl] - Time to live in milliseconds
* @returns {Promise<boolean>} Whether the operation was successful * @returns {Promise<boolean>} Whether the operation was successful
*/ */
async function setCachedTools(tools, options = {}) { async function setCachedTools(tools, options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, ttl } = options; const { userId, serverName, ttl } = options;
// Cache by MCP server if specified // Cache by MCP server if specified (requires userId)
if (serverName) { if (serverName && userId) {
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl); return await cache.set(ToolCacheKeys.MCP_SERVER(userId, serverName), tools, ttl);
} }
// Default to global cache // Default to global cache
@ -57,13 +59,14 @@ async function setCachedTools(tools, options = {}) {
* Invalidates cached tools * Invalidates cached tools
* @function invalidateCachedTools * @function invalidateCachedTools
* @param {Object} options - Options for invalidating tools * @param {Object} options - Options for invalidating tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name to invalidate * @param {string} [options.serverName] - MCP server name to invalidate
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools * @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
async function invalidateCachedTools(options = {}) { async function invalidateCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, invalidateGlobal = false } = options; const { userId, serverName, invalidateGlobal = false } = options;
const keysToDelete = []; const keysToDelete = [];
@ -71,22 +74,23 @@ async function invalidateCachedTools(options = {}) {
keysToDelete.push(ToolCacheKeys.GLOBAL); keysToDelete.push(ToolCacheKeys.GLOBAL);
} }
if (serverName) { if (serverName && userId) {
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName)); keysToDelete.push(ToolCacheKeys.MCP_SERVER(userId, serverName));
} }
await Promise.all(keysToDelete.map((key) => cache.delete(key))); await Promise.all(keysToDelete.map((key) => cache.delete(key)));
} }
/** /**
* Gets MCP tools for a specific server from cache or merges with global tools * Gets MCP tools for a specific server from cache
* @function getMCPServerTools * @function getMCPServerTools
* @param {string} userId - The user ID
* @param {string} serverName - The MCP server name * @param {string} serverName - The MCP server name
* @returns {Promise<LCAvailableTools|null>} The available tools for the server * @returns {Promise<LCAvailableTools|null>} The available tools for the server
*/ */
async function getMCPServerTools(serverName) { async function getMCPServerTools(userId, serverName) {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName)); const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
if (serverTools) { if (serverTools) {
return serverTools; return serverTools;

View file

@ -6,11 +6,12 @@ const { getLogStores } = require('~/cache');
/** /**
* Updates MCP tools in the cache for a specific server * Updates MCP tools in the cache for a specific server
* @param {Object} params - Parameters for updating MCP tools * @param {Object} params - Parameters for updating MCP tools
* @param {string} params.userId - User ID for user-specific caching
* @param {string} params.serverName - MCP server name * @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server * @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<LCAvailableTools>} * @returns {Promise<LCAvailableTools>}
*/ */
async function updateMCPServerTools({ serverName, tools }) { async function updateMCPServerTools({ userId, serverName, tools }) {
try { try {
const serverTools = {}; const serverTools = {};
const mcpDelimiter = Constants.mcp_delimiter; const mcpDelimiter = Constants.mcp_delimiter;
@ -27,14 +28,16 @@ async function updateMCPServerTools({ serverName, tools }) {
}; };
} }
await setCachedTools(serverTools, { serverName }); await setCachedTools(serverTools, { userId, serverName });
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS); await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`); logger.debug(
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
);
return serverTools; return serverTools;
} catch (error) { } catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error); logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error);
throw error; throw error;
} }
} }
@ -65,21 +68,22 @@ async function mergeAppTools(appTools) {
/** /**
* Caches MCP server tools (no longer merges with global) * Caches MCP server tools (no longer merges with global)
* @param {object} params * @param {object} params
* @param {string} params.userId - User ID for user-specific caching
* @param {string} params.serverName * @param {string} params.serverName
* @param {import('@librechat/api').LCAvailableTools} params.serverTools * @param {import('@librechat/api').LCAvailableTools} params.serverTools
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
async function cacheMCPServerTools({ serverName, serverTools }) { async function cacheMCPServerTools({ userId, serverName, serverTools }) {
try { try {
const count = Object.keys(serverTools).length; const count = Object.keys(serverTools).length;
if (!count) { if (!count) {
return; return;
} }
// Only cache server-specific tools, no merging with global // Only cache server-specific tools, no merging with global
await setCachedTools(serverTools, { serverName }); await setCachedTools(serverTools, { userId, serverName });
logger.debug(`Cached ${count} MCP server tools for ${serverName}`); logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`);
} catch (error) { } catch (error) {
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error); logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error);
throw error; throw error;
} }
} }

View file

@ -98,6 +98,7 @@ async function reinitMCPServer({
if (connection && !oauthRequired) { if (connection && !oauthRequired) {
tools = await connection.fetchTools(); tools = await connection.fetchTools();
availableTools = await updateMCPServerTools({ availableTools = await updateMCPServerTools({
userId: user.id,
serverName, serverName,
tools, tools,
}); });