🔌 fix: Shared MCP Server Connection Management (#9822)

- Fixed a bug in reinitMCPServer where a user connection was created for an app-level server whenever this server is reinitialized
- Made MCPManager.getUserConnection to return an error if the connection is app-level
- Add MCPManager.getConnection to return either an app connection or a user connection based on the serverName
- Made MCPManager.appConnections public to avoid unnecessary wrapper methods.
This commit is contained in:
Theo N. Truong 2025-09-26 06:24:36 -06:00 committed by GitHub
parent 4f3683fd9a
commit 3219734b9e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 56 additions and 42 deletions

View file

@ -442,10 +442,10 @@ async function getMCPSetupData(userId) {
} }
const mcpManager = getMCPManager(userId); const mcpManager = getMCPManager(userId);
/** @type {ReturnType<MCPManager['getAllConnections']>} */ /** @type {Map<string, import('@librechat/api').MCPConnection>} */
let appConnections = new Map(); let appConnections = new Map();
try { try {
appConnections = (await mcpManager.getAllConnections()) || new Map(); appConnections = (await mcpManager.appConnections?.getAll()) || new Map();
} catch (error) { } catch (error) {
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
} }

View file

@ -123,7 +123,7 @@ describe('tests for the new helper functions used by the MCP connection status e
beforeEach(() => { beforeEach(() => {
mockGetAppConfig = require('./Config').getAppConfig; mockGetAppConfig = require('./Config').getAppConfig;
mockGetMCPManager.mockReturnValue({ mockGetMCPManager.mockReturnValue({
getAllConnections: jest.fn(() => new Map()), appConnections: { getAll: jest.fn(() => new Map()) },
getUserConnections: jest.fn(() => new Map()), getUserConnections: jest.fn(() => new Map()),
getOAuthServers: jest.fn(() => new Set()), getOAuthServers: jest.fn(() => new Set()),
}); });
@ -137,7 +137,7 @@ describe('tests for the new helper functions used by the MCP connection status e
const mockOAuthServers = new Set(['server2']); const mockOAuthServers = new Set(['server2']);
const mockMCPManager = { const mockMCPManager = {
getAllConnections: jest.fn(() => mockAppConnections), appConnections: { getAll: jest.fn(() => mockAppConnections) },
getUserConnections: jest.fn(() => mockUserConnections), getUserConnections: jest.fn(() => mockUserConnections),
getOAuthServers: jest.fn(() => mockOAuthServers), getOAuthServers: jest.fn(() => mockOAuthServers),
}; };
@ -147,7 +147,7 @@ describe('tests for the new helper functions used by the MCP connection status e
expect(mockGetAppConfig).toHaveBeenCalled(); expect(mockGetAppConfig).toHaveBeenCalled();
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.getAllConnections).toHaveBeenCalled(); expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.getOAuthServers).toHaveBeenCalled(); expect(mockMCPManager.getOAuthServers).toHaveBeenCalled();
@ -168,7 +168,7 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetAppConfig.mockResolvedValue({ mcpConfig: mockConfig.mcpServers }); mockGetAppConfig.mockResolvedValue({ mcpConfig: mockConfig.mcpServers });
const mockMCPManager = { const mockMCPManager = {
getAllConnections: jest.fn(() => null), appConnections: { getAll: jest.fn(() => null) },
getUserConnections: jest.fn(() => null), getUserConnections: jest.fn(() => null),
getOAuthServers: jest.fn(() => null), getOAuthServers: jest.fn(() => null),
}; };

View file

@ -29,7 +29,7 @@ async function reinitMCPServer({
flowManager: _flowManager, flowManager: _flowManager,
}) { }) {
/** @type {MCPConnection | null} */ /** @type {MCPConnection | null} */
let userConnection = null; let connection = null;
/** @type {LCAvailableTools | null} */ /** @type {LCAvailableTools | null} */
let availableTools = null; let availableTools = null;
/** @type {ReturnType<MCPConnection['fetchTools']> | null} */ /** @type {ReturnType<MCPConnection['fetchTools']> | null} */
@ -50,7 +50,7 @@ async function reinitMCPServer({
}); });
try { try {
userConnection = await mcpManager.getUserConnection({ connection = await mcpManager.getConnection({
user, user,
signal, signal,
forceNew, forceNew,
@ -70,7 +70,7 @@ async function reinitMCPServer({
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
} catch (err) { } catch (err) {
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`); logger.info(`[MCP Reinitialize] getConnection threw error: ${err.message}`);
logger.info( logger.info(
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`, `[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
); );
@ -95,8 +95,8 @@ async function reinitMCPServer({
} }
} }
if (userConnection && !oauthRequired) { if (connection && !oauthRequired) {
tools = await userConnection.fetchTools(); tools = await connection.fetchTools();
availableTools = await updateMCPServerTools({ availableTools = await updateMCPServerTools({
serverName, serverName,
tools, tools,
@ -111,7 +111,7 @@ async function reinitMCPServer({
if (oauthRequired) { if (oauthRequired) {
return `MCP server '${serverName}' ready for OAuth authentication`; return `MCP server '${serverName}' ready for OAuth authentication`;
} }
if (userConnection) { if (connection) {
return `MCP server '${serverName}' reinitialized successfully`; return `MCP server '${serverName}' reinitialized successfully`;
} }
return `Failed to reinitialize MCP server '${serverName}'`; return `Failed to reinitialize MCP server '${serverName}'`;
@ -119,7 +119,7 @@ async function reinitMCPServer({
const result = { const result = {
availableTools, availableTools,
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)), success: Boolean((connection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(), message: getResponseMessage(),
oauthRequired, oauthRequired,
serverName, serverName,

View file

@ -20,8 +20,6 @@ import { processMCPEnv } from '~/utils/env';
*/ */
export class MCPManager extends UserConnectionManager { export class MCPManager extends UserConnectionManager {
private static instance: MCPManager | null; private static instance: MCPManager | null;
// Connections shared by all users.
private appConnections: ConnectionsRepository | null = null;
/** Creates and initializes the singleton MCPManager instance */ /** Creates and initializes the singleton MCPManager instance */
public static async createInstance(configs: t.MCPServers): Promise<MCPManager> { public static async createInstance(configs: t.MCPServers): Promise<MCPManager> {
@ -43,9 +41,25 @@ export class MCPManager extends UserConnectionManager {
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!); this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!);
} }
/** Returns all app-level connections */ /** Retrieves an app-level or user-specific connection based on provided arguments */
public async getAllConnections(): Promise<Map<string, MCPConnection> | null> { public async getConnection(
return this.appConnections!.getAll(); args: {
serverName: string;
user?: TUser;
forceNew?: boolean;
flowManager?: FlowStateManager<MCPOAuthTokens | null>;
} & Omit<t.OAuthConnectionOptions, 'useOAuth' | 'user' | 'flowManager'>,
): Promise<MCPConnection> {
if (this.appConnections!.has(args.serverName)) {
return this.appConnections!.get(args.serverName);
} else if (args.user?.id) {
return this.getUserConnection(args as Parameters<typeof this.getUserConnection>[0]);
} else {
throw new McpError(
ErrorCode.InvalidRequest,
`No connection found for server ${args.serverName}`,
);
}
} }
/** Get servers that require OAuth */ /** Get servers that require OAuth */
@ -180,12 +194,11 @@ Please follow these instructions when using tools from the respective MCP server
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`; const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
try { try {
if (!this.appConnections?.has(serverName) && userId && user) { if (userId && user) this.updateUserLastActivity(userId);
this.updateUserLastActivity(userId);
/** Get or create user-specific connection */ connection = await this.getConnection({
connection = await this.getUserConnection({
user,
serverName, serverName,
user,
flowManager, flowManager,
tokenMethods, tokenMethods,
oauthStart, oauthStart,
@ -194,16 +207,6 @@ Please follow these instructions when using tools from the respective MCP server
customUserVars, customUserVars,
requestBody, requestBody,
}); });
} else {
/** App-level connection */
connection = await this.appConnections!.get(serverName);
if (!connection) {
throw new McpError(
ErrorCode.InvalidRequest,
`${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`,
);
}
}
if (!(await connection.isConnected())) { if (!(await connection.isConnected())) {
/** May happen if getUserConnection failed silently or app connection dropped */ /** May happen if getUserConnection failed silently or app connection dropped */

View file

@ -4,6 +4,7 @@ import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { MCPConnection } from './connection'; import { MCPConnection } from './connection';
import type * as t from './types'; import type * as t from './types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
/** /**
* Abstract base class for managing user-specific MCP connections with lifecycle management. * Abstract base class for managing user-specific MCP connections with lifecycle management.
@ -14,6 +15,9 @@ import type * as t from './types';
*/ */
export abstract class UserConnectionManager { export abstract class UserConnectionManager {
protected readonly serversRegistry: MCPServersRegistry; protected readonly serversRegistry: MCPServersRegistry;
// Connections shared by all users.
public appConnections: ConnectionsRepository | null = null;
// Connections per userId -> serverName -> connection
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map(); protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
/** Last activity timestamp for users (not per server) */ /** Last activity timestamp for users (not per server) */
protected userLastActivity: Map<string, number> = new Map(); protected userLastActivity: Map<string, number> = new Map();
@ -60,6 +64,13 @@ export abstract class UserConnectionManager {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
} }
if (this.appConnections!.has(serverName)) {
throw new McpError(
ErrorCode.InvalidRequest,
`[MCP][User: ${userId}] Trying to create user-specific connection for app-level server "${serverName}"`,
);
}
const userServerMap = this.userConnections.get(userId); const userServerMap = this.userConnections.get(userId);
let connection = forceNew ? undefined : userServerMap?.get(serverName); let connection = forceNew ? undefined : userServerMap?.get(serverName);
const now = Date.now(); const now = Date.now();