diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index c49ba4cc31..35bba77ae6 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -4,6 +4,7 @@ const { MCPOAuthHandler } = require('@librechat/api'); const { CacheKeys, Constants } = require('librechat-data-provider'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config'); +const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { getMCPManager, getFlowStateManager } = require('~/config'); const { requireJwtAuth } = require('~/server/middleware'); @@ -468,7 +469,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { /** * Get connection status for all MCP servers - * This endpoint returns the actual connection status from MCPManager without disconnecting idle connections + * This endpoint returns all app level and user-scoped connection statuses from MCPManager without disconnecting idle connections */ router.get('/connection/status', requireJwtAuth, async (req, res) => { try { @@ -478,84 +479,19 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { return res.status(401).json({ error: 'User not authenticated' }); } - const mcpManager = getMCPManager(user.id); + const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( + user.id, + ); const connectionStatus = {}; - const printConfig = false; - const config = await loadCustomConfig(printConfig); - const mcpConfig = config?.mcpServers; - - const appConnections = mcpManager.getAllConnections() || new Map(); - const userConnections = mcpManager.getUserConnections(user.id) || new Map(); - const oauthServers = mcpManager.getOAuthServers() || new Set(); - - if (!mcpConfig) { - return res.status(404).json({ error: 'MCP config not found' }); - } - - // Get flow manager to check for active/timed-out OAuth flows - const flowsCache = getLogStores(CacheKeys.FLOWS); - const flowManager = getFlowStateManager(flowsCache); - for (const [serverName] of Object.entries(mcpConfig)) { - const getConnectionState = (serverName) => - appConnections.get(serverName)?.connectionState ?? - userConnections.get(serverName)?.connectionState ?? - 'disconnected'; - - const baseConnectionState = getConnectionState(serverName); - - let hasActiveOAuthFlow = false; - let hasFailedOAuthFlow = false; - - if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) { - try { - // Check for user-specific OAuth flows - const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); - const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); - if (flowState) { - // Check if flow failed or timed out - const flowAge = Date.now() - flowState.createdAt; - const flowTTL = flowState.ttl || 180000; // Default 3 minutes - - if (flowState.status === 'FAILED' || flowAge > flowTTL) { - hasFailedOAuthFlow = true; - logger.debug(`[MCP Connection Status] Found failed OAuth flow for ${serverName}`, { - flowId, - status: flowState.status, - flowAge, - flowTTL, - timedOut: flowAge > flowTTL, - }); - } else if (flowState.status === 'PENDING') { - hasActiveOAuthFlow = true; - logger.debug(`[MCP Connection Status] Found active OAuth flow for ${serverName}`, { - flowId, - flowAge, - flowTTL, - }); - } - } - } catch (error) { - logger.error( - `[MCP Connection Status] Error checking OAuth flows for ${serverName}:`, - error, - ); - } - } - - // Determine the final connection state - let finalConnectionState = baseConnectionState; - if (hasFailedOAuthFlow) { - finalConnectionState = 'error'; // Report as error if OAuth failed - } else if (hasActiveOAuthFlow && baseConnectionState === 'disconnected') { - finalConnectionState = 'connecting'; // Still waiting for OAuth - } - - connectionStatus[serverName] = { - requiresOAuth: oauthServers.has(serverName), - connectionState: finalConnectionState, - }; + connectionStatus[serverName] = await getServerConnectionStatus( + user.id, + serverName, + appConnections, + userConnections, + oauthServers, + ); } res.json({ @@ -563,11 +499,67 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { connectionStatus, }); } catch (error) { + if (error.message === 'MCP config not found') { + return res.status(404).json({ error: error.message }); + } logger.error('[MCP Connection Status] Failed to get connection status', error); res.status(500).json({ error: 'Failed to get connection status' }); } }); +/** + * Get connection status for a single MCP server + * This endpoint returns the connection status for a specific server for a given user + */ +router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => { + try { + const user = req.user; + const { serverName } = req.params; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!serverName) { + return res.status(400).json({ error: 'Server name is required' }); + } + + const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( + user.id, + ); + + if (!mcpConfig[serverName]) { + return res + .status(404) + .json({ error: `MCP server '${serverName}' not found in configuration` }); + } + + const serverStatus = await getServerConnectionStatus( + user.id, + serverName, + appConnections, + userConnections, + oauthServers, + ); + + res.json({ + success: true, + serverName, + connectionStatus: serverStatus.connectionState, + requiresOAuth: serverStatus.requiresOAuth, + }); + } catch (error) { + if (error.message === 'MCP config not found') { + return res.status(404).json({ error: error.message }); + } + logger.error( + `[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`, + error, + ); + res.status(500).json({ error: 'Failed to get connection status' }); + } +}); + /** * Check which authentication values exist for a specific MCP server * This endpoint returns only boolean flags indicating if values are set, not the actual values diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 9970981828..f8ec2d04d5 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -12,7 +12,7 @@ const { } = require('@librechat/api'); const { findToken, createToken, updateToken } = require('~/models'); const { getMCPManager, getFlowStateManager } = require('~/config'); -const { getCachedTools } = require('./Config'); +const { getCachedTools, loadCustomConfig } = require('./Config'); const { getLogStores } = require('~/cache'); /** @@ -239,6 +239,123 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) { return toolInstance; } +/** + * Get MCP setup data including config, connections, and OAuth servers + * @param {string} userId - The user ID + * @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers + */ +async function getMCPSetupData(userId) { + const printConfig = false; + const config = await loadCustomConfig(printConfig); + const mcpConfig = config?.mcpServers; + + if (!mcpConfig) { + throw new Error('MCP config not found'); + } + + const mcpManager = getMCPManager(userId); + const appConnections = mcpManager.getAllConnections() || new Map(); + const userConnections = mcpManager.getUserConnections(userId) || new Map(); + const oauthServers = mcpManager.getOAuthServers() || new Set(); + + return { + mcpConfig, + appConnections, + userConnections, + oauthServers, + }; +} + +/** + * Check OAuth flow status for a user and server + * @param {string} userId - The user ID + * @param {string} serverName - The server name + * @returns {Object} Object containing hasActiveFlow and hasFailedFlow flags + */ +async function checkOAuthFlowStatus(userId, serverName) { + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const flowId = MCPOAuthHandler.generateFlowId(userId, serverName); + + try { + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (!flowState) { + return { hasActiveFlow: false, hasFailedFlow: false }; + } + + const flowAge = Date.now() - flowState.createdAt; + const flowTTL = flowState.ttl || 180000; // Default 3 minutes + + if (flowState.status === 'FAILED' || flowAge > flowTTL) { + logger.debug(`[MCP Connection Status] Found failed OAuth flow for ${serverName}`, { + flowId, + status: flowState.status, + flowAge, + flowTTL, + timedOut: flowAge > flowTTL, + }); + return { hasActiveFlow: false, hasFailedFlow: true }; + } + + if (flowState.status === 'PENDING') { + logger.debug(`[MCP Connection Status] Found active OAuth flow for ${serverName}`, { + flowId, + flowAge, + flowTTL, + }); + return { hasActiveFlow: true, hasFailedFlow: false }; + } + + return { hasActiveFlow: false, hasFailedFlow: false }; + } catch (error) { + logger.error(`[MCP Connection Status] Error checking OAuth flows for ${serverName}:`, error); + return { hasActiveFlow: false, hasFailedFlow: false }; + } +} + +/** + * Get connection status for a specific MCP server + * @param {string} userId - The user ID + * @param {string} serverName - The server name + * @param {Map} appConnections - App-level connections + * @param {Map} userConnections - User-level connections + * @param {Set} oauthServers - Set of OAuth servers + * @returns {Object} Object containing requiresOAuth and connectionState + */ +async function getServerConnectionStatus( + userId, + serverName, + appConnections, + userConnections, + oauthServers, +) { + const getConnectionState = () => + appConnections.get(serverName)?.connectionState ?? + userConnections.get(serverName)?.connectionState ?? + 'disconnected'; + + const baseConnectionState = getConnectionState(); + let finalConnectionState = baseConnectionState; + + if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) { + const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName); + + if (hasFailedFlow) { + finalConnectionState = 'error'; + } else if (hasActiveFlow) { + finalConnectionState = 'connecting'; + } + } + + return { + requiresOAuth: oauthServers.has(serverName), + connectionState: finalConnectionState, + }; +} + module.exports = { createMCPTool, + getMCPSetupData, + checkOAuthFlowStatus, + getServerConnectionStatus, }; diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js new file mode 100644 index 0000000000..8c81abd685 --- /dev/null +++ b/api/server/services/MCP.spec.js @@ -0,0 +1,510 @@ +const { logger } = require('@librechat/data-schemas'); +const { MCPOAuthHandler } = require('@librechat/api'); +const { CacheKeys } = require('librechat-data-provider'); +const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus } = require('./MCP'); + +// Mock all dependencies +jest.mock('@librechat/data-schemas', () => ({ + logger: { + debug: jest.fn(), + error: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + MCPOAuthHandler: { + generateFlowId: jest.fn(), + }, +})); + +jest.mock('librechat-data-provider', () => ({ + CacheKeys: { + FLOWS: 'flows', + }, +})); + +jest.mock('./Config', () => ({ + loadCustomConfig: jest.fn(), +})); + +jest.mock('~/config', () => ({ + getMCPManager: jest.fn(), + getFlowStateManager: jest.fn(), +})); + +jest.mock('~/cache', () => ({ + getLogStores: jest.fn(), +})); + +jest.mock('~/models', () => ({ + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), +})); + +describe('tests for the new helper functions used by the MCP connection status endpoints', () => { + let mockLoadCustomConfig; + let mockGetMCPManager; + let mockGetFlowStateManager; + let mockGetLogStores; + + beforeEach(() => { + jest.clearAllMocks(); + + mockLoadCustomConfig = require('./Config').loadCustomConfig; + mockGetMCPManager = require('~/config').getMCPManager; + mockGetFlowStateManager = require('~/config').getFlowStateManager; + mockGetLogStores = require('~/cache').getLogStores; + }); + + describe('getMCPSetupData', () => { + const mockUserId = 'user-123'; + const mockConfig = { + mcpServers: { + server1: { type: 'stdio' }, + server2: { type: 'http' }, + }, + }; + + beforeEach(() => { + mockGetMCPManager.mockReturnValue({ + getAllConnections: jest.fn(() => new Map()), + getUserConnections: jest.fn(() => new Map()), + getOAuthServers: jest.fn(() => new Set()), + }); + }); + + it('should successfully return MCP setup data', async () => { + mockLoadCustomConfig.mockResolvedValue(mockConfig); + + const mockAppConnections = new Map([['server1', { status: 'connected' }]]); + const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]); + const mockOAuthServers = new Set(['server2']); + + const mockMCPManager = { + getAllConnections: jest.fn(() => mockAppConnections), + getUserConnections: jest.fn(() => mockUserConnections), + getOAuthServers: jest.fn(() => mockOAuthServers), + }; + mockGetMCPManager.mockReturnValue(mockMCPManager); + + const result = await getMCPSetupData(mockUserId); + + expect(mockLoadCustomConfig).toHaveBeenCalledWith(false); + expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); + expect(mockMCPManager.getAllConnections).toHaveBeenCalled(); + expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); + expect(mockMCPManager.getOAuthServers).toHaveBeenCalled(); + + expect(result).toEqual({ + mcpConfig: mockConfig.mcpServers, + appConnections: mockAppConnections, + userConnections: mockUserConnections, + oauthServers: mockOAuthServers, + }); + }); + + it('should throw error when MCP config not found', async () => { + mockLoadCustomConfig.mockResolvedValue({}); + await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found'); + }); + + it('should handle null values from MCP manager gracefully', async () => { + mockLoadCustomConfig.mockResolvedValue(mockConfig); + + const mockMCPManager = { + getAllConnections: jest.fn(() => null), + getUserConnections: jest.fn(() => null), + getOAuthServers: jest.fn(() => null), + }; + mockGetMCPManager.mockReturnValue(mockMCPManager); + + const result = await getMCPSetupData(mockUserId); + + expect(result).toEqual({ + mcpConfig: mockConfig.mcpServers, + appConnections: new Map(), + userConnections: new Map(), + oauthServers: new Set(), + }); + }); + }); + + describe('checkOAuthFlowStatus', () => { + const mockUserId = 'user-123'; + const mockServerName = 'test-server'; + const mockFlowId = 'flow-123'; + + beforeEach(() => { + const mockFlowsCache = {}; + const mockFlowManager = { + getFlowState: jest.fn(), + }; + + mockGetLogStores.mockReturnValue(mockFlowsCache); + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + MCPOAuthHandler.generateFlowId.mockReturnValue(mockFlowId); + }); + + it('should return false flags when no flow state exists', async () => { + const mockFlowManager = { getFlowState: jest.fn(() => null) }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + + const result = await checkOAuthFlowStatus(mockUserId, mockServerName); + + expect(mockGetLogStores).toHaveBeenCalledWith(CacheKeys.FLOWS); + expect(MCPOAuthHandler.generateFlowId).toHaveBeenCalledWith(mockUserId, mockServerName); + expect(mockFlowManager.getFlowState).toHaveBeenCalledWith(mockFlowId, 'mcp_oauth'); + expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false }); + }); + + it('should detect failed flow when status is FAILED', async () => { + const mockFlowState = { + status: 'FAILED', + createdAt: Date.now() - 60000, // 1 minute ago + ttl: 180000, + }; + const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + + const result = await checkOAuthFlowStatus(mockUserId, mockServerName); + + expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true }); + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining('Found failed OAuth flow'), + expect.objectContaining({ + flowId: mockFlowId, + status: 'FAILED', + }), + ); + }); + + it('should detect failed flow when flow has timed out', async () => { + const mockFlowState = { + status: 'PENDING', + createdAt: Date.now() - 200000, // 200 seconds ago (> 180s TTL) + ttl: 180000, + }; + const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + + const result = await checkOAuthFlowStatus(mockUserId, mockServerName); + + expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true }); + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining('Found failed OAuth flow'), + expect.objectContaining({ + timedOut: true, + }), + ); + }); + + it('should detect failed flow when TTL not specified and flow exceeds default TTL', async () => { + const mockFlowState = { + status: 'PENDING', + createdAt: Date.now() - 200000, // 200 seconds ago (> 180s default TTL) + // ttl not specified, should use 180000 default + }; + const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + + const result = await checkOAuthFlowStatus(mockUserId, mockServerName); + + expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true }); + }); + + it('should detect active flow when status is PENDING and within TTL', async () => { + const mockFlowState = { + status: 'PENDING', + createdAt: Date.now() - 60000, // 1 minute ago (< 180s TTL) + ttl: 180000, + }; + const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + + const result = await checkOAuthFlowStatus(mockUserId, mockServerName); + + expect(result).toEqual({ hasActiveFlow: true, hasFailedFlow: false }); + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining('Found active OAuth flow'), + expect.objectContaining({ + flowId: mockFlowId, + }), + ); + }); + + it('should return false flags for other statuses', async () => { + const mockFlowState = { + status: 'COMPLETED', + createdAt: Date.now() - 60000, + ttl: 180000, + }; + const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + + const result = await checkOAuthFlowStatus(mockUserId, mockServerName); + + expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false }); + }); + + it('should handle errors gracefully', async () => { + const mockError = new Error('Flow state error'); + const mockFlowManager = { + getFlowState: jest.fn(() => { + throw mockError; + }), + }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + + const result = await checkOAuthFlowStatus(mockUserId, mockServerName); + + expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false }); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Error checking OAuth flows'), + mockError, + ); + }); + }); + + describe('getServerConnectionStatus', () => { + const mockUserId = 'user-123'; + const mockServerName = 'test-server'; + + it('should return app connection state when available', async () => { + const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]); + const userConnections = new Map(); + const oauthServers = new Set(); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: false, + connectionState: 'connected', + }); + }); + + it('should fallback to user connection state when app connection not available', async () => { + const appConnections = new Map(); + const userConnections = new Map([[mockServerName, { connectionState: 'connecting' }]]); + const oauthServers = new Set(); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: false, + connectionState: 'connecting', + }); + }); + + it('should default to disconnected when no connections exist', async () => { + const appConnections = new Map(); + const userConnections = new Map(); + const oauthServers = new Set(); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: false, + connectionState: 'disconnected', + }); + }); + + it('should prioritize app connection over user connection', async () => { + const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]); + const userConnections = new Map([[mockServerName, { connectionState: 'disconnected' }]]); + const oauthServers = new Set(); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: false, + connectionState: 'connected', + }); + }); + + it('should indicate OAuth requirement when server is in OAuth servers set', async () => { + const appConnections = new Map(); + const userConnections = new Map(); + const oauthServers = new Set([mockServerName]); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result.requiresOAuth).toBe(true); + }); + + it('should handle OAuth flow status when disconnected and requires OAuth with failed flow', async () => { + const appConnections = new Map(); + const userConnections = new Map(); + const oauthServers = new Set([mockServerName]); + + // Mock flow state to return failed flow + const mockFlowManager = { + getFlowState: jest.fn(() => ({ + status: 'FAILED', + createdAt: Date.now() - 60000, + ttl: 180000, + })), + }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + mockGetLogStores.mockReturnValue({}); + MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id'); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: true, + connectionState: 'error', + }); + }); + + it('should handle OAuth flow status when disconnected and requires OAuth with active flow', async () => { + const appConnections = new Map(); + const userConnections = new Map(); + const oauthServers = new Set([mockServerName]); + + // Mock flow state to return active flow + const mockFlowManager = { + getFlowState: jest.fn(() => ({ + status: 'PENDING', + createdAt: Date.now() - 60000, // 1 minute ago + ttl: 180000, // 3 minutes TTL + })), + }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + mockGetLogStores.mockReturnValue({}); + MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id'); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: true, + connectionState: 'connecting', + }); + }); + + it('should handle OAuth flow status when disconnected and requires OAuth with no flow', async () => { + const appConnections = new Map(); + const userConnections = new Map(); + const oauthServers = new Set([mockServerName]); + + // Mock flow state to return no flow + const mockFlowManager = { + getFlowState: jest.fn(() => null), + }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + mockGetLogStores.mockReturnValue({}); + MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id'); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: true, + connectionState: 'disconnected', + }); + }); + + it('should not check OAuth flow status when server is connected', async () => { + const mockFlowManager = { + getFlowState: jest.fn(), + }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + mockGetLogStores.mockReturnValue({}); + + const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]); + const userConnections = new Map(); + const oauthServers = new Set([mockServerName]); + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: true, + connectionState: 'connected', + }); + + // Should not call flow manager since server is connected + expect(mockFlowManager.getFlowState).not.toHaveBeenCalled(); + }); + + it('should not check OAuth flow status when server does not require OAuth', async () => { + const mockFlowManager = { + getFlowState: jest.fn(), + }; + mockGetFlowStateManager.mockReturnValue(mockFlowManager); + mockGetLogStores.mockReturnValue({}); + + const appConnections = new Map(); + const userConnections = new Map(); + const oauthServers = new Set(); // Server not in OAuth servers + + const result = await getServerConnectionStatus( + mockUserId, + mockServerName, + appConnections, + userConnections, + oauthServers, + ); + + expect(result).toEqual({ + requiresOAuth: false, + connectionState: 'disconnected', + }); + + // Should not call flow manager since server doesn't require OAuth + expect(mockFlowManager.getFlowState).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/client/src/components/MCP/ServerInitializationSection.tsx b/client/src/components/MCP/ServerInitializationSection.tsx index ea30f71356..36c9ca6b17 100644 --- a/client/src/components/MCP/ServerInitializationSection.tsx +++ b/client/src/components/MCP/ServerInitializationSection.tsx @@ -1,7 +1,8 @@ -import React, { useState, useCallback } from 'react'; +import React, { useCallback } from 'react'; import { Button } from '@librechat/client'; import { RefreshCw, Link } from 'lucide-react'; -import { useLocalize, useMCPServerInitialization } from '~/hooks'; +import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager'; +import { useLocalize } from '~/hooks'; interface ServerInitializationSectionProps { serverName: string; @@ -14,32 +15,27 @@ export default function ServerInitializationSection({ }: ServerInitializationSectionProps) { const localize = useLocalize(); - const [oauthUrl, setOauthUrl] = useState(null); - - // Use the shared initialization hook - const { initializeServer, isLoading, connectionStatus, cancelOAuthFlow, isCancellable } = - useMCPServerInitialization({ - onOAuthStarted: (name, url) => { - // Store the OAuth URL locally for display - setOauthUrl(url); - }, - onSuccess: () => { - // Clear OAuth URL on success - setOauthUrl(null); - }, - }); + // Use the centralized server manager instead of the old initialization hook so we can handle multiple oauth flows at once + const { + initializeServer, + connectionStatus, + cancelOAuthFlow, + isInitializing, + isCancellable, + getOAuthUrl, + } = useMCPServerManager(); const serverStatus = connectionStatus[serverName]; const isConnected = serverStatus?.connectionState === 'connected'; const canCancel = isCancellable(serverName); + const isServerInitializing = isInitializing(serverName); + const serverOAuthUrl = getOAuthUrl(serverName); const handleInitializeClick = useCallback(() => { - setOauthUrl(null); initializeServer(serverName); }, [initializeServer, serverName]); const handleCancelClick = useCallback(() => { - setOauthUrl(null); cancelOAuthFlow(serverName); }, [cancelOAuthFlow, serverName]); @@ -49,11 +45,11 @@ export default function ServerInitializationSection({
); @@ -70,13 +66,13 @@ export default function ServerInitializationSection({ {/* Only show authenticate button when OAuth URL is not present */} - {!oauthUrl && ( + {!serverOAuthUrl && (