mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-16 16:30:15 +01:00
🚉 feat: MCP Registry Individual Server Init (#9887)
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Waiting to run
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Waiting to run
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Blocked by required conditions
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Waiting to run
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Waiting to run
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Blocked by required conditions
* initialize servers sequentially * adjust for exported properties that are not nullable anymore * use underscore separator * mock with set * customize init timeout via env var
This commit is contained in:
parent
0b2fde73e3
commit
b8720a9b7a
6 changed files with 257 additions and 81 deletions
|
|
@ -450,7 +450,7 @@ async function getMCPSetupData(userId) {
|
||||||
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
||||||
}
|
}
|
||||||
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
||||||
const oauthServers = mcpManager.getOAuthServers() || new Set();
|
const oauthServers = mcpManager.getOAuthServers();
|
||||||
|
|
||||||
return {
|
return {
|
||||||
mcpConfig,
|
mcpConfig,
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
const mockMCPManager = {
|
const mockMCPManager = {
|
||||||
appConnections: { getAll: jest.fn(() => null) },
|
appConnections: { getAll: jest.fn(() => null) },
|
||||||
getUserConnections: jest.fn(() => null),
|
getUserConnections: jest.fn(() => null),
|
||||||
getOAuthServers: jest.fn(() => null),
|
getOAuthServers: jest.fn(() => new Set()),
|
||||||
};
|
};
|
||||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ export class MCPManager extends UserConnectionManager {
|
||||||
/** Initializes the MCPManager by setting up server registry and app connections */
|
/** Initializes the MCPManager by setting up server registry and app connections */
|
||||||
public async initialize() {
|
public async initialize() {
|
||||||
await this.serversRegistry.initialize();
|
await this.serversRegistry.initialize();
|
||||||
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!);
|
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Retrieves an app-level or user-specific connection based on provided arguments */
|
/** Retrieves an app-level or user-specific connection based on provided arguments */
|
||||||
|
|
@ -63,22 +63,23 @@ export class MCPManager extends UserConnectionManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get servers that require OAuth */
|
/** Get servers that require OAuth */
|
||||||
public getOAuthServers(): Set<string> | null {
|
public getOAuthServers(): Set<string> {
|
||||||
return this.serversRegistry.oauthServers!;
|
return this.serversRegistry.oauthServers;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get all servers */
|
/** Get all servers */
|
||||||
public getAllServers(): t.MCPServers | null {
|
public getAllServers(): t.MCPServers {
|
||||||
return this.serversRegistry.rawConfigs!;
|
return this.serversRegistry.rawConfigs;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns all available tool functions from app-level connections */
|
/** Returns all available tool functions from app-level connections */
|
||||||
public getAppToolFunctions(): t.LCAvailableTools | null {
|
public getAppToolFunctions(): t.LCAvailableTools {
|
||||||
return this.serversRegistry.toolFunctions!;
|
return this.serversRegistry.toolFunctions;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns all available tool functions from all connections available to user */
|
/** Returns all available tool functions from all connections available to user */
|
||||||
public async getAllToolFunctions(userId: string): Promise<t.LCAvailableTools | null> {
|
public async getAllToolFunctions(userId: string): Promise<t.LCAvailableTools | null> {
|
||||||
const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions() ?? {};
|
const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions();
|
||||||
const userConnections = this.getUserConnections(userId);
|
const userConnections = this.getUserConnections(userId);
|
||||||
if (!userConnections || userConnections.size === 0) {
|
if (!userConnections || userConnections.size === 0) {
|
||||||
return allToolFunctions;
|
return allToolFunctions;
|
||||||
|
|
@ -120,7 +121,7 @@ export class MCPManager extends UserConnectionManager {
|
||||||
* @returns Object mapping server names to their instructions
|
* @returns Object mapping server names to their instructions
|
||||||
*/
|
*/
|
||||||
public getInstructions(serverNames?: string[]): Record<string, string> {
|
public getInstructions(serverNames?: string[]): Record<string, string> {
|
||||||
const instructions = this.serversRegistry.serverInstructions!;
|
const instructions = this.serversRegistry.serverInstructions;
|
||||||
if (!serverNames) return instructions;
|
if (!serverNames) return instructions;
|
||||||
return pick(instructions, serverNames);
|
return pick(instructions, serverNames);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import pick from 'lodash/pick';
|
|
||||||
import pickBy from 'lodash/pickBy';
|
|
||||||
import mapValues from 'lodash/mapValues';
|
import mapValues from 'lodash/mapValues';
|
||||||
import { logger } from '@librechat/data-schemas';
|
import { logger } from '@librechat/data-schemas';
|
||||||
import { Constants } from 'librechat-data-provider';
|
import { Constants } from 'librechat-data-provider';
|
||||||
|
|
@ -11,6 +9,14 @@ import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||||
import { processMCPEnv, isEnabled } from '~/utils';
|
import { processMCPEnv, isEnabled } from '~/utils';
|
||||||
|
|
||||||
|
const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000;
|
||||||
|
|
||||||
|
function getMCPInitTimeout(): number {
|
||||||
|
return process.env.MCP_INIT_TIMEOUT_MS != null
|
||||||
|
? parseInt(process.env.MCP_INIT_TIMEOUT_MS)
|
||||||
|
: DEFAULT_MCP_INIT_TIMEOUT_MS;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Manages MCP server configurations and metadata discovery.
|
* Manages MCP server configurations and metadata discovery.
|
||||||
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
|
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
|
||||||
|
|
@ -20,19 +26,21 @@ import { processMCPEnv, isEnabled } from '~/utils';
|
||||||
export class MCPServersRegistry {
|
export class MCPServersRegistry {
|
||||||
private initialized: boolean = false;
|
private initialized: boolean = false;
|
||||||
private connections: ConnectionsRepository;
|
private connections: ConnectionsRepository;
|
||||||
|
private initTimeoutMs: number;
|
||||||
|
|
||||||
public readonly rawConfigs: t.MCPServers;
|
public readonly rawConfigs: t.MCPServers;
|
||||||
public readonly parsedConfigs: Record<string, t.ParsedServerConfig>;
|
public readonly parsedConfigs: Record<string, t.ParsedServerConfig>;
|
||||||
|
|
||||||
public oauthServers: Set<string> | null = null;
|
public oauthServers: Set<string> = new Set();
|
||||||
public serverInstructions: Record<string, string> | null = null;
|
public serverInstructions: Record<string, string> = {};
|
||||||
public toolFunctions: t.LCAvailableTools | null = null;
|
public toolFunctions: t.LCAvailableTools = {};
|
||||||
public appServerConfigs: t.MCPServers | null = null;
|
public appServerConfigs: t.MCPServers = {};
|
||||||
|
|
||||||
constructor(configs: t.MCPServers) {
|
constructor(configs: t.MCPServers) {
|
||||||
this.rawConfigs = configs;
|
this.rawConfigs = configs;
|
||||||
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con }));
|
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con }));
|
||||||
this.connections = new ConnectionsRepository(configs);
|
this.connections = new ConnectionsRepository(configs);
|
||||||
|
this.initTimeoutMs = getMCPInitTimeout();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
|
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
|
||||||
|
|
@ -42,21 +50,43 @@ export class MCPServersRegistry {
|
||||||
|
|
||||||
const serverNames = Object.keys(this.parsedConfigs);
|
const serverNames = Object.keys(this.parsedConfigs);
|
||||||
|
|
||||||
await Promise.allSettled(serverNames.map((serverName) => this.gatherServerInfo(serverName)));
|
await Promise.allSettled(
|
||||||
|
serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)),
|
||||||
this.setOAuthServers();
|
);
|
||||||
this.setServerInstructions();
|
|
||||||
this.setAppServerConfigs();
|
|
||||||
await this.setAppToolFunctions();
|
|
||||||
|
|
||||||
this.connections.disconnectAll();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Fetches all metadata for a single server in parallel */
|
/** Wraps server initialization with a timeout to prevent hanging */
|
||||||
private async gatherServerInfo(serverName: string): Promise<void> {
|
private async initializeServerWithTimeout(serverName: string): Promise<void> {
|
||||||
|
let timeoutId: NodeJS.Timeout | null = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
await Promise.race([
|
||||||
|
this.initializeServer(serverName),
|
||||||
|
new Promise<never>((_, reject) => {
|
||||||
|
timeoutId = setTimeout(() => {
|
||||||
|
reject(new Error('Server initialization timed out'));
|
||||||
|
}, this.initTimeoutMs);
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error);
|
||||||
|
throw error;
|
||||||
|
} finally {
|
||||||
|
if (timeoutId != null) {
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Initializes a single server with all its metadata and adds it to appropriate collections */
|
||||||
|
private async initializeServer(serverName: string): Promise<void> {
|
||||||
|
logger.info(`${this.prefix(serverName)} Initializing server`);
|
||||||
|
const start = Date.now();
|
||||||
|
|
||||||
|
const config = this.parsedConfigs[serverName];
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await this.fetchOAuthRequirement(serverName);
|
await this.fetchOAuthRequirement(serverName);
|
||||||
const config = this.parsedConfigs[serverName];
|
|
||||||
|
|
||||||
if (config.startup !== false && !config.requiresOAuth) {
|
if (config.startup !== false && !config.requiresOAuth) {
|
||||||
await Promise.allSettled([
|
await Promise.allSettled([
|
||||||
|
|
@ -73,49 +103,39 @@ export class MCPServersRegistry {
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error);
|
logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/** Sets app-level server configs (startup enabled, non-OAuth servers) */
|
// Add to OAuth servers if needed
|
||||||
private setAppServerConfigs(): void {
|
if (config.requiresOAuth) {
|
||||||
const appServers = Object.keys(
|
this.oauthServers.add(serverName);
|
||||||
pickBy(
|
|
||||||
this.parsedConfigs,
|
|
||||||
(config) => config.startup !== false && config.requiresOAuth === false,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
this.appServerConfigs = pick(this.rawConfigs, appServers);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Creates set of server names that require OAuth authentication */
|
|
||||||
private setOAuthServers(): Set<string> {
|
|
||||||
if (this.oauthServers) return this.oauthServers;
|
|
||||||
this.oauthServers = new Set(
|
|
||||||
Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)),
|
|
||||||
);
|
|
||||||
return this.oauthServers;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Collects server instructions from all configured servers */
|
|
||||||
private setServerInstructions(): void {
|
|
||||||
this.serverInstructions = mapValues(
|
|
||||||
pickBy(this.parsedConfigs, (config) => config.serverInstructions),
|
|
||||||
(config) => config.serverInstructions as string,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Builds registry of all available tool functions from loaded connections */
|
|
||||||
private async setAppToolFunctions(): Promise<void> {
|
|
||||||
const connections = (await this.connections.getLoaded()).entries();
|
|
||||||
const allToolFunctions: t.LCAvailableTools = {};
|
|
||||||
for (const [serverName, conn] of connections) {
|
|
||||||
try {
|
|
||||||
const toolFunctions = await this.getToolFunctions(serverName, conn);
|
|
||||||
Object.assign(allToolFunctions, toolFunctions);
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
this.toolFunctions = allToolFunctions;
|
|
||||||
|
// Add server instructions if available
|
||||||
|
if (config.serverInstructions != null) {
|
||||||
|
this.serverInstructions[serverName] = config.serverInstructions as string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to app server configs if eligible (startup enabled, non-OAuth servers)
|
||||||
|
if (config.startup !== false && config.requiresOAuth === false) {
|
||||||
|
this.appServerConfigs[serverName] = this.rawConfigs[serverName];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch tool functions for this server if a connection was established
|
||||||
|
try {
|
||||||
|
const conn = await this.connections.get(serverName);
|
||||||
|
const toolFunctions = await this.getToolFunctions(serverName, conn);
|
||||||
|
Object.assign(this.toolFunctions, toolFunctions);
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect this server's connection after initialization
|
||||||
|
try {
|
||||||
|
await this.connections.disconnect(serverName);
|
||||||
|
} catch (disconnectError) {
|
||||||
|
logger.debug(`${this.prefix(serverName)} Failed to disconnect:`, disconnectError);
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`${this.prefix(serverName)} Initialized server in ${Date.now() - start}ms`);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Converts server tools to LibreChat-compatible tool functions format */
|
/** Converts server tools to LibreChat-compatible tool functions format */
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,7 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
get: jest.fn(),
|
get: jest.fn(),
|
||||||
getLoaded: jest.fn(),
|
getLoaded: jest.fn(),
|
||||||
disconnectAll: jest.fn(),
|
disconnectAll: jest.fn(),
|
||||||
|
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||||
} as unknown as jest.Mocked<ConnectionsRepository>;
|
} as unknown as jest.Mocked<ConnectionsRepository>;
|
||||||
|
|
||||||
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||||
|
|
@ -160,6 +161,7 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
delete process.env.MCP_INIT_TIMEOUT_MS;
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -179,15 +181,14 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
const registry = new MCPServersRegistry(rawConfigs);
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
// Verify initial state
|
// Verify initial state
|
||||||
expect(registry.oauthServers).toBeNull();
|
expect(registry.oauthServers.size).toBe(0);
|
||||||
expect(registry.serverInstructions).toBeNull();
|
expect(registry.serverInstructions).toEqual({});
|
||||||
expect(registry.toolFunctions).toBeNull();
|
expect(registry.toolFunctions).toEqual({});
|
||||||
expect(registry.appServerConfigs).toBeNull();
|
expect(registry.appServerConfigs).toEqual({});
|
||||||
|
|
||||||
await registry.initialize();
|
await registry.initialize();
|
||||||
|
|
||||||
// Test oauthServers Set
|
// Test oauthServers Set
|
||||||
expect(registry.oauthServers).toBeInstanceOf(Set);
|
|
||||||
expect(registry.oauthServers).toEqual(
|
expect(registry.oauthServers).toEqual(
|
||||||
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
|
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
|
||||||
);
|
);
|
||||||
|
|
@ -228,18 +229,49 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
|
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle errors gracefully and continue initialization', async () => {
|
it('should handle errors gracefully and continue initialization of other servers', async () => {
|
||||||
const registry = new MCPServersRegistry(rawConfigs);
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
// Make one server throw an error
|
// Make one specific server throw an error during OAuth detection
|
||||||
mockDetectOAuthRequirement.mockRejectedValueOnce(new Error('OAuth detection failed'));
|
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||||
|
if (url === 'https://api.github.com/mcp') {
|
||||||
|
return Promise.reject(new Error('OAuth detection failed'));
|
||||||
|
}
|
||||||
|
// Return normal responses for other servers
|
||||||
|
const oauthResults: Record<string, OAuthDetectionResult> = {
|
||||||
|
'https://api.disabled.com/mcp': {
|
||||||
|
requiresOAuth: false,
|
||||||
|
method: 'no-metadata-found',
|
||||||
|
metadata: null,
|
||||||
|
},
|
||||||
|
'https://api.public.com/mcp': {
|
||||||
|
requiresOAuth: false,
|
||||||
|
method: 'no-metadata-found',
|
||||||
|
metadata: null,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
return Promise.resolve(
|
||||||
|
oauthResults[url] ?? {
|
||||||
|
requiresOAuth: false,
|
||||||
|
method: 'no-metadata-found',
|
||||||
|
metadata: null,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
await registry.initialize();
|
await registry.initialize();
|
||||||
|
|
||||||
// Should still initialize successfully
|
// Should still initialize successfully for other servers
|
||||||
expect(registry.oauthServers).toBeInstanceOf(Set);
|
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||||
expect(registry.toolFunctions).toBeDefined();
|
expect(registry.toolFunctions).toBeDefined();
|
||||||
|
|
||||||
|
// The failed server should not be in oauthServers (since it failed OAuth detection)
|
||||||
|
expect(registry.oauthServers.has('oauth_server')).toBe(false);
|
||||||
|
|
||||||
|
// But other servers should still be processed successfully
|
||||||
|
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
|
||||||
|
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||||
|
|
||||||
// Error should be logged as a warning at the higher level
|
// Error should be logged as a warning at the higher level
|
||||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||||
expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'),
|
expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'),
|
||||||
|
|
@ -247,12 +279,15 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should disconnect all connections after initialization', async () => {
|
it('should disconnect individual connections after each server initialization', async () => {
|
||||||
const registry = new MCPServersRegistry(rawConfigs);
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
await registry.initialize();
|
await registry.initialize();
|
||||||
|
|
||||||
expect(mockConnectionsRepo.disconnectAll).toHaveBeenCalledTimes(1);
|
// Verify disconnect was called for each server during initialization
|
||||||
|
// All servers attempt to connect during initialization for metadata gathering
|
||||||
|
const serverNames = Object.keys(rawConfigs);
|
||||||
|
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should log configuration updates for each startup-enabled server', async () => {
|
it('should log configuration updates for each startup-enabled server', async () => {
|
||||||
|
|
@ -357,5 +392,125 @@ describe('MCPServersRegistry - Initialize Function', () => {
|
||||||
// Verify getInstructions was called for both "true" cases
|
// Verify getInstructions was called for both "true" cases
|
||||||
expect(mockClient.getInstructions).toHaveBeenCalledTimes(2);
|
expect(mockClient.getInstructions).toHaveBeenCalledTimes(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should use Promise.allSettled for individual server initialization', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
// Spy on Promise.allSettled to verify it's being used
|
||||||
|
const allSettledSpy = jest.spyOn(Promise, 'allSettled');
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
// Verify Promise.allSettled was called with an array of server initialization promises
|
||||||
|
expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)]));
|
||||||
|
|
||||||
|
// Verify it was called with the correct number of server promises
|
||||||
|
const serverNames = Object.keys(rawConfigs);
|
||||||
|
expect(allSettledSpy).toHaveBeenCalledWith(
|
||||||
|
expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))),
|
||||||
|
);
|
||||||
|
|
||||||
|
allSettledSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should isolate server failures and not affect other servers', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
// Make multiple servers fail in different ways
|
||||||
|
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||||
|
if (serverName === 'stdio_server') {
|
||||||
|
// First server fails
|
||||||
|
throw new Error('Connection failed for stdio_server');
|
||||||
|
}
|
||||||
|
if (serverName === 'websocket_server') {
|
||||||
|
// Second server fails
|
||||||
|
throw new Error('Connection failed for websocket_server');
|
||||||
|
}
|
||||||
|
// Other servers succeed
|
||||||
|
const connection = mockConnections.get(serverName);
|
||||||
|
if (!connection) {
|
||||||
|
throw new Error(`Connection not found for server: ${serverName}`);
|
||||||
|
}
|
||||||
|
return Promise.resolve(connection);
|
||||||
|
});
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
// Despite failures, initialization should complete
|
||||||
|
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||||
|
expect(registry.toolFunctions).toBeDefined();
|
||||||
|
|
||||||
|
// Successful servers should still be processed
|
||||||
|
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||||
|
|
||||||
|
// Failed servers should not crash the whole initialization
|
||||||
|
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('[MCP][stdio_server] Error fetching tool functions:'),
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('[MCP][websocket_server] Error fetching tool functions:'),
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should properly clean up connections even when some servers fail', async () => {
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
// Make some connections fail during disconnect
|
||||||
|
mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => {
|
||||||
|
if (serverName === 'stdio_server') {
|
||||||
|
return Promise.reject(new Error('Disconnect failed'));
|
||||||
|
}
|
||||||
|
return Promise.resolve();
|
||||||
|
});
|
||||||
|
|
||||||
|
await registry.initialize();
|
||||||
|
|
||||||
|
// Should still attempt to disconnect all servers during initialization
|
||||||
|
const serverNames = Object.keys(rawConfigs);
|
||||||
|
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
|
||||||
|
|
||||||
|
// Failed disconnects should be logged but not crash initialization
|
||||||
|
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('[MCP][stdio_server] Failed to disconnect:'),
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should timeout individual server initialization after configured timeout', async () => {
|
||||||
|
const timeout = 2000;
|
||||||
|
// Create registry with a short timeout for testing
|
||||||
|
process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`;
|
||||||
|
|
||||||
|
const registry = new MCPServersRegistry(rawConfigs);
|
||||||
|
|
||||||
|
// Make one server hang indefinitely during OAuth detection
|
||||||
|
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||||
|
if (url === 'https://api.github.com/mcp') {
|
||||||
|
// Slow init
|
||||||
|
return new Promise((res) => setTimeout(res, timeout * 2));
|
||||||
|
}
|
||||||
|
// Return normal responses for other servers
|
||||||
|
return Promise.resolve({
|
||||||
|
requiresOAuth: false,
|
||||||
|
method: 'no-metadata-found',
|
||||||
|
metadata: null,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const start = Date.now();
|
||||||
|
await registry.initialize();
|
||||||
|
const duration = Date.now() - start;
|
||||||
|
|
||||||
|
// Should complete within reasonable time despite one server hanging
|
||||||
|
// Allow some buffer for test execution overhead
|
||||||
|
expect(duration).toBeLessThan(timeout * 1.5);
|
||||||
|
|
||||||
|
// The timeout should prevent the hanging server from blocking initialization
|
||||||
|
// Other servers should still be processed successfully
|
||||||
|
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
|
||||||
|
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||||
|
}, 10_000); // 10 second Jest timeout
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ export class OAuthReconnectionManager {
|
||||||
|
|
||||||
// 1. derive the servers to reconnect
|
// 1. derive the servers to reconnect
|
||||||
const serversToReconnect = [];
|
const serversToReconnect = [];
|
||||||
for (const serverName of this.mcpManager.getOAuthServers() ?? []) {
|
for (const serverName of this.mcpManager.getOAuthServers()) {
|
||||||
const canReconnect = await this.canReconnect(userId, serverName);
|
const canReconnect = await this.canReconnect(userId, serverName);
|
||||||
if (canReconnect) {
|
if (canReconnect) {
|
||||||
serversToReconnect.push(serverName);
|
serversToReconnect.push(serverName);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue