From c827fdd10eed422eca137adb20584e79df0d64f5 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 23 Aug 2025 03:27:05 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A6=20feat:=20Auto-reinitialize=20MCP?= =?UTF-8?q?=20Servers=20on=20Request=20(#9226)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/clients/tools/util/handleTools.js | 60 +++- api/models/Agent.js | 11 +- api/server/controllers/agents/client.js | 34 +-- api/server/controllers/agents/request.js | 68 +++-- api/server/routes/__tests__/mcp.spec.js | 57 +++- api/server/routes/mcp.js | 110 ++------ api/server/services/Config/getCachedTools.js | 6 +- api/server/services/Config/getCustomConfig.js | 29 +- api/server/services/Config/mcpToolsCache.js | 3 +- api/server/services/Endpoints/agents/agent.js | 32 ++- .../services/Endpoints/agents/initialize.js | 20 +- api/server/services/MCP.js | 261 +++++++++++++++--- api/server/services/ToolService.js | 29 +- api/server/services/Tools/mcp.js | 142 ++++++++++ api/typedefs.js | 13 + client/src/hooks/Chat/useChatFunctions.ts | 8 +- client/src/hooks/SSE/useEventHandlers.ts | 3 +- client/src/hooks/SSE/useSSE.ts | 2 + client/src/hooks/SSE/useStepHandler.ts | 62 ++++- packages/api/src/index.ts | 1 + packages/api/src/mcp/MCPConnectionFactory.ts | 8 +- packages/api/src/mcp/UserConnectionManager.ts | 26 +- packages/api/src/mcp/__tests__/auth.test.ts | 121 +++++++- packages/api/src/mcp/auth.ts | 63 +++-- packages/api/src/mcp/connection.ts | 2 +- packages/api/src/mcp/types/index.ts | 1 + packages/api/src/tools/format.ts | 7 +- packages/data-provider/src/config.ts | 4 + 28 files changed, 871 insertions(+), 312 deletions(-) create mode 100644 api/server/services/Tools/mcp.js diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index fc04d0c58..ea127e4ab 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -3,7 +3,7 @@ const { SerpAPI } = require('@langchain/community/tools/serpapi'); const { Calculator } = require('@langchain/community/tools/calculator'); const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api'); const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents'); -const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider'); +const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider'); const { availableTools, manifestToolMap, @@ -24,9 +24,9 @@ const { const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); +const { createMCPTool, createMCPTools } = require('~/server/services/MCP'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getCachedTools } = require('~/server/services/Config'); -const { createMCPTool } = require('~/server/services/MCP'); /** * Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values. @@ -123,6 +123,8 @@ const getAuthFields = (toolKey) => { * * @param {object} object * @param {string} object.user + * @param {Record>} [object.userMCPAuthMap] + * @param {AbortSignal} [object.signal] * @param {Pick} [object.agent] * @param {string} [object.model] * @param {EModelEndpoint} [object.endpoint] @@ -137,7 +139,9 @@ const loadTools = async ({ user, agent, model, + signal, endpoint, + userMCPAuthMap, tools = [], options = {}, functions = true, @@ -231,6 +235,7 @@ const loadTools = async ({ /** @type {Record} */ const toolContextMap = {}; const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {}; + const requestedMCPTools = {}; for (const tool of tools) { if (tool === Tools.execute_code) { @@ -299,14 +304,35 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} }; continue; } else if (tool && cachedTools && mcpToolPattern.test(tool)) { - requestedTools[tool] = async () => + const [toolName, serverName] = tool.split(Constants.mcp_delimiter); + if (toolName === Constants.mcp_all) { + const currentMCPGenerator = async (index) => + createMCPTools({ + req: options.req, + res: options.res, + index, + serverName, + userMCPAuthMap, + model: agent?.model ?? model, + provider: agent?.provider ?? endpoint, + signal, + }); + requestedMCPTools[serverName] = [currentMCPGenerator]; + continue; + } + const currentMCPGenerator = async (index) => createMCPTool({ + index, req: options.req, res: options.res, toolKey: tool, + userMCPAuthMap, model: agent?.model ?? model, provider: agent?.provider ?? endpoint, + signal, }); + requestedMCPTools[serverName] = requestedMCPTools[serverName] || []; + requestedMCPTools[serverName].push(currentMCPGenerator); continue; } @@ -346,6 +372,34 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} } const loadedTools = (await Promise.all(toolPromises)).flatMap((plugin) => plugin || []); + const mcpToolPromises = []; + /** MCP server tools are initialized sequentially by server */ + let index = -1; + for (const [serverName, generators] of Object.entries(requestedMCPTools)) { + index++; + for (const generator of generators) { + try { + if (generator && generators.length === 1) { + mcpToolPromises.push( + generator(index).catch((error) => { + logger.error(`Error loading ${serverName} tools:`, error); + return null; + }), + ); + continue; + } + const mcpTool = await generator(index); + if (Array.isArray(mcpTool)) { + loadedTools.push(...mcpTool); + } else if (mcpTool) { + loadedTools.push(mcpTool); + } + } catch (error) { + logger.error(`Error loading MCP tool for server ${serverName}:`, error); + } + } + } + loadedTools.push(...(await Promise.all(mcpToolPromises)).flatMap((plugin) => plugin || [])); return { loadedTools, toolContextMap }; }; diff --git a/api/models/Agent.js b/api/models/Agent.js index be9fe62e6..13fc1e472 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -2,7 +2,7 @@ const mongoose = require('mongoose'); const crypto = require('node:crypto'); const { logger } = require('@librechat/data-schemas'); const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider'); -const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } = +const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_all, mcp_delimiter } = require('librechat-data-provider').Constants; const { removeAgentFromAllProjects, @@ -78,6 +78,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _ tools.push(Tools.web_search); } + const addedServers = new Set(); if (mcpServers.size > 0) { for (const toolName of Object.keys(availableTools)) { if (!toolName.includes(mcp_delimiter)) { @@ -85,9 +86,17 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _ } const mcpServer = toolName.split(mcp_delimiter)?.[1]; if (mcpServer && mcpServers.has(mcpServer)) { + addedServers.add(mcpServer); tools.push(toolName); } } + + for (const mcpServer of mcpServers) { + if (addedServers.has(mcpServer)) { + continue; + } + tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`); + } } const instructions = req.body.promptPrefix; diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 14c9c3822..897e8f84f 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -33,18 +33,13 @@ const { bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); -const { - findPluginAuthsByKeys, - getFormattedMemories, - deleteMemory, - setMemory, -} = require('~/models'); -const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config'); const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getFormattedMemories, deleteMemory, setMemory } = require('~/models'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { getProviderConfig } = require('~/server/services/Endpoints'); +const { checkCapability } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); const { getRoleByName } = require('~/models/Role'); const { loadAgent } = require('~/models/Agent'); @@ -615,6 +610,7 @@ class AgentClient extends BaseClient { await this.chatCompletion({ payload, onProgress: opts.onProgress, + userMCPAuthMap: opts.userMCPAuthMap, abortController: opts.abortController, }); return this.contentParts; @@ -747,7 +743,13 @@ class AgentClient extends BaseClient { return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate; } - async chatCompletion({ payload, abortController = null }) { + /** + * @param {object} params + * @param {string | ChatCompletionMessageParam[]} params.payload + * @param {Record>} [params.userMCPAuthMap] + * @param {AbortController} [params.abortController] + */ + async chatCompletion({ payload, userMCPAuthMap, abortController = null }) { /** @type {Partial} */ let config; /** @type {ReturnType} */ @@ -903,21 +905,9 @@ class AgentClient extends BaseClient { run.Graph.contentData = contentData; } - try { - if (await hasCustomUserVars()) { - config.configurable.userMCPAuthMap = await getMCPAuthMap({ - tools: agent.tools, - userId: this.options.req.user.id, - findPluginAuthsByKeys, - }); - } - } catch (err) { - logger.error( - `[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`, - err, - ); + if (userMCPAuthMap != null) { + config.configurable.userMCPAuthMap = userMCPAuthMap; } - await run.processStream({ messages }, config, { keepContent: i !== 0, tokenCounter: createTokenCounter(this.getEncoding()), diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 8054a6a68..110d2fdd5 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -9,6 +9,24 @@ const { const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup'); const { saveMessage } = require('~/models'); +function createCloseHandler(abortController) { + return function (manual) { + if (!manual) { + logger.debug('[AgentController] Request closed'); + } + if (!abortController) { + return; + } else if (abortController.signal.aborted) { + return; + } else if (abortController.requestCompleted) { + return; + } + + abortController.abort(); + logger.debug('[AgentController] Request aborted on close'); + }; +} + const AgentController = async (req, res, next, initializeClient, addTitle) => { let { text, @@ -31,7 +49,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { let userMessagePromise; let getAbortData; let client = null; - // Initialize as an array let cleanupHandlers = []; const newConvo = !conversationId; @@ -62,9 +79,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { // Create a function to handle final cleanup const performCleanup = () => { logger.debug('[AgentController] Performing cleanup'); - // Make sure cleanupHandlers is an array before iterating if (Array.isArray(cleanupHandlers)) { - // Execute all cleanup handlers for (const handler of cleanupHandlers) { try { if (typeof handler === 'function') { @@ -105,8 +120,33 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { }; try { - /** @type {{ client: TAgentClient }} */ - const result = await initializeClient({ req, res, endpointOption }); + let prelimAbortController = new AbortController(); + const prelimCloseHandler = createCloseHandler(prelimAbortController); + res.on('close', prelimCloseHandler); + const removePrelimHandler = (manual) => { + try { + prelimCloseHandler(manual); + res.removeListener('close', prelimCloseHandler); + } catch (e) { + logger.error('[AgentController] Error removing close listener', e); + } + }; + cleanupHandlers.push(removePrelimHandler); + /** @type {{ client: TAgentClient; userMCPAuthMap?: Record> }} */ + const result = await initializeClient({ + req, + res, + endpointOption, + signal: prelimAbortController.signal, + }); + if (prelimAbortController.signal?.aborted) { + prelimAbortController = null; + throw new Error('Request was aborted before initialization could complete'); + } else { + prelimAbortController = null; + removePrelimHandler(true); + cleanupHandlers.pop(); + } client = result.client; // Register client with finalization registry if available @@ -138,22 +178,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { }; const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); - - // Simple handler to avoid capturing scope - const closeHandler = () => { - logger.debug('[AgentController] Request closed'); - if (!abortController) { - return; - } else if (abortController.signal.aborted) { - return; - } else if (abortController.requestCompleted) { - return; - } - - abortController.abort(); - logger.debug('[AgentController] Request aborted on close'); - }; - + const closeHandler = createCloseHandler(abortController); res.on('close', closeHandler); cleanupHandlers.push(() => { try { @@ -175,6 +200,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { abortController, overrideParentMessageId, isEdited: !!editedContent, + userMCPAuthMap: result.userMCPAuthMap, responseMessageId: editedResponseMessageId, progressOptions: { res, diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 25d17c51d..b572340c5 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -11,6 +11,7 @@ jest.mock('@librechat/api', () => ({ completeOAuthFlow: jest.fn(), generateFlowId: jest.fn(), }, + getUserMCPAuthMap: jest.fn(), })); jest.mock('@librechat/data-schemas', () => ({ @@ -37,6 +38,7 @@ jest.mock('~/models', () => ({ updateToken: jest.fn(), createToken: jest.fn(), deleteTokens: jest.fn(), + findPluginAuthsByKeys: jest.fn(), })); jest.mock('~/server/services/Config', () => ({ @@ -71,6 +73,10 @@ jest.mock('~/server/middleware', () => ({ requireJwtAuth: (req, res, next) => next(), })); +jest.mock('~/server/services/Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); + describe('MCP Routes', () => { let app; let mongoServer; @@ -682,6 +688,13 @@ describe('MCP Routes', () => { require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); + require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({ + success: true, + message: "MCP server 'oauth-server' ready for OAuth authentication", + serverName: 'oauth-server', + oauthRequired: true, + oauthUrl: 'https://oauth.example.com/auth', + }); const response = await request(app).post('/api/mcp/oauth-server/reinitialize'); @@ -706,6 +719,7 @@ describe('MCP Routes', () => { require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); + require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue(null); const response = await request(app).post('/api/mcp/error-server/reinitialize'); @@ -769,6 +783,14 @@ describe('MCP Routes', () => { setCachedTools.mockResolvedValue(); updateMCPUserTools.mockResolvedValue(); + require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({ + success: true, + message: "MCP server 'test-server' reinitialized successfully", + serverName: 'test-server', + oauthRequired: false, + oauthUrl: null, + }); + const response = await request(app).post('/api/mcp/test-server/reinitialize'); expect(response.status).toBe(200); @@ -783,14 +805,6 @@ describe('MCP Routes', () => { 'test-user-id', 'test-server', ); - expect(updateMCPUserTools).toHaveBeenCalledWith({ - userId: 'test-user-id', - serverName: 'test-server', - tools: [ - { name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } }, - { name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } }, - ], - }); }); it('should handle server with custom user variables', async () => { @@ -812,9 +826,14 @@ describe('MCP Routes', () => { require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); - require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue( - 'api-key-value', - ); + require('@librechat/api').getUserMCPAuthMap.mockResolvedValue({ + 'mcp:test-server': { + API_KEY: 'api-key-value', + }, + }); + require('~/models').findPluginAuthsByKeys.mockResolvedValue([ + { key: 'API_KEY', value: 'api-key-value' }, + ]); const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache'); @@ -822,13 +841,23 @@ describe('MCP Routes', () => { setCachedTools.mockResolvedValue(); updateMCPUserTools.mockResolvedValue(); + require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({ + success: true, + message: "MCP server 'test-server' reinitialized successfully", + serverName: 'test-server', + oauthRequired: false, + oauthUrl: null, + }); + const response = await request(app).post('/api/mcp/test-server/reinitialize'); expect(response.status).toBe(200); expect(response.body.success).toBe(true); - expect( - require('~/server/services/PluginService').getUserPluginAuthValue, - ).toHaveBeenCalledWith('test-user-id', 'API_KEY', false); + expect(require('@librechat/api').getUserMCPAuthMap).toHaveBeenCalledWith({ + userId: 'test-user-id', + servers: ['test-server'], + findPluginAuthsByKeys: require('~/models').findPluginAuthsByKeys, + }); }); }); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 0a068cd8f..d41cc6d73 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,13 +1,15 @@ -const { logger } = require('@librechat/data-schemas'); -const { MCPOAuthHandler } = require('@librechat/api'); const { Router } = require('express'); +const { logger } = require('@librechat/data-schemas'); +const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { CacheKeys, Constants } = require('librechat-data-provider'); const { getMCPManager, getFlowStateManager } = require('~/config'); +const { reinitMCPServer } = require('~/server/services/Tools/mcp'); const { requireJwtAuth } = require('~/server/middleware'); +const { findPluginAuthsByKeys } = require('~/models'); const { getLogStores } = require('~/cache'); const router = Router(); @@ -302,107 +304,39 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { }); } - const flowsCache = getLogStores(CacheKeys.FLOWS); - const flowManager = getFlowStateManager(flowsCache); - await mcpManager.disconnectUserConnection(user.id, serverName); logger.info( `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, ); - let customUserVars = {}; + /** @type {Record> | undefined} */ + let userMCPAuthMap; if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { - for (const varName of Object.keys(serverConfig.customUserVars)) { - try { - const value = await getUserPluginAuthValue(user.id, varName, false); - customUserVars[varName] = value; - } catch (err) { - logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err); - } - } - } - - let userConnection = null; - let oauthRequired = false; - let oauthUrl = null; - - try { - userConnection = await mcpManager.getUserConnection({ - user, - serverName, - flowManager, - customUserVars, - tokenMethods: { - findToken, - updateToken, - createToken, - deleteTokens, - }, - returnOnOAuth: true, - oauthStart: async (authURL) => { - logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`); - oauthUrl = authURL; - oauthRequired = true; - }, - }); - - logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); - } catch (err) { - logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`); - logger.info( - `[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`, - ); - - const isOAuthError = - err.message?.includes('OAuth') || - err.message?.includes('authentication') || - err.message?.includes('401'); - - const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early'; - - if (isOAuthError || oauthRequired || isOAuthFlowInitiated) { - logger.info( - `[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`, - ); - oauthRequired = true; - } else { - logger.error( - `[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`, - err, - ); - return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); - } - } - - if (userConnection && !oauthRequired) { - const tools = await userConnection.fetchTools(); - await updateMCPUserTools({ + userMCPAuthMap = await getUserMCPAuthMap({ userId: user.id, - serverName, - tools, + servers: [serverName], + findPluginAuthsByKeys, }); } - logger.debug( - `[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`, - ); + const result = await reinitMCPServer({ + req, + serverName, + userMCPAuthMap, + }); - const getResponseMessage = () => { - if (oauthRequired) { - return `MCP server '${serverName}' ready for OAuth authentication`; - } - if (userConnection) { - return `MCP server '${serverName}' reinitialized successfully`; - } - return `Failed to reinitialize MCP server '${serverName}'`; - }; + if (!result) { + return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); + } + + const { success, message, oauthRequired, oauthUrl } = result; res.json({ - success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)), - message: getResponseMessage(), + success, + message, + oauthUrl, serverName, oauthRequired, - oauthUrl, }); } catch (error) { logger.error('[MCP Reinitialize] Unexpected error', error); diff --git a/api/server/services/Config/getCachedTools.js b/api/server/services/Config/getCachedTools.js index b3a4f0c86..669c179a8 100644 --- a/api/server/services/Config/getCachedTools.js +++ b/api/server/services/Config/getCachedTools.js @@ -26,7 +26,7 @@ const ToolCacheKeys = { * @param {string[]} [options.roleIds] - Role IDs for role-based tools * @param {string[]} [options.groupIds] - Group IDs for group-based tools * @param {boolean} [options.includeGlobal=true] - Whether to include global tools - * @returns {Promise} The available tools object or null if not cached + * @returns {Promise} The available tools object or null if not cached */ async function getCachedTools(options = {}) { const cache = getLogStores(CacheKeys.CONFIG_STORE); @@ -41,13 +41,13 @@ async function getCachedTools(options = {}) { // Future implementation will merge tools from multiple sources // based on user permissions, roles, and groups if (userId) { - // Check if we have pre-computed effective tools for this user + /** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */ const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId)); if (effectiveTools) { return effectiveTools; } - // Otherwise, compute from individual sources + /** @type {LCAvailableTools | null} Otherwise, compute from individual sources */ const toolSources = []; if (includeGlobal) { diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index 2b9f658b4..ced319050 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -1,5 +1,4 @@ -const { logger } = require('@librechat/data-schemas'); -const { isEnabled, getUserMCPAuthMap } = require('@librechat/api'); +const { isEnabled } = require('@librechat/api'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { normalizeEndpointName } = require('~/server/utils'); const loadCustomConfig = require('./loadCustomConfig'); @@ -53,31 +52,6 @@ const getCustomEndpointConfig = async (endpoint) => { ); }; -/** - * @param {Object} params - * @param {string} params.userId - * @param {GenericTool[]} [params.tools] - * @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys - * @returns {Promise> | undefined>} - */ -async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) { - try { - if (!tools || tools.length === 0) { - return; - } - return await getUserMCPAuthMap({ - tools, - userId, - findPluginAuthsByKeys, - }); - } catch (err) { - logger.error( - `[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`, - err, - ); - } -} - /** * @returns {Promise} */ @@ -88,7 +62,6 @@ async function hasCustomUserVars() { } module.exports = { - getMCPAuthMap, getCustomConfig, getBalanceConfig, hasCustomUserVars, diff --git a/api/server/services/Config/mcpToolsCache.js b/api/server/services/Config/mcpToolsCache.js index bd21807cd..d335868d8 100644 --- a/api/server/services/Config/mcpToolsCache.js +++ b/api/server/services/Config/mcpToolsCache.js @@ -9,7 +9,7 @@ const { getLogStores } = require('~/cache'); * @param {string} params.userId - User ID * @param {string} params.serverName - MCP server name * @param {Array} params.tools - Array of tool objects from MCP server - * @returns {Promise} + * @returns {Promise} */ async function updateMCPUserTools({ userId, serverName, tools }) { try { @@ -39,6 +39,7 @@ async function updateMCPUserTools({ userId, serverName, tools }) { const cache = getLogStores(CacheKeys.CONFIG_STORE); await cache.delete(CacheKeys.TOOLS); logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`); + return userTools; } catch (error) { logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error); throw error; diff --git a/api/server/services/Endpoints/agents/agent.js b/api/server/services/Endpoints/agents/agent.js index a64ce97e7..40dbe1700 100644 --- a/api/server/services/Endpoints/agents/agent.js +++ b/api/server/services/Endpoints/agents/agent.js @@ -30,7 +30,13 @@ const { getModelMaxTokens } = require('~/utils'); * @param {TEndpointOption} [params.endpointOption] * @param {Set} [params.allowedProviders] * @param {boolean} [params.isInitialAgent] - * @returns {Promise, toolContextMap: Record, maxContextTokens: number }>} + * @returns {Promise, + * toolContextMap: Record, + * maxContextTokens: number, + * userMCPAuthMap?: Record> + * }>} */ const initializeAgent = async ({ req, @@ -91,16 +97,19 @@ const initializeAgent = async ({ }); const provider = agent.provider; - const { tools: structuredTools, toolContextMap } = - (await loadTools?.({ - req, - res, - provider, - agentId: agent.id, - tools: agent.tools, - model: agent.model, - tool_resources, - })) ?? {}; + const { + tools: structuredTools, + toolContextMap, + userMCPAuthMap, + } = (await loadTools?.({ + req, + res, + provider, + agentId: agent.id, + tools: agent.tools, + model: agent.model, + tool_resources, + })) ?? {}; agent.endpoint = provider; const { getOptions, overrideProvider } = await getProviderConfig(provider); @@ -189,6 +198,7 @@ const initializeAgent = async ({ tools, attachments, resendFiles, + userMCPAuthMap, toolContextMap, useLegacyContent: !!options.useLegacyContent, maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9), diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index dfe780c41..74e134375 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -19,7 +19,10 @@ const AgentClient = require('~/server/controllers/agents/client'); const { getAgent } = require('~/models/Agent'); const { logViolation } = require('~/cache'); -function createToolLoader() { +/** + * @param {AbortSignal} signal + */ +function createToolLoader(signal) { /** * @param {object} params * @param {ServerRequest} params.req @@ -29,7 +32,11 @@ function createToolLoader() { * @param {string} params.provider * @param {string} params.model * @param {AgentToolResources} params.tool_resources - * @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record } | undefined>} + * @returns {Promise<{ + * tools: StructuredTool[], + * toolContextMap: Record, + * userMCPAuthMap?: Record> + * } | undefined>} */ return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) { const agent = { id: agentId, tools, provider, model }; @@ -38,6 +45,7 @@ function createToolLoader() { req, res, agent, + signal, tool_resources, }); } catch (error) { @@ -46,7 +54,7 @@ function createToolLoader() { }; } -const initializeClient = async ({ req, res, endpointOption }) => { +const initializeClient = async ({ req, res, signal, endpointOption }) => { if (!endpointOption) { throw new Error('Endpoint option not provided'); } @@ -92,7 +100,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { /** @type {Set} */ const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders); - const loadTools = createToolLoader(); + const loadTools = createToolLoader(signal); /** @type {Array} */ const requestFiles = req.body.files ?? []; /** @type {string} */ @@ -111,6 +119,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { }); const agent_ids = primaryConfig.agent_ids; + let userMCPAuthMap = primaryConfig.userMCPAuthMap; if (agent_ids?.length) { for (const agentId of agent_ids) { const agent = await getAgent({ id: agentId }); @@ -140,6 +149,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { endpointOption, allowedProviders, }); + Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {}); agentConfigs.set(agentId, config); } } @@ -188,7 +198,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { : EModelEndpoint.agents, }); - return { client }; + return { client, userMCPAuthMap }; }; module.exports = { initializeClient }; diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index eb9d91e4c..2fc03f629 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,7 +1,12 @@ const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); const { logger } = require('@librechat/data-schemas'); -const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents'); +const { + Providers, + StepTypes, + GraphEvents, + Constants: AgentConstants, +} = require('@librechat/agents'); const { sendEvent, MCPOAuthHandler, @@ -11,14 +16,14 @@ const { const { Time, CacheKeys, - StepTypes, Constants, ContentTypes, isAssistantsEndpoint, } = require('librechat-data-provider'); +const { getCachedTools, loadCustomConfig } = require('./Config'); const { findToken, createToken, updateToken } = require('~/models'); const { getMCPManager, getFlowStateManager } = require('~/config'); -const { getCachedTools, loadCustomConfig } = require('./Config'); +const { reinitMCPServer } = require('./Tools/mcp'); const { getLogStores } = require('~/cache'); /** @@ -26,16 +31,13 @@ const { getLogStores } = require('~/cache'); * @param {ServerResponse} params.res - The Express response object for sending events. * @param {string} params.stepId - The ID of the step in the flow. * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. - * @param {string} params.loginFlowId - The ID of the login flow. - * @param {FlowStateManager} params.flowManager - The flow manager instance. */ -function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) { +function createRunStepDeltaEmitter({ res, stepId, toolCall }) { /** - * Creates a function to handle OAuth login requests. * @param {string} authURL - The URL to redirect the user for OAuth authentication. - * @returns {Promise} Returns true to indicate the event was sent successfully. + * @returns {void} */ - return async function (authURL) { + return function (authURL) { /** @type {{ id: string; delta: AgentToolCallDelta }} */ const data = { id: stepId, @@ -46,17 +48,54 @@ function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, sig expires_at: Date.now() + Time.TWO_MINUTES, }, }; - /** Used to ensure the handler (use of `sendEvent`) is only invoked once */ - await flowManager.createFlowWithHandler( - loginFlowId, - 'oauth_login', - async () => { - sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); - logger.debug('Sent OAuth login request to client'); - return true; + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); + }; +} + +/** + * @param {object} params + * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {string} params.runId - The Run ID, i.e. message ID + * @param {string} params.stepId - The ID of the step in the flow. + * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. + * @param {number} [params.index] + */ +function createRunStepEmitter({ res, runId, stepId, toolCall, index }) { + return function () { + /** @type {import('@librechat/agents').RunStep} */ + const data = { + runId: runId ?? Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + id: stepId, + type: StepTypes.TOOL_CALLS, + index: index ?? 0, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [toolCall], }, - signal, - ); + }; + sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data }); + }; +} + +/** + * Creates a function used to ensure the flow handler is only invoked once + * @param {object} params + * @param {string} params.flowId - The ID of the login flow. + * @param {FlowStateManager} params.flowManager - The flow manager instance. + * @param {(authURL: string) => void} [params.callback] + */ +function createOAuthStart({ flowId, flowManager, callback }) { + /** + * Creates a function to handle OAuth login requests. + * @param {string} authURL - The URL to redirect the user for OAuth authentication. + * @returns {Promise} Returns true to indicate the event was sent successfully. + */ + return async function (authURL) { + await flowManager.createFlowWithHandler(flowId, 'oauth_login', async () => { + callback?.(authURL); + logger.debug('Sent OAuth login request to client'); + return true; + }); }; } @@ -99,23 +138,166 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) { } /** - * Creates a general tool for an entire action set. + * @param {Object} params + * @param {() => void} params.runStepEmitter + * @param {(authURL: string) => void} params.runStepDeltaEmitter + * @returns {(authURL: string) => void} + */ +function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) { + return function (authURL) { + runStepEmitter(); + runStepDeltaEmitter(authURL); + }; +} + +/** + * @param {Object} params + * @param {ServerRequest} params.req - The Express request object, containing user/request info. + * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {string} params.serverName + * @param {AbortSignal} params.signal + * @param {string} params.model + * @param {number} [params.index] + * @param {Record>} [params.userMCPAuthMap] + * @returns { Promise unknown}>> } An object with `_call` method to execute the tool input. + */ +async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) { + const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; + const flowId = `${req.user?.id}:${serverName}:${Date.now()}`; + const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); + const stepId = 'step_oauth_login_' + serverName; + const toolCall = { + id: flowId, + name: serverName, + type: 'tool_call_chunk', + }; + + const runStepEmitter = createRunStepEmitter({ + res, + index, + runId, + stepId, + toolCall, + }); + const runStepDeltaEmitter = createRunStepDeltaEmitter({ + res, + stepId, + toolCall, + }); + const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }); + const oauthStart = createOAuthStart({ + res, + flowId, + callback, + flowManager, + }); + return await reinitMCPServer({ + req, + signal, + serverName, + oauthStart, + flowManager, + userMCPAuthMap, + forceNew: true, + returnOnOAuth: false, + connectionTimeout: Time.TWO_MINUTES, + }); +} + +/** + * Creates all tools from the specified MCP Server via `toolKey`. * - * @param {Object} params - The parameters for loading action sets. + * This function assumes tools could not be aggregated from the cache of tool definitions, + * i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated. + * + * @param {Object} params + * @param {ServerRequest} params.req - The Express request object, containing user/request info. + * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {string} params.serverName + * @param {string} params.model + * @param {Providers | EModelEndpoint} params.provider - The provider for the tool. + * @param {number} [params.index] + * @param {AbortSignal} [params.signal] + * @param {Record>} [params.userMCPAuthMap] + * @returns { Promise unknown}>> } An object with `_call` method to execute the tool input. + */ +async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) { + const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }); + if (!result || !result.tools) { + logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); + return; + } + + const serverTools = []; + for (const tool of result.tools) { + const toolInstance = await createMCPTool({ + req, + res, + provider, + userMCPAuthMap, + availableTools: result.availableTools, + toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`, + }); + if (toolInstance) { + serverTools.push(toolInstance); + } + } + + return serverTools; +} + +/** + * Creates a single tool from the specified MCP Server via `toolKey`. + * @param {Object} params * @param {ServerRequest} params.req - The Express request object, containing user/request info. * @param {ServerResponse} params.res - The Express response object for sending events. * @param {string} params.toolKey - The toolKey for the tool. - * @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool. * @param {string} params.model - The model for the tool. + * @param {number} [params.index] + * @param {AbortSignal} [params.signal] + * @param {Providers | EModelEndpoint} params.provider - The provider for the tool. + * @param {LCAvailableTools} [params.availableTools] + * @param {Record>} [params.userMCPAuthMap] * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ -async function createMCPTool({ req, res, toolKey, provider: _provider }) { - const availableTools = await getCachedTools({ userId: req.user?.id, includeGlobal: true }); - const toolDefinition = availableTools?.[toolKey]?.function; +async function createMCPTool({ + req, + res, + index, + signal, + toolKey, + provider, + userMCPAuthMap, + availableTools: tools, +}) { + const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); + const availableTools = + tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true })); + /** @type {LCTool | undefined} */ + let toolDefinition = availableTools?.[toolKey]?.function; if (!toolDefinition) { - logger.error(`Tool ${toolKey} not found in available tools`); - return null; + logger.warn( + `[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`, + ); + const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }); + toolDefinition = result?.availableTools?.[toolKey]?.function; } + + if (!toolDefinition) { + logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`); + return; + } + + return createToolInstance({ + res, + provider, + toolName, + serverName, + toolDefinition, + }); +} + +function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) { /** @type {LCTool} */ const { description, parameters } = toolDefinition; const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; @@ -128,16 +310,8 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) { schema = z.object({ input: z.string().optional() }); } - const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`; - if (!req.user?.id) { - logger.error( - `[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`, - ); - throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`); - } - /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise} */ const _call = async (toolArguments, config) => { const userId = config?.configurable?.user?.id || config?.configurable?.user_id; @@ -154,14 +328,16 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) { const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; - const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; - const oauthStart = createOAuthStart({ + const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; + const runStepDeltaEmitter = createRunStepDeltaEmitter({ res, stepId, toolCall, - loginFlowId, + }); + const oauthStart = createOAuthStart({ + flowId, flowManager, - signal: derivedSignal, + callback: runStepDeltaEmitter, }); const oauthEnd = createOAuthEnd({ res, @@ -207,7 +383,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) { return result; } catch (error) { logger.error( - `[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`, + `[MCP][${serverName}][${toolName}][User: ${userId}] Error calling MCP tool:`, error, ); @@ -220,12 +396,12 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) { if (isOAuthError) { throw new Error( - `OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`, + `[MCP][${serverName}][${toolName}] OAuth authentication required. Please check the server logs for the authentication URL.`, ); } throw new Error( - `"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`, + `[MCP][${serverName}][${toolName}] tool call failed${error?.message ? `: ${error?.message}` : '.'}`, ); } finally { // Clean up abort handler to prevent memory leaks @@ -380,6 +556,7 @@ async function getServerConnectionStatus( module.exports = { createMCPTool, + createMCPTools, getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus, diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 2f2062e14..4f6c1ed3e 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -1,9 +1,9 @@ const fs = require('fs'); const path = require('path'); const { sleep } = require('@librechat/agents'); -const { getToolkitKey } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { zodToJsonSchema } = require('zod-to-json-schema'); +const { getToolkitKey, getUserMCPAuthMap } = require('@librechat/api'); const { Calculator } = require('@langchain/community/tools/calculator'); const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools'); const { @@ -33,12 +33,17 @@ const { toolkits, } = require('~/app/clients/tools'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); -const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config'); +const { + getEndpointsConfig, + hasCustomUserVars, + getCachedTools, +} = require('~/server/services/Config'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { isActionDomainAllowed } = require('~/server/services/domains'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); +const { findPluginAuthsByKeys } = require('~/models'); /** * Loads and formats tools from the specified tool directory. @@ -469,11 +474,12 @@ async function processRequiredActions(client, requiredActions) { * @param {Object} params - Run params containing user and request information. * @param {ServerRequest} params.req - The request object. * @param {ServerResponse} params.res - The request object. + * @param {AbortSignal} params.signal * @param {Pick} The agent tools. + * @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record> }>} The agent tools. */ -async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) { +async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) { if (!agent.tools || agent.tools.length === 0) { return {}; } else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) { @@ -523,8 +529,20 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) webSearchCallbacks = createOnSearchResults(res); } + /** @type {Record>} */ + let userMCPAuthMap; + if (await hasCustomUserVars()) { + userMCPAuthMap = await getUserMCPAuthMap({ + tools: agent.tools, + userId: req.user.id, + findPluginAuthsByKeys, + }); + } + const { loadedTools, toolContextMap } = await loadTools({ agent, + signal, + userMCPAuthMap, functions: true, user: req.user.id, tools: _agentTools, @@ -588,6 +606,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) if (!checkCapability(AgentCapabilities.actions)) { return { tools: agentTools, + userMCPAuthMap, toolContextMap, }; } @@ -599,6 +618,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) } return { tools: agentTools, + userMCPAuthMap, toolContextMap, }; } @@ -707,6 +727,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) return { tools: agentTools, toolContextMap, + userMCPAuthMap, }; } diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js new file mode 100644 index 000000000..f6efbb78a --- /dev/null +++ b/api/server/services/Tools/mcp.js @@ -0,0 +1,142 @@ +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys, Constants } = require('librechat-data-provider'); +const { findToken, createToken, updateToken, deleteTokens } = require('~/models'); +const { getMCPManager, getFlowStateManager } = require('~/config'); +const { updateMCPUserTools } = require('~/server/services/Config'); +const { getLogStores } = require('~/cache'); + +/** + * @param {Object} params + * @param {ServerRequest} params.req + * @param {string} params.serverName - The name of the MCP server + * @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish + * @param {AbortSignal} [params.signal] - The abort signal to handle cancellation. + * @param {boolean} [params.forceNew] + * @param {number} [params.connectionTimeout] + * @param {FlowStateManager} [params.flowManager] + * @param {(authURL: string) => Promise} [params.oauthStart] + * @param {Record>} [params.userMCPAuthMap] + */ +async function reinitMCPServer({ + req, + signal, + forceNew, + serverName, + userMCPAuthMap, + connectionTimeout, + returnOnOAuth = true, + oauthStart: _oauthStart, + flowManager: _flowManager, +}) { + /** @type {MCPConnection | null} */ + let userConnection = null; + /** @type {LCAvailableTools | null} */ + let availableTools = null; + /** @type {ReturnType | null} */ + let tools = null; + let oauthRequired = false; + let oauthUrl = null; + try { + const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; + const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS)); + const mcpManager = getMCPManager(); + + const oauthStart = + _oauthStart ?? + (async (authURL) => { + logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`); + oauthUrl = authURL; + oauthRequired = true; + }); + + try { + userConnection = await mcpManager.getUserConnection({ + user: req.user, + signal, + forceNew, + oauthStart, + serverName, + flowManager, + returnOnOAuth, + customUserVars, + connectionTimeout, + tokenMethods: { + findToken, + updateToken, + createToken, + deleteTokens, + }, + }); + + logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); + } catch (err) { + logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`); + logger.info( + `[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`, + ); + + const isOAuthError = + err.message?.includes('OAuth') || + err.message?.includes('authentication') || + err.message?.includes('401'); + + const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early'; + + if (isOAuthError || oauthRequired || isOAuthFlowInitiated) { + logger.info( + `[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`, + ); + oauthRequired = true; + } else { + logger.error( + `[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`, + err, + ); + } + } + + if (userConnection && !oauthRequired) { + tools = await userConnection.fetchTools(); + availableTools = await updateMCPUserTools({ + userId: req.user.id, + serverName, + tools, + }); + } + + logger.debug( + `[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`, + ); + + const getResponseMessage = () => { + if (oauthRequired) { + return `MCP server '${serverName}' ready for OAuth authentication`; + } + if (userConnection) { + return `MCP server '${serverName}' reinitialized successfully`; + } + return `Failed to reinitialize MCP server '${serverName}'`; + }; + + const result = { + availableTools, + success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)), + message: getResponseMessage(), + oauthRequired, + serverName, + oauthUrl, + tools, + }; + logger.debug(`[MCP Reinitialize] Response for ${serverName}:`, result); + return result; + } catch (error) { + logger.error( + '[MCP Reinitialize] Error loading MCP Tools, servers may still be initializing:', + error, + ); + } +} + +module.exports = { + reinitMCPServer, +}; diff --git a/api/typedefs.js b/api/typedefs.js index b8d2aa348..2703c41d0 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -1115,6 +1115,18 @@ * @memberof typedefs */ +/** + * @exports MCPConnection + * @typedef {import('@librechat/api').MCPConnection} MCPConnection + * @memberof typedefs + */ + +/** + * @exports LCFunctionTool + * @typedef {import('@librechat/api').LCFunctionTool} LCFunctionTool + * @memberof typedefs + */ + /** * @exports FlowStateManager * @typedef {import('@librechat/api').FlowStateManager} FlowStateManager @@ -1825,6 +1837,7 @@ * @param {object} opts - Options for the completion * @param {onTokenProgress} opts.onProgress - Callback function to handle token progress * @param {AbortController} opts.abortController - AbortController instance + * @param {Record>} [opts.userMCPAuthMap] * @returns {Promise} * @memberof typedefs */ diff --git a/client/src/hooks/Chat/useChatFunctions.ts b/client/src/hooks/Chat/useChatFunctions.ts index 37f7b5913..2dcfe3770 100644 --- a/client/src/hooks/Chat/useChatFunctions.ts +++ b/client/src/hooks/Chat/useChatFunctions.ts @@ -230,15 +230,19 @@ export default function useChatFunctions({ const responseMessageId = editedMessageId ?? - (latestMessage?.messageId && isRegenerate ? latestMessage?.messageId + '_' : null) ?? + (latestMessage?.messageId && isRegenerate + ? latestMessage.messageId.replace(/_+$/, '') + '_' + : null) ?? null; + const initialResponseId = + responseMessageId ?? `${isRegenerate ? messageId : intermediateId}`.replace(/_+$/, '') + '_'; const initialResponse: TMessage = { sender: responseSender, text: '', endpoint: endpoint ?? '', parentMessageId: isRegenerate ? messageId : intermediateId, - messageId: responseMessageId ?? `${isRegenerate ? messageId : intermediateId}_`, + messageId: initialResponseId, thread_id, conversationId, unfinished: false, diff --git a/client/src/hooks/SSE/useEventHandlers.ts b/client/src/hooks/SSE/useEventHandlers.ts index 21d39f852..5d459b844 100644 --- a/client/src/hooks/SSE/useEventHandlers.ts +++ b/client/src/hooks/SSE/useEventHandlers.ts @@ -182,7 +182,7 @@ export default function useEventHandlers({ const { token } = useAuthContext(); const contentHandler = useContentHandler({ setMessages, getMessages }); - const stepHandler = useStepHandler({ + const { stepHandler, clearStepMaps } = useStepHandler({ setMessages, getMessages, announcePolite, @@ -806,6 +806,7 @@ export default function useEventHandlers({ ); return { + clearStepMaps, stepHandler, syncHandler, finalHandler, diff --git a/client/src/hooks/SSE/useSSE.ts b/client/src/hooks/SSE/useSSE.ts index 9e1cdf1d1..f639e408b 100644 --- a/client/src/hooks/SSE/useSSE.ts +++ b/client/src/hooks/SSE/useSSE.ts @@ -62,6 +62,7 @@ export default function useSSE( } = chatHelpers; const { + clearStepMaps, stepHandler, syncHandler, finalHandler, @@ -101,6 +102,7 @@ export default function useSSE( payload = removeNullishValues(payload) as TPayload; let textIndex = null; + clearStepMaps(); const sse = new SSE(payloadData.server, { payload: JSON.stringify(payload), diff --git a/client/src/hooks/SSE/useStepHandler.ts b/client/src/hooks/SSE/useStepHandler.ts index 8ca8213da..3e73ef205 100644 --- a/client/src/hooks/SSE/useStepHandler.ts +++ b/client/src/hooks/SSE/useStepHandler.ts @@ -1,5 +1,11 @@ import { useCallback, useRef } from 'react'; -import { StepTypes, ContentTypes, ToolCallTypes, getNonEmptyValue } from 'librechat-data-provider'; +import { + Constants, + StepTypes, + ContentTypes, + ToolCallTypes, + getNonEmptyValue, +} from 'librechat-data-provider'; import type { Agents, TMessage, @@ -178,11 +184,12 @@ export default function useStepHandler({ return { ...message, content: updatedContent as TMessageContentParts[] }; }; - return useCallback( + const stepHandler = useCallback( ({ event, data }: TStepEvent, submission: EventSubmission) => { const messages = getMessages() || []; const { userMessage } = submission; setIsSubmitting(true); + let parentMessageId = userMessage.messageId; const currentTime = Date.now(); if (currentTime - lastAnnouncementTimeRef.current > MESSAGE_UPDATE_INTERVAL) { @@ -197,7 +204,11 @@ export default function useStepHandler({ if (event === 'on_run_step') { const runStep = data as Agents.RunStep; - const responseMessageId = runStep.runId ?? ''; + let responseMessageId = runStep.runId ?? ''; + if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { + responseMessageId = submission?.initialResponse?.messageId ?? ''; + parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; + } if (!responseMessageId) { console.warn('No message id found in run step event'); return; @@ -211,7 +222,7 @@ export default function useStepHandler({ response = { ...responseMessage, - parentMessageId: userMessage.messageId, + parentMessageId, conversationId: userMessage.conversationId, messageId: responseMessageId, content: initialContent, @@ -246,14 +257,18 @@ export default function useStepHandler({ messageMap.current.set(responseMessageId, updatedResponse); const updatedMessages = messages.map((msg) => - msg.messageId === runStep.runId ? updatedResponse : msg, + msg.messageId === responseMessageId ? updatedResponse : msg, ); setMessages(updatedMessages); } } else if (event === 'on_agent_update') { const { agent_update } = data as Agents.AgentUpdate; - const responseMessageId = agent_update.runId || ''; + let responseMessageId = agent_update.runId || ''; + if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { + responseMessageId = submission?.initialResponse?.messageId ?? ''; + parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; + } if (!responseMessageId) { console.warn('No message id found in agent update event'); return; @@ -271,7 +286,11 @@ export default function useStepHandler({ } else if (event === 'on_message_delta') { const messageDelta = data as Agents.MessageDeltaEvent; const runStep = stepMap.current.get(messageDelta.id); - const responseMessageId = runStep?.runId ?? ''; + let responseMessageId = runStep?.runId ?? ''; + if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { + responseMessageId = submission?.initialResponse?.messageId ?? ''; + parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; + } if (!runStep || !responseMessageId) { console.warn('No run step or runId found for message delta event'); @@ -299,7 +318,11 @@ export default function useStepHandler({ } else if (event === 'on_reasoning_delta') { const reasoningDelta = data as Agents.ReasoningDeltaEvent; const runStep = stepMap.current.get(reasoningDelta.id); - const responseMessageId = runStep?.runId ?? ''; + let responseMessageId = runStep?.runId ?? ''; + if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { + responseMessageId = submission?.initialResponse?.messageId ?? ''; + parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; + } if (!runStep || !responseMessageId) { console.warn('No run step or runId found for reasoning delta event'); @@ -327,7 +350,11 @@ export default function useStepHandler({ } else if (event === 'on_run_step_delta') { const runStepDelta = data as Agents.RunStepDeltaEvent; const runStep = stepMap.current.get(runStepDelta.id); - const responseMessageId = runStep?.runId ?? ''; + let responseMessageId = runStep?.runId ?? ''; + if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { + responseMessageId = submission?.initialResponse?.messageId ?? ''; + parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; + } if (!runStep || !responseMessageId) { console.warn('No run step or runId found for run step delta event'); @@ -366,7 +393,7 @@ export default function useStepHandler({ messageMap.current.set(responseMessageId, updatedResponse); const updatedMessages = messages.map((msg) => - msg.messageId === runStep.runId ? updatedResponse : msg, + msg.messageId === responseMessageId ? updatedResponse : msg, ); setMessages(updatedMessages); @@ -377,7 +404,11 @@ export default function useStepHandler({ const { id: stepId } = result; const runStep = stepMap.current.get(stepId); - const responseMessageId = runStep?.runId ?? ''; + let responseMessageId = runStep?.runId ?? ''; + if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { + responseMessageId = submission?.initialResponse?.messageId ?? ''; + parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; + } if (!runStep || !responseMessageId) { console.warn('No run step or runId found for completed tool call event'); @@ -399,7 +430,7 @@ export default function useStepHandler({ messageMap.current.set(responseMessageId, updatedResponse); const updatedMessages = messages.map((msg) => - msg.messageId === runStep.runId ? updatedResponse : msg, + msg.messageId === responseMessageId ? updatedResponse : msg, ); setMessages(updatedMessages); @@ -414,4 +445,11 @@ export default function useStepHandler({ }, [getMessages, setIsSubmitting, lastAnnouncementTimeRef, announcePolite, setMessages], ); + + const clearStepMaps = useCallback(() => { + toolCallIdMap.current.clear(); + messageMap.current.clear(); + stepMap.current.clear(); + }, []); + return { stepHandler, clearStepMaps }; } diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index 1096b019d..4947b6bc6 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -1,5 +1,6 @@ /* MCP */ export * from './mcp/MCPManager'; +export * from './mcp/connection'; export * from './mcp/oauth'; export * from './mcp/auth'; export * from './mcp/zod'; diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 368b80e61..6be5c26a3 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -28,6 +28,7 @@ export class MCPConnectionFactory { protected readonly oauthStart?: (authURL: string) => Promise; protected readonly oauthEnd?: () => Promise; protected readonly returnOnOAuth?: boolean; + protected readonly connectionTimeout?: number; /** Creates a new MCP connection with optional OAuth support */ static async create( @@ -47,6 +48,7 @@ export class MCPConnectionFactory { }); this.serverName = basic.serverName; this.useOAuth = !!oauth?.useOAuth; + this.connectionTimeout = oauth?.connectionTimeout; this.logPrefix = oauth?.user ? `[MCP][${basic.serverName}][${oauth.user.id}]` : `[MCP][${basic.serverName}]`; @@ -82,8 +84,9 @@ export class MCPConnectionFactory { if (!this.tokenMethods?.findToken) return null; try { + const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName); const tokens = await this.flowManager!.createFlowWithHandler( - `tokens:${this.userId}:${this.serverName}`, + flowId, 'mcp_get_tokens', async () => { return await MCPTokenStorage.getTokens({ @@ -203,7 +206,7 @@ export class MCPConnectionFactory { /** Attempts to establish connection with timeout handling */ protected async attemptToConnect(connection: MCPConnection): Promise { - const connectTimeout = this.serverConfig.initTimeout ?? 30000; + const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000; const connectionTimeout = new Promise((_, reject) => setTimeout( () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), @@ -347,6 +350,7 @@ export class MCPConnectionFactory { newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata, + this.signal, ); if (typeof this.oauthEnd === 'function') { await this.oauthEnd(); diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 8e2ab12f3..92d6e012e 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -1,13 +1,8 @@ -import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; import { logger } from '@librechat/data-schemas'; -import type { TokenMethods } from '@librechat/data-schemas'; -import type { TUser } from 'librechat-data-provider'; -import type { FlowStateManager } from '~/flow/manager'; -import type { MCPOAuthTokens } from '~/mcp/oauth'; +import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; import { MCPConnection } from './connection'; -import type { RequestBody } from '~/types'; import type * as t from './types'; /** @@ -44,8 +39,9 @@ export abstract class UserConnectionManager { /** Gets or creates a connection for a specific user */ public async getUserConnection({ - user, serverName, + forceNew, + user, flowManager, customUserVars, requestBody, @@ -54,25 +50,18 @@ export abstract class UserConnectionManager { oauthEnd, signal, returnOnOAuth = false, + connectionTimeout, }: { - user: TUser; serverName: string; - flowManager: FlowStateManager; - customUserVars?: Record; - requestBody?: RequestBody; - tokenMethods?: TokenMethods; - oauthStart?: (authURL: string) => Promise; - oauthEnd?: () => Promise; - signal?: AbortSignal; - returnOnOAuth?: boolean; - }): Promise { + forceNew?: boolean; + } & Omit): Promise { const userId = user.id; if (!userId) { throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); } const userServerMap = this.userConnections.get(userId); - let connection = userServerMap?.get(serverName); + let connection = forceNew ? undefined : userServerMap?.get(serverName); const now = Date.now(); // Check if user is idle @@ -131,6 +120,7 @@ export abstract class UserConnectionManager { oauthEnd: oauthEnd, returnOnOAuth: returnOnOAuth, requestBody: requestBody, + connectionTimeout: connectionTimeout, }, ); diff --git a/packages/api/src/mcp/__tests__/auth.test.ts b/packages/api/src/mcp/__tests__/auth.test.ts index 5d3793cfb..04f7d5c81 100644 --- a/packages/api/src/mcp/__tests__/auth.test.ts +++ b/packages/api/src/mcp/__tests__/auth.test.ts @@ -45,7 +45,7 @@ describe('getUserMCPAuthMap', () => { }, ]; - const tools = testCases.map((testCase) => + const toolInstances = testCases.map((testCase) => createMockTool(testCase.normalizedToolName, testCase.originalName), ); @@ -54,7 +54,7 @@ describe('getUserMCPAuthMap', () => { await getUserMCPAuthMap({ userId: 'user123', - tools, + toolInstances, findPluginAuthsByKeys: mockFindPluginAuthsByKeys, }); @@ -69,7 +69,7 @@ describe('getUserMCPAuthMap', () => { describe('Edge Cases', () => { it('should return empty object when no tools have mcpRawServerName', async () => { - const tools = [ + const toolInstances = [ createMockTool('regular_tool', undefined, false), createMockTool('another_tool', undefined, false), createMockTool('test_mcp_Server_no_raw_name', undefined), @@ -77,7 +77,7 @@ describe('getUserMCPAuthMap', () => { const result = await getUserMCPAuthMap({ userId: 'user123', - tools, + toolInstances, findPluginAuthsByKeys: mockFindPluginAuthsByKeys, }); @@ -104,14 +104,14 @@ describe('getUserMCPAuthMap', () => { }); it('should handle database errors gracefully', async () => { - const tools = [createMockTool('test_mcp_Server1', 'Server1')]; + const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')]; const dbError = new Error('Database connection failed'); mockGetPluginAuthMap.mockRejectedValue(dbError); const result = await getUserMCPAuthMap({ userId: 'user123', - tools, + toolInstances, findPluginAuthsByKeys: mockFindPluginAuthsByKeys, }); @@ -119,18 +119,119 @@ describe('getUserMCPAuthMap', () => { }); it('should handle non-Error exceptions gracefully', async () => { - const tools = [createMockTool('test_mcp_Server1', 'Server1')]; + const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')]; mockGetPluginAuthMap.mockRejectedValue('String error'); const result = await getUserMCPAuthMap({ userId: 'user123', - tools, + toolInstances, findPluginAuthsByKeys: mockFindPluginAuthsByKeys, }); expect(result).toEqual({}); }); + + it('should handle mixed null/undefined values in tools array', async () => { + const tools = [ + 'test_mcp_Server1', + null, + 'test_mcp_Server2', + undefined, + 'regular_tool', + 'test_mcp_Server3', + ]; + + mockGetPluginAuthMap.mockResolvedValue({ + mcp_Server1: { API_KEY: 'key1' }, + mcp_Server2: { API_KEY: 'key2' }, + mcp_Server3: { API_KEY: 'key3' }, + }); + + const result = await getUserMCPAuthMap({ + userId: 'user123', + tools: tools as (string | undefined)[], + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(mockGetPluginAuthMap).toHaveBeenCalledWith({ + userId: 'user123', + pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'], + throwError: false, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(result).toEqual({ + mcp_Server1: { API_KEY: 'key1' }, + mcp_Server2: { API_KEY: 'key2' }, + mcp_Server3: { API_KEY: 'key3' }, + }); + }); + + it('should handle mixed null/undefined values in servers array', async () => { + const servers = ['Server1', null, 'Server2', undefined, 'Server3']; + + mockGetPluginAuthMap.mockResolvedValue({ + mcp_Server1: { API_KEY: 'key1' }, + mcp_Server2: { API_KEY: 'key2' }, + mcp_Server3: { API_KEY: 'key3' }, + }); + + const result = await getUserMCPAuthMap({ + userId: 'user123', + servers: servers as (string | undefined)[], + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(mockGetPluginAuthMap).toHaveBeenCalledWith({ + userId: 'user123', + pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'], + throwError: false, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(result).toEqual({ + mcp_Server1: { API_KEY: 'key1' }, + mcp_Server2: { API_KEY: 'key2' }, + mcp_Server3: { API_KEY: 'key3' }, + }); + }); + + it('should handle mixed null/undefined values in toolInstances array', async () => { + const toolInstances = [ + createMockTool('test_mcp_Server1', 'Server1'), + null, + createMockTool('test_mcp_Server2', 'Server2'), + undefined, + createMockTool('regular_tool', undefined, false), + createMockTool('test_mcp_Server3', 'Server3'), + ]; + + mockGetPluginAuthMap.mockResolvedValue({ + mcp_Server1: { API_KEY: 'key1' }, + mcp_Server2: { API_KEY: 'key2' }, + mcp_Server3: { API_KEY: 'key3' }, + }); + + const result = await getUserMCPAuthMap({ + userId: 'user123', + toolInstances: toolInstances as (GenericTool | null)[], + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(mockGetPluginAuthMap).toHaveBeenCalledWith({ + userId: 'user123', + pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'], + throwError: false, + findPluginAuthsByKeys: mockFindPluginAuthsByKeys, + }); + + expect(result).toEqual({ + mcp_Server1: { API_KEY: 'key1' }, + mcp_Server2: { API_KEY: 'key2' }, + mcp_Server3: { API_KEY: 'key3' }, + }); + }); }); describe('Integration', () => { @@ -138,7 +239,7 @@ describe('getUserMCPAuthMap', () => { const originalServerName = 'Connector: Company'; const toolName = 'test_auth_mcp_Connector__Company'; - const tools = [createMockTool(toolName, originalServerName)]; + const toolInstances = [createMockTool(toolName, originalServerName)]; const mockCustomUserVars = { 'mcp_Connector: Company': { @@ -151,7 +252,7 @@ describe('getUserMCPAuthMap', () => { const result = await getUserMCPAuthMap({ userId: 'user123', - tools, + toolInstances, findPluginAuthsByKeys: mockFindPluginAuthsByKeys, }); diff --git a/packages/api/src/mcp/auth.ts b/packages/api/src/mcp/auth.ts index 8221278fd..e10c16b4b 100644 --- a/packages/api/src/mcp/auth.ts +++ b/packages/api/src/mcp/auth.ts @@ -7,33 +7,56 @@ import { getPluginAuthMap } from '~/agents/auth'; export async function getUserMCPAuthMap({ userId, tools, + servers, + toolInstances, findPluginAuthsByKeys, }: { userId: string; - tools: GenericTool[] | undefined; + tools?: (string | undefined)[]; + servers?: (string | undefined)[]; + toolInstances?: (GenericTool | null)[]; findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys']; }) { - if (!tools || tools.length === 0) { - return {}; - } - - const uniqueMcpServers = new Set(); - - for (const tool of tools) { - const mcpTool = tool as GenericTool & { mcpRawServerName?: string }; - if (mcpTool.mcpRawServerName) { - uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`); - } - } - - if (uniqueMcpServers.size === 0) { - return {}; - } - - const mcpPluginKeysToFetch = Array.from(uniqueMcpServers); - let allMcpCustomUserVars: Record> = {}; + let mcpPluginKeysToFetch: string[] = []; try { + const uniqueMcpServers = new Set(); + + if (servers != null && servers.length) { + for (const serverName of servers) { + if (!serverName) { + continue; + } + uniqueMcpServers.add(`${Constants.mcp_prefix}${serverName}`); + } + } else if (tools != null && tools.length) { + for (const toolName of tools) { + if (!toolName) { + continue; + } + const delimiterIndex = toolName.indexOf(Constants.mcp_delimiter); + if (delimiterIndex === -1) continue; + const mcpServer = toolName.slice(delimiterIndex + Constants.mcp_delimiter.length); + if (!mcpServer) continue; + uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpServer}`); + } + } else if (toolInstances != null && toolInstances.length) { + for (const tool of toolInstances) { + if (!tool) { + continue; + } + const mcpTool = tool as GenericTool & { mcpRawServerName?: string }; + if (mcpTool.mcpRawServerName) { + uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`); + } + } + } + + if (uniqueMcpServers.size === 0) { + return {}; + } + + mcpPluginKeysToFetch = Array.from(uniqueMcpServers); allMcpCustomUserVars = await getPluginAuthMap({ userId, pluginKeys: mcpPluginKeysToFetch, diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index 8e2eb00b4..6f641f027 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -446,7 +446,7 @@ export class MCPConnection extends EventEmitter { const serverUrl = this.url; logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`); - const oauthTimeout = this.options.initTimeout ?? 60000; + const oauthTimeout = this.options.initTimeout ?? 60000 * 2; /** Promise that will resolve when OAuth is handled */ const oauthHandledPromise = new Promise((resolve, reject) => { let timeoutId: NodeJS.Timeout | null = null; diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index cdf51d4ef..6230ac15e 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -134,4 +134,5 @@ export interface OAuthConnectionOptions { oauthStart?: (authURL: string) => Promise; oauthEnd?: () => Promise; returnOnOAuth?: boolean; + connectionTimeout?: number; } diff --git a/packages/api/src/tools/format.ts b/packages/api/src/tools/format.ts index 79ea4df08..dce8b9d16 100644 --- a/packages/api/src/tools/format.ts +++ b/packages/api/src/tools/format.ts @@ -1,5 +1,6 @@ import { AuthType, Constants, EToolResources } from 'librechat-data-provider'; -import type { TCustomConfig, TPlugin, FunctionTool } from 'librechat-data-provider'; +import type { TCustomConfig, TPlugin } from 'librechat-data-provider'; +import { LCAvailableTools, LCFunctionTool } from '~/mcp/types'; /** * Filters out duplicate plugins from the list of plugins. @@ -60,7 +61,7 @@ export function convertMCPToolToPlugin({ customConfig, }: { toolKey: string; - toolData: FunctionTool; + toolData: LCFunctionTool; customConfig?: Partial | null; }): TPlugin | undefined { if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) { @@ -112,7 +113,7 @@ export function convertMCPToolsToPlugins({ functionTools, customConfig, }: { - functionTools?: Record; + functionTools?: LCAvailableTools; customConfig?: Partial | null; }): TPlugin[] | undefined { if (!functionTools || typeof functionTools !== 'object') { diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 0be8ce3ff..57474241d 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1525,6 +1525,8 @@ export enum Constants { CONFIG_VERSION = '1.2.8', /** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */ NO_PARENT = '00000000-0000-0000-0000-000000000000', + /** Standard value to use whatever the submission prelim. `responseMessageId` is */ + USE_PRELIM_RESPONSE_MESSAGE_ID = 'USE_PRELIM_RESPONSE_MESSAGE_ID', /** Standard value for the initial conversationId before a request is sent */ NEW_CONVO = 'new', /** Standard value for the temporary conversationId after a request is sent and before the server responds */ @@ -1551,6 +1553,8 @@ export enum Constants { mcp_delimiter = '_mcp_', /** Prefix for MCP plugins */ mcp_prefix = 'mcp_', + /** Unique value to indicate all MCP servers */ + mcp_all = 'sys__all__sys', /** Placeholder Agent ID for Ephemeral Agents */ EPHEMERAL_AGENT_ID = 'ephemeral', }