mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
🚦 feat: Auto-reinitialize MCP Servers on Request (#9226)
This commit is contained in:
parent
ac608ded46
commit
c827fdd10e
28 changed files with 871 additions and 312 deletions
|
@ -1,5 +1,6 @@
|
|||
/* MCP */
|
||||
export * from './mcp/MCPManager';
|
||||
export * from './mcp/connection';
|
||||
export * from './mcp/oauth';
|
||||
export * from './mcp/auth';
|
||||
export * from './mcp/zod';
|
||||
|
|
|
@ -28,6 +28,7 @@ export class MCPConnectionFactory {
|
|||
protected readonly oauthStart?: (authURL: string) => Promise<void>;
|
||||
protected readonly oauthEnd?: () => Promise<void>;
|
||||
protected readonly returnOnOAuth?: boolean;
|
||||
protected readonly connectionTimeout?: number;
|
||||
|
||||
/** Creates a new MCP connection with optional OAuth support */
|
||||
static async create(
|
||||
|
@ -47,6 +48,7 @@ export class MCPConnectionFactory {
|
|||
});
|
||||
this.serverName = basic.serverName;
|
||||
this.useOAuth = !!oauth?.useOAuth;
|
||||
this.connectionTimeout = oauth?.connectionTimeout;
|
||||
this.logPrefix = oauth?.user
|
||||
? `[MCP][${basic.serverName}][${oauth.user.id}]`
|
||||
: `[MCP][${basic.serverName}]`;
|
||||
|
@ -82,8 +84,9 @@ export class MCPConnectionFactory {
|
|||
if (!this.tokenMethods?.findToken) return null;
|
||||
|
||||
try {
|
||||
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
|
||||
const tokens = await this.flowManager!.createFlowWithHandler(
|
||||
`tokens:${this.userId}:${this.serverName}`,
|
||||
flowId,
|
||||
'mcp_get_tokens',
|
||||
async () => {
|
||||
return await MCPTokenStorage.getTokens({
|
||||
|
@ -203,7 +206,7 @@ export class MCPConnectionFactory {
|
|||
|
||||
/** Attempts to establish connection with timeout handling */
|
||||
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
|
||||
const connectTimeout = this.serverConfig.initTimeout ?? 30000;
|
||||
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
|
||||
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||
|
@ -347,6 +350,7 @@ export class MCPConnectionFactory {
|
|||
newFlowId,
|
||||
'mcp_oauth',
|
||||
flowMetadata as FlowMetadata,
|
||||
this.signal,
|
||||
);
|
||||
if (typeof this.oauthEnd === 'function') {
|
||||
await this.oauthEnd();
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { TUser } from 'librechat-data-provider';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||
import { MCPConnection } from './connection';
|
||||
import type { RequestBody } from '~/types';
|
||||
import type * as t from './types';
|
||||
|
||||
/**
|
||||
|
@ -44,8 +39,9 @@ export abstract class UserConnectionManager {
|
|||
|
||||
/** Gets or creates a connection for a specific user */
|
||||
public async getUserConnection({
|
||||
user,
|
||||
serverName,
|
||||
forceNew,
|
||||
user,
|
||||
flowManager,
|
||||
customUserVars,
|
||||
requestBody,
|
||||
|
@ -54,25 +50,18 @@ export abstract class UserConnectionManager {
|
|||
oauthEnd,
|
||||
signal,
|
||||
returnOnOAuth = false,
|
||||
connectionTimeout,
|
||||
}: {
|
||||
user: TUser;
|
||||
serverName: string;
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||
customUserVars?: Record<string, string>;
|
||||
requestBody?: RequestBody;
|
||||
tokenMethods?: TokenMethods;
|
||||
oauthStart?: (authURL: string) => Promise<void>;
|
||||
oauthEnd?: () => Promise<void>;
|
||||
signal?: AbortSignal;
|
||||
returnOnOAuth?: boolean;
|
||||
}): Promise<MCPConnection> {
|
||||
forceNew?: boolean;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
|
||||
const userId = user.id;
|
||||
if (!userId) {
|
||||
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
|
||||
}
|
||||
|
||||
const userServerMap = this.userConnections.get(userId);
|
||||
let connection = userServerMap?.get(serverName);
|
||||
let connection = forceNew ? undefined : userServerMap?.get(serverName);
|
||||
const now = Date.now();
|
||||
|
||||
// Check if user is idle
|
||||
|
@ -131,6 +120,7 @@ export abstract class UserConnectionManager {
|
|||
oauthEnd: oauthEnd,
|
||||
returnOnOAuth: returnOnOAuth,
|
||||
requestBody: requestBody,
|
||||
connectionTimeout: connectionTimeout,
|
||||
},
|
||||
);
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ describe('getUserMCPAuthMap', () => {
|
|||
},
|
||||
];
|
||||
|
||||
const tools = testCases.map((testCase) =>
|
||||
const toolInstances = testCases.map((testCase) =>
|
||||
createMockTool(testCase.normalizedToolName, testCase.originalName),
|
||||
);
|
||||
|
||||
|
@ -54,7 +54,7 @@ describe('getUserMCPAuthMap', () => {
|
|||
|
||||
await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
tools,
|
||||
toolInstances,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
|
@ -69,7 +69,7 @@ describe('getUserMCPAuthMap', () => {
|
|||
|
||||
describe('Edge Cases', () => {
|
||||
it('should return empty object when no tools have mcpRawServerName', async () => {
|
||||
const tools = [
|
||||
const toolInstances = [
|
||||
createMockTool('regular_tool', undefined, false),
|
||||
createMockTool('another_tool', undefined, false),
|
||||
createMockTool('test_mcp_Server_no_raw_name', undefined),
|
||||
|
@ -77,7 +77,7 @@ describe('getUserMCPAuthMap', () => {
|
|||
|
||||
const result = await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
tools,
|
||||
toolInstances,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
|
@ -104,14 +104,14 @@ describe('getUserMCPAuthMap', () => {
|
|||
});
|
||||
|
||||
it('should handle database errors gracefully', async () => {
|
||||
const tools = [createMockTool('test_mcp_Server1', 'Server1')];
|
||||
const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
|
||||
const dbError = new Error('Database connection failed');
|
||||
|
||||
mockGetPluginAuthMap.mockRejectedValue(dbError);
|
||||
|
||||
const result = await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
tools,
|
||||
toolInstances,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
|
@ -119,18 +119,119 @@ describe('getUserMCPAuthMap', () => {
|
|||
});
|
||||
|
||||
it('should handle non-Error exceptions gracefully', async () => {
|
||||
const tools = [createMockTool('test_mcp_Server1', 'Server1')];
|
||||
const toolInstances = [createMockTool('test_mcp_Server1', 'Server1')];
|
||||
|
||||
mockGetPluginAuthMap.mockRejectedValue('String error');
|
||||
|
||||
const result = await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
tools,
|
||||
toolInstances,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('should handle mixed null/undefined values in tools array', async () => {
|
||||
const tools = [
|
||||
'test_mcp_Server1',
|
||||
null,
|
||||
'test_mcp_Server2',
|
||||
undefined,
|
||||
'regular_tool',
|
||||
'test_mcp_Server3',
|
||||
];
|
||||
|
||||
mockGetPluginAuthMap.mockResolvedValue({
|
||||
mcp_Server1: { API_KEY: 'key1' },
|
||||
mcp_Server2: { API_KEY: 'key2' },
|
||||
mcp_Server3: { API_KEY: 'key3' },
|
||||
});
|
||||
|
||||
const result = await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
tools: tools as (string | undefined)[],
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
|
||||
throwError: false,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
mcp_Server1: { API_KEY: 'key1' },
|
||||
mcp_Server2: { API_KEY: 'key2' },
|
||||
mcp_Server3: { API_KEY: 'key3' },
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle mixed null/undefined values in servers array', async () => {
|
||||
const servers = ['Server1', null, 'Server2', undefined, 'Server3'];
|
||||
|
||||
mockGetPluginAuthMap.mockResolvedValue({
|
||||
mcp_Server1: { API_KEY: 'key1' },
|
||||
mcp_Server2: { API_KEY: 'key2' },
|
||||
mcp_Server3: { API_KEY: 'key3' },
|
||||
});
|
||||
|
||||
const result = await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
servers: servers as (string | undefined)[],
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
|
||||
throwError: false,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
mcp_Server1: { API_KEY: 'key1' },
|
||||
mcp_Server2: { API_KEY: 'key2' },
|
||||
mcp_Server3: { API_KEY: 'key3' },
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle mixed null/undefined values in toolInstances array', async () => {
|
||||
const toolInstances = [
|
||||
createMockTool('test_mcp_Server1', 'Server1'),
|
||||
null,
|
||||
createMockTool('test_mcp_Server2', 'Server2'),
|
||||
undefined,
|
||||
createMockTool('regular_tool', undefined, false),
|
||||
createMockTool('test_mcp_Server3', 'Server3'),
|
||||
];
|
||||
|
||||
mockGetPluginAuthMap.mockResolvedValue({
|
||||
mcp_Server1: { API_KEY: 'key1' },
|
||||
mcp_Server2: { API_KEY: 'key2' },
|
||||
mcp_Server3: { API_KEY: 'key3' },
|
||||
});
|
||||
|
||||
const result = await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
toolInstances: toolInstances as (GenericTool | null)[],
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
expect(mockGetPluginAuthMap).toHaveBeenCalledWith({
|
||||
userId: 'user123',
|
||||
pluginKeys: ['mcp_Server1', 'mcp_Server2', 'mcp_Server3'],
|
||||
throwError: false,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
mcp_Server1: { API_KEY: 'key1' },
|
||||
mcp_Server2: { API_KEY: 'key2' },
|
||||
mcp_Server3: { API_KEY: 'key3' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Integration', () => {
|
||||
|
@ -138,7 +239,7 @@ describe('getUserMCPAuthMap', () => {
|
|||
const originalServerName = 'Connector: Company';
|
||||
const toolName = 'test_auth_mcp_Connector__Company';
|
||||
|
||||
const tools = [createMockTool(toolName, originalServerName)];
|
||||
const toolInstances = [createMockTool(toolName, originalServerName)];
|
||||
|
||||
const mockCustomUserVars = {
|
||||
'mcp_Connector: Company': {
|
||||
|
@ -151,7 +252,7 @@ describe('getUserMCPAuthMap', () => {
|
|||
|
||||
const result = await getUserMCPAuthMap({
|
||||
userId: 'user123',
|
||||
tools,
|
||||
toolInstances,
|
||||
findPluginAuthsByKeys: mockFindPluginAuthsByKeys,
|
||||
});
|
||||
|
||||
|
|
|
@ -7,33 +7,56 @@ import { getPluginAuthMap } from '~/agents/auth';
|
|||
export async function getUserMCPAuthMap({
|
||||
userId,
|
||||
tools,
|
||||
servers,
|
||||
toolInstances,
|
||||
findPluginAuthsByKeys,
|
||||
}: {
|
||||
userId: string;
|
||||
tools: GenericTool[] | undefined;
|
||||
tools?: (string | undefined)[];
|
||||
servers?: (string | undefined)[];
|
||||
toolInstances?: (GenericTool | null)[];
|
||||
findPluginAuthsByKeys: PluginAuthMethods['findPluginAuthsByKeys'];
|
||||
}) {
|
||||
if (!tools || tools.length === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const uniqueMcpServers = new Set<string>();
|
||||
|
||||
for (const tool of tools) {
|
||||
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
|
||||
if (mcpTool.mcpRawServerName) {
|
||||
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (uniqueMcpServers.size === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
|
||||
|
||||
let allMcpCustomUserVars: Record<string, Record<string, string>> = {};
|
||||
let mcpPluginKeysToFetch: string[] = [];
|
||||
try {
|
||||
const uniqueMcpServers = new Set<string>();
|
||||
|
||||
if (servers != null && servers.length) {
|
||||
for (const serverName of servers) {
|
||||
if (!serverName) {
|
||||
continue;
|
||||
}
|
||||
uniqueMcpServers.add(`${Constants.mcp_prefix}${serverName}`);
|
||||
}
|
||||
} else if (tools != null && tools.length) {
|
||||
for (const toolName of tools) {
|
||||
if (!toolName) {
|
||||
continue;
|
||||
}
|
||||
const delimiterIndex = toolName.indexOf(Constants.mcp_delimiter);
|
||||
if (delimiterIndex === -1) continue;
|
||||
const mcpServer = toolName.slice(delimiterIndex + Constants.mcp_delimiter.length);
|
||||
if (!mcpServer) continue;
|
||||
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpServer}`);
|
||||
}
|
||||
} else if (toolInstances != null && toolInstances.length) {
|
||||
for (const tool of toolInstances) {
|
||||
if (!tool) {
|
||||
continue;
|
||||
}
|
||||
const mcpTool = tool as GenericTool & { mcpRawServerName?: string };
|
||||
if (mcpTool.mcpRawServerName) {
|
||||
uniqueMcpServers.add(`${Constants.mcp_prefix}${mcpTool.mcpRawServerName}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (uniqueMcpServers.size === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
mcpPluginKeysToFetch = Array.from(uniqueMcpServers);
|
||||
allMcpCustomUserVars = await getPluginAuthMap({
|
||||
userId,
|
||||
pluginKeys: mcpPluginKeysToFetch,
|
||||
|
|
|
@ -446,7 +446,7 @@ export class MCPConnection extends EventEmitter {
|
|||
const serverUrl = this.url;
|
||||
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
|
||||
|
||||
const oauthTimeout = this.options.initTimeout ?? 60000;
|
||||
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
|
||||
/** Promise that will resolve when OAuth is handled */
|
||||
const oauthHandledPromise = new Promise<void>((resolve, reject) => {
|
||||
let timeoutId: NodeJS.Timeout | null = null;
|
||||
|
|
|
@ -134,4 +134,5 @@ export interface OAuthConnectionOptions {
|
|||
oauthStart?: (authURL: string) => Promise<void>;
|
||||
oauthEnd?: () => Promise<void>;
|
||||
returnOnOAuth?: boolean;
|
||||
connectionTimeout?: number;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import { AuthType, Constants, EToolResources } from 'librechat-data-provider';
|
||||
import type { TCustomConfig, TPlugin, FunctionTool } from 'librechat-data-provider';
|
||||
import type { TCustomConfig, TPlugin } from 'librechat-data-provider';
|
||||
import { LCAvailableTools, LCFunctionTool } from '~/mcp/types';
|
||||
|
||||
/**
|
||||
* Filters out duplicate plugins from the list of plugins.
|
||||
|
@ -60,7 +61,7 @@ export function convertMCPToolToPlugin({
|
|||
customConfig,
|
||||
}: {
|
||||
toolKey: string;
|
||||
toolData: FunctionTool;
|
||||
toolData: LCFunctionTool;
|
||||
customConfig?: Partial<TCustomConfig> | null;
|
||||
}): TPlugin | undefined {
|
||||
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
|
||||
|
@ -112,7 +113,7 @@ export function convertMCPToolsToPlugins({
|
|||
functionTools,
|
||||
customConfig,
|
||||
}: {
|
||||
functionTools?: Record<string, FunctionTool>;
|
||||
functionTools?: LCAvailableTools;
|
||||
customConfig?: Partial<TCustomConfig> | null;
|
||||
}): TPlugin[] | undefined {
|
||||
if (!functionTools || typeof functionTools !== 'object') {
|
||||
|
|
|
@ -1525,6 +1525,8 @@ export enum Constants {
|
|||
CONFIG_VERSION = '1.2.8',
|
||||
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */
|
||||
NO_PARENT = '00000000-0000-0000-0000-000000000000',
|
||||
/** Standard value to use whatever the submission prelim. `responseMessageId` is */
|
||||
USE_PRELIM_RESPONSE_MESSAGE_ID = 'USE_PRELIM_RESPONSE_MESSAGE_ID',
|
||||
/** Standard value for the initial conversationId before a request is sent */
|
||||
NEW_CONVO = 'new',
|
||||
/** Standard value for the temporary conversationId after a request is sent and before the server responds */
|
||||
|
@ -1551,6 +1553,8 @@ export enum Constants {
|
|||
mcp_delimiter = '_mcp_',
|
||||
/** Prefix for MCP plugins */
|
||||
mcp_prefix = 'mcp_',
|
||||
/** Unique value to indicate all MCP servers */
|
||||
mcp_all = 'sys__all__sys',
|
||||
/** Placeholder Agent ID for Ephemeral Agents */
|
||||
EPHEMERAL_AGENT_ID = 'ephemeral',
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue