🚦 feat: Auto-reinitialize MCP Servers on Request (#9226)

This commit is contained in:
Danny Avila 2025-08-23 03:27:05 -04:00
parent ac608ded46
commit c827fdd10e
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
28 changed files with 871 additions and 312 deletions

View file

@ -1,5 +1,6 @@
/* MCP */
export * from './mcp/MCPManager';
export * from './mcp/connection';
export * from './mcp/oauth';
export * from './mcp/auth';
export * from './mcp/zod';

View file

@ -28,6 +28,7 @@ export class MCPConnectionFactory {
protected readonly oauthStart?: (authURL: string) => Promise<void>;
protected readonly oauthEnd?: () => Promise<void>;
protected readonly returnOnOAuth?: boolean;
protected readonly connectionTimeout?: number;
/** Creates a new MCP connection with optional OAuth support */
static async create(
@ -47,6 +48,7 @@ export class MCPConnectionFactory {
});
this.serverName = basic.serverName;
this.useOAuth = !!oauth?.useOAuth;
this.connectionTimeout = oauth?.connectionTimeout;
this.logPrefix = oauth?.user
? `[MCP][${basic.serverName}][${oauth.user.id}]`
: `[MCP][${basic.serverName}]`;
@ -82,8 +84,9 @@ export class MCPConnectionFactory {
if (!this.tokenMethods?.findToken) return null;
try {
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
const tokens = await this.flowManager!.createFlowWithHandler(
`tokens:${this.userId}:${this.serverName}`,
flowId,
'mcp_get_tokens',
async () => {
return await MCPTokenStorage.getTokens({
@ -203,7 +206,7 @@ export class MCPConnectionFactory {
/** Attempts to establish connection with timeout handling */
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
const connectTimeout = this.serverConfig.initTimeout ?? 30000;
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout(
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
@ -347,6 +350,7 @@ export class MCPConnectionFactory {
newFlowId,
'mcp_oauth',
flowMetadata as FlowMetadata,
this.signal,
);
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();

View file

@ -1,13 +1,8 @@
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { MCPConnection } from './connection';
import type { RequestBody } from '~/types';
import type * as t from './types';
/**
@ -44,8 +39,9 @@ export abstract class UserConnectionManager {
/** Gets or creates a connection for a specific user */
public async getUserConnection({
user,
serverName,
forceNew,
user,
flowManager,
customUserVars,
requestBody,
@ -54,25 +50,18 @@ export abstract class UserConnectionManager {
oauthEnd,
signal,
returnOnOAuth = false,
connectionTimeout,
}: {
user: TUser;
serverName: string;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
customUserVars?: Record<string, string>;
requestBody?: RequestBody;
tokenMethods?: TokenMethods;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
signal?: AbortSignal;
returnOnOAuth?: boolean;
}): Promise<MCPConnection> {
forceNew?: boolean;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
const userId = user.id;
if (!userId) {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
}
const userServerMap = this.userConnections.get(userId);
let connection = userServerMap?.get(serverName);
let connection = forceNew ? undefined : userServerMap?.get(serverName);
const now = Date.now();
// Check if user is idle
@ -131,6 +120,7 @@ export abstract class UserConnectionManager {
oauthEnd: oauthEnd,
returnOnOAuth: returnOnOAuth,
requestBody: requestBody,
connectionTimeout: connectionTimeout,
},
);

View file

@ -45,7 +45,7 @@ describe('getUserMCPAuthMap', () => {
},
];
const tools = testCases.map((testCase) =>
const toolInstances = testCases.map((testCase) =>
createMockTool(testCase.normalizedToolName, testCase.originalName),
);
@ -54,7 +54,7 @@ describe('getUserMCPAuthMap', () => {
await getUserMCPAuthMap({
userId: 'user123',
tools,
toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
@ -69,7 +69,7 @@ describe('getUserMCPAuthMap', () => {
describe('Edge Cases', () => {
it('should return empty object when no tools have mcpRawServerName', async () => {
const tools = [
const toolInstances = [
createMockTool('regular_tool', undefined, false),
createMockTool('another_tool', undefined, false),
createMockTool('test_mcp_Server_no_raw_name', undefined),
@ -77,7 +77,7 @@ describe('getUserMCPAuthMap', () => {
const result = await getUserMCPAuthMap({
userId: 'user123',
tools,
toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
@ -104,14 +104,14 @@ describe('getUserMCPAuthMap', () => {
});
it('should handle database errors gracefully', async () => {
const tools = [createMockTool('test_mcp_Server1', 'Server1')];
const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
const dbError = new Error('Database connection failed');
mockGetPluginAuthMap.mockRejectedValue(dbError);
const result = await getUserMCPAuthMap({
userId: 'user123',
tools,
toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
@ -119,18 +119,119 @@ describe('getUserMCPAuthMap', () => {
});
it('should handle non-Error exceptions gracefully', async () => {
const tools = [createMockTool('test_mcp_Server1', 'Server1')];
const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
mockGetPluginAuthMap.mockRejectedValue('String error');
const result = await getUserMCPAuthMap({
userId: 'user123',
tools,
toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(result).toEqual({});
});
it('should handle mixed null/undefined values in tools array', async () => {
const tools = [
'test_mcp_Server1',
null,
'test_mcp_Server2',
undefined,
'regular_tool',
'test_mcp_Server3',
];
mockGetPluginAuthMap.mockResolvedValue({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
const result = await getUserMCPAuthMap({
userId: 'user123',
tools: tools as (string | undefined)[],
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
userId: 'user123',
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
throwError: false,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(result).toEqual({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
});
it('should handle mixed null/undefined values in servers array', async () => {
const servers = ['Server1', null, 'Server2', undefined, 'Server3'];
mockGetPluginAuthMap.mockResolvedValue({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
const result = await getUserMCPAuthMap({
userId: 'user123',
servers: servers as (string | undefined)[],
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
userId: 'user123',
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
throwError: false,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(result).toEqual({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
});
it('should handle mixed null/undefined values in toolInstances array', async () => {
const toolInstances = [
createMockTool('test_mcp_Server1', 'Server1'),
null,
createMockTool('test_mcp_Server2', 'Server2'),
undefined,
createMockTool('regular_tool', undefined, false),
createMockTool('test_mcp_Server3', 'Server3'),
];
mockGetPluginAuthMap.mockResolvedValue({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
const result = await getUserMCPAuthMap({
userId: 'user123',
toolInstances: toolInstances as (GenericTool | null)[],
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
userId: 'user123',
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
throwError: false,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});
expect(result).toEqual({
mcp_Server1: { API_KEY: 'key1' },
mcp_Server2: { API_KEY: 'key2' },
mcp_Server3: { API_KEY: 'key3' },
});
});
});
describe('Integration', () => {
@ -138,7 +239,7 @@ describe('getUserMCPAuthMap', () => {
const originalServerName = 'Connector: Company';
const toolName = 'test_auth_mcp_Connector__Company';
const tools = [createMockTool(toolName, originalServerName)];
const toolInstances = [createMockTool(toolName, originalServerName)];
const mockCustomUserVars = {
'mcp_Connector: Company': {
@ -151,7 +252,7 @@ describe('getUserMCPAuthMap', () => {
const result = await getUserMCPAuthMap({
userId: 'user123',
tools,
toolInstances,
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
});

View file

@ -7,33 +7,56 @@ import { getPluginAuthMap } from '~/agents/auth';
export async function getUserMCPAuthMap({
userId,
tools,
servers,
toolInstances,
findPluginAuthsByKeys,
}: {
userId: string;
tools: GenericTool[] | undefined;
tools?: (string | undefined)[];
servers?: (string | undefined)[];
toolInstances?: (GenericTool | null)[];
findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'];
}) {
if (!tools || tools.length === 0) {
return {};
}
const uniqueMcpServers = new Set<string>();
for (const tool of tools) {
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
if (mcpTool.mcpRawServerName) {
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
}
}
if (uniqueMcpServers.size === 0) {
return {};
}
const mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
let allMcpCustomUserVars: Record<string, Record<string, string>> = {};
let mcpPluginKeysToFetch: string[] = [];
try {
const uniqueMcpServers = new Set<string>();
if (servers != null && servers.length) {
for (const serverName of servers) {
if (!serverName) {
continue;
}
uniqueMcpServers.add(`${Constants.mcp_prefix}${serverName}`);
}
} else if (tools != null && tools.length) {
for (const toolName of tools) {
if (!toolName) {
continue;
}
const delimiterIndex = toolName.indexOf(Constants.mcp_delimiter);
if (delimiterIndex === -1) continue;
const mcpServer = toolName.slice(delimiterIndex + Constants.mcp_delimiter.length);
if (!mcpServer) continue;
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpServer}`);
}
} else if (toolInstances != null && toolInstances.length) {
for (const tool of toolInstances) {
if (!tool) {
continue;
}
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
if (mcpTool.mcpRawServerName) {
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
}
}
}
if (uniqueMcpServers.size === 0) {
return {};
}
mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
allMcpCustomUserVars = await getPluginAuthMap({
userId,
pluginKeys: mcpPluginKeysToFetch,

View file

@ -446,7 +446,7 @@ export class MCPConnection extends EventEmitter {
const serverUrl = this.url;
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
const oauthTimeout = this.options.initTimeout ?? 60000;
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
/** Promise that will resolve when OAuth is handled */
const oauthHandledPromise = new Promise<void>((resolve, reject) => {
let timeoutId: NodeJS.Timeout | null = null;

View file

@ -134,4 +134,5 @@ export interface OAuthConnectionOptions {
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
returnOnOAuth?: boolean;
connectionTimeout?: number;
}

View file

@ -1,5 +1,6 @@
import { AuthType, Constants, EToolResources } from 'librechat-data-provider';
import type { TCustomConfig, TPlugin, FunctionTool } from 'librechat-data-provider';
import type { TCustomConfig, TPlugin } from 'librechat-data-provider';
import { LCAvailableTools, LCFunctionTool } from '~/mcp/types';
/**
* Filters out duplicate plugins from the list of plugins.
@ -60,7 +61,7 @@ export function convertMCPToolToPlugin({
customConfig,
}: {
toolKey: string;
toolData: FunctionTool;
toolData: LCFunctionTool;
customConfig?: Partial<TCustomConfig> | null;
}): TPlugin | undefined {
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
@ -112,7 +113,7 @@ export function convertMCPToolsToPlugins({
functionTools,
customConfig,
}: {
functionTools?: Record<string, FunctionTool>;
functionTools?: LCAvailableTools;
customConfig?: Partial<TCustomConfig> | null;
}): TPlugin[] | undefined {
if (!functionTools || typeof functionTools !== 'object') {

View file

@ -1525,6 +1525,8 @@ export enum Constants {
CONFIG_VERSION = '1.2.8',
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
NO_PARENT = '00000000-0000-0000-0000-000000000000',
/** Standard value to use whatever the submission prelim. `responseMessageId` is */
USE_PRELIM_RESPONSE_MESSAGE_ID = 'USE_PRELIM_RESPONSE_MESSAGE_ID',
/** Standard value for the initial conversationId before a request is sent */
NEW_CONVO = 'new',
/** Standard value for the temporary conversationId after a request is sent and before the server responds */
@ -1551,6 +1553,8 @@ export enum Constants {
mcp_delimiter = '_mcp_',
/** Prefix for MCP plugins */
mcp_prefix = 'mcp_',
/** Unique value to indicate all MCP servers */
mcp_all = 'sys__all__sys',
/** Placeholder Agent ID for Ephemeral Agents */
EPHEMERAL_AGENT_ID = 'ephemeral',
}