diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index efaa80cfc8..698014cbe0 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -1,8 +1,13 @@ const { logger } = require('@librechat/data-schemas'); const { SerpAPI } = require('@langchain/community/tools/serpapi'); const { Calculator } = require('@langchain/community/tools/calculator'); -const { mcpToolPattern, loadWebSearchAuth, checkAccess } = require('@librechat/api'); const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents'); +const { + checkAccess, + createSafeUser, + mcpToolPattern, + loadWebSearchAuth, +} = require('@librechat/api'); const { Tools, Constants, @@ -410,6 +415,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} /** MCP server tools are initialized sequentially by server */ let index = -1; const failedMCPServers = new Set(); + const safeUser = createSafeUser(options.req?.user); for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) { index++; /** @type {LCAvailableTools} */ @@ -420,14 +426,14 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} continue; } const mcpParams = { - res: options.res, - userId: user, index, - serverName: config.serverName, - userMCPAuthMap, - model: agent?.model ?? model, - provider: agent?.provider ?? endpoint, signal, + user: safeUser, + userMCPAuthMap, + res: options.res, + model: agent?.model ?? model, + serverName: config.serverName, + provider: agent?.provider ?? endpoint, }; if (config.type === 'all' && toolConfigs.length === 1) { diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 9182a2a0cd..b1022136e3 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,7 +1,12 @@ const { Router } = require('express'); const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); -const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap } = require('@librechat/api'); +const { + createSafeUser, + MCPOAuthHandler, + MCPTokenStorage, + getUserMCPAuthMap, +} = require('@librechat/api'); const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); @@ -335,9 +340,9 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => { router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { try { const { serverName } = req.params; - const userId = req.user?.id; + const user = createSafeUser(req.user); - if (!userId) { + if (!user.id) { return res.status(401).json({ error: 'User not authenticated' }); } @@ -351,7 +356,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { }); } - await mcpManager.disconnectUserConnection(userId, serverName); + await mcpManager.disconnectUserConnection(user.id, serverName); logger.info( `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, ); @@ -360,14 +365,14 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { let userMCPAuthMap; if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { userMCPAuthMap = await getUserMCPAuthMap({ - userId, + userId: user.id, servers: [serverName], findPluginAuthsByKeys, }); } const result = await reinitMCPServer({ - userId, + user, serverName, userMCPAuthMap, }); diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 36d61b2337..791f824dbf 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -153,7 +153,7 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) { /** * @param {Object} params * @param {ServerResponse} params.res - The Express response object for sending events. - * @param {string} params.userId - The user ID from the request object. + * @param {IUser} params.user - The user from the request object. * @param {string} params.serverName * @param {AbortSignal} params.signal * @param {string} params.model @@ -161,9 +161,9 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) { * @param {Record>} [params.userMCPAuthMap] * @returns { Promise unknown}>> } An object with `_call` method to execute the tool input. */ -async function reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap }) { +async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) { const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; - const flowId = `${userId}:${serverName}:${Date.now()}`; + const flowId = `${user.id}:${serverName}:${Date.now()}`; const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); const stepId = 'step_oauth_login_' + serverName; const toolCall = { @@ -192,7 +192,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP flowManager, }); return await reinitMCPServer({ - userId, + user, signal, serverName, oauthStart, @@ -212,7 +212,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP * * @param {Object} params * @param {ServerResponse} params.res - The Express response object for sending events. - * @param {string} params.userId - The user ID from the request object. + * @param {IUser} params.user - The user from the request object. * @param {string} params.serverName * @param {string} params.model * @param {Providers | EModelEndpoint} params.provider - The provider for the tool. @@ -221,16 +221,8 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP * @param {Record>} [params.userMCPAuthMap] * @returns { Promise unknown}>> } An object with `_call` method to execute the tool input. */ -async function createMCPTools({ - res, - userId, - index, - signal, - serverName, - provider, - userMCPAuthMap, -}) { - const result = await reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap }); +async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) { + const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }); if (!result || !result.tools) { logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); return; @@ -240,7 +232,7 @@ async function createMCPTools({ for (const tool of result.tools) { const toolInstance = await createMCPTool({ res, - userId, + user, provider, userMCPAuthMap, availableTools: result.availableTools, @@ -258,7 +250,7 @@ async function createMCPTools({ * Creates a single tool from the specified MCP Server via `toolKey`. * @param {Object} params * @param {ServerResponse} params.res - The Express response object for sending events. - * @param {string} params.userId - The user ID from the request object. + * @param {IUser} params.user - The user from the request object. * @param {string} params.toolKey - The toolKey for the tool. * @param {string} params.model - The model for the tool. * @param {number} [params.index] @@ -270,7 +262,7 @@ async function createMCPTools({ */ async function createMCPTool({ res, - userId, + user, index, signal, toolKey, @@ -288,7 +280,7 @@ async function createMCPTool({ ); const result = await reconnectServer({ res, - userId, + user, index, signal, serverName, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index cb90ace24c..9773d58745 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -1,13 +1,45 @@ const { logger } = require('@librechat/data-schemas'); const { MCPOAuthHandler } = require('@librechat/api'); const { CacheKeys } = require('librechat-data-provider'); -const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus } = require('./MCP'); +const { + createMCPTool, + createMCPTools, + getMCPSetupData, + checkOAuthFlowStatus, + getServerConnectionStatus, +} = require('./MCP'); // Mock all dependencies jest.mock('@librechat/data-schemas', () => ({ logger: { debug: jest.fn(), error: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + }, +})); + +jest.mock('@langchain/core/tools', () => ({ + tool: jest.fn((fn, config) => { + const toolInstance = { _call: fn, ...config }; + return toolInstance; + }), +})); + +jest.mock('@librechat/agents', () => ({ + Providers: { + VERTEXAI: 'vertexai', + GOOGLE: 'google', + }, + StepTypes: { + TOOL_CALLS: 'tool_calls', + }, + GraphEvents: { + ON_RUN_STEP_DELTA: 'on_run_step_delta', + ON_RUN_STEP: 'on_run_step', + }, + Constants: { + CONTENT_AND_ARTIFACT: 'content_and_artifact', }, })); @@ -15,12 +47,27 @@ jest.mock('@librechat/api', () => ({ MCPOAuthHandler: { generateFlowId: jest.fn(), }, + sendEvent: jest.fn(), + normalizeServerName: jest.fn((name) => name), + convertWithResolvedRefs: jest.fn((params) => params), })); jest.mock('librechat-data-provider', () => ({ CacheKeys: { FLOWS: 'flows', }, + Constants: { + USE_PRELIM_RESPONSE_MESSAGE_ID: 'prelim_response_id', + mcp_delimiter: '::', + mcp_prefix: 'mcp_', + }, + ContentTypes: { + TEXT: 'text', + }, + isAssistantsEndpoint: jest.fn(() => false), + Time: { + TWO_MINUTES: 120000, + }, })); jest.mock('./Config', () => ({ @@ -44,8 +91,11 @@ jest.mock('~/models', () => ({ updateToken: jest.fn(), })); +jest.mock('./Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); + describe('tests for the new helper functions used by the MCP connection status endpoints', () => { - let mockLoadCustomConfig; let mockGetMCPManager; let mockGetFlowStateManager; let mockGetLogStores; @@ -54,7 +104,6 @@ describe('tests for the new helper functions used by the MCP connection status e beforeEach(() => { jest.clearAllMocks(); - mockLoadCustomConfig = require('./Config').loadCustomConfig; mockGetMCPManager = require('~/config').getMCPManager; mockGetFlowStateManager = require('~/config').getFlowStateManager; mockGetLogStores = require('~/cache').getLogStores; @@ -567,3 +616,275 @@ describe('tests for the new helper functions used by the MCP connection status e }); }); }); + +describe('User parameter passing tests', () => { + let mockReinitMCPServer; + let mockGetFlowStateManager; + let mockGetLogStores; + + beforeEach(() => { + jest.clearAllMocks(); + mockReinitMCPServer = require('./Tools/mcp').reinitMCPServer; + mockGetFlowStateManager = require('~/config').getFlowStateManager; + mockGetLogStores = require('~/cache').getLogStores; + + // Setup default mocks + mockGetLogStores.mockReturnValue({}); + mockGetFlowStateManager.mockReturnValue({ + createFlowWithHandler: jest.fn(), + failFlow: jest.fn(), + }); + }); + + describe('createMCPTools', () => { + it('should pass user parameter to reinitMCPServer when calling reconnectServer internally', async () => { + const mockUser = { id: 'test-user-123', name: 'Test User' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + const mockSignal = new AbortController().signal; + + mockReinitMCPServer.mockResolvedValue({ + tools: [{ name: 'test-tool' }], + availableTools: { + 'test-tool::test-server': { + function: { + description: 'Test tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }, + }); + + await createMCPTools({ + res: mockRes, + user: mockUser, + serverName: 'test-server', + provider: 'openai', + signal: mockSignal, + userMCPAuthMap: {}, + }); + + // Verify reinitMCPServer was called with the user + expect(mockReinitMCPServer).toHaveBeenCalledWith( + expect.objectContaining({ + user: mockUser, + serverName: 'test-server', + }), + ); + expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser); + }); + + it('should throw error if user is not provided', async () => { + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + mockReinitMCPServer.mockResolvedValue({ + tools: [], + availableTools: {}, + }); + + // Call without user should throw error + await expect( + createMCPTools({ + res: mockRes, + user: undefined, + serverName: 'test-server', + provider: 'openai', + userMCPAuthMap: {}, + }), + ).rejects.toThrow("Cannot read properties of undefined (reading 'id')"); + + // Verify reinitMCPServer was not called due to early error + expect(mockReinitMCPServer).not.toHaveBeenCalled(); + }); + }); + + describe('createMCPTool', () => { + it('should pass user parameter to reinitMCPServer when tool not in cache', async () => { + const mockUser = { id: 'test-user-456', email: 'test@example.com' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + const mockSignal = new AbortController().signal; + + mockReinitMCPServer.mockResolvedValue({ + availableTools: { + 'test-tool::test-server': { + function: { + description: 'Test tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }, + }); + + // Call without availableTools to trigger reinit + await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: 'test-tool::test-server', + provider: 'openai', + signal: mockSignal, + userMCPAuthMap: {}, + availableTools: undefined, // Force reinit + }); + + // Verify reinitMCPServer was called with the user + expect(mockReinitMCPServer).toHaveBeenCalledWith( + expect.objectContaining({ + user: mockUser, + serverName: 'test-server', + }), + ); + expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser); + }); + + it('should not call reinitMCPServer when tool is in cache', async () => { + const mockUser = { id: 'test-user-789' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + const availableTools = { + 'test-tool::test-server': { + function: { + description: 'Cached tool', + parameters: { type: 'object', properties: {} }, + }, + }, + }; + + await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: 'test-tool::test-server', + provider: 'openai', + userMCPAuthMap: {}, + availableTools: availableTools, + }); + + // Verify reinitMCPServer was NOT called since tool was in cache + expect(mockReinitMCPServer).not.toHaveBeenCalled(); + }); + }); + + describe('reinitMCPServer (via reconnectServer)', () => { + it('should always receive user parameter when called from createMCPTools', async () => { + const mockUser = { id: 'user-001', role: 'admin' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // Track all calls to reinitMCPServer + const reinitCalls = []; + mockReinitMCPServer.mockImplementation((params) => { + reinitCalls.push(params); + return Promise.resolve({ + tools: [{ name: 'tool1' }, { name: 'tool2' }], + availableTools: { + 'tool1::server1': { function: { description: 'Tool 1', parameters: {} } }, + 'tool2::server1': { function: { description: 'Tool 2', parameters: {} } }, + }, + }); + }); + + await createMCPTools({ + res: mockRes, + user: mockUser, + serverName: 'server1', + provider: 'anthropic', + userMCPAuthMap: {}, + }); + + // Verify all calls to reinitMCPServer had the user + expect(reinitCalls.length).toBeGreaterThan(0); + reinitCalls.forEach((call) => { + expect(call.user).toBe(mockUser); + expect(call.user.id).toBe('user-001'); + }); + }); + + it('should always receive user parameter when called from createMCPTool', async () => { + const mockUser = { id: 'user-002', permissions: ['read', 'write'] }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // Track all calls to reinitMCPServer + const reinitCalls = []; + mockReinitMCPServer.mockImplementation((params) => { + reinitCalls.push(params); + return Promise.resolve({ + availableTools: { + 'my-tool::my-server': { + function: { description: 'My Tool', parameters: {} }, + }, + }, + }); + }); + + await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: 'my-tool::my-server', + provider: 'google', + userMCPAuthMap: {}, + availableTools: undefined, // Force reinit + }); + + // Verify the call to reinitMCPServer had the user + expect(reinitCalls.length).toBe(1); + expect(reinitCalls[0].user).toBe(mockUser); + expect(reinitCalls[0].user.id).toBe('user-002'); + }); + }); + + describe('User parameter integrity', () => { + it('should preserve user object properties through the call chain', async () => { + const complexUser = { + id: 'complex-user', + name: 'John Doe', + email: 'john@example.com', + metadata: { subscription: 'premium', settings: { theme: 'dark' } }, + }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + let capturedUser = null; + mockReinitMCPServer.mockImplementation((params) => { + capturedUser = params.user; + return Promise.resolve({ + tools: [{ name: 'test' }], + availableTools: { + 'test::server': { function: { description: 'Test', parameters: {} } }, + }, + }); + }); + + await createMCPTools({ + res: mockRes, + user: complexUser, + serverName: 'server', + provider: 'openai', + userMCPAuthMap: {}, + }); + + // Verify the complete user object was passed + expect(capturedUser).toEqual(complexUser); + expect(capturedUser.id).toBe('complex-user'); + expect(capturedUser.metadata.subscription).toBe('premium'); + expect(capturedUser.metadata.settings.theme).toBe('dark'); + }); + + it('should throw error when user is null', async () => { + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + mockReinitMCPServer.mockResolvedValue({ + tools: [], + availableTools: {}, + }); + + await expect( + createMCPTools({ + res: mockRes, + user: null, + serverName: 'test-server', + provider: 'openai', + userMCPAuthMap: {}, + }), + ).rejects.toThrow("Cannot read properties of null (reading 'id')"); + + // Verify reinitMCPServer was not called due to early error + expect(mockReinitMCPServer).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 04c27eafe4..2669ba4658 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -7,7 +7,7 @@ const { getLogStores } = require('~/cache'); /** * @param {Object} params - * @param {string} params.userId + * @param {IUser} params.user - The user from the request object. * @param {string} params.serverName - The name of the MCP server * @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish * @param {AbortSignal} [params.signal] - The abort signal to handle cancellation. @@ -18,7 +18,7 @@ const { getLogStores } = require('~/cache'); * @param {Record>} [params.userMCPAuthMap] */ async function reinitMCPServer({ - userId, + user, signal, forceNew, serverName, @@ -51,7 +51,7 @@ async function reinitMCPServer({ try { userConnection = await mcpManager.getUserConnection({ - user: { id: userId }, + user, signal, forceNew, oauthStart, diff --git a/packages/api/src/utils/env.ts b/packages/api/src/utils/env.ts index 767d95101e..43429ef5a6 100644 --- a/packages/api/src/utils/env.ts +++ b/packages/api/src/utils/env.ts @@ -1,5 +1,6 @@ import { extractEnvVariable } from 'librechat-data-provider'; import type { TUser, MCPOptions } from 'librechat-data-provider'; +import type { IUser } from '@librechat/data-schemas'; import type { RequestBody } from '~/types'; /** @@ -26,6 +27,31 @@ const ALLOWED_USER_FIELDS = [ 'termsAccepted', ] as const; +type AllowedUserField = (typeof ALLOWED_USER_FIELDS)[number]; +type SafeUser = Pick; + +/** + * Creates a safe user object containing only allowed fields. + * Optimized for performance while maintaining type safety. + * + * @param user - The user object to extract safe fields from + * @returns A new object containing only allowed fields + */ +export function createSafeUser(user: IUser | null | undefined): Partial { + if (!user) { + return {}; + } + + const safeUser: Partial = {}; + for (const field of ALLOWED_USER_FIELDS) { + if (field in user) { + safeUser[field] = user[field]; + } + } + + return safeUser; +} + /** * List of allowed request body fields that can be used in header placeholders. * These are common fields from the request body that are safe to expose in headers.