mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
♻️ refactor: MCPManager for Scalability, Fix App-Level Detection, Add Lazy Connections (#8930)
* feat: MCP Connection management overhaul - Making MCPManager manageable Refactor the monolithic MCPManager into focused, single-responsibility classes: • MCPServersRegistry: Server configuration discovery and metadata management • UserConnectionManager: Manages user-level connections • ConnectionsRepository: Low-level connection pool with lazy loading • MCPConnectionFactory: Handles MCP connection creation with OAuth support New Features: • Lazy loading of app-level connections for horizontal scaling • Automatic reconnection for app-level connections • Enhanced OAuth detection with explicit requiresOAuth flag • Centralized MCP configuration management Bug Fixes: • App-level connection detection in MCPManager.callTool • MCP Connection Reinitialization route behavior Optimizations: • MCPConnection.isConnected() caching to reduce overhead • Concurrent server metadata retrieval instead of sequential This refactoring addresses scalability bottlenecks and improves reliability while maintaining backward compatibility with existing configurations. * feat: Enabled import order in eslint. * # Moved tests to __tests__ folder # added tests for MCPServersRegistry.ts * # Add unit tests for ConnectionsRepository functionality * # Add unit tests for MCPConnectionFactory functionality * # Reorganize MCP connection tests and improve error handling * # reordering imports * # Update testPathIgnorePatterns in jest.config.mjs to exclude development TypeScript files * # removed mcp/manager.ts
This commit is contained in:
parent
9dbf153489
commit
8780a78165
32 changed files with 2571 additions and 1468 deletions
13
.env.example
13
.env.example
|
@ -698,3 +698,16 @@ OPENWEATHER_API_KEY=
|
||||||
# JINA_API_KEY=your_jina_api_key
|
# JINA_API_KEY=your_jina_api_key
|
||||||
# or
|
# or
|
||||||
# COHERE_API_KEY=your_cohere_api_key
|
# COHERE_API_KEY=your_cohere_api_key
|
||||||
|
|
||||||
|
#======================#
|
||||||
|
# MCP Configuration #
|
||||||
|
#======================#
|
||||||
|
|
||||||
|
# Treat 401/403 responses as OAuth requirement when no oauth metadata found
|
||||||
|
# MCP_OAUTH_ON_AUTH_ERROR=true
|
||||||
|
|
||||||
|
# Timeout for OAuth detection requests in milliseconds
|
||||||
|
# MCP_OAUTH_DETECTION_TIMEOUT=5000
|
||||||
|
|
||||||
|
# Cache connection status checks for this many milliseconds to avoid expensive verification
|
||||||
|
# MCP_CONNECTION_CHECK_TTL=60000
|
||||||
|
|
4
.github/CONTRIBUTING.md
vendored
4
.github/CONTRIBUTING.md
vendored
|
@ -147,7 +147,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||||
## 8. Module Import Conventions
|
## 8. Module Import Conventions
|
||||||
|
|
||||||
- `npm` packages first,
|
- `npm` packages first,
|
||||||
- from shortest line (top) to longest (bottom)
|
- from longest line (top) to shortest (bottom)
|
||||||
|
|
||||||
- Followed by typescript types (pertains to data-provider and client workspaces)
|
- Followed by typescript types (pertains to data-provider and client workspaces)
|
||||||
- longest line (top) to shortest (bottom)
|
- longest line (top) to shortest (bottom)
|
||||||
|
@ -157,6 +157,8 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||||
- longest line (top) to shortest (bottom)
|
- longest line (top) to shortest (bottom)
|
||||||
- imports with alias `~` treated the same as relative import with respect to line length
|
- imports with alias `~` treated the same as relative import with respect to line length
|
||||||
|
|
||||||
|
**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
Please ensure that you adapt this summary to fit the specific context and nuances of your project.
|
Please ensure that you adapt this summary to fit the specific context and nuances of your project.
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -137,3 +137,4 @@ helm/**/.values.yaml
|
||||||
/.openai/
|
/.openai/
|
||||||
/.tabnine/
|
/.tabnine/
|
||||||
/.codeium
|
/.codeium
|
||||||
|
*.local.md
|
||||||
|
|
|
@ -1,27 +1,13 @@
|
||||||
|
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||||
const { EventSource } = require('eventsource');
|
const { EventSource } = require('eventsource');
|
||||||
const { Time } = require('librechat-data-provider');
|
const { Time } = require('librechat-data-provider');
|
||||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
|
||||||
const logger = require('./winston');
|
const logger = require('./winston');
|
||||||
|
|
||||||
global.EventSource = EventSource;
|
global.EventSource = EventSource;
|
||||||
|
|
||||||
/** @type {MCPManager} */
|
/** @type {MCPManager} */
|
||||||
let mcpManager = null;
|
|
||||||
let flowManager = null;
|
let flowManager = null;
|
||||||
|
|
||||||
/**
|
|
||||||
* @param {string} [userId] - Optional user ID, to avoid disconnecting the current user.
|
|
||||||
* @returns {MCPManager}
|
|
||||||
*/
|
|
||||||
function getMCPManager(userId) {
|
|
||||||
if (!mcpManager) {
|
|
||||||
mcpManager = MCPManager.getInstance();
|
|
||||||
} else {
|
|
||||||
mcpManager.checkIdleConnections(userId);
|
|
||||||
}
|
|
||||||
return mcpManager;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {Keyv} flowsCache
|
* @param {Keyv} flowsCache
|
||||||
* @returns {FlowStateManager}
|
* @returns {FlowStateManager}
|
||||||
|
@ -37,6 +23,7 @@ function getFlowStateManager(flowsCache) {
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
logger,
|
logger,
|
||||||
getMCPManager,
|
createMCPManager: MCPManager.createInstance,
|
||||||
|
getMCPManager: MCPManager.getInstance,
|
||||||
getFlowStateManager,
|
getFlowStateManager,
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
const express = require('express');
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
const request = require('supertest');
|
const request = require('supertest');
|
||||||
const mongoose = require('mongoose');
|
const mongoose = require('mongoose');
|
||||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
const express = require('express');
|
||||||
|
|
||||||
jest.mock('@librechat/api', () => ({
|
jest.mock('@librechat/api', () => ({
|
||||||
MCPOAuthHandler: {
|
MCPOAuthHandler: {
|
||||||
|
@ -494,12 +494,9 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return 500 when token retrieval throws an unexpected error', async () => {
|
it('should return 500 when token retrieval throws an unexpected error', async () => {
|
||||||
const mockFlowManager = {
|
getLogStores.mockImplementation(() => {
|
||||||
getFlowState: jest.fn().mockRejectedValue(new Error('Database connection failed')),
|
throw new Error('Database connection failed');
|
||||||
};
|
});
|
||||||
|
|
||||||
getLogStores.mockReturnValue({});
|
|
||||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow');
|
const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow');
|
||||||
|
|
||||||
|
@ -563,8 +560,8 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('POST /oauth/cancel/:serverName', () => {
|
describe('POST /oauth/cancel/:serverName', () => {
|
||||||
const { getLogStores } = require('~/cache');
|
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
const { MCPOAuthHandler } = require('@librechat/api');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
it('should cancel OAuth flow successfully', async () => {
|
it('should cancel OAuth flow successfully', async () => {
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
|
@ -644,15 +641,15 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('POST /:serverName/reinitialize', () => {
|
describe('POST /:serverName/reinitialize', () => {
|
||||||
const { loadCustomConfig } = require('~/server/services/Config');
|
|
||||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
|
||||||
|
|
||||||
it('should return 404 when server is not found in configuration', async () => {
|
it('should return 404 when server is not found in configuration', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
const mockMcpManager = {
|
||||||
mcpServers: {
|
getRawConfig: jest.fn().mockReturnValue(null),
|
||||||
'other-server': {},
|
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||||
},
|
};
|
||||||
});
|
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||||
|
require('~/cache').getLogStores.mockReturnValue({});
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/non-existent-server/reinitialize');
|
const response = await request(app).post('/api/mcp/non-existent-server/reinitialize');
|
||||||
|
|
||||||
|
@ -663,16 +660,11 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle OAuth requirement during reinitialize', async () => {
|
it('should handle OAuth requirement during reinitialize', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
|
||||||
mcpServers: {
|
|
||||||
'oauth-server': {
|
|
||||||
customUserVars: {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const mockMcpManager = {
|
const mockMcpManager = {
|
||||||
disconnectServer: jest.fn().mockResolvedValue(),
|
getRawConfig: jest.fn().mockReturnValue({
|
||||||
|
customUserVars: {},
|
||||||
|
}),
|
||||||
|
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||||
mcpConfigs: {},
|
mcpConfigs: {},
|
||||||
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
|
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
|
||||||
if (oauthStart) {
|
if (oauthStart) {
|
||||||
|
@ -690,7 +682,7 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
expect(response.body).toEqual({
|
expect(response.body).toEqual({
|
||||||
success: 'https://oauth.example.com/auth',
|
success: true,
|
||||||
message: "MCP server 'oauth-server' ready for OAuth authentication",
|
message: "MCP server 'oauth-server' ready for OAuth authentication",
|
||||||
serverName: 'oauth-server',
|
serverName: 'oauth-server',
|
||||||
oauthRequired: true,
|
oauthRequired: true,
|
||||||
|
@ -699,14 +691,9 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return 500 when reinitialize fails with non-OAuth error', async () => {
|
it('should return 500 when reinitialize fails with non-OAuth error', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
|
||||||
mcpServers: {
|
|
||||||
'error-server': {},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const mockMcpManager = {
|
const mockMcpManager = {
|
||||||
disconnectServer: jest.fn().mockResolvedValue(),
|
getRawConfig: jest.fn().mockReturnValue({}),
|
||||||
|
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||||
mcpConfigs: {},
|
mcpConfigs: {},
|
||||||
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
|
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
|
||||||
};
|
};
|
||||||
|
@ -724,7 +711,13 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return 500 when unexpected error occurs', async () => {
|
it('should return 500 when unexpected error occurs', async () => {
|
||||||
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
|
const mockMcpManager = {
|
||||||
|
getRawConfig: jest.fn().mockImplementation(() => {
|
||||||
|
throw new Error('Config loading failed');
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||||
|
|
||||||
|
@ -747,29 +740,17 @@ describe('MCP Routes', () => {
|
||||||
expect(response.body).toEqual({ error: 'User not authenticated' });
|
expect(response.body).toEqual({ error: 'User not authenticated' });
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle errors when fetching custom user variables', async () => {
|
it('should successfully reinitialize server and cache tools', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
|
||||||
mcpServers: {
|
|
||||||
'test-server': {
|
|
||||||
customUserVars: {
|
|
||||||
API_KEY: 'test-key-var',
|
|
||||||
SECRET_TOKEN: 'test-secret-var',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
getUserPluginAuthValue
|
|
||||||
.mockResolvedValueOnce('test-api-key-value')
|
|
||||||
.mockRejectedValueOnce(new Error('Database error'));
|
|
||||||
|
|
||||||
const mockUserConnection = {
|
const mockUserConnection = {
|
||||||
fetchTools: jest.fn().mockResolvedValue([]),
|
fetchTools: jest.fn().mockResolvedValue([
|
||||||
|
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
|
||||||
|
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
|
||||||
|
]),
|
||||||
};
|
};
|
||||||
|
|
||||||
const mockMcpManager = {
|
const mockMcpManager = {
|
||||||
disconnectServer: jest.fn().mockResolvedValue(),
|
getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }),
|
||||||
mcpConfigs: {},
|
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -784,38 +765,54 @@ describe('MCP Routes', () => {
|
||||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
expect(response.body.success).toBe(true);
|
expect(response.body).toEqual({
|
||||||
|
success: true,
|
||||||
|
message: "MCP server 'test-server' reinitialized successfully",
|
||||||
|
serverName: 'test-server',
|
||||||
|
oauthRequired: false,
|
||||||
|
oauthUrl: null,
|
||||||
|
});
|
||||||
|
expect(mockMcpManager.disconnectUserConnection).toHaveBeenCalledWith(
|
||||||
|
'test-user-id',
|
||||||
|
'test-server',
|
||||||
|
);
|
||||||
|
expect(setCachedTools).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return failure message when reinitialize completely fails', async () => {
|
it('should handle server with custom user variables', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
const mockUserConnection = {
|
||||||
mcpServers: {
|
fetchTools: jest.fn().mockResolvedValue([]),
|
||||||
'test-server': {},
|
};
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const mockMcpManager = {
|
const mockMcpManager = {
|
||||||
disconnectServer: jest.fn().mockResolvedValue(),
|
getRawConfig: jest.fn().mockReturnValue({
|
||||||
mcpConfigs: {},
|
endpoint: 'http://test-server.com',
|
||||||
getUserConnection: jest.fn().mockResolvedValue(null),
|
customUserVars: {
|
||||||
|
API_KEY: 'some-env-var',
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||||
|
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||||
};
|
};
|
||||||
|
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||||
require('~/cache').getLogStores.mockReturnValue({});
|
require('~/cache').getLogStores.mockReturnValue({});
|
||||||
|
require('~/server/services/PluginService').getUserPluginAuthValue.mockResolvedValue(
|
||||||
|
'api-key-value',
|
||||||
|
);
|
||||||
|
|
||||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||||
const { Constants } = require('librechat-data-provider');
|
getCachedTools.mockResolvedValue({});
|
||||||
getCachedTools.mockResolvedValue({
|
|
||||||
[`existing-tool${Constants.mcp_delimiter}test-server`]: { type: 'function' },
|
|
||||||
});
|
|
||||||
setCachedTools.mockResolvedValue();
|
setCachedTools.mockResolvedValue();
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
expect(response.body.success).toBe(false);
|
expect(response.body.success).toBe(true);
|
||||||
expect(response.body.message).toBe("Failed to reinitialize MCP server 'test-server'");
|
expect(
|
||||||
|
require('~/server/services/PluginService').getUserPluginAuthValue,
|
||||||
|
).toHaveBeenCalledWith('test-user-id', 'API_KEY', false);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -984,21 +981,19 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('GET /:serverName/auth-values', () => {
|
describe('GET /:serverName/auth-values', () => {
|
||||||
const { loadCustomConfig } = require('~/server/services/Config');
|
|
||||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||||
|
|
||||||
it('should return auth value flags for server', async () => {
|
it('should return auth value flags for server', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
const mockMcpManager = {
|
||||||
mcpServers: {
|
getRawConfig: jest.fn().mockReturnValue({
|
||||||
'test-server': {
|
customUserVars: {
|
||||||
customUserVars: {
|
API_KEY: 'some-env-var',
|
||||||
API_KEY: 'some-env-var',
|
SECRET_TOKEN: 'another-env-var',
|
||||||
SECRET_TOKEN: 'another-env-var',
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
});
|
};
|
||||||
|
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
|
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||||
|
@ -1017,11 +1012,11 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return 404 when server is not found in configuration', async () => {
|
it('should return 404 when server is not found in configuration', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
const mockMcpManager = {
|
||||||
mcpServers: {
|
getRawConfig: jest.fn().mockReturnValue(null),
|
||||||
'other-server': {},
|
};
|
||||||
},
|
|
||||||
});
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
|
const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
|
||||||
|
|
||||||
|
@ -1032,16 +1027,15 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle errors when checking auth values', async () => {
|
it('should handle errors when checking auth values', async () => {
|
||||||
loadCustomConfig.mockResolvedValue({
|
const mockMcpManager = {
|
||||||
mcpServers: {
|
getRawConfig: jest.fn().mockReturnValue({
|
||||||
'test-server': {
|
customUserVars: {
|
||||||
customUserVars: {
|
API_KEY: 'some-env-var',
|
||||||
API_KEY: 'some-env-var',
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
});
|
};
|
||||||
|
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
|
getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||||
|
@ -1057,7 +1051,13 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return 500 when auth values check throws unexpected error', async () => {
|
it('should return 500 when auth values check throws unexpected error', async () => {
|
||||||
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
|
const mockMcpManager = {
|
||||||
|
getRawConfig: jest.fn().mockImplementation(() => {
|
||||||
|
throw new Error('Config loading failed');
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||||
|
|
||||||
|
@ -1066,14 +1066,13 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle customUserVars that is not an object', async () => {
|
it('should handle customUserVars that is not an object', async () => {
|
||||||
const { loadCustomConfig } = require('~/server/services/Config');
|
const mockMcpManager = {
|
||||||
loadCustomConfig.mockResolvedValue({
|
getRawConfig: jest.fn().mockReturnValue({
|
||||||
mcpServers: {
|
customUserVars: 'not-an-object',
|
||||||
'test-server': {
|
}),
|
||||||
customUserVars: 'not-an-object',
|
};
|
||||||
},
|
|
||||||
},
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
});
|
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||||
|
|
||||||
|
@ -1097,98 +1096,6 @@ describe('MCP Routes', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('POST /:serverName/reinitialize - Tool Deletion Coverage', () => {
|
|
||||||
it('should handle null cached tools during reinitialize (triggers || {} fallback)', async () => {
|
|
||||||
const { loadCustomConfig, getCachedTools } = require('~/server/services/Config');
|
|
||||||
|
|
||||||
const mockUserConnection = {
|
|
||||||
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
|
|
||||||
};
|
|
||||||
|
|
||||||
const mockMcpManager = {
|
|
||||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
|
||||||
disconnectServer: jest.fn(),
|
|
||||||
initializeServer: jest.fn(),
|
|
||||||
mcpConfigs: {},
|
|
||||||
};
|
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
|
||||||
|
|
||||||
loadCustomConfig.mockResolvedValue({
|
|
||||||
mcpServers: {
|
|
||||||
'test-server': { env: { API_KEY: 'test-key' } },
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
getCachedTools.mockResolvedValue(null);
|
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
|
|
||||||
|
|
||||||
expect(response.body).toEqual({
|
|
||||||
message: "MCP server 'test-server' reinitialized successfully",
|
|
||||||
success: true,
|
|
||||||
oauthRequired: false,
|
|
||||||
oauthUrl: null,
|
|
||||||
serverName: 'test-server',
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should delete existing cached tools during successful reinitialize', async () => {
|
|
||||||
const {
|
|
||||||
loadCustomConfig,
|
|
||||||
getCachedTools,
|
|
||||||
setCachedTools,
|
|
||||||
} = require('~/server/services/Config');
|
|
||||||
|
|
||||||
const mockUserConnection = {
|
|
||||||
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
|
|
||||||
};
|
|
||||||
|
|
||||||
const mockMcpManager = {
|
|
||||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
|
||||||
disconnectServer: jest.fn(),
|
|
||||||
initializeServer: jest.fn(),
|
|
||||||
mcpConfigs: {},
|
|
||||||
};
|
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
|
||||||
|
|
||||||
loadCustomConfig.mockResolvedValue({
|
|
||||||
mcpServers: {
|
|
||||||
'test-server': { env: { API_KEY: 'test-key' } },
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const existingTools = {
|
|
||||||
'old-tool_mcp_test-server': { type: 'function' },
|
|
||||||
'other-tool_mcp_other-server': { type: 'function' },
|
|
||||||
};
|
|
||||||
getCachedTools.mockResolvedValue(existingTools);
|
|
||||||
|
|
||||||
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
|
|
||||||
|
|
||||||
expect(response.body).toEqual({
|
|
||||||
message: "MCP server 'test-server' reinitialized successfully",
|
|
||||||
success: true,
|
|
||||||
oauthRequired: false,
|
|
||||||
oauthUrl: null,
|
|
||||||
serverName: 'test-server',
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(setCachedTools).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({
|
|
||||||
'new-tool_mcp_test-server': expect.any(Object),
|
|
||||||
'other-tool_mcp_other-server': { type: 'function' },
|
|
||||||
}),
|
|
||||||
{ userId: 'test-user-id' },
|
|
||||||
);
|
|
||||||
expect(setCachedTools).toHaveBeenCalledWith(
|
|
||||||
expect.not.objectContaining({
|
|
||||||
'old-tool_mcp_test-server': expect.anything(),
|
|
||||||
}),
|
|
||||||
{ userId: 'test-user-id' },
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
||||||
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
const { MCPOAuthHandler } = require('@librechat/api');
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
const { Router } = require('express');
|
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { MCPOAuthHandler } = require('@librechat/api');
|
const { MCPOAuthHandler } = require('@librechat/api');
|
||||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
const { Router } = require('express');
|
||||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
|
||||||
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
|
|
||||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||||
|
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||||
|
const { setCachedTools, getCachedTools } = require('~/server/services/Config');
|
||||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||||
|
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||||
const { requireJwtAuth } = require('~/server/middleware');
|
const { requireJwtAuth } = require('~/server/middleware');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
@ -315,9 +315,9 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||||
|
|
||||||
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
|
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
|
||||||
|
|
||||||
const printConfig = false;
|
const mcpManager = getMCPManager();
|
||||||
const config = await loadCustomConfig(printConfig);
|
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||||
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
|
if (!serverConfig) {
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: `MCP server '${serverName}' not found in configuration`,
|
error: `MCP server '${serverName}' not found in configuration`,
|
||||||
});
|
});
|
||||||
|
@ -325,13 +325,12 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||||
|
|
||||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||||
const flowManager = getFlowStateManager(flowsCache);
|
const flowManager = getFlowStateManager(flowsCache);
|
||||||
const mcpManager = getMCPManager();
|
|
||||||
|
|
||||||
await mcpManager.disconnectServer(serverName);
|
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||||
logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`);
|
logger.info(
|
||||||
|
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
|
||||||
|
);
|
||||||
|
|
||||||
const serverConfig = config.mcpServers[serverName];
|
|
||||||
mcpManager.mcpConfigs[serverName] = serverConfig;
|
|
||||||
let customUserVars = {};
|
let customUserVars = {};
|
||||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||||
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
||||||
|
@ -437,7 +436,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
res.json({
|
res.json({
|
||||||
success: (userConnection && !oauthRequired) || (oauthRequired && oauthUrl),
|
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
|
||||||
message: getResponseMessage(),
|
message: getResponseMessage(),
|
||||||
serverName,
|
serverName,
|
||||||
oauthRequired,
|
oauthRequired,
|
||||||
|
@ -551,15 +550,14 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
||||||
return res.status(401).json({ error: 'User not authenticated' });
|
return res.status(401).json({ error: 'User not authenticated' });
|
||||||
}
|
}
|
||||||
|
|
||||||
const printConfig = false;
|
const mcpManager = getMCPManager();
|
||||||
const config = await loadCustomConfig(printConfig);
|
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||||
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
|
if (!serverConfig) {
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: `MCP server '${serverName}' not found in configuration`,
|
error: `MCP server '${serverName}' not found in configuration`,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const serverConfig = config.mcpServers[serverName];
|
|
||||||
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
|
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
|
||||||
const authValueFlags = {};
|
const authValueFlags = {};
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { CacheKeys } = require('librechat-data-provider');
|
|
||||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
|
||||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
|
||||||
const { getCachedTools, setCachedTools } = require('./Config');
|
const { getCachedTools, setCachedTools } = require('./Config');
|
||||||
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
|
const { createMCPManager } = require('~/config');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -31,33 +30,19 @@ async function initializeMCPs(app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Initializing MCP servers...');
|
logger.info('Initializing MCP servers...');
|
||||||
const mcpManager = getMCPManager();
|
const mcpManager = await createMCPManager(mcpServers);
|
||||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
|
||||||
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await mcpManager.initializeMCPs({
|
|
||||||
mcpServers: filteredServers,
|
|
||||||
flowManager,
|
|
||||||
tokenMethods: {
|
|
||||||
findToken,
|
|
||||||
updateToken,
|
|
||||||
createToken,
|
|
||||||
deleteTokens,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
delete app.locals.mcpConfig;
|
delete app.locals.mcpConfig;
|
||||||
const availableTools = await getCachedTools();
|
const cachedTools = await getCachedTools();
|
||||||
|
|
||||||
if (!availableTools) {
|
if (!cachedTools) {
|
||||||
logger.warn('No available tools found in cache during MCP initialization');
|
logger.warn('No available tools found in cache during MCP initialization');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolsCopy = { ...availableTools };
|
const mcpTools = mcpManager.getAppToolFunctions();
|
||||||
await mcpManager.mapAvailableTools(toolsCopy, flowManager);
|
await setCachedTools({ ...cachedTools, ...mcpTools }, { isGlobal: true });
|
||||||
await setCachedTools(toolsCopy, { isGlobal: true });
|
|
||||||
|
|
||||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
await cache.delete(CacheKeys.TOOLS);
|
await cache.delete(CacheKeys.TOOLS);
|
||||||
|
|
|
@ -2,11 +2,11 @@ import { fileURLToPath } from 'node:url';
|
||||||
import path from 'node:path';
|
import path from 'node:path';
|
||||||
import typescriptEslintEslintPlugin from '@typescript-eslint/eslint-plugin';
|
import typescriptEslintEslintPlugin from '@typescript-eslint/eslint-plugin';
|
||||||
import { fixupConfigRules, fixupPluginRules } from '@eslint/compat';
|
import { fixupConfigRules, fixupPluginRules } from '@eslint/compat';
|
||||||
// import perfectionist from 'eslint-plugin-perfectionist';
|
import perfectionist from 'eslint-plugin-perfectionist';
|
||||||
import reactHooks from 'eslint-plugin-react-hooks';
|
import reactHooks from 'eslint-plugin-react-hooks';
|
||||||
import prettier from 'eslint-plugin-prettier';
|
|
||||||
import tsParser from '@typescript-eslint/parser';
|
import tsParser from '@typescript-eslint/parser';
|
||||||
import importPlugin from 'eslint-plugin-import';
|
import importPlugin from 'eslint-plugin-import';
|
||||||
|
import prettier from 'eslint-plugin-prettier';
|
||||||
import { FlatCompat } from '@eslint/eslintrc';
|
import { FlatCompat } from '@eslint/eslintrc';
|
||||||
import jsxA11Y from 'eslint-plugin-jsx-a11y';
|
import jsxA11Y from 'eslint-plugin-jsx-a11y';
|
||||||
import i18next from 'eslint-plugin-i18next';
|
import i18next from 'eslint-plugin-i18next';
|
||||||
|
@ -62,7 +62,7 @@ export default [
|
||||||
'jsx-a11y': fixupPluginRules(jsxA11Y),
|
'jsx-a11y': fixupPluginRules(jsxA11Y),
|
||||||
'import/parsers': tsParser,
|
'import/parsers': tsParser,
|
||||||
i18next,
|
i18next,
|
||||||
// perfectionist,
|
perfectionist,
|
||||||
prettier: fixupPluginRules(prettier),
|
prettier: fixupPluginRules(prettier),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -140,32 +140,31 @@ export default [
|
||||||
'react/prop-types': 'off',
|
'react/prop-types': 'off',
|
||||||
'react/display-name': 'off',
|
'react/display-name': 'off',
|
||||||
|
|
||||||
// 'perfectionist/sort-imports': [
|
'perfectionist/sort-imports': [
|
||||||
// 'error',
|
'error',
|
||||||
// {
|
{
|
||||||
// type: 'line-length',
|
type: 'line-length',
|
||||||
// order: 'desc',
|
order: 'desc',
|
||||||
// newlinesBetween: 'never',
|
newlinesBetween: 'never',
|
||||||
// customGroups: {
|
customGroups: {
|
||||||
// value: {
|
value: {
|
||||||
// react: ['^react$'],
|
react: ['^react$'],
|
||||||
// // react: ['^react$', '^fs', '^zod', '^path'],
|
local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'],
|
||||||
// local: ['^(\\.{1,2}|~)/', '^librechat-data-provider'],
|
},
|
||||||
// },
|
},
|
||||||
// },
|
groups: [
|
||||||
// groups: [
|
'react',
|
||||||
// 'react',
|
'builtin',
|
||||||
// 'builtin',
|
'external',
|
||||||
// 'external',
|
['builtin-type', 'external-type'],
|
||||||
// ['builtin-type', 'external-type'],
|
['internal-type'],
|
||||||
// ['internal-type'],
|
'local',
|
||||||
// 'local',
|
['parent', 'sibling', 'index'],
|
||||||
// ['parent', 'sibling', 'index'],
|
'object',
|
||||||
// 'object',
|
'unknown',
|
||||||
// 'unknown',
|
],
|
||||||
// ],
|
},
|
||||||
// },
|
],
|
||||||
// ],
|
|
||||||
|
|
||||||
// 'perfectionist/sort-named-imports': [
|
// 'perfectionist/sort-named-imports': [
|
||||||
// 'error',
|
// 'error',
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
export default {
|
export default {
|
||||||
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
|
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
|
||||||
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
|
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
|
||||||
testPathIgnorePatterns: ['/node_modules/', '/dist/'],
|
testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'],
|
||||||
coverageReporters: ['text', 'cobertura'],
|
coverageReporters: ['text', 'cobertura'],
|
||||||
testResultsProcessor: 'jest-junit',
|
testResultsProcessor: 'jest-junit',
|
||||||
moduleNameMapper: {
|
moduleNameMapper: {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/* MCP */
|
/* MCP */
|
||||||
export * from './mcp/manager';
|
export * from './mcp/MCPManager';
|
||||||
export * from './mcp/oauth';
|
export * from './mcp/oauth';
|
||||||
export * from './mcp/auth';
|
export * from './mcp/auth';
|
||||||
export * from './mcp/zod';
|
export * from './mcp/zod';
|
||||||
|
|
87
packages/api/src/mcp/ConnectionsRepository.ts
Normal file
87
packages/api/src/mcp/ConnectionsRepository.ts
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import { MCPConnectionFactory, OAuthConnectionOptions } from '~/mcp/MCPConnectionFactory';
|
||||||
|
import { MCPConnection } from './connection';
|
||||||
|
import type * as t from './types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manages MCP connections with lazy loading and reconnection.
|
||||||
|
* Maintains a pool of connections and handles connection lifecycle management.
|
||||||
|
*/
|
||||||
|
export class ConnectionsRepository {
|
||||||
|
protected readonly serverConfigs: Record<string, t.MCPOptions>;
|
||||||
|
protected connections: Map<string, MCPConnection> = new Map();
|
||||||
|
protected oauthOpts: OAuthConnectionOptions | undefined;
|
||||||
|
|
||||||
|
constructor(serverConfigs: t.MCPServers, oauthOpts?: OAuthConnectionOptions) {
|
||||||
|
this.serverConfigs = serverConfigs;
|
||||||
|
this.oauthOpts = oauthOpts;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Checks whether this repository can connect to a specific server */
|
||||||
|
has(serverName: string): boolean {
|
||||||
|
return !!this.serverConfigs[serverName];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets or creates a connection for the specified server with lazy loading */
|
||||||
|
async get(serverName: string): Promise<MCPConnection> {
|
||||||
|
const existingConnection = this.connections.get(serverName);
|
||||||
|
if (existingConnection && (await existingConnection.isConnected())) return existingConnection;
|
||||||
|
else await this.disconnect(serverName);
|
||||||
|
|
||||||
|
const connection = await MCPConnectionFactory.create(
|
||||||
|
{
|
||||||
|
serverName,
|
||||||
|
serverConfig: this.getServerConfig(serverName),
|
||||||
|
},
|
||||||
|
this.oauthOpts,
|
||||||
|
);
|
||||||
|
|
||||||
|
this.connections.set(serverName, connection);
|
||||||
|
return connection;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets or creates connections for multiple servers concurrently */
|
||||||
|
async getMany(serverNames: string[]): Promise<Map<string, MCPConnection>> {
|
||||||
|
const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]);
|
||||||
|
const connections = await Promise.all(connectionPromises);
|
||||||
|
return new Map(connections as [string, MCPConnection][]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns all currently loaded connections without creating new ones */
|
||||||
|
async getLoaded(): Promise<Map<string, MCPConnection>> {
|
||||||
|
return this.getMany(Array.from(this.connections.keys()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets or creates connections for all configured servers */
|
||||||
|
async getAll(): Promise<Map<string, MCPConnection>> {
|
||||||
|
return this.getMany(Object.keys(this.serverConfigs));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Disconnects and removes a specific server connection from the pool */
|
||||||
|
disconnect(serverName: string): Promise<void> {
|
||||||
|
const connection = this.connections.get(serverName);
|
||||||
|
if (!connection) return Promise.resolve();
|
||||||
|
this.connections.delete(serverName);
|
||||||
|
return connection.disconnect().catch((err) => {
|
||||||
|
logger.error(`${this.prefix(serverName)} Error disconnecting`, err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Disconnects all active connections and returns array of disconnect promises */
|
||||||
|
disconnectAll(): Promise<void>[] {
|
||||||
|
const serverNames = Array.from(this.connections.keys());
|
||||||
|
return serverNames.map((serverName) => this.disconnect(serverName));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves server configuration by name or throws if not found
|
||||||
|
protected getServerConfig(serverName: string): t.MCPOptions {
|
||||||
|
const serverConfig = this.serverConfigs[serverName];
|
||||||
|
if (serverConfig) return serverConfig;
|
||||||
|
throw new Error(`${this.prefix(serverName)} Server not found in configuration`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns formatted log prefix for server messages
|
||||||
|
protected prefix(serverName: string): string {
|
||||||
|
return `[MCP][${serverName}]`;
|
||||||
|
}
|
||||||
|
}
|
384
packages/api/src/mcp/MCPConnectionFactory.ts
Normal file
384
packages/api/src/mcp/MCPConnectionFactory.ts
Normal file
|
@ -0,0 +1,384 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||||
|
import type { TokenMethods } from '@librechat/data-schemas';
|
||||||
|
import type { TUser } from 'librechat-data-provider';
|
||||||
|
import type { MCPOAuthTokens, MCPOAuthFlowMetadata } from '~/mcp/oauth';
|
||||||
|
import type { FlowStateManager } from '~/flow/manager';
|
||||||
|
import type { FlowMetadata } from '~/flow/types';
|
||||||
|
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||||
|
import { MCPConnection } from './connection';
|
||||||
|
import { processMCPEnv } from '~/utils';
|
||||||
|
import type * as t from './types';
|
||||||
|
|
||||||
|
export interface BasicConnectionOptions {
|
||||||
|
serverName: string;
|
||||||
|
serverConfig: t.MCPOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface OAuthConnectionOptions {
|
||||||
|
useOAuth: true;
|
||||||
|
user: TUser;
|
||||||
|
customUserVars?: Record<string, string>;
|
||||||
|
flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||||
|
tokenMethods?: TokenMethods;
|
||||||
|
signal?: AbortSignal;
|
||||||
|
oauthStart?: (authURL: string) => Promise<void>;
|
||||||
|
oauthEnd?: () => Promise<void>;
|
||||||
|
returnOnOAuth?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory for creating MCP connections with optional OAuth authentication.
|
||||||
|
* Handles OAuth flows, token management, and connection retry logic.
|
||||||
|
* NOTE: Much of the OAuth logic was extracted from the old MCPManager class as is.
|
||||||
|
*/
|
||||||
|
export class MCPConnectionFactory {
|
||||||
|
protected readonly serverName: string;
|
||||||
|
protected readonly serverConfig: t.MCPOptions;
|
||||||
|
protected readonly logPrefix: string;
|
||||||
|
protected readonly useOAuth: boolean;
|
||||||
|
|
||||||
|
// OAuth-related properties (only set when useOAuth is true)
|
||||||
|
protected readonly userId?: string;
|
||||||
|
protected readonly flowManager?: FlowStateManager<MCPOAuthTokens | null>;
|
||||||
|
protected readonly tokenMethods?: TokenMethods;
|
||||||
|
protected readonly signal?: AbortSignal;
|
||||||
|
protected readonly oauthStart?: (authURL: string) => Promise<void>;
|
||||||
|
protected readonly oauthEnd?: () => Promise<void>;
|
||||||
|
protected readonly returnOnOAuth?: boolean;
|
||||||
|
|
||||||
|
/** Creates a new MCP connection with optional OAuth support */
|
||||||
|
static async create(
|
||||||
|
basic: BasicConnectionOptions,
|
||||||
|
oauth?: OAuthConnectionOptions,
|
||||||
|
): Promise<MCPConnection> {
|
||||||
|
const factory = new this(basic, oauth);
|
||||||
|
return factory.createConnection();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected constructor(basic: BasicConnectionOptions, oauth?: OAuthConnectionOptions) {
|
||||||
|
this.serverConfig = processMCPEnv(basic.serverConfig, oauth?.user, oauth?.customUserVars);
|
||||||
|
this.serverName = basic.serverName;
|
||||||
|
this.useOAuth = !!oauth?.useOAuth;
|
||||||
|
this.logPrefix = oauth?.user
|
||||||
|
? `[MCP][${basic.serverName}][${oauth.user.id}]`
|
||||||
|
: `[MCP][${basic.serverName}]`;
|
||||||
|
|
||||||
|
if (oauth?.useOAuth) {
|
||||||
|
this.userId = oauth.user.id;
|
||||||
|
this.flowManager = oauth.flowManager;
|
||||||
|
this.tokenMethods = oauth.tokenMethods;
|
||||||
|
this.signal = oauth.signal;
|
||||||
|
this.oauthStart = oauth.oauthStart;
|
||||||
|
this.oauthEnd = oauth.oauthEnd;
|
||||||
|
this.returnOnOAuth = oauth.returnOnOAuth;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Creates the base MCP connection with OAuth tokens */
|
||||||
|
protected async createConnection(): Promise<MCPConnection> {
|
||||||
|
const oauthTokens = this.useOAuth ? await this.getOAuthTokens() : null;
|
||||||
|
const connection = new MCPConnection({
|
||||||
|
serverName: this.serverName,
|
||||||
|
serverConfig: this.serverConfig,
|
||||||
|
userId: this.userId,
|
||||||
|
oauthTokens,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (this.useOAuth) this.handleOAuthEvents(connection);
|
||||||
|
await this.attemptToConnect(connection);
|
||||||
|
return connection;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Retrieves existing OAuth tokens from storage or returns null */
|
||||||
|
protected async getOAuthTokens(): Promise<MCPOAuthTokens | null> {
|
||||||
|
if (!this.tokenMethods?.findToken) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const tokens = await this.flowManager!.createFlowWithHandler(
|
||||||
|
`tokens:${this.userId}:${this.serverName}`,
|
||||||
|
'mcp_get_tokens',
|
||||||
|
async () => {
|
||||||
|
return await MCPTokenStorage.getTokens({
|
||||||
|
userId: this.userId!,
|
||||||
|
serverName: this.serverName,
|
||||||
|
findToken: this.tokenMethods!.findToken!,
|
||||||
|
createToken: this.tokenMethods!.createToken,
|
||||||
|
updateToken: this.tokenMethods!.updateToken,
|
||||||
|
refreshTokens: this.createRefreshTokensFunction(),
|
||||||
|
});
|
||||||
|
},
|
||||||
|
this.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`);
|
||||||
|
return tokens;
|
||||||
|
} catch (error) {
|
||||||
|
logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Creates a function to refresh OAuth tokens when they expire */
|
||||||
|
protected createRefreshTokensFunction(): (
|
||||||
|
refreshToken: string,
|
||||||
|
metadata: {
|
||||||
|
userId: string;
|
||||||
|
serverName: string;
|
||||||
|
identifier: string;
|
||||||
|
clientInfo?: OAuthClientInformation;
|
||||||
|
},
|
||||||
|
) => Promise<MCPOAuthTokens> {
|
||||||
|
return async (refreshToken, metadata) => {
|
||||||
|
return await MCPOAuthHandler.refreshOAuthTokens(
|
||||||
|
refreshToken,
|
||||||
|
{
|
||||||
|
serverUrl: (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url,
|
||||||
|
serverName: metadata.serverName,
|
||||||
|
clientInfo: metadata.clientInfo,
|
||||||
|
},
|
||||||
|
this.serverConfig.oauth,
|
||||||
|
);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets up OAuth event handlers for the connection */
|
||||||
|
protected handleOAuthEvents(connection: MCPConnection): void {
|
||||||
|
connection.on('oauthRequired', async (data) => {
|
||||||
|
logger.info(`${this.logPrefix} oauthRequired event received`);
|
||||||
|
|
||||||
|
// If we just want to initiate OAuth and return, handle it differently
|
||||||
|
if (this.returnOnOAuth) {
|
||||||
|
try {
|
||||||
|
const config = this.serverConfig;
|
||||||
|
const { authorizationUrl, flowId, flowMetadata } =
|
||||||
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
this.serverName,
|
||||||
|
data.serverUrl || '',
|
||||||
|
this.userId!,
|
||||||
|
config?.oauth,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create the flow state so the OAuth callback can find it
|
||||||
|
// We spawn this in the background without waiting for it
|
||||||
|
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => {
|
||||||
|
// The OAuth callback will resolve this flow, so we expect it to timeout here
|
||||||
|
// which is fine - we just need the flow state to exist
|
||||||
|
});
|
||||||
|
|
||||||
|
if (this.oauthStart) {
|
||||||
|
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
|
||||||
|
await this.oauthStart(authorizationUrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit oauthFailed to signal that connection should not proceed
|
||||||
|
// but OAuth was successfully initiated
|
||||||
|
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||||
|
return;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`${this.logPrefix} Failed to initiate OAuth flow`, error);
|
||||||
|
connection.emit('oauthFailed', new Error('OAuth initiation failed'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal OAuth handling - wait for completion
|
||||||
|
const result = await this.handleOAuthRequired();
|
||||||
|
|
||||||
|
if (result?.tokens && this.tokenMethods?.createToken) {
|
||||||
|
try {
|
||||||
|
connection.setOAuthTokens(result.tokens);
|
||||||
|
await MCPTokenStorage.storeTokens({
|
||||||
|
userId: this.userId!,
|
||||||
|
serverName: this.serverName,
|
||||||
|
tokens: result.tokens,
|
||||||
|
createToken: this.tokenMethods.createToken,
|
||||||
|
updateToken: this.tokenMethods.updateToken,
|
||||||
|
findToken: this.tokenMethods.findToken,
|
||||||
|
clientInfo: result.clientInfo,
|
||||||
|
});
|
||||||
|
logger.info(`${this.logPrefix} OAuth tokens saved to storage`);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`${this.logPrefix} Failed to save OAuth tokens to storage`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only emit oauthHandled if we actually got tokens (OAuth succeeded)
|
||||||
|
if (result?.tokens) {
|
||||||
|
connection.emit('oauthHandled');
|
||||||
|
} else {
|
||||||
|
// OAuth failed, emit oauthFailed to properly reject the promise
|
||||||
|
logger.warn(`${this.logPrefix} OAuth failed, emitting oauthFailed event`);
|
||||||
|
connection.emit('oauthFailed', new Error('OAuth authentication failed'));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Attempts to establish connection with timeout handling */
|
||||||
|
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
|
||||||
|
const connectTimeout = this.serverConfig.initTimeout ?? 30000;
|
||||||
|
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||||
|
setTimeout(
|
||||||
|
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||||
|
connectTimeout,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
const connectionAttempt = this.connectTo(connection);
|
||||||
|
await Promise.race([connectionAttempt, connectionTimeout]);
|
||||||
|
|
||||||
|
if (await connection.isConnected()) return;
|
||||||
|
logger.error(`${this.logPrefix} Failed to establish connection.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles connection attempts with retry logic and OAuth error handling
|
||||||
|
private async connectTo(connection: MCPConnection): Promise<void> {
|
||||||
|
const maxAttempts = 3;
|
||||||
|
let attempts = 0;
|
||||||
|
let oauthHandled = false;
|
||||||
|
|
||||||
|
while (attempts < maxAttempts) {
|
||||||
|
try {
|
||||||
|
await connection.connect();
|
||||||
|
if (await connection.isConnected()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
throw new Error('Connection attempt succeeded but status is not connected');
|
||||||
|
} catch (error) {
|
||||||
|
attempts++;
|
||||||
|
|
||||||
|
if (this.useOAuth && this.isOAuthError(error)) {
|
||||||
|
// Only handle OAuth if this is a user connection (has oauthStart handler)
|
||||||
|
if (this.oauthStart && !oauthHandled) {
|
||||||
|
const errorWithFlag = error as (Error & { isOAuthError?: boolean }) | undefined;
|
||||||
|
if (errorWithFlag?.isOAuthError) {
|
||||||
|
oauthHandled = true;
|
||||||
|
logger.info(`${this.logPrefix} Handling OAuth`);
|
||||||
|
await this.handleOAuthRequired();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Don't retry on OAuth errors - just throw
|
||||||
|
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (attempts === maxAttempts) {
|
||||||
|
logger.error(`${this.logPrefix} Failed to connect after ${maxAttempts} attempts`, error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 2000 * attempts));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines if an error indicates OAuth authentication is required
|
||||||
|
private isOAuthError(error: unknown): boolean {
|
||||||
|
if (!error || typeof error !== 'object') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for SSE error with 401 status
|
||||||
|
if ('message' in error && typeof error.message === 'string') {
|
||||||
|
return error.message.includes('401') || error.message.includes('Non-200 status code (401)');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for error code
|
||||||
|
if ('code' in error) {
|
||||||
|
const code = (error as { code?: number }).code;
|
||||||
|
return code === 401 || code === 403;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Manages OAuth flow initiation and completion */
|
||||||
|
protected async handleOAuthRequired(): Promise<{
|
||||||
|
tokens: MCPOAuthTokens | null;
|
||||||
|
clientInfo?: OAuthClientInformation;
|
||||||
|
} | null> {
|
||||||
|
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
|
||||||
|
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
|
||||||
|
|
||||||
|
if (!this.flowManager || !serverUrl) {
|
||||||
|
logger.error(
|
||||||
|
`${this.logPrefix} OAuth required but flow manager not available or server URL missing for ${this.serverName}`,
|
||||||
|
);
|
||||||
|
logger.warn(`${this.logPrefix} Please configure OAuth credentials for ${this.serverName}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
logger.debug(`${this.logPrefix} Checking for existing OAuth flow for ${this.serverName}...`);
|
||||||
|
|
||||||
|
/** Flow ID to check if a flow already exists */
|
||||||
|
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
|
||||||
|
|
||||||
|
/** Check if there's already an ongoing OAuth flow for this flowId */
|
||||||
|
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||||
|
if (existingFlow && existingFlow.status === 'PENDING') {
|
||||||
|
logger.debug(
|
||||||
|
`${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`,
|
||||||
|
);
|
||||||
|
/** Tokens from existing flow to complete */
|
||||||
|
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth');
|
||||||
|
if (typeof this.oauthEnd === 'function') {
|
||||||
|
await this.oauthEnd();
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
/** Client information from the existing flow metadata */
|
||||||
|
const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata;
|
||||||
|
const clientInfo = existingMetadata?.clientInfo;
|
||||||
|
|
||||||
|
return { tokens, clientInfo };
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);
|
||||||
|
const {
|
||||||
|
authorizationUrl,
|
||||||
|
flowId: newFlowId,
|
||||||
|
flowMetadata,
|
||||||
|
} = await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
this.serverName,
|
||||||
|
serverUrl,
|
||||||
|
this.userId!,
|
||||||
|
this.serverConfig.oauth,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (typeof this.oauthStart === 'function') {
|
||||||
|
logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`);
|
||||||
|
await this.oauthStart(authorizationUrl);
|
||||||
|
} else {
|
||||||
|
logger.info(`
|
||||||
|
═══════════════════════════════════════════════════════════════════════
|
||||||
|
Please visit the following URL to authenticate:
|
||||||
|
|
||||||
|
${authorizationUrl}
|
||||||
|
|
||||||
|
${this.logPrefix} Flow ID: ${newFlowId}
|
||||||
|
═══════════════════════════════════════════════════════════════════════
|
||||||
|
`);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Tokens from the new flow */
|
||||||
|
const tokens = await this.flowManager.createFlow(
|
||||||
|
newFlowId,
|
||||||
|
'mcp_oauth',
|
||||||
|
flowMetadata as FlowMetadata,
|
||||||
|
);
|
||||||
|
if (typeof this.oauthEnd === 'function') {
|
||||||
|
await this.oauthEnd();
|
||||||
|
}
|
||||||
|
logger.info(`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`);
|
||||||
|
|
||||||
|
/** Client information from the flow metadata */
|
||||||
|
const clientInfo = flowMetadata?.clientInfo;
|
||||||
|
|
||||||
|
return { tokens, clientInfo };
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`${this.logPrefix} Failed to complete OAuth flow for ${this.serverName}`, error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
263
packages/api/src/mcp/MCPManager.ts
Normal file
263
packages/api/src/mcp/MCPManager.ts
Normal file
|
@ -0,0 +1,263 @@
|
||||||
|
import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import pick from 'lodash/pick';
|
||||||
|
import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
|
||||||
|
import type { TokenMethods } from '@librechat/data-schemas';
|
||||||
|
import type { TUser } from 'librechat-data-provider';
|
||||||
|
import type { FlowStateManager } from '~/flow/manager';
|
||||||
|
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||||
|
import { UserConnectionManager } from '~/mcp/UserConnectionManager';
|
||||||
|
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||||
|
import { formatToolContent } from './parsers';
|
||||||
|
import { MCPConnection } from './connection';
|
||||||
|
import { CONSTANTS } from './enum';
|
||||||
|
import type * as t from './types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Centralized manager for MCP server connections and tool execution.
|
||||||
|
* Extends UserConnectionManager to handle both app-level and user-specific connections.
|
||||||
|
*/
|
||||||
|
export class MCPManager extends UserConnectionManager {
|
||||||
|
private static instance: MCPManager | null;
|
||||||
|
// Connections shared by all users.
|
||||||
|
private appConnections: ConnectionsRepository | null = null;
|
||||||
|
|
||||||
|
/** Creates and initializes the singleton MCPManager instance */
|
||||||
|
public static async createInstance(configs: t.MCPServers): Promise<MCPManager> {
|
||||||
|
if (MCPManager.instance) throw new Error('MCPManager has already been initialized.');
|
||||||
|
MCPManager.instance = new MCPManager(configs);
|
||||||
|
await MCPManager.instance.initialize();
|
||||||
|
return MCPManager.instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the singleton MCPManager instance */
|
||||||
|
public static getInstance(): MCPManager {
|
||||||
|
if (!MCPManager.instance) throw new Error('MCPManager has not been initialized.');
|
||||||
|
return MCPManager.instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Initializes the MCPManager by setting up server registry and app connections */
|
||||||
|
public async initialize() {
|
||||||
|
await this.serversRegistry.initialize();
|
||||||
|
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns all app-level connections */
|
||||||
|
public async getAllConnections(): Promise<Map<string, MCPConnection>> {
|
||||||
|
return this.appConnections!.getAll();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get servers that require OAuth */
|
||||||
|
public getOAuthServers(): Set<string> {
|
||||||
|
return this.serversRegistry.oauthServers!;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns all available tool functions from app-level connections */
|
||||||
|
public getAppToolFunctions(): t.LCAvailableTools {
|
||||||
|
return this.serversRegistry.toolFunctions!;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get instructions for MCP servers
|
||||||
|
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
|
||||||
|
* @returns Object mapping server names to their instructions
|
||||||
|
*/
|
||||||
|
public getInstructions(serverNames?: string[]): Record<string, string> {
|
||||||
|
const instructions = this.serversRegistry.serverInstructions!;
|
||||||
|
if (!serverNames) return instructions;
|
||||||
|
return pick(instructions, serverNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format MCP server instructions for injection into context
|
||||||
|
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
|
||||||
|
* @returns Formatted instructions string ready for context injection
|
||||||
|
*/
|
||||||
|
public formatInstructionsForContext(serverNames?: string[]): string {
|
||||||
|
/** Instructions for specified servers or all stored instructions */
|
||||||
|
const instructionsToInclude = this.getInstructions(serverNames);
|
||||||
|
|
||||||
|
if (Object.keys(instructionsToInclude).length === 0) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format instructions for context injection
|
||||||
|
const formattedInstructions = Object.entries(instructionsToInclude)
|
||||||
|
.map(([serverName, instructions]) => {
|
||||||
|
return `## ${serverName} MCP Server Instructions
|
||||||
|
|
||||||
|
${instructions}`;
|
||||||
|
})
|
||||||
|
.join('\n\n');
|
||||||
|
|
||||||
|
return `# MCP Server Instructions
|
||||||
|
|
||||||
|
The following MCP servers are available with their specific instructions:
|
||||||
|
|
||||||
|
${formattedInstructions}
|
||||||
|
|
||||||
|
Please follow these instructions when using tools from the respective MCP servers.`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Loads tools from all app-level connections into the manifest. */
|
||||||
|
public async loadManifestTools({
|
||||||
|
serverToolsCallback,
|
||||||
|
getServerTools,
|
||||||
|
}: {
|
||||||
|
flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||||
|
serverToolsCallback?: (serverName: string, tools: t.LCManifestTool[]) => Promise<void>;
|
||||||
|
getServerTools?: (serverName: string) => Promise<t.LCManifestTool[] | undefined>;
|
||||||
|
}): Promise<t.LCToolManifest> {
|
||||||
|
const mcpTools: t.LCManifestTool[] = [];
|
||||||
|
const connections = await this.appConnections!.getAll();
|
||||||
|
for (const [serverName, connection] of connections.entries()) {
|
||||||
|
try {
|
||||||
|
if (!(await connection.isConnected())) {
|
||||||
|
logger.warn(
|
||||||
|
`[MCP][${serverName}] Connection not available for ${serverName} manifest tools.`,
|
||||||
|
);
|
||||||
|
if (typeof getServerTools !== 'function') {
|
||||||
|
logger.warn(
|
||||||
|
`[MCP][${serverName}] No \`getServerTools\` function provided, skipping tool loading.`,
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const serverTools = await getServerTools(serverName);
|
||||||
|
if (serverTools && serverTools.length > 0) {
|
||||||
|
logger.info(`[MCP][${serverName}] Loaded tools from cache for manifest`);
|
||||||
|
mcpTools.push(...serverTools);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tools = await connection.fetchTools();
|
||||||
|
const serverTools: t.LCManifestTool[] = [];
|
||||||
|
for (const tool of tools) {
|
||||||
|
const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
|
||||||
|
|
||||||
|
const config = this.serversRegistry.parsedConfigs[serverName];
|
||||||
|
const manifestTool: t.LCManifestTool = {
|
||||||
|
name: tool.name,
|
||||||
|
pluginKey,
|
||||||
|
description: tool.description ?? '',
|
||||||
|
icon: connection.iconPath,
|
||||||
|
authConfig: config?.customUserVars
|
||||||
|
? Object.entries(config.customUserVars).map(([key, value]) => ({
|
||||||
|
authField: key,
|
||||||
|
label: value.title || key,
|
||||||
|
description: value.description || '',
|
||||||
|
}))
|
||||||
|
: undefined,
|
||||||
|
};
|
||||||
|
if (config?.chatMenu === false) {
|
||||||
|
manifestTool.chatMenu = false;
|
||||||
|
}
|
||||||
|
mcpTools.push(manifestTool);
|
||||||
|
serverTools.push(manifestTool);
|
||||||
|
}
|
||||||
|
if (typeof serverToolsCallback === 'function') {
|
||||||
|
await serverToolsCallback(serverName, serverTools);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[MCP][${serverName}] Error fetching tools for manifest:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mcpTools;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls a tool on an MCP server, using either a user-specific connection
|
||||||
|
* (if userId is provided) or an app-level connection. Updates the last activity timestamp
|
||||||
|
* for user-specific connections upon successful call initiation.
|
||||||
|
*/
|
||||||
|
async callTool({
|
||||||
|
user,
|
||||||
|
serverName,
|
||||||
|
toolName,
|
||||||
|
provider,
|
||||||
|
toolArguments,
|
||||||
|
options,
|
||||||
|
tokenMethods,
|
||||||
|
flowManager,
|
||||||
|
oauthStart,
|
||||||
|
oauthEnd,
|
||||||
|
customUserVars,
|
||||||
|
}: {
|
||||||
|
user?: TUser;
|
||||||
|
serverName: string;
|
||||||
|
toolName: string;
|
||||||
|
provider: t.Provider;
|
||||||
|
toolArguments?: Record<string, unknown>;
|
||||||
|
options?: RequestOptions;
|
||||||
|
tokenMethods?: TokenMethods;
|
||||||
|
customUserVars?: Record<string, string>;
|
||||||
|
flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||||
|
oauthStart?: (authURL: string) => Promise<void>;
|
||||||
|
oauthEnd?: () => Promise<void>;
|
||||||
|
}): Promise<t.FormattedToolResponse> {
|
||||||
|
/** User-specific connection */
|
||||||
|
let connection: MCPConnection | undefined;
|
||||||
|
const userId = user?.id;
|
||||||
|
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!this.appConnections?.has(serverName) && userId && user) {
|
||||||
|
this.updateUserLastActivity(userId);
|
||||||
|
/** Get or create user-specific connection */
|
||||||
|
connection = await this.getUserConnection({
|
||||||
|
user,
|
||||||
|
serverName,
|
||||||
|
flowManager,
|
||||||
|
tokenMethods,
|
||||||
|
oauthStart,
|
||||||
|
oauthEnd,
|
||||||
|
signal: options?.signal,
|
||||||
|
customUserVars,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
/** App-level connection */
|
||||||
|
connection = await this.appConnections!.get(serverName);
|
||||||
|
if (!connection) {
|
||||||
|
throw new McpError(
|
||||||
|
ErrorCode.InvalidRequest,
|
||||||
|
`${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(await connection.isConnected())) {
|
||||||
|
/** May happen if getUserConnection failed silently or app connection dropped */
|
||||||
|
throw new McpError(
|
||||||
|
ErrorCode.InternalError, // Use InternalError for connection issues
|
||||||
|
`${logPrefix} Connection is not active. Cannot execute tool ${toolName}.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await connection.client.request(
|
||||||
|
{
|
||||||
|
method: 'tools/call',
|
||||||
|
params: {
|
||||||
|
name: toolName,
|
||||||
|
arguments: toolArguments,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CallToolResultSchema,
|
||||||
|
{
|
||||||
|
timeout: connection.timeout,
|
||||||
|
...options,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (userId) {
|
||||||
|
this.updateUserLastActivity(userId);
|
||||||
|
}
|
||||||
|
this.checkIdleConnections();
|
||||||
|
return formatToolContent(result as t.MCPToolCallResponse, provider);
|
||||||
|
} catch (error) {
|
||||||
|
// Log with context and re-throw or handle as needed
|
||||||
|
logger.error(`${logPrefix}[${toolName}] Tool call failed`, error);
|
||||||
|
// Rethrowing allows the caller (createMCPTool) to handle the final user message
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
200
packages/api/src/mcp/MCPServersRegistry.ts
Normal file
200
packages/api/src/mcp/MCPServersRegistry.ts
Normal file
|
@ -0,0 +1,200 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import mapValues from 'lodash/mapValues';
|
||||||
|
import pickBy from 'lodash/pickBy';
|
||||||
|
import pick from 'lodash/pick';
|
||||||
|
import type { JsonSchemaType } from '~/types';
|
||||||
|
import type * as t from '~/mcp/types';
|
||||||
|
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||||
|
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||||
|
import { type MCPConnection } from './connection';
|
||||||
|
import { processMCPEnv } from '~/utils';
|
||||||
|
import { CONSTANTS } from '~/mcp/enum';
|
||||||
|
|
||||||
|
type ParsedServerConfig = t.MCPOptions & {
|
||||||
|
url?: string;
|
||||||
|
requiresOAuth?: boolean;
|
||||||
|
oauthMetadata?: Record<string, unknown> | null;
|
||||||
|
capabilities?: string;
|
||||||
|
tools?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manages MCP server configurations and metadata discovery.
|
||||||
|
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
|
||||||
|
* Determines which servers are for app-level connections.
|
||||||
|
* Has its own connections repository. All connections are disconnected after initialization.
|
||||||
|
*/
|
||||||
|
export class MCPServersRegistry {
|
||||||
|
private initialized: boolean = false;
|
||||||
|
private connections: ConnectionsRepository;
|
||||||
|
|
||||||
|
public readonly rawConfigs: t.MCPServers;
|
||||||
|
public readonly parsedConfigs: Record<string, ParsedServerConfig>;
|
||||||
|
|
||||||
|
public oauthServers: Set<string> | null = null;
|
||||||
|
public serverInstructions: Record<string, string> | null = null;
|
||||||
|
public toolFunctions: t.LCAvailableTools | null = null;
|
||||||
|
public appServerConfigs: t.MCPServers | null = null;
|
||||||
|
|
||||||
|
constructor(configs: t.MCPServers) {
|
||||||
|
this.rawConfigs = configs;
|
||||||
|
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv(con));
|
||||||
|
this.connections = new ConnectionsRepository(configs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
|
||||||
|
public async initialize() {
|
||||||
|
if (this.initialized) return;
|
||||||
|
this.initialized = true;
|
||||||
|
|
||||||
|
const serverNames = Object.keys(this.parsedConfigs);
|
||||||
|
|
||||||
|
await Promise.allSettled(serverNames.map((serverName) => this.gatherServerInfo(serverName)));
|
||||||
|
|
||||||
|
this.setOAuthServers();
|
||||||
|
this.setServerInstructions();
|
||||||
|
this.setAppServerConfigs();
|
||||||
|
await this.setAppToolFunctions();
|
||||||
|
|
||||||
|
this.connections.disconnectAll();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetches all metadata for a single server in parallel
|
||||||
|
private async gatherServerInfo(serverName: string) {
|
||||||
|
try {
|
||||||
|
await Promise.allSettled([
|
||||||
|
this.fetchOAuthRequirement(serverName).catch((error) =>
|
||||||
|
logger.error(`${this.prefix(serverName)} Failed to fetch OAuth requirement:`, error),
|
||||||
|
),
|
||||||
|
this.fetchServerInstructions(serverName).catch((error) =>
|
||||||
|
logger.error(`${this.prefix(serverName)} Failed to fetch server instructions:`, error),
|
||||||
|
),
|
||||||
|
this.fetchServerCapabilities(serverName).catch((error) =>
|
||||||
|
logger.error(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
this.logUpdatedConfig(serverName);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`${this.prefix(serverName)} Failed to initialize server:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets app-level server configs (startup enabled, non-OAuth servers)
|
||||||
|
private setAppServerConfigs() {
|
||||||
|
const appServers = Object.keys(
|
||||||
|
pickBy(
|
||||||
|
this.parsedConfigs,
|
||||||
|
(config) => config.startup !== false && config.requiresOAuth === false,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
this.appServerConfigs = pick(this.rawConfigs, appServers);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates set of server names that require OAuth authentication
|
||||||
|
private setOAuthServers() {
|
||||||
|
if (this.oauthServers) return this.oauthServers;
|
||||||
|
this.oauthServers = new Set(
|
||||||
|
Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)),
|
||||||
|
);
|
||||||
|
return this.oauthServers;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collects server instructions from all configured servers
|
||||||
|
private setServerInstructions() {
|
||||||
|
this.serverInstructions = mapValues(
|
||||||
|
pickBy(this.parsedConfigs, (config) => config.serverInstructions),
|
||||||
|
(config) => config.serverInstructions as string,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds registry of all available tool functions from loaded connections
|
||||||
|
private async setAppToolFunctions() {
|
||||||
|
const connections = (await this.connections.getLoaded()).entries();
|
||||||
|
const allToolFunctions: t.LCAvailableTools = {};
|
||||||
|
for (const [serverName, conn] of connections) {
|
||||||
|
try {
|
||||||
|
const toolFunctions = await this.getToolFunctions(serverName, conn);
|
||||||
|
Object.assign(allToolFunctions, toolFunctions);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`${this.prefix(serverName)} Error fetching tool functions:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.toolFunctions = allToolFunctions;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts server tools to LibreChat-compatible tool functions format
|
||||||
|
private async getToolFunctions(
|
||||||
|
serverName: string,
|
||||||
|
conn: MCPConnection,
|
||||||
|
): Promise<t.LCAvailableTools> {
|
||||||
|
const { tools } = await conn.client.listTools();
|
||||||
|
|
||||||
|
const toolFunctions: t.LCAvailableTools = {};
|
||||||
|
tools.forEach((tool) => {
|
||||||
|
const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
|
||||||
|
toolFunctions[name] = {
|
||||||
|
type: 'function',
|
||||||
|
['function']: {
|
||||||
|
name,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: tool.inputSchema as JsonSchemaType,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
return toolFunctions;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines if server requires OAuth if not already specified in the config
|
||||||
|
private async fetchOAuthRequirement(serverName: string) {
|
||||||
|
const config = this.parsedConfigs[serverName];
|
||||||
|
if (config.requiresOAuth != null) return;
|
||||||
|
if (config.url == null) return (config.requiresOAuth = false);
|
||||||
|
|
||||||
|
const result = await detectOAuthRequirement(config.url);
|
||||||
|
config.requiresOAuth = result.requiresOAuth;
|
||||||
|
config.oauthMetadata = result.metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves server instructions from MCP server if enabled in the config
|
||||||
|
private async fetchServerInstructions(serverName: string) {
|
||||||
|
const config = this.parsedConfigs[serverName];
|
||||||
|
if (!config.serverInstructions) return;
|
||||||
|
if (typeof config.serverInstructions === 'string') return;
|
||||||
|
|
||||||
|
const conn = await this.connections.get(serverName);
|
||||||
|
config.serverInstructions = conn.client.getInstructions();
|
||||||
|
if (!config.serverInstructions) {
|
||||||
|
logger.warn(`${this.prefix(serverName)} No server instructions available`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetches server capabilities and available tools list
|
||||||
|
private async fetchServerCapabilities(serverName: string) {
|
||||||
|
const config = this.parsedConfigs[serverName];
|
||||||
|
const conn = await this.connections.get(serverName);
|
||||||
|
const capabilities = conn.client.getServerCapabilities();
|
||||||
|
if (!capabilities) return;
|
||||||
|
config.capabilities = JSON.stringify(capabilities);
|
||||||
|
if (!capabilities.tools) return;
|
||||||
|
const tools = await conn.client.listTools();
|
||||||
|
config.tools = tools.tools.map((tool) => tool.name).join(', ');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logs server configuration summary after initialization
|
||||||
|
private logUpdatedConfig(serverName: string) {
|
||||||
|
const prefix = this.prefix(serverName);
|
||||||
|
const config = this.parsedConfigs[serverName];
|
||||||
|
logger.info(`${prefix} URL: ${config.url ?? 'N/A'}`);
|
||||||
|
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
|
||||||
|
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
|
||||||
|
logger.info(`${prefix} Tools: ${config.tools}`);
|
||||||
|
logger.info(`${prefix} Server Instructions: ${config.serverInstructions ?? 'None'}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns formatted log prefix for server messages
|
||||||
|
private prefix(serverName: string): string {
|
||||||
|
return `[MCP][${serverName}]`;
|
||||||
|
}
|
||||||
|
}
|
236
packages/api/src/mcp/UserConnectionManager.ts
Normal file
236
packages/api/src/mcp/UserConnectionManager.ts
Normal file
|
@ -0,0 +1,236 @@
|
||||||
|
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import type { TokenMethods } from '@librechat/data-schemas';
|
||||||
|
import type { TUser } from 'librechat-data-provider';
|
||||||
|
import type { FlowStateManager } from '~/flow/manager';
|
||||||
|
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||||
|
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||||
|
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||||
|
import { MCPConnection } from './connection';
|
||||||
|
import type * as t from './types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract base class for managing user-specific MCP connections with lifecycle management.
|
||||||
|
* Only meant to be extended by MCPManager.
|
||||||
|
* Much of the logic was move here from the old MCPManager to make it more manageable.
|
||||||
|
* User connections will soon be ephemeral and not cached anymore:
|
||||||
|
* https://github.com/danny-avila/LibreChat/discussions/8790
|
||||||
|
*/
|
||||||
|
export abstract class UserConnectionManager {
|
||||||
|
protected readonly serversRegistry: MCPServersRegistry;
|
||||||
|
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
|
||||||
|
/** Last activity timestamp for users (not per server) */
|
||||||
|
protected userLastActivity: Map<string, number> = new Map();
|
||||||
|
protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
|
||||||
|
|
||||||
|
constructor(serverConfigs: t.MCPServers) {
|
||||||
|
this.serversRegistry = new MCPServersRegistry(serverConfigs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** fetches am MCP Server config from the registry */
|
||||||
|
public getRawConfig(serverName: string): t.MCPOptions | undefined {
|
||||||
|
return this.serversRegistry.rawConfigs[serverName];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the last activity timestamp for a user */
|
||||||
|
protected updateUserLastActivity(userId: string): void {
|
||||||
|
const now = Date.now();
|
||||||
|
this.userLastActivity.set(userId, now);
|
||||||
|
logger.debug(
|
||||||
|
`[MCP][User: ${userId}] Updated last activity timestamp: ${new Date(now).toISOString()}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets or creates a connection for a specific user */
|
||||||
|
public async getUserConnection({
|
||||||
|
user,
|
||||||
|
serverName,
|
||||||
|
flowManager,
|
||||||
|
customUserVars,
|
||||||
|
tokenMethods,
|
||||||
|
oauthStart,
|
||||||
|
oauthEnd,
|
||||||
|
signal,
|
||||||
|
returnOnOAuth = false,
|
||||||
|
}: {
|
||||||
|
user: TUser;
|
||||||
|
serverName: string;
|
||||||
|
flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||||
|
customUserVars?: Record<string, string>;
|
||||||
|
tokenMethods?: TokenMethods;
|
||||||
|
oauthStart?: (authURL: string) => Promise<void>;
|
||||||
|
oauthEnd?: () => Promise<void>;
|
||||||
|
signal?: AbortSignal;
|
||||||
|
returnOnOAuth?: boolean;
|
||||||
|
}): Promise<MCPConnection> {
|
||||||
|
const userId = user.id;
|
||||||
|
if (!userId) {
|
||||||
|
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const userServerMap = this.userConnections.get(userId);
|
||||||
|
let connection = userServerMap?.get(serverName);
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
// Check if user is idle
|
||||||
|
const lastActivity = this.userLastActivity.get(userId);
|
||||||
|
if (lastActivity && now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
|
||||||
|
logger.info(`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections.`);
|
||||||
|
// Disconnect all user connections
|
||||||
|
try {
|
||||||
|
await this.disconnectUserConnections(userId);
|
||||||
|
} catch (err) {
|
||||||
|
logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err);
|
||||||
|
}
|
||||||
|
connection = undefined; // Force creation of a new connection
|
||||||
|
} else if (connection) {
|
||||||
|
if (await connection.isConnected()) {
|
||||||
|
logger.debug(`[MCP][User: ${userId}][${serverName}] Reusing active connection`);
|
||||||
|
this.updateUserLastActivity(userId);
|
||||||
|
return connection;
|
||||||
|
} else {
|
||||||
|
// Connection exists but is not connected, attempt to remove potentially stale entry
|
||||||
|
logger.warn(
|
||||||
|
`[MCP][User: ${userId}][${serverName}] Found existing but disconnected connection object. Cleaning up.`,
|
||||||
|
);
|
||||||
|
this.removeUserConnection(userId, serverName); // Clean up maps
|
||||||
|
connection = undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no valid connection exists, create a new one
|
||||||
|
if (!connection) {
|
||||||
|
logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const config = this.serversRegistry.parsedConfigs[serverName];
|
||||||
|
if (!config) {
|
||||||
|
throw new McpError(
|
||||||
|
ErrorCode.InvalidRequest,
|
||||||
|
`[MCP][User: ${userId}] Configuration for server "${serverName}" not found.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
connection = await MCPConnectionFactory.create(
|
||||||
|
{
|
||||||
|
serverName: serverName,
|
||||||
|
serverConfig: config,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
useOAuth: true,
|
||||||
|
user: user,
|
||||||
|
customUserVars: customUserVars,
|
||||||
|
flowManager: flowManager,
|
||||||
|
tokenMethods: tokenMethods,
|
||||||
|
signal: signal,
|
||||||
|
oauthStart: oauthStart,
|
||||||
|
oauthEnd: oauthEnd,
|
||||||
|
returnOnOAuth: returnOnOAuth,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!(await connection?.isConnected())) {
|
||||||
|
throw new Error('Failed to establish connection after initialization attempt.');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!this.userConnections.has(userId)) {
|
||||||
|
this.userConnections.set(userId, new Map());
|
||||||
|
}
|
||||||
|
this.userConnections.get(userId)?.set(serverName, connection);
|
||||||
|
|
||||||
|
logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`);
|
||||||
|
// Update timestamp on creation
|
||||||
|
this.updateUserLastActivity(userId);
|
||||||
|
return connection;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[MCP][User: ${userId}][${serverName}] Failed to establish connection`, error);
|
||||||
|
// Ensure partial connection state is cleaned up if initialization fails
|
||||||
|
await connection?.disconnect().catch((disconnectError) => {
|
||||||
|
logger.error(
|
||||||
|
`[MCP][User: ${userId}][${serverName}] Error during cleanup after failed connection`,
|
||||||
|
disconnectError,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
// Ensure cleanup even if connection attempt fails
|
||||||
|
this.removeUserConnection(userId, serverName);
|
||||||
|
throw error; // Re-throw the error to the caller
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns all connections for a specific user */
|
||||||
|
public getUserConnections(userId: string) {
|
||||||
|
return this.userConnections.get(userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Removes a specific user connection entry */
|
||||||
|
protected removeUserConnection(userId: string, serverName: string): void {
|
||||||
|
const userMap = this.userConnections.get(userId);
|
||||||
|
if (userMap) {
|
||||||
|
userMap.delete(serverName);
|
||||||
|
if (userMap.size === 0) {
|
||||||
|
this.userConnections.delete(userId);
|
||||||
|
// Only remove user activity timestamp if all connections are gone
|
||||||
|
this.userLastActivity.delete(userId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(`[MCP][User: ${userId}][${serverName}] Removed connection entry.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Disconnects and removes a specific user connection */
|
||||||
|
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
|
||||||
|
const userMap = this.userConnections.get(userId);
|
||||||
|
const connection = userMap?.get(serverName);
|
||||||
|
if (connection) {
|
||||||
|
logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`);
|
||||||
|
await connection.disconnect();
|
||||||
|
this.removeUserConnection(userId, serverName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Disconnects and removes all connections for a specific user */
|
||||||
|
public async disconnectUserConnections(userId: string): Promise<void> {
|
||||||
|
const userMap = this.userConnections.get(userId);
|
||||||
|
const disconnectPromises: Promise<void>[] = [];
|
||||||
|
if (userMap) {
|
||||||
|
logger.info(`[MCP][User: ${userId}] Disconnecting all servers...`);
|
||||||
|
const userServers = Array.from(userMap.keys());
|
||||||
|
for (const serverName of userServers) {
|
||||||
|
disconnectPromises.push(
|
||||||
|
this.disconnectUserConnection(userId, serverName).catch((error) => {
|
||||||
|
logger.error(
|
||||||
|
`[MCP][User: ${userId}][${serverName}] Error during disconnection:`,
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
await Promise.allSettled(disconnectPromises);
|
||||||
|
// Ensure user activity timestamp is removed
|
||||||
|
this.userLastActivity.delete(userId);
|
||||||
|
logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Check for and disconnect idle connections */
|
||||||
|
protected checkIdleConnections(currentUserId?: string): void {
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
// Iterate through all users to check for idle ones
|
||||||
|
for (const [userId, lastActivity] of this.userLastActivity.entries()) {
|
||||||
|
if (currentUserId && currentUserId === userId) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
|
||||||
|
logger.info(
|
||||||
|
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`,
|
||||||
|
);
|
||||||
|
// Disconnect all user connections asynchronously (fire and forget)
|
||||||
|
this.disconnectUserConnections(userId).catch((err) =>
|
||||||
|
logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
212
packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts
Normal file
212
packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import { ConnectionsRepository } from '../ConnectionsRepository';
|
||||||
|
import { MCPConnectionFactory } from '../MCPConnectionFactory';
|
||||||
|
import { MCPConnection } from '../connection';
|
||||||
|
import type * as t from '../types';
|
||||||
|
|
||||||
|
// Mock external dependencies
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: {
|
||||||
|
error: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('../MCPConnectionFactory', () => ({
|
||||||
|
MCPConnectionFactory: {
|
||||||
|
create: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('../connection');
|
||||||
|
|
||||||
|
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||||
|
|
||||||
|
describe('ConnectionsRepository', () => {
|
||||||
|
let repository: ConnectionsRepository;
|
||||||
|
let mockServerConfigs: t.MCPServers;
|
||||||
|
let mockConnection: jest.Mocked<MCPConnection>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockServerConfigs = {
|
||||||
|
server1: { url: 'http://localhost:3001' },
|
||||||
|
server2: { command: 'test-command', args: ['--test'] },
|
||||||
|
server3: { url: 'ws://localhost:8080', type: 'websocket' },
|
||||||
|
};
|
||||||
|
|
||||||
|
mockConnection = {
|
||||||
|
isConnected: jest.fn().mockResolvedValue(true),
|
||||||
|
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||||
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
|
||||||
|
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
|
||||||
|
|
||||||
|
repository = new ConnectionsRepository(mockServerConfigs);
|
||||||
|
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('has', () => {
|
||||||
|
it('should return true for existing server', () => {
|
||||||
|
expect(repository.has('server1')).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false for non-existing server', () => {
|
||||||
|
expect(repository.has('nonexistent')).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('get', () => {
|
||||||
|
it('should return existing connected connection', async () => {
|
||||||
|
mockConnection.isConnected.mockResolvedValue(true);
|
||||||
|
repository['connections'].set('server1', mockConnection);
|
||||||
|
|
||||||
|
const result = await repository.get('server1');
|
||||||
|
|
||||||
|
expect(result).toBe(mockConnection);
|
||||||
|
expect(MCPConnectionFactory.create).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create new connection if none exists', async () => {
|
||||||
|
const result = await repository.get('server1');
|
||||||
|
|
||||||
|
expect(result).toBe(mockConnection);
|
||||||
|
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
serverName: 'server1',
|
||||||
|
serverConfig: mockServerConfigs.server1,
|
||||||
|
},
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
expect(repository['connections'].get('server1')).toBe(mockConnection);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create new connection if existing connection is not connected', async () => {
|
||||||
|
const oldConnection = {
|
||||||
|
isConnected: jest.fn().mockResolvedValue(false),
|
||||||
|
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||||
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
repository['connections'].set('server1', oldConnection);
|
||||||
|
|
||||||
|
const result = await repository.get('server1');
|
||||||
|
|
||||||
|
expect(result).toBe(mockConnection);
|
||||||
|
expect(oldConnection.disconnect).toHaveBeenCalled();
|
||||||
|
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
serverName: 'server1',
|
||||||
|
serverConfig: mockServerConfigs.server1,
|
||||||
|
},
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw error for non-existent server configuration', async () => {
|
||||||
|
await expect(repository.get('nonexistent')).rejects.toThrow(
|
||||||
|
'[MCP][nonexistent] Server not found in configuration',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle MCPConnectionFactory.create errors', async () => {
|
||||||
|
const createError = new Error('Connection creation failed');
|
||||||
|
(MCPConnectionFactory.create as jest.Mock).mockRejectedValue(createError);
|
||||||
|
|
||||||
|
await expect(repository.get('server1')).rejects.toThrow('Connection creation failed');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getMany', () => {
|
||||||
|
it('should return connections for multiple servers', async () => {
|
||||||
|
const result = await repository.getMany(['server1', 'server3']);
|
||||||
|
|
||||||
|
expect(result).toBeInstanceOf(Map);
|
||||||
|
expect(result.size).toBe(2);
|
||||||
|
expect(result.get('server1')).toBe(mockConnection);
|
||||||
|
expect(result.get('server3')).toBe(mockConnection);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getLoaded', () => {
|
||||||
|
it('should return connections for loaded servers only', async () => {
|
||||||
|
// Load one connection
|
||||||
|
await repository.get('server1');
|
||||||
|
|
||||||
|
const result = await repository.getLoaded();
|
||||||
|
|
||||||
|
expect(result).toBeInstanceOf(Map);
|
||||||
|
expect(result.size).toBe(1);
|
||||||
|
expect(result.get('server1')).toBe(mockConnection);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return empty map when no connections are loaded', async () => {
|
||||||
|
const result = await repository.getLoaded();
|
||||||
|
|
||||||
|
expect(result).toBeInstanceOf(Map);
|
||||||
|
expect(result.size).toBe(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getAll', () => {
|
||||||
|
it('should return connections for all configured servers', async () => {
|
||||||
|
const result = await repository.getAll();
|
||||||
|
|
||||||
|
expect(result).toBeInstanceOf(Map);
|
||||||
|
expect(result.size).toBe(3);
|
||||||
|
expect(result.get('server1')).toBe(mockConnection);
|
||||||
|
expect(result.get('server2')).toBe(mockConnection);
|
||||||
|
expect(result.get('server3')).toBe(mockConnection);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('disconnect', () => {
|
||||||
|
it('should disconnect and remove existing connection', async () => {
|
||||||
|
repository['connections'].set('server1', mockConnection);
|
||||||
|
|
||||||
|
await repository.disconnect('server1');
|
||||||
|
|
||||||
|
expect(mockConnection.disconnect).toHaveBeenCalled();
|
||||||
|
expect(repository['connections'].has('server1')).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle disconnect error gracefully', async () => {
|
||||||
|
const disconnectError = new Error('Disconnect failed');
|
||||||
|
mockConnection.disconnect.mockRejectedValue(disconnectError);
|
||||||
|
repository['connections'].set('server1', mockConnection);
|
||||||
|
|
||||||
|
await repository.disconnect('server1');
|
||||||
|
|
||||||
|
expect(mockConnection.disconnect).toHaveBeenCalled();
|
||||||
|
expect(repository['connections'].has('server1')).toBe(false);
|
||||||
|
expect(mockLogger.error).toHaveBeenCalledWith(
|
||||||
|
'[MCP][server1] Error disconnecting',
|
||||||
|
disconnectError,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('disconnectAll', () => {
|
||||||
|
it('should disconnect all active connections', () => {
|
||||||
|
const mockConnection1 = {
|
||||||
|
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||||
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
const mockConnection2 = {
|
||||||
|
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||||
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
const mockConnection3 = {
|
||||||
|
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||||
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
|
||||||
|
repository['connections'].set('server1', mockConnection1);
|
||||||
|
repository['connections'].set('server2', mockConnection2);
|
||||||
|
repository['connections'].set('server3', mockConnection3);
|
||||||
|
|
||||||
|
const promises = repository.disconnectAll();
|
||||||
|
|
||||||
|
expect(promises).toHaveLength(3);
|
||||||
|
expect(Array.isArray(promises)).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
347
packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts
Normal file
347
packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts
Normal file
|
@ -0,0 +1,347 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import type { TUser } from 'librechat-data-provider';
|
||||||
|
import type { FlowStateManager } from '~/flow/manager';
|
||||||
|
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||||
|
import { MCPConnectionFactory } from '../MCPConnectionFactory';
|
||||||
|
import { MCPOAuthHandler } from '~/mcp/oauth';
|
||||||
|
import { MCPConnection } from '../connection';
|
||||||
|
import { processMCPEnv } from '~/utils';
|
||||||
|
import type * as t from '../types';
|
||||||
|
|
||||||
|
jest.mock('../connection');
|
||||||
|
jest.mock('~/mcp/oauth');
|
||||||
|
jest.mock('~/utils');
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: {
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
debug: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||||
|
const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction<typeof processMCPEnv>;
|
||||||
|
const mockMCPConnection = MCPConnection as jest.MockedClass<typeof MCPConnection>;
|
||||||
|
const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked<typeof MCPOAuthHandler>;
|
||||||
|
|
||||||
|
describe('MCPConnectionFactory', () => {
|
||||||
|
let mockUser: TUser;
|
||||||
|
let mockServerConfig: t.MCPOptions;
|
||||||
|
let mockFlowManager: jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
|
||||||
|
let mockConnectionInstance: jest.Mocked<MCPConnection>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
mockUser = {
|
||||||
|
id: 'user123',
|
||||||
|
email: 'test@example.com',
|
||||||
|
} as TUser;
|
||||||
|
|
||||||
|
mockServerConfig = {
|
||||||
|
command: 'node',
|
||||||
|
args: ['server.js'],
|
||||||
|
initTimeout: 5000,
|
||||||
|
} as t.MCPOptions;
|
||||||
|
|
||||||
|
mockFlowManager = {
|
||||||
|
createFlow: jest.fn(),
|
||||||
|
createFlowWithHandler: jest.fn(),
|
||||||
|
getFlowState: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
|
||||||
|
|
||||||
|
mockConnectionInstance = {
|
||||||
|
connect: jest.fn(),
|
||||||
|
isConnected: jest.fn(),
|
||||||
|
setOAuthTokens: jest.fn(),
|
||||||
|
on: jest.fn().mockReturnValue(mockConnectionInstance),
|
||||||
|
emit: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
|
||||||
|
mockMCPConnection.mockImplementation(() => mockConnectionInstance);
|
||||||
|
mockProcessMCPEnv.mockReturnValue(mockServerConfig);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('static create method', () => {
|
||||||
|
it('should create a basic connection without OAuth', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const connection = await MCPConnectionFactory.create(basicOptions);
|
||||||
|
|
||||||
|
expect(connection).toBe(mockConnectionInstance);
|
||||||
|
expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, undefined, undefined);
|
||||||
|
expect(mockMCPConnection).toHaveBeenCalledWith({
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
userId: undefined,
|
||||||
|
oauthTokens: null,
|
||||||
|
});
|
||||||
|
expect(mockConnectionInstance.connect).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create a connection with OAuth', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
const oauthOptions = {
|
||||||
|
useOAuth: true as const,
|
||||||
|
user: mockUser,
|
||||||
|
flowManager: mockFlowManager,
|
||||||
|
tokenMethods: {
|
||||||
|
findToken: jest.fn(),
|
||||||
|
createToken: jest.fn(),
|
||||||
|
updateToken: jest.fn(),
|
||||||
|
deleteTokens: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockTokens: MCPOAuthTokens = {
|
||||||
|
access_token: 'access123',
|
||||||
|
refresh_token: 'refresh123',
|
||||||
|
token_type: 'Bearer',
|
||||||
|
obtained_at: Date.now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
mockFlowManager.createFlowWithHandler.mockResolvedValue(mockTokens);
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||||
|
|
||||||
|
expect(connection).toBe(mockConnectionInstance);
|
||||||
|
expect(mockProcessMCPEnv).toHaveBeenCalledWith(mockServerConfig, mockUser, undefined);
|
||||||
|
expect(mockMCPConnection).toHaveBeenCalledWith({
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
userId: 'user123',
|
||||||
|
oauthTokens: mockTokens,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('OAuth token handling', () => {
|
||||||
|
it('should return null when no findToken method is provided', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
const oauthOptions = {
|
||||||
|
useOAuth: true as const,
|
||||||
|
user: mockUser,
|
||||||
|
flowManager: mockFlowManager,
|
||||||
|
tokenMethods: {
|
||||||
|
findToken: undefined as unknown as () => Promise<any>,
|
||||||
|
createToken: jest.fn(),
|
||||||
|
updateToken: jest.fn(),
|
||||||
|
deleteTokens: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(true);
|
||||||
|
|
||||||
|
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||||
|
|
||||||
|
expect(mockFlowManager.createFlowWithHandler).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle token retrieval errors gracefully', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
const oauthOptions = {
|
||||||
|
useOAuth: true as const,
|
||||||
|
user: mockUser,
|
||||||
|
flowManager: mockFlowManager,
|
||||||
|
tokenMethods: {
|
||||||
|
findToken: jest.fn(),
|
||||||
|
createToken: jest.fn(),
|
||||||
|
updateToken: jest.fn(),
|
||||||
|
deleteTokens: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockFlowManager.createFlowWithHandler.mockRejectedValue(new Error('Token fetch failed'));
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||||
|
|
||||||
|
expect(connection).toBe(mockConnectionInstance);
|
||||||
|
expect(mockMCPConnection).toHaveBeenCalledWith({
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
userId: 'user123',
|
||||||
|
oauthTokens: null,
|
||||||
|
});
|
||||||
|
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('No existing tokens found or error loading tokens'),
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('OAuth event handling', () => {
|
||||||
|
it('should handle oauthRequired event for returnOnOAuth scenario', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: {
|
||||||
|
...mockServerConfig,
|
||||||
|
url: 'https://api.example.com',
|
||||||
|
type: 'sse' as const,
|
||||||
|
} as t.SSEOptions,
|
||||||
|
};
|
||||||
|
|
||||||
|
const oauthOptions = {
|
||||||
|
useOAuth: true as const,
|
||||||
|
user: mockUser,
|
||||||
|
flowManager: mockFlowManager,
|
||||||
|
returnOnOAuth: true,
|
||||||
|
oauthStart: jest.fn(),
|
||||||
|
tokenMethods: {
|
||||||
|
findToken: jest.fn(),
|
||||||
|
createToken: jest.fn(),
|
||||||
|
updateToken: jest.fn(),
|
||||||
|
deleteTokens: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockFlowData = {
|
||||||
|
authorizationUrl: 'https://auth.example.com',
|
||||||
|
flowId: 'flow123',
|
||||||
|
flowMetadata: {
|
||||||
|
serverName: 'test-server',
|
||||||
|
userId: 'user123',
|
||||||
|
serverUrl: 'https://api.example.com',
|
||||||
|
state: 'random-state',
|
||||||
|
clientInfo: { client_id: 'client123' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
|
||||||
|
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||||
|
|
||||||
|
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
|
||||||
|
mockConnectionInstance.on.mockImplementation((event, handler) => {
|
||||||
|
if (event === 'oauthRequired') {
|
||||||
|
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
|
||||||
|
}
|
||||||
|
return mockConnectionInstance;
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||||
|
} catch {
|
||||||
|
// Expected to fail due to connection not established
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(oauthRequiredHandler!).toBeDefined();
|
||||||
|
|
||||||
|
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||||
|
|
||||||
|
expect(mockMCPOAuthHandler.initiateOAuthFlow).toHaveBeenCalledWith(
|
||||||
|
'test-server',
|
||||||
|
'https://api.example.com',
|
||||||
|
'user123',
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
|
||||||
|
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
|
||||||
|
'oauthFailed',
|
||||||
|
expect.objectContaining({
|
||||||
|
message: 'OAuth flow initiated - return early',
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('connection retry logic', () => {
|
||||||
|
it('should establish connection successfully', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig, // Use default 5000ms timeout
|
||||||
|
};
|
||||||
|
|
||||||
|
mockConnectionInstance.connect.mockResolvedValue(undefined);
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const connection = await MCPConnectionFactory.create(basicOptions);
|
||||||
|
|
||||||
|
expect(connection).toBe(mockConnectionInstance);
|
||||||
|
expect(mockConnectionInstance.connect).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle OAuth errors during connection attempts', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
const oauthOptions = {
|
||||||
|
useOAuth: true as const,
|
||||||
|
user: mockUser,
|
||||||
|
flowManager: mockFlowManager,
|
||||||
|
oauthStart: jest.fn(),
|
||||||
|
tokenMethods: {
|
||||||
|
findToken: jest.fn(),
|
||||||
|
createToken: jest.fn(),
|
||||||
|
updateToken: jest.fn(),
|
||||||
|
deleteTokens: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const oauthError = new Error('Non-200 status code (401)');
|
||||||
|
(oauthError as unknown as Record<string, unknown>).isOAuthError = true;
|
||||||
|
|
||||||
|
mockConnectionInstance.connect.mockRejectedValue(oauthError);
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||||
|
|
||||||
|
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow(
|
||||||
|
'Non-200 status code (401)',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('OAuth required, stopping connection attempts'),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('isOAuthError method', () => {
|
||||||
|
it('should identify OAuth errors by message content', async () => {
|
||||||
|
const basicOptions = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverConfig: mockServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
const oauthOptions = {
|
||||||
|
useOAuth: true as const,
|
||||||
|
user: mockUser,
|
||||||
|
flowManager: mockFlowManager,
|
||||||
|
tokenMethods: {
|
||||||
|
findToken: jest.fn(),
|
||||||
|
createToken: jest.fn(),
|
||||||
|
updateToken: jest.fn(),
|
||||||
|
deleteTokens: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const error401 = new Error('401 Unauthorized');
|
||||||
|
|
||||||
|
mockConnectionInstance.connect.mockRejectedValue(error401);
|
||||||
|
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||||
|
|
||||||
|
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow('401');
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('OAuth required, stopping connection attempts'),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
287
packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts
Normal file
287
packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts
Normal file
|
@ -0,0 +1,287 @@
|
||||||
|
import { readFileSync } from 'fs';
|
||||||
|
import { join } from 'path';
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import { load as yamlLoad } from 'js-yaml';
|
||||||
|
import { ConnectionsRepository } from '../ConnectionsRepository';
|
||||||
|
import { MCPServersRegistry } from '../MCPServersRegistry';
|
||||||
|
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||||
|
import { MCPConnection } from '../connection';
|
||||||
|
import type * as t from '../types';
|
||||||
|
|
||||||
|
// Mock external dependencies
|
||||||
|
jest.mock('../oauth/detectOAuth');
|
||||||
|
jest.mock('../ConnectionsRepository');
|
||||||
|
jest.mock('../connection');
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: {
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
debug: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock processMCPEnv to verify it's called and adds a processed marker
|
||||||
|
jest.mock('~/utils', () => ({
|
||||||
|
...jest.requireActual('~/utils'),
|
||||||
|
processMCPEnv: jest.fn((config) => ({
|
||||||
|
...config,
|
||||||
|
_processed: true, // Simple marker to verify processing occurred
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
|
||||||
|
typeof detectOAuthRequirement
|
||||||
|
>;
|
||||||
|
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||||
|
|
||||||
|
describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
|
let rawConfigs: t.MCPServers;
|
||||||
|
let expectedParsedConfigs: Record<string, any>;
|
||||||
|
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
|
||||||
|
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Load fixtures
|
||||||
|
const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml');
|
||||||
|
const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml');
|
||||||
|
|
||||||
|
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
|
||||||
|
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
|
||||||
|
string,
|
||||||
|
any
|
||||||
|
>;
|
||||||
|
|
||||||
|
// Setup mock connections
|
||||||
|
mockConnections = new Map();
|
||||||
|
const serverNames = Object.keys(rawConfigs);
|
||||||
|
|
||||||
|
serverNames.forEach((serverName) => {
|
||||||
|
const mockConnection = {
|
||||||
|
client: {
|
||||||
|
listTools: jest.fn(),
|
||||||
|
getInstructions: jest.fn(),
|
||||||
|
getServerCapabilities: jest.fn(),
|
||||||
|
},
|
||||||
|
} as unknown as jest.Mocked<MCPConnection>;
|
||||||
|
|
||||||
|
// Setup mock responses based on expected configs
|
||||||
|
const expectedConfig = expectedParsedConfigs[serverName];
|
||||||
|
|
||||||
|
// Mock listTools response
|
||||||
|
if (expectedConfig.tools) {
|
||||||
|
const toolNames = expectedConfig.tools.split(', ');
|
||||||
|
const tools = toolNames.map((name: string) => ({
|
||||||
|
name,
|
||||||
|
description: `Description for ${name}`,
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
input: { type: 'string' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
mockConnection.client.listTools.mockResolvedValue({ tools });
|
||||||
|
} else {
|
||||||
|
mockConnection.client.listTools.mockResolvedValue({ tools: [] });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock getInstructions response
|
||||||
|
if (expectedConfig.serverInstructions) {
|
||||||
|
mockConnection.client.getInstructions.mockReturnValue(expectedConfig.serverInstructions);
|
||||||
|
} else {
|
||||||
|
mockConnection.client.getInstructions.mockReturnValue(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock getServerCapabilities response
|
||||||
|
if (expectedConfig.capabilities) {
|
||||||
|
const capabilities = JSON.parse(expectedConfig.capabilities);
|
||||||
|
mockConnection.client.getServerCapabilities.mockReturnValue(capabilities);
|
||||||
|
} else {
|
||||||
|
mockConnection.client.getServerCapabilities.mockReturnValue(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
mockConnections.set(serverName, mockConnection);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Setup ConnectionsRepository mock
|
||||||
|
mockConnectionsRepo = {
|
||||||
|
get: jest.fn(),
|
||||||
|
getLoaded: jest.fn(),
|
||||||
|
disconnectAll: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<ConnectionsRepository>;
|
||||||
|
|
||||||
|
mockConnectionsRepo.get.mockImplementation((serverName: string) =>
|
||||||
|
Promise.resolve(mockConnections.get(serverName)!),
|
||||||
|
);
|
||||||
|
|
||||||
|
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
|
||||||
|
|
||||||
|
(ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo);
|
||||||
|
|
||||||
|
// Setup OAuth detection mock with deterministic results
|
||||||
|
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||||
|
const oauthResults: Record<string, any> = {
|
||||||
|
'https://api.github.com/mcp': {
|
||||||
|
requiresOAuth: true,
|
||||||
|
metadata: {
|
||||||
|
authorization_url: 'https://github.com/login/oauth/authorize',
|
||||||
|
token_url: 'https://github.com/login/oauth/access_token',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'https://api.disabled.com/mcp': {
|
||||||
|
requiresOAuth: false,
|
||||||
|
metadata: null,
|
||||||
|
},
|
||||||
|
'https://api.public.com/mcp': {
|
||||||
|
requiresOAuth: false,
|
||||||
|
metadata: null,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return Promise.resolve(oauthResults[url] || { requiresOAuth: false, metadata: null });
|
||||||
|
});
|
||||||
|
|
||||||
|
// Clear all mocks
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('initialize() method', () => {
|
||||||
|
it('should only run initialization once', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
await registry.initialize(); // Second call should not re-run
|
||||||
|
|
||||||
|
// Verify that connections are only requested for servers that need them
|
||||||
|
// (servers with serverInstructions=true or all servers for capabilities)
|
||||||
|
expect(mockConnectionsRepo.get).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should set all public properties correctly after initialization', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
// Verify initial state
|
||||||
|
expect(registry.oauthServers).toBeNull();
|
||||||
|
expect(registry.serverInstructions).toBeNull();
|
||||||
|
expect(registry.toolFunctions).toBeNull();
|
||||||
|
expect(registry.appServerConfigs).toBeNull();
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
// Test oauthServers Set
|
||||||
|
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||||
|
expect(registry.oauthServers).toEqual(
|
||||||
|
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test serverInstructions
|
||||||
|
expect(registry.serverInstructions).toEqual({
|
||||||
|
oauth_server: 'GitHub MCP server instructions',
|
||||||
|
stdio_server: 'Follow these instructions for stdio server',
|
||||||
|
non_oauth_server: 'Public API instructions',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Test appServerConfigs (startup enabled, non-OAuth servers only)
|
||||||
|
expect(registry.appServerConfigs).toEqual({
|
||||||
|
stdio_server: rawConfigs.stdio_server,
|
||||||
|
websocket_server: rawConfigs.websocket_server,
|
||||||
|
non_oauth_server: rawConfigs.non_oauth_server,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Test toolFunctions (only 2 servers have tools: oauth_server has 1, stdio_server has 2)
|
||||||
|
const expectedToolFunctions = {
|
||||||
|
get_repository_mcp_oauth_server: {
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'get_repository_mcp_oauth_server',
|
||||||
|
description: 'Description for get_repository',
|
||||||
|
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
file_read_mcp_stdio_server: {
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'file_read_mcp_stdio_server',
|
||||||
|
description: 'Description for file_read',
|
||||||
|
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
file_write_mcp_stdio_server: {
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'file_write_mcp_stdio_server',
|
||||||
|
description: 'Description for file_write',
|
||||||
|
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle errors gracefully and continue initialization', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
// Make one server throw an error
|
||||||
|
mockDetectOAuthRequirement.mockRejectedValueOnce(new Error('OAuth detection failed'));
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
// Should still initialize successfully
|
||||||
|
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||||
|
expect(registry.toolFunctions).toBeDefined();
|
||||||
|
|
||||||
|
// Error should be logged
|
||||||
|
expect(mockLogger.error).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('[MCP][oauth_server] Failed to fetch OAuth requirement:'),
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should disconnect all connections after initialization', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
expect(mockConnectionsRepo.disconnectAll).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should log configuration updates for each server', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
const serverNames = Object.keys(rawConfigs);
|
||||||
|
serverNames.forEach((serverName) => {
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining(`[MCP][${serverName}] URL:`),
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining(`[MCP][${serverName}] OAuth Required:`),
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining(`[MCP][${serverName}] Capabilities:`),
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining(`[MCP][${serverName}] Tools:`),
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining(`[MCP][${serverName}] Server Instructions:`),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should have parsedConfigs matching the expected fixture after initialization', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
// Compare the actual parsedConfigs against the expected fixture
|
||||||
|
expect(registry.parsedConfigs).toEqual(expectedParsedConfigs);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -1,7 +1,7 @@
|
||||||
import type { PluginAuthMethods } from '@librechat/data-schemas';
|
import type { PluginAuthMethods } from '@librechat/data-schemas';
|
||||||
import type { GenericTool } from '@librechat/agents';
|
import type { GenericTool } from '@librechat/agents';
|
||||||
import { getPluginAuthMap } from '~/agents/auth';
|
import { getPluginAuthMap } from '~/agents/auth';
|
||||||
import { getUserMCPAuthMap } from './auth';
|
import { getUserMCPAuthMap } from '../auth';
|
||||||
|
|
||||||
jest.mock('~/agents/auth', () => ({
|
jest.mock('~/agents/auth', () => ({
|
||||||
getPluginAuthMap: jest.fn(),
|
getPluginAuthMap: jest.fn(),
|
|
@ -0,0 +1,76 @@
|
||||||
|
// Integration tests for OAuth detection against real public MCP servers
|
||||||
|
// These tests verify the actual behavior against live endpoints
|
||||||
|
//
|
||||||
|
// DEVELOPMENT ONLY: This file is excluded from the test suite (.dev.ts extension)
|
||||||
|
// Use this for development and debugging OAuth detection behavior
|
||||||
|
//
|
||||||
|
// To run manually from packages/api directory:
|
||||||
|
// npx jest --testMatch="**/detectOAuth.integration.dev.ts"
|
||||||
|
|
||||||
|
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||||
|
|
||||||
|
describe('OAuth Detection Integration Tests', () => {
|
||||||
|
const NETWORK_TIMEOUT = 10000;
|
||||||
|
|
||||||
|
interface TestServer {
|
||||||
|
name: string;
|
||||||
|
url: string;
|
||||||
|
expectedOAuth: boolean;
|
||||||
|
expectedMethod: string;
|
||||||
|
withMeta: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const testServers: TestServer[] = [
|
||||||
|
{
|
||||||
|
name: 'GitHub Copilot MCP Server',
|
||||||
|
url: 'https://api.githubcopilot.com/mcp',
|
||||||
|
expectedOAuth: true,
|
||||||
|
expectedMethod: '401-challenge-metadata',
|
||||||
|
withMeta: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'GitHub API (401 without metadata)',
|
||||||
|
url: 'https://api.github.com/user',
|
||||||
|
expectedOAuth: true,
|
||||||
|
expectedMethod: 'no-metadata-found',
|
||||||
|
withMeta: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'Stytch Todo MCP Server',
|
||||||
|
url: 'https://mcp-stytch-consumer-todo-list.maxwell-gerber42.workers.dev',
|
||||||
|
expectedOAuth: true,
|
||||||
|
expectedMethod: 'protected-resource-metadata',
|
||||||
|
withMeta: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'HTTPBin (Non-OAuth)',
|
||||||
|
url: 'https://httpbin.org',
|
||||||
|
expectedOAuth: false,
|
||||||
|
expectedMethod: 'no-metadata-found',
|
||||||
|
withMeta: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'Unreachable Server',
|
||||||
|
url: 'https://definitely-not-a-real-server-12345.com',
|
||||||
|
expectedOAuth: false,
|
||||||
|
expectedMethod: 'no-metadata-found',
|
||||||
|
withMeta: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
describe('detectOAuthRequirement integration', () => {
|
||||||
|
testServers.forEach((server) => {
|
||||||
|
it(
|
||||||
|
`should handle ${server.name}`,
|
||||||
|
async () => {
|
||||||
|
const result = await detectOAuthRequirement(server.url);
|
||||||
|
|
||||||
|
expect(result.requiresOAuth).toBe(server.expectedOAuth);
|
||||||
|
expect(result.method).toBe(server.expectedMethod);
|
||||||
|
expect(result.metadata == null).toBe(!server.withMeta);
|
||||||
|
},
|
||||||
|
NETWORK_TIMEOUT,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Expected parsed MCP server configurations after running initialize()
|
||||||
|
# These represent the expected state of parsedConfigs after all fetch functions complete
|
||||||
|
|
||||||
|
oauth_server:
|
||||||
|
_processed: true
|
||||||
|
type: "streamable-http"
|
||||||
|
url: "https://api.github.com/mcp"
|
||||||
|
headers:
|
||||||
|
Authorization: "Bearer {{GITHUB_TOKEN}}"
|
||||||
|
serverInstructions: "GitHub MCP server instructions"
|
||||||
|
requiresOAuth: true
|
||||||
|
oauthMetadata:
|
||||||
|
authorization_url: "https://github.com/login/oauth/authorize"
|
||||||
|
token_url: "https://github.com/login/oauth/access_token"
|
||||||
|
capabilities: '{"tools":{"listChanged":true},"resources":{},"prompts":{}}'
|
||||||
|
tools: "get_repository"
|
||||||
|
|
||||||
|
oauth_predefined:
|
||||||
|
_processed: true
|
||||||
|
type: "sse"
|
||||||
|
url: "https://api.example.com/sse"
|
||||||
|
requiresOAuth: true
|
||||||
|
oauthMetadata:
|
||||||
|
authorization_url: "https://example.com/oauth/authorize"
|
||||||
|
token_url: "https://example.com/oauth/token"
|
||||||
|
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||||
|
tools: ""
|
||||||
|
|
||||||
|
stdio_server:
|
||||||
|
_processed: true
|
||||||
|
command: "node"
|
||||||
|
args: ["server.js"]
|
||||||
|
env:
|
||||||
|
API_KEY: "${TEST_API_KEY}"
|
||||||
|
startup: true
|
||||||
|
serverInstructions: "Follow these instructions for stdio server"
|
||||||
|
requiresOAuth: false
|
||||||
|
capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}'
|
||||||
|
tools: "file_read, file_write"
|
||||||
|
|
||||||
|
websocket_server:
|
||||||
|
_processed: true
|
||||||
|
type: "websocket"
|
||||||
|
url: "ws://localhost:3001/mcp"
|
||||||
|
startup: true
|
||||||
|
requiresOAuth: false
|
||||||
|
oauthMetadata: null
|
||||||
|
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||||
|
tools: ""
|
||||||
|
|
||||||
|
disabled_server:
|
||||||
|
_processed: true
|
||||||
|
type: "streamable-http"
|
||||||
|
url: "https://api.disabled.com/mcp"
|
||||||
|
startup: false
|
||||||
|
requiresOAuth: false
|
||||||
|
oauthMetadata: null
|
||||||
|
|
||||||
|
non_oauth_server:
|
||||||
|
_processed: true
|
||||||
|
type: "streamable-http"
|
||||||
|
url: "https://api.public.com/mcp"
|
||||||
|
requiresOAuth: false
|
||||||
|
serverInstructions: "Public API instructions"
|
||||||
|
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||||
|
tools: ""
|
||||||
|
|
||||||
|
oauth_startup_enabled:
|
||||||
|
_processed: true
|
||||||
|
type: "sse"
|
||||||
|
url: "https://api.oauth-startup.com/sse"
|
||||||
|
requiresOAuth: true
|
||||||
|
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||||
|
tools: ""
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Raw MCP server configurations used as input to MCPServersRegistry constructor
|
||||||
|
# These configs test different code paths in the initialization process
|
||||||
|
|
||||||
|
# Test OAuth detection with URL - should trigger fetchOAuthRequirement
|
||||||
|
oauth_server:
|
||||||
|
type: "streamable-http"
|
||||||
|
url: "https://api.github.com/mcp"
|
||||||
|
headers:
|
||||||
|
Authorization: "Bearer {{GITHUB_TOKEN}}"
|
||||||
|
serverInstructions: true
|
||||||
|
|
||||||
|
# Test OAuth already specified - should skip OAuth detection
|
||||||
|
oauth_predefined:
|
||||||
|
type: "sse"
|
||||||
|
url: "https://api.example.com/sse"
|
||||||
|
requiresOAuth: true
|
||||||
|
oauthMetadata:
|
||||||
|
authorization_url: "https://example.com/oauth/authorize"
|
||||||
|
token_url: "https://example.com/oauth/token"
|
||||||
|
|
||||||
|
# Test stdio server without URL - should set requiresOAuth to false
|
||||||
|
stdio_server:
|
||||||
|
command: "node"
|
||||||
|
args: ["server.js"]
|
||||||
|
env:
|
||||||
|
API_KEY: "${TEST_API_KEY}"
|
||||||
|
startup: true
|
||||||
|
serverInstructions: "Follow these instructions for stdio server"
|
||||||
|
|
||||||
|
# Test websocket server with capabilities but no tools
|
||||||
|
websocket_server:
|
||||||
|
type: "websocket"
|
||||||
|
url: "ws://localhost:3001/mcp"
|
||||||
|
startup: true
|
||||||
|
|
||||||
|
# Test server with startup disabled - should not be included in appServerConfigs
|
||||||
|
disabled_server:
|
||||||
|
type: "streamable-http"
|
||||||
|
url: "https://api.disabled.com/mcp"
|
||||||
|
startup: false
|
||||||
|
|
||||||
|
# Test non-OAuth server - should be included in appServerConfigs
|
||||||
|
non_oauth_server:
|
||||||
|
type: "streamable-http"
|
||||||
|
url: "https://api.public.com/mcp"
|
||||||
|
requiresOAuth: false
|
||||||
|
serverInstructions: true
|
||||||
|
|
||||||
|
# Test server with OAuth but startup enabled - should not be in appServerConfigs
|
||||||
|
oauth_startup_enabled:
|
||||||
|
type: "sse"
|
||||||
|
url: "https://api.oauth-startup.com/sse"
|
||||||
|
requiresOAuth: true
|
|
@ -1,5 +1,5 @@
|
||||||
import { MCPOAuthHandler } from './handler';
|
|
||||||
import type { MCPOptions } from 'librechat-data-provider';
|
import type { MCPOptions } from 'librechat-data-provider';
|
||||||
|
import { MCPOAuthHandler } from '~/mcp/oauth';
|
||||||
|
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
logger: {
|
logger: {
|
|
@ -1,4 +1,4 @@
|
||||||
import { normalizeServerName } from './utils';
|
import { normalizeServerName } from '../utils';
|
||||||
|
|
||||||
describe('normalizeServerName', () => {
|
describe('normalizeServerName', () => {
|
||||||
it('should not modify server names that already match the pattern', () => {
|
it('should not modify server names that already match the pattern', () => {
|
|
@ -2,7 +2,7 @@
|
||||||
// zod.spec.ts
|
// zod.spec.ts
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import type { JsonSchemaType } from '~/types';
|
import type { JsonSchemaType } from '~/types';
|
||||||
import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from './zod';
|
import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from '../zod';
|
||||||
|
|
||||||
describe('convertJsonSchemaToZod', () => {
|
describe('convertJsonSchemaToZod', () => {
|
||||||
describe('primitive types', () => {
|
describe('primitive types', () => {
|
|
@ -1,17 +1,18 @@
|
||||||
import { EventEmitter } from 'events';
|
import { EventEmitter } from 'events';
|
||||||
import { logger } from '@librechat/data-schemas';
|
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
|
||||||
import {
|
import {
|
||||||
StdioClientTransport,
|
StdioClientTransport,
|
||||||
getDefaultEnvironment,
|
getDefaultEnvironment,
|
||||||
} from '@modelcontextprotocol/sdk/client/stdio.js';
|
} from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||||
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
|
||||||
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||||
|
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||||
|
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
||||||
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||||
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||||
import type { MCPOAuthTokens } from './oauth/types';
|
import type { MCPOAuthTokens } from './oauth/types';
|
||||||
|
import { mcpConfig } from './mcpConfig';
|
||||||
import type * as t from './types';
|
import type * as t from './types';
|
||||||
|
|
||||||
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
|
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
|
||||||
|
@ -56,9 +57,17 @@ function isStreamableHTTPOptions(options: t.MCPOptions): options is t.Streamable
|
||||||
}
|
}
|
||||||
|
|
||||||
const FIVE_MINUTES = 5 * 60 * 1000;
|
const FIVE_MINUTES = 5 * 60 * 1000;
|
||||||
|
|
||||||
|
interface MCPConnectionParams {
|
||||||
|
serverName: string;
|
||||||
|
serverConfig: t.MCPOptions;
|
||||||
|
userId?: string;
|
||||||
|
oauthTokens?: MCPOAuthTokens | null;
|
||||||
|
}
|
||||||
|
|
||||||
export class MCPConnection extends EventEmitter {
|
export class MCPConnection extends EventEmitter {
|
||||||
private static instance: MCPConnection | null = null;
|
|
||||||
public client: Client;
|
public client: Client;
|
||||||
|
private options: t.MCPOptions;
|
||||||
private transport: Transport | null = null; // Make this nullable
|
private transport: Transport | null = null; // Make this nullable
|
||||||
private connectionState: t.ConnectionState = 'disconnected';
|
private connectionState: t.ConnectionState = 'disconnected';
|
||||||
private connectPromise: Promise<void> | null = null;
|
private connectPromise: Promise<void> | null = null;
|
||||||
|
@ -70,26 +79,23 @@ export class MCPConnection extends EventEmitter {
|
||||||
private reconnectAttempts = 0;
|
private reconnectAttempts = 0;
|
||||||
private readonly userId?: string;
|
private readonly userId?: string;
|
||||||
private lastPingTime: number;
|
private lastPingTime: number;
|
||||||
|
private lastConnectionCheckAt: number = 0;
|
||||||
private oauthTokens?: MCPOAuthTokens | null;
|
private oauthTokens?: MCPOAuthTokens | null;
|
||||||
private oauthRequired = false;
|
private oauthRequired = false;
|
||||||
iconPath?: string;
|
iconPath?: string;
|
||||||
timeout?: number;
|
timeout?: number;
|
||||||
url?: string;
|
url?: string;
|
||||||
|
|
||||||
constructor(
|
constructor(params: MCPConnectionParams) {
|
||||||
serverName: string,
|
|
||||||
private readonly options: t.MCPOptions,
|
|
||||||
userId?: string,
|
|
||||||
oauthTokens?: MCPOAuthTokens | null,
|
|
||||||
) {
|
|
||||||
super();
|
super();
|
||||||
this.serverName = serverName;
|
this.options = params.serverConfig;
|
||||||
this.userId = userId;
|
this.serverName = params.serverName;
|
||||||
this.iconPath = options.iconPath;
|
this.userId = params.userId;
|
||||||
this.timeout = options.timeout;
|
this.iconPath = params.serverConfig.iconPath;
|
||||||
|
this.timeout = params.serverConfig.timeout;
|
||||||
this.lastPingTime = Date.now();
|
this.lastPingTime = Date.now();
|
||||||
if (oauthTokens) {
|
if (params.oauthTokens) {
|
||||||
this.oauthTokens = oauthTokens;
|
this.oauthTokens = params.oauthTokens;
|
||||||
}
|
}
|
||||||
this.client = new Client(
|
this.client = new Client(
|
||||||
{
|
{
|
||||||
|
@ -110,28 +116,6 @@ export class MCPConnection extends EventEmitter {
|
||||||
return `[MCP]${userPart}[${this.serverName}]`;
|
return `[MCP]${userPart}[${this.serverName}]`;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static getInstance(
|
|
||||||
serverName: string,
|
|
||||||
options: t.MCPOptions,
|
|
||||||
userId?: string,
|
|
||||||
): MCPConnection {
|
|
||||||
if (!MCPConnection.instance) {
|
|
||||||
MCPConnection.instance = new MCPConnection(serverName, options, userId);
|
|
||||||
}
|
|
||||||
return MCPConnection.instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static getExistingInstance(): MCPConnection | null {
|
|
||||||
return MCPConnection.instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static async destroyInstance(): Promise<void> {
|
|
||||||
if (MCPConnection.instance) {
|
|
||||||
await MCPConnection.instance.disconnect();
|
|
||||||
MCPConnection.instance = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private emitError(error: unknown, errorContext: string): void {
|
private emitError(error: unknown, errorContext: string): void {
|
||||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||||
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
|
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
|
||||||
|
@ -589,6 +573,13 @@ export class MCPConnection extends EventEmitter {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we recently checked, skip expensive verification
|
||||||
|
const now = Date.now();
|
||||||
|
if (now - this.lastConnectionCheckAt < mcpConfig.CONNECTION_CHECK_TTL) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
this.lastConnectionCheckAt = now;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Try ping first as it's the lightest check
|
// Try ping first as it's the lightest check
|
||||||
await this.client.ping();
|
await this.client.ping();
|
||||||
|
|
File diff suppressed because it is too large
Load diff
11
packages/api/src/mcp/mcpConfig.ts
Normal file
11
packages/api/src/mcp/mcpConfig.ts
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
import { math, isEnabled } from '~/utils';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Centralized configuration for MCP-related environment variables.
|
||||||
|
* Provides typed access to MCP settings with default values.
|
||||||
|
*/
|
||||||
|
export const mcpConfig = {
|
||||||
|
OAUTH_ON_AUTH_ERROR: isEnabled(process.env.MCP_OAUTH_ON_AUTH_ERROR ?? true),
|
||||||
|
OAUTH_DETECTION_TIMEOUT: math(process.env.MCP_OAUTH_DETECTION_TIMEOUT ?? 5000),
|
||||||
|
CONNECTION_CHECK_TTL: math(process.env.MCP_CONNECTION_CHECK_TTL ?? 60000),
|
||||||
|
};
|
120
packages/api/src/mcp/oauth/detectOAuth.ts
Normal file
120
packages/api/src/mcp/oauth/detectOAuth.ts
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
// ATTENTION: If you modify OAuth detection logic in this file, run the integration tests to verify:
|
||||||
|
// npx jest --testMatch="**/detectOAuth.integration.dev.ts" (from packages/api directory)
|
||||||
|
//
|
||||||
|
// These tests are excluded from CI because they make live HTTP requests to external services,
|
||||||
|
// which could cause flaky builds due to network issues or changes in third-party endpoints.
|
||||||
|
// Manual testing ensures the OAuth detection still works against real MCP servers.
|
||||||
|
|
||||||
|
import { discoverOAuthProtectedResourceMetadata } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||||
|
import { mcpConfig } from '../mcpConfig';
|
||||||
|
|
||||||
|
export interface OAuthDetectionResult {
|
||||||
|
requiresOAuth: boolean;
|
||||||
|
method: 'protected-resource-metadata' | '401-challenge-metadata' | 'no-metadata-found';
|
||||||
|
metadata?: Record<string, unknown> | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Detects if an MCP server requires OAuth authentication using proactive discovery methods.
|
||||||
|
*
|
||||||
|
* This function implements a comprehensive OAuth detection strategy:
|
||||||
|
* 1. Standard Protected Resource Metadata (RFC 9728) - checks /.well-known/oauth-protected-resource
|
||||||
|
* 2. 401 Challenge Method - checks WWW-Authenticate header for resource_metadata URL
|
||||||
|
* 3. Optional fallback: treat any 401/403 response as OAuth requirement (if MCP_OAUTH_ON_AUTH_ERROR=true)
|
||||||
|
*
|
||||||
|
* @param serverUrl - The MCP server URL to check for OAuth requirements
|
||||||
|
* @returns Promise<OAuthDetectionResult> - OAuth requirement details
|
||||||
|
*/
|
||||||
|
export async function detectOAuthRequirement(serverUrl: string): Promise<OAuthDetectionResult> {
|
||||||
|
const protectedResourceResult = await checkProtectedResourceMetadata(serverUrl);
|
||||||
|
if (protectedResourceResult) return protectedResourceResult;
|
||||||
|
|
||||||
|
const challengeResult = await check401ChallengeMetadata(serverUrl);
|
||||||
|
if (challengeResult) return challengeResult;
|
||||||
|
|
||||||
|
const fallbackResult = await checkAuthErrorFallback(serverUrl);
|
||||||
|
if (fallbackResult) return fallbackResult;
|
||||||
|
|
||||||
|
// No OAuth detected
|
||||||
|
return {
|
||||||
|
requiresOAuth: false,
|
||||||
|
method: 'no-metadata-found',
|
||||||
|
metadata: null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// ------------------------ Private helper functions for OAuth detection -------------------------//
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Checks for OAuth using standard protected resource metadata (RFC 9728)
|
||||||
|
async function checkProtectedResourceMetadata(
|
||||||
|
serverUrl: string,
|
||||||
|
): Promise<OAuthDetectionResult | null> {
|
||||||
|
try {
|
||||||
|
const resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl);
|
||||||
|
|
||||||
|
if (!resourceMetadata?.authorization_servers?.length) return null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
requiresOAuth: true,
|
||||||
|
method: 'protected-resource-metadata',
|
||||||
|
metadata: resourceMetadata,
|
||||||
|
};
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks for OAuth using 401 challenge with resource metadata URL
|
||||||
|
async function check401ChallengeMetadata(serverUrl: string): Promise<OAuthDetectionResult | null> {
|
||||||
|
try {
|
||||||
|
const response = await fetch(serverUrl, {
|
||||||
|
method: 'HEAD',
|
||||||
|
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.status !== 401) return null;
|
||||||
|
|
||||||
|
const wwwAuth = response.headers.get('www-authenticate');
|
||||||
|
const metadataUrl = wwwAuth?.match(/resource_metadata="([^"]+)"/)?.[1];
|
||||||
|
if (!metadataUrl) return null;
|
||||||
|
|
||||||
|
const metadataResponse = await fetch(metadataUrl, {
|
||||||
|
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
|
||||||
|
});
|
||||||
|
const metadata = await metadataResponse.json();
|
||||||
|
|
||||||
|
if (!metadata?.authorization_servers?.length) return null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
requiresOAuth: true,
|
||||||
|
method: '401-challenge-metadata',
|
||||||
|
metadata,
|
||||||
|
};
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback method: treats any auth error as OAuth requirement if configured
|
||||||
|
async function checkAuthErrorFallback(serverUrl: string): Promise<OAuthDetectionResult | null> {
|
||||||
|
try {
|
||||||
|
if (!mcpConfig.OAUTH_ON_AUTH_ERROR) return null;
|
||||||
|
|
||||||
|
const response = await fetch(serverUrl, {
|
||||||
|
method: 'HEAD',
|
||||||
|
signal: AbortSignal.timeout(mcpConfig.OAUTH_DETECTION_TIMEOUT),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.status !== 401 && response.status !== 403) return null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
requiresOAuth: true,
|
||||||
|
method: 'no-metadata-found',
|
||||||
|
metadata: null,
|
||||||
|
};
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
export * from './types';
|
export * from './types';
|
||||||
export * from './handler';
|
export * from './handler';
|
||||||
export * from './tokens';
|
export * from './tokens';
|
||||||
|
export * from './detectOAuth';
|
||||||
|
|
|
@ -3,6 +3,13 @@ import { TokenExchangeMethodEnum } from './types/agents';
|
||||||
import { extractEnvVariable } from './utils';
|
import { extractEnvVariable } from './utils';
|
||||||
|
|
||||||
const BaseOptionsSchema = z.object({
|
const BaseOptionsSchema = z.object({
|
||||||
|
/**
|
||||||
|
* Controls whether the MCP server is initialized during application startup.
|
||||||
|
* - true (default): Server is initialized during app startup and included in app-level connections
|
||||||
|
* - false: Skips initialization at startup and excludes from app-level connections - useful for servers
|
||||||
|
* requiring manual authentication (e.g., GitHub PAT tokens) that need to be configured through the UI after startup
|
||||||
|
*/
|
||||||
|
startup: z.boolean().optional(),
|
||||||
iconPath: z.string().optional(),
|
iconPath: z.string().optional(),
|
||||||
timeout: z.number().optional(),
|
timeout: z.number().optional(),
|
||||||
initTimeout: z.number().optional(),
|
initTimeout: z.number().optional(),
|
||||||
|
@ -15,6 +22,11 @@ const BaseOptionsSchema = z.object({
|
||||||
* - string: Use custom instructions (overrides server-provided)
|
* - string: Use custom instructions (overrides server-provided)
|
||||||
*/
|
*/
|
||||||
serverInstructions: z.union([z.boolean(), z.string()]).optional(),
|
serverInstructions: z.union([z.boolean(), z.string()]).optional(),
|
||||||
|
/**
|
||||||
|
* Whether this server requires OAuth authentication
|
||||||
|
* If not specified, will be auto-detected during construction
|
||||||
|
*/
|
||||||
|
requiresOAuth: z.boolean().optional(),
|
||||||
/**
|
/**
|
||||||
* OAuth configuration for SSE and Streamable HTTP transports
|
* OAuth configuration for SSE and Streamable HTTP transports
|
||||||
* - Optional: OAuth can be auto-discovered on 401 responses
|
* - Optional: OAuth can be auto-discovered on 401 responses
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue