🏹 feat: Concurrent MCP Initialization Support (#8677)

*  feat: Enhance MCP Connection Status Management

- Introduced new functions to retrieve and manage connection status for multiple MCP servers, including OAuth flow checks and server-specific status retrieval.
- Refactored the MCP connection status endpoints to support both all servers and individual server queries.
- Replaced the old server initialization hook with a new `useMCPServerManager` hook for improved state management and handling of multiple OAuth flows.
- Updated the MCPPanel component to utilize the new context provider for better state handling and UI updates.
- Fixed a number of UI bugs when initializing servers

* 🗣️ i18n: Remove unused strings from translation.json

* refactor: move helper functions out of the route module into mcp service file

* ci: add tests for newly added functions in mcp service file

* fix: memoize setMCPValues to avoid render loop
This commit is contained in:
Dustin Healy 2025-07-28 09:25:34 -07:00 committed by GitHub
parent 37aba18a96
commit 0ef3fefaec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1092 additions and 542 deletions

View file

@ -4,6 +4,7 @@ const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys, Constants } = require('librechat-data-provider'); const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config'); const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { getMCPManager, getFlowStateManager } = require('~/config'); const { getMCPManager, getFlowStateManager } = require('~/config');
const { requireJwtAuth } = require('~/server/middleware'); const { requireJwtAuth } = require('~/server/middleware');
@ -468,7 +469,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
/** /**
* Get connection status for all MCP servers * Get connection status for all MCP servers
* This endpoint returns the actual connection status from MCPManager without disconnecting idle connections * This endpoint returns all app level and user-scoped connection statuses from MCPManager without disconnecting idle connections
*/ */
router.get('/connection/status', requireJwtAuth, async (req, res) => { router.get('/connection/status', requireJwtAuth, async (req, res) => {
try { try {
@ -478,96 +479,87 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
return res.status(401).json({ error: 'User not authenticated' }); return res.status(401).json({ error: 'User not authenticated' });
} }
const mcpManager = getMCPManager(user.id); const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
user.id,
);
const connectionStatus = {}; const connectionStatus = {};
const printConfig = false;
const config = await loadCustomConfig(printConfig);
const mcpConfig = config?.mcpServers;
const appConnections = mcpManager.getAllConnections() || new Map();
const userConnections = mcpManager.getUserConnections(user.id) || new Map();
const oauthServers = mcpManager.getOAuthServers() || new Set();
if (!mcpConfig) {
return res.status(404).json({ error: 'MCP config not found' });
}
// Get flow manager to check for active/timed-out OAuth flows
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
for (const [serverName] of Object.entries(mcpConfig)) { for (const [serverName] of Object.entries(mcpConfig)) {
const getConnectionState = (serverName) => connectionStatus[serverName] = await getServerConnectionStatus(
appConnections.get(serverName)?.connectionState ?? user.id,
userConnections.get(serverName)?.connectionState ?? serverName,
'disconnected'; appConnections,
userConnections,
const baseConnectionState = getConnectionState(serverName); oauthServers,
let hasActiveOAuthFlow = false;
let hasFailedOAuthFlow = false;
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
try {
// Check for user-specific OAuth flows
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (flowState) {
// Check if flow failed or timed out
const flowAge = Date.now() - flowState.createdAt;
const flowTTL = flowState.ttl || 180000; // Default 3 minutes
if (flowState.status === 'FAILED' || flowAge > flowTTL) {
hasFailedOAuthFlow = true;
logger.debug(`[MCP Connection Status] Found failed OAuth flow for ${serverName}`, {
flowId,
status: flowState.status,
flowAge,
flowTTL,
timedOut: flowAge > flowTTL,
});
} else if (flowState.status === 'PENDING') {
hasActiveOAuthFlow = true;
logger.debug(`[MCP Connection Status] Found active OAuth flow for ${serverName}`, {
flowId,
flowAge,
flowTTL,
});
}
}
} catch (error) {
logger.error(
`[MCP Connection Status] Error checking OAuth flows for ${serverName}:`,
error,
); );
} }
}
// Determine the final connection state
let finalConnectionState = baseConnectionState;
if (hasFailedOAuthFlow) {
finalConnectionState = 'error'; // Report as error if OAuth failed
} else if (hasActiveOAuthFlow && baseConnectionState === 'disconnected') {
finalConnectionState = 'connecting'; // Still waiting for OAuth
}
connectionStatus[serverName] = {
requiresOAuth: oauthServers.has(serverName),
connectionState: finalConnectionState,
};
}
res.json({ res.json({
success: true, success: true,
connectionStatus, connectionStatus,
}); });
} catch (error) { } catch (error) {
if (error.message === 'MCP config not found') {
return res.status(404).json({ error: error.message });
}
logger.error('[MCP Connection Status] Failed to get connection status', error); logger.error('[MCP Connection Status] Failed to get connection status', error);
res.status(500).json({ error: 'Failed to get connection status' }); res.status(500).json({ error: 'Failed to get connection status' });
} }
}); });
/**
* Get connection status for a single MCP server
* This endpoint returns the connection status for a specific server for a given user
*/
router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => {
try {
const user = req.user;
const { serverName } = req.params;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
if (!serverName) {
return res.status(400).json({ error: 'Server name is required' });
}
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
user.id,
);
if (!mcpConfig[serverName]) {
return res
.status(404)
.json({ error: `MCP server '${serverName}' not found in configuration` });
}
const serverStatus = await getServerConnectionStatus(
user.id,
serverName,
appConnections,
userConnections,
oauthServers,
);
res.json({
success: true,
serverName,
connectionStatus: serverStatus.connectionState,
requiresOAuth: serverStatus.requiresOAuth,
});
} catch (error) {
if (error.message === 'MCP config not found') {
return res.status(404).json({ error: error.message });
}
logger.error(
`[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`,
error,
);
res.status(500).json({ error: 'Failed to get connection status' });
}
});
/** /**
* Check which authentication values exist for a specific MCP server * Check which authentication values exist for a specific MCP server
* This endpoint returns only boolean flags indicating if values are set, not the actual values * This endpoint returns only boolean flags indicating if values are set, not the actual values

View file

@ -12,7 +12,7 @@ const {
} = require('@librechat/api'); } = require('@librechat/api');
const { findToken, createToken, updateToken } = require('~/models'); const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config'); const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools } = require('./Config'); const { getCachedTools, loadCustomConfig } = require('./Config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
/** /**
@ -239,6 +239,123 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
return toolInstance; return toolInstance;
} }
/**
* Get MCP setup data including config, connections, and OAuth servers
* @param {string} userId - The user ID
* @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers
*/
async function getMCPSetupData(userId) {
const printConfig = false;
const config = await loadCustomConfig(printConfig);
const mcpConfig = config?.mcpServers;
if (!mcpConfig) {
throw new Error('MCP config not found');
}
const mcpManager = getMCPManager(userId);
const appConnections = mcpManager.getAllConnections() || new Map();
const userConnections = mcpManager.getUserConnections(userId) || new Map();
const oauthServers = mcpManager.getOAuthServers() || new Set();
return {
mcpConfig,
appConnections,
userConnections,
oauthServers,
};
}
/**
* Check OAuth flow status for a user and server
* @param {string} userId - The user ID
* @param {string} serverName - The server name
* @returns {Object} Object containing hasActiveFlow and hasFailedFlow flags
*/
async function checkOAuthFlowStatus(userId, serverName) {
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
try {
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (!flowState) {
return { hasActiveFlow: false, hasFailedFlow: false };
}
const flowAge = Date.now() - flowState.createdAt;
const flowTTL = flowState.ttl || 180000; // Default 3 minutes
if (flowState.status === 'FAILED' || flowAge > flowTTL) {
logger.debug(`[MCP Connection Status] Found failed OAuth flow for ${serverName}`, {
flowId,
status: flowState.status,
flowAge,
flowTTL,
timedOut: flowAge > flowTTL,
});
return { hasActiveFlow: false, hasFailedFlow: true };
}
if (flowState.status === 'PENDING') {
logger.debug(`[MCP Connection Status] Found active OAuth flow for ${serverName}`, {
flowId,
flowAge,
flowTTL,
});
return { hasActiveFlow: true, hasFailedFlow: false };
}
return { hasActiveFlow: false, hasFailedFlow: false };
} catch (error) {
logger.error(`[MCP Connection Status] Error checking OAuth flows for ${serverName}:`, error);
return { hasActiveFlow: false, hasFailedFlow: false };
}
}
/**
* Get connection status for a specific MCP server
* @param {string} userId - The user ID
* @param {string} serverName - The server name
* @param {Map} appConnections - App-level connections
* @param {Map} userConnections - User-level connections
* @param {Set} oauthServers - Set of OAuth servers
* @returns {Object} Object containing requiresOAuth and connectionState
*/
async function getServerConnectionStatus(
userId,
serverName,
appConnections,
userConnections,
oauthServers,
) {
const getConnectionState = () =>
appConnections.get(serverName)?.connectionState ??
userConnections.get(serverName)?.connectionState ??
'disconnected';
const baseConnectionState = getConnectionState();
let finalConnectionState = baseConnectionState;
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
}
return {
requiresOAuth: oauthServers.has(serverName),
connectionState: finalConnectionState,
};
}
module.exports = { module.exports = {
createMCPTool, createMCPTool,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
}; };

