🏷️ fix: Add user ID to MCP tools cache keys (#10201)

* add user id to mcp tools cache key

* tests

* clean up redundant tests

* remove unused imports
This commit is contained in:
Federico Ruggi 2025-10-30 22:09:56 +01:00 committed by GitHub
parent 8f4705f683
commit ea45d0b9c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 52 additions and 30 deletions

View file

@ -448,7 +448,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
}
if (!availableTools) {
try {
availableTools = await getMCPServerTools(serverName);
availableTools = await getMCPServerTools(safeUser.id, serverName);
} catch (error) {
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
}

View file

@ -79,6 +79,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
/** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);
const userId = req.user?.id; // note: userId cannot be undefined at runtime
if (modelSpec?.mcpServers) {
for (const mcpServer of modelSpec.mcpServers) {
mcpServers.add(mcpServer);
@ -102,7 +103,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
if (addedServers.has(mcpServer)) {
continue;
}
const serverTools = await getMCPServerTools(mcpServer);
const serverTools = await getMCPServerTools(userId, mcpServer);
if (!serverTools) {
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
addedServers.add(mcpServer);

View file

@ -1931,7 +1931,7 @@ describe('models/Agent', () => {
});
// Mock getMCPServerTools to return tools for each server
getMCPServerTools.mockImplementation(async (server) => {
getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') {
return { tool1_mcp_server1: {} };
} else if (server === 'server2') {
@ -2125,7 +2125,7 @@ describe('models/Agent', () => {
getCachedTools.mockResolvedValue(availableTools);
// Mock getMCPServerTools to return all tools for server1
getMCPServerTools.mockImplementation(async (server) => {
getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') {
return availableTools; // All 100 tools belong to server1
}
@ -2674,7 +2674,7 @@ describe('models/Agent', () => {
});
// Mock getMCPServerTools to return only tools matching the server
getMCPServerTools.mockImplementation(async (server) => {
getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') {
// Only return tool that correctly matches server1 format
return { tool_mcp_server1: {} };

View file

@ -32,7 +32,7 @@ const getMCPTools = async (req, res) => {
const mcpServers = {};
const cachePromises = configuredServers.map((serverName) =>
getMCPServerTools(serverName).then((tools) => ({ serverName, tools })),
getMCPServerTools(userId, serverName).then((tools) => ({ serverName, tools })),
);
const cacheResults = await Promise.all(cachePromises);
@ -52,7 +52,7 @@ const getMCPTools = async (req, res) => {
if (Object.keys(serverTools).length > 0) {
// Cache asynchronously without blocking
cacheMCPServerTools({ serverName, serverTools }).catch((err) =>
cacheMCPServerTools({ userId, serverName, serverTools }).catch((err) =>
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
);
}

View file

@ -47,6 +47,7 @@ jest.mock('~/models', () => ({
jest.mock('~/server/services/Config', () => ({
setCachedTools: jest.fn(),
getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
loadCustomConfig: jest.fn(),
}));

View file

@ -205,6 +205,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const tools = await userConnection.fetchTools();
await updateMCPServerTools({
userId: flowState.userId,
serverName,
tools,
});

View file

@ -0,0 +1,10 @@
const { ToolCacheKeys } = require('../getCachedTools');
describe('getCachedTools - Cache Isolation Security', () => {
describe('ToolCacheKeys.MCP_SERVER', () => {
it('should generate cache keys that include userId', () => {
const key = ToolCacheKeys.MCP_SERVER('user123', 'github');
expect(key).toBe('tools:mcp:user123:github');
});
});
});

View file

@ -7,24 +7,25 @@ const getLogStores = require('~/cache/getLogStores');
const ToolCacheKeys = {
/** Global tools available to all users */
GLOBAL: 'tools:global',
/** MCP tools cached by server name */
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`,
/** MCP tools cached by user ID and server name */
MCP_SERVER: (userId, serverName) => `tools:mcp:${userId}:${serverName}`,
};
/**
* Retrieves available tools from cache
* @function getCachedTools
* @param {Object} options - Options for retrieving tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name to get cached tools for
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/
async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName } = options;
const { userId, serverName } = options;
// Return MCP server-specific tools if requested
if (serverName) {
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
if (serverName && userId) {
return await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
}
// Default to global tools
@ -36,17 +37,18 @@ async function getCachedTools(options = {}) {
* @function setCachedTools
* @param {Object} tools - The tools object to cache
* @param {Object} options - Options for caching tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name for server-specific tools
* @param {number} [options.ttl] - Time to live in milliseconds
* @returns {Promise<boolean>} Whether the operation was successful
*/
async function setCachedTools(tools, options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, ttl } = options;
const { userId, serverName, ttl } = options;
// Cache by MCP server if specified
if (serverName) {
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl);
// Cache by MCP server if specified (requires userId)
if (serverName && userId) {
return await cache.set(ToolCacheKeys.MCP_SERVER(userId, serverName), tools, ttl);
}
// Default to global cache
@ -57,13 +59,14 @@ async function setCachedTools(tools, options = {}) {
* Invalidates cached tools
* @function invalidateCachedTools
* @param {Object} options - Options for invalidating tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name to invalidate
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
* @returns {Promise<void>}
*/
async function invalidateCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, invalidateGlobal = false } = options;
const { userId, serverName, invalidateGlobal = false } = options;
const keysToDelete = [];
@ -71,22 +74,23 @@ async function invalidateCachedTools(options = {}) {
keysToDelete.push(ToolCacheKeys.GLOBAL);
}
if (serverName) {
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName));
if (serverName && userId) {
keysToDelete.push(ToolCacheKeys.MCP_SERVER(userId, serverName));
}
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
}
/**
* Gets MCP tools for a specific server from cache or merges with global tools
* Gets MCP tools for a specific server from cache
* @function getMCPServerTools
* @param {string} userId - The user ID
* @param {string} serverName - The MCP server name
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
*/
async function getMCPServerTools(serverName) {
async function getMCPServerTools(userId, serverName) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
if (serverTools) {
return serverTools;

View file

@ -6,11 +6,12 @@ const { getLogStores } = require('~/cache');
/**
* Updates MCP tools in the cache for a specific server
* @param {Object} params - Parameters for updating MCP tools
* @param {string} params.userId - User ID for user-specific caching
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPServerTools({ serverName, tools }) {
async function updateMCPServerTools({ userId, serverName, tools }) {
try {
const serverTools = {};
const mcpDelimiter = Constants.mcp_delimiter;
@ -27,14 +28,16 @@ async function updateMCPServerTools({ serverName, tools }) {
};
}
await setCachedTools(serverTools, { serverName });
await setCachedTools(serverTools, { userId, serverName });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`);
logger.debug(
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
);
return serverTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error);
throw error;
}
}
@ -65,21 +68,22 @@ async function mergeAppTools(appTools) {
/**
* Caches MCP server tools (no longer merges with global)
* @param {object} params
* @param {string} params.userId - User ID for user-specific caching
* @param {string} params.serverName
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
* @returns {Promise<void>}
*/
async function cacheMCPServerTools({ serverName, serverTools }) {
async function cacheMCPServerTools({ userId, serverName, serverTools }) {
try {
const count = Object.keys(serverTools).length;
if (!count) {
return;
}
// Only cache server-specific tools, no merging with global
await setCachedTools(serverTools, { serverName });
logger.debug(`Cached ${count} MCP server tools for ${serverName}`);
await setCachedTools(serverTools, { userId, serverName });
logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`);
} catch (error) {
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error);
logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error);
throw error;
}
}

View file

@ -98,6 +98,7 @@ async function reinitMCPServer({
if (connection && !oauthRequired) {
tools = await connection.fetchTools();
availableTools = await updateMCPServerTools({
userId: user.id,
serverName,
tools,
});