♻️ refactor: MCPManager for Scalability, Fix App-Level Detection, Add Lazy Connections (#8930)

* feat: MCP Connection management overhaul - Making MCPManager manageable

Refactor the monolithic MCPManager into focused, single-responsibility classes:

• MCPServersRegistry: Server configuration discovery and metadata management
• UserConnectionManager: Manages user-level connections
• ConnectionsRepository: Low-level connection pool with lazy loading
• MCPConnectionFactory: Handles MCP connection creation with OAuth support

New Features:
• Lazy loading of app-level connections for horizontal scaling
• Automatic reconnection for app-level connections
• Enhanced OAuth detection with explicit requiresOAuth flag
• Centralized MCP configuration management

Bug Fixes:
• App-level connection detection in MCPManager.callTool
• MCP Connection Reinitialization route behavior

Optimizations:
• MCPConnection.isConnected() caching to reduce overhead
• Concurrent server metadata retrieval instead of sequential

This refactoring addresses scalability bottlenecks and improves reliability
while maintaining backward compatibility with existing configurations.

* feat: Enabled import order in eslint.

* # Moved tests to __tests__ folder
# added tests for MCPServersRegistry.ts

* # Add unit tests for ConnectionsRepository functionality

* # Add unit tests for MCPConnectionFactory functionality

* # Reorganize MCP connection tests and improve error handling

* # reordering imports

* # Update testPathIgnorePatterns in jest.config.mjs to exclude development TypeScript files

* # removed mcp/manager.ts
This commit is contained in:
Theo N. Truong 2025-08-13 09:45:06 -06:00 committed by GitHub
parent 9dbf153489
commit 8780a78165
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 2571 additions and 1468 deletions

View file

@ -698,3 +698,16 @@ OPENWEATHER_API_KEY=
# JINA_API_KEY=your_jina_api_key
# or
# COHERE_API_KEY=your_cohere_api_key
#======================#
# MCP Configuration #
#======================#
# Treat 401/403 responses as OAuth requirement when no oauth metadata found
# MCP_OAUTH_ON_AUTH_ERROR=true
# Timeout for OAuth detection requests in milliseconds
# MCP_OAUTH_DETECTION_TIMEOUT=5000
# Cache connection status checks for this many milliseconds to avoid expensive verification
# MCP_CONNECTION_CHECK_TTL=60000

View file

@ -147,7 +147,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate
## 8. Module Import Conventions
- `npm` packages first,
- from shortest line (top) to longest (bottom)
- from longest line (top) to shortest (bottom)
- Followed by typescript types (pertains to data-provider and client workspaces)
- longest line (top) to shortest (bottom)
@ -157,6 +157,8 @@ Apply the following naming conventions to branches, labels, and other Git-relate
- longest line (top) to shortest (bottom)
- imports with alias `~` treated the same as relative import with respect to line length
**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks.
---
Please ensure that you adapt this summary to fit the specific context and nuances of your project.

1
.gitignore vendored
View file

@ -137,3 +137,4 @@ helm/**/.values.yaml
/.openai/
/.tabnine/
/.codeium
*.local.md

View file

@ -1,27 +1,13 @@
const { MCPManager, FlowStateManager } = require('@librechat/api');
const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider');
const { MCPManager, FlowStateManager } = require('@librechat/api');
const logger = require('./winston');
global.EventSource = EventSource;
/** @type {MCPManager} */
let mcpManager = null;
let flowManager = null;
/**
* @param {string} [userId] - Optional user ID, to avoid disconnecting the current user.
* @returns {MCPManager}
*/
function getMCPManager(userId) {
if (!mcpManager) {
mcpManager = MCPManager.getInstance();
} else {
mcpManager.checkIdleConnections(userId);
}
return mcpManager;
}
/**
* @param {Keyv} flowsCache
* @returns {FlowStateManager}
@ -37,6 +23,7 @@ function getFlowStateManager(flowsCache) {
module.exports = {
logger,
getMCPManager,
createMCPManager: MCPManager.createInstance,
getMCPManager: MCPManager.getInstance,
getFlowStateManager,
};

View file

@ -1,7 +1,7 @@
const express = require('express');
const { MongoMemoryServer } = require('mongodb-memory-server');
const request = require('supertest');
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const express = require('express');
jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {
@ -494,12 +494,9 @@ describe('MCP Routes', () => {
});
it('should return 500 when token retrieval throws an unexpected error', async () => {
const mockFlowManager = {
getFlowState: jest.fn().mockRejectedValue(new Error('Database connection failed')),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
getLogStores.mockImplementation(() => {
throw new Error('Database connection failed');
});
const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow');
@ -563,8 +560,8 @@ describe('MCP Routes', () => {
});
describe('POST /oauth/cancel/:serverName', () => {
const { getLogStores } = require('~/cache');
const { MCPOAuthHandler } = require('@librechat/api');
const { getLogStores } = require('~/cache');
it('should cancel OAuth flow successfully', async () => {
const mockFlowManager = {
@ -644,15 +641,15 @@ describe('MCP Routes', () => {
});
describe('POST /:serverName/reinitialize', () => {
const { loadCustomConfig } = require('~/server/services/Config');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
it('should return 404 when server is not found in configuration', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'other-server': {},
},
});
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue(null),
disconnectUserConnection: jest.fn().mockResolvedValue(),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
const response = await request(app).post('/api/mcp/non-existent-server/reinitialize');
@ -663,16 +660,11 @@ describe('MCP Routes', () => {
});
it('should handle OAuth requirement during reinitialize', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'oauth-server': {
customUserVars: {},
},
},
});
const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(),
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {},
}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
mcpConfigs: {},
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
if (oauthStart) {
@ -690,7 +682,7 @@ describe('MCP Routes', () => {
expect(response.status).toBe(200);
expect(response.body).toEqual({
success: 'https://oauth.example.com/auth',
success: true,
message: "MCP server 'oauth-server' ready for OAuth authentication",
serverName: 'oauth-server',
oauthRequired: true,
@ -699,14 +691,9 @@ describe('MCP Routes', () => {
});
it('should return 500 when reinitialize fails with non-OAuth error', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'error-server': {},
},
});
const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(),
getRawConfig: jest.fn().mockReturnValue({}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
mcpConfigs: {},
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
};
@ -724,7 +711,13 @@ describe('MCP Routes', () => {
});
it('should return 500 when unexpected error occurs', async () => {
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
const mockMcpManager = {
getRawConfig: jest.fn().mockImplementation(() => {
throw new Error('Config loading failed');
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).post('/api/mcp/test-server/reinitialize');
@ -747,29 +740,17 @@ describe('MCP Routes', () => {
expect(response.body).toEqual({ error: 'User not authenticated' });
});
it('should handle errors when fetching custom user variables', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': {
customUserVars: {
API_KEY: 'test-key-var',
SECRET_TOKEN: 'test-secret-var',
},
},
},
});
getUserPluginAuthValue
.mockResolvedValueOnce('test-api-key-value')
.mockRejectedValueOnce(new Error('Database error'));
it('should successfully reinitialize server and cache tools', async () => {
const mockUserConnection = {
fetchTools: jest.fn().mockResolvedValue([]),
fetchTools: jest.fn().mockResolvedValue([
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
]),
};
const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(),
mcpConfigs: {},
getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }),
disconnectUserConnection: jest.fn().mockResolvedValue(),
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
};
@ -784,38 +765,54 @@ describe('MCP Routes', () => {
const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200);
expect(response.body.success).toBe(true);
expect(response.body).toEqual({
success: true,
message: "MCP server 'test-server' reinitialized successfully",
serverName: 'test-server',
oauthRequired: false,
oauthUrl: null,
});
expect(mockMcpManager.disconnectUserConnection).toHaveBeenCalledWith(
'test-user-id',
'test-server',
);
expect(setCachedTools).toHaveBeenCalled();
});
it('should return failure message when reinitialize completely fails', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': {},
},
});
it('should handle server with custom user variables', async () => {
const mockUserConnection = {
fetchTools: jest.fn().mockResolvedValue([]),
};
const mockMcpManager = {
disconnectServer: jest.fn().mockResolvedValue(),
mcpConfigs: {},
getUserConnection: jest.fn().mockResolvedValue(null),
getRawConfig: jest.fn().mockReturnValue({
endpoint: 'http://test-server.com',
customUserVars: {
API_KEY: 'some-env-var',
},
}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue(
'api-key-value',
);
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { Constants } = require('librechat-data-provider');
getCachedTools.mockResolvedValue({
[`existing-tool${Constants.mcp_delimiter}test-server`]: { type: 'function' },
});
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
const response = await request(app).post('/api/mcp/test-server/reinitialize');
expect(response.status).toBe(200);
expect(response.body.success).toBe(false);
expect(response.body.message).toBe("Failed to reinitialize MCP server 'test-server'");
expect(response.body.success).toBe(true);
expect(
require('~/server/services/PluginService').getUserPluginAuthValue,
).toHaveBeenCalledWith('test-user-id', 'API_KEY', false);
});
});
@ -984,21 +981,19 @@ describe('MCP Routes', () => {
});
describe('GET /:serverName/auth-values', () => {
const { loadCustomConfig } = require('~/server/services/Config');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
it('should return auth value flags for server', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': {
customUserVars: {
API_KEY: 'some-env-var',
SECRET_TOKEN: 'another-env-var',
},
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {
API_KEY: 'some-env-var',
SECRET_TOKEN: 'another-env-var',
},
},
});
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1017,11 +1012,11 @@ describe('MCP Routes', () => {
});
it('should return 404 when server is not found in configuration', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'other-server': {},
},
});
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue(null),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
@ -1032,16 +1027,15 @@ describe('MCP Routes', () => {
});
it('should handle errors when checking auth values', async () => {
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': {
customUserVars: {
API_KEY: 'some-env-var',
},
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {
API_KEY: 'some-env-var',
},
},
});
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1057,7 +1051,13 @@ describe('MCP Routes', () => {
});
it('should return 500 when auth values check throws unexpected error', async () => {
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
const mockMcpManager = {
getRawConfig: jest.fn().mockImplementation(() => {
throw new Error('Config loading failed');
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1066,14 +1066,13 @@ describe('MCP Routes', () => {
});
it('should handle customUserVars that is not an object', async () => {
const { loadCustomConfig } = require('~/server/services/Config');
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': {
customUserVars: 'not-an-object',
},
},
});
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
customUserVars: 'not-an-object',
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/auth-values');
@ -1097,98 +1096,6 @@ describe('MCP Routes', () => {
});
});
describe('POST /:serverName/reinitialize - Tool Deletion Coverage', () => {
it('should handle null cached tools during reinitialize (triggers || {} fallback)', async () => {
const { loadCustomConfig, getCachedTools } = require('~/server/services/Config');
const mockUserConnection = {
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
};
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
disconnectServer: jest.fn(),
initializeServer: jest.fn(),
mcpConfigs: {},
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': { env: { API_KEY: 'test-key' } },
},
});
getCachedTools.mockResolvedValue(null);
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
expect(response.body).toEqual({
message: "MCP server 'test-server' reinitialized successfully",
success: true,
oauthRequired: false,
oauthUrl: null,
serverName: 'test-server',
});
});
it('should delete existing cached tools during successful reinitialize', async () => {
const {
loadCustomConfig,
getCachedTools,
setCachedTools,
} = require('~/server/services/Config');
const mockUserConnection = {
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
};
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
disconnectServer: jest.fn(),
initializeServer: jest.fn(),
mcpConfigs: {},
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
loadCustomConfig.mockResolvedValue({
mcpServers: {
'test-server': { env: { API_KEY: 'test-key' } },
},
});
const existingTools = {
'old-tool_mcp_test-server': { type: 'function' },
'other-tool_mcp_other-server': { type: 'function' },
};
getCachedTools.mockResolvedValue(existingTools);
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
expect(response.body).toEqual({
message: "MCP server 'test-server' reinitialized successfully",
success: true,
oauthRequired: false,
oauthUrl: null,
serverName: 'test-server',
});
expect(setCachedTools).toHaveBeenCalledWith(
expect.objectContaining({
'new-tool_mcp_test-server': expect.any(Object),
'other-tool_mcp_other-server': { type: 'function' },
}),
{ userId: 'test-user-id' },
);
expect(setCachedTools).toHaveBeenCalledWith(
expect.not.objectContaining({
'old-tool_mcp_test-server': expect.anything(),
}),
{ userId: 'test-user-id' },
);
});
});
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
const { MCPOAuthHandler } = require('@librechat/api');

View file

@ -1,11 +1,11 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
const { Router } = require('express');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { setCachedTools, getCachedTools } = require('~/server/services/Config');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { requireJwtAuth } = require('~/server/middleware');
const { getLogStores } = require('~/cache');
@ -315,9 +315,9 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
const printConfig = false;
const config = await loadCustomConfig(printConfig);
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
const mcpManager = getMCPManager();
const serverConfig = mcpManager.getRawConfig(serverName);
if (!serverConfig) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
});
@ -325,13 +325,12 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const mcpManager = getMCPManager();
await mcpManager.disconnectServer(serverName);
logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`);
await mcpManager.disconnectUserConnection(user.id, serverName);
logger.info(
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
);
const serverConfig = config.mcpServers[serverName];
mcpManager.mcpConfigs[serverName] = serverConfig;
let customUserVars = {};
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
for (const varName of Object.keys(serverConfig.customUserVars)) {
@ -437,7 +436,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
};
res.json({
success: (userConnection && !oauthRequired) || (oauthRequired && oauthUrl),
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(),
serverName,
oauthRequired,
@ -551,15 +550,14 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
return res.status(401).json({ error: 'User not authenticated' });
}
const printConfig = false;
const config = await loadCustomConfig(printConfig);
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
const mcpManager = getMCPManager();
const serverConfig = mcpManager.getRawConfig(serverName);
if (!serverConfig) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
});
}
const serverConfig = config.mcpServers[serverName];
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
const authValueFlags = {};

View file

@ -1,8 +1,7 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, setCachedTools } = require('./Config');
const { CacheKeys } = require('librechat-data-provider');
const { createMCPManager } = require('~/config');
const { getLogStores } = require('~/cache');
/**
@ -31,33 +30,19 @@ async function initializeMCPs(app) {
}
logger.info('Initializing MCP servers...');
const mcpManager = getMCPManager();
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
const mcpManager = await createMCPManager(mcpServers);
try {
await mcpManager.initializeMCPs({
mcpServers: filteredServers,
flowManager,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
});
delete app.locals.mcpConfig;
const availableTools = await getCachedTools();
const cachedTools = await getCachedTools();
if (!availableTools) {
if (!cachedTools) {
logger.warn('No available tools found in cache during MCP initialization');
return;
}
const toolsCopy = { ...availableTools };
await mcpManager.mapAvailableTools(toolsCopy, flowManager);
await setCachedTools(toolsCopy, { isGlobal: true });
const mcpTools = mcpManager.getAppToolFunctions();
await setCachedTools({ ...cachedTools, ...mcpTools }, { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);

View file

@ -2,11 +2,11 @@ import { fileURLToPath } from 'node:url';
import path from 'node:path';
import typescriptEslintEslintPlugin from '@typescript-eslint/eslint-plugin';
import { fixupConfigRules, fixupPluginRules } from '@eslint/compat';
// import perfectionist from 'eslint-plugin-perfectionist';
import perfectionist from 'eslint-plugin-perfectionist';
import reactHooks from 'eslint-plugin-react-hooks';
import prettier from 'eslint-plugin-prettier';
import tsParser from '@typescript-eslint/parser';
import importPlugin from 'eslint-plugin-import';
import prettier from 'eslint-plugin-prettier';
import { FlatCompat } from '@eslint/eslintrc';
import jsxA11Y from 'eslint-plugin-jsx-a11y';
import i18next from 'eslint-plugin-i18next';
@ -62,7 +62,7 @@ export default [
'jsx-a11y': fixupPluginRules(jsxA11Y),
'import/parsers': tsParser,
i18next,
// perfectionist,
perfectionist,
prettier: fixupPluginRules(prettier),
},
@ -140,32 +140,31 @@ export default [
'react/prop-types': 'off',
'react/display-name': 'off',
// 'perfectionist/sort-imports': [
// 'error',
// {
// type: 'line-length',
// order: 'desc',
// newlinesBetween: 'never',
// customGroups: {
// value: {
// react: ['^react$'],
// // react: ['^react$', '^fs', '^zod', '^path'],
// local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'],
// },
// },
// groups: [
// 'react',
// 'builtin',
// 'external',
// ['builtin-type', 'external-type'],
// ['internal-type'],
// 'local',
// ['parent', 'sibling', 'index'],
// 'object',
// 'unknown',
// ],
// },
// ],
'perfectionist/sort-imports': [
'error',
{
type: 'line-length',
order: 'desc',
newlinesBetween: 'never',
customGroups: {
value: {
react: ['^react$'],
local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'],
},
},
groups: [
'react',
'builtin',
'external',
['builtin-type', 'external-type'],
['internal-type'],
'local',
['parent', 'sibling', 'index'],
'object',
'unknown',
],
},
],
// 'perfectionist/sort-named-imports': [
// 'error',

View file

@ -1,7 +1,7 @@
export default {
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
testPathIgnorePatterns: ['/node_modules/', '/dist/'],
testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'],
coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit',
moduleNameMapper: {

View file

@ -1,5 +1,5 @@
/* MCP */
export * from './mcp/manager';
export * from './mcp/MCPManager';
export * from './mcp/oauth';
export * from './mcp/auth';
export * from './mcp/zod';

View file

@ -0,0 +1,87 @@
import { logger } from '@librechat/data-schemas';
import { MCPConnectionFactory, OAuthConnectionOptions } from '~/mcp/MCPConnectionFactory';
import { MCPConnection } from './connection';
import type * as t from './types';
/**
* Manages MCP connections with lazy loading and reconnection.
* Maintains a pool of connections and handles connection lifecycle management.
*/
export class ConnectionsRepository {
protected readonly serverConfigs: Record<string, t.MCPOptions>;
protected connections: Map<string, MCPConnection> = new Map();
protected oauthOpts: OAuthConnectionOptions | undefined;
constructor(serverConfigs: t.MCPServers, oauthOpts?: OAuthConnectionOptions) {
this.serverConfigs = serverConfigs;
this.oauthOpts = oauthOpts;
}
/** Checks whether this repository can connect to a specific server */
has(serverName: string): boolean {
return !!this.serverConfigs[serverName];
}
/** Gets or creates a connection for the specified server with lazy loading */
async get(serverName: string): Promise<MCPConnection> {
const existingConnection = this.connections.get(serverName);
if (existingConnection && (await existingConnection.isConnected())) return existingConnection;
else await this.disconnect(serverName);
const connection = await MCPConnectionFactory.create(
{
serverName,
serverConfig: this.getServerConfig(serverName),
},
this.oauthOpts,
);
this.connections.set(serverName, connection);
return connection;
}
/** Gets or creates connections for multiple servers concurrently */
async getMany(serverNames: string[]): Promise<Map<string, MCPConnection>> {
const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]);
const connections = await Promise.all(connectionPromises);
return new Map(connections as [string, MCPConnection][]);
}
/** Returns all currently loaded connections without creating new ones */
async getLoaded(): Promise<Map<string, MCPConnection>> {
return this.getMany(Array.from(this.connections.keys()));
}
/** Gets or creates connections for all configured servers */
async getAll(): Promise<Map<string, MCPConnection>> {
return this.getMany(Object.keys(this.serverConfigs));
}
/** Disconnects and removes a specific server connection from the pool */
disconnect(serverName: string): Promise<void> {
const connection = this.connections.get(serverName);
if (!connection) return Promise.resolve();
this.connections.delete(serverName);
return connection.disconnect().catch((err) => {
logger.error(`${this.prefix(serverName)} Error disconnecting`, err);
});
}
/** Disconnects all active connections and returns array of disconnect promises */
disconnectAll(): Promise<void>[] {
const serverNames = Array.from(this.connections.keys());
return serverNames.map((serverName) => this.disconnect(serverName));
}
// Retrieves server configuration by name or throws if not found
protected getServerConfig(serverName: string): t.MCPOptions {
const serverConfig = this.serverConfigs[serverName];
if (serverConfig) return serverConfig;
throw new Error(`${this.prefix(serverName)} Server not found in configuration`);
}
// Returns formatted log prefix for server messages
protected prefix(serverName: string): string {
return `[MCP][${serverName}]`;
}
}

View file

@ -0,0 +1,384 @@
import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { MCPOAuthTokens, MCPOAuthFlowMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
import type * as t from './types';
export interface BasicConnectionOptions {
serverName: string;
serverConfig: t.MCPOptions;
}
export interface OAuthConnectionOptions {
useOAuth: true;
user: TUser;
customUserVars?: Record<string, string>;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
tokenMethods?: TokenMethods;
signal?: AbortSignal;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
returnOnOAuth?: boolean;
}
/**
* Factory for creating MCP connections with optional OAuth authentication.
* Handles OAuth flows, token management, and connection retry logic.
* NOTE: Much of the OAuth logic was extracted from the old MCPManager class as is.
*/
export class MCPConnectionFactory {
protected readonly serverName: string;
protected readonly serverConfig: t.MCPOptions;
protected readonly logPrefix: string;
protected readonly useOAuth: boolean;
// OAuth-related properties (only set when useOAuth is true)
protected readonly userId?: string;
protected readonly flowManager?: FlowStateManager<MCPOAuthTokens | null>;
protected readonly tokenMethods?: TokenMethods;
protected readonly signal?: AbortSignal;
protected readonly oauthStart?: (authURL: string) => Promise<void>;
protected readonly oauthEnd?: () => Promise<void>;
protected readonly returnOnOAuth?: boolean;
/** Creates a new MCP connection with optional OAuth support */
static async create(
basic: BasicConnectionOptions,
oauth?: OAuthConnectionOptions,
): Promise<MCPConnection> {
const factory = new this(basic, oauth);
return factory.createConnection();
}
protected constructor(basic: BasicConnectionOptions, oauth?: OAuthConnectionOptions) {
this.serverConfig = processMCPEnv(basic.serverConfig, oauth?.user, oauth?.customUserVars);
this.serverName = basic.serverName;
this.useOAuth = !!oauth?.useOAuth;
this.logPrefix = oauth?.user
? `[MCP][${basic.serverName}][${oauth.user.id}]`
: `[MCP][${basic.serverName}]`;
if (oauth?.useOAuth) {
this.userId = oauth.user.id;
this.flowManager = oauth.flowManager;
this.tokenMethods = oauth.tokenMethods;
this.signal = oauth.signal;
this.oauthStart = oauth.oauthStart;
this.oauthEnd = oauth.oauthEnd;
this.returnOnOAuth = oauth.returnOnOAuth;
}
}
/** Creates the base MCP connection with OAuth tokens */
protected async createConnection(): Promise<MCPConnection> {
const oauthTokens = this.useOAuth ? await this.getOAuthTokens() : null;
const connection = new MCPConnection({
serverName: this.serverName,
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
});
if (this.useOAuth) this.handleOAuthEvents(connection);
await this.attemptToConnect(connection);
return connection;
}
/** Retrieves existing OAuth tokens from storage or returns null */
protected async getOAuthTokens(): Promise<MCPOAuthTokens | null> {
if (!this.tokenMethods?.findToken) return null;
try {
const tokens = await this.flowManager!.createFlowWithHandler(
`tokens:${this.userId}:${this.serverName}`,
'mcp_get_tokens',
async () => {
return await MCPTokenStorage.getTokens({
userId: this.userId!,
serverName: this.serverName,
findToken: this.tokenMethods!.findToken!,
createToken: this.tokenMethods!.createToken,
updateToken: this.tokenMethods!.updateToken,
refreshTokens: this.createRefreshTokensFunction(),
});
},
this.signal,
);
if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`);
return tokens;
} catch (error) {
logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error);
return null;
}
}
/** Creates a function to refresh OAuth tokens when they expire */
protected createRefreshTokensFunction(): (
refreshToken: string,
metadata: {
userId: string;
serverName: string;
identifier: string;
clientInfo?: OAuthClientInformation;
},
) => Promise<MCPOAuthTokens> {
return async (refreshToken, metadata) => {
return await MCPOAuthHandler.refreshOAuthTokens(
refreshToken,
{
serverUrl: (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url,
serverName: metadata.serverName,
clientInfo: metadata.clientInfo,
},
this.serverConfig.oauth,
);
};
}
/** Sets up OAuth event handlers for the connection */
protected handleOAuthEvents(connection: MCPConnection): void {
connection.on('oauthRequired', async (data) => {
logger.info(`${this.logPrefix} oauthRequired event received`);
// If we just want to initiate OAuth and return, handle it differently
if (this.returnOnOAuth) {
try {
const config = this.serverConfig;
const { authorizationUrl, flowId, flowMetadata } =
await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
data.serverUrl || '',
this.userId!,
config?.oauth,
);
// Create the flow state so the OAuth callback can find it
// We spawn this in the background without waiting for it
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => {
// The OAuth callback will resolve this flow, so we expect it to timeout here
// which is fine - we just need the flow state to exist
});
if (this.oauthStart) {
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
await this.oauthStart(authorizationUrl);
}
// Emit oauthFailed to signal that connection should not proceed
// but OAuth was successfully initiated
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
} catch (error) {
logger.error(`${this.logPrefix} Failed to initiate OAuth flow`, error);
connection.emit('oauthFailed', new Error('OAuth initiation failed'));
return;
}
}
// Normal OAuth handling - wait for completion
const result = await this.handleOAuthRequired();
if (result?.tokens && this.tokenMethods?.createToken) {
try {
connection.setOAuthTokens(result.tokens);
await MCPTokenStorage.storeTokens({
userId: this.userId!,
serverName: this.serverName,
tokens: result.tokens,
createToken: this.tokenMethods.createToken,
updateToken: this.tokenMethods.updateToken,
findToken: this.tokenMethods.findToken,
clientInfo: result.clientInfo,
});
logger.info(`${this.logPrefix} OAuth tokens saved to storage`);
} catch (error) {
logger.error(`${this.logPrefix} Failed to save OAuth tokens to storage`, error);
}
}
// Only emit oauthHandled if we actually got tokens (OAuth succeeded)
if (result?.tokens) {
connection.emit('oauthHandled');
} else {
// OAuth failed, emit oauthFailed to properly reject the promise
logger.warn(`${this.logPrefix} OAuth failed, emitting oauthFailed event`);
connection.emit('oauthFailed', new Error('OAuth authentication failed'));
}
});
}
/** Attempts to establish connection with timeout handling */
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
const connectTimeout = this.serverConfig.initTimeout ?? 30000;
const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout(
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
connectTimeout,
),
);
const connectionAttempt = this.connectTo(connection);
await Promise.race([connectionAttempt, connectionTimeout]);
if (await connection.isConnected()) return;
logger.error(`${this.logPrefix} Failed to establish connection.`);
}
// Handles connection attempts with retry logic and OAuth error handling
private async connectTo(connection: MCPConnection): Promise<void> {
const maxAttempts = 3;
let attempts = 0;
let oauthHandled = false;
while (attempts < maxAttempts) {
try {
await connection.connect();
if (await connection.isConnected()) {
return;
}
throw new Error('Connection attempt succeeded but status is not connected');
} catch (error) {
attempts++;
if (this.useOAuth && this.isOAuthError(error)) {
// Only handle OAuth if this is a user connection (has oauthStart handler)
if (this.oauthStart && !oauthHandled) {
const errorWithFlag = error as (Error & { isOAuthError?: boolean }) | undefined;
if (errorWithFlag?.isOAuthError) {
oauthHandled = true;
logger.info(`${this.logPrefix} Handling OAuth`);
await this.handleOAuthRequired();
}
}
// Don't retry on OAuth errors - just throw
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
throw error;
}
if (attempts === maxAttempts) {
logger.error(`${this.logPrefix} Failed to connect after ${maxAttempts} attempts`, error);
throw error;
}
await new Promise((resolve) => setTimeout(resolve, 2000 * attempts));
}
}
}
// Determines if an error indicates OAuth authentication is required
private isOAuthError(error: unknown): boolean {
if (!error || typeof error !== 'object') {
return false;
}
// Check for SSE error with 401 status
if ('message' in error && typeof error.message === 'string') {
return error.message.includes('401') || error.message.includes('Non-200 status code (401)');
}
// Check for error code
if ('code' in error) {
const code = (error as { code?: number }).code;
return code === 401 || code === 403;
}
return false;
}
/** Manages OAuth flow initiation and completion */
protected async handleOAuthRequired(): Promise<{
tokens: MCPOAuthTokens | null;
clientInfo?: OAuthClientInformation;
} | null> {
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
if (!this.flowManager || !serverUrl) {
logger.error(
`${this.logPrefix} OAuth required but flow manager not available or server URL missing for ${this.serverName}`,
);
logger.warn(`${this.logPrefix} Please configure OAuth credentials for ${this.serverName}`);
return null;
}
try {
logger.debug(`${this.logPrefix} Checking for existing OAuth flow for ${this.serverName}...`);
/** Flow ID to check if a flow already exists */
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
/** Check if there's already an ongoing OAuth flow for this flowId */
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow && existingFlow.status === 'PENDING') {
logger.debug(
`${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`,
);
/** Tokens from existing flow to complete */
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth');
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(
`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`,
);
/** Client information from the existing flow metadata */
const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata;
const clientInfo = existingMetadata?.clientInfo;
return { tokens, clientInfo };
}
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);
const {
authorizationUrl,
flowId: newFlowId,
flowMetadata,
} = await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
serverUrl,
this.userId!,
this.serverConfig.oauth,
);
if (typeof this.oauthStart === 'function') {
logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`);
await this.oauthStart(authorizationUrl);
} else {
logger.info(`
Please visit the following URL to authenticate:
${authorizationUrl}
${this.logPrefix} Flow ID: ${newFlowId}
`);
}
/** Tokens from the new flow */
const tokens = await this.flowManager.createFlow(
newFlowId,
'mcp_oauth',
flowMetadata as FlowMetadata,
);
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`);
/** Client information from the flow metadata */
const clientInfo = flowMetadata?.clientInfo;
return { tokens, clientInfo };
} catch (error) {
logger.error(`${this.logPrefix} Failed to complete OAuth flow for ${this.serverName}`, error);
return null;
}
}
}

View file

@ -0,0 +1,263 @@
import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { logger } from '@librechat/data-schemas';
import pick from 'lodash/pick';
import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { UserConnectionManager } from '~/mcp/UserConnectionManager';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
import { CONSTANTS } from './enum';
import type * as t from './types';
/**
* Centralized manager for MCP server connections and tool execution.
* Extends UserConnectionManager to handle both app-level and user-specific connections.
*/
export class MCPManager extends UserConnectionManager {
private static instance: MCPManager | null;
// Connections shared by all users.
private appConnections: ConnectionsRepository | null = null;
/** Creates and initializes the singleton MCPManager instance */
public static async createInstance(configs: t.MCPServers): Promise<MCPManager> {
if (MCPManager.instance) throw new Error('MCPManager has already been initialized.');
MCPManager.instance = new MCPManager(configs);
await MCPManager.instance.initialize();
return MCPManager.instance;
}
/** Returns the singleton MCPManager instance */
public static getInstance(): MCPManager {
if (!MCPManager.instance) throw new Error('MCPManager has not been initialized.');
return MCPManager.instance;
}
/** Initializes the MCPManager by setting up server registry and app connections */
public async initialize() {
await this.serversRegistry.initialize();
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!);
}
/** Returns all app-level connections */
public async getAllConnections(): Promise<Map<string, MCPConnection>> {
return this.appConnections!.getAll();
}
/** Get servers that require OAuth */
public getOAuthServers(): Set<string> {
return this.serversRegistry.oauthServers!;
}
/** Returns all available tool functions from app-level connections */
public getAppToolFunctions(): t.LCAvailableTools {
return this.serversRegistry.toolFunctions!;
}
/**
* Get instructions for MCP servers
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
* @returns Object mapping server names to their instructions
*/
public getInstructions(serverNames?: string[]): Record<string, string> {
const instructions = this.serversRegistry.serverInstructions!;
if (!serverNames) return instructions;
return pick(instructions, serverNames);
}
/**
* Format MCP server instructions for injection into context
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
* @returns Formatted instructions string ready for context injection
*/
public formatInstructionsForContext(serverNames?: string[]): string {
/** Instructions for specified servers or all stored instructions */
const instructionsToInclude = this.getInstructions(serverNames);
if (Object.keys(instructionsToInclude).length === 0) {
return '';
}
// Format instructions for context injection
const formattedInstructions = Object.entries(instructionsToInclude)
.map(([serverName, instructions]) => {
return `## ${serverName} MCP Server Instructions
${instructions}`;
})
.join('\n\n');
return `# MCP Server Instructions
The following MCP servers are available with their specific instructions:
${formattedInstructions}
Please follow these instructions when using tools from the respective MCP servers.`;
}
/** Loads tools from all app-level connections into the manifest. */
public async loadManifestTools({
serverToolsCallback,
getServerTools,
}: {
flowManager: FlowStateManager<MCPOAuthTokens | null>;
serverToolsCallback?: (serverName: string, tools: t.LCManifestTool[]) => Promise<void>;
getServerTools?: (serverName: string) => Promise<t.LCManifestTool[] | undefined>;
}): Promise<t.LCToolManifest> {
const mcpTools: t.LCManifestTool[] = [];
const connections = await this.appConnections!.getAll();
for (const [serverName, connection] of connections.entries()) {
try {
if (!(await connection.isConnected())) {
logger.warn(
`[MCP][${serverName}] Connection not available for ${serverName} manifest tools.`,
);
if (typeof getServerTools !== 'function') {
logger.warn(
`[MCP][${serverName}] No \`getServerTools\` function provided, skipping tool loading.`,
);
continue;
}
const serverTools = await getServerTools(serverName);
if (serverTools && serverTools.length > 0) {
logger.info(`[MCP][${serverName}] Loaded tools from cache for manifest`);
mcpTools.push(...serverTools);
}
continue;
}
const tools = await connection.fetchTools();
const serverTools: t.LCManifestTool[] = [];
for (const tool of tools) {
const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
const config = this.serversRegistry.parsedConfigs[serverName];
const manifestTool: t.LCManifestTool = {
name: tool.name,
pluginKey,
description: tool.description ?? '',
icon: connection.iconPath,
authConfig: config?.customUserVars
? Object.entries(config.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}))
: undefined,
};
if (config?.chatMenu === false) {
manifestTool.chatMenu = false;
}
mcpTools.push(manifestTool);
serverTools.push(manifestTool);
}
if (typeof serverToolsCallback === 'function') {
await serverToolsCallback(serverName, serverTools);
}
} catch (error) {
logger.error(`[MCP][${serverName}] Error fetching tools for manifest:`, error);
}
}
return mcpTools;
}
/**
* Calls a tool on an MCP server, using either a user-specific connection
* (if userId is provided) or an app-level connection. Updates the last activity timestamp
* for user-specific connections upon successful call initiation.
*/
async callTool({
user,
serverName,
toolName,
provider,
toolArguments,
options,
tokenMethods,
flowManager,
oauthStart,
oauthEnd,
customUserVars,
}: {
user?: TUser;
serverName: string;
toolName: string;
provider: t.Provider;
toolArguments?: Record<string, unknown>;
options?: RequestOptions;
tokenMethods?: TokenMethods;
customUserVars?: Record<string, string>;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
}): Promise<t.FormattedToolResponse> {
/** User-specific connection */
let connection: MCPConnection | undefined;
const userId = user?.id;
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
try {
if (!this.appConnections?.has(serverName) && userId && user) {
this.updateUserLastActivity(userId);
/** Get or create user-specific connection */
connection = await this.getUserConnection({
user,
serverName,
flowManager,
tokenMethods,
oauthStart,
oauthEnd,
signal: options?.signal,
customUserVars,
});
} else {
/** App-level connection */
connection = await this.appConnections!.get(serverName);
if (!connection) {
throw new McpError(
ErrorCode.InvalidRequest,
`${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`,
);
}
}
if (!(await connection.isConnected())) {
/** May happen if getUserConnection failed silently or app connection dropped */
throw new McpError(
ErrorCode.InternalError, // Use InternalError for connection issues
`${logPrefix} Connection is not active. Cannot execute tool ${toolName}.`,
);
}
const result = await connection.client.request(
{
method: 'tools/call',
params: {
name: toolName,
arguments: toolArguments,
},
},
CallToolResultSchema,
{
timeout: connection.timeout,
...options,
},
);
if (userId) {
this.updateUserLastActivity(userId);
}
this.checkIdleConnections();
return formatToolContent(result as t.MCPToolCallResponse, provider);
} catch (error) {
// Log with context and re-throw or handle as needed
logger.error(`${logPrefix}[${toolName}] Tool call failed`, error);
// Rethrowing allows the caller (createMCPTool) to handle the final user message
throw error;
}
}
}

View file

@ -0,0 +1,200 @@
import { logger } from '@librechat/data-schemas';
import mapValues from 'lodash/mapValues';
import pickBy from 'lodash/pickBy';
import pick from 'lodash/pick';
import type { JsonSchemaType } from '~/types';
import type * as t from '~/mcp/types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { type MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
import { CONSTANTS } from '~/mcp/enum';
type ParsedServerConfig = t.MCPOptions & {
url?: string;
requiresOAuth?: boolean;
oauthMetadata?: Record<string, unknown> | null;
capabilities?: string;
tools?: string;
};
/**
* Manages MCP server configurations and metadata discovery.
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
* Determines which servers are for app-level connections.
* Has its own connections repository. All connections are disconnected after initialization.
*/
export class MCPServersRegistry {
private initialized: boolean = false;
private connections: ConnectionsRepository;
public readonly rawConfigs: t.MCPServers;
public readonly parsedConfigs: Record<string, ParsedServerConfig>;
public oauthServers: Set<string> | null = null;
public serverInstructions: Record<string, string> | null = null;
public toolFunctions: t.LCAvailableTools | null = null;
public appServerConfigs: t.MCPServers | null = null;
constructor(configs: t.MCPServers) {
this.rawConfigs = configs;
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv(con));
this.connections = new ConnectionsRepository(configs);
}
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
public async initialize() {
if (this.initialized) return;
this.initialized = true;
const serverNames = Object.keys(this.parsedConfigs);
await Promise.allSettled(serverNames.map((serverName) => this.gatherServerInfo(serverName)));
this.setOAuthServers();
this.setServerInstructions();
this.setAppServerConfigs();
await this.setAppToolFunctions();
this.connections.disconnectAll();
}
// Fetches all metadata for a single server in parallel
private async gatherServerInfo(serverName: string) {
try {
await Promise.allSettled([
this.fetchOAuthRequirement(serverName).catch((error) =>
logger.error(`${this.prefix(serverName)} Failed to fetch OAuth requirement:`, error),
),
this.fetchServerInstructions(serverName).catch((error) =>
logger.error(`${this.prefix(serverName)} Failed to fetch server instructions:`, error),
),
this.fetchServerCapabilities(serverName).catch((error) =>
logger.error(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error),
),
]);
this.logUpdatedConfig(serverName);
} catch (error) {
logger.error(`${this.prefix(serverName)} Failed to initialize server:`, error);
}
}
// Sets app-level server configs (startup enabled, non-OAuth servers)
private setAppServerConfigs() {
const appServers = Object.keys(
pickBy(
this.parsedConfigs,
(config) => config.startup !== false && config.requiresOAuth === false,
),
);
this.appServerConfigs = pick(this.rawConfigs, appServers);
}
// Creates set of server names that require OAuth authentication
private setOAuthServers() {
if (this.oauthServers) return this.oauthServers;
this.oauthServers = new Set(
Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)),
);
return this.oauthServers;
}
// Collects server instructions from all configured servers
private setServerInstructions() {
this.serverInstructions = mapValues(
pickBy(this.parsedConfigs, (config) => config.serverInstructions),
(config) => config.serverInstructions as string,
);
}
// Builds registry of all available tool functions from loaded connections
private async setAppToolFunctions() {
const connections = (await this.connections.getLoaded()).entries();
const allToolFunctions: t.LCAvailableTools = {};
for (const [serverName, conn] of connections) {
try {
const toolFunctions = await this.getToolFunctions(serverName, conn);
Object.assign(allToolFunctions, toolFunctions);
} catch (error) {
logger.error(`${this.prefix(serverName)} Error fetching tool functions:`, error);
}
}
this.toolFunctions = allToolFunctions;
}
// Converts server tools to LibreChat-compatible tool functions format
private async getToolFunctions(
serverName: string,
conn: MCPConnection,
): Promise<t.LCAvailableTools> {
const { tools } = await conn.client.listTools();
const toolFunctions: t.LCAvailableTools = {};
tools.forEach((tool) => {
const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
toolFunctions[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema as JsonSchemaType,
},
};
});
return toolFunctions;
}
// Determines if server requires OAuth if not already specified in the config
private async fetchOAuthRequirement(serverName: string) {
const config = this.parsedConfigs[serverName];
if (config.requiresOAuth != null) return;
if (config.url == null) return (config.requiresOAuth = false);
const result = await detectOAuthRequirement(config.url);
config.requiresOAuth = result.requiresOAuth;
config.oauthMetadata = result.metadata;
}
// Retrieves server instructions from MCP server if enabled in the config
private async fetchServerInstructions(serverName: string) {
const config = this.parsedConfigs[serverName];
if (!config.serverInstructions) return;
if (typeof config.serverInstructions === 'string') return;
const conn = await this.connections.get(serverName);
config.serverInstructions = conn.client.getInstructions();
if (!config.serverInstructions) {
logger.warn(`${this.prefix(serverName)} No server instructions available`);
}
}
// Fetches server capabilities and available tools list
private async fetchServerCapabilities(serverName: string) {
const config = this.parsedConfigs[serverName];
const conn = await this.connections.get(serverName);
const capabilities = conn.client.getServerCapabilities();
if (!capabilities) return;
config.capabilities = JSON.stringify(capabilities);
if (!capabilities.tools) return;
const tools = await conn.client.listTools();
config.tools = tools.tools.map((tool) => tool.name).join(', ');
}
// Logs server configuration summary after initialization
private logUpdatedConfig(serverName: string) {
const prefix = this.prefix(serverName);
const config = this.parsedConfigs[serverName];
logger.info(`${prefix} URL: ${config.url ?? 'N/A'}`);
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
logger.info(`${prefix} Tools: ${config.tools}`);
logger.info(`${prefix} Server Instructions: ${config.serverInstructions ?? 'None'}`);
}
// Returns formatted log prefix for server messages
private prefix(serverName: string): string {
return `[MCP][${serverName}]`;
}
}

