♻️ refactor: MCPManager for Scalability, Fix App-Level Detection, Add Lazy Connections (#8930)

* feat: MCP Connection management overhaul - Making MCPManager manageable

Refactor the monolithic MCPManager into focused, single-responsibility classes:

• MCPServersRegistry: Server configuration discovery and metadata management
• UserConnectionManager: Manages user-level connections
• ConnectionsRepository: Low-level connection pool with lazy loading
• MCPConnectionFactory: Handles MCP connection creation with OAuth support

New Features:
• Lazy loading of app-level connections for horizontal scaling
• Automatic reconnection for app-level connections
• Enhanced OAuth detection with explicit requiresOAuth flag
• Centralized MCP configuration management

Bug Fixes:
• App-level connection detection in MCPManager.callTool
• MCP Connection Reinitialization route behavior

Optimizations:
• MCPConnection.isConnected() caching to reduce overhead
• Concurrent server metadata retrieval instead of sequential

This refactoring addresses scalability bottlenecks and improves reliability
while maintaining backward compatibility with existing configurations.

* feat: Enabled import order in eslint.

* # Moved tests to __tests__ folder
# added tests for MCPServersRegistry.ts

* # Add unit tests for ConnectionsRepository functionality

* # Add unit tests for MCPConnectionFactory functionality

* # Reorganize MCP connection tests and improve error handling

* # reordering imports

* # Update testPathIgnorePatterns in jest.config.mjs to exclude development TypeScript files

* # removed mcp/manager.ts
This commit is contained in:
Theo N. Truong 2025-08-13 09:45:06 -06:00 committed by GitHub
parent 9dbf153489
commit 8780a78165
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 2571 additions and 1468 deletions

View file

@ -698,3 +698,16 @@ OPENWEATHER_API_KEY=
# JINA_API_KEY=your_jina_api_key # JINA_API_KEY=your_jina_api_key
# or # or
# COHERE_API_KEY=your_cohere_api_key # COHERE_API_KEY=your_cohere_api_key
#======================#
# MCP Configuration #
#======================#
# Treat 401/403 responses as OAuth requirement when no oauth metadata found
# MCP_OAUTH_ON_AUTH_ERROR=true
# Timeout for OAuth detection requests in milliseconds
# MCP_OAUTH_DETECTION_TIMEOUT=5000
# Cache connection status checks for this many milliseconds to avoid expensive verification
# MCP_CONNECTION_CHECK_TTL=60000

View file

@ -147,7 +147,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate
## 8. Module Import Conventions ## 8. Module Import Conventions
- `npm` packages first, - `npm` packages first,
- from shortest line (top) to longest (bottom) - from longest line (top) to shortest (bottom)
- Followed by typescript types (pertains to data-provider and client workspaces) - Followed by typescript types (pertains to data-provider and client workspaces)
- longest line (top) to shortest (bottom) - longest line (top) to shortest (bottom)
@ -157,6 +157,8 @@ Apply the following naming conventions to branches, labels, and other Git-relate
- longest line (top) to shortest (bottom) - longest line (top) to shortest (bottom)
- imports with alias `~` treated the same as relative import with respect to line length - imports with alias `~` treated the same as relative import with respect to line length
**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks.
--- ---
Please ensure that you adapt this summary to fit the specific context and nuances of your project. Please ensure that you adapt this summary to fit the specific context and nuances of your project.

1
.gitignore vendored
View file

@ -137,3 +137,4 @@ helm/**/.values.yaml
/.openai/ /.openai/
/.tabnine/ /.tabnine/
/.codeium /.codeium
*.local.md

View file

@ -1,27 +1,13 @@
const { MCPManager, FlowStateManager } = require('@librechat/api');
const { EventSource } = require('eventsource'); const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider'); const { Time } = require('librechat-data-provider');
const { MCPManager, FlowStateManager } = require('@librechat/api');
const logger = require('./winston'); const logger = require('./winston');
global.EventSource = EventSource; global.EventSource = EventSource;
/** @type {MCPManager} */ /** @type {MCPManager} */
let mcpManager = null;
let flowManager = null; let flowManager = null;
/**
* @param {string} [userId] - Optional user ID, to avoid disconnecting the current user.
* @returns {MCPManager}
*/
function getMCPManager(userId) {
if (!mcpManager) {
mcpManager = MCPManager.getInstance();
} else {
mcpManager.checkIdleConnections(userId);
}
return mcpManager;
}
/** /**
* @param {Keyv} flowsCache * @param {Keyv} flowsCache
* @returns {FlowStateManager} * @returns {FlowStateManager}
@ -37,6 +23,7 @@ function getFlowStateManager(flowsCache) {
module.exports = { module.exports = {
logger, logger,
getMCPManager, createMCPManager: MCPManager.createInstance,
getMCPManager: MCPManager.getInstance,
getFlowStateManager, getFlowStateManager,
}; };

View file

@ -1,7 +1,7 @@
const express = require('express'); const { MongoMemoryServer } = require('mongodb-memory-server');
const request = require('supertest'); const request = require('supertest');
const mongoose = require('mongoose'); const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server'); const express = require('express');
jest.mock('@librechat/api', () => ({ jest.mock('@librechat/api', () => ({
MCPOAuthHandler: { MCPOAuthHandler: {
@ -494,12 +494,9 @@ describe('MCP Routes', () => {
}); });
it('should return 500 when token retrieval throws an unexpected error', async () => { it('should return 500 when token retrieval throws an unexpected error', async () => {
const mockFlowManager = { getLogStores.mockImplementation(() => {
getFlowState: jest.fn().mockRejectedValue(new Error('Database connection failed')), throw new Error('Database connection failed');
}; });
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow'); const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow');
@ -563,8 +560,8 @@ describe('MCP Routes', () => {
}); });
describe('POST /oauth/cancel/:serverName', () => { describe('POST /oauth/cancel/:serverName', () => {
const { getLogStores } = require('~/cache');
const { MCPOAuthHandler } = require('@librechat/api'); const { MCPOAuthHandler } = require('@librechat/api');
const { getLogStores } = require('~/cache');
it('should cancel OAuth flow successfully', async () => { it('should cancel OAuth flow successfully', async () => {
const mockFlowManager = { const mockFlowManager = {
@ -644,15 +641,15 @@ describe('MCP Routes', () => {
}); });
describe('POST /:serverName/reinitialize', () => { describe('POST /:serverName/reinitialize', () => {
const { loadCustomConfig } = require('~/server/services/Config');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
it('should return 404 when server is not found in configuration', async () => { it('should return 404 when server is not found in configuration', async () => {
loadCustomConfig.mockResolvedValue({ const mockMcpManager = {
mcpServers: { getRawConfig: jest.fn().mockReturnValue(null),
'other-server': {}, disconnectUserConnection: jest.fn().mockResolvedValue(),
}, };
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
const response = await request(app).post('/api/mcp/non-existent-server/reinitialize'); const response = await request(app).post('/api/mcp/non-existent-server/reinitialize');
@ -663,16 +660,11 @@ describe('MCP Routes', () => {
}); });
it('should handle OAuth requirement during reinitialize', async () => { it('should handle OAuth requirement during reinitialize', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'oauth-server': {
customUserVars: {},
},
},
});
const mockMcpManager = { const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(), getRawConfig: jest.fn().mockReturnValue({
customUserVars: {},
}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
mcpConfigs: {}, mcpConfigs: {},
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => { getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
if (oauthStart) { if (oauthStart) {
@ -690,7 +682,7 @@ describe('MCP Routes', () => {
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body).toEqual({ expect(response.body).toEqual({
success: 'https://oauth.example.com/auth', success: true,
message: "MCP server 'oauth-server' ready for OAuth authentication", message: "MCP server 'oauth-server' ready for OAuth authentication",
serverName: 'oauth-server', serverName: 'oauth-server',
oauthRequired: true, oauthRequired: true,
@ -699,14 +691,9 @@ describe('MCP Routes', () => {
}); });
it('should return 500 when reinitialize fails with non-OAuth error', async () => { it('should return 500 when reinitialize fails with non-OAuth error', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'error-server': {},
},
});
const mockMcpManager = { const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(), getRawConfig: jest.fn().mockReturnValue({}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
mcpConfigs: {}, mcpConfigs: {},
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')), getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
}; };
@ -724,7 +711,13 @@ describe('MCP Routes', () => {
}); });
it('should return 500 when unexpected error occurs', async () => { it('should return 500 when unexpected error occurs', async () => {
loadCustomConfig.mockRejectedValue(new Error('Config loading failed')); const mockMcpManager = {
getRawConfig: jest.fn().mockImplementation(() => {
throw new Error('Config loading failed');
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).post('/api/mcp/test-server/reinitialize'); const response = await request(app).post('/api/mcp/test-server/reinitialize');
@ -747,29 +740,17 @@ describe('MCP Routes', () => {
expect(response.body).toEqual({ error: 'User not authenticated' }); expect(response.body).toEqual({ error: 'User not authenticated' });
}); });
it('should handle errors when fetching custom user variables', async () => { it('should successfully reinitialize server and cache tools', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': {
customUserVars: {
API_KEY: 'test-key-var',
SECRET_TOKEN: 'test-secret-var',
},
},
},
});
getUserPluginAuthValue
.mockResolvedValueOnce('test-api-key-value')
.mockRejectedValueOnce(new Error('Database error'));
const mockUserConnection = { const mockUserConnection = {
fetchTools: jest.fn().mockResolvedValue([]), fetchTools: jest.fn().mockResolvedValue([
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
]),
}; };
const mockMcpManager = { const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(), getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }),
mcpConfigs: {}, disconnectUserConnection: jest.fn().mockResolvedValue(),
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
}; };
@ -784,38 +765,54 @@ describe('MCP Routes', () => {
const response = await request(app).post('/api/mcp/test-server/reinitialize'); const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body.success).toBe(true); expect(response.body).toEqual({
success: true,
message: "MCP server 'test-server' reinitialized successfully",
serverName: 'test-server',
oauthRequired: false,
oauthUrl: null,
});
expect(mockMcpManager.disconnectUserConnection).toHaveBeenCalledWith(
'test-user-id',
'test-server',
);
expect(setCachedTools).toHaveBeenCalled();
}); });
it('should return failure message when reinitialize completely fails', async () => { it('should handle server with custom user variables', async () => {
loadCustomConfig.mockResolvedValue({ const mockUserConnection = {
mcpServers: { fetchTools: jest.fn().mockResolvedValue([]),
'test-server': {}, };
},
});
const mockMcpManager = { const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(), getRawConfig: jest.fn().mockReturnValue({
mcpConfigs: {}, endpoint: 'http://test-server.com',
getUserConnection: jest.fn().mockResolvedValue(null), customUserVars: {
API_KEY: 'some-env-var',
},
}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
}; };
require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue(
'api-key-value',
);
const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { Constants } = require('librechat-data-provider'); getCachedTools.mockResolvedValue({});
getCachedTools.mockResolvedValue({
[`existing-tool${Constants.mcp_delimiter}test-server`]: { type: 'function' },
});
setCachedTools.mockResolvedValue(); setCachedTools.mockResolvedValue();
const response = await request(app).post('/api/mcp/test-server/reinitialize'); const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body.success).toBe(false); expect(response.body.success).toBe(true);
expect(response.body.message).toBe("Failed to reinitialize MCP server 'test-server'"); expect(
require('~/server/services/PluginService').getUserPluginAuthValue,
).toHaveBeenCalledWith('test-user-id', 'API_KEY', false);
}); });
}); });
@ -984,21 +981,19 @@ describe('MCP Routes', () => {
}); });
describe('GET /:serverName/auth-values', () => { describe('GET /:serverName/auth-values', () => {
const { loadCustomConfig } = require('~/server/services/Config');
const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { getUserPluginAuthValue } = require('~/server/services/PluginService');
it('should return auth value flags for server', async () => { it('should return auth value flags for server', async () => {
loadCustomConfig.mockResolvedValue({ const mockMcpManager = {
mcpServers: { getRawConfig: jest.fn().mockReturnValue({
'test-server': { customUserVars: {
customUserVars: { API_KEY: 'some-env-var',
API_KEY: 'some-env-var', SECRET_TOKEN: 'another-env-var',
SECRET_TOKEN: 'another-env-var',
},
}, },
}, }),
}); };
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce(''); getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
const response = await request(app).get('/api/mcp/test-server/auth-values'); const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1017,11 +1012,11 @@ describe('MCP Routes', () => {
}); });
it('should return 404 when server is not found in configuration', async () => { it('should return 404 when server is not found in configuration', async () => {
loadCustomConfig.mockResolvedValue({ const mockMcpManager = {
mcpServers: { getRawConfig: jest.fn().mockReturnValue(null),
'other-server': {}, };
},
}); require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/non-existent-server/auth-values'); const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
@ -1032,16 +1027,15 @@ describe('MCP Routes', () => {
}); });
it('should handle errors when checking auth values', async () => { it('should handle errors when checking auth values', async () => {
loadCustomConfig.mockResolvedValue({ const mockMcpManager = {
mcpServers: { getRawConfig: jest.fn().mockReturnValue({
'test-server': { customUserVars: {
customUserVars: { API_KEY: 'some-env-var',
API_KEY: 'some-env-var',
},
}, },
}, }),
}); };
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
getUserPluginAuthValue.mockRejectedValue(new Error('Database error')); getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
const response = await request(app).get('/api/mcp/test-server/auth-values'); const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1057,7 +1051,13 @@ describe('MCP Routes', () => {
}); });
it('should return 500 when auth values check throws unexpected error', async () => { it('should return 500 when auth values check throws unexpected error', async () => {
loadCustomConfig.mockRejectedValue(new Error('Config loading failed')); const mockMcpManager = {
getRawConfig: jest.fn().mockImplementation(() => {
throw new Error('Config loading failed');
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/auth-values'); const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1066,14 +1066,13 @@ describe('MCP Routes', () => {
}); });
it('should handle customUserVars that is not an object', async () => { it('should handle customUserVars that is not an object', async () => {
const { loadCustomConfig } = require('~/server/services/Config'); const mockMcpManager = {
loadCustomConfig.mockResolvedValue({ getRawConfig: jest.fn().mockReturnValue({
mcpServers: { customUserVars: 'not-an-object',
'test-server': { }),
customUserVars: 'not-an-object', };
},
}, require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
});
const response = await request(app).get('/api/mcp/test-server/auth-values'); const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1097,98 +1096,6 @@ describe('MCP Routes', () => {
}); });
}); });
describe('POST /:serverName/reinitialize - Tool Deletion Coverage', () => {
it('should handle null cached tools during reinitialize (triggers || {} fallback)', async () => {
const { loadCustomConfig, getCachedTools } = require('~/server/services/Config');
const mockUserConnection = {
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
};
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
disconnectServer: jest.fn(),
initializeServer: jest.fn(),
mcpConfigs: {},
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': { env: { API_KEY: 'test-key' } },
},
});
getCachedTools.mockResolvedValue(null);
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
expect(response.body).toEqual({
message: "MCP server 'test-server' reinitialized successfully",
success: true,
oauthRequired: false,
oauthUrl: null,
serverName: 'test-server',
});
});
it('should delete existing cached tools during successful reinitialize', async () => {
const {
loadCustomConfig,
getCachedTools,
setCachedTools,
} = require('~/server/services/Config');
const mockUserConnection = {
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
};
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
disconnectServer: jest.fn(),
initializeServer: jest.fn(),
mcpConfigs: {},
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': { env: { API_KEY: 'test-key' } },
},
});
const existingTools = {
'old-tool_mcp_test-server': { type: 'function' },
'other-tool_mcp_other-server': { type: 'function' },
};
getCachedTools.mockResolvedValue(existingTools);
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
expect(response.body).toEqual({
message: "MCP server 'test-server' reinitialized successfully",
success: true,
oauthRequired: false,
oauthUrl: null,
serverName: 'test-server',
});
expect(setCachedTools).toHaveBeenCalledWith(
expect.objectContaining({
'new-tool_mcp_test-server': expect.any(Object),
'other-tool_mcp_other-server': { type: 'function' },
}),
{ userId: 'test-user-id' },
);
expect(setCachedTools).toHaveBeenCalledWith(
expect.not.objectContaining({
'old-tool_mcp_test-server': expect.anything(),
}),
{ userId: 'test-user-id' },
);
});
});
describe('GET /:serverName/oauth/callback - Edge Cases', () => { describe('GET /:serverName/oauth/callback - Edge Cases', () => {
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => { it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
const { MCPOAuthHandler } = require('@librechat/api'); const { MCPOAuthHandler } = require('@librechat/api');

View file

@ -1,11 +1,11 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api'); const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys, Constants } = require('librechat-data-provider'); const { Router } = require('express');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { setCachedTools, getCachedTools } = require('~/server/services/Config');
const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config'); const { getMCPManager, getFlowStateManager } = require('~/config');
const { requireJwtAuth } = require('~/server/middleware'); const { requireJwtAuth } = require('~/server/middleware');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
@ -315,9 +315,9 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
const printConfig = false; const mcpManager = getMCPManager();
const config = await loadCustomConfig(printConfig); const serverConfig = mcpManager.getRawConfig(serverName);
if (!config || !config.mcpServers || !config.mcpServers[serverName]) { if (!serverConfig) {
return res.status(404).json({ return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`, error: `MCP server '${serverName}' not found in configuration`,
}); });
@ -325,13 +325,12 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
const flowsCache = getLogStores(CacheKeys.FLOWS); const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache); const flowManager = getFlowStateManager(flowsCache);
const mcpManager = getMCPManager();
await mcpManager.disconnectServer(serverName); await mcpManager.disconnectUserConnection(user.id, serverName);
logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`); logger.info(
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
);
const serverConfig = config.mcpServers[serverName];
mcpManager.mcpConfigs[serverName] = serverConfig;
let customUserVars = {}; let customUserVars = {};
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
for (const varName of Object.keys(serverConfig.customUserVars)) { for (const varName of Object.keys(serverConfig.customUserVars)) {
@ -437,7 +436,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
}; };
res.json({ res.json({
success: (userConnection && !oauthRequired) || (oauthRequired && oauthUrl), success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(), message: getResponseMessage(),
serverName, serverName,
oauthRequired, oauthRequired,
@ -551,15 +550,14 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
return res.status(401).json({ error: 'User not authenticated' }); return res.status(401).json({ error: 'User not authenticated' });
} }
const printConfig = false; const mcpManager = getMCPManager();
const config = await loadCustomConfig(printConfig); const serverConfig = mcpManager.getRawConfig(serverName);
if (!config || !config.mcpServers || !config.mcpServers[serverName]) { if (!serverConfig) {
return res.status(404).json({ return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`, error: `MCP server '${serverName}' not found in configuration`,
}); });
} }
const serverConfig = config.mcpServers[serverName];
const pluginKey = `${Constants.mcp_prefix}${serverName}`; const pluginKey = `${Constants.mcp_prefix}${serverName}`;
const authValueFlags = {}; const authValueFlags = {};

View file

@ -1,8 +1,7 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, setCachedTools } = require('./Config'); const { getCachedTools, setCachedTools } = require('./Config');
const { CacheKeys } = require('librechat-data-provider');
const { createMCPManager } = require('~/config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
/** /**
@ -31,33 +30,19 @@ async function initializeMCPs(app) {
} }
logger.info('Initializing MCP servers...'); logger.info('Initializing MCP servers...');
const mcpManager = getMCPManager(); const mcpManager = await createMCPManager(mcpServers);
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
try { try {
await mcpManager.initializeMCPs({
mcpServers: filteredServers,
flowManager,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
});
delete app.locals.mcpConfig; delete app.locals.mcpConfig;
const availableTools = await getCachedTools(); const cachedTools = await getCachedTools();
if (!availableTools) { if (!cachedTools) {
logger.warn('No available tools found in cache during MCP initialization'); logger.warn('No available tools found in cache during MCP initialization');
return; return;
} }
const toolsCopy = { ...availableTools }; const mcpTools = mcpManager.getAppToolFunctions();
await mcpManager.mapAvailableTools(toolsCopy, flowManager); await setCachedTools({ ...cachedTools, ...mcpTools }, { isGlobal: true });
await setCachedTools(toolsCopy, { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS); await cache.delete(CacheKeys.TOOLS);

View file

@ -2,11 +2,11 @@ import { fileURLToPath } from 'node:url';
import path from 'node:path'; import path from 'node:path';
import typescriptEslintEslintPlugin from '@typescript-eslint/eslint-plugin'; import typescriptEslintEslintPlugin from '@typescript-eslint/eslint-plugin';
import { fixupConfigRules, fixupPluginRules } from '@eslint/compat'; import { fixupConfigRules, fixupPluginRules } from '@eslint/compat';
// import perfectionist from 'eslint-plugin-perfectionist'; import perfectionist from 'eslint-plugin-perfectionist';
import reactHooks from 'eslint-plugin-react-hooks'; import reactHooks from 'eslint-plugin-react-hooks';
import prettier from 'eslint-plugin-prettier';
import tsParser from '@typescript-eslint/parser'; import tsParser from '@typescript-eslint/parser';
import importPlugin from 'eslint-plugin-import'; import importPlugin from 'eslint-plugin-import';
import prettier from 'eslint-plugin-prettier';
import { FlatCompat } from '@eslint/eslintrc'; import { FlatCompat } from '@eslint/eslintrc';
import jsxA11Y from 'eslint-plugin-jsx-a11y'; import jsxA11Y from 'eslint-plugin-jsx-a11y';
import i18next from 'eslint-plugin-i18next'; import i18next from 'eslint-plugin-i18next';
@ -62,7 +62,7 @@ export default [
'jsx-a11y': fixupPluginRules(jsxA11Y), 'jsx-a11y': fixupPluginRules(jsxA11Y),
'import/parsers': tsParser, 'import/parsers': tsParser,
i18next, i18next,
// perfectionist, perfectionist,
prettier: fixupPluginRules(prettier), prettier: fixupPluginRules(prettier),
}, },
@ -140,32 +140,31 @@ export default [
'react/prop-types': 'off', 'react/prop-types': 'off',
'react/display-name': 'off', 'react/display-name': 'off',
// 'perfectionist/sort-imports': [ 'perfectionist/sort-imports': [
// 'error', 'error',
// { {
// type: 'line-length', type: 'line-length',
// order: 'desc', order: 'desc',
// newlinesBetween: 'never', newlinesBetween: 'never',
// customGroups: { customGroups: {
// value: { value: {
// react: ['^react$'], react: ['^react$'],
// // react: ['^react$', '^fs', '^zod', '^path'], local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'],
// local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'], },
// }, },
// }, groups: [
// groups: [ 'react',
// 'react', 'builtin',
// 'builtin', 'external',
// 'external', ['builtin-type', 'external-type'],
// ['builtin-type', 'external-type'], ['internal-type'],
// ['internal-type'], 'local',
// 'local', ['parent', 'sibling', 'index'],
// ['parent', 'sibling', 'index'], 'object',
// 'object', 'unknown',
// 'unknown', ],
// ], },
// }, ],
// ],
// 'perfectionist/sort-named-imports': [ // 'perfectionist/sort-named-imports': [
// 'error', // 'error',

View file

@ -1,7 +1,7 @@
export default { export default {
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'], collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'], coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
testPathIgnorePatterns: ['/node_modules/', '/dist/'], testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'],
coverageReporters: ['text', 'cobertura'], coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit', testResultsProcessor: 'jest-junit',
moduleNameMapper: { moduleNameMapper: {

View file

@ -1,5 +1,5 @@
/* MCP */ /* MCP */
export * from './mcp/manager'; export * from './mcp/MCPManager';
export * from './mcp/oauth'; export * from './mcp/oauth';
export * from './mcp/auth'; export * from './mcp/auth';
export * from './mcp/zod'; export * from './mcp/zod';

View file

@ -0,0 +1,87 @@
import { logger } from '@librechat/data-schemas';
import { MCPConnectionFactory, OAuthConnectionOptions } from '~/mcp/MCPConnectionFactory';
import { MCPConnection } from './connection';
import type * as t from './types';
/**
* Manages MCP connections with lazy loading and reconnection.
* Maintains a pool of connections and handles connection lifecycle management.
*/
export class ConnectionsRepository {
protected readonly serverConfigs: Record<string, t.MCPOptions>;
protected connections: Map<string, MCPConnection> = new Map();
protected oauthOpts: OAuthConnectionOptions | undefined;
constructor(serverConfigs: t.MCPServers, oauthOpts?: OAuthConnectionOptions) {
this.serverConfigs = serverConfigs;
this.oauthOpts = oauthOpts;
}
/** Checks whether this repository can connect to a specific server */
has(serverName: string): boolean {
return !!this.serverConfigs[serverName];
}
/** Gets or creates a connection for the specified server with lazy loading */
async get(serverName: string): Promise<MCPConnection> {
const existingConnection = this.connections.get(serverName);
if (existingConnection && (await existingConnection.isConnected())) return existingConnection;
else await this.disconnect(serverName);
const connection = await MCPConnectionFactory.create(
{
serverName,
serverConfig: this.getServerConfig(serverName),
},
this.oauthOpts,
);
this.connections.set(serverName, connection);
return connection;
}
/** Gets or creates connections for multiple servers concurrently */
async getMany(serverNames: string[]): Promise<Map<string, MCPConnection>> {
const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]);
const connections = await Promise.all(connectionPromises);
return new Map(connections as [string, MCPConnection][]);
}
/** Returns all currently loaded connections without creating new ones */
async getLoaded(): Promise<Map<string, MCPConnection>> {
return this.getMany(Array.from(this.connections.keys()));
}
/** Gets or creates connections for all configured servers */
async getAll(): Promise<Map<string, MCPConnection>> {
return this.getMany(Object.keys(this.serverConfigs));
}
/** Disconnects and removes a specific server connection from the pool */
disconnect(serverName: string): Promise<void> {
const connection = this.connections.get(serverName);
if (!connection) return Promise.resolve();
this.connections.delete(serverName);
return connection.disconnect().catch((err) => {
logger.error(`${this.prefix(serverName)} Error disconnecting`, err);
});
}
/** Disconnects all active connections and returns array of disconnect promises */
disconnectAll(): Promise<void>[] {
const serverNames = Array.from(this.connections.keys());
return serverNames.map((serverName) => this.disconnect(serverName));
}
// Retrieves server configuration by name or throws if not found
protected getServerConfig(serverName: string): t.MCPOptions {
const serverConfig = this.serverConfigs[serverName];
if (serverConfig) return serverConfig;
throw new Error(`${this.prefix(serverName)} Server not found in configuration`);
}
// Returns formatted log prefix for server messages
protected prefix(serverName: string): string {
return `[MCP][${serverName}]`;
}
}

View file

@ -0,0 +1,384 @@
import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { MCPOAuthTokens, MCPOAuthFlowMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
import type * as t from './types';
export interface BasicConnectionOptions {
serverName: string;
serverConfig: t.MCPOptions;
}
export interface OAuthConnectionOptions {
useOAuth: true;
user: TUser;
customUserVars?: Record<string, string>;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
tokenMethods?: TokenMethods;
signal?: AbortSignal;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
returnOnOAuth?: boolean;
}
/**
* Factory for creating MCP connections with optional OAuth authentication.
* Handles OAuth flows, token management, and connection retry logic.
* NOTE: Much of the OAuth logic was extracted from the old MCPManager class as is.
*/
export class MCPConnectionFactory {
protected readonly serverName: string;
protected readonly serverConfig: t.MCPOptions;
protected readonly logPrefix: string;
protected readonly useOAuth: boolean;
// OAuth-related properties (only set when useOAuth is true)
protected readonly userId?: string;
protected readonly flowManager?: FlowStateManager<MCPOAuthTokens | null>;
protected readonly tokenMethods?: TokenMethods;
protected readonly signal?: AbortSignal;
protected readonly oauthStart?: (authURL: string) => Promise<void>;
protected readonly oauthEnd?: () => Promise<void>;
protected readonly returnOnOAuth?: boolean;
/** Creates a new MCP connection with optional OAuth support */
static async create(
basic: BasicConnectionOptions,
oauth?: OAuthConnectionOptions,
): Promise<MCPConnection> {
const factory = new this(basic, oauth);
return factory.createConnection();
}
protected constructor(basic: BasicConnectionOptions, oauth?: OAuthConnectionOptions) {
this.serverConfig = processMCPEnv(basic.serverConfig, oauth?.user, oauth?.customUserVars);
this.serverName = basic.serverName;
this.useOAuth = !!oauth?.useOAuth;
this.logPrefix = oauth?.user
? `[MCP][${basic.serverName}][${oauth.user.id}]`
: `[MCP][${basic.serverName}]`;
if (oauth?.useOAuth) {
this.userId = oauth.user.id;
this.flowManager = oauth.flowManager;
this.tokenMethods = oauth.tokenMethods;
this.signal = oauth.signal;
this.oauthStart = oauth.oauthStart;
this.oauthEnd = oauth.oauthEnd;
this.returnOnOAuth = oauth.returnOnOAuth;
}
}
/** Creates the base MCP connection with OAuth tokens */
protected async createConnection(): Promise<MCPConnection> {
const oauthTokens = this.useOAuth ? await this.getOAuthTokens() : null;
const connection = new MCPConnection({
serverName: this.serverName,
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
});
if (this.useOAuth) this.handleOAuthEvents(connection);
await this.attemptToConnect(connection);
return connection;
}
/** Retrieves existing OAuth tokens from storage or returns null */
protected async getOAuthTokens(): Promise<MCPOAuthTokens | null> {
if (!this.tokenMethods?.findToken) return null;
try {
const tokens = await this.flowManager!.createFlowWithHandler(
`tokens:${this.userId}:${this.serverName}`,
'mcp_get_tokens',
async () => {
return await MCPTokenStorage.getTokens({
userId: this.userId!,
serverName: this.serverName,
findToken: this.tokenMethods!.findToken!,
createToken: this.tokenMethods!.createToken,
updateToken: this.tokenMethods!.updateToken,
refreshTokens: this.createRefreshTokensFunction(),
});
},
this.signal,
);
if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`);
return tokens;
} catch (error) {
logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error);
return null;
}
}
/** Creates a function to refresh OAuth tokens when they expire */
protected createRefreshTokensFunction(): (
refreshToken: string,
metadata: {
userId: string;
serverName: string;
identifier: string;
clientInfo?: OAuthClientInformation;
},
) => Promise<MCPOAuthTokens> {
return async (refreshToken, metadata) => {
return await MCPOAuthHandler.refreshOAuthTokens(
refreshToken,
{
serverUrl: (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url,
serverName: metadata.serverName,
clientInfo: metadata.clientInfo,
},
this.serverConfig.oauth,
);
};
}
/** Sets up OAuth event handlers for the connection */
protected handleOAuthEvents(connection: MCPConnection): void {
connection.on('oauthRequired', async (data) => {
logger.info(`${this.logPrefix} oauthRequired event received`);
// If we just want to initiate OAuth and return, handle it differently
if (this.returnOnOAuth) {
try {
const config = this.serverConfig;
const { authorizationUrl, flowId, flowMetadata } =
await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
data.serverUrl || '',
this.userId!,
config?.oauth,
);
// Create the flow state so the OAuth callback can find it
// We spawn this in the background without waiting for it
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => {
// The OAuth callback will resolve this flow, so we expect it to timeout here
// which is fine - we just need the flow state to exist
});
if (this.oauthStart) {
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
await this.oauthStart(authorizationUrl);
}
// Emit oauthFailed to signal that connection should not proceed
// but OAuth was successfully initiated
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
} catch (error) {
logger.error(`${this.logPrefix} Failed to initiate OAuth flow`, error);
connection.emit('oauthFailed', new Error('OAuth initiation failed'));
return;
}
}
// Normal OAuth handling - wait for completion
const result = await this.handleOAuthRequired();
if (result?.tokens && this.tokenMethods?.createToken) {
try {
connection.setOAuthTokens(result.tokens);
await MCPTokenStorage.storeTokens({
userId: this.userId!,
serverName: this.serverName,
tokens: result.tokens,
createToken: this.tokenMethods.createToken,
updateToken: this.tokenMethods.updateToken,
findToken: this.tokenMethods.findToken,
clientInfo: result.clientInfo,
});
logger.info(`${this.logPrefix} OAuth tokens saved to storage`);
} catch (error) {
logger.error(`${this.logPrefix} Failed to save OAuth tokens to storage`, error);
}
}
// Only emit oauthHandled if we actually got tokens (OAuth succeeded)
if (result?.tokens) {
connection.emit('oauthHandled');
} else {
// OAuth failed, emit oauthFailed to properly reject the promise
logger.warn(`${this.logPrefix} OAuth failed, emitting oauthFailed event`);
connection.emit('oauthFailed', new Error('OAuth authentication failed'));
}
});
}
/** Attempts to establish connection with timeout handling */
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
const connectTimeout = this.serverConfig.initTimeout ?? 30000;
const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout(
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
connectTimeout,
),
);
const connectionAttempt = this.connectTo(connection);
await Promise.race([connectionAttempt, connectionTimeout]);
if (await connection.isConnected()) return;
logger.error(`${this.logPrefix} Failed to establish connection.`);
}
// Handles connection attempts with retry logic and OAuth error handling
private async connectTo(connection: MCPConnection): Promise<void> {
const maxAttempts = 3;
let attempts = 0;
let oauthHandled = false;
while (attempts < maxAttempts) {
try {
await connection.connect();
if (await connection.isConnected()) {
return;
}
throw new Error('Connection attempt succeeded but status is not connected');
} catch (error) {
attempts++;
if (this.useOAuth && this.isOAuthError(error)) {
// Only handle OAuth if this is a user connection (has oauthStart handler)
if (this.oauthStart && !oauthHandled) {
const errorWithFlag = error as (Error & { isOAuthError?: boolean }) | undefined;
if (errorWithFlag?.isOAuthError) {
oauthHandled = true;
logger.info(`${this.logPrefix} Handling OAuth`);
await this.handleOAuthRequired();
}
}
// Don't retry on OAuth errors - just throw
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
throw error;
}
if (attempts === maxAttempts) {
logger.error(`${this.logPrefix} Failed to connect after ${maxAttempts} attempts`, error);
throw error;
}
await new Promise((resolve) => setTimeout(resolve, 2000 * attempts));
}
}
}
// Determines if an error indicates OAuth authentication is required
private isOAuthError(error: unknown): boolean {
if (!error || typeof error !== 'object') {
return false;
}
// Check for SSE error with 401 status
if ('message' in error && typeof error.message === 'string') {
return error.message.includes('401') || error.message.includes('Non-200 status code (401)');
}
// Check for error code
if ('code' in error) {
const code = (error as { code?: number }).code;
return code === 401 || code === 403;
}
return false;
}
/** Manages OAuth flow initiation and completion */
protected async handleOAuthRequired(): Promise<{
tokens: MCPOAuthTokens | null;
clientInfo?: OAuthClientInformation;
} | null> {
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
if (!this.flowManager || !serverUrl) {
logger.error(
`${this.logPrefix} OAuth required but flow manager not available or server URL missing for ${this.serverName}`,
);
logger.warn(`${this.logPrefix} Please configure OAuth credentials for ${this.serverName}`);
return null;
}
try {
logger.debug(`${this.logPrefix} Checking for existing OAuth flow for ${this.serverName}...`);
/** Flow ID to check if a flow already exists */
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
/** Check if there's already an ongoing OAuth flow for this flowId */
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow && existingFlow.status === 'PENDING') {
logger.debug(
`${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`,
);
/** Tokens from existing flow to complete */
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth');
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(
`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`,
);
/** Client information from the existing flow metadata */
const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata;
const clientInfo = existingMetadata?.clientInfo;
return { tokens, clientInfo };
}
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);
const {
authorizationUrl,
flowId: newFlowId,
flowMetadata,
} = await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
serverUrl,
this.userId!,
this.serverConfig.oauth,
);
if (typeof this.oauthStart === 'function') {
logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`);
await this.oauthStart(authorizationUrl);
} else {
logger.info(`
Please visit the following URL to authenticate:
${authorizationUrl}
${this.logPrefix} Flow ID: ${newFlowId}
`);
}
/** Tokens from the new flow */
const tokens = await this.flowManager.createFlow(
newFlowId,
'mcp_oauth',
flowMetadata as FlowMetadata,
);
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`);
/** Client information from the flow metadata */
const clientInfo = flowMetadata?.clientInfo;
return { tokens, clientInfo };
} catch (error) {
logger.error(`${this.logPrefix} Failed to complete OAuth flow for ${this.serverName}`, error);
return null;
}
}
}

View file

@ -0,0 +1,263 @@
import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { logger } from '@librechat/data-schemas';
import pick from 'lodash/pick';
import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { UserConnectionManager } from '~/mcp/UserConnectionManager';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
import { CONSTANTS } from './enum';
import type * as t from './types';
/**
* Centralized manager for MCP server connections and tool execution.
* Extends UserConnectionManager to handle both app-level and user-specific connections.
*/
export class MCPManager extends UserConnectionManager {
private static instance: MCPManager | null;
// Connections shared by all users.
private appConnections: ConnectionsRepository | null = null;
/** Creates and initializes the singleton MCPManager instance */
public static async createInstance(configs: t.MCPServers): Promise<MCPManager> {
if (MCPManager.instance) throw new Error('MCPManager has already been initialized.');
MCPManager.instance = new MCPManager(configs);
await MCPManager.instance.initialize();
return MCPManager.instance;
}
/** Returns the singleton MCPManager instance */
public static getInstance(): MCPManager {
if (!MCPManager.instance) throw new Error('MCPManager has not been initialized.');
return MCPManager.instance;
}
/** Initializes the MCPManager by setting up server registry and app connections */
public async initialize() {
await this.serversRegistry.initialize();
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!);
}
/** Returns all app-level connections */
public async getAllConnections(): Promise<Map<string, MCPConnection>> {
return this.appConnections!.getAll();
}
/** Get servers that require OAuth */
public getOAuthServers(): Set<string> {
return this.serversRegistry.oauthServers!;
}
/** Returns all available tool functions from app-level connections */
public getAppToolFunctions(): t.LCAvailableTools {
return this.serversRegistry.toolFunctions!;
}
/**
* Get instructions for MCP servers
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
* @returns Object mapping server names to their instructions
*/
public getInstructions(serverNames?: string[]): Record<string, string> {
const instructions = this.serversRegistry.serverInstructions!;
if (!serverNames) return instructions;
return pick(instructions, serverNames);
}
/**
* Format MCP server instructions for injection into context
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
* @returns Formatted instructions string ready for context injection
*/
public formatInstructionsForContext(serverNames?: string[]): string {
/** Instructions for specified servers or all stored instructions */
const instructionsToInclude = this.getInstructions(serverNames);
if (Object.keys(instructionsToInclude).length === 0) {
return '';
}
// Format instructions for context injection
const formattedInstructions = Object.entries(instructionsToInclude)
.map(([serverName, instructions]) => {
return `## ${serverName} MCP Server Instructions
${instructions}`;
})
.join('\n\n');
return `# MCP Server Instructions
The following MCP servers are available with their specific instructions:
${formattedInstructions}
Please follow these instructions when using tools from the respective MCP servers.`;
}
/** Loads tools from all app-level connections into the manifest. */
public async loadManifestTools({
serverToolsCallback,
getServerTools,
}: {
flowManager: FlowStateManager<MCPOAuthTokens | null>;
serverToolsCallback?: (serverName: string, tools: t.LCManifestTool[]) => Promise<void>;
getServerTools?: (serverName: string) => Promise<t.LCManifestTool[] | undefined>;
}): Promise<t.LCToolManifest> {
const mcpTools: t.LCManifestTool[] = [];
const connections = await this.appConnections!.getAll();
for (const [serverName, connection] of connections.entries()) {
try {
if (!(await connection.isConnected())) {
logger.warn(
`[MCP][${serverName}] Connection not available for ${serverName} manifest tools.`,
);
if (typeof getServerTools !== 'function') {
logger.warn(
`[MCP][${serverName}] No \`getServerTools\` function provided, skipping tool loading.`,
);
continue;
}
const serverTools = await getServerTools(serverName);
if (serverTools && serverTools.length > 0) {
logger.info(`[MCP][${serverName}] Loaded tools from cache for manifest`);
mcpTools.push(...serverTools);
}
continue;
}
const tools = await connection.fetchTools();
const serverTools: t.LCManifestTool[] = [];
for (const tool of tools) {
const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
const config = this.serversRegistry.parsedConfigs[serverName];
const manifestTool: t.LCManifestTool = {
name: tool.name,
pluginKey,
description: tool.description ?? '',
icon: connection.iconPath,
authConfig: config?.customUserVars
? Object.entries(config.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}))
: undefined,
};
if (config?.chatMenu === false) {
manifestTool.chatMenu = false;
}
mcpTools.push(manifestTool);
serverTools.push(manifestTool);
}
if (typeof serverToolsCallback === 'function') {
await serverToolsCallback(serverName, serverTools);
}
} catch (error) {
logger.error(`[MCP][${serverName}] Error fetching tools for manifest:`, error);
}
}
return mcpTools;
}
/**
* Calls a tool on an MCP server, using either a user-specific connection
* (if userId is provided) or an app-level connection. Updates the last activity timestamp
* for user-specific connections upon successful call initiation.
*/
async callTool({
user,
serverName,
toolName,
provider,
toolArguments,
options,
tokenMethods,
flowManager,
oauthStart,
oauthEnd,
customUserVars,
}: {
user?: TUser;
serverName: string;
toolName: string;
provider: t.Provider;
toolArguments?: Record<string, unknown>;
options?: RequestOptions;
tokenMethods?: TokenMethods;
customUserVars?: Record<string, string>;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
}): Promise<t.FormattedToolResponse> {
/** User-specific connection */
let connection: MCPConnection | undefined;
const userId = user?.id;
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
try {
if (!this.appConnections?.has(serverName) && userId && user) {
this.updateUserLastActivity(userId);
/** Get or create user-specific connection */
connection = await this.getUserConnection({
user,
serverName,
flowManager,
tokenMethods,
oauthStart,
oauthEnd,
signal: options?.signal,
customUserVars,
});
} else {
/** App-level connection */
connection = await this.appConnections!.get(serverName);
if (!connection) {
throw new McpError(
ErrorCode.InvalidRequest,
`${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`,
);
}
}
if (!(await connection.isConnected())) {
/** May happen if getUserConnection failed silently or app connection dropped */
throw new McpError(
ErrorCode.InternalError, // Use InternalError for connection issues
`${logPrefix} Connection is not active. Cannot execute tool ${toolName}.`,
);
}
const result = await connection.client.request(
{
method: 'tools/call',
params: {
name: toolName,
arguments: toolArguments,
},
},
CallToolResultSchema,
{
timeout: connection.timeout,
...options,
},
);
if (userId) {
this.updateUserLastActivity(userId);
}
this.checkIdleConnections();
return formatToolContent(result as t.MCPToolCallResponse, provider);
} catch (error) {
// Log with context and re-throw or handle as needed
logger.error(`${logPrefix}[${toolName}] Tool call failed`, error);
// Rethrowing allows the caller (createMCPTool) to handle the final user message
throw error;
}
}
}

View file

@ -0,0 +1,200 @@
import { logger } from '@librechat/data-schemas';
import mapValues from 'lodash/mapValues';
import pickBy from 'lodash/pickBy';
import pick from 'lodash/pick';
import type { JsonSchemaType } from '~/types';
import type * as t from '~/mcp/types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { type MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
import { CONSTANTS } from '~/mcp/enum';
type ParsedServerConfig = t.MCPOptions & {
url?: string;
requiresOAuth?: boolean;
oauthMetadata?: Record<string, unknown> | null;
capabilities?: string;
tools?: string;
};
/**
* Manages MCP server configurations and metadata discovery.
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
* Determines which servers are for app-level connections.
* Has its own connections repository. All connections are disconnected after initialization.
*/
export class MCPServersRegistry {
private initialized: boolean = false;
private connections: ConnectionsRepository;
public readonly rawConfigs: t.MCPServers;
public readonly parsedConfigs: Record<string, ParsedServerConfig>;
public oauthServers: Set<string> | null = null;
public serverInstructions: Record<string, string> | null = null;
public toolFunctions: t.LCAvailableTools | null = null;
public appServerConfigs: t.MCPServers | null = null;
constructor(configs: t.MCPServers) {
this.rawConfigs = configs;
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv(con));
this.connections = new ConnectionsRepository(configs);
}
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
public async initialize() {
if (this.initialized) return;
this.initialized = true;
const serverNames = Object.keys(this.parsedConfigs);
await Promise.allSettled(serverNames.map((serverName) => this.gatherServerInfo(serverName)));
this.setOAuthServers();
this.setServerInstructions();
this.setAppServerConfigs();
await this.setAppToolFunctions();
this.connections.disconnectAll();
}
// Fetches all metadata for a single server in parallel
private async gatherServerInfo(serverName: string) {
try {
await Promise.allSettled([
this.fetchOAuthRequirement(serverName).catch((error) =>
logger.error(`${this.prefix(serverName)} Failed to fetch OAuth requirement:`, error),
),
this.fetchServerInstructions(serverName).catch((error) =>
logger.error(`${this.prefix(serverName)} Failed to fetch server instructions:`, error),
),
this.fetchServerCapabilities(serverName).catch((error) =>
logger.error(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error),
),
]);
this.logUpdatedConfig(serverName);
} catch (error) {
logger.error(`${this.prefix(serverName)} Failed to initialize server:`, error);
}
}
// Sets app-level server configs (startup enabled, non-OAuth servers)
private setAppServerConfigs() {
const appServers = Object.keys(
pickBy(
this.parsedConfigs,
(config) => config.startup !== false && config.requiresOAuth === false,
),
);
this.appServerConfigs = pick(this.rawConfigs, appServers);
}
// Creates set of server names that require OAuth authentication
private setOAuthServers() {
if (this.oauthServers) return this.oauthServers;
this.oauthServers = new Set(
Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)),
);
return this.oauthServers;
}
// Collects server instructions from all configured servers
private setServerInstructions() {
this.serverInstructions = mapValues(
pickBy(this.parsedConfigs, (config) => config.serverInstructions),
(config) => config.serverInstructions as string,
);
}
// Builds registry of all available tool functions from loaded connections
private async setAppToolFunctions() {
const connections = (await this.connections.getLoaded()).entries();
const allToolFunctions: t.LCAvailableTools = {};
for (const [serverName, conn] of connections) {
try {
const toolFunctions = await this.getToolFunctions(serverName, conn);
Object.assign(allToolFunctions, toolFunctions);
} catch (error) {
logger.error(`${this.prefix(serverName)} Error fetching tool functions:`, error);
}
}
this.toolFunctions = allToolFunctions;
}
// Converts server tools to LibreChat-compatible tool functions format
private async getToolFunctions(
serverName: string,
conn: MCPConnection,
): Promise<t.LCAvailableTools> {
const { tools } = await conn.client.listTools();
const toolFunctions: t.LCAvailableTools = {};
tools.forEach((tool) => {
const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
toolFunctions[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema as JsonSchemaType,
},
};
});
return toolFunctions;
}
// Determines if server requires OAuth if not already specified in the config
private async fetchOAuthRequirement(serverName: string) {
const config = this.parsedConfigs[serverName];
if (config.requiresOAuth != null) return;
if (config.url == null) return (config.requiresOAuth = false);
const result = await detectOAuthRequirement(config.url);
config.requiresOAuth = result.requiresOAuth;
config.oauthMetadata = result.metadata;
}
// Retrieves server instructions from MCP server if enabled in the config
private async fetchServerInstructions(serverName: string) {
const config = this.parsedConfigs[serverName];
if (!config.serverInstructions) return;
if (typeof config.serverInstructions === 'string') return;
const conn = await this.connections.get(serverName);
config.serverInstructions = conn.client.getInstructions();
if (!config.serverInstructions) {
logger.warn(`${this.prefix(serverName)} No server instructions available`);
}
}
// Fetches server capabilities and available tools list
private async fetchServerCapabilities(serverName: string) {
const config = this.parsedConfigs[serverName];
const conn = await this.connections.get(serverName);
const capabilities = conn.client.getServerCapabilities();
if (!capabilities) return;
config.capabilities = JSON.stringify(capabilities);
if (!capabilities.tools) return;
const tools = await conn.client.listTools();
config.tools = tools.tools.map((tool) => tool.name).join(', ');
}
// Logs server configuration summary after initialization
private logUpdatedConfig(serverName: string) {
const prefix = this.prefix(serverName);
const config = this.parsedConfigs[serverName];
logger.info(`${prefix} URL: ${config.url ?? 'N/A'}`);
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
logger.info(`${prefix} Tools: ${config.tools}`);
logger.info(`${prefix} Server Instructions: ${config.serverInstructions ?? 'None'}`);
}
// Returns formatted log prefix for server messages
private prefix(serverName: string): string {
return `[MCP][${serverName}]`;
}
}

View file

@ -0,0 +1,236 @@
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { MCPConnection } from './connection';
import type * as t from './types';
/**
* Abstract base class for managing user-specific MCP connections with lifecycle management.
* Only meant to be extended by MCPManager.
* Much of the logic was move here from the old MCPManager to make it more manageable.
* User connections will soon be ephemeral and not cached anymore:
* https://github.com/danny-avila/LibreChat/discussions/8790
*/
export abstract class UserConnectionManager {
protected readonly serversRegistry: MCPServersRegistry;
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
/** Last activity timestamp for users (not per server) */
protected userLastActivity: Map<string, number> = new Map();
protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
constructor(serverConfigs: t.MCPServers) {
this.serversRegistry = new MCPServersRegistry(serverConfigs);
}
/** fetches am MCP Server config from the registry */
public getRawConfig(serverName: string): t.MCPOptions | undefined {
return this.serversRegistry.rawConfigs[serverName];
}
/** Updates the last activity timestamp for a user */
protected updateUserLastActivity(userId: string): void {
const now = Date.now();
this.userLastActivity.set(userId, now);
logger.debug(
`[MCP][User: ${userId}] Updated last activity timestamp: ${new Date(now).toISOString()}`,
);
}
/** Gets or creates a connection for a specific user */
public async getUserConnection({
user,
serverName,
flowManager,
customUserVars,
tokenMethods,
oauthStart,
oauthEnd,
signal,
returnOnOAuth = false,
}: {
user: TUser;
serverName: string;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
customUserVars?: Record<string, string>;
tokenMethods?: TokenMethods;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
signal?: AbortSignal;
returnOnOAuth?: boolean;
}): Promise<MCPConnection> {
const userId = user.id;
if (!userId) {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
}
const userServerMap = this.userConnections.get(userId);
let connection = userServerMap?.get(serverName);
const now = Date.now();
// Check if user is idle
const lastActivity = this.userLastActivity.get(userId);
if (lastActivity && now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
logger.info(`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections.`);
// Disconnect all user connections
try {
await this.disconnectUserConnections(userId);
} catch (err) {
logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err);
}
connection = undefined; // Force creation of a new connection
} else if (connection) {
if (await connection.isConnected()) {
logger.debug(`[MCP][User: ${userId}][${serverName}] Reusing active connection`);
this.updateUserLastActivity(userId);
return connection;
} else {
// Connection exists but is not connected, attempt to remove potentially stale entry
logger.warn(
`[MCP][User: ${userId}][${serverName}] Found existing but disconnected connection object. Cleaning up.`,
);
this.removeUserConnection(userId, serverName); // Clean up maps
connection = undefined;
}
}
// If no valid connection exists, create a new one
if (!connection) {
logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
}
const config = this.serversRegistry.parsedConfigs[serverName];
if (!config) {
throw new McpError(
ErrorCode.InvalidRequest,
`[MCP][User: ${userId}] Configuration for server "${serverName}" not found.`,
);
}
try {
connection = await MCPConnectionFactory.create(
{
serverName: serverName,
serverConfig: config,
},
{
useOAuth: true,
user: user,
customUserVars: customUserVars,
flowManager: flowManager,
tokenMethods: tokenMethods,
signal: signal,
oauthStart: oauthStart,
oauthEnd: oauthEnd,
returnOnOAuth: returnOnOAuth,
},
);
if (!(await connection?.isConnected())) {
throw new Error('Failed to establish connection after initialization attempt.');
}
if (!this.userConnections.has(userId)) {
this.userConnections.set(userId, new Map());
}
this.userConnections.get(userId)?.set(serverName, connection);
logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`);
// Update timestamp on creation
this.updateUserLastActivity(userId);
return connection;
} catch (error) {
logger.error(`[MCP][User: ${userId}][${serverName}] Failed to establish connection`, error);
// Ensure partial connection state is cleaned up if initialization fails
await connection?.disconnect().catch((disconnectError) => {
logger.error(
`[MCP][User: ${userId}][${serverName}] Error during cleanup after failed connection`,
disconnectError,
);
});
// Ensure cleanup even if connection attempt fails
this.removeUserConnection(userId, serverName);
throw error; // Re-throw the error to the caller
}
}
/** Returns all connections for a specific user */
public getUserConnections(userId: string) {
return this.userConnections.get(userId);
}
/** Removes a specific user connection entry */
protected removeUserConnection(userId: string, serverName: string): void {
const userMap = this.userConnections.get(userId);
if (userMap) {
userMap.delete(serverName);
if (userMap.size === 0) {
this.userConnections.delete(userId);
// Only remove user activity timestamp if all connections are gone
this.userLastActivity.delete(userId);
}
}
logger.debug(`[MCP][User: ${userId}][${serverName}] Removed connection entry.`);
}
/** Disconnects and removes a specific user connection */
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
const userMap = this.userConnections.get(userId);
const connection = userMap?.get(serverName);
if (connection) {
logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`);
await connection.disconnect();
this.removeUserConnection(userId, serverName);
}
}
/** Disconnects and removes all connections for a specific user */
public async disconnectUserConnections(userId: string): Promise<void> {
const userMap = this.userConnections.get(userId);
const disconnectPromises: Promise<void>[] = [];
if (userMap) {
logger.info(`[MCP][User: ${userId}] Disconnecting all servers...`);
const userServers = Array.from(userMap.keys());
for (const serverName of userServers) {
disconnectPromises.push(
this.disconnectUserConnection(userId, serverName).catch((error) => {
logger.error(
`[MCP][User: ${userId}][${serverName}] Error during disconnection:`,
error,
);
}),
);
}
await Promise.allSettled(disconnectPromises);
// Ensure user activity timestamp is removed
this.userLastActivity.delete(userId);
logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);
}
}
/** Check for and disconnect idle connections */
protected checkIdleConnections(currentUserId?: string): void {
const now = Date.now();
// Iterate through all users to check for idle ones
for (const [userId, lastActivity] of this.userLastActivity.entries()) {
if (currentUserId && currentUserId === userId) {
continue;
}
if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
logger.info(
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`,
);
// Disconnect all user connections asynchronously (fire and forget)
this.disconnectUserConnections(userId).catch((err) =>
logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err),
);
}
}
}
}

View file

@ -0,0 +1,212 @@
import { logger } from '@librechat/data-schemas';
import { ConnectionsRepository } from '../ConnectionsRepository';
import { MCPConnectionFactory } from '../MCPConnectionFactory';
import { MCPConnection } from '../connection';
import type * as t from '../types';
// Mock external dependencies
jest.mock('@librechat/data-schemas', () => ({
logger: {
error: jest.fn(),
},
}));
jest.mock('../MCPConnectionFactory', () => ({
MCPConnectionFactory: {
create: jest.fn(),
},
}));
jest.mock('../connection');
const mockLogger = logger as jest.Mocked<typeof logger>;
describe('ConnectionsRepository', () => {
let repository: ConnectionsRepository;
let mockServerConfigs: t.MCPServers;
let mockConnection: jest.Mocked<MCPConnection>;
beforeEach(() => {
mockServerConfigs = {
server1: { url: 'http://localhost:3001' },
server2: { command: 'test-command', args: ['--test'] },
server3: { url: 'ws://localhost:8080', type: 'websocket' },
};
mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
repository = new ConnectionsRepository(mockServerConfigs);
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('has', () => {
it('should return true for existing server', () => {
expect(repository.has('server1')).toBe(true);
});
it('should return false for non-existing server', () => {
expect(repository.has('nonexistent')).toBe(false);
});
});
describe('get', () => {
it('should return existing connected connection', async () => {
mockConnection.isConnected.mockResolvedValue(true);
repository['connections'].set('server1', mockConnection);
const result = await repository.get('server1');
expect(result).toBe(mockConnection);
expect(MCPConnectionFactory.create).not.toHaveBeenCalled();
});
it('should create new connection if none exists', async () => {
const result = await repository.get('server1');
expect(result).toBe(mockConnection);
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
},
undefined,
);
expect(repository['connections'].get('server1')).toBe(mockConnection);
});
it('should create new connection if existing connection is not connected', async () => {
const oldConnection = {
isConnected: jest.fn().mockResolvedValue(false),
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
repository['connections'].set('server1', oldConnection);
const result = await repository.get('server1');
expect(result).toBe(mockConnection);
expect(oldConnection.disconnect).toHaveBeenCalled();
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
},
undefined,
);
});
it('should throw error for non-existent server configuration', async () => {
await expect(repository.get('nonexistent')).rejects.toThrow(
'[MCP][nonexistent] Server not found in configuration',
);
});
it('should handle MCPConnectionFactory.create errors', async () => {
const createError = new Error('Connection creation failed');
(MCPConnectionFactory.create as jest.Mock).mockRejectedValue(createError);
await expect(repository.get('server1')).rejects.toThrow('Connection creation failed');
});
});
describe('getMany', () => {
it('should return connections for multiple servers', async () => {
const result = await repository.getMany(['server1', 'server3']);
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(2);
expect(result.get('server1')).toBe(mockConnection);
expect(result.get('server3')).toBe(mockConnection);
});
});
describe('getLoaded', () => {
it('should return connections for loaded servers only', async () => {
// Load one connection
await repository.get('server1');
const result = await repository.getLoaded();
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(1);
expect(result.get('server1')).toBe(mockConnection);
});
it('should return empty map when no connections are loaded', async () => {
const result = await repository.getLoaded();
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(0);
});
});
describe('getAll', () => {
it('should return connections for all configured servers', async () => {
const result = await repository.getAll();
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(3);
expect(result.get('server1')).toBe(mockConnection);
expect(result.get('server2')).toBe(mockConnection);
expect(result.get('server3')).toBe(mockConnection);
});
});
describe('disconnect', () => {
it('should disconnect and remove existing connection', async () => {
repository['connections'].set('server1', mockConnection);
await repository.disconnect('server1');
expect(mockConnection.disconnect).toHaveBeenCalled();
expect(repository['connections'].has('server1')).toBe(false);
});
it('should handle disconnect error gracefully', async () => {
const disconnectError = new Error('Disconnect failed');
mockConnection.disconnect.mockRejectedValue(disconnectError);
repository['connections'].set('server1', mockConnection);
await repository.disconnect('server1');
expect(mockConnection.disconnect).toHaveBeenCalled();
expect(repository['connections'].has('server1')).toBe(false);
expect(mockLogger.error).toHaveBeenCalledWith(
'[MCP][server1] Error disconnecting',
disconnectError,
);
});
});
describe('disconnectAll', () => {
it('should disconnect all active connections', () => {
const mockConnection1 = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
const mockConnection2 = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
const mockConnection3 = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
repository['connections'].set('server1', mockConnection1);
repository['connections'].set('server2', mockConnection2);
repository['connections'].set('server3', mockConnection3);
const promises = repository.disconnectAll();
expect(promises).toHaveLength(3);
expect(Array.isArray(promises)).toBe(true);
});
});
});

View file

@ -0,0 +1,347 @@
import { logger } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { MCPConnectionFactory } from '../MCPConnectionFactory';
import { MCPOAuthHandler } from '~/mcp/oauth';
import { MCPConnection } from '../connection';
import { processMCPEnv } from '~/utils';
import type * as t from '../types';
jest.mock('../connection');
jest.mock('~/mcp/oauth');
jest.mock('~/utils');
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
const mockLogger = logger as jest.Mocked<typeof logger>;
const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction<typeof processMCPEnv>;
const mockMCPConnection = MCPConnection as jest.MockedClass<typeof MCPConnection>;
const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked<typeof MCPOAuthHandler>;
describe('MCPConnectionFactory', () => {
let mockUser: TUser;
let mockServerConfig: t.MCPOptions;
let mockFlowManager: jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
let mockConnectionInstance: jest.Mocked<MCPConnection>;
beforeEach(() => {
jest.clearAllMocks();
mockUser = {
id: 'user123',
email: 'test@example.com',
} as TUser;
mockServerConfig = {
command: 'node',
args: ['server.js'],
initTimeout: 5000,
} as t.MCPOptions;
mockFlowManager = {
createFlow: jest.fn(),
createFlowWithHandler: jest.fn(),
getFlowState: jest.fn(),
} as unknown as jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
mockConnectionInstance = {
connect: jest.fn(),
isConnected: jest.fn(),
setOAuthTokens: jest.fn(),
on: jest.fn().mockReturnValue(mockConnectionInstance),
emit: jest.fn(),
} as unknown as jest.Mocked<MCPConnection>;
mockMCPConnection.mockImplementation(() => mockConnectionInstance);
mockProcessMCPEnv.mockReturnValue(mockServerConfig);
});
describe('static create method', () => {
it('should create a basic connection without OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, undefined, undefined);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: undefined,
oauthTokens: null,
});
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
it('should create a connection with OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockTokens: MCPOAuthTokens = {
access_token: 'access123',
refresh_token: 'refresh123',
token_type: 'Bearer',
obtained_at: Date.now(),
};
mockFlowManager.createFlowWithHandler.mockResolvedValue(mockTokens);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, mockUser, undefined);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: mockTokens,
});
});
});
describe('OAuth token handling', () => {
it('should return null when no findToken method is provided', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: undefined as unknown as () => Promise<any>,
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(mockFlowManager.createFlowWithHandler).not.toHaveBeenCalled();
});
it('should handle token retrieval errors gracefully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockFlowManager.createFlowWithHandler.mockRejectedValue(new Error('Token fetch failed'));
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: null,
});
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('No existing tokens found or error loading tokens'),
expect.any(Error),
);
});
});
describe('OAuth event handling', () => {
it('should handle oauthRequired event for returnOnOAuth scenario', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
returnOnOAuth: true,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'flow123',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'random-state',
clientInfo: { client_id: 'client123' },
},
};
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail due to connection not established
}
expect(oauthRequiredHandler!).toBeDefined();
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).toHaveBeenCalledWith(
'test-server',
'https://api.example.com',
'user123',
undefined,
);
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({
message: 'OAuth flow initiated - return early',
}),
);
});
});
describe('connection retry logic', () => {
it('should establish connection successfully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig, // Use default 5000ms timeout
};
mockConnectionInstance.connect.mockResolvedValue(undefined);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockConnectionInstance.connect).toHaveBeenCalledTimes(1);
});
it('should handle OAuth errors during connection attempts', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const oauthError = new Error('Non-200 status code (401)');
(oauthError as unknown as Record<string, unknown>).isOAuthError = true;
mockConnectionInstance.connect.mockRejectedValue(oauthError);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow(
'Non-200 status code (401)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
describe('isOAuthError method', () => {
it('should identify OAuth errors by message content', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const error401 = new Error('401 Unauthorized');
mockConnectionInstance.connect.mockRejectedValue(error401);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow('401');
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
});

View file

@ -0,0 +1,287 @@
import { readFileSync } from 'fs';
import { join } from 'path';
import { logger } from '@librechat/data-schemas';
import { load as yamlLoad } from 'js-yaml';
import { ConnectionsRepository } from '../ConnectionsRepository';
import { MCPServersRegistry } from '../MCPServersRegistry';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { MCPConnection } from '../connection';
import type * as t from '../types';
// Mock external dependencies
jest.mock('../oauth/detectOAuth');
jest.mock('../ConnectionsRepository');
jest.mock('../connection');
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
// Mock processMCPEnv to verify it's called and adds a processed marker
jest.mock('~/utils', () => ({
...jest.requireActual('~/utils'),
processMCPEnv: jest.fn((config) => ({
...config,
_processed: true, // Simple marker to verify processing occurred
})),
}));
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
typeof detectOAuthRequirement
>;
const mockLogger = logger as jest.Mocked<typeof logger>;
describe('MCPServersRegistry - Initialize Function', () => {
let rawConfigs: t.MCPServers;
let expectedParsedConfigs: Record<string, any>;
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
beforeEach(() => {
// Load fixtures
const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml');
const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml');
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
string,
any
>;
// Setup mock connections
mockConnections = new Map();
const serverNames = Object.keys(rawConfigs);
serverNames.forEach((serverName) => {
const mockConnection = {
client: {
listTools: jest.fn(),
getInstructions: jest.fn(),
getServerCapabilities: jest.fn(),
},
} as unknown as jest.Mocked<MCPConnection>;
// Setup mock responses based on expected configs
const expectedConfig = expectedParsedConfigs[serverName];
// Mock listTools response
if (expectedConfig.tools) {
const toolNames = expectedConfig.tools.split(', ');
const tools = toolNames.map((name: string) => ({
name,
description: `Description for ${name}`,
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' },
},
},
}));
mockConnection.client.listTools.mockResolvedValue({ tools });
} else {
mockConnection.client.listTools.mockResolvedValue({ tools: [] });
}
// Mock getInstructions response
if (expectedConfig.serverInstructions) {
mockConnection.client.getInstructions.mockReturnValue(expectedConfig.serverInstructions);
} else {
mockConnection.client.getInstructions.mockReturnValue(null);
}
// Mock getServerCapabilities response
if (expectedConfig.capabilities) {
const capabilities = JSON.parse(expectedConfig.capabilities);
mockConnection.client.getServerCapabilities.mockReturnValue(capabilities);
} else {
mockConnection.client.getServerCapabilities.mockReturnValue(null);
}
mockConnections.set(serverName, mockConnection);
});
// Setup ConnectionsRepository mock
mockConnectionsRepo = {
get: jest.fn(),
getLoaded: jest.fn(),
disconnectAll: jest.fn(),
} as unknown as jest.Mocked<ConnectionsRepository>;
mockConnectionsRepo.get.mockImplementation((serverName: string) =>
Promise.resolve(mockConnections.get(serverName)!),
);
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
(ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo);
// Setup OAuth detection mock with deterministic results
mockDetectOAuthRequirement.mockImplementation((url: string) => {
const oauthResults: Record<string, any> = {
'https://api.github.com/mcp': {
requiresOAuth: true,
metadata: {
authorization_url: 'https://github.com/login/oauth/authorize',
token_url: 'https://github.com/login/oauth/access_token',
},
},
'https://api.disabled.com/mcp': {
requiresOAuth: false,
metadata: null,
},
'https://api.public.com/mcp': {
requiresOAuth: false,
metadata: null,
},
};
return Promise.resolve(oauthResults[url] || { requiresOAuth: false, metadata: null });
});
// Clear all mocks
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('initialize() method', () => {
it('should only run initialization once', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
await registry.initialize(); // Second call should not re-run
// Verify that connections are only requested for servers that need them
// (servers with serverInstructions=true or all servers for capabilities)
expect(mockConnectionsRepo.get).toHaveBeenCalled();
});
it('should set all public properties correctly after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Verify initial state
expect(registry.oauthServers).toBeNull();
expect(registry.serverInstructions).toBeNull();
expect(registry.toolFunctions).toBeNull();
expect(registry.appServerConfigs).toBeNull();
await registry.initialize();
// Test oauthServers Set
expect(registry.oauthServers).toBeInstanceOf(Set);
expect(registry.oauthServers).toEqual(
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
);
// Test serverInstructions
expect(registry.serverInstructions).toEqual({
oauth_server: 'GitHub MCP server instructions',
stdio_server: 'Follow these instructions for stdio server',
non_oauth_server: 'Public API instructions',
});
// Test appServerConfigs (startup enabled, non-OAuth servers only)
expect(registry.appServerConfigs).toEqual({
stdio_server: rawConfigs.stdio_server,
websocket_server: rawConfigs.websocket_server,
non_oauth_server: rawConfigs.non_oauth_server,
});
// Test toolFunctions (only 2 servers have tools: oauth_server has 1, stdio_server has 2)
const expectedToolFunctions = {
get_repository_mcp_oauth_server: {
type: 'function',
function: {
name: 'get_repository_mcp_oauth_server',
description: 'Description for get_repository',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
file_read_mcp_stdio_server: {
type: 'function',
function: {
name: 'file_read_mcp_stdio_server',
description: 'Description for file_read',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
file_write_mcp_stdio_server: {
type: 'function',
function: {
name: 'file_write_mcp_stdio_server',
description: 'Description for file_write',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
};
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
});
it('should handle errors gracefully and continue initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Make one server throw an error
mockDetectOAuthRequirement.mockRejectedValueOnce(new Error('OAuth detection failed'));
await registry.initialize();
// Should still initialize successfully
expect(registry.oauthServers).toBeInstanceOf(Set);
expect(registry.toolFunctions).toBeDefined();
// Error should be logged
expect(mockLogger.error).toHaveBeenCalledWith(
expect.stringContaining('[MCP][oauth_server] Failed to fetch OAuth requirement:'),
expect.any(Error),
);
});
it('should disconnect all connections after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
expect(mockConnectionsRepo.disconnectAll).toHaveBeenCalledTimes(1);
});
it('should log configuration updates for each server', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
const serverNames = Object.keys(rawConfigs);
serverNames.forEach((serverName) => {
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] URL:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] OAuth Required:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Capabilities:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Tools:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Server Instructions:`),
);
});
});
it('should have parsedConfigs matching the expected fixture after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
// Compare the actual parsedConfigs against the expected fixture
expect(registry.parsedConfigs).toEqual(expectedParsedConfigs);
});
});
});

View file

@ -1,7 +1,7 @@
import type { PluginAuthMethods } from '@librechat/data-schemas'; import type { PluginAuthMethods } from '@librechat/data-schemas';
import type { GenericTool } from '@librechat/agents'; import type { GenericTool } from '@librechat/agents';
import { getPluginAuthMap } from '~/agents/auth'; import { getPluginAuthMap } from '~/agents/auth';
import { getUserMCPAuthMap } from './auth'; import { getUserMCPAuthMap } from '../auth';
jest.mock('~/agents/auth', () => ({ jest.mock('~/agents/auth', () => ({
getPluginAuthMap: jest.fn(), getPluginAuthMap: jest.fn(),

View file

@ -0,0 +1,76 @@
// Integration tests for OAuth detection against real public MCP servers
// These tests verify the actual behavior against live endpoints
//
// DEVELOPMENT ONLY: This file is excluded from the test suite (.dev.ts extension)
// Use this for development and debugging OAuth detection behavior
//
// To run manually from packages/api directory:
// npx jest --testMatch="**/detectOAuth.integration.dev.ts"
import { detectOAuthRequirement } from '~/mcp/oauth';
describe('OAuth Detection Integration Tests', () => {
const NETWORK_TIMEOUT = 10000;
interface TestServer {
name: string;
url: string;
expectedOAuth: boolean;
expectedMethod: string;
withMeta: boolean;
}
const testServers: TestServer[] = [
{
name: 'GitHub Copilot MCP Server',
url: 'https://api.githubcopilot.com/mcp',
expectedOAuth: true,
expectedMethod: '401-challenge-metadata',
withMeta: true,
},
{
name: 'GitHub API (401 without metadata)',
url: 'https://api.github.com/user',
expectedOAuth: true,
expectedMethod: 'no-metadata-found',
withMeta: false,
},
{
name: 'Stytch Todo MCP Server',
url: 'https://mcp-stytch-consumer-todo-list.maxwell-gerber42.workers.dev',
expectedOAuth: true,
expectedMethod: 'protected-resource-metadata',
withMeta: true,
},
{
name: 'HTTPBin (Non-OAuth)',
url: 'https://httpbin.org',
expectedOAuth: false,
expectedMethod: 'no-metadata-found',
withMeta: false,
},
{
name: 'Unreachable Server',
url: 'https://definitely-not-a-real-server-12345.com',
expectedOAuth: false,
expectedMethod: 'no-metadata-found',
withMeta: false,
},
];
describe('detectOAuthRequirement integration', () => {
testServers.forEach((server) => {
it(
`should handle ${server.name}`,
async () => {
const result = await detectOAuthRequirement(server.url);
expect(result.requiresOAuth).toBe(server.expectedOAuth);
expect(result.method).toBe(server.expectedMethod);
expect(result.metadata == null).toBe(!server.withMeta);
},
NETWORK_TIMEOUT,
);
});
});
});

View file

@ -0,0 +1,74 @@
# Expected parsed MCP server configurations after running initialize()
# These represent the expected state of parsedConfigs after all fetch functions complete
oauth_server:
_processed: true
type: "streamable-http"
url: "https://api.github.com/mcp"
headers:
Authorization: "Bearer {{GITHUB_TOKEN}}"
serverInstructions: "GitHub MCP server instructions"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://github.com/login/oauth/authorize"
token_url: "https://github.com/login/oauth/access_token"
capabilities: '{"tools":{"listChanged":true},"resources":{},"prompts":{}}'
tools: "get_repository"
oauth_predefined:
_processed: true
type: "sse"
url: "https://api.example.com/sse"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://example.com/oauth/authorize"
token_url: "https://example.com/oauth/token"
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
stdio_server:
_processed: true
command: "node"
args: ["server.js"]
env:
API_KEY: "${TEST_API_KEY}"
startup: true
serverInstructions: "Follow these instructions for stdio server"
requiresOAuth: false
capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}'
tools: "file_read, file_write"
websocket_server:
_processed: true
type: "websocket"
url: "ws://localhost:3001/mcp"
startup: true
requiresOAuth: false
oauthMetadata: null
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
disabled_server:
_processed: true
type: "streamable-http"
url: "https://api.disabled.com/mcp"
startup: false
requiresOAuth: false
oauthMetadata: null
non_oauth_server:
_processed: true
type: "streamable-http"
url: "https://api.public.com/mcp"
requiresOAuth: false
serverInstructions: "Public API instructions"
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
oauth_startup_enabled:
_processed: true
type: "sse"
url: "https://api.oauth-startup.com/sse"
requiresOAuth: true
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""

View file

@ -0,0 +1,53 @@
# Raw MCP server configurations used as input to MCPServersRegistry constructor
# These configs test different code paths in the initialization process
# Test OAuth detection with URL - should trigger fetchOAuthRequirement
oauth_server:
type: "streamable-http"
url: "https://api.github.com/mcp"
headers:
Authorization: "Bearer {{GITHUB_TOKEN}}"
serverInstructions: true
# Test OAuth already specified - should skip OAuth detection
oauth_predefined:
type: "sse"
url: "https://api.example.com/sse"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://example.com/oauth/authorize"
token_url: "https://example.com/oauth/token"
# Test stdio server without URL - should set requiresOAuth to false
stdio_server:
command: "node"
args: ["server.js"]
env:
API_KEY: "${TEST_API_KEY}"
startup: true
serverInstructions: "Follow these instructions for stdio server"
# Test websocket server with capabilities but no tools
websocket_server:
type: "websocket"
url: "ws://localhost:3001/mcp"
startup: true
# Test server with startup disabled - should not be included in appServerConfigs
disabled_server:
type: "streamable-http"
url: "https://api.disabled.com/mcp"
startup: false
# Test non-OAuth server - should be included in appServerConfigs
non_oauth_server:
type: "streamable-http"
url: "https://api.public.com/mcp"
requiresOAuth: false
serverInstructions: true
# Test server with OAuth but startup enabled - should not be in appServerConfigs
oauth_startup_enabled:
type: "sse"
url: "https://api.oauth-startup.com/sse"
requiresOAuth: true

View file

@ -1,5 +1,5 @@
import { MCPOAuthHandler } from './handler';
import type { MCPOptions } from 'librechat-data-provider'; import type { MCPOptions } from 'librechat-data-provider';
import { MCPOAuthHandler } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({ jest.mock('@librechat/data-schemas', () => ({
logger: { logger: {

View file

@ -1,4 +1,4 @@
import { normalizeServerName } from './utils'; import { normalizeServerName } from '../utils';
describe('normalizeServerName', () => { describe('normalizeServerName', () => {
it('should not modify server names that already match the pattern', () => { it('should not modify server names that already match the pattern', () => {

View file

@ -2,7 +2,7 @@
// zod.spec.ts // zod.spec.ts
import { z } from 'zod'; import { z } from 'zod';
import type { JsonSchemaType } from '~/types'; import type { JsonSchemaType } from '~/types';
import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from './zod'; import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from '../zod';
describe('convertJsonSchemaToZod', () => { describe('convertJsonSchemaToZod', () => {
describe('primitive types', () => { describe('primitive types', () => {

View file

@ -1,17 +1,18 @@
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import { logger } from '@librechat/data-schemas';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { import {
StdioClientTransport, StdioClientTransport,
getDefaultEnvironment, getDefaultEnvironment,
} from '@modelcontextprotocol/sdk/client/stdio.js'; } from '@modelcontextprotocol/sdk/client/stdio.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { logger } from '@librechat/data-schemas';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type { MCPOAuthTokens } from './oauth/types'; import type { MCPOAuthTokens } from './oauth/types';
import { mcpConfig } from './mcpConfig';
import type * as t from './types'; import type * as t from './types';
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions { function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
@ -56,9 +57,17 @@ function isStreamableHTTPOptions(options: t.MCPOptions): options is t.Streamable
} }
const FIVE_MINUTES = 5 * 60 * 1000; const FIVE_MINUTES = 5 * 60 * 1000;
interface MCPConnectionParams {
serverName: string;
serverConfig: t.MCPOptions;
userId?: string;
oauthTokens?: MCPOAuthTokens | null;
}
export class MCPConnection extends EventEmitter { export class MCPConnection extends EventEmitter {
private static instance: MCPConnection | null = null;
public client: Client; public client: Client;
private options: t.MCPOptions;
private transport: Transport | null = null; // Make this nullable private transport: Transport | null = null; // Make this nullable
private connectionState: t.ConnectionState = 'disconnected'; private connectionState: t.ConnectionState = 'disconnected';
private connectPromise: Promise<void> | null = null; private connectPromise: Promise<void> | null = null;
@ -70,26 +79,23 @@ export class MCPConnection extends EventEmitter {
private reconnectAttempts = 0; private reconnectAttempts = 0;
private readonly userId?: string; private readonly userId?: string;
private lastPingTime: number; private lastPingTime: number;
private lastConnectionCheckAt: number = 0;
private oauthTokens?: MCPOAuthTokens | null; private oauthTokens?: MCPOAuthTokens | null;
private oauthRequired = false; private oauthRequired = false;
iconPath?: string; iconPath?: string;
timeout?: number; timeout?: number;
url?: string; url?: string;
constructor( constructor(params: MCPConnectionParams) {
serverName: string,
private readonly options: t.MCPOptions,
userId?: string,
oauthTokens?: MCPOAuthTokens | null,
) {
super(); super();
this.serverName = serverName; this.options = params.serverConfig;
this.userId = userId; this.serverName = params.serverName;
this.iconPath = options.iconPath; this.userId = params.userId;
this.timeout = options.timeout; this.iconPath = params.serverConfig.iconPath;
this.timeout = params.serverConfig.timeout;
this.lastPingTime = Date.now(); this.lastPingTime = Date.now();
if (oauthTokens) { if (params.oauthTokens) {
this.oauthTokens = oauthTokens; this.oauthTokens = params.oauthTokens;
} }
this.client = new Client( this.client = new Client(
{ {
@ -110,28 +116,6 @@ export class MCPConnection extends EventEmitter {
return `[MCP]${userPart}[${this.serverName}]`; return `[MCP]${userPart}[${this.serverName}]`;
} }
public static getInstance(
serverName: string,
options: t.MCPOptions,
userId?: string,
): MCPConnection {
if (!MCPConnection.instance) {
MCPConnection.instance = new MCPConnection(serverName, options, userId);
}
return MCPConnection.instance;
}
public static getExistingInstance(): MCPConnection | null {
return MCPConnection.instance;
}
public static async destroyInstance(): Promise<void> {
if (MCPConnection.instance) {
await MCPConnection.instance.disconnect();
MCPConnection.instance = null;
}
}
private emitError(error: unknown, errorContext: string): void { private emitError(error: unknown, errorContext: string): void {
const errorMessage = error instanceof Error ? error.message : String(error); const errorMessage = error instanceof Error ? error.message : String(error);
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`); logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
@ -589,6 +573,13 @@ export class MCPConnection extends EventEmitter {
return false; return false;
} }
// If we recently checked, skip expensive verification
const now = Date.now();
if (now - this.lastConnectionCheckAt < mcpConfig.CONNECTION_CHECK_TTL) {
return true;
}
this.lastConnectionCheckAt = now;
try { try {
// Try ping first as it's the lightest check // Try ping first as it's the lightest check
await this.client.ping(); await this.client.ping();

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,11 @@
import { math, isEnabled } from '~/utils';
/**
* Centralized configuration for MCP-related environment variables.
* Provides typed access to MCP settings with default values.
*/
export const mcpConfig = {
OAUTH_ON_AUTH_ERROR: isEnabled(process.env.MCP_OAUTH_ON_AUTH_ERROR ?? true),
OAUTH_DETECTION_TIMEOUT: math(process.env.MCP_OAUTH_DETECTION_TIMEOUT ?? 5000),
CONNECTION_CHECK_TTL: math(process.env.MCP_CONNECTION_CHECK_TTL ?? 60000),
};

View file

@ -0,0 +1,120 @@
// ATTENTION: If you modify OAuth detection logic in this file, run the integration tests to verify:
// npx jest --testMatch="**/detectOAuth.integration.dev.ts" (from packages/api directory)
//
// These tests are excluded from CI because they make live HTTP requests to external services,
// which could cause flaky builds due to network issues or changes in third-party endpoints.
// Manual testing ensures the OAuth detection still works against real MCP servers.
import { discoverOAuthProtectedResourceMetadata } from '@modelcontextprotocol/sdk/client/auth.js';
import { mcpConfig } from '../mcpConfig';
export interface OAuthDetectionResult {
requiresOAuth: boolean;
method: 'protected-resource-metadata' | '401-challenge-metadata' | 'no-metadata-found';
metadata?: Record<string, unknown> | null;
}
/**
* Detects if an MCP server requires OAuth authentication using proactive discovery methods.
*
* This function implements a comprehensive OAuth detection strategy:
* 1. Standard Protected Resource Metadata (RFC 9728) - checks /.well-known/oauth-protected-resource
* 2. 401 Challenge Method - checks WWW-Authenticate header for resource_metadata URL
* 3. Optional fallback: treat any 401/403 response as OAuth requirement (if MCP_OAUTH_ON_AUTH_ERROR=true)
*
* @param serverUrl - The MCP server URL to check for OAuth requirements
* @returns Promise<OAuthDetectionResult> - OAuth requirement details
*/
export async function detectOAuthRequirement(serverUrl: string): Promise<OAuthDetectionResult> {
const protectedResourceResult = await checkProtectedResourceMetadata(serverUrl);
if (protectedResourceResult) return protectedResourceResult;
const challengeResult = await check401ChallengeMetadata(serverUrl);
if (challengeResult) return challengeResult;
const fallbackResult = await checkAuthErrorFallback(serverUrl);
if (fallbackResult) return fallbackResult;
// No OAuth detected
return {
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// ------------------------ Private helper functions for OAuth detection -------------------------//
////////////////////////////////////////////////////////////////////////////////////////////////////
// Checks for OAuth using standard protected resource metadata (RFC 9728)
async function checkProtectedResourceMetadata(
serverUrl: string,
): Promise<OAuthDetectionResult | null> {
try {
const resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl);
if (!resourceMetadata?.authorization_servers?.length) return null;
return {
requiresOAuth: true,
method: 'protected-resource-metadata',
metadata: resourceMetadata,
};
} catch {
return null;
}
}
// Checks for OAuth using 401 challenge with resource metadata URL
async function check401ChallengeMetadata(serverUrl: string): Promise<OAuthDetectionResult | null> {
try {
const response = await fetch(serverUrl, {
method: 'HEAD',
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
});
if (response.status !== 401) return null;
const wwwAuth = response.headers.get('www-authenticate');
const metadataUrl = wwwAuth?.match(/resource_metadata="([^"]+)"/)?.[1];
if (!metadataUrl) return null;
const metadataResponse = await fetch(metadataUrl, {
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
});
const metadata = await metadataResponse.json();
if (!metadata?.authorization_servers?.length) return null;
return {
requiresOAuth: true,
method: '401-challenge-metadata',
metadata,
};
} catch {
return null;
}
}
// Fallback method: treats any auth error as OAuth requirement if configured
async function checkAuthErrorFallback(serverUrl: string): Promise<OAuthDetectionResult | null> {
try {
if (!mcpConfig.OAUTH_ON_AUTH_ERROR) return null;
const response = await fetch(serverUrl, {
method: 'HEAD',
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
});
if (response.status !== 401 && response.status !== 403) return null;
return {
requiresOAuth: true,
method: 'no-metadata-found',
metadata: null,
};
} catch {
return null;
}
}

View file

@ -1,3 +1,4 @@
export * from './types'; export * from './types';
export * from './handler'; export * from './handler';
export * from './tokens'; export * from './tokens';
export * from './detectOAuth';

View file

@ -3,6 +3,13 @@ import { TokenExchangeMethodEnum } from './types/agents';
import { extractEnvVariable } from './utils'; import { extractEnvVariable } from './utils';
const BaseOptionsSchema = z.object({ const BaseOptionsSchema = z.object({
/**
* Controls whether the MCP server is initialized during application startup.
* - true (default): Server is initialized during app startup and included in app-level connections
* - false: Skips initialization at startup and excludes from app-level connections - useful for servers
* requiring manual authentication (e.g., GitHub PAT tokens) that need to be configured through the UI after startup
*/
startup: z.boolean().optional(),
iconPath: z.string().optional(), iconPath: z.string().optional(),
timeout: z.number().optional(), timeout: z.number().optional(),
initTimeout: z.number().optional(), initTimeout: z.number().optional(),
@ -15,6 +22,11 @@ const BaseOptionsSchema = z.object({
* - string: Use custom instructions (overrides server-provided) * - string: Use custom instructions (overrides server-provided)
*/ */
serverInstructions: z.union([z.boolean(), z.string()]).optional(), serverInstructions: z.union([z.boolean(), z.string()]).optional(),
/**
* Whether this server requires OAuth authentication
* If not specified, will be auto-detected during construction
*/
requiresOAuth: z.boolean().optional(),
/** /**
* OAuth configuration for SSE and Streamable HTTP transports * OAuth configuration for SSE and Streamable HTTP transports
* - Optional: OAuth can be auto-discovered on 401 responses * - Optional: OAuth can be auto-discovered on 401 responses