mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-29 11:57:20 +02:00
⚡ refactor: Fast-Fail MCP Tool Discovery on 401 for Non-OAuth Servers (#12395)
* fix: fast-fail MCP discovery for non-OAuth servers on auth errors Always attach oauthHandler in discoverToolsInternal regardless of useOAuth flag. Previously, non-OAuth servers hitting 401 would hang for 30s because connectClient's oauthHandledPromise had no listener to emit oauthFailed, waiting until withTimeout killed it. * chore: import order
This commit is contained in:
parent
3f805d68a1
commit
221e49222d
3 changed files with 80 additions and 17 deletions
|
|
@ -81,7 +81,7 @@ export class MCPConnectionFactory {
|
|||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
|
||||
const oauthHandler = async () => {
|
||||
const oauthHandler = () => {
|
||||
logger.info(
|
||||
`${this.logPrefix} [Discovery] OAuth required; skipping URL generation in discovery mode`,
|
||||
);
|
||||
|
|
@ -89,9 +89,9 @@ export class MCPConnectionFactory {
|
|||
connection.emit('oauthFailed', new Error('OAuth required during tool discovery'));
|
||||
};
|
||||
|
||||
if (this.useOAuth) {
|
||||
connection.on('oauthRequired', oauthHandler);
|
||||
}
|
||||
// Register unconditionally: non-OAuth servers that return 401 also emit 'oauthRequired',
|
||||
// and without this listener, connectClient()'s oauthHandledPromise hangs for 30s+.
|
||||
connection.once('oauthRequired', oauthHandler);
|
||||
|
||||
try {
|
||||
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
|
||||
|
|
@ -103,9 +103,7 @@ export class MCPConnectionFactory {
|
|||
|
||||
if (await connection.isConnected()) {
|
||||
const tools = await connection.fetchTools();
|
||||
if (this.useOAuth) {
|
||||
connection.removeListener('oauthRequired', oauthHandler);
|
||||
}
|
||||
connection.removeListener('oauthRequired', oauthHandler);
|
||||
return { tools, connection, oauthRequired: false, oauthUrl: null };
|
||||
}
|
||||
} catch {
|
||||
|
|
@ -117,9 +115,7 @@ export class MCPConnectionFactory {
|
|||
|
||||
try {
|
||||
const tools = await this.attemptUnauthenticatedToolListing();
|
||||
if (this.useOAuth) {
|
||||
connection.removeListener('oauthRequired', oauthHandler);
|
||||
}
|
||||
connection.removeListener('oauthRequired', oauthHandler);
|
||||
if (tools && tools.length > 0) {
|
||||
logger.info(
|
||||
`${this.logPrefix} [Discovery] Successfully discovered ${tools.length} tools without auth`,
|
||||
|
|
@ -137,9 +133,7 @@ export class MCPConnectionFactory {
|
|||
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError);
|
||||
}
|
||||
|
||||
if (this.useOAuth) {
|
||||
connection.removeListener('oauthRequired', oauthHandler);
|
||||
}
|
||||
connection.removeListener('oauthRequired', oauthHandler);
|
||||
|
||||
try {
|
||||
await connection.disconnect();
|
||||
|
|
|
|||
|
|
@ -825,17 +825,17 @@ describe('MCPConnectionFactory', () => {
|
|||
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
|
||||
|
||||
let oauthHandler: (() => Promise<void>) | undefined;
|
||||
mockConnectionInstance.on.mockImplementation((event, handler) => {
|
||||
let oauthHandler: (() => void) | undefined;
|
||||
mockConnectionInstance.once.mockImplementation((event, handler) => {
|
||||
if (event === 'oauthRequired') {
|
||||
oauthHandler = handler as () => Promise<void>;
|
||||
oauthHandler = handler as () => void;
|
||||
}
|
||||
return mockConnectionInstance;
|
||||
});
|
||||
|
||||
mockConnectionInstance.connect.mockImplementation(async () => {
|
||||
if (oauthHandler) {
|
||||
await oauthHandler();
|
||||
oauthHandler();
|
||||
}
|
||||
throw new Error('OAuth required');
|
||||
});
|
||||
|
|
@ -849,6 +849,46 @@ describe('MCPConnectionFactory', () => {
|
|||
expect(mockOAuthStart).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should fast-fail discovery when non-OAuth server returns 401', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'github',
|
||||
serverConfig: {
|
||||
...mockServerConfig,
|
||||
url: 'https://api.githubcopilot.com/mcp/',
|
||||
type: 'streamable-http' as const,
|
||||
initTimeout: 30000,
|
||||
} as t.StreamableHTTPOptions,
|
||||
};
|
||||
|
||||
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
|
||||
|
||||
let oauthHandler: (() => void) | undefined;
|
||||
mockConnectionInstance.once.mockImplementation((event, handler) => {
|
||||
if (event === 'oauthRequired') {
|
||||
oauthHandler = handler as () => void;
|
||||
}
|
||||
return mockConnectionInstance;
|
||||
});
|
||||
|
||||
mockConnectionInstance.connect.mockImplementation(async () => {
|
||||
if (oauthHandler) {
|
||||
oauthHandler();
|
||||
}
|
||||
throw Object.assign(new Error('unauthorized'), { code: 401 });
|
||||
});
|
||||
|
||||
const start = Date.now();
|
||||
const result = await MCPConnectionFactory.discoverTools(basicOptions);
|
||||
const elapsed = Date.now() - start;
|
||||
|
||||
expect(elapsed).toBeLessThan(5000);
|
||||
expect(result.tools).toBeNull();
|
||||
expect(result.oauthRequired).toBe(true);
|
||||
expect(result.oauthUrl).toBeNull();
|
||||
expect(result.connection).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null tools when discovery fails completely', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@
|
|||
*/
|
||||
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { StreamableHTTPOptions } from '~/mcp/types';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
|
|
@ -265,4 +267,31 @@ describe('MCPConnection OAuth Events — Real Server', () => {
|
|||
expect(await connection.isConnected()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('MCPConnectionFactory.discoverTools — non-OAuth 401 fast-fail', () => {
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
it('should fast-fail when a non-OAuth discovery hits 401', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
type: 'streamable-http',
|
||||
url: server.url,
|
||||
initTimeout: 15000,
|
||||
} as StreamableHTTPOptions,
|
||||
};
|
||||
|
||||
const start = Date.now();
|
||||
const result = await MCPConnectionFactory.discoverTools(basicOptions);
|
||||
const elapsed = Date.now() - start;
|
||||
|
||||
expect(elapsed).toBeLessThan(5000);
|
||||
expect(result.tools).toBeNull();
|
||||
expect(result.oauthRequired).toBe(true);
|
||||
expect(result.oauthUrl).toBeNull();
|
||||
expect(result.connection).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue