From 1fe977e48fdef75313674c4671cecc50cad76973 Mon Sep 17 00:00:00 2001 From: Dustin Healy <54083382+dustinhealy@users.noreply.github.com> Date: Thu, 24 Jul 2025 07:44:58 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20MCP=20Name=20Normalizatio?= =?UTF-8?q?n=20breaking=20User=20Provided=20Variables=20(#8644)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/server/services/Config/getCustomConfig.js | 5 - api/server/services/MCP.js | 1 + packages/api/src/mcp/auth.test.ts | 168 ++++++++++++++++++ packages/api/src/mcp/auth.ts | 11 +- 4 files changed, 172 insertions(+), 13 deletions(-) create mode 100644 packages/api/src/mcp/auth.test.ts diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index 7495ce1e2a..a7cb74de54 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -3,7 +3,6 @@ const { isEnabled, getUserMCPAuthMap } = require('@librechat/api'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { normalizeEndpointName } = require('~/server/utils'); const loadCustomConfig = require('./loadCustomConfig'); -const { getCachedTools } = require('./getCachedTools'); const getLogStores = require('~/cache/getLogStores'); /** @@ -66,13 +65,9 @@ async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) { if (!tools || tools.length === 0) { return; } - const appTools = await getCachedTools({ - userId, - }); return await getUserMCPAuthMap({ tools, userId, - appTools, findPluginAuthsByKeys, }); } catch (err) { diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 3f0a4d618e..9970981828 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -235,6 +235,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) { responseFormat: AgentConstants.CONTENT_AND_ARTIFACT, }); toolInstance.mcp = true; + toolInstance.mcpRawServerName = serverName; return toolInstance; } diff --git a/packages/api/src/mcp/auth.test.ts b/packages/api/src/mcp/auth.test.ts new file mode 100644 index 0000000000..7bfb40ae93 --- /dev/null +++ b/packages/api/src/mcp/auth.test.ts @@ -0,0 +1,168 @@ +import type { PluginAuthMethods } from '@librechat/data-schemas'; +import type { GenericTool } from '@librechat/agents'; +import { getPluginAuthMap } from '~/agents/auth'; +import { getUserMCPAuthMap } from './auth'; + +jest.mock('~/agents/auth', () => ({ + getPluginAuthMap: jest.fn(), +})); + +const mockGetPluginAuthMap = getPluginAuthMap as jest.MockedFunction; + +const createMockTool = ( + name: string, + mcpRawServerName?: string, + mcp = true, +): GenericTool & { mcpRawServerName?: string; mcp?: boolean } => + ({ + name, + mcpRawServerName, + mcp, + description: 'Mock tool', + }) as GenericTool & { mcpRawServerName?: string; mcp?: boolean }; + +const mockFindPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'] = jest.fn(); + +describe('getUserMCPAuthMap', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('Core Functionality', () => { + it('should handle server names with various special characters and spaces', async () => { + const testCases = [ + { + originalName: 'Connector: Company', + normalizedToolName: 'tool_mcp_Connector__Company', + }, + { + originalName: 'Server (Production) @ Company.com', + normalizedToolName: 'tool_mcp_Server__Production____Company.com', + }, + { + originalName: '🌟 Testing Server™ (α-β) 测试服务器', + normalizedToolName: 'tool_mcp_____Testing_Server_________', + }, + ]; + + const tools = testCases.map((testCase) => + createMockTool(testCase.normalizedToolName, testCase.originalName), + ); + + const expectedKeys = testCases.map((tc) => `mcp_${tc.originalName}`); + mockGetPluginAuthMap.mockResolvedValue({}); + + await getUserMCPAuthMap({ + userId: 'user123', + tools, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(mockGetPluginAuthMap).toHaveBeenCalledWith({ + userId: 'user123', + pluginKeys: expectedKeys, + throwError: false, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + }); + }); + + describe('Edge Cases', () => { + it('should return empty object when no tools have mcpRawServerName', async () => { + const tools = [ + createMockTool('regular_tool', undefined, false), + createMockTool('another_tool', undefined, false), + createMockTool('test_mcp_Server_no_raw_name', undefined), + ]; + + const result = await getUserMCPAuthMap({ + userId: 'user123', + tools, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(result).toEqual({}); + expect(mockGetPluginAuthMap).not.toHaveBeenCalled(); + }); + + it('should handle empty or undefined tools array', async () => { + let result = await getUserMCPAuthMap({ + userId: 'user123', + tools: [], + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + expect(result).toEqual({}); + expect(mockGetPluginAuthMap).not.toHaveBeenCalled(); + + result = await getUserMCPAuthMap({ + userId: 'user123', + tools: undefined, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + expect(result).toEqual({}); + expect(mockGetPluginAuthMap).not.toHaveBeenCalled(); + }); + + it('should handle database errors gracefully', async () => { + const tools = [createMockTool('test_mcp_Server1', 'Server1')]; + const dbError = new Error('Database connection failed'); + + mockGetPluginAuthMap.mockRejectedValue(dbError); + + const result = await getUserMCPAuthMap({ + userId: 'user123', + tools, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(result).toEqual({}); + }); + + it('should handle non-Error exceptions gracefully', async () => { + const tools = [createMockTool('test_mcp_Server1', 'Server1')]; + + mockGetPluginAuthMap.mockRejectedValue('String error'); + + const result = await getUserMCPAuthMap({ + userId: 'user123', + tools, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(result).toEqual({}); + }); + }); + + describe('Integration', () => { + it('should handle complete workflow with normalized tool names and original server names', async () => { + const originalServerName = 'Connector: Company'; + const toolName = 'test_auth_mcp_Connector__Company'; + + const tools = [createMockTool(toolName, originalServerName)]; + + const mockCustomUserVars = { + 'mcp_Connector: Company': { + API_KEY: 'test123', + SECRET_TOKEN: 'secret456', + }, + }; + + mockGetPluginAuthMap.mockResolvedValue(mockCustomUserVars); + + const result = await getUserMCPAuthMap({ + userId: 'user123', + tools, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(mockGetPluginAuthMap).toHaveBeenCalledWith({ + userId: 'user123', + pluginKeys: ['mcp_Connector: Company'], + throwError: false, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(result).toEqual(mockCustomUserVars); + }); + }); +}); diff --git a/packages/api/src/mcp/auth.ts b/packages/api/src/mcp/auth.ts index 7f6f6001fa..8221278fd1 100644 --- a/packages/api/src/mcp/auth.ts +++ b/packages/api/src/mcp/auth.ts @@ -3,17 +3,14 @@ import { Constants } from 'librechat-data-provider'; import type { PluginAuthMethods } from '@librechat/data-schemas'; import type { GenericTool } from '@librechat/agents'; import { getPluginAuthMap } from '~/agents/auth'; -import { mcpToolPattern } from './utils'; export async function getUserMCPAuthMap({ userId, tools, - appTools, findPluginAuthsByKeys, }: { userId: string; tools: GenericTool[] | undefined; - appTools: Record; findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys']; }) { if (!tools || tools.length === 0) { @@ -23,11 +20,9 @@ export async function getUserMCPAuthMap({ const uniqueMcpServers = new Set(); for (const tool of tools) { - const toolKey = tool.name; - if (toolKey && appTools[toolKey] && mcpToolPattern.test(toolKey)) { - const parts = toolKey.split(Constants.mcp_delimiter); - const serverName = parts[parts.length - 1]; - uniqueMcpServers.add(`${Constants.mcp_prefix}${serverName}`); + const mcpTool = tool as GenericTool & { mcpRawServerName?: string }; + if (mcpTool.mcpRawServerName) { + uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`); } }