View file

@ -0,0 +1,236 @@
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { MCPConnection } from './connection';
import type * as t from './types';
/**
* Abstract base class for managing user-specific MCP connections with lifecycle management.
* Only meant to be extended by MCPManager.
* Much of the logic was move here from the old MCPManager to make it more manageable.
* User connections will soon be ephemeral and not cached anymore:
* https://github.com/danny-avila/LibreChat/discussions/8790
*/
export abstract class UserConnectionManager {
protected readonly serversRegistry: MCPServersRegistry;
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
/** Last activity timestamp for users (not per server) */
protected userLastActivity: Map<string, number> = new Map();
protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
constructor(serverConfigs: t.MCPServers) {
this.serversRegistry = new MCPServersRegistry(serverConfigs);
}
/** fetches am MCP Server config from the registry */
public getRawConfig(serverName: string): t.MCPOptions | undefined {
return this.serversRegistry.rawConfigs[serverName];
}
/** Updates the last activity timestamp for a user */
protected updateUserLastActivity(userId: string): void {
const now = Date.now();
this.userLastActivity.set(userId, now);
logger.debug(
`[MCP][User: ${userId}] Updated last activity timestamp: ${new Date(now).toISOString()}`,
);
}
/** Gets or creates a connection for a specific user */
public async getUserConnection({
user,
serverName,
flowManager,
customUserVars,
tokenMethods,
oauthStart,
oauthEnd,
signal,
returnOnOAuth = false,
}: {
user: TUser;
serverName: string;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
customUserVars?: Record<string, string>;
tokenMethods?: TokenMethods;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
signal?: AbortSignal;
returnOnOAuth?: boolean;
}): Promise<MCPConnection> {
const userId = user.id;
if (!userId) {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
}
const userServerMap = this.userConnections.get(userId);
let connection = userServerMap?.get(serverName);
const now = Date.now();
// Check if user is idle
const lastActivity = this.userLastActivity.get(userId);
if (lastActivity && now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
logger.info(`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections.`);
// Disconnect all user connections
try {
await this.disconnectUserConnections(userId);
} catch (err) {
logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err);
}
connection = undefined; // Force creation of a new connection
} else if (connection) {
if (await connection.isConnected()) {
logger.debug(`[MCP][User: ${userId}][${serverName}] Reusing active connection`);
this.updateUserLastActivity(userId);
return connection;
} else {
// Connection exists but is not connected, attempt to remove potentially stale entry
logger.warn(
`[MCP][User: ${userId}][${serverName}] Found existing but disconnected connection object. Cleaning up.`,
);
this.removeUserConnection(userId, serverName); // Clean up maps
connection = undefined;
}
}
// If no valid connection exists, create a new one
if (!connection) {
logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
}
const config = this.serversRegistry.parsedConfigs[serverName];
if (!config) {
throw new McpError(
ErrorCode.InvalidRequest,
`[MCP][User: ${userId}] Configuration for server "${serverName}" not found.`,
);
}
try {
connection = await MCPConnectionFactory.create(
{
serverName: serverName,
serverConfig: config,
},
{
useOAuth: true,
user: user,
customUserVars: customUserVars,
flowManager: flowManager,
tokenMethods: tokenMethods,
signal: signal,
oauthStart: oauthStart,
oauthEnd: oauthEnd,
returnOnOAuth: returnOnOAuth,
},
);
if (!(await connection?.isConnected())) {
throw new Error('Failed to establish connection after initialization attempt.');
}
if (!this.userConnections.has(userId)) {
this.userConnections.set(userId, new Map());
}
this.userConnections.get(userId)?.set(serverName, connection);
logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`);
// Update timestamp on creation
this.updateUserLastActivity(userId);
return connection;
} catch (error) {
logger.error(`[MCP][User: ${userId}][${serverName}] Failed to establish connection`, error);
// Ensure partial connection state is cleaned up if initialization fails
await connection?.disconnect().catch((disconnectError) => {
logger.error(
`[MCP][User: ${userId}][${serverName}] Error during cleanup after failed connection`,
disconnectError,
);
});
// Ensure cleanup even if connection attempt fails
this.removeUserConnection(userId, serverName);
throw error; // Re-throw the error to the caller
}
}
/** Returns all connections for a specific user */
public getUserConnections(userId: string) {
return this.userConnections.get(userId);
}
/** Removes a specific user connection entry */
protected removeUserConnection(userId: string, serverName: string): void {
const userMap = this.userConnections.get(userId);
if (userMap) {
userMap.delete(serverName);
if (userMap.size === 0) {
this.userConnections.delete(userId);
// Only remove user activity timestamp if all connections are gone
this.userLastActivity.delete(userId);
}
}
logger.debug(`[MCP][User: ${userId}][${serverName}] Removed connection entry.`);
}
/** Disconnects and removes a specific user connection */
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
const userMap = this.userConnections.get(userId);
const connection = userMap?.get(serverName);
if (connection) {
logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`);
await connection.disconnect();
this.removeUserConnection(userId, serverName);
}
}
/** Disconnects and removes all connections for a specific user */
public async disconnectUserConnections(userId: string): Promise<void> {
const userMap = this.userConnections.get(userId);
const disconnectPromises: Promise<void>[] = [];
if (userMap) {
logger.info(`[MCP][User: ${userId}] Disconnecting all servers...`);
const userServers = Array.from(userMap.keys());
for (const serverName of userServers) {
disconnectPromises.push(
this.disconnectUserConnection(userId, serverName).catch((error) => {
logger.error(
`[MCP][User: ${userId}][${serverName}] Error during disconnection:`,
error,
);
}),
);
}
await Promise.allSettled(disconnectPromises);
// Ensure user activity timestamp is removed
this.userLastActivity.delete(userId);
logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);
}
}
/** Check for and disconnect idle connections */
protected checkIdleConnections(currentUserId?: string): void {
const now = Date.now();
// Iterate through all users to check for idle ones
for (const [userId, lastActivity] of this.userLastActivity.entries()) {
if (currentUserId && currentUserId === userId) {
continue;
}
if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
logger.info(
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`,
);
// Disconnect all user connections asynchronously (fire and forget)
this.disconnectUserConnections(userId).catch((err) =>
logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err),
);
}
}
}
}

View file

@ -0,0 +1,212 @@
import { logger } from '@librechat/data-schemas';
import { ConnectionsRepository } from '../ConnectionsRepository';
import { MCPConnectionFactory } from '../MCPConnectionFactory';
import { MCPConnection } from '../connection';
import type * as t from '../types';
// Mock external dependencies
jest.mock('@librechat/data-schemas', () => ({
logger: {
error: jest.fn(),
},
}));
jest.mock('../MCPConnectionFactory', () => ({
MCPConnectionFactory: {
create: jest.fn(),
},
}));
jest.mock('../connection');
const mockLogger = logger as jest.Mocked<typeof logger>;
describe('ConnectionsRepository', () => {
let repository: ConnectionsRepository;
let mockServerConfigs: t.MCPServers;
let mockConnection: jest.Mocked<MCPConnection>;
beforeEach(() => {
mockServerConfigs = {
server1: { url: 'http://localhost:3001' },
server2: { command: 'test-command', args: ['--test'] },
server3: { url: 'ws://localhost:8080', type: 'websocket' },
};
mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
repository = new ConnectionsRepository(mockServerConfigs);
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('has', () => {
it('should return true for existing server', () => {
expect(repository.has('server1')).toBe(true);
});
it('should return false for non-existing server', () => {
expect(repository.has('nonexistent')).toBe(false);
});
});
describe('get', () => {
it('should return existing connected connection', async () => {
mockConnection.isConnected.mockResolvedValue(true);
repository['connections'].set('server1', mockConnection);
const result = await repository.get('server1');
expect(result).toBe(mockConnection);
expect(MCPConnectionFactory.create).not.toHaveBeenCalled();
});
it('should create new connection if none exists', async () => {
const result = await repository.get('server1');
expect(result).toBe(mockConnection);
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
},
undefined,
);
expect(repository['connections'].get('server1')).toBe(mockConnection);
});
it('should create new connection if existing connection is not connected', async () => {
const oldConnection = {
isConnected: jest.fn().mockResolvedValue(false),
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
repository['connections'].set('server1', oldConnection);
const result = await repository.get('server1');
expect(result).toBe(mockConnection);
expect(oldConnection.disconnect).toHaveBeenCalled();
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
},
undefined,
);
});
it('should throw error for non-existent server configuration', async () => {
await expect(repository.get('nonexistent')).rejects.toThrow(
'[MCP][nonexistent] Server not found in configuration',
);
});
it('should handle MCPConnectionFactory.create errors', async () => {
const createError = new Error('Connection creation failed');
(MCPConnectionFactory.create as jest.Mock).mockRejectedValue(createError);
await expect(repository.get('server1')).rejects.toThrow('Connection creation failed');
});
});
describe('getMany', () => {
it('should return connections for multiple servers', async () => {
const result = await repository.getMany(['server1', 'server3']);
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(2);
expect(result.get('server1')).toBe(mockConnection);
expect(result.get('server3')).toBe(mockConnection);
});
});
describe('getLoaded', () => {
it('should return connections for loaded servers only', async () => {
// Load one connection
await repository.get('server1');
const result = await repository.getLoaded();
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(1);
expect(result.get('server1')).toBe(mockConnection);
});
it('should return empty map when no connections are loaded', async () => {
const result = await repository.getLoaded();
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(0);
});
});
describe('getAll', () => {
it('should return connections for all configured servers', async () => {
const result = await repository.getAll();
expect(result).toBeInstanceOf(Map);
expect(result.size).toBe(3);
expect(result.get('server1')).toBe(mockConnection);
expect(result.get('server2')).toBe(mockConnection);
expect(result.get('server3')).toBe(mockConnection);
});
});
describe('disconnect', () => {
it('should disconnect and remove existing connection', async () => {
repository['connections'].set('server1', mockConnection);
await repository.disconnect('server1');
expect(mockConnection.disconnect).toHaveBeenCalled();
expect(repository['connections'].has('server1')).toBe(false);
});
it('should handle disconnect error gracefully', async () => {
const disconnectError = new Error('Disconnect failed');
mockConnection.disconnect.mockRejectedValue(disconnectError);
repository['connections'].set('server1', mockConnection);
await repository.disconnect('server1');
expect(mockConnection.disconnect).toHaveBeenCalled();
expect(repository['connections'].has('server1')).toBe(false);
expect(mockLogger.error).toHaveBeenCalledWith(
'[MCP][server1] Error disconnecting',
disconnectError,
);
});
});
describe('disconnectAll', () => {
it('should disconnect all active connections', () => {
const mockConnection1 = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
const mockConnection2 = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
const mockConnection3 = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
repository['connections'].set('server1', mockConnection1);
repository['connections'].set('server2', mockConnection2);
repository['connections'].set('server3', mockConnection3);
const promises = repository.disconnectAll();
expect(promises).toHaveLength(3);
expect(Array.isArray(promises)).toBe(true);
});
});
});

View file

@ -0,0 +1,347 @@
import { logger } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { MCPConnectionFactory } from '../MCPConnectionFactory';
import { MCPOAuthHandler } from '~/mcp/oauth';
import { MCPConnection } from '../connection';
import { processMCPEnv } from '~/utils';
import type * as t from '../types';
jest.mock('../connection');
jest.mock('~/mcp/oauth');
jest.mock('~/utils');
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
const mockLogger = logger as jest.Mocked<typeof logger>;
const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction<typeof processMCPEnv>;
const mockMCPConnection = MCPConnection as jest.MockedClass<typeof MCPConnection>;
const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked<typeof MCPOAuthHandler>;
describe('MCPConnectionFactory', () => {
let mockUser: TUser;
let mockServerConfig: t.MCPOptions;
let mockFlowManager: jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
let mockConnectionInstance: jest.Mocked<MCPConnection>;
beforeEach(() => {
jest.clearAllMocks();
mockUser = {
id: 'user123',
email: 'test@example.com',
} as TUser;
mockServerConfig = {
command: 'node',
args: ['server.js'],
initTimeout: 5000,
} as t.MCPOptions;
mockFlowManager = {
createFlow: jest.fn(),
createFlowWithHandler: jest.fn(),
getFlowState: jest.fn(),
} as unknown as jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
mockConnectionInstance = {
connect: jest.fn(),
isConnected: jest.fn(),
setOAuthTokens: jest.fn(),
on: jest.fn().mockReturnValue(mockConnectionInstance),
emit: jest.fn(),
} as unknown as jest.Mocked<MCPConnection>;
mockMCPConnection.mockImplementation(() => mockConnectionInstance);
mockProcessMCPEnv.mockReturnValue(mockServerConfig);
});
describe('static create method', () => {
it('should create a basic connection without OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, undefined, undefined);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: undefined,
oauthTokens: null,
});
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
it('should create a connection with OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockTokens: MCPOAuthTokens = {
access_token: 'access123',
refresh_token: 'refresh123',
token_type: 'Bearer',
obtained_at: Date.now(),
};
mockFlowManager.createFlowWithHandler.mockResolvedValue(mockTokens);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, mockUser, undefined);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: mockTokens,
});
});
});
describe('OAuth token handling', () => {
it('should return null when no findToken method is provided', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: undefined as unknown as () => Promise<any>,
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(mockFlowManager.createFlowWithHandler).not.toHaveBeenCalled();
});
it('should handle token retrieval errors gracefully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockFlowManager.createFlowWithHandler.mockRejectedValue(new Error('Token fetch failed'));
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: null,
});
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('No existing tokens found or error loading tokens'),
expect.any(Error),
);
});
});
describe('OAuth event handling', () => {
it('should handle oauthRequired event for returnOnOAuth scenario', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
returnOnOAuth: true,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'flow123',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'random-state',
clientInfo: { client_id: 'client123' },
},
};
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail due to connection not established
}
expect(oauthRequiredHandler!).toBeDefined();
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).toHaveBeenCalledWith(
'test-server',
'https://api.example.com',
'user123',
undefined,
);
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({
message: 'OAuth flow initiated - return early',
}),
);
});
});
describe('connection retry logic', () => {
it('should establish connection successfully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig, // Use default 5000ms timeout
};
mockConnectionInstance.connect.mockResolvedValue(undefined);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockConnectionInstance.connect).toHaveBeenCalledTimes(1);
});
it('should handle OAuth errors during connection attempts', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const oauthError = new Error('Non-200 status code (401)');
(oauthError as unknown as Record<string, unknown>).isOAuthError = true;
mockConnectionInstance.connect.mockRejectedValue(oauthError);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow(
'Non-200 status code (401)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
describe('isOAuthError method', () => {
it('should identify OAuth errors by message content', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const error401 = new Error('401 Unauthorized');
mockConnectionInstance.connect.mockRejectedValue(error401);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow('401');
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
});

View file

@ -0,0 +1,287 @@
import { readFileSync } from 'fs';
import { join } from 'path';
import { logger } from '@librechat/data-schemas';
import { load as yamlLoad } from 'js-yaml';
import { ConnectionsRepository } from '../ConnectionsRepository';
import { MCPServersRegistry } from '../MCPServersRegistry';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { MCPConnection } from '../connection';
import type * as t from '../types';
// Mock external dependencies
jest.mock('../oauth/detectOAuth');
jest.mock('../ConnectionsRepository');
jest.mock('../connection');
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
// Mock processMCPEnv to verify it's called and adds a processed marker
jest.mock('~/utils', () => ({
...jest.requireActual('~/utils'),
processMCPEnv: jest.fn((config) => ({
...config,
_processed: true, // Simple marker to verify processing occurred
})),
}));
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
typeof detectOAuthRequirement
>;
const mockLogger = logger as jest.Mocked<typeof logger>;
describe('MCPServersRegistry - Initialize Function', () => {
let rawConfigs: t.MCPServers;
let expectedParsedConfigs: Record<string, any>;
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
beforeEach(() => {
// Load fixtures
const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml');
const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml');
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
string,
any
>;
// Setup mock connections
mockConnections = new Map();
const serverNames = Object.keys(rawConfigs);
serverNames.forEach((serverName) => {
const mockConnection = {
client: {
listTools: jest.fn(),
getInstructions: jest.fn(),
getServerCapabilities: jest.fn(),
},
} as unknown as jest.Mocked<MCPConnection>;
// Setup mock responses based on expected configs
const expectedConfig = expectedParsedConfigs[serverName];
// Mock listTools response
if (expectedConfig.tools) {
const toolNames = expectedConfig.tools.split(', ');
const tools = toolNames.map((name: string) => ({
name,
description: `Description for ${name}`,
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' },
},
},
}));
mockConnection.client.listTools.mockResolvedValue({ tools });
} else {
mockConnection.client.listTools.mockResolvedValue({ tools: [] });
}
// Mock getInstructions response
if (expectedConfig.serverInstructions) {
mockConnection.client.getInstructions.mockReturnValue(expectedConfig.serverInstructions);
} else {
mockConnection.client.getInstructions.mockReturnValue(null);
}
// Mock getServerCapabilities response
if (expectedConfig.capabilities) {
const capabilities = JSON.parse(expectedConfig.capabilities);
mockConnection.client.getServerCapabilities.mockReturnValue(capabilities);
} else {
mockConnection.client.getServerCapabilities.mockReturnValue(null);
}
mockConnections.set(serverName, mockConnection);
});
// Setup ConnectionsRepository mock
mockConnectionsRepo = {
get: jest.fn(),
getLoaded: jest.fn(),
disconnectAll: jest.fn(),
} as unknown as jest.Mocked<ConnectionsRepository>;
mockConnectionsRepo.get.mockImplementation((serverName: string) =>
Promise.resolve(mockConnections.get(serverName)!),
);
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
(ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo);
// Setup OAuth detection mock with deterministic results
mockDetectOAuthRequirement.mockImplementation((url: string) => {
const oauthResults: Record<string, any> = {
'https://api.github.com/mcp': {
requiresOAuth: true,
metadata: {
authorization_url: 'https://github.com/login/oauth/authorize',
token_url: 'https://github.com/login/oauth/access_token',
},
},
'https://api.disabled.com/mcp': {
requiresOAuth: false,
metadata: null,
},
'https://api.public.com/mcp': {
requiresOAuth: false,
metadata: null,
},
};
return Promise.resolve(oauthResults[url] || { requiresOAuth: false, metadata: null });
});
// Clear all mocks
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('initialize() method', () => {
it('should only run initialization once', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
await registry.initialize(); // Second call should not re-run
// Verify that connections are only requested for servers that need them
// (servers with serverInstructions=true or all servers for capabilities)
expect(mockConnectionsRepo.get).toHaveBeenCalled();
});
it('should set all public properties correctly after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Verify initial state
expect(registry.oauthServers).toBeNull();
expect(registry.serverInstructions).toBeNull();
expect(registry.toolFunctions).toBeNull();
expect(registry.appServerConfigs).toBeNull();
await registry.initialize();
// Test oauthServers Set
expect(registry.oauthServers).toBeInstanceOf(Set);
expect(registry.oauthServers).toEqual(
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
);
// Test serverInstructions
expect(registry.serverInstructions).toEqual({
oauth_server: 'GitHub MCP server instructions',
stdio_server: 'Follow these instructions for stdio server',
non_oauth_server: 'Public API instructions',
});
// Test appServerConfigs (startup enabled, non-OAuth servers only)
expect(registry.appServerConfigs).toEqual({
stdio_server: rawConfigs.stdio_server,
websocket_server: rawConfigs.websocket_server,
non_oauth_server: rawConfigs.non_oauth_server,
});
// Test toolFunctions (only 2 servers have tools: oauth_server has 1, stdio_server has 2)
const expectedToolFunctions = {
get_repository_mcp_oauth_server: {
type: 'function',
function: {
name: 'get_repository_mcp_oauth_server',
description: 'Description for get_repository',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
file_read_mcp_stdio_server: {
type: 'function',
function: {
name: 'file_read_mcp_stdio_server',
description: 'Description for file_read',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
file_write_mcp_stdio_server: {
type: 'function',
function: {
name: 'file_write_mcp_stdio_server',
description: 'Description for file_write',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
};
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
});
it('should handle errors gracefully and continue initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Make one server throw an error
mockDetectOAuthRequirement.mockRejectedValueOnce(new Error('OAuth detection failed'));
await registry.initialize();
// Should still initialize successfully
expect(registry.oauthServers).toBeInstanceOf(Set);
expect(registry.toolFunctions).toBeDefined();
// Error should be logged
expect(mockLogger.error).toHaveBeenCalledWith(
expect.stringContaining('[MCP][oauth_server] Failed to fetch OAuth requirement:'),
expect.any(Error),
);
});
it('should disconnect all connections after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
expect(mockConnectionsRepo.disconnectAll).toHaveBeenCalledTimes(1);
});
it('should log configuration updates for each server', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
const serverNames = Object.keys(rawConfigs);
serverNames.forEach((serverName) => {
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] URL:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] OAuth Required:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Capabilities:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Tools:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Server Instructions:`),
);
});
});
it('should have parsedConfigs matching the expected fixture after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
// Compare the actual parsedConfigs against the expected fixture
expect(registry.parsedConfigs).toEqual(expectedParsedConfigs);
});
});
});