View file

@ -0,0 +1,510 @@
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus } = require('./MCP');
// Mock all dependencies
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
error: jest.fn(),
},
}));
jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
}));
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
FLOWS: 'flows',
},
}));
jest.mock('./Config', () => ({
loadCustomConfig: jest.fn(),
}));
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(),
}));
jest.mock('~/models', () => ({
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
}));
describe('tests for the new helper functions used by the MCP connection status endpoints', () => {
let mockLoadCustomConfig;
let mockGetMCPManager;
let mockGetFlowStateManager;
let mockGetLogStores;
beforeEach(() => {
jest.clearAllMocks();
mockLoadCustomConfig = require('./Config').loadCustomConfig;
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
});
describe('getMCPSetupData', () => {
const mockUserId = 'user-123';
const mockConfig = {
mcpServers: {
server1: { type: 'stdio' },
server2: { type: 'http' },
},
};
beforeEach(() => {
mockGetMCPManager.mockReturnValue({
getAllConnections: jest.fn(() => new Map()),
getUserConnections: jest.fn(() => new Map()),
getOAuthServers: jest.fn(() => new Set()),
});
});
it('should successfully return MCP setup data', async () => {
mockLoadCustomConfig.mockResolvedValue(mockConfig);
const mockAppConnections = new Map([['server1', { status: 'connected' }]]);
const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]);
const mockOAuthServers = new Set(['server2']);
const mockMCPManager = {
getAllConnections: jest.fn(() => mockAppConnections),
getUserConnections: jest.fn(() => mockUserConnections),
getOAuthServers: jest.fn(() => mockOAuthServers),
};
mockGetMCPManager.mockReturnValue(mockMCPManager);
const result = await getMCPSetupData(mockUserId);
expect(mockLoadCustomConfig).toHaveBeenCalledWith(false);
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.getAllConnections).toHaveBeenCalled();
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.getOAuthServers).toHaveBeenCalled();
expect(result).toEqual({
mcpConfig: mockConfig.mcpServers,
appConnections: mockAppConnections,
userConnections: mockUserConnections,
oauthServers: mockOAuthServers,
});
});
it('should throw error when MCP config not found', async () => {
mockLoadCustomConfig.mockResolvedValue({});
await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found');
});
it('should handle null values from MCP manager gracefully', async () => {
mockLoadCustomConfig.mockResolvedValue(mockConfig);
const mockMCPManager = {
getAllConnections: jest.fn(() => null),
getUserConnections: jest.fn(() => null),
getOAuthServers: jest.fn(() => null),
};
mockGetMCPManager.mockReturnValue(mockMCPManager);
const result = await getMCPSetupData(mockUserId);
expect(result).toEqual({
mcpConfig: mockConfig.mcpServers,
appConnections: new Map(),
userConnections: new Map(),
oauthServers: new Set(),
});
});
});
describe('checkOAuthFlowStatus', () => {
const mockUserId = 'user-123';
const mockServerName = 'test-server';
const mockFlowId = 'flow-123';
beforeEach(() => {
const mockFlowsCache = {};
const mockFlowManager = {
getFlowState: jest.fn(),
};
mockGetLogStores.mockReturnValue(mockFlowsCache);
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
MCPOAuthHandler.generateFlowId.mockReturnValue(mockFlowId);
});
it('should return false flags when no flow state exists', async () => {
const mockFlowManager = { getFlowState: jest.fn(() => null) };
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
expect(mockGetLogStores).toHaveBeenCalledWith(CacheKeys.FLOWS);
expect(MCPOAuthHandler.generateFlowId).toHaveBeenCalledWith(mockUserId, mockServerName);
expect(mockFlowManager.getFlowState).toHaveBeenCalledWith(mockFlowId, 'mcp_oauth');
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
});
it('should detect failed flow when status is FAILED', async () => {
const mockFlowState = {
status: 'FAILED',
createdAt: Date.now() - 60000, // 1 minute ago
ttl: 180000,
};
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
expect(logger.debug).toHaveBeenCalledWith(
expect.stringContaining('Found failed OAuth flow'),
expect.objectContaining({
flowId: mockFlowId,
status: 'FAILED',
}),
);
});
it('should detect failed flow when flow has timed out', async () => {
const mockFlowState = {
status: 'PENDING',
createdAt: Date.now() - 200000, // 200 seconds ago (> 180s TTL)
ttl: 180000,
};
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
expect(logger.debug).toHaveBeenCalledWith(
expect.stringContaining('Found failed OAuth flow'),
expect.objectContaining({
timedOut: true,
}),
);
});
it('should detect failed flow when TTL not specified and flow exceeds default TTL', async () => {
const mockFlowState = {
status: 'PENDING',
createdAt: Date.now() - 200000, // 200 seconds ago (> 180s default TTL)
// ttl not specified, should use 180000 default
};
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
});
it('should detect active flow when status is PENDING and within TTL', async () => {
const mockFlowState = {
status: 'PENDING',
createdAt: Date.now() - 60000, // 1 minute ago (< 180s TTL)
ttl: 180000,
};
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
expect(result).toEqual({ hasActiveFlow: true, hasFailedFlow: false });
expect(logger.debug).toHaveBeenCalledWith(
expect.stringContaining('Found active OAuth flow'),
expect.objectContaining({
flowId: mockFlowId,
}),
);
});
it('should return false flags for other statuses', async () => {
const mockFlowState = {
status: 'COMPLETED',
createdAt: Date.now() - 60000,
ttl: 180000,
};
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
});
it('should handle errors gracefully', async () => {
const mockError = new Error('Flow state error');
const mockFlowManager = {
getFlowState: jest.fn(() => {
throw mockError;
}),
};
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
expect(logger.error).toHaveBeenCalledWith(
expect.stringContaining('Error checking OAuth flows'),
mockError,
);
});
});
describe('getServerConnectionStatus', () => {
const mockUserId = 'user-123';
const mockServerName = 'test-server';
it('should return app connection state when available', async () => {
const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]);
const userConnections = new Map();
const oauthServers = new Set();
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: false,
connectionState: 'connected',
});
});
it('should fallback to user connection state when app connection not available', async () => {
const appConnections = new Map();
const userConnections = new Map([[mockServerName, { connectionState: 'connecting' }]]);
const oauthServers = new Set();
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: false,
connectionState: 'connecting',
});
});
it('should default to disconnected when no connections exist', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set();
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: false,
connectionState: 'disconnected',
});
});
it('should prioritize app connection over user connection', async () => {
const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]);
const userConnections = new Map([[mockServerName, { connectionState: 'disconnected' }]]);
const oauthServers = new Set();
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: false,
connectionState: 'connected',
});
});
it('should indicate OAuth requirement when server is in OAuth servers set', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result.requiresOAuth).toBe(true);
});
it('should handle OAuth flow status when disconnected and requires OAuth with failed flow', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock flow state to return failed flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
status: 'FAILED',
createdAt: Date.now() - 60000,
ttl: 180000,
})),
};
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
mockGetLogStores.mockReturnValue({});
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'error',
});
});
it('should handle OAuth flow status when disconnected and requires OAuth with active flow', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock flow state to return active flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
status: 'PENDING',
createdAt: Date.now() - 60000, // 1 minute ago
ttl: 180000, // 3 minutes TTL
})),
};
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
mockGetLogStores.mockReturnValue({});
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'connecting',
});
});
it('should handle OAuth flow status when disconnected and requires OAuth with no flow', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock flow state to return no flow
const mockFlowManager = {
getFlowState: jest.fn(() => null),
};
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
mockGetLogStores.mockReturnValue({});
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'disconnected',
});
});
it('should not check OAuth flow status when server is connected', async () => {
const mockFlowManager = {
getFlowState: jest.fn(),
};
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
mockGetLogStores.mockReturnValue({});
const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]);
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'connected',
});
// Should not call flow manager since server is connected
expect(mockFlowManager.getFlowState).not.toHaveBeenCalled();
});
it('should not check OAuth flow status when server does not require OAuth', async () => {
const mockFlowManager = {
getFlowState: jest.fn(),
};
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
mockGetLogStores.mockReturnValue({});
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set(); // Server not in OAuth servers
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: false,
connectionState: 'disconnected',
});
// Should not call flow manager since server doesn't require OAuth
expect(mockFlowManager.getFlowState).not.toHaveBeenCalled();
});
});
});

View file

@ -1,7 +1,8 @@
import React, { useState, useCallback } from 'react'; import React, { useCallback } from 'react';
import { Button } from '@librechat/client'; import { Button } from '@librechat/client';
import { RefreshCw, Link } from 'lucide-react'; import { RefreshCw, Link } from 'lucide-react';
import { useLocalize, useMCPServerInitialization } from '~/hooks'; import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager';
import { useLocalize } from '~/hooks';
interface ServerInitializationSectionProps { interface ServerInitializationSectionProps {
serverName: string; serverName: string;
@ -14,32 +15,27 @@ export default function ServerInitializationSection({
}: ServerInitializationSectionProps) { }: ServerInitializationSectionProps) {
const localize = useLocalize(); const localize = useLocalize();
const [oauthUrl, setOauthUrl] = useState<string | null>(null); // Use the centralized server manager instead of the old initialization hook so we can handle multiple oauth flows at once
const {
// Use the shared initialization hook initializeServer,
const { initializeServer, isLoading, connectionStatus, cancelOAuthFlow, isCancellable } = connectionStatus,
useMCPServerInitialization({ cancelOAuthFlow,
onOAuthStarted: (name, url) => { isInitializing,
// Store the OAuth URL locally for display isCancellable,
setOauthUrl(url); getOAuthUrl,
}, } = useMCPServerManager();
onSuccess: () => {
// Clear OAuth URL on success
setOauthUrl(null);
},
});
const serverStatus = connectionStatus[serverName]; const serverStatus = connectionStatus[serverName];
const isConnected = serverStatus?.connectionState === 'connected'; const isConnected = serverStatus?.connectionState === 'connected';
const canCancel = isCancellable(serverName); const canCancel = isCancellable(serverName);
const isServerInitializing = isInitializing(serverName);
const serverOAuthUrl = getOAuthUrl(serverName);
const handleInitializeClick = useCallback(() => { const handleInitializeClick = useCallback(() => {
setOauthUrl(null);
initializeServer(serverName); initializeServer(serverName);
}, [initializeServer, serverName]); }, [initializeServer, serverName]);
const handleCancelClick = useCallback(() => { const handleCancelClick = useCallback(() => {
setOauthUrl(null);
cancelOAuthFlow(serverName); cancelOAuthFlow(serverName);
}, [cancelOAuthFlow, serverName]); }, [cancelOAuthFlow, serverName]);
@ -49,11 +45,11 @@ export default function ServerInitializationSection({
<div className="flex justify-start"> <div className="flex justify-start">
<button <button
onClick={handleInitializeClick} onClick={handleInitializeClick}
disabled={isLoading} disabled={isServerInitializing}
className="flex items-center gap-1 text-xs text-gray-400 hover:text-gray-600 disabled:opacity-50 dark:text-gray-500 dark:hover:text-gray-400" className="flex items-center gap-1 text-xs text-gray-400 hover:text-gray-600 disabled:opacity-50 dark:text-gray-500 dark:hover:text-gray-400"
> >
<RefreshCw className={`h-3 w-3 ${isLoading ? 'animate-spin' : ''}`} /> <RefreshCw className={`h-3 w-3 ${isServerInitializing ? 'animate-spin' : ''}`} />
{isLoading ? localize('com_ui_loading') : localize('com_ui_reinitialize')} {isServerInitializing ? localize('com_ui_loading') : localize('com_ui_reinitialize')}
</button> </button>
</div> </div>
); );
@ -70,13 +66,13 @@ export default function ServerInitializationSection({
</span> </span>
</div> </div>
{/* Only show authenticate button when OAuth URL is not present */} {/* Only show authenticate button when OAuth URL is not present */}
{!oauthUrl && ( {!serverOAuthUrl && (
<Button <Button
onClick={handleInitializeClick} onClick={handleInitializeClick}
disabled={isLoading} disabled={isServerInitializing}
className="flex items-center gap-2 bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 dark:hover:bg-blue-800" className="flex items-center gap-2 bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 dark:hover:bg-blue-800"
> >
{isLoading ? ( {isServerInitializing ? (
<> <>
<RefreshCw className="h-4 w-4 animate-spin" /> <RefreshCw className="h-4 w-4 animate-spin" />
{localize('com_ui_loading')} {localize('com_ui_loading')}
@ -94,7 +90,7 @@ export default function ServerInitializationSection({
</div> </div>
{/* OAuth URL display */} {/* OAuth URL display */}
{oauthUrl && ( {serverOAuthUrl && (
<div className="mt-4 rounded-lg border border-blue-200 bg-blue-50 p-3 dark:border-blue-700 dark:bg-blue-900/20"> <div className="mt-4 rounded-lg border border-blue-200 bg-blue-50 p-3 dark:border-blue-700 dark:bg-blue-900/20">
<div className="mb-2 flex items-center gap-2"> <div className="mb-2 flex items-center gap-2">
<div className="flex h-4 w-4 items-center justify-center rounded-full bg-blue-500"> <div className="flex h-4 w-4 items-center justify-center rounded-full bg-blue-500">
@ -106,7 +102,7 @@ export default function ServerInitializationSection({
</div> </div>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<Button <Button
onClick={() => window.open(oauthUrl, '_blank', 'noopener,noreferrer')} onClick={() => window.open(serverOAuthUrl, '_blank', 'noopener,noreferrer')}
className="flex-1 bg-blue-600 text-white hover:bg-blue-700 dark:hover:bg-blue-800" className="flex-1 bg-blue-600 text-white hover:bg-blue-700 dark:hover:bg-blue-800"
> >
{localize('com_ui_continue_oauth')} {localize('com_ui_continue_oauth')}

View file

@ -8,11 +8,12 @@ import type { TUpdateUserPlugins } from 'librechat-data-provider';
import ServerInitializationSection from '~/components/MCP/ServerInitializationSection'; import ServerInitializationSection from '~/components/MCP/ServerInitializationSection';
import CustomUserVarsSection from '~/components/MCP/CustomUserVarsSection'; import CustomUserVarsSection from '~/components/MCP/CustomUserVarsSection';
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries'; import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
import BadgeRowProvider from '~/Providers/BadgeRowContext';
import { useGetStartupConfig } from '~/data-provider'; import { useGetStartupConfig } from '~/data-provider';
import MCPPanelSkeleton from './MCPPanelSkeleton'; import MCPPanelSkeleton from './MCPPanelSkeleton';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
export default function MCPPanel() { function MCPPanelContent() {
const localize = useLocalize(); const localize = useLocalize();
const { showToast } = useToastContext(); const { showToast } = useToastContext();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
@ -205,3 +206,11 @@ export default function MCPPanel() {
); );
} }
} }
export default function MCPPanel() {
return (
<BadgeRowProvider>
<MCPPanelContent />
</BadgeRowProvider>
);
}

View file

@ -1 +1 @@
export { useMCPServerInitialization } from './useMCPServerInitialization'; export { useMCPServerManager } from './useMCPServerManager';

View file

@ -1,317 +0,0 @@
import { useCallback, useState, useEffect, useMemo } from 'react';
import { useToastContext } from '@librechat/client';
import { QueryKeys } from 'librechat-data-provider';
import { useQueryClient } from '@tanstack/react-query';
import {
useReinitializeMCPServerMutation,
useCancelMCPOAuthMutation,
} from 'librechat-data-provider/react-query';
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
import { useLocalize } from '~/hooks';
import { logger } from '~/utils';
interface UseMCPServerInitializationOptions {
onSuccess?: (serverName: string) => void;
onOAuthStarted?: (serverName: string, oauthUrl: string) => void;
onError?: (serverName: string, error: any) => void;
}
export function useMCPServerInitialization(options?: UseMCPServerInitializationOptions) {
const localize = useLocalize();
const { showToast } = useToastContext();
const queryClient = useQueryClient();
// OAuth state management
const [oauthPollingServers, setOauthPollingServers] = useState<Map<string, string>>(new Map());
const [oauthStartTimes, setOauthStartTimes] = useState<Map<string, number>>(new Map());
const [initializingServers, setInitializingServers] = useState<Set<string>>(new Set());
const [cancellableServers, setCancellableServers] = useState<Set<string>>(new Set());
// Get connection status
const { data: connectionStatusData } = useMCPConnectionStatusQuery();
const connectionStatus = useMemo(
() => connectionStatusData?.connectionStatus || {},
[connectionStatusData],
);
// Main initialization mutation
const reinitializeMutation = useReinitializeMCPServerMutation();
// Track which server is currently being processed
const [currentProcessingServer, setCurrentProcessingServer] = useState<string | null>(null);
// Cancel OAuth mutation
const cancelOAuthMutation = useCancelMCPOAuthMutation();
// Helper function to clean up OAuth state
const cleanupOAuthState = useCallback((serverName: string) => {
setOauthPollingServers((prev) => {
const newMap = new Map(prev);
newMap.delete(serverName);
return newMap;
});
setOauthStartTimes((prev) => {
const newMap = new Map(prev);
newMap.delete(serverName);
return newMap;
});
setInitializingServers((prev) => {
const newSet = new Set(prev);
newSet.delete(serverName);
return newSet;
});
setCancellableServers((prev) => {
const newSet = new Set(prev);
newSet.delete(serverName);
return newSet;
});
}, []);
// Cancel OAuth flow
const cancelOAuthFlow = useCallback(
(serverName: string) => {
logger.info(`[MCP OAuth] User cancelling OAuth flow for ${serverName}`);
cancelOAuthMutation.mutate(serverName, {
onSuccess: () => {
cleanupOAuthState(serverName);
showToast({
message: localize('com_ui_mcp_oauth_cancelled', { 0: serverName }),
status: 'info',
});
},
onError: (error) => {
logger.error(`[MCP OAuth] Failed to cancel OAuth flow for ${serverName}:`, error);
// Clean up state anyway
cleanupOAuthState(serverName);
},
});
},
[cancelOAuthMutation, cleanupOAuthState, showToast, localize],
);
// Helper function to handle successful connection
const handleSuccessfulConnection = useCallback(
async (serverName: string, message: string) => {
showToast({ message, status: 'success' });
// Force immediate refetch to update UI
await Promise.all([
queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]),
queryClient.refetchQueries([QueryKeys.tools]),
]);
// Clean up OAuth state
cleanupOAuthState(serverName);
// Call optional success callback
options?.onSuccess?.(serverName);
},
[showToast, queryClient, options, cleanupOAuthState],
);
// Helper function to handle OAuth timeout/failure
const handleOAuthFailure = useCallback(
(serverName: string, isTimeout: boolean) => {
logger.warn(
`[MCP OAuth] OAuth ${isTimeout ? 'timed out' : 'failed'} for ${serverName}, stopping poll`,
);
// Clean up OAuth state
cleanupOAuthState(serverName);
// Show error toast
showToast({
message: isTimeout
? localize('com_ui_mcp_oauth_timeout', { 0: serverName })
: localize('com_ui_mcp_init_failed'),
status: 'error',
});
},
[showToast, localize, cleanupOAuthState],
);
// Poll for OAuth completion
useEffect(() => {
if (oauthPollingServers.size === 0) {
return;
}
const pollInterval = setInterval(() => {
// Check each polling server
oauthPollingServers.forEach((oauthUrl, serverName) => {
const serverStatus = connectionStatus[serverName];
// Check for client-side timeout (3 minutes)
const startTime = oauthStartTimes.get(serverName);
const hasTimedOut = startTime && Date.now() - startTime > 180000; // 3 minutes
if (serverStatus?.connectionState === 'connected') {
// OAuth completed successfully
handleSuccessfulConnection(
serverName,
localize('com_ui_mcp_authenticated_success', { 0: serverName }),
);
} else if (serverStatus?.connectionState === 'error' || hasTimedOut) {
// OAuth failed or timed out
handleOAuthFailure(serverName, !!hasTimedOut);
}
setCancellableServers((prev) => new Set(prev).add(serverName));
});
queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]);
}, 3500);
return () => {
clearInterval(pollInterval);
};
}, [
oauthPollingServers,
oauthStartTimes,
connectionStatus,
queryClient,
handleSuccessfulConnection,
handleOAuthFailure,
localize,
]);
// Initialize server function
const initializeServer = useCallback(
(serverName: string) => {
// Prevent spam - check if already initializing
if (initializingServers.has(serverName)) {
return;
}
if (connectionStatus[serverName]?.requiresOAuth) {
setCancellableServers((prev) => new Set(prev).add(serverName));
}
// Add to initializing set
setInitializingServers((prev) => new Set(prev).add(serverName));
// If there's already a server being processed, that one will be cancelled
if (currentProcessingServer && currentProcessingServer !== serverName) {
// Clean up the cancelled server's state immediately
showToast({
message: localize('com_ui_mcp_init_cancelled', { 0: currentProcessingServer }),
status: 'warning',
});
cleanupOAuthState(currentProcessingServer);
}
// Track the current server being processed
setCurrentProcessingServer(serverName);
reinitializeMutation.mutate(serverName, {
onSuccess: (response: any) => {
// Clear current processing server
setCurrentProcessingServer(null);
if (response.success) {
if (response.oauthRequired && response.oauthUrl) {
// OAuth required - store URL and start polling
setOauthPollingServers((prev) => new Map(prev).set(serverName, response.oauthUrl));
// Track when OAuth started for timeout detection
setOauthStartTimes((prev) => new Map(prev).set(serverName, Date.now()));
// Call optional OAuth callback or open URL directly
if (options?.onOAuthStarted) {
options.onOAuthStarted(serverName, response.oauthUrl);
} else {
window.open(response.oauthUrl, '_blank', 'noopener,noreferrer');
}
showToast({
message: localize('com_ui_connecting'),
status: 'info',
});
} else if (response.oauthRequired) {
// OAuth required but no URL - shouldn't happen
showToast({
message: localize('com_ui_mcp_oauth_no_url'),
status: 'warning',
});
// Remove from initializing since it failed
setInitializingServers((prev) => {
const newSet = new Set(prev);
newSet.delete(serverName);
return newSet;
});
} else {
// Successful connection without OAuth
handleSuccessfulConnection(
serverName,
response.message || localize('com_ui_mcp_initialized_success', { 0: serverName }),
);
}
} else {
// Remove from initializing if not successful
setInitializingServers((prev) => {
const newSet = new Set(prev);
newSet.delete(serverName);
return newSet;
});
}
},
onError: (error: any) => {
console.error(`Error initializing MCP server ${serverName}:`, error);
setCurrentProcessingServer(null);
const isCancelled =
error?.name === 'CanceledError' ||
error?.code === 'ERR_CANCELED' ||
error?.message?.includes('cancel') ||
error?.message?.includes('abort');
if (isCancelled) {
showToast({
message: localize('com_ui_mcp_init_cancelled', { 0: serverName }),
status: 'warning',
});
} else {
showToast({
message: localize('com_ui_mcp_init_failed'),
status: 'error',
});
}
// Clean up OAuth state using helper function
cleanupOAuthState(serverName);
// Call optional error callback
options?.onError?.(serverName, error);
},
});
},
[
initializingServers,
connectionStatus,
currentProcessingServer,
reinitializeMutation,
showToast,
localize,
cleanupOAuthState,
options,
handleSuccessfulConnection,
],
);
return {
initializeServer,
isInitializing: (serverName: string) => initializingServers.has(serverName),
isCancellable: (serverName: string) => cancellableServers.has(serverName),
initializingServers,
oauthPollingServers,
oauthStartTimes,
connectionStatus,
isLoading: reinitializeMutation.isLoading,
cancelOAuthFlow,
};
}

View file

@ -1,33 +1,52 @@
import { useCallback, useState, useMemo, useRef } from 'react'; import { useCallback, useState, useMemo, useRef, useEffect } from 'react';
import { useToastContext } from '@librechat/client'; import { useToastContext } from '@librechat/client';
import { useQueryClient } from '@tanstack/react-query'; import { useQueryClient } from '@tanstack/react-query';
import { Constants, QueryKeys } from 'librechat-data-provider'; import { Constants, QueryKeys } from 'librechat-data-provider';
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query'; import {
useUpdateUserPluginsMutation,
useReinitializeMCPServerMutation,
useCancelMCPOAuthMutation,
} from 'librechat-data-provider/react-query';
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
import type { TUpdateUserPlugins, TPlugin } from 'librechat-data-provider'; import type { TUpdateUserPlugins, TPlugin } from 'librechat-data-provider';
import type { ConfigFieldDetail } from '~/components/MCP/MCPConfigDialog'; import type { ConfigFieldDetail } from '~/components/MCP/MCPConfigDialog';
import { useLocalize, useMCPServerInitialization } from '~/hooks'; import { useLocalize } from '~/hooks';
import { useBadgeRowContext } from '~/Providers'; import { useBadgeRowContext } from '~/Providers';
interface ServerState {
isInitializing: boolean;
oauthUrl: string | null;
oauthStartTime: number | null;
isCancellable: boolean;
pollInterval: NodeJS.Timeout | null;
}
export function useMCPServerManager() { export function useMCPServerManager() {
const localize = useLocalize(); const localize = useLocalize();
const { showToast } = useToastContext(); const { showToast } = useToastContext();
const { mcpSelect, startupConfig } = useBadgeRowContext(); const { mcpSelect, startupConfig } = useBadgeRowContext();
const { mcpValues, setMCPValues, mcpToolDetails, isPinned, setIsPinned } = mcpSelect; const { mcpValues, setMCPValues, mcpToolDetails, isPinned, setIsPinned } = mcpSelect;
const queryClient = useQueryClient();
const [isConfigModalOpen, setIsConfigModalOpen] = useState(false);
const [selectedToolForConfig, setSelectedToolForConfig] = useState<TPlugin | null>(null);
const previousFocusRef = useRef<HTMLElement | null>(null);
const mcpValuesRef = useRef(mcpValues);
// fixes the issue where OAuth flows would deselect all the servers except the one that is being authenticated on success
useEffect(() => {
mcpValuesRef.current = mcpValues;
}, [mcpValues]);
const configuredServers = useMemo(() => { const configuredServers = useMemo(() => {
if (!startupConfig?.mcpServers) { if (!startupConfig?.mcpServers) return [];
return [];
}
return Object.entries(startupConfig.mcpServers) return Object.entries(startupConfig.mcpServers)
.filter(([, config]) => config.chatMenu !== false) .filter(([, config]) => config.chatMenu !== false)
.map(([serverName]) => serverName); .map(([serverName]) => serverName);
}, [startupConfig?.mcpServers]); }, [startupConfig?.mcpServers]);
const [isConfigModalOpen, setIsConfigModalOpen] = useState(false); const reinitializeMutation = useReinitializeMCPServerMutation();
const [selectedToolForConfig, setSelectedToolForConfig] = useState<TPlugin | null>(null); const cancelOAuthMutation = useCancelMCPOAuthMutation();
const previousFocusRef = useRef<HTMLElement | null>(null);
const queryClient = useQueryClient();
const updateUserPluginsMutation = useUpdateUserPluginsMutation({ const updateUserPluginsMutation = useUpdateUserPluginsMutation({
onSuccess: async () => { onSuccess: async () => {
@ -48,52 +67,278 @@ export function useMCPServerManager() {
}, },
}); });
const { initializeServer, isInitializing, connectionStatus, cancelOAuthFlow, isCancellable } = const [serverStates, setServerStates] = useState<Record<string, ServerState>>(() => {
useMCPServerInitialization({ const initialStates: Record<string, ServerState> = {};
onSuccess: (serverName) => { configuredServers.forEach((serverName) => {
initialStates[serverName] = {
isInitializing: false,
oauthUrl: null,
oauthStartTime: null,
isCancellable: false,
pollInterval: null,
};
});
return initialStates;
});
const { data: connectionStatusData } = useMCPConnectionStatusQuery();
const connectionStatus = useMemo(
() => connectionStatusData?.connectionStatus || {},
[connectionStatusData?.connectionStatus],
);
useEffect(() => {
if (!mcpValues?.length) return;
const connectedSelected = mcpValues.filter(
(serverName) => connectionStatus[serverName]?.connectionState === 'connected',
);
if (connectedSelected.length !== mcpValues.length) {
setMCPValues(connectedSelected);
}
}, [connectionStatus, mcpValues, setMCPValues]);
const updateServerState = useCallback((serverName: string, updates: Partial<ServerState>) => {
setServerStates((prev) => {
const newStates = { ...prev };
const currentState = newStates[serverName] || {
isInitializing: false,
oauthUrl: null,
oauthStartTime: null,
isCancellable: false,
pollInterval: null,
};
newStates[serverName] = { ...currentState, ...updates };
return newStates;
});
}, []);
const cleanupServerState = useCallback(
(serverName: string) => {
const state = serverStates[serverName];
if (state?.pollInterval) {
clearInterval(state.pollInterval);
}
updateServerState(serverName, {
isInitializing: false,
oauthUrl: null,
oauthStartTime: null,
isCancellable: false,
pollInterval: null,
});
},
[serverStates, updateServerState],
);
const startServerPolling = useCallback(
(serverName: string) => {
const pollInterval = setInterval(async () => {
try {
await queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]);
const freshConnectionData = queryClient.getQueryData([
QueryKeys.mcpConnectionStatus,
]) as any;
const freshConnectionStatus = freshConnectionData?.connectionStatus || {};
const state = serverStates[serverName];
const serverStatus = freshConnectionStatus[serverName];
if (serverStatus?.connectionState === 'connected') {
clearInterval(pollInterval);
showToast({
message: localize('com_ui_mcp_authenticated_success', { 0: serverName }),
status: 'success',
});
const currentValues = mcpValuesRef.current ?? [];
if (!currentValues.includes(serverName)) {
setMCPValues([...currentValues, serverName]);
}
// This delay is to ensure UI has updated with new connection status before cleanup
// Otherwise servers will show as disconnected for a second after OAuth flow completes
setTimeout(() => {
cleanupServerState(serverName);
}, 1000);
return;
}
if (state?.oauthStartTime && Date.now() - state.oauthStartTime > 180000) {
showToast({
message: localize('com_ui_mcp_oauth_timeout', { 0: serverName }),
status: 'error',
});
cleanupServerState(serverName);
return;
}
if (serverStatus?.connectionState === 'error') {
showToast({
message: localize('com_ui_mcp_init_failed'),
status: 'error',
});
cleanupServerState(serverName);
}
} catch (error) {
console.error(`[MCP Manager] Error polling server ${serverName}:`, error);
}
}, 3500);
updateServerState(serverName, { pollInterval });
},
[
queryClient,
serverStates,
showToast,
localize,
setMCPValues,
cleanupServerState,
updateServerState,
],
);
const initializeServer = useCallback(
async (serverName: string) => {
updateServerState(serverName, { isInitializing: true });
try {
const response = await reinitializeMutation.mutateAsync(serverName);
if (response.success) {
if (response.oauthRequired && response.oauthUrl) {
updateServerState(serverName, {
oauthUrl: response.oauthUrl,
oauthStartTime: Date.now(),
isCancellable: true,
isInitializing: true,
});
window.open(response.oauthUrl, '_blank', 'noopener,noreferrer');
startServerPolling(serverName);
} else {
await queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]);
showToast({
message: localize('com_ui_mcp_initialized_success', { 0: serverName }),
status: 'success',
});
const currentValues = mcpValues ?? []; const currentValues = mcpValues ?? [];
if (!currentValues.includes(serverName)) { if (!currentValues.includes(serverName)) {
setMCPValues([...currentValues, serverName]); setMCPValues([...currentValues, serverName]);
} }
},
onError: (serverName) => {
const tool = mcpToolDetails?.find((t) => t.name === serverName);
const serverConfig = startupConfig?.mcpServers?.[serverName];
const serverStatus = connectionStatus[serverName];
const hasAuthConfig = cleanupServerState(serverName);
(tool?.authConfig && tool.authConfig.length > 0) ||
(serverConfig?.customUserVars && Object.keys(serverConfig.customUserVars).length > 0);
const wouldShowButton =
!serverStatus ||
serverStatus.connectionState === 'disconnected' ||
serverStatus.connectionState === 'error' ||
(serverStatus.connectionState === 'connected' && hasAuthConfig);
if (!wouldShowButton) {
return;
} }
}
const configTool = tool || { } catch (error) {
name: serverName, console.error(`[MCP Manager] Failed to initialize ${serverName}:`, error);
pluginKey: `${Constants.mcp_prefix}${serverName}`, showToast({
authConfig: serverConfig?.customUserVars message: localize('com_ui_mcp_init_failed', { 0: serverName }),
? Object.entries(serverConfig.customUserVars).map(([key, config]) => ({ status: 'error',
authField: key,
label: config.title,
description: config.description,
}))
: [],
authenticated: false,
};
previousFocusRef.current = document.activeElement as HTMLElement;
setSelectedToolForConfig(configTool);
setIsConfigModalOpen(true);
},
}); });
cleanupServerState(serverName);
}
},
[
updateServerState,
reinitializeMutation,
startServerPolling,
queryClient,
showToast,
localize,
mcpValues,
cleanupServerState,
setMCPValues,
],
);
const cancelOAuthFlow = useCallback(
(serverName: string) => {
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
cleanupServerState(serverName);
cancelOAuthMutation.mutate(serverName);
showToast({
message: localize('com_ui_mcp_oauth_cancelled', { 0: serverName }),
status: 'warning',
});
},
[queryClient, cleanupServerState, showToast, localize, cancelOAuthMutation],
);
const isInitializing = useCallback(
(serverName: string) => {
return serverStates[serverName]?.isInitializing || false;
},
[serverStates],
);
const isCancellable = useCallback(
(serverName: string) => {
return serverStates[serverName]?.isCancellable || false;
},
[serverStates],
);
const getOAuthUrl = useCallback(
(serverName: string) => {
return serverStates[serverName]?.oauthUrl || null;
},
[serverStates],
);
const placeholderText = useMemo(
() => startupConfig?.interface?.mcpServers?.placeholder || localize('com_ui_mcp_servers'),
[startupConfig?.interface?.mcpServers?.placeholder, localize],
);
const batchToggleServers = useCallback(
(serverNames: string[]) => {
const connectedServers: string[] = [];
const disconnectedServers: string[] = [];
serverNames.forEach((serverName) => {
const serverStatus = connectionStatus[serverName];
if (serverStatus?.connectionState === 'connected') {
connectedServers.push(serverName);
} else {
disconnectedServers.push(serverName);
}
});
setMCPValues(connectedServers);
disconnectedServers.forEach((serverName) => {
initializeServer(serverName);
});
},
[connectionStatus, setMCPValues, initializeServer],
);
const toggleServerSelection = useCallback(
(serverName: string) => {
const currentValues = mcpValues ?? [];
const isCurrentlySelected = currentValues.includes(serverName);
if (isCurrentlySelected) {
const filteredValues = currentValues.filter((name) => name !== serverName);
setMCPValues(filteredValues);
} else {
const serverStatus = connectionStatus[serverName];
if (serverStatus?.connectionState === 'connected') {
setMCPValues([...currentValues, serverName]);
} else {
initializeServer(serverName);
}
}
},
[mcpValues, setMCPValues, connectionStatus, initializeServer],
);
const handleConfigSave = useCallback( const handleConfigSave = useCallback(
(targetName: string, authData: Record<string, string>) => { (targetName: string, authData: Record<string, string>) => {
@ -155,48 +400,6 @@ export function useMCPServerManager() {
} }
}, []); }, []);
const toggleServerSelection = useCallback(
(serverName: string) => {
const currentValues = mcpValues ?? [];
const serverStatus = connectionStatus[serverName];
if (currentValues.includes(serverName)) {
const filteredValues = currentValues.filter((name) => name !== serverName);
setMCPValues(filteredValues);
} else {
if (serverStatus?.connectionState === 'connected') {
setMCPValues([...currentValues, serverName]);
} else {
initializeServer(serverName);
}
}
},
[connectionStatus, mcpValues, setMCPValues, initializeServer],
);
const batchToggleServers = useCallback(
(serverNames: string[]) => {
const connectedServers: string[] = [];
const disconnectedServers: string[] = [];
serverNames.forEach((serverName) => {
const serverStatus = connectionStatus[serverName];
if (serverStatus?.connectionState === 'connected') {
connectedServers.push(serverName);
} else {
disconnectedServers.push(serverName);
}
});
setMCPValues(connectedServers);
disconnectedServers.forEach((serverName) => {
initializeServer(serverName);
});
},
[connectionStatus, setMCPValues, initializeServer],
);
const getServerStatusIconProps = useCallback( const getServerStatusIconProps = useCallback(
(serverName: string) => { (serverName: string) => {
const tool = mcpToolDetails?.find((t) => t.name === serverName); const tool = mcpToolDetails?.find((t) => t.name === serverName);
@ -255,11 +458,6 @@ export function useMCPServerManager() {
], ],
); );
const placeholderText = useMemo(
() => startupConfig?.interface?.mcpServers?.placeholder || localize('com_ui_mcp_servers'),
[startupConfig?.interface?.mcpServers?.placeholder, localize],
);
const getConfigDialogProps = useCallback(() => { const getConfigDialogProps = useCallback(() => {
if (!selectedToolForConfig) return null; if (!selectedToolForConfig) return null;
@ -302,27 +500,31 @@ export function useMCPServerManager() {
]); ]);
return { return {
// Data
configuredServers, configuredServers,
connectionStatus,
initializeServer,
cancelOAuthFlow,
isInitializing,
isCancellable,
getOAuthUrl,
mcpValues, mcpValues,
setMCPValues,
mcpToolDetails, mcpToolDetails,
isPinned, isPinned,
setIsPinned, setIsPinned,
startupConfig,
connectionStatus,
placeholderText, placeholderText,
// Handlers
toggleServerSelection,
batchToggleServers, batchToggleServers,
getServerStatusIconProps, toggleServerSelection,
// Dialog state
selectedToolForConfig,
isConfigModalOpen,
getConfigDialogProps,
// Utilities
localize, localize,
isConfigModalOpen,
handleDialogOpenChange,
selectedToolForConfig,
setSelectedToolForConfig,
handleSave,
handleRevoke,
getServerStatusIconProps,
getConfigDialogProps,
}; };
} }

View file

@ -81,13 +81,21 @@ export function useMCPSelect({ conversationId }: UseMCPSelectOptions) {
[setEphemeralAgent], [setEphemeralAgent],
); );
const [mcpValues, setMCPValues] = useLocalStorage<string[]>( const [mcpValues, setMCPValuesRaw] = useLocalStorage<string[]>(
`${LocalStorageKeys.LAST_MCP_}${key}`, `${LocalStorageKeys.LAST_MCP_}${key}`,
mcpState, mcpState,
setSelectedValues, setSelectedValues,
storageCondition, storageCondition,
); );
const setMCPValuesRawRef = useRef(setMCPValuesRaw);
setMCPValuesRawRef.current = setMCPValuesRaw;
// Create a stable memoized setter to avoid re-creating it on every render and causing an infinite render loop
const setMCPValues = useCallback((value: string[]) => {
setMCPValuesRawRef.current(value);
}, []);
const [isPinned, setIsPinned] = useLocalStorage<boolean>( const [isPinned, setIsPinned] = useLocalStorage<boolean>(
`${LocalStorageKeys.PIN_MCP_}${key}`, `${LocalStorageKeys.PIN_MCP_}${key}`,
true, true,

View file

@ -857,13 +857,11 @@
"com_ui_mcp_not_authenticated": "{{0}} not authenticated (OAuth Required)", "com_ui_mcp_not_authenticated": "{{0}} not authenticated (OAuth Required)",
"com_ui_mcp_not_initialized": "{{0}} not initialized", "com_ui_mcp_not_initialized": "{{0}} not initialized",
"com_ui_mcp_oauth_cancelled": "OAuth login cancelled for {{0}}", "com_ui_mcp_oauth_cancelled": "OAuth login cancelled for {{0}}",
"com_ui_mcp_oauth_no_url": "OAuth authentication required but no URL provided",
"com_ui_mcp_oauth_timeout": "OAuth login timed out for {{0}}", "com_ui_mcp_oauth_timeout": "OAuth login timed out for {{0}}",
"com_ui_mcp_server_not_found": "Server not found.", "com_ui_mcp_server_not_found": "Server not found.",
"com_ui_mcp_servers": "MCP Servers", "com_ui_mcp_servers": "MCP Servers",
"com_ui_mcp_update_var": "Update {{0}}", "com_ui_mcp_update_var": "Update {{0}}",
"com_ui_mcp_url": "MCP Server URL", "com_ui_mcp_url": "MCP Server URL",
"com_ui_mcp_init_cancelled": "MCP server '{{0}}' initialization was cancelled due to simultaneous request",
"com_ui_medium": "Medium", "com_ui_medium": "Medium",
"com_ui_memories": "Memories", "com_ui_memories": "Memories",
"com_ui_memories_allow_create": "Allow creating Memories", "com_ui_memories_allow_create": "Allow creating Memories",

View file

@ -134,6 +134,8 @@ export const plugins = () => '/api/plugins';
export const mcpReinitialize = (serverName: string) => `/api/mcp/${serverName}/reinitialize`; export const mcpReinitialize = (serverName: string) => `/api/mcp/${serverName}/reinitialize`;
export const mcpConnectionStatus = () => '/api/mcp/connection/status'; export const mcpConnectionStatus = () => '/api/mcp/connection/status';
export const mcpServerConnectionStatus = (serverName: string) =>
`/api/mcp/connection/status/${serverName}`;
export const mcpAuthValues = (serverName: string) => { export const mcpAuthValues = (serverName: string) => {
return `/api/mcp/${serverName}/auth-values`; return `/api/mcp/${serverName}/auth-values`;
}; };

View file

@ -149,6 +149,12 @@ export const getMCPConnectionStatus = (): Promise<q.MCPConnectionStatusResponse>
return request.get(endpoints.mcpConnectionStatus()); return request.get(endpoints.mcpConnectionStatus());
}; };
export const getMCPServerConnectionStatus = (
serverName: string,
): Promise<q.MCPServerConnectionStatusResponse> => {
return request.get(endpoints.mcpServerConnectionStatus(serverName));
};
export const getMCPAuthValues = (serverName: string): Promise<q.MCPAuthValuesResponse> => { export const getMCPAuthValues = (serverName: string): Promise<q.MCPAuthValuesResponse> => {
return request.get(endpoints.mcpAuthValues(serverName)); return request.get(endpoints.mcpAuthValues(serverName));
}; };

View file

@ -6,6 +6,7 @@ import type {
} from '@tanstack/react-query'; } from '@tanstack/react-query';
import { Constants, initialModelsConfig } from '../config'; import { Constants, initialModelsConfig } from '../config';
import { defaultOrderQuery } from '../types/assistants'; import { defaultOrderQuery } from '../types/assistants';
import { MCPServerConnectionStatusResponse } from '../types/queries';
import * as dataService from '../data-service'; import * as dataService from '../data-service';
import * as m from '../types/mutations'; import * as m from '../types/mutations';
import { QueryKeys } from '../keys'; import { QueryKeys } from '../keys';
@ -380,3 +381,21 @@ export const useUpdateFeedbackMutation = (
}, },
); );
}; };
export const useMCPServerConnectionStatusQuery = (
serverName: string,
config?: UseQueryOptions<MCPServerConnectionStatusResponse>,
): QueryObserverResult<MCPServerConnectionStatusResponse> => {
return useQuery<MCPServerConnectionStatusResponse>(
[QueryKeys.mcpConnectionStatus, serverName],
() => dataService.getMCPServerConnectionStatus(serverName),
{
refetchOnWindowFocus: false,
refetchOnReconnect: false,
refetchOnMount: false,
staleTime: 10000, // 10 seconds
enabled: !!serverName,
...config,
},
);
};

View file

@ -0,0 +1 @@
export * from './queries';

View file

@ -135,6 +135,13 @@ export interface MCPConnectionStatusResponse {
connectionStatus: Record<string, MCPServerStatus>; connectionStatus: Record<string, MCPServerStatus>;
} }
export interface MCPServerConnectionStatusResponse {
success: boolean;
serverName: string;
connectionStatus: string;
requiresOAuth: boolean;
}
export interface MCPAuthValuesResponse { export interface MCPAuthValuesResponse {
success: boolean; success: boolean;
serverName: string; serverName: string;