From 221e49222d1cef704729200eec3bed12239d548a Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 25 Mar 2026 13:18:02 -0400 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20refactor:=20Fast-Fail=20MCP=20Tool?= =?UTF-8?q?=20Discovery=20on=20401=20for=20Non-OAuth=20Servers=20(#12395)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: fast-fail MCP discovery for non-OAuth servers on auth errors Always attach oauthHandler in discoverToolsInternal regardless of useOAuth flag. Previously, non-OAuth servers hitting 401 would hang for 30s because connectClient's oauthHandledPromise had no listener to emit oauthFailed, waiting until withTimeout killed it. * chore: import order --- packages/api/src/mcp/MCPConnectionFactory.ts | 20 +++----- .../__tests__/MCPConnectionFactory.test.ts | 48 +++++++++++++++++-- .../MCPOAuthConnectionEvents.test.ts | 29 +++++++++++ 3 files changed, 80 insertions(+), 17 deletions(-) diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 337662c812..eb62514a4e 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -81,7 +81,7 @@ export class MCPConnectionFactory { useSSRFProtection: this.useSSRFProtection, }); - const oauthHandler = async () => { + const oauthHandler = () => { logger.info( `${this.logPrefix} [Discovery] OAuth required; skipping URL generation in discovery mode`, ); @@ -89,9 +89,9 @@ export class MCPConnectionFactory { connection.emit('oauthFailed', new Error('OAuth required during tool discovery')); }; - if (this.useOAuth) { - connection.on('oauthRequired', oauthHandler); - } + // Register unconditionally: non-OAuth servers that return 401 also emit 'oauthRequired', + // and without this listener, connectClient()'s oauthHandledPromise hangs for 30s+. + connection.once('oauthRequired', oauthHandler); try { const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000; @@ -103,9 +103,7 @@ export class MCPConnectionFactory { if (await connection.isConnected()) { const tools = await connection.fetchTools(); - if (this.useOAuth) { - connection.removeListener('oauthRequired', oauthHandler); - } + connection.removeListener('oauthRequired', oauthHandler); return { tools, connection, oauthRequired: false, oauthUrl: null }; } } catch { @@ -117,9 +115,7 @@ export class MCPConnectionFactory { try { const tools = await this.attemptUnauthenticatedToolListing(); - if (this.useOAuth) { - connection.removeListener('oauthRequired', oauthHandler); - } + connection.removeListener('oauthRequired', oauthHandler); if (tools && tools.length > 0) { logger.info( `${this.logPrefix} [Discovery] Successfully discovered ${tools.length} tools without auth`, @@ -137,9 +133,7 @@ export class MCPConnectionFactory { logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError); } - if (this.useOAuth) { - connection.removeListener('oauthRequired', oauthHandler); - } + connection.removeListener('oauthRequired', oauthHandler); try { await connection.disconnect(); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index 75d7b4321d..326b77789e 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -825,17 +825,17 @@ describe('MCPConnectionFactory', () => { mockConnectionInstance.isConnected.mockResolvedValue(false); mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined); - let oauthHandler: (() => Promise) | undefined; - mockConnectionInstance.on.mockImplementation((event, handler) => { + let oauthHandler: (() => void) | undefined; + mockConnectionInstance.once.mockImplementation((event, handler) => { if (event === 'oauthRequired') { - oauthHandler = handler as () => Promise; + oauthHandler = handler as () => void; } return mockConnectionInstance; }); mockConnectionInstance.connect.mockImplementation(async () => { if (oauthHandler) { - await oauthHandler(); + oauthHandler(); } throw new Error('OAuth required'); }); @@ -849,6 +849,46 @@ describe('MCPConnectionFactory', () => { expect(mockOAuthStart).not.toHaveBeenCalled(); }); + it('should fast-fail discovery when non-OAuth server returns 401', async () => { + const basicOptions = { + serverName: 'github', + serverConfig: { + ...mockServerConfig, + url: 'https://api.githubcopilot.com/mcp/', + type: 'streamable-http' as const, + initTimeout: 30000, + } as t.StreamableHTTPOptions, + }; + + mockConnectionInstance.isConnected.mockResolvedValue(false); + mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined); + + let oauthHandler: (() => void) | undefined; + mockConnectionInstance.once.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthHandler = handler as () => void; + } + return mockConnectionInstance; + }); + + mockConnectionInstance.connect.mockImplementation(async () => { + if (oauthHandler) { + oauthHandler(); + } + throw Object.assign(new Error('unauthorized'), { code: 401 }); + }); + + const start = Date.now(); + const result = await MCPConnectionFactory.discoverTools(basicOptions); + const elapsed = Date.now() - start; + + expect(elapsed).toBeLessThan(5000); + expect(result.tools).toBeNull(); + expect(result.oauthRequired).toBe(true); + expect(result.oauthUrl).toBeNull(); + expect(result.connection).toBeNull(); + }); + it('should return null tools when discovery fails completely', async () => { const basicOptions = { serverName: 'test-server', diff --git a/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts index 4e168d00f3..79470337a7 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts @@ -6,8 +6,10 @@ */ import { MCPConnection } from '~/mcp/connection'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { createOAuthMCPServer } from './helpers/oauthTestServer'; import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { StreamableHTTPOptions } from '~/mcp/types'; import type { MCPOAuthTokens } from '~/mcp/oauth'; jest.mock('@librechat/data-schemas', () => ({ @@ -265,4 +267,31 @@ describe('MCPConnection OAuth Events — Real Server', () => { expect(await connection.isConnected()).toBe(true); }); }); + + describe('MCPConnectionFactory.discoverTools — non-OAuth 401 fast-fail', () => { + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + it('should fast-fail when a non-OAuth discovery hits 401', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: { + type: 'streamable-http', + url: server.url, + initTimeout: 15000, + } as StreamableHTTPOptions, + }; + + const start = Date.now(); + const result = await MCPConnectionFactory.discoverTools(basicOptions); + const elapsed = Date.now() - start; + + expect(elapsed).toBeLessThan(5000); + expect(result.tools).toBeNull(); + expect(result.oauthRequired).toBe(true); + expect(result.oauthUrl).toBeNull(); + expect(result.connection).toBeNull(); + }); + }); });