View file

@ -1,7 +1,7 @@
import type { PluginAuthMethods } from '@librechat/data-schemas';
import type { GenericTool } from '@librechat/agents';
import { getPluginAuthMap } from '~/agents/auth';
import { getUserMCPAuthMap } from './auth';
import { getUserMCPAuthMap } from '../auth';
jest.mock('~/agents/auth', () => ({
getPluginAuthMap: jest.fn(),

View file

@ -0,0 +1,76 @@
// Integration tests for OAuth detection against real public MCP servers
// These tests verify the actual behavior against live endpoints
//
// DEVELOPMENT ONLY: This file is excluded from the test suite (.dev.ts extension)
// Use this for development and debugging OAuth detection behavior
//
// To run manually from packages/api directory:
// npx jest --testMatch="**/detectOAuth.integration.dev.ts"
import { detectOAuthRequirement } from '~/mcp/oauth';
describe('OAuth Detection Integration Tests', () => {
const NETWORK_TIMEOUT = 10000;
interface TestServer {
name: string;
url: string;
expectedOAuth: boolean;
expectedMethod: string;
withMeta: boolean;
}
const testServers: TestServer[] = [
{
name: 'GitHub Copilot MCP Server',
url: 'https://api.githubcopilot.com/mcp',
expectedOAuth: true,
expectedMethod: '401-challenge-metadata',
withMeta: true,
},
{
name: 'GitHub API (401 without metadata)',
url: 'https://api.github.com/user',
expectedOAuth: true,
expectedMethod: 'no-metadata-found',
withMeta: false,
},
{
name: 'Stytch Todo MCP Server',
url: 'https://mcp-stytch-consumer-todo-list.maxwell-gerber42.workers.dev',
expectedOAuth: true,
expectedMethod: 'protected-resource-metadata',
withMeta: true,
},
{
name: 'HTTPBin (Non-OAuth)',
url: 'https://httpbin.org',
expectedOAuth: false,
expectedMethod: 'no-metadata-found',
withMeta: false,
},
{
name: 'Unreachable Server',
url: 'https://definitely-not-a-real-server-12345.com',
expectedOAuth: false,
expectedMethod: 'no-metadata-found',
withMeta: false,
},
];
describe('detectOAuthRequirement integration', () => {
testServers.forEach((server) => {
it(
`should handle ${server.name}`,
async () => {
const result = await detectOAuthRequirement(server.url);
expect(result.requiresOAuth).toBe(server.expectedOAuth);
expect(result.method).toBe(server.expectedMethod);
expect(result.metadata == null).toBe(!server.withMeta);
},
NETWORK_TIMEOUT,
);
});
});
});

View file

@ -0,0 +1,74 @@
# Expected parsed MCP server configurations after running initialize()
# These represent the expected state of parsedConfigs after all fetch functions complete
oauth_server:
_processed: true
type: "streamable-http"
url: "https://api.github.com/mcp"
headers:
Authorization: "Bearer {{GITHUB_TOKEN}}"
serverInstructions: "GitHub MCP server instructions"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://github.com/login/oauth/authorize"
token_url: "https://github.com/login/oauth/access_token"
capabilities: '{"tools":{"listChanged":true},"resources":{},"prompts":{}}'
tools: "get_repository"
oauth_predefined:
_processed: true
type: "sse"
url: "https://api.example.com/sse"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://example.com/oauth/authorize"
token_url: "https://example.com/oauth/token"
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
stdio_server:
_processed: true
command: "node"
args: ["server.js"]
env:
API_KEY: "${TEST_API_KEY}"
startup: true
serverInstructions: "Follow these instructions for stdio server"
requiresOAuth: false
capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}'
tools: "file_read, file_write"
websocket_server:
_processed: true
type: "websocket"
url: "ws://localhost:3001/mcp"
startup: true
requiresOAuth: false
oauthMetadata: null
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
disabled_server:
_processed: true
type: "streamable-http"
url: "https://api.disabled.com/mcp"
startup: false
requiresOAuth: false
oauthMetadata: null
non_oauth_server:
_processed: true
type: "streamable-http"
url: "https://api.public.com/mcp"
requiresOAuth: false
serverInstructions: "Public API instructions"
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
oauth_startup_enabled:
_processed: true
type: "sse"
url: "https://api.oauth-startup.com/sse"
requiresOAuth: true
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""

View file

@ -0,0 +1,53 @@
# Raw MCP server configurations used as input to MCPServersRegistry constructor
# These configs test different code paths in the initialization process
# Test OAuth detection with URL - should trigger fetchOAuthRequirement
oauth_server:
type: "streamable-http"
url: "https://api.github.com/mcp"
headers:
Authorization: "Bearer {{GITHUB_TOKEN}}"
serverInstructions: true
# Test OAuth already specified - should skip OAuth detection
oauth_predefined:
type: "sse"
url: "https://api.example.com/sse"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://example.com/oauth/authorize"
token_url: "https://example.com/oauth/token"
# Test stdio server without URL - should set requiresOAuth to false
stdio_server:
command: "node"
args: ["server.js"]
env:
API_KEY: "${TEST_API_KEY}"
startup: true
serverInstructions: "Follow these instructions for stdio server"
# Test websocket server with capabilities but no tools
websocket_server:
type: "websocket"
url: "ws://localhost:3001/mcp"
startup: true
# Test server with startup disabled - should not be included in appServerConfigs
disabled_server:
type: "streamable-http"
url: "https://api.disabled.com/mcp"
startup: false
# Test non-OAuth server - should be included in appServerConfigs
non_oauth_server:
type: "streamable-http"
url: "https://api.public.com/mcp"
requiresOAuth: false
serverInstructions: true
# Test server with OAuth but startup enabled - should not be in appServerConfigs
oauth_startup_enabled:
type: "sse"
url: "https://api.oauth-startup.com/sse"
requiresOAuth: true

View file

@ -1,5 +1,5 @@
import { MCPOAuthHandler } from './handler';
import type { MCPOptions } from 'librechat-data-provider';
import { MCPOAuthHandler } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {

View file

@ -1,4 +1,4 @@
import { normalizeServerName } from './utils';
import { normalizeServerName } from '../utils';
describe('normalizeServerName', () => {
it('should not modify server names that already match the pattern', () => {

View file

@ -2,7 +2,7 @@
// zod.spec.ts
import { z } from 'zod';
import type { JsonSchemaType } from '~/types';
import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from './zod';
import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from '../zod';
describe('convertJsonSchemaToZod', () => {
describe('primitive types', () => {

View file

@ -1,17 +1,18 @@
import { EventEmitter } from 'events';
import { logger } from '@librechat/data-schemas';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import {
StdioClientTransport,
getDefaultEnvironment,
} from '@modelcontextprotocol/sdk/client/stdio.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { logger } from '@librechat/data-schemas';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type { MCPOAuthTokens } from './oauth/types';
import { mcpConfig } from './mcpConfig';
import type * as t from './types';
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
@ -56,9 +57,17 @@ function isStreamableHTTPOptions(options: t.MCPOptions): options is t.Streamable
}
const FIVE_MINUTES = 5 * 60 * 1000;
interface MCPConnectionParams {
serverName: string;
serverConfig: t.MCPOptions;
userId?: string;
oauthTokens?: MCPOAuthTokens | null;
}
export class MCPConnection extends EventEmitter {
private static instance: MCPConnection | null = null;
public client: Client;
private options: t.MCPOptions;
private transport: Transport | null = null; // Make this nullable
private connectionState: t.ConnectionState = 'disconnected';
private connectPromise: Promise<void> | null = null;
@ -70,26 +79,23 @@ export class MCPConnection extends EventEmitter {
private reconnectAttempts = 0;
private readonly userId?: string;
private lastPingTime: number;
private lastConnectionCheckAt: number = 0;
private oauthTokens?: MCPOAuthTokens | null;
private oauthRequired = false;
iconPath?: string;
timeout?: number;
url?: string;
constructor(
serverName: string,
private readonly options: t.MCPOptions,
userId?: string,
oauthTokens?: MCPOAuthTokens | null,
) {
constructor(params: MCPConnectionParams) {
super();
this.serverName = serverName;
this.userId = userId;
this.iconPath = options.iconPath;
this.timeout = options.timeout;
this.options = params.serverConfig;
this.serverName = params.serverName;
this.userId = params.userId;
this.iconPath = params.serverConfig.iconPath;
this.timeout = params.serverConfig.timeout;
this.lastPingTime = Date.now();
if (oauthTokens) {
this.oauthTokens = oauthTokens;
if (params.oauthTokens) {
this.oauthTokens = params.oauthTokens;
}
this.client = new Client(
{
@ -110,28 +116,6 @@ export class MCPConnection extends EventEmitter {
return `[MCP]${userPart}[${this.serverName}]`;
}
public static getInstance(
serverName: string,
options: t.MCPOptions,
userId?: string,
): MCPConnection {
if (!MCPConnection.instance) {
MCPConnection.instance = new MCPConnection(serverName, options, userId);
}
return MCPConnection.instance;
}
public static getExistingInstance(): MCPConnection | null {
return MCPConnection.instance;
}
public static async destroyInstance(): Promise<void> {
if (MCPConnection.instance) {
await MCPConnection.instance.disconnect();
MCPConnection.instance = null;
}
}
private emitError(error: unknown, errorContext: string): void {
const errorMessage = error instanceof Error ? error.message : String(error);
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
@ -589,6 +573,13 @@ export class MCPConnection extends EventEmitter {
return false;
}
// If we recently checked, skip expensive verification
const now = Date.now();
if (now - this.lastConnectionCheckAt < mcpConfig.CONNECTION_CHECK_TTL) {
return true;
}
this.lastConnectionCheckAt = now;
try {
// Try ping first as it's the lightest check
await this.client.ping();

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,11 @@
import { math, isEnabled } from '~/utils';
/**
* Centralized configuration for MCP-related environment variables.
* Provides typed access to MCP settings with default values.
*/
export const mcpConfig = {
OAUTH_ON_AUTH_ERROR: isEnabled(process.env.MCP_OAUTH_ON_AUTH_ERROR ?? true),
OAUTH_DETECTION_TIMEOUT: math(process.env.MCP_OAUTH_DETECTION_TIMEOUT ?? 5000),
CONNECTION_CHECK_TTL: math(process.env.MCP_CONNECTION_CHECK_TTL ?? 60000),
};

View file

@ -0,0 +1,120 @@
// ATTENTION: If you modify OAuth detection logic in this file, run the integration tests to verify:
// npx jest --testMatch="**/detectOAuth.integration.dev.ts" (from packages/api directory)
//
// These tests are excluded from CI because they make live HTTP requests to external services,
// which could cause flaky builds due to network issues or changes in third-party endpoints.
// Manual testing ensures the OAuth detection still works against real MCP servers.
import { discoverOAuthProtectedResourceMetadata } from '@modelcontextprotocol/sdk/client/auth.js';
import { mcpConfig } from '../mcpConfig';
export interface OAuthDetectionResult {
requiresOAuth: boolean;
method: 'protected-resource-metadata' | '401-challenge-metadata' | 'no-metadata-found';
metadata?: Record<string, unknown> | null;
}
/**
* Detects if an MCP server requires OAuth authentication using proactive discovery methods.
*
* This function implements a comprehensive OAuth detection strategy:
* 1. Standard Protected Resource Metadata (RFC 9728) - checks /.well-known/oauth-protected-resource
* 2. 401 Challenge Method - checks WWW-Authenticate header for resource_metadata URL
* 3. Optional fallback: treat any 401/403 response as OAuth requirement (if MCP_OAUTH_ON_AUTH_ERROR=true)
*
* @param serverUrl - The MCP server URL to check for OAuth requirements
* @returns Promise<OAuthDetectionResult> - OAuth requirement details
*/
export async function detectOAuthRequirement(serverUrl: string): Promise<OAuthDetectionResult> {
const protectedResourceResult = await checkProtectedResourceMetadata(serverUrl);
if (protectedResourceResult) return protectedResourceResult;
const challengeResult = await check401ChallengeMetadata(serverUrl);
if (challengeResult) return challengeResult;
const fallbackResult = await checkAuthErrorFallback(serverUrl);
if (fallbackResult) return fallbackResult;
// No OAuth detected
return {
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// ------------------------ Private helper functions for OAuth detection -------------------------//
////////////////////////////////////////////////////////////////////////////////////////////////////
// Checks for OAuth using standard protected resource metadata (RFC 9728)
async function checkProtectedResourceMetadata(
serverUrl: string,
): Promise<OAuthDetectionResult | null> {
try {
const resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl);
if (!resourceMetadata?.authorization_servers?.length) return null;
return {
requiresOAuth: true,
method: 'protected-resource-metadata',
metadata: resourceMetadata,
};
} catch {
return null;
}
}
// Checks for OAuth using 401 challenge with resource metadata URL
async function check401ChallengeMetadata(serverUrl: string): Promise<OAuthDetectionResult | null> {
try {
const response = await fetch(serverUrl, {
method: 'HEAD',
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
});
if (response.status !== 401) return null;
const wwwAuth = response.headers.get('www-authenticate');
const metadataUrl = wwwAuth?.match(/resource_metadata="([^"]+)"/)?.[1];
if (!metadataUrl) return null;
const metadataResponse = await fetch(metadataUrl, {
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
});
const metadata = await metadataResponse.json();
if (!metadata?.authorization_servers?.length) return null;
return {
requiresOAuth: true,
method: '401-challenge-metadata',
metadata,
};
} catch {
return null;
}
}
// Fallback method: treats any auth error as OAuth requirement if configured
async function checkAuthErrorFallback(serverUrl: string): Promise<OAuthDetectionResult | null> {
try {
if (!mcpConfig.OAUTH_ON_AUTH_ERROR) return null;
const response = await fetch(serverUrl, {
method: 'HEAD',
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
});
if (response.status !== 401 && response.status !== 403) return null;
return {
requiresOAuth: true,
method: 'no-metadata-found',
metadata: null,
};
} catch {
return null;
}
}

View file

@ -1,3 +1,4 @@
export * from './types';
export * from './handler';
export * from './tokens';
export * from './detectOAuth';

View file

@ -3,6 +3,13 @@ import { TokenExchangeMethodEnum } from './types/agents';
import { extractEnvVariable } from './utils';
const BaseOptionsSchema = z.object({
/**
* Controls whether the MCP server is initialized during application startup.
* - true (default): Server is initialized during app startup and included in app-level connections
* - false: Skips initialization at startup and excludes from app-level connections - useful for servers
* requiring manual authentication (e.g., GitHub PAT tokens) that need to be configured through the UI after startup
*/
startup: z.boolean().optional(),
iconPath: z.string().optional(),
timeout: z.number().optional(),
initTimeout: z.number().optional(),
@ -15,6 +22,11 @@ const BaseOptionsSchema = z.object({
* - string: Use custom instructions (overrides server-provided)
*/
serverInstructions: z.union([z.boolean(), z.string()]).optional(),
/**
* Whether this server requires OAuth authentication
* If not specified, will be auto-detected during construction
*/
requiresOAuth: z.boolean().optional(),
/**
* OAuth configuration for SSE and Streamable HTTP transports
* - Optional: OAuth can be auto-discovered on 401 responses