mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-16 16:30:15 +01:00
🏷️ fix: Add user ID to MCP tools cache keys (#10201)
* add user id to mcp tools cache key * tests * clean up redundant tests * remove unused imports
This commit is contained in:
parent
8f4705f683
commit
ea45d0b9c6
10 changed files with 52 additions and 30 deletions
|
|
@ -448,7 +448,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||||
}
|
}
|
||||||
if (!availableTools) {
|
if (!availableTools) {
|
||||||
try {
|
try {
|
||||||
availableTools = await getMCPServerTools(serverName);
|
availableTools = await getMCPServerTools(safeUser.id, serverName);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
|
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
|
||||||
/** @type {TEphemeralAgent | null} */
|
/** @type {TEphemeralAgent | null} */
|
||||||
const ephemeralAgent = req.body.ephemeralAgent;
|
const ephemeralAgent = req.body.ephemeralAgent;
|
||||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||||
|
const userId = req.user?.id; // note: userId cannot be undefined at runtime
|
||||||
if (modelSpec?.mcpServers) {
|
if (modelSpec?.mcpServers) {
|
||||||
for (const mcpServer of modelSpec.mcpServers) {
|
for (const mcpServer of modelSpec.mcpServers) {
|
||||||
mcpServers.add(mcpServer);
|
mcpServers.add(mcpServer);
|
||||||
|
|
@ -102,7 +103,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
|
||||||
if (addedServers.has(mcpServer)) {
|
if (addedServers.has(mcpServer)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const serverTools = await getMCPServerTools(mcpServer);
|
const serverTools = await getMCPServerTools(userId, mcpServer);
|
||||||
if (!serverTools) {
|
if (!serverTools) {
|
||||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||||
addedServers.add(mcpServer);
|
addedServers.add(mcpServer);
|
||||||
|
|
|
||||||
|
|
@ -1931,7 +1931,7 @@ describe('models/Agent', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Mock getMCPServerTools to return tools for each server
|
// Mock getMCPServerTools to return tools for each server
|
||||||
getMCPServerTools.mockImplementation(async (server) => {
|
getMCPServerTools.mockImplementation(async (_userId, server) => {
|
||||||
if (server === 'server1') {
|
if (server === 'server1') {
|
||||||
return { tool1_mcp_server1: {} };
|
return { tool1_mcp_server1: {} };
|
||||||
} else if (server === 'server2') {
|
} else if (server === 'server2') {
|
||||||
|
|
@ -2125,7 +2125,7 @@ describe('models/Agent', () => {
|
||||||
getCachedTools.mockResolvedValue(availableTools);
|
getCachedTools.mockResolvedValue(availableTools);
|
||||||
|
|
||||||
// Mock getMCPServerTools to return all tools for server1
|
// Mock getMCPServerTools to return all tools for server1
|
||||||
getMCPServerTools.mockImplementation(async (server) => {
|
getMCPServerTools.mockImplementation(async (_userId, server) => {
|
||||||
if (server === 'server1') {
|
if (server === 'server1') {
|
||||||
return availableTools; // All 100 tools belong to server1
|
return availableTools; // All 100 tools belong to server1
|
||||||
}
|
}
|
||||||
|
|
@ -2674,7 +2674,7 @@ describe('models/Agent', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Mock getMCPServerTools to return only tools matching the server
|
// Mock getMCPServerTools to return only tools matching the server
|
||||||
getMCPServerTools.mockImplementation(async (server) => {
|
getMCPServerTools.mockImplementation(async (_userId, server) => {
|
||||||
if (server === 'server1') {
|
if (server === 'server1') {
|
||||||
// Only return tool that correctly matches server1 format
|
// Only return tool that correctly matches server1 format
|
||||||
return { tool_mcp_server1: {} };
|
return { tool_mcp_server1: {} };
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ const getMCPTools = async (req, res) => {
|
||||||
const mcpServers = {};
|
const mcpServers = {};
|
||||||
|
|
||||||
const cachePromises = configuredServers.map((serverName) =>
|
const cachePromises = configuredServers.map((serverName) =>
|
||||||
getMCPServerTools(serverName).then((tools) => ({ serverName, tools })),
|
getMCPServerTools(userId, serverName).then((tools) => ({ serverName, tools })),
|
||||||
);
|
);
|
||||||
const cacheResults = await Promise.all(cachePromises);
|
const cacheResults = await Promise.all(cachePromises);
|
||||||
|
|
||||||
|
|
@ -52,7 +52,7 @@ const getMCPTools = async (req, res) => {
|
||||||
|
|
||||||
if (Object.keys(serverTools).length > 0) {
|
if (Object.keys(serverTools).length > 0) {
|
||||||
// Cache asynchronously without blocking
|
// Cache asynchronously without blocking
|
||||||
cacheMCPServerTools({ serverName, serverTools }).catch((err) =>
|
cacheMCPServerTools({ userId, serverName, serverTools }).catch((err) =>
|
||||||
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
|
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ jest.mock('~/models', () => ({
|
||||||
jest.mock('~/server/services/Config', () => ({
|
jest.mock('~/server/services/Config', () => ({
|
||||||
setCachedTools: jest.fn(),
|
setCachedTools: jest.fn(),
|
||||||
getCachedTools: jest.fn(),
|
getCachedTools: jest.fn(),
|
||||||
|
getMCPServerTools: jest.fn(),
|
||||||
loadCustomConfig: jest.fn(),
|
loadCustomConfig: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -205,6 +205,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||||
|
|
||||||
const tools = await userConnection.fetchTools();
|
const tools = await userConnection.fetchTools();
|
||||||
await updateMCPServerTools({
|
await updateMCPServerTools({
|
||||||
|
userId: flowState.userId,
|
||||||
serverName,
|
serverName,
|
||||||
tools,
|
tools,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
10
api/server/services/Config/__tests__/getCachedTools.spec.js
Normal file
10
api/server/services/Config/__tests__/getCachedTools.spec.js
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
const { ToolCacheKeys } = require('../getCachedTools');
|
||||||
|
|
||||||
|
describe('getCachedTools - Cache Isolation Security', () => {
|
||||||
|
describe('ToolCacheKeys.MCP_SERVER', () => {
|
||||||
|
it('should generate cache keys that include userId', () => {
|
||||||
|
const key = ToolCacheKeys.MCP_SERVER('user123', 'github');
|
||||||
|
expect(key).toBe('tools:mcp:user123:github');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -7,24 +7,25 @@ const getLogStores = require('~/cache/getLogStores');
|
||||||
const ToolCacheKeys = {
|
const ToolCacheKeys = {
|
||||||
/** Global tools available to all users */
|
/** Global tools available to all users */
|
||||||
GLOBAL: 'tools:global',
|
GLOBAL: 'tools:global',
|
||||||
/** MCP tools cached by server name */
|
/** MCP tools cached by user ID and server name */
|
||||||
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`,
|
MCP_SERVER: (userId, serverName) => `tools:mcp:${userId}:${serverName}`,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieves available tools from cache
|
* Retrieves available tools from cache
|
||||||
* @function getCachedTools
|
* @function getCachedTools
|
||||||
* @param {Object} options - Options for retrieving tools
|
* @param {Object} options - Options for retrieving tools
|
||||||
|
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||||
* @param {string} [options.serverName] - MCP server name to get cached tools for
|
* @param {string} [options.serverName] - MCP server name to get cached tools for
|
||||||
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
|
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
|
||||||
*/
|
*/
|
||||||
async function getCachedTools(options = {}) {
|
async function getCachedTools(options = {}) {
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
const { serverName } = options;
|
const { userId, serverName } = options;
|
||||||
|
|
||||||
// Return MCP server-specific tools if requested
|
// Return MCP server-specific tools if requested
|
||||||
if (serverName) {
|
if (serverName && userId) {
|
||||||
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
|
return await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to global tools
|
// Default to global tools
|
||||||
|
|
@ -36,17 +37,18 @@ async function getCachedTools(options = {}) {
|
||||||
* @function setCachedTools
|
* @function setCachedTools
|
||||||
* @param {Object} tools - The tools object to cache
|
* @param {Object} tools - The tools object to cache
|
||||||
* @param {Object} options - Options for caching tools
|
* @param {Object} options - Options for caching tools
|
||||||
|
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||||
* @param {string} [options.serverName] - MCP server name for server-specific tools
|
* @param {string} [options.serverName] - MCP server name for server-specific tools
|
||||||
* @param {number} [options.ttl] - Time to live in milliseconds
|
* @param {number} [options.ttl] - Time to live in milliseconds
|
||||||
* @returns {Promise<boolean>} Whether the operation was successful
|
* @returns {Promise<boolean>} Whether the operation was successful
|
||||||
*/
|
*/
|
||||||
async function setCachedTools(tools, options = {}) {
|
async function setCachedTools(tools, options = {}) {
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
const { serverName, ttl } = options;
|
const { userId, serverName, ttl } = options;
|
||||||
|
|
||||||
// Cache by MCP server if specified
|
// Cache by MCP server if specified (requires userId)
|
||||||
if (serverName) {
|
if (serverName && userId) {
|
||||||
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl);
|
return await cache.set(ToolCacheKeys.MCP_SERVER(userId, serverName), tools, ttl);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to global cache
|
// Default to global cache
|
||||||
|
|
@ -57,13 +59,14 @@ async function setCachedTools(tools, options = {}) {
|
||||||
* Invalidates cached tools
|
* Invalidates cached tools
|
||||||
* @function invalidateCachedTools
|
* @function invalidateCachedTools
|
||||||
* @param {Object} options - Options for invalidating tools
|
* @param {Object} options - Options for invalidating tools
|
||||||
|
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||||
* @param {string} [options.serverName] - MCP server name to invalidate
|
* @param {string} [options.serverName] - MCP server name to invalidate
|
||||||
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
|
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
|
||||||
* @returns {Promise<void>}
|
* @returns {Promise<void>}
|
||||||
*/
|
*/
|
||||||
async function invalidateCachedTools(options = {}) {
|
async function invalidateCachedTools(options = {}) {
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
const { serverName, invalidateGlobal = false } = options;
|
const { userId, serverName, invalidateGlobal = false } = options;
|
||||||
|
|
||||||
const keysToDelete = [];
|
const keysToDelete = [];
|
||||||
|
|
||||||
|
|
@ -71,22 +74,23 @@ async function invalidateCachedTools(options = {}) {
|
||||||
keysToDelete.push(ToolCacheKeys.GLOBAL);
|
keysToDelete.push(ToolCacheKeys.GLOBAL);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (serverName) {
|
if (serverName && userId) {
|
||||||
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName));
|
keysToDelete.push(ToolCacheKeys.MCP_SERVER(userId, serverName));
|
||||||
}
|
}
|
||||||
|
|
||||||
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
|
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets MCP tools for a specific server from cache or merges with global tools
|
* Gets MCP tools for a specific server from cache
|
||||||
* @function getMCPServerTools
|
* @function getMCPServerTools
|
||||||
|
* @param {string} userId - The user ID
|
||||||
* @param {string} serverName - The MCP server name
|
* @param {string} serverName - The MCP server name
|
||||||
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
|
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
|
||||||
*/
|
*/
|
||||||
async function getMCPServerTools(serverName) {
|
async function getMCPServerTools(userId, serverName) {
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
|
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
|
||||||
|
|
||||||
if (serverTools) {
|
if (serverTools) {
|
||||||
return serverTools;
|
return serverTools;
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,12 @@ const { getLogStores } = require('~/cache');
|
||||||
/**
|
/**
|
||||||
* Updates MCP tools in the cache for a specific server
|
* Updates MCP tools in the cache for a specific server
|
||||||
* @param {Object} params - Parameters for updating MCP tools
|
* @param {Object} params - Parameters for updating MCP tools
|
||||||
|
* @param {string} params.userId - User ID for user-specific caching
|
||||||
* @param {string} params.serverName - MCP server name
|
* @param {string} params.serverName - MCP server name
|
||||||
* @param {Array} params.tools - Array of tool objects from MCP server
|
* @param {Array} params.tools - Array of tool objects from MCP server
|
||||||
* @returns {Promise<LCAvailableTools>}
|
* @returns {Promise<LCAvailableTools>}
|
||||||
*/
|
*/
|
||||||
async function updateMCPServerTools({ serverName, tools }) {
|
async function updateMCPServerTools({ userId, serverName, tools }) {
|
||||||
try {
|
try {
|
||||||
const serverTools = {};
|
const serverTools = {};
|
||||||
const mcpDelimiter = Constants.mcp_delimiter;
|
const mcpDelimiter = Constants.mcp_delimiter;
|
||||||
|
|
@ -27,14 +28,16 @@ async function updateMCPServerTools({ serverName, tools }) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
await setCachedTools(serverTools, { serverName });
|
await setCachedTools(serverTools, { userId, serverName });
|
||||||
|
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
await cache.delete(CacheKeys.TOOLS);
|
await cache.delete(CacheKeys.TOOLS);
|
||||||
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`);
|
logger.debug(
|
||||||
|
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
|
||||||
|
);
|
||||||
return serverTools;
|
return serverTools;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
|
logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error);
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -65,21 +68,22 @@ async function mergeAppTools(appTools) {
|
||||||
/**
|
/**
|
||||||
* Caches MCP server tools (no longer merges with global)
|
* Caches MCP server tools (no longer merges with global)
|
||||||
* @param {object} params
|
* @param {object} params
|
||||||
|
* @param {string} params.userId - User ID for user-specific caching
|
||||||
* @param {string} params.serverName
|
* @param {string} params.serverName
|
||||||
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
|
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
|
||||||
* @returns {Promise<void>}
|
* @returns {Promise<void>}
|
||||||
*/
|
*/
|
||||||
async function cacheMCPServerTools({ serverName, serverTools }) {
|
async function cacheMCPServerTools({ userId, serverName, serverTools }) {
|
||||||
try {
|
try {
|
||||||
const count = Object.keys(serverTools).length;
|
const count = Object.keys(serverTools).length;
|
||||||
if (!count) {
|
if (!count) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Only cache server-specific tools, no merging with global
|
// Only cache server-specific tools, no merging with global
|
||||||
await setCachedTools(serverTools, { serverName });
|
await setCachedTools(serverTools, { userId, serverName });
|
||||||
logger.debug(`Cached ${count} MCP server tools for ${serverName}`);
|
logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error);
|
logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error);
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,7 @@ async function reinitMCPServer({
|
||||||
if (connection && !oauthRequired) {
|
if (connection && !oauthRequired) {
|
||||||
tools = await connection.fetchTools();
|
tools = await connection.fetchTools();
|
||||||
availableTools = await updateMCPServerTools({
|
availableTools = await updateMCPServerTools({
|
||||||
|
userId: user.id,
|
||||||
serverName,
|
serverName,
|
||||||
tools,
|
tools,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue