diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 1dbc7633ac..fdd6d227b6 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -97,7 +97,7 @@ function createServerToolsCallback() { return; } await mcpToolsCache.set(serverName, serverTools); - logger.debug(`MCP tools for ${serverName} added to cache.`); + logger.warn(`MCP tools for ${serverName} added to cache.`); } catch (error) { logger.error('Error retrieving MCP tools from cache:', error); } @@ -143,7 +143,7 @@ const getAvailableTools = async (req, res) => { const cache = getLogStores(CacheKeys.CONFIG_STORE); const cachedToolsArray = await cache.get(CacheKeys.TOOLS); const cachedUserTools = await getCachedTools({ userId }); - const userPlugins = convertMCPToolsToPlugins(cachedUserTools, customConfig); + const userPlugins = await convertMCPToolsToPlugins(cachedUserTools, customConfig, userId); if (cachedToolsArray && userPlugins) { const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]); @@ -202,23 +202,102 @@ const getAvailableTools = async (req, res) => { const serverName = parts[parts.length - 1]; const serverConfig = customConfig?.mcpServers?.[serverName]; - if (!serverConfig?.customUserVars) { + logger.warn( + `[getAvailableTools] Processing MCP tool:`, + JSON.stringify({ + pluginKey: plugin.pluginKey, + serverName, + hasServerConfig: !!serverConfig, + hasCustomUserVars: !!serverConfig?.customUserVars, + }), + ); + + if (!serverConfig) { + logger.warn( + `[getAvailableTools] No server config found for ${serverName}, skipping auth check`, + ); toolsOutput.push(toolToAdd); continue; } - const customVarKeys = Object.keys(serverConfig.customUserVars); + // Handle MCP servers with customUserVars (user-level auth required) + if (serverConfig.customUserVars) { + logger.warn(`[getAvailableTools] Processing user-level MCP server: ${serverName}`); + const customVarKeys = Object.keys(serverConfig.customUserVars); - if (customVarKeys.length === 0) { - toolToAdd.authConfig = []; - toolToAdd.authenticated = true; - } else { + // Build authConfig for MCP tools toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({ authField: key, label: value.title || key, description: value.description || '', })); - toolToAdd.authenticated = false; + + // Check actual connection status for MCP tools with auth requirements + if (userId) { + try { + const mcpManager = getMCPManager(userId); + const connectionStatus = await mcpManager.getUserConnectionStatus(userId, serverName); + toolToAdd.authenticated = connectionStatus.connected; + logger.warn(`[getAvailableTools] User-level connection status for ${serverName}:`, { + connected: connectionStatus.connected, + hasConnection: connectionStatus.hasConnection, + }); + } catch (error) { + logger.error( + `[getAvailableTools] Error checking connection status for ${serverName}:`, + error, + ); + toolToAdd.authenticated = false; + } + } else { + // For non-authenticated requests, default to false + toolToAdd.authenticated = false; + } + } else { + // Handle app-level MCP servers (no auth required) + logger.warn(`[getAvailableTools] Processing app-level MCP server: ${serverName}`); + toolToAdd.authConfig = []; + + // Check if the app-level connection is active + try { + const mcpManager = getMCPManager(); + const allConnections = mcpManager.getAllConnections(); + logger.warn(`[getAvailableTools] All app-level connections:`, { + connectionNames: Array.from(allConnections.keys()), + serverName, + }); + + const appConnection = mcpManager.getConnection(serverName); + logger.warn(`[getAvailableTools] Checking app-level connection for ${serverName}:`, { + hasConnection: !!appConnection, + connectionState: appConnection?.getConnectionState?.(), + }); + + if (appConnection) { + const connectionState = appConnection.getConnectionState(); + logger.warn(`[getAvailableTools] App-level connection status for ${serverName}:`, { + connectionState, + hasConnection: !!appConnection, + }); + + // For app-level connections, consider them authenticated if they're in 'connected' state + // This is more reliable than isConnected() which does network calls + toolToAdd.authenticated = connectionState === 'connected'; + logger.warn(`[getAvailableTools] Final authenticated status for ${serverName}:`, { + authenticated: toolToAdd.authenticated, + connectionState, + }); + } else { + logger.warn(`[getAvailableTools] No app-level connection found for ${serverName}`); + toolToAdd.authenticated = false; + } + } catch (error) { + logger.error( + `[getAvailableTools] Error checking app-level connection status for ${serverName}:`, + error, + ); + toolToAdd.authenticated = false; + } } toolsOutput.push(toolToAdd); @@ -241,7 +320,7 @@ const getAvailableTools = async (req, res) => { * @param {Object} customConfig - Custom configuration for MCP servers * @returns {Array} Array of plugin objects */ -function convertMCPToolsToPlugins(functionTools, customConfig) { +async function convertMCPToolsToPlugins(functionTools, customConfig, userId = null) { const plugins = []; for (const [toolKey, toolData] of Object.entries(functionTools)) { @@ -257,7 +336,7 @@ function convertMCPToolsToPlugins(functionTools, customConfig) { name: parts[0], // Use the tool name without server suffix pluginKey: toolKey, description: functionData.description || '', - authenticated: true, + authenticated: false, // Default to false, will be updated based on connection status icon: undefined, }; @@ -265,6 +344,7 @@ function convertMCPToolsToPlugins(functionTools, customConfig) { const serverConfig = customConfig?.mcpServers?.[serverName]; if (!serverConfig?.customUserVars) { plugin.authConfig = []; + plugin.authenticated = true; // No auth required plugins.push(plugin); continue; } @@ -272,12 +352,30 @@ function convertMCPToolsToPlugins(functionTools, customConfig) { const customVarKeys = Object.keys(serverConfig.customUserVars); if (customVarKeys.length === 0) { plugin.authConfig = []; + plugin.authenticated = true; // No auth required } else { plugin.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({ authField: key, label: value.title || key, description: value.description || '', })); + + // Check actual connection status for MCP tools with auth requirements + if (userId) { + try { + const mcpManager = getMCPManager(userId); + const connectionStatus = await mcpManager.getUserConnectionStatus(userId, serverName); + plugin.authenticated = connectionStatus.connected; + } catch (error) { + logger.error( + `[convertMCPToolsToPlugins] Error checking connection status for ${serverName}:`, + error, + ); + plugin.authenticated = false; + } + } else { + plugin.authenticated = false; + } } plugins.push(plugin); diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 69791dd7a5..e1d115dab9 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -180,14 +180,18 @@ const updateUserPluginsController = async (req, res) => { try { const mcpManager = getMCPManager(user.id); if (mcpManager) { + // Extract server name from pluginKey (e.g., "mcp_myserver" -> "myserver") + const serverName = pluginKey.replace(Constants.mcp_prefix, ''); + logger.info( - `[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`, + `[updateUserPluginsController] Disconnecting MCP connection for user ${user.id} and server ${serverName} after plugin auth update for ${pluginKey}.`, ); - await mcpManager.disconnectUserConnections(user.id); + // COMMENTED OUT: Don't kill the server connection on revoke + // await mcpManager.disconnectUserConnection(user.id, serverName); } } catch (disconnectError) { logger.error( - `[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`, + `[updateUserPluginsController] Error disconnecting MCP connection for user ${user.id} after plugin auth update:`, disconnectError, ); // Do not fail the request for this, but log it. diff --git a/api/server/routes/config.js b/api/server/routes/config.js index dd93037dd9..bb8c366cb0 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -106,6 +106,7 @@ router.get('/', async function (req, res) { const serverConfig = config.mcpServers[serverName]; payload.mcpServers[serverName] = { customUserVars: serverConfig?.customUserVars || {}, + requiresOAuth: req.app.locals.mcpOAuthRequirements?.[serverName] || false, }; } } diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index b9084f982d..39074c3e6d 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -4,7 +4,7 @@ const { MCPOAuthHandler } = require('@librechat/api'); const { CacheKeys, Constants } = require('librechat-data-provider'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config'); -const { getUserPluginAuthValue } = require('~/server/services/PluginService'); +const { getUserPluginAuthValueByPlugin } = require('~/server/services/PluginService'); const { getMCPManager, getFlowStateManager } = require('~/config'); const { requireJwtAuth } = require('~/server/middleware'); const { getLogStores } = require('~/cache'); @@ -206,10 +206,91 @@ router.get('/oauth/status/:flowId', async (req, res) => { }); /** - * Reinitialize MCP server - * This endpoint allows reinitializing a specific MCP server + * Get connection status for all MCP servers + * This endpoint returns the actual connection status from MCPManager */ -router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { +router.get('/connection/status', requireJwtAuth, async (req, res) => { + try { + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + const mcpManager = getMCPManager(); + const connectionStatus = {}; + + // Get all MCP server names from custom config + const config = await loadCustomConfig(); + const mcpConfig = config?.mcpServers; + + if (mcpConfig) { + for (const [serverName, config] of Object.entries(mcpConfig)) { + try { + // Check if this is an app-level connection (exists in mcpManager.connections) + const appConnection = mcpManager.getConnection(serverName); + const hasAppConnection = !!appConnection; + + // Check if this is a user-level connection (exists in mcpManager.userConnections) + const userConnection = mcpManager.getUserConnectionIfExists(user.id, serverName); + const hasUserConnection = !!userConnection; + + // Determine if connected based on actual connection state + let connected = false; + if (hasAppConnection) { + connected = await appConnection.isConnected(); + } else if (hasUserConnection) { + connected = await userConnection.isConnected(); + } + + // Determine if this server requires user authentication + const hasAuthConfig = + config.customUserVars && Object.keys(config.customUserVars).length > 0; + const requiresOAuth = req.app.locals.mcpOAuthRequirements?.[serverName] || false; + + connectionStatus[serverName] = { + connected, + hasAuthConfig, + hasConnection: hasAppConnection || hasUserConnection, + isAppLevel: hasAppConnection, + isUserLevel: hasUserConnection, + requiresOAuth, + }; + } catch (error) { + logger.error( + `[MCP Connection Status] Error checking connection for ${serverName}:`, + error, + ); + connectionStatus[serverName] = { + connected: false, + hasAuthConfig: config.customUserVars && Object.keys(config.customUserVars).length > 0, + hasConnection: false, + isAppLevel: false, + isUserLevel: false, + requiresOAuth: req.app.locals.mcpOAuthRequirements?.[serverName] || false, + error: error.message, + }; + } + } + } + + logger.info(`[MCP Connection Status] Returning status for user ${user.id}:`, connectionStatus); + + res.json({ + success: true, + connectionStatus, + }); + } catch (error) { + logger.error('[MCP Connection Status] Failed to get connection status', error); + res.status(500).json({ error: 'Failed to get connection status' }); + } +}); + +/** + * Check which authentication values exist for a specific MCP server + * This endpoint returns only boolean flags indicating if values are set, not the actual values + */ +router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { try { const { serverName } = req.params; const user = req.user; @@ -218,10 +299,206 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { return res.status(401).json({ error: 'User not authenticated' }); } + const config = await loadCustomConfig(); + if (!config || !config.mcpServers || !config.mcpServers[serverName]) { + return res.status(404).json({ + error: `MCP server '${serverName}' not found in configuration`, + }); + } + + const serverConfig = config.mcpServers[serverName]; + const pluginKey = `${Constants.mcp_prefix}${serverName}`; + const authValueFlags = {}; + + // Check existence of saved values for each custom user variable (don't fetch actual values) + if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { + for (const varName of Object.keys(serverConfig.customUserVars)) { + try { + const value = await getUserPluginAuthValueByPlugin(user.id, varName, pluginKey, false); + // Only store boolean flag indicating if value exists + authValueFlags[varName] = !!(value && value.length > 0); + } catch (err) { + logger.error( + `[MCP Auth Value Flags] Error checking ${varName} for user ${user.id}:`, + err, + ); + // Default to false if we can't check + authValueFlags[varName] = false; + } + } + } + + res.json({ + success: true, + serverName, + authValueFlags, + }); + } catch (error) { + logger.error( + `[MCP Auth Value Flags] Failed to check auth value flags for ${req.params.serverName}`, + error, + ); + res.status(500).json({ error: 'Failed to check auth value flags' }); + } +}); + +/** + * Check if a specific MCP server requires OAuth + * This endpoint checks if a specific MCP server requires OAuth authentication + */ +router.get('/:serverName/oauth/required', requireJwtAuth, async (req, res) => { + try { + const { serverName } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + const mcpManager = getMCPManager(); + const requiresOAuth = await mcpManager.isOAuthRequired(serverName); + + res.json({ + success: true, + serverName, + requiresOAuth, + }); + } catch (error) { + logger.error( + `[MCP OAuth Required] Failed to check OAuth requirement for ${req.params.serverName}`, + error, + ); + res.status(500).json({ error: 'Failed to check OAuth requirement' }); + } +}); + +/** + * Complete MCP server reinitialization after OAuth + * This endpoint completes the reinitialization process after OAuth authentication + */ +router.post('/:serverName/reinitialize/complete', requireJwtAuth, async (req, res) => { + let responseSent = false; + + try { + const { serverName } = req.params; + const user = req.user; + + if (!user?.id) { + responseSent = true; + return res.status(401).json({ error: 'User not authenticated' }); + } + + logger.info(`[MCP Complete Reinitialize] Starting completion for ${serverName}`); + + const mcpManager = getMCPManager(); + + // Wait for connection to be established via event-driven approach + const userConnection = await new Promise((resolve, reject) => { + // Set a reasonable timeout (10 seconds) + const timeout = setTimeout(() => { + mcpManager.removeListener('connectionEstablished', connectionHandler); + reject(new Error('Timeout waiting for connection establishment')); + }, 10000); + + const connectionHandler = ({ + userId: eventUserId, + serverName: eventServerName, + connection, + }) => { + if (eventUserId === user.id && eventServerName === serverName) { + clearTimeout(timeout); + mcpManager.removeListener('connectionEstablished', connectionHandler); + resolve(connection); + } + }; + + // Check if connection already exists + const existingConnection = mcpManager.getUserConnectionIfExists(user.id, serverName); + if (existingConnection) { + clearTimeout(timeout); + resolve(existingConnection); + return; + } + + // Listen for the connection establishment event + mcpManager.on('connectionEstablished', connectionHandler); + }); + + if (!userConnection) { + responseSent = true; + return res.status(404).json({ error: 'User connection not found' }); + } + + const userTools = (await getCachedTools({ userId: user.id })) || {}; + + // Remove any old tools from this server in the user's cache + const mcpDelimiter = Constants.mcp_delimiter; + for (const key of Object.keys(userTools)) { + if (key.endsWith(`${mcpDelimiter}${serverName}`)) { + delete userTools[key]; + } + } + + // Add the new tools from this server + const tools = await userConnection.fetchTools(); + for (const tool of tools) { + const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`; + userTools[name] = { + type: 'function', + ['function']: { + name, + description: tool.description, + parameters: tool.inputSchema, + }, + }; + } + + // Save the updated user tool cache + await setCachedTools(userTools, { userId: user.id }); + + responseSent = true; + res.json({ + success: true, + message: `MCP server '${serverName}' reinitialized successfully`, + serverName, + }); + } catch (error) { + logger.error( + `[MCP Complete Reinitialize] Error completing reinitialization for ${req.params.serverName}:`, + error, + ); + + if (!responseSent) { + res.status(500).json({ + success: false, + message: 'Failed to complete MCP server reinitialization', + serverName: req.params.serverName, + }); + } + } +}); + +/** + * Reinitialize MCP server + * This endpoint allows reinitializing a specific MCP server + */ +router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { + let responseSent = false; + + try { + const { serverName } = req.params; + const user = req.user; + + if (!user?.id) { + responseSent = true; + return res.status(401).json({ error: 'User not authenticated' }); + } + logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); const config = await loadCustomConfig(); if (!config || !config.mcpServers || !config.mcpServers[serverName]) { + responseSent = true; return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, }); @@ -231,6 +508,21 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { const flowManager = getFlowStateManager(flowsCache); const mcpManager = getMCPManager(); + // Clean up any stale OAuth flows for this server + try { + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + const existingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (existingFlow && existingFlow.status === 'PENDING') { + logger.info(`[MCP Reinitialize] Cleaning up stale OAuth flow for ${serverName}`); + await flowManager.failFlow(flowId, 'mcp_oauth', new Error('OAuth flow interrupted')); + } + } catch (error) { + logger.warn( + `[MCP Reinitialize] Error cleaning up stale OAuth flow for ${serverName}:`, + error, + ); + } + await mcpManager.disconnectServer(serverName); logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`); @@ -240,7 +532,8 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { for (const varName of Object.keys(serverConfig.customUserVars)) { try { - const value = await getUserPluginAuthValue(user.id, varName, false); + const pluginKey = `${Constants.mcp_prefix}${serverName}`; + const value = await getUserPluginAuthValueByPlugin(user.id, varName, pluginKey, false); if (value) { customUserVars[varName] = value; } @@ -251,6 +544,8 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { } let userConnection = null; + let oauthRequired = false; + try { userConnection = await mcpManager.getUserConnection({ user, @@ -263,9 +558,79 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { createToken, deleteTokens, }, + oauthStart: (authURL) => { + // This will be called if OAuth is required + oauthRequired = true; + responseSent = true; + logger.info(`[MCP Reinitialize] OAuth required for ${serverName}, auth URL: ${authURL}`); + + // Get the flow ID for polling + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + + // Return the OAuth response immediately - client will poll for completion + res.json({ + success: false, + oauthRequired: true, + authURL, + flowId, + message: `OAuth authentication required for MCP server '${serverName}'`, + serverName, + }); + }, + oauthEnd: () => { + // This will be called when OAuth flow completes + logger.info(`[MCP Reinitialize] OAuth flow completed for ${serverName}`); + }, }); + + // If response was already sent for OAuth, don't continue + if (responseSent) { + return; + } } catch (err) { logger.error(`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`, err); + + // Check if this is an OAuth error + if (err.message && err.message.includes('OAuth required')) { + // Try to get the OAuth URL from the flow manager + try { + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + const existingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth'); + + if (existingFlow && existingFlow.metadata) { + const { serverUrl, oauth: oauthConfig } = existingFlow.metadata; + if (serverUrl && oauthConfig) { + const { authorizationUrl: authUrl } = await MCPOAuthHandler.initiateOAuthFlow( + serverName, + serverUrl, + user.id, + oauthConfig, + ); + + return res.json({ + success: false, + oauthRequired: true, + authURL: authUrl, + flowId, + message: `OAuth authentication required for MCP server '${serverName}'`, + serverName, + }); + } + } + } catch (oauthErr) { + logger.error(`[MCP Reinitialize] Error getting OAuth URL for ${serverName}:`, oauthErr); + } + + responseSent = true; + return res.status(401).json({ + success: false, + oauthRequired: true, + message: `OAuth authentication required for MCP server '${serverName}'`, + serverName, + }); + } + + responseSent = true; return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); } @@ -296,6 +661,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { // Save the updated user tool cache await setCachedTools(userTools, { userId: user.id }); + responseSent = true; res.json({ success: true, message: `MCP server '${serverName}' reinitialized successfully`, @@ -303,7 +669,9 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { }); } catch (error) { logger.error('[MCP Reinitialize] Unexpected error', error); - res.status(500).json({ error: 'Internal server error' }); + if (!responseSent) { + res.status(500).json({ error: 'Internal server error' }); + } } }); diff --git a/api/server/services/Config/loadCustomConfig.js b/api/server/services/Config/loadCustomConfig.js index 393281daf2..2de36591b1 100644 --- a/api/server/services/Config/loadCustomConfig.js +++ b/api/server/services/Config/loadCustomConfig.js @@ -108,8 +108,6 @@ https://www.librechat.ai/docs/configuration/stt_tts`); return null; } else { - logger.info('Custom config file loaded:'); - logger.info(JSON.stringify(customConfig, null, 2)); logger.debug('Custom config:', customConfig); } diff --git a/api/server/services/PluginService.js b/api/server/services/PluginService.js index af42e0471c..1e276de118 100644 --- a/api/server/services/PluginService.js +++ b/api/server/services/PluginService.js @@ -41,6 +41,38 @@ const getUserPluginAuthValue = async (userId, authField, throwError = true) => { } }; +/** + * Asynchronously retrieves and decrypts the authentication value for a user's specific plugin, based on a specified authentication field and plugin key. + * + * @param {string} userId - The unique identifier of the user for whom the plugin authentication value is to be retrieved. + * @param {string} authField - The specific authentication field (e.g., 'API_KEY', 'URL') whose value is to be retrieved and decrypted. + * @param {string} pluginKey - The plugin key to filter by (e.g., 'mcp_github-mcp'). + * @param {boolean} throwError - Whether to throw an error if the authentication value does not exist. Defaults to `true`. + * @returns {Promise} A promise that resolves to the decrypted authentication value if found, or `null` if no such authentication value exists for the given user, field, and plugin. + * + * @throws {Error} Throws an error if there's an issue during the retrieval or decryption process, or if the authentication value does not exist. + * @async + */ +const getUserPluginAuthValueByPlugin = async (userId, authField, pluginKey, throwError = true) => { + try { + const pluginAuth = await findOnePluginAuth({ userId, authField, pluginKey }); + if (!pluginAuth) { + throw new Error( + `No plugin auth ${authField} found for user ${userId} and plugin ${pluginKey}`, + ); + } + + const decryptedValue = await decrypt(pluginAuth.value); + return decryptedValue; + } catch (err) { + if (!throwError) { + return null; + } + logger.error('[getUserPluginAuthValueByPlugin]', err); + throw err; + } +}; + // const updateUserPluginAuth = async (userId, authField, pluginKey, value) => { // try { // const encryptedValue = encrypt(value); @@ -119,6 +151,7 @@ const deleteUserPluginAuth = async (userId, authField, all = false, pluginKey) = module.exports = { getUserPluginAuthValue, + getUserPluginAuthValueByPlugin, updateUserPluginAuth, deleteUserPluginAuth, }; diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index 18edb2449d..d5e2545aa5 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -10,6 +10,15 @@ const { getLogStores } = require('~/cache'); * @param {import('express').Application} app - Express app instance */ async function initializeMCPs(app) { + // TEMPORARY: Reset all OAuth tokens for fresh testing + try { + logger.info('[MCP] Resetting all OAuth tokens for fresh testing...'); + await deleteTokens({}); + logger.info('[MCP] All OAuth tokens reset successfully'); + } catch (error) { + logger.error('[MCP] Error resetting OAuth tokens:', error); + } + const mcpServers = app.locals.mcpConfig; if (!mcpServers) { return; @@ -36,7 +45,7 @@ async function initializeMCPs(app) { const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null; try { - await mcpManager.initializeMCPs({ + const oauthRequirements = await mcpManager.initializeMCPs({ mcpServers: filteredServers, flowManager, tokenMethods: { @@ -64,6 +73,9 @@ async function initializeMCPs(app) { logger.debug('Cleared tools array cache after MCP initialization'); logger.info('MCP servers initialized successfully'); + + // Store OAuth requirement information in app locals for client access + app.locals.mcpOAuthRequirements = oauthRequirements; } catch (error) { logger.error('Failed to initialize MCP servers:', error); } diff --git a/client/src/components/Chat/Input/MCPSelect.tsx b/client/src/components/Chat/Input/MCPSelect.tsx index 0a03decd53..8f2d967883 100644 --- a/client/src/components/Chat/Input/MCPSelect.tsx +++ b/client/src/components/Chat/Input/MCPSelect.tsx @@ -1,9 +1,12 @@ -import React, { memo, useCallback, useState } from 'react'; -import { SettingsIcon } from 'lucide-react'; +import React, { memo, useCallback, useState, useMemo } from 'react'; +import { SettingsIcon, PlugZap } from 'lucide-react'; import { Constants } from 'librechat-data-provider'; import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query'; +import { useMCPConnectionStatusQuery, useMCPAuthValuesQuery } from '~/data-provider'; +import { useQueryClient } from '@tanstack/react-query'; +import { QueryKeys } from 'librechat-data-provider'; import type { TUpdateUserPlugins, TPlugin } from 'librechat-data-provider'; -import MCPConfigDialog, { type ConfigFieldDetail } from '~/components/ui/MCPConfigDialog'; +import { MCPConfigDialog, type ConfigFieldDetail } from '~/components/ui/MCP'; import { useToastContext, useBadgeRowContext } from '~/Providers'; import MultiSelect from '~/components/ui/MultiSelect'; import { MCPIcon } from '~/components/svg'; @@ -18,15 +21,47 @@ function MCPSelect() { const localize = useLocalize(); const { showToast } = useToastContext(); const { mcpSelect, startupConfig } = useBadgeRowContext(); - const { mcpValues, setMCPValues, mcpServerNames, mcpToolDetails, isPinned } = mcpSelect; + const { mcpValues, setMCPValues, mcpToolDetails, isPinned } = mcpSelect; + + // Get real connection status from MCPManager + const { data: statusQuery } = useMCPConnectionStatusQuery(); + + const mcpServerStatuses = statusQuery?.connectionStatus || {}; + + console.log('mcpServerStatuses', mcpServerStatuses); + console.log('statusQuery', statusQuery); const [isConfigModalOpen, setIsConfigModalOpen] = useState(false); const [selectedToolForConfig, setSelectedToolForConfig] = useState(null); + // Fetch auth values for the selected server + const { data: authValuesData } = useMCPAuthValuesQuery(selectedToolForConfig?.name || '', { + enabled: isConfigModalOpen && !!selectedToolForConfig?.name, + }); + + const queryClient = useQueryClient(); + const updateUserPluginsMutation = useUpdateUserPluginsMutation({ - onSuccess: () => { - setIsConfigModalOpen(false); + onSuccess: async (data, variables) => { showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' }); + + // // For 'uninstall' actions (revoke), remove the server from selected values + // if (variables.action === 'uninstall') { + // const serverName = variables.pluginKey.replace(Constants.mcp_prefix, ''); + // const currentValues = mcpValues ?? []; + // const filteredValues = currentValues.filter((name) => name !== serverName); + // setMCPValues(filteredValues); + // } + + // Wait for all refetches to complete before ending loading state + await Promise.all([ + queryClient.invalidateQueries([QueryKeys.tools]), + queryClient.refetchQueries([QueryKeys.tools]), + queryClient.invalidateQueries([QueryKeys.mcpAuthValues]), + queryClient.refetchQueries([QueryKeys.mcpAuthValues]), + queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]), + queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]), + ]); }, onError: (error: unknown) => { console.error('Error updating MCP auth:', error); @@ -53,10 +88,12 @@ function MCPSelect() { const handleConfigSave = useCallback( (targetName: string, authData: Record) => { if (selectedToolForConfig && selectedToolForConfig.name === targetName) { - const basePluginKey = getBaseMCPPluginKey(selectedToolForConfig.pluginKey); - + // Use the pluginKey directly since it's already in the correct format + console.log( + `[MCP Select] Saving config for ${targetName}, pluginKey: ${`${Constants.mcp_prefix}${targetName}`}`, + ); const payload: TUpdateUserPlugins = { - pluginKey: basePluginKey, + pluginKey: `${Constants.mcp_prefix}${targetName}`, action: 'install', auth: authData, }; @@ -69,10 +106,12 @@ function MCPSelect() { const handleConfigRevoke = useCallback( (targetName: string) => { if (selectedToolForConfig && selectedToolForConfig.name === targetName) { - const basePluginKey = getBaseMCPPluginKey(selectedToolForConfig.pluginKey); - + // Use the pluginKey directly since it's already in the correct format + console.log( + `[MCP Select] Revoking config for ${targetName}, pluginKey: ${`${Constants.mcp_prefix}${targetName}`}`, + ); const payload: TUpdateUserPlugins = { - pluginKey: basePluginKey, + pluginKey: `${Constants.mcp_prefix}${targetName}`, action: 'uninstall', auth: {}, }; @@ -82,49 +121,138 @@ function MCPSelect() { [selectedToolForConfig, updateUserPluginsMutation], ); + // Create stable callback references to prevent stale closures + const handleSave = useCallback( + (authData: Record) => { + if (selectedToolForConfig) { + handleConfigSave(selectedToolForConfig.name, authData); + } + }, + [selectedToolForConfig, handleConfigSave], + ); + + const handleRevoke = useCallback(() => { + if (selectedToolForConfig) { + handleConfigRevoke(selectedToolForConfig.name); + } + }, [selectedToolForConfig, handleConfigRevoke]); + + // Only allow connected servers to be selected + const handleSetSelectedValues = useCallback( + (values: string[]) => { + // Filter to only include connected servers + const connectedValues = values.filter((serverName) => { + const serverStatus = mcpServerStatuses?.[serverName]; + return serverStatus?.connected || false; + }); + setMCPValues(connectedValues); + }, + [setMCPValues, mcpServerStatuses], + ); + const renderItemContent = useCallback( (serverName: string, defaultContent: React.ReactNode) => { - const tool = mcpToolDetails?.find((t) => t.name === serverName); - const hasAuthConfig = tool?.authConfig && tool.authConfig.length > 0; + const serverStatus = mcpServerStatuses?.[serverName]; + const connected = serverStatus?.connected || false; + const hasAuthConfig = serverStatus?.hasAuthConfig || false; - // Common wrapper for the main content (check mark + text) - // Ensures Check & Text are adjacent and the group takes available space. - const mainContentWrapper = ( -
{defaultContent}
- ); + // Icon logic: + // - connected with auth config = gear (green) + // - connected without auth config = no icon (just text) + // - not connected = zap (orange) + let icon: React.ReactNode = null; + let tooltip = 'Configure server'; - if (tool && hasAuthConfig) { - return ( -
- {mainContentWrapper} + if (connected) { + if (hasAuthConfig) { + icon = ; + tooltip = 'Configure connected server'; + } else { + // No icon for connected servers without auth config + tooltip = 'Connected server (no configuration needed)'; + } + } else { + icon = ; + tooltip = 'Configure server'; + } + + const onClick = () => { + const serverConfig = startupConfig?.mcpServers?.[serverName]; + if (serverConfig) { + const serverTool = { + name: serverName, + pluginKey: `${Constants.mcp_prefix}${serverName}`, + authConfig: Object.entries(serverConfig.customUserVars || {}).map(([key, config]) => ({ + authField: key, + label: config.title, + description: config.description, + requiresOAuth: serverConfig.requiresOAuth || false, + })), + authenticated: connected, + }; + setSelectedToolForConfig(serverTool); + setIsConfigModalOpen(true); + } + }; + + return ( +
+
+ {defaultContent} +
+ {icon && ( -
- ); - } - // For items without a settings icon, return the consistently wrapped main content. - return mainContentWrapper; + )} +
+ ); }, - [mcpToolDetails, setSelectedToolForConfig, setIsConfigModalOpen], + [mcpServerStatuses, setSelectedToolForConfig, setIsConfigModalOpen, startupConfig], ); - // Don't render if no servers are selected and not pinned - if ((!mcpValues || mcpValues.length === 0) && !isPinned) { + // Memoize schema and initial values to prevent unnecessary re-renders + const fieldsSchema = useMemo(() => { + const schema: Record = {}; + if (selectedToolForConfig?.authConfig) { + selectedToolForConfig.authConfig.forEach((field) => { + schema[field.authField] = { + title: field.label, + description: field.description, + }; + }); + } + return schema; + }, [selectedToolForConfig?.authConfig]); + + const initialValues = useMemo(() => { + const initial: Record = {}; + // Always start with empty values for security - never prefill sensitive data + if (selectedToolForConfig?.authConfig) { + selectedToolForConfig.authConfig.forEach((field) => { + initial[field.authField] = ''; + }); + } + return initial; + }, [selectedToolForConfig?.authConfig]); + + // Don't render if no MCP servers are available at all + if (!mcpServerStatuses || Object.keys(mcpServerStatuses).length === 0) { return null; } - if (!mcpToolDetails || mcpToolDetails.length === 0) { + // Don't render if no servers are selected and not pinned + if ((!mcpValues || mcpValues.length === 0) && !isPinned) { return null; } @@ -133,9 +261,9 @@ function MCPSelect() { return ( <> { - const schema: Record = {}; - if (selectedToolForConfig?.authConfig) { - selectedToolForConfig.authConfig.forEach((field) => { - schema[field.authField] = { - title: field.label, - description: field.description, - }; - }); - } - return schema; - })()} - initialValues={(() => { - const initial: Record = {}; - // Note: Actual initial values might need to be fetched if they are stored user-specifically - if (selectedToolForConfig?.authConfig) { - selectedToolForConfig.authConfig.forEach((field) => { - initial[field.authField] = ''; // Or fetched value - }); - } - return initial; - })()} - onSave={(authData) => { - if (selectedToolForConfig) { - handleConfigSave(selectedToolForConfig.name, authData); - } - }} - onRevoke={() => { - if (selectedToolForConfig) { - handleConfigRevoke(selectedToolForConfig.name); - } - }} + fieldsSchema={fieldsSchema} + initialValues={initialValues} + onSave={handleSave} + onRevoke={handleRevoke} isSubmitting={updateUserPluginsMutation.isLoading} + isConnected={mcpServerStatuses?.[selectedToolForConfig.name]?.connected || false} + authConfig={selectedToolForConfig.authConfig} /> )} diff --git a/client/src/components/SidePanel/MCP/MCPPanel.tsx b/client/src/components/SidePanel/MCP/MCPPanel.tsx index 0a8ca856f6..9630e77727 100644 --- a/client/src/components/SidePanel/MCP/MCPPanel.tsx +++ b/client/src/components/SidePanel/MCP/MCPPanel.tsx @@ -6,6 +6,8 @@ import { useUpdateUserPluginsMutation, useReinitializeMCPServerMutation, } from 'librechat-data-provider/react-query'; +import { useQueryClient } from '@tanstack/react-query'; +import { QueryKeys } from 'librechat-data-provider'; import type { TUpdateUserPlugins } from 'librechat-data-provider'; import { Button, Input, Label } from '~/components/ui'; import { useGetStartupConfig } from '~/data-provider'; @@ -29,6 +31,7 @@ export default function MCPPanel() { ); const [rotatingServers, setRotatingServers] = useState>(new Set()); const reinitializeMCPMutation = useReinitializeMCPServerMutation(); + const queryClient = useQueryClient(); const mcpServerDefinitions = useMemo(() => { if (!startupConfig?.mcpServers) { @@ -50,11 +53,43 @@ export default function MCPPanel() { }, [startupConfig?.mcpServers]); const updateUserPluginsMutation = useUpdateUserPluginsMutation({ - onSuccess: () => { + onSuccess: async (data, variables) => { showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' }); + + // Refetch tools query to refresh authentication state in the dropdown + queryClient.refetchQueries([QueryKeys.tools]); + + // For 'uninstall' actions (revoke), remove the server from selected values + if (variables.action === 'uninstall') { + const serverName = variables.pluginKey.replace(Constants.mcp_prefix, ''); + // Note: MCPPanel doesn't directly manage selected values, but this ensures + // the tools query is refreshed so MCPSelect will pick up the changes + } + + // Only reinitialize for 'install' actions (save), not 'uninstall' actions (revoke) + if (variables.action === 'install') { + // Extract server name from pluginKey (e.g., "mcp_myServer" -> "myServer") + const serverName = variables.pluginKey.replace(Constants.mcp_prefix, ''); + + // Reinitialize the MCP server to pick up the new authentication values + try { + await reinitializeMCPMutation.mutateAsync(serverName); + console.log( + `[MCP Panel] Successfully reinitialized server ${serverName} after auth update`, + ); + } catch (error) { + console.error( + `[MCP Panel] Error reinitializing server ${serverName} after auth update:`, + error, + ); + // Don't show error toast to user as the auth update was successful + } + } + // For 'uninstall' actions (revoke), the backend already disconnects the connections + // so no additional action is needed here }, - onError: (error) => { - console.error('Error updating MCP custom user variables:', error); + onError: (error: unknown) => { + console.error('Error updating MCP auth:', error); showToast({ message: localize('com_nav_mcp_vars_update_error'), status: 'error', @@ -98,17 +133,79 @@ export default function MCPPanel() { async (serverName: string) => { setRotatingServers((prev) => new Set(prev).add(serverName)); try { - await reinitializeMCPMutation.mutateAsync(serverName); - showToast({ - message: `MCP server '${serverName}' reinitialized successfully`, - status: 'success', - }); + const response = await reinitializeMCPMutation.mutateAsync(serverName); + + // Check if OAuth is required + if (response.oauthRequired) { + if (response.authorizationUrl) { + // Show OAuth URL to user + showToast({ + message: `OAuth required for ${serverName}. Please visit the authorization URL.`, + status: 'info', + }); + + // Open OAuth URL in new window/tab + window.open(response.authorizationUrl, '_blank', 'noopener,noreferrer'); + + // Show a more detailed message with the URL + setTimeout(() => { + showToast({ + message: `OAuth URL opened for ${serverName}. Complete authentication and try reinitializing again.`, + status: 'info', + }); + }, 1000); + } else { + showToast({ + message: `OAuth authentication required for ${serverName}. Please configure OAuth credentials.`, + status: 'warning', + }); + } + } else if (response.oauthCompleted) { + showToast({ + message: + response.message || + `MCP server '${serverName}' reinitialized successfully after OAuth`, + status: 'success', + }); + } else { + showToast({ + message: response.message || `MCP server '${serverName}' reinitialized successfully`, + status: 'success', + }); + } } catch (error) { console.error('Error reinitializing MCP server:', error); - showToast({ - message: 'Failed to reinitialize MCP server', - status: 'error', - }); + + // Check if the error response contains OAuth information + if (error?.response?.data?.oauthRequired) { + const errorData = error.response.data; + if (errorData.authorizationUrl) { + showToast({ + message: `OAuth required for ${serverName}. Please visit the authorization URL.`, + status: 'info', + }); + + // Open OAuth URL in new window/tab + window.open(errorData.authorizationUrl, '_blank', 'noopener,noreferrer'); + + setTimeout(() => { + showToast({ + message: `OAuth URL opened for ${serverName}. Complete authentication and try reinitializing again.`, + status: 'info', + }); + }, 1000); + } else { + showToast({ + message: errorData.message || `OAuth authentication required for ${serverName}`, + status: 'warning', + }); + } + } else { + showToast({ + message: 'Failed to reinitialize MCP server', + status: 'error', + }); + } } finally { setRotatingServers((prev) => { const next = new Set(prev); diff --git a/client/src/components/ui/MCP/CustomUserVarsSection.tsx b/client/src/components/ui/MCP/CustomUserVarsSection.tsx new file mode 100644 index 0000000000..db62486e59 --- /dev/null +++ b/client/src/components/ui/MCP/CustomUserVarsSection.tsx @@ -0,0 +1,161 @@ +import React, { useMemo } from 'react'; +import { useForm, Controller } from 'react-hook-form'; +import { Input, Label, Button } from '~/components/ui'; +import { useLocalize } from '~/hooks'; +import { useMCPAuthValuesQuery } from '~/data-provider/Tools/queries'; + +export interface CustomUserVarConfig { + title: string; + description?: string; +} + +interface CustomUserVarsSectionProps { + serverName: string; + fields: Record; + onSave: (authData: Record) => void; + onRevoke: () => void; + isSubmitting?: boolean; +} + +interface AuthFieldProps { + name: string; + config: CustomUserVarConfig; + hasValue: boolean; + control: any; + errors: any; +} + +function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps) { + const localize = useLocalize(); + + return ( +
+
+ + {hasValue ? ( +
+
+ {localize('com_ui_set')} +
+ ) : ( +
+
+ {localize('com_ui_unset')} +
+ )} +
+ ( + + )} + /> + {config.description && ( +

+ )} + {errors[name] &&

{errors[name]?.message}

} +
+ ); +} + +export default function CustomUserVarsSection({ + serverName, + fields, + onSave, + onRevoke, + isSubmitting = false, +}: CustomUserVarsSectionProps) { + const localize = useLocalize(); + + // Fetch auth value flags for the server + const { data: authValuesData } = useMCPAuthValuesQuery(serverName, { + enabled: !!serverName, + }); + + const { + control, + handleSubmit, + reset, + formState: { errors }, + } = useForm>({ + defaultValues: useMemo(() => { + const initial: Record = {}; + Object.keys(fields).forEach((key) => { + initial[key] = ''; + }); + return initial; + }, [fields]), + }); + + const onFormSubmit = (data: Record) => { + onSave(data); + }; + + const handleRevokeClick = () => { + onRevoke(); + // Reset form after revoke + reset(); + }; + + // Don't render if no fields to configure + if (!fields || Object.keys(fields).length === 0) { + return null; + } + + return ( +
+
+ {Object.entries(fields).map(([key, config]) => { + const hasValue = authValuesData?.authValueFlags?.[key] || false; + + return ( + + ); + })} + + +
+ + +
+
+ ); +} diff --git a/client/src/components/ui/MCP/MCPConfigDialog.tsx b/client/src/components/ui/MCP/MCPConfigDialog.tsx new file mode 100644 index 0000000000..08efb6a85c --- /dev/null +++ b/client/src/components/ui/MCP/MCPConfigDialog.tsx @@ -0,0 +1,108 @@ +import React, { useMemo, useCallback } from 'react'; +import { useLocalize } from '~/hooks'; +import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries'; +import { CustomUserVarsSection, ServerInitializationSection } from './'; +import { useQueryClient } from '@tanstack/react-query'; +import { QueryKeys } from 'librechat-data-provider'; + +import { + OGDialog, + OGDialogContent, + OGDialogHeader, + OGDialogTitle, + OGDialogDescription, +} from '~/components/ui/OriginalDialog'; + +export interface ConfigFieldDetail { + title: string; + description: string; +} + +interface MCPConfigDialogProps { + isOpen: boolean; + onOpenChange: (isOpen: boolean) => void; + fieldsSchema: Record; + initialValues: Record; + onSave: (updatedValues: Record) => void; + isSubmitting?: boolean; + onRevoke?: () => void; + serverName: string; + isConnected?: boolean; + authConfig?: Array<{ + authField: string; + label: string; + description: string; + requiresOAuth?: boolean; + }>; +} + +export default function MCPConfigDialog({ + isOpen, + onOpenChange, + fieldsSchema, + onSave, + isSubmitting = false, + onRevoke, + serverName, +}: MCPConfigDialogProps) { + const localize = useLocalize(); + const queryClient = useQueryClient(); + + // Get connection status to determine OAuth requirements with aggressive refresh + const { data: statusQuery, refetch: refetchConnectionStatus } = useMCPConnectionStatusQuery({ + refetchOnMount: true, + refetchOnWindowFocus: true, + staleTime: 0, + cacheTime: 0, + }); + const mcpServerStatuses = statusQuery?.connectionStatus || {}; + + // Derive real-time connection status and OAuth requirements + const serverStatus = mcpServerStatuses[serverName]; + const isRealTimeConnected = serverStatus?.connected || false; + const requiresOAuth = useMemo(() => { + return serverStatus?.requiresOAuth || false; + }, [serverStatus?.requiresOAuth]); + + const hasFields = Object.keys(fieldsSchema).length > 0; + const dialogTitle = hasFields + ? localize('com_ui_configure_mcp_variables_for', { 0: serverName }) + : `${serverName} MCP Server`; + const dialogDescription = hasFields + ? localize('com_ui_mcp_dialog_desc') + : `Manage connection and settings for the ${serverName} MCP server.`; + + return ( + + + +
+ {dialogTitle} + {isRealTimeConnected && ( +
+
+ {localize('com_ui_active')} +
+ )} +
+ {dialogDescription} + + + {/* Content */} +
+ {/* Custom User Variables Section */} + {})} + isSubmitting={isSubmitting} + /> +
+ + {/* Server Initialization Section */} + + + + ); +} diff --git a/client/src/components/ui/MCP/ServerInitializationSection.tsx b/client/src/components/ui/MCP/ServerInitializationSection.tsx new file mode 100644 index 0000000000..8367c57d0d --- /dev/null +++ b/client/src/components/ui/MCP/ServerInitializationSection.tsx @@ -0,0 +1,228 @@ +import React, { useState, useEffect, useCallback } from 'react'; +import { Button } from '~/components/ui'; +import { useLocalize } from '~/hooks'; +import { useToastContext } from '~/Providers'; +import { + useReinitializeMCPServerMutation, + useMCPOAuthStatusQuery, + useCompleteMCPServerReinitializeMutation, +} from 'librechat-data-provider/react-query'; +import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries'; +import { useQueryClient } from '@tanstack/react-query'; +import { QueryKeys } from 'librechat-data-provider'; + +import { RefreshCw, Link } from 'lucide-react'; + +interface ServerInitializationSectionProps { + serverName: string; + requiresOAuth: boolean; +} + +export default function ServerInitializationSection({ + serverName, + requiresOAuth, +}: ServerInitializationSectionProps) { + const localize = useLocalize(); + const { showToast } = useToastContext(); + const queryClient = useQueryClient(); + + const [oauthUrl, setOauthUrl] = useState(null); + const [oauthFlowId, setOauthFlowId] = useState(null); + + const { data: statusQuery } = useMCPConnectionStatusQuery(); + const mcpServerStatuses = statusQuery?.connectionStatus || {}; + const serverStatus = mcpServerStatuses[serverName]; + const isConnected = serverStatus?.connected || false; + + // Helper function to invalidate caches after successful connection + const handleSuccessfulConnection = useCallback( + async (message: string) => { + showToast({ message, status: 'success' }); + + // Force immediate refetch to update UI + await Promise.all([ + queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]), + queryClient.refetchQueries([QueryKeys.tools]), + ]); + }, + [showToast, queryClient], + ); + + // Main initialization mutation + const reinitializeMutation = useReinitializeMCPServerMutation(); + + // OAuth completion mutation (stores our tools) + const completeReinitializeMutation = useCompleteMCPServerReinitializeMutation(); + + // Override the mutation success handlers + const handleInitializeServer = useCallback(() => { + // Reset OAuth state before starting + setOauthUrl(null); + setOauthFlowId(null); + + // Trigger initialization + reinitializeMutation.mutate(serverName, { + onSuccess: (response) => { + if (response.oauthRequired) { + if (response.authURL && response.flowId) { + setOauthUrl(response.authURL); + setOauthFlowId(response.flowId); + // Keep loading state - OAuth completion will handle success + } else { + showToast({ + message: `OAuth authentication required for ${serverName}. Please configure OAuth credentials.`, + status: 'warning', + }); + } + } else if (response.success) { + handleSuccessfulConnection( + response.message || `MCP server '${serverName}' initialized successfully`, + ); + } + }, + onError: (error: any) => { + console.error('Error initializing MCP server:', error); + showToast({ + message: 'Failed to initialize MCP server', + status: 'error', + }); + }, + }); + }, [reinitializeMutation, serverName, showToast, handleSuccessfulConnection]); + + // OAuth status polling (only when we have a flow ID) + const oauthStatusQuery = useMCPOAuthStatusQuery(oauthFlowId || '', { + enabled: !!oauthFlowId, + refetchInterval: oauthFlowId ? 2000 : false, + retry: false, + onSuccess: (data) => { + if (data?.completed) { + // Immediately reset OAuth state to stop polling + setOauthUrl(null); + setOauthFlowId(null); + + // OAuth completed, trigger completion mutation + completeReinitializeMutation.mutate(serverName, { + onSuccess: (response) => { + handleSuccessfulConnection( + response.message || `MCP server '${serverName}' initialized successfully after OAuth`, + ); + }, + onError: (error: any) => { + // Check if it initialized anyway + if (isConnected) { + handleSuccessfulConnection('MCP server initialized successfully after OAuth'); + return; + } + + console.error('Error completing MCP initialization:', error); + showToast({ + message: 'Failed to complete MCP server initialization after OAuth', + status: 'error', + }); + + // OAuth state already reset above + }, + }); + } else if (data?.failed) { + showToast({ + message: `OAuth authentication failed: ${data.error || 'Unknown error'}`, + status: 'error', + }); + // Reset OAuth state on failure + setOauthUrl(null); + setOauthFlowId(null); + } + }, + }); + + // Reset OAuth state when component unmounts or server changes + useEffect(() => { + return () => { + setOauthUrl(null); + setOauthFlowId(null); + }; + }, [serverName]); + + const isLoading = + reinitializeMutation.isLoading || + completeReinitializeMutation.isLoading || + (!!oauthFlowId && oauthStatusQuery.isFetching); + + // Show subtle reinitialize option if connected + if (isConnected) { + return ( +
+ +
+ ); + } + + return ( +
+
+
+ + {requiresOAuth + ? `${serverName} not authenticated (OAuth Required)` + : `${serverName} not initialized`} + +
+ {/* Only show authenticate button when OAuth URL is not present */} + {!oauthUrl && ( + + )} +
+ + {/* OAuth URL display */} + {oauthUrl && ( +
+
+
+ +
+ + {localize('com_ui_authorization_url')} + +
+
+ +
+

+ {localize('com_ui_oauth_flow_desc')} +

+
+ )} +
+ ); +} diff --git a/client/src/components/ui/MCP/index.ts b/client/src/components/ui/MCP/index.ts new file mode 100644 index 0000000000..43ce6bb7da --- /dev/null +++ b/client/src/components/ui/MCP/index.ts @@ -0,0 +1,5 @@ +export { default as MCPConfigDialog } from './MCPConfigDialog'; +export { default as CustomUserVarsSection } from './CustomUserVarsSection'; +export { default as ServerInitializationSection } from './ServerInitializationSection'; +export type { ConfigFieldDetail } from './MCPConfigDialog'; +export type { CustomUserVarConfig } from './CustomUserVarsSection'; diff --git a/client/src/components/ui/MCPConfigDialog.tsx b/client/src/components/ui/MCPConfigDialog.tsx deleted file mode 100644 index d1a53bd902..0000000000 --- a/client/src/components/ui/MCPConfigDialog.tsx +++ /dev/null @@ -1,122 +0,0 @@ -import React, { useEffect } from 'react'; -import { useForm, Controller } from 'react-hook-form'; -import { Input, Label, OGDialog, Button } from '~/components/ui'; -import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; -import { useLocalize } from '~/hooks'; - -export interface ConfigFieldDetail { - title: string; - description: string; -} - -interface MCPConfigDialogProps { - isOpen: boolean; - onOpenChange: (isOpen: boolean) => void; - fieldsSchema: Record; - initialValues: Record; - onSave: (updatedValues: Record) => void; - isSubmitting?: boolean; - onRevoke?: () => void; - serverName: string; -} - -export default function MCPConfigDialog({ - isOpen, - onOpenChange, - fieldsSchema, - initialValues, - onSave, - isSubmitting = false, - onRevoke, - serverName, -}: MCPConfigDialogProps) { - const localize = useLocalize(); - const { - control, - handleSubmit, - reset, - formState: { errors, _ }, - } = useForm>({ - defaultValues: initialValues, - }); - - useEffect(() => { - if (isOpen) { - reset(initialValues); - } - }, [isOpen, initialValues, reset]); - - const onFormSubmit = (data: Record) => { - onSave(data); - }; - - const handleRevoke = () => { - if (onRevoke) { - onRevoke(); - } - }; - - const dialogTitle = localize('com_ui_configure_mcp_variables_for', { 0: serverName }); - const dialogDescription = localize('com_ui_mcp_dialog_desc'); - - return ( - - - {Object.entries(fieldsSchema).map(([key, details]) => ( -
- - ( - - )} - /> - {details.description && ( -

- )} - {errors[key] &&

{errors[key]?.message}

} -
- ))} - - } - selection={{ - selectHandler: handleSubmit(onFormSubmit), - selectClasses: 'bg-green-500 hover:bg-green-600 text-white', - selectText: isSubmitting ? localize('com_ui_saving') : localize('com_ui_save'), - }} - buttons={ - onRevoke && ( - - ) - } - footerClassName="flex justify-end gap-2 px-6 pb-6 pt-2" - showCancelButton={true} - /> -
- ); -} diff --git a/client/src/data-provider/Tools/queries.ts b/client/src/data-provider/Tools/queries.ts index 1aea8e9f5a..1e77ed72f5 100644 --- a/client/src/data-provider/Tools/queries.ts +++ b/client/src/data-provider/Tools/queries.ts @@ -40,3 +40,76 @@ export const useGetToolCalls = ( }, ); }; + +/** + * Hook for getting MCP connection status + */ +export const useMCPConnectionStatusQuery = ( + config?: UseQueryOptions, +): QueryObserverResult => { + return useQuery( + [QueryKeys.mcpConnectionStatus], + () => dataService.getMCPConnectionStatus(), + { + // refetchOnWindowFocus: false, + // refetchOnReconnect: false, + // refetchOnMount: true, + ...config, + }, + ); +}; + +/** + * Hook for getting MCP auth value flags for a specific server + */ +export const useMCPAuthValuesQuery = ( + serverName: string, + config?: UseQueryOptions< + { success: boolean; serverName: string; authValueFlags: Record }, + unknown, + { success: boolean; serverName: string; authValueFlags: Record } + >, +): QueryObserverResult< + { success: boolean; serverName: string; authValueFlags: Record }, + unknown +> => { + return useQuery< + { success: boolean; serverName: string; authValueFlags: Record }, + unknown, + { success: boolean; serverName: string; authValueFlags: Record } + >([QueryKeys.mcpAuthValues, serverName], () => dataService.getMCPAuthValues(serverName), { + // refetchOnWindowFocus: false, + // refetchOnReconnect: false, + // refetchOnMount: true, + enabled: !!serverName, + ...config, + }); +}; + +/** + * Hook for getting MCP OAuth status for a specific flow + */ +export const useMCPOAuthStatusQuery = ( + flowId: string, + config?: UseQueryOptions< + { status: string; completed: boolean; failed: boolean; error?: string }, + unknown, + { status: string; completed: boolean; failed: boolean; error?: string } + >, +): QueryObserverResult< + { status: string; completed: boolean; failed: boolean; error?: string }, + unknown +> => { + return useQuery< + { status: string; completed: boolean; failed: boolean; error?: string }, + unknown, + { status: string; completed: boolean; failed: boolean; error?: string } + >([QueryKeys.mcpOAuthStatus, flowId], () => dataService.getMCPOAuthStatus(flowId), { + refetchOnWindowFocus: false, + refetchOnReconnect: false, + refetchOnMount: true, + staleTime: 1000, // Consider data stale after 1 second for polling + enabled: !!flowId, + ...config, + }); +}; diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 811fa0df62..fef7a767c0 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -524,7 +524,9 @@ "com_ui_2fa_verified": "Successfully verified Two-Factor Authentication", "com_ui_accept": "I accept", "com_ui_action_button": "Action Button", + "com_ui_active": "Active", "com_ui_add": "Add", + "com_ui_authenticate": "Authenticate", "com_ui_add_mcp": "Add MCP", "com_ui_add_mcp_server": "Add MCP Server", "com_ui_add_model_preset": "Add a model or preset for an additional response", @@ -844,6 +846,7 @@ "com_ui_max_tags": "Maximum number allowed is {{0}}, using latest values.", "com_ui_mcp_dialog_desc": "Please enter the necessary information below.", "com_ui_mcp_enter_var": "Enter value for {{0}}", + "com_ui_mcp_initialize": "Initialize", "com_ui_mcp_server_not_found": "Server not found.", "com_ui_mcp_servers": "MCP Servers", "com_ui_mcp_url": "MCP Server URL", @@ -959,6 +962,13 @@ "com_ui_save_submit": "Save & Submit", "com_ui_saved": "Saved!", "com_ui_saving": "Saving...", + "com_ui_set": "Set", + "com_ui_unset": "Unset", + "com_ui_configuration": "Configuration", + "com_ui_mcp_auth_desc": "Configure authentication credentials for this MCP server.", + "com_ui_authorization_url": "Authorization URL", + "com_ui_continue_oauth": "Continue OAuth Flow", + "com_ui_oauth_flow_desc": "Click the button above to continue the OAuth flow in a new tab.", "com_ui_schema": "Schema", "com_ui_scope": "Scope", "com_ui_search": "Search", diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index 99e59b5467..295c9fe989 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -69,6 +69,7 @@ export class MCPConnection extends EventEmitter { private lastPingTime: number; private oauthTokens?: MCPOAuthTokens | null; private oauthRequired = false; + private oauthTimeoutId: NodeJS.Timeout | null = null; iconPath?: string; timeout?: number; url?: string; @@ -421,6 +422,7 @@ export class MCPConnection extends EventEmitter { const cleanup = () => { if (timeoutId) { clearTimeout(timeoutId); + this.oauthTimeoutId = null; } if (oauthHandledListener) { this.off('oauthHandled', oauthHandledListener); @@ -448,11 +450,26 @@ export class MCPConnection extends EventEmitter { reject(new Error(`OAuth handling timeout after ${oauthTimeout}ms`)); }, oauthTimeout); + // Store the timeout ID for potential cancellation + this.oauthTimeoutId = timeoutId; + // Listen for both success and failure events this.once('oauthHandled', oauthHandledListener); this.once('oauthFailed', oauthFailedListener); }); + // Check if there are any listeners for oauthRequired event + const hasOAuthListeners = this.listenerCount('oauthRequired') > 0; + + if (!hasOAuthListeners) { + // No OAuth handler available (like during startup), immediately fail + logger.warn( + `${this.getLogPrefix()} OAuth required but no handler available, failing immediately`, + ); + this.oauthRequired = false; + throw error; + } + // Emit the event this.emit('oauthRequired', { serverName: this.serverName, @@ -517,10 +534,11 @@ export class MCPConnection extends EventEmitter { try { await this.disconnect(); await this.connectClient(); - if (!(await this.isConnected())) { + if (!(await this.isConnected()) && !(this.isInitializing && this.oauthTokens)) { throw new Error('Connection not established'); } } catch (error) { + console.log('connection detail', this.oauthRequired); logger.error(`${this.getLogPrefix()} Connection failed:`, error); throw error; } @@ -545,6 +563,37 @@ export class MCPConnection extends EventEmitter { public async disconnect(): Promise { try { + // Cancel any pending OAuth timeout + if (this.oauthTimeoutId) { + clearTimeout(this.oauthTimeoutId); + this.oauthTimeoutId = null; + } + + if (this.transport) { + await this.client.close(); + this.transport = null; + } + if (this.connectionState === 'disconnected') { + return; + } + this.connectionState = 'disconnected'; + this.emit('connectionChange', 'disconnected'); + } finally { + this.connectPromise = null; + } + } + + public async disconnectAndStopReconnecting(): Promise { + try { + // Stop any reconnection attempts + this.shouldStopReconnecting = true; + + // Cancel any pending OAuth timeout + if (this.oauthTimeoutId) { + clearTimeout(this.oauthTimeoutId); + this.oauthTimeoutId = null; + } + if (this.transport) { await this.client.close(); this.transport = null; @@ -650,6 +699,16 @@ export class MCPConnection extends EventEmitter { this.oauthTokens = tokens; } + /** Check if OAuth is required for this connection */ + public getOAuthRequired(): boolean { + return this.oauthRequired; + } + + /** Get the current connection state */ + public getConnectionState(): t.ConnectionState { + return this.connectionState; + } + private isOAuthError(error: unknown): boolean { if (!error || typeof error !== 'object') { return false; diff --git a/packages/api/src/mcp/manager.ts b/packages/api/src/mcp/manager.ts index a3bee8a18c..f9ef5135b2 100644 --- a/packages/api/src/mcp/manager.ts +++ b/packages/api/src/mcp/manager.ts @@ -15,8 +15,9 @@ import { MCPTokenStorage } from './oauth/tokens'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils/env'; +import { EventEmitter } from 'events'; -export class MCPManager { +export class MCPManager extends EventEmitter { private static instance: MCPManager | null = null; /** App-level connections initialized at startup */ private connections: Map = new Map(); @@ -29,6 +30,10 @@ export class MCPManager { /** Store MCP server instructions */ private serverInstructions: Map = new Map(); + constructor() { + super(); + } + public static getInstance(): MCPManager { if (!MCPManager.instance) { MCPManager.instance = new MCPManager(); @@ -47,7 +52,7 @@ export class MCPManager { mcpServers: t.MCPServers; flowManager: FlowStateManager; tokenMethods?: TokenMethods; - }): Promise { + }): Promise> { this.mcpConfigs = mcpServers; if (!flowManager) { @@ -59,6 +64,7 @@ export class MCPManager { } const entries = Object.entries(mcpServers); const initializedServers = new Set(); + const oauthSkippedServers = new Set(); const connectionResults = await Promise.allSettled( entries.map(async ([serverName, config], i) => { try { @@ -70,19 +76,46 @@ export class MCPManager { }); initializedServers.add(i); } catch (error) { - logger.error(`[MCP][${serverName}] Initialization failed`, error); + // Check if this is an OAuth skipped error + if ( + error instanceof Error && + (error as Error & { isOAuthSkipped?: boolean }).isOAuthSkipped + ) { + oauthSkippedServers.add(i); + } else { + logger.error(`[MCP][${serverName}] Initialization failed`, error); + // Debug: Log the actual error for filesystem server + if (serverName === 'filesystem') { + logger.error(`[MCP][${serverName}] Error details:`, { + message: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + isOAuthError: this.isOAuthError(error), + }); + } + } } }), ); const failedConnections = connectionResults.filter( - (result): result is PromiseRejectedResult => result.status === 'rejected', + (result): result is PromiseRejectedResult => + result.status === 'rejected' && + !( + result.reason instanceof Error && + (result.reason as Error & { isOAuthSkipped?: boolean }).isOAuthSkipped + ), ); logger.info( `[MCP] Initialized ${initializedServers.size}/${entries.length} app-level server(s)`, ); + if (oauthSkippedServers.size > 0) { + logger.info( + `[MCP] ${oauthSkippedServers.size}/${entries.length} app-level server(s) skipped for OAuth`, + ); + } + if (failedConnections.length > 0) { logger.warn( `[MCP] ${failedConnections.length}/${entries.length} app-level server(s) failed to initialize`, @@ -92,6 +125,8 @@ export class MCPManager { entries.forEach(([serverName], index) => { if (initializedServers.has(index)) { logger.info(`[MCP][${serverName}] ✓ Initialized`); + } else if (oauthSkippedServers.has(index)) { + logger.info(`[MCP][${serverName}] OAuth Required`); } else { logger.info(`[MCP][${serverName}] ✗ Failed`); } @@ -99,9 +134,16 @@ export class MCPManager { if (initializedServers.size === entries.length) { logger.info('[MCP] All app-level servers initialized successfully'); - } else if (initializedServers.size === 0) { + } else if (initializedServers.size === 0 && oauthSkippedServers.size === 0) { logger.warn('[MCP] No app-level servers initialized'); } + + // Return OAuth requirement map + const oauthRequirementMap: Record = {}; + entries.forEach(([serverName], index) => { + oauthRequirementMap[serverName] = oauthSkippedServers.has(index); + }); + return oauthRequirementMap; } /** Initializes a single MCP server connection (app-level) */ @@ -166,40 +208,18 @@ export class MCPManager { logger.info(`[MCP][${serverName}] Loaded OAuth tokens`); } const connection = new MCPConnection(serverName, processedConfig, undefined, tokens); - logger.info(`[MCP][${serverName}] Setting up OAuth event listener`); - connection.on('oauthRequired', async (data) => { - logger.debug(`[MCP][${serverName}] oauthRequired event received`); - const result = await this.handleOAuthRequired({ - ...data, - flowManager, - }); - if (result?.tokens && tokenMethods?.createToken) { - try { - connection.setOAuthTokens(result.tokens); - await MCPTokenStorage.storeTokens({ - userId: CONSTANTS.SYSTEM_USER_ID, - serverName, - tokens: result.tokens, - createToken: tokenMethods.createToken, - updateToken: tokenMethods.updateToken, - findToken: tokenMethods.findToken, - clientInfo: result.clientInfo, - }); - logger.info(`[MCP][${serverName}] OAuth tokens saved to storage`); - } catch (error) { - logger.error(`[MCP][${serverName}] Failed to save OAuth tokens to storage`, error); - } - } - // Only emit oauthHandled if we actually got tokens (OAuth succeeded) - if (result?.tokens) { - connection.emit('oauthHandled'); - } else { - // OAuth failed, emit oauthFailed to properly reject the promise - logger.warn(`[MCP][${serverName}] OAuth failed, emitting oauthFailed event`); - connection.emit('oauthFailed', new Error('OAuth authentication failed')); - } + // Track OAuth skipped state explicitly + let oauthSkipped = false; + + connection.on('oauthRequired', async () => { + logger.debug(`[MCP][${serverName}] oauthRequired event received`); + oauthSkipped = true; + // Emit event to signal that initialization should be skipped + connection.emit('oauthSkipped'); + return; }); + try { const connectTimeout = processedConfig.initTimeout ?? 30000; const connectionTimeout = new Promise((_, reject) => @@ -208,13 +228,35 @@ export class MCPManager { connectTimeout, ), ); + + // Listen for oauthSkipped event to stop initialization + const oauthSkippedPromise = new Promise((resolve) => { + connection.once('oauthSkipped', () => { + logger.debug(`[MCP][${serverName}] OAuth skipped, stopping initialization`); + resolve(); + }); + }); + const connectionAttempt = this.initializeServer({ connection, logPrefix: `[MCP][${serverName}]`, flowManager, handleOAuth: false, }); - await Promise.race([connectionAttempt, connectionTimeout]); + + // Race between connection attempt, timeout, and oauthSkipped + await Promise.race([connectionAttempt, connectionTimeout, oauthSkippedPromise]); + + // Check if OAuth was explicitly skipped + if (oauthSkipped) { + // Throw a special error to signal OAuth was skipped + const oauthSkippedError = new Error(`OAuth required for ${serverName}`) as Error & { + isOAuthSkipped: boolean; + }; + oauthSkippedError.isOAuthSkipped = true; + throw oauthSkippedError; + } + if (await connection.isConnected()) { this.connections.set(serverName, connection); @@ -269,6 +311,17 @@ export class MCPManager { logger.info(`[MCP][${serverName}] ✗ Failed`); } } catch (error) { + // Debug: Log the actual error for filesystem server + if (serverName === 'filesystem') { + logger.error(`[MCP][${serverName}] Error details:`, { + message: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + isOAuthError: this.isOAuthError(error), + errorType: error?.constructor?.name, + errorKeys: error && typeof error === 'object' ? Object.keys(error) : [], + oauthSkipped, + }); + } logger.error(`[MCP][${serverName}] Initialization failed`, error); throw error; } @@ -340,6 +393,21 @@ export class MCPManager { return false; } + // Debug: Log error details for filesystem server + if (error && typeof error === 'object' && 'message' in error) { + const errorMessage = (error as { message?: string }).message; + if (errorMessage && errorMessage.includes('filesystem')) { + logger.debug('[MCP] isOAuthError check for filesystem:', { + message: errorMessage, + hasCode: 'code' in error, + code: (error as { code?: number }).code, + includes401: errorMessage.includes('401'), + includes403: errorMessage.includes('403'), + includesNon200: errorMessage.includes('Non-200 status code (401)'), + }); + } + } + // Check for SSE error with 401 status if ('message' in error && typeof error.message === 'string') { return error.message.includes('401') || error.message.includes('Non-200 status code (401)'); @@ -578,6 +646,10 @@ export class MCPManager { this.userConnections.get(userId)?.set(serverName, connection); logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`); + + // Emit event that connection is established for waiting endpoints + this.emit('connectionEstablished', { userId, serverName, connection }); + // Update timestamp on creation this.updateUserLastActivity(userId); return connection; @@ -618,7 +690,7 @@ export class MCPManager { const connection = userMap?.get(serverName); if (connection) { logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`); - await connection.disconnect(); + await connection.disconnectAndStopReconnecting(); this.removeUserConnection(userId, serverName); } } @@ -649,6 +721,7 @@ export class MCPManager { /** Returns the app-level connection (used for mapping tools, etc.) */ public getConnection(serverName: string): MCPConnection | undefined { + console.log(this.connections); return this.connections.get(serverName); } @@ -657,6 +730,12 @@ export class MCPManager { return this.connections; } + /** Returns the user-level connection if it exists (does not create one) */ + public getUserConnectionIfExists(userId: string, serverName: string): MCPConnection | undefined { + const userMap = this.userConnections.get(userId); + return userMap?.get(serverName); + } + /** Attempts to reconnect an app-level connection if it's disconnected */ private async isConnectionActive({ serverName, @@ -928,24 +1007,62 @@ export class MCPManager { /** Disconnects all app-level and user-level connections */ public async disconnectAll(): Promise { - logger.info('[MCP] Disconnecting all app-level and user-level connections...'); - - const userDisconnectPromises = Array.from(this.userConnections.keys()).map((userId) => - this.disconnectUserConnections(userId), - ); - await Promise.allSettled(userDisconnectPromises); - this.userLastActivity.clear(); + logger.info('[MCP] Disconnecting all connections...'); // Disconnect all app-level connections - const appDisconnectPromises = Array.from(this.connections.values()).map((connection) => - connection.disconnect().catch((error) => { - logger.error(`[MCP][${connection.serverName}] Error during disconnectAll:`, error); - }), - ); - await Promise.allSettled(appDisconnectPromises); + const appConnections = Array.from(this.connections.values()); + await Promise.allSettled(appConnections.map((connection) => connection.disconnect())); this.connections.clear(); - logger.info('[MCP] All connections processed for disconnection.'); + // Disconnect all user-level connections + const userConnections = Array.from(this.userConnections.values()).flatMap((userMap) => + Array.from(userMap.values()), + ); + await Promise.allSettled(userConnections.map((connection) => connection.disconnect())); + this.userConnections.clear(); + + // Clear activity timestamps + this.userLastActivity.clear(); + + logger.info('[MCP] All connections disconnected'); + } + + /** + * Get connection status for a specific user and server + */ + public async getUserConnectionStatus( + userId: string, + serverName: string, + ): Promise<{ + connected: boolean; + hasConnection: boolean; + }> { + const userConnections = this.userConnections.get(userId); + const connection = userConnections?.get(serverName); + + if (!connection) { + return { + connected: false, + hasConnection: false, + }; + } + + try { + const isConnected = await connection.isConnected(); + return { + connected: isConnected, + hasConnection: true, + }; + } catch (error) { + logger.error( + `[MCP] Error checking connection status for user ${userId}, server ${serverName}:`, + error, + ); + return { + connected: false, + hasConnection: true, + }; + } } /** Destroys the singleton instance and disconnects all connections */ diff --git a/packages/data-provider/src/api-endpoints.ts b/packages/data-provider/src/api-endpoints.ts index 3930132ce0..bd3435afa0 100644 --- a/packages/data-provider/src/api-endpoints.ts +++ b/packages/data-provider/src/api-endpoints.ts @@ -134,6 +134,15 @@ export const plugins = () => '/api/plugins'; export const mcpReinitialize = (serverName: string) => `/api/mcp/${serverName}/reinitialize`; +export const mcpReinitializeComplete = (serverName: string) => + `/api/mcp/${serverName}/reinitialize/complete`; + +export const mcpConnectionStatus = () => '/api/mcp/connection/status'; + +export const mcpAuthValues = (serverName: string) => `/api/mcp/${serverName}/auth-values`; + +export const mcpOAuthStatus = (flowId: string) => `/api/mcp/oauth/status/${flowId}`; + export const config = () => '/api/config'; export const prompts = () => '/api/prompts'; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index a6ac8b3ecf..a301e79d1e 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -606,6 +606,7 @@ export type TStartupConfig = { description: string; } >; + requiresOAuth?: boolean; } >; mcpPlaceholder?: string; diff --git a/packages/data-provider/src/data-service.ts b/packages/data-provider/src/data-service.ts index d9e8ff7d12..1bcae5520c 100644 --- a/packages/data-provider/src/data-service.ts +++ b/packages/data-provider/src/data-service.ts @@ -145,6 +145,26 @@ export const reinitializeMCPServer = (serverName: string) => { return request.post(endpoints.mcpReinitialize(serverName)); }; +export const completeMCPServerReinitialize = (serverName: string) => { + return request.post(endpoints.mcpReinitializeComplete(serverName)); +}; + +export const getMCPConnectionStatus = (): Promise => { + return request.get(endpoints.mcpConnectionStatus()); +}; + +export const getMCPAuthValues = ( + serverName: string, +): Promise<{ success: boolean; serverName: string; authValueFlags: Record }> => { + return request.get(endpoints.mcpAuthValues(serverName)); +}; + +export const getMCPOAuthStatus = ( + flowId: string, +): Promise<{ status: string; completed: boolean; failed: boolean; error?: string }> => { + return request.get(endpoints.mcpOAuthStatus(flowId)); +}; + /* Config */ export const getStartupConfig = (): Promise< diff --git a/packages/data-provider/src/keys.ts b/packages/data-provider/src/keys.ts index ec94c0f0ff..f75583ec7d 100644 --- a/packages/data-provider/src/keys.ts +++ b/packages/data-provider/src/keys.ts @@ -46,6 +46,9 @@ export enum QueryKeys { health = 'health', userTerms = 'userTerms', banner = 'banner', + mcpConnectionStatus = 'mcpConnectionStatus', + mcpAuthValues = 'mcpAuthValues', + mcpOAuthStatus = 'mcpOAuthStatus', /* Memories */ memories = 'memories', } diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index 696777131d..64044b2703 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -8,6 +8,12 @@ const BaseOptionsSchema = z.object({ initTimeout: z.number().optional(), /** Controls visibility in chat dropdown menu (MCPSelect) */ chatMenu: z.boolean().optional(), + /** + * Controls whether the MCP server should be initialized on startup + * - true: Initialize on startup (default) + * - false: Skip initialization on startup (can be initialized later) + */ + startup: z.boolean().optional(), /** * Controls server instruction behavior: * - undefined/not set: No instructions included (default) diff --git a/packages/data-provider/src/react-query/react-query-service.ts b/packages/data-provider/src/react-query/react-query-service.ts index ca7d4374fe..df2c725bed 100644 --- a/packages/data-provider/src/react-query/react-query-service.ts +++ b/packages/data-provider/src/react-query/react-query-service.ts @@ -311,13 +311,22 @@ export const useUpdateUserPluginsMutation = ( ...options, onSuccess: (...args) => { queryClient.invalidateQueries([QueryKeys.user]); + queryClient.refetchQueries([QueryKeys.tools]); onSuccess?.(...args); }, }); }; export const useReinitializeMCPServerMutation = (): UseMutationResult< - { success: boolean; message: string; serverName: string }, + { + success: boolean; + message: string; + serverName: string; + oauthRequired?: boolean; + oauthCompleted?: boolean; + authURL?: string; + flowId?: string; + }, unknown, string, unknown @@ -330,6 +339,54 @@ export const useReinitializeMCPServerMutation = (): UseMutationResult< }); }; +export const useCompleteMCPServerReinitializeMutation = (): UseMutationResult< + { + success: boolean; + message: string; + serverName: string; + }, + unknown, + string, + unknown +> => { + const queryClient = useQueryClient(); + return useMutation( + (serverName: string) => dataService.completeMCPServerReinitialize(serverName), + { + onSuccess: () => { + queryClient.refetchQueries([QueryKeys.tools]); + queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]); + }, + }, + ); +}; + +export const useMCPOAuthStatusQuery = ( + flowId: string, + config?: UseQueryOptions< + { status: string; completed: boolean; failed: boolean; error?: string }, + unknown, + { status: string; completed: boolean; failed: boolean; error?: string } + >, +): QueryObserverResult< + { status: string; completed: boolean; failed: boolean; error?: string }, + unknown +> => { + return useQuery< + { status: string; completed: boolean; failed: boolean; error?: string }, + unknown, + { status: string; completed: boolean; failed: boolean; error?: string } + >([QueryKeys.mcpOAuthStatus, flowId], () => dataService.getMCPOAuthStatus(flowId), { + refetchOnWindowFocus: false, + refetchOnReconnect: false, + refetchOnMount: true, + staleTime: 1000, // Consider data stale after 1 second for polling + enabled: !!flowId, + refetchInterval: flowId ? 2000 : false, // Poll every 2 seconds when OAuth is active + ...config, + }); +}; + export const useGetCustomConfigSpeechQuery = ( config?: UseQueryOptions, ): QueryObserverResult => { diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index a1494d68f5..c845227904 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -417,6 +417,7 @@ export const tPluginAuthConfigSchema = z.object({ authField: z.string(), label: z.string(), description: z.string(), + requiresOAuth: z.boolean().optional(), }); export type TPluginAuthConfig = z.infer; diff --git a/packages/data-provider/src/types.ts b/packages/data-provider/src/types.ts index b1fbc42653..b49c8860da 100644 --- a/packages/data-provider/src/types.ts +++ b/packages/data-provider/src/types.ts @@ -632,3 +632,14 @@ export type TBalanceResponse = { lastRefill?: Date; refillAmount?: number; }; + +export type TMCPConnectionStatus = { + connected: boolean; + hasAuthConfig: boolean; + hasConnection: boolean; + isAppLevel: boolean; + isUserLevel: boolean; + requiresOAuth: boolean; +}; + +export type TMCPConnectionStatusResponse = Record; diff --git a/packages/data-schemas/src/methods/pluginAuth.ts b/packages/data-schemas/src/methods/pluginAuth.ts index 5355fec50c..c1af1f3fea 100644 --- a/packages/data-schemas/src/methods/pluginAuth.ts +++ b/packages/data-schemas/src/methods/pluginAuth.ts @@ -61,15 +61,28 @@ export function createPluginAuthMethods(mongoose: typeof import('mongoose')) { }: UpdatePluginAuthParams): Promise { try { const PluginAuth: Model = mongoose.models.PluginAuth; - const existingAuth = await PluginAuth.findOne({ userId, pluginKey, authField }).lean(); + + // First try to find existing record by { userId, authField } (for backward compatibility) + let existingAuth = await PluginAuth.findOne({ userId, authField }).lean(); + + // If not found and pluginKey is provided, try to find by { userId, pluginKey, authField } + if (!existingAuth && pluginKey) { + existingAuth = await PluginAuth.findOne({ userId, pluginKey, authField }).lean(); + } if (existingAuth) { + // Update existing record, preserving the original structure + const updateQuery = existingAuth.pluginKey + ? { userId, pluginKey: existingAuth.pluginKey, authField } + : { userId, authField }; + return await PluginAuth.findOneAndUpdate( - { userId, pluginKey, authField }, + updateQuery, { $set: { value } }, { new: true, upsert: true }, ).lean(); } else { + // Create new record const newPluginAuth = await new PluginAuth({ userId, authField, @@ -109,7 +122,16 @@ export function createPluginAuthMethods(mongoose: typeof import('mongoose')) { throw new Error('authField is required when all is false'); } - return await PluginAuth.deleteOne({ userId, authField }); + // Build the filter based on available parameters + const filter: { userId: string; authField: string; pluginKey?: string } = { + userId, + authField, + }; + if (pluginKey) { + filter.pluginKey = pluginKey; + } + + return await PluginAuth.deleteOne(filter); } catch (error) { throw new Error( `Failed to delete plugin auth: ${error instanceof Error ? error.message : 'Unknown error'}`,