🔐 fix: MCP OAuth Tool Discovery and Event Emission (#11599)

* fix: MCP OAuth tool discovery and event emission in event-driven mode

- Add discoverServerTools method to MCPManager for tool discovery when OAuth is required
- Fix OAuth event emission to send both ON_RUN_STEP and ON_RUN_STEP_DELTA events
- Fix hasSubscriber flag reset in GenerationJobManager for proper event buffering
- Add ToolDiscoveryOptions and ToolDiscoveryResult types
- Update reinitMCPServer to use new discovery method and propagate OAuth URLs

* refactor: Update ToolService and MCP modules for improved functionality

- Reintroduced Constants in ToolService for better reference management.
- Enhanced loadToolDefinitionsWrapper to handle both response and streamId scenarios.
- Updated MCP module to correct type definitions for oauthStart parameter.
- Improved MCPConnectionFactory to ensure proper disconnection handling during tool discovery.
- Adjusted tests to reflect changes in mock implementations and ensure accurate behavior during OAuth handling.

* fix: Refine OAuth handling in MCPConnectionFactory and related tests

- Updated the OAuth URL assignment logic in reinitMCPServer to prevent overwriting existing URLs.
- Enhanced error logging to provide clearer messages when tool discovery fails.
- Adjusted tests to reflect changes in OAuth handling, ensuring accurate detection of OAuth requirements without generating URLs in discovery mode.

* refactor: Clean up OAuth URL assignment in reinitMCPServer

- Removed redundant OAuth URL assignment logic in the reinitMCPServer function to streamline the tool discovery process.
- Enhanced error logging for tool discovery failures, improving clarity in debugging and monitoring.

* fix: Update response handling in ToolService for event-driven mode

- Changed the condition in loadToolDefinitionsWrapper to check for writableEnded instead of headersSent, ensuring proper event emission when the response is still writable.
- This adjustment enhances the reliability of event handling during tool execution, particularly in streaming scenarios.
This commit is contained in:
Danny Avila 2026-02-01 19:37:04 -05:00 committed by GitHub
parent 5af1342dbb
commit d13037881a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 667 additions and 40 deletions

View file

@ -1,5 +1,6 @@
import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { Tool } from '@modelcontextprotocol/sdk/types.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager';
@ -11,6 +12,13 @@ import { withTimeout } from '~/utils/promise';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
export interface ToolDiscoveryResult {
tools: Tool[] | null;
connection: MCPConnection | null;
oauthRequired: boolean;
oauthUrl: string | null;
}
/**
* Factory for creating MCP connections with optional OAuth authentication.
* Handles OAuth flows, token management, and connection retry logic.
@ -41,6 +49,137 @@ export class MCPConnectionFactory {
return factory.createConnection();
}
/**
* Discovers tools from an MCP server, even when OAuth is required.
* Per MCP spec, tool listing should be possible without authentication.
* Returns tools if discoverable, plus OAuth status for tool execution.
*/
static async discoverTools(
basic: t.BasicConnectionOptions,
oauth?: Omit<t.OAuthConnectionOptions, 'returnOnOAuth'>,
): Promise<ToolDiscoveryResult> {
const factory = new this(basic, oauth ? { ...oauth, returnOnOAuth: true } : undefined);
return factory.discoverToolsInternal();
}
protected async discoverToolsInternal(): Promise<ToolDiscoveryResult> {
const oauthUrl: string | null = null;
let oauthRequired = false;
const oauthTokens = this.useOAuth ? await this.getOAuthTokens() : null;
const connection = new MCPConnection({
serverName: this.serverName,
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
});
const oauthHandler = async () => {
logger.info(
`${this.logPrefix} [Discovery] OAuth required; skipping URL generation in discovery mode`,
);
oauthRequired = true;
connection.emit('oauthFailed', new Error('OAuth required during tool discovery'));
};
if (this.useOAuth) {
connection.on('oauthRequired', oauthHandler);
}
try {
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
await withTimeout(
connection.connect(),
connectTimeout,
`Connection timeout after ${connectTimeout}ms`,
);
if (await connection.isConnected()) {
const tools = await connection.fetchTools();
if (this.useOAuth) {
connection.removeListener('oauthRequired', oauthHandler);
}
return { tools, connection, oauthRequired: false, oauthUrl: null };
}
} catch {
logger.debug(
`${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`,
);
}
try {
const tools = await this.attemptUnauthenticatedToolListing();
if (this.useOAuth) {
connection.removeListener('oauthRequired', oauthHandler);
}
if (tools && tools.length > 0) {
logger.info(
`${this.logPrefix} [Discovery] Successfully discovered ${tools.length} tools without auth`,
);
try {
await connection.disconnect();
} catch {
// Ignore cleanup errors
}
return { tools, connection: null, oauthRequired, oauthUrl };
}
} catch (listError) {
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError);
}
if (this.useOAuth) {
connection.removeListener('oauthRequired', oauthHandler);
}
try {
await connection.disconnect();
} catch {
// Ignore cleanup errors
}
return { tools: null, connection: null, oauthRequired, oauthUrl };
}
protected async attemptUnauthenticatedToolListing(): Promise<Tool[] | null> {
const unauthConnection = new MCPConnection({
serverName: this.serverName,
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens: null,
});
unauthConnection.on('oauthRequired', () => {
logger.debug(
`${this.logPrefix} [Discovery] Unauthenticated connection requires OAuth, failing fast`,
);
unauthConnection.emit(
'oauthFailed',
new Error('OAuth not supported in unauthenticated discovery'),
);
});
try {
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 15000;
await withTimeout(unauthConnection.connect(), connectTimeout, `Unauth connection timeout`);
if (await unauthConnection.isConnected()) {
const tools = await unauthConnection.fetchTools();
await unauthConnection.disconnect();
return tools;
}
} catch {
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated connection attempt failed`);
}
try {
await unauthConnection.disconnect();
} catch {
// Ignore cleanup errors
}
return null;
}
protected constructor(basic: t.BasicConnectionOptions, oauth?: t.OAuthConnectionOptions) {
this.serverConfig = processMCPEnv({
options: basic.serverConfig,
@ -56,7 +195,7 @@ export class MCPConnectionFactory {
: `[MCP][${basic.serverName}]`;
if (oauth?.useOAuth) {
this.userId = oauth.user.id;
this.userId = oauth.user?.id;
this.flowManager = oauth.flowManager;
this.tokenMethods = oauth.tokenMethods;
this.signal = oauth.signal;

View file

@ -8,11 +8,12 @@ import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from './oauth';
import type { RequestBody } from '~/types';
import type * as t from './types';
import { MCPServersInitializer } from './registry/MCPServersInitializer';
import { MCPServerInspector } from './registry/MCPServerInspector';
import { MCPServersRegistry } from './registry/MCPServersRegistry';
import { UserConnectionManager } from './UserConnectionManager';
import { ConnectionsRepository } from './ConnectionsRepository';
import { MCPServerInspector } from './registry/MCPServerInspector';
import { MCPServersInitializer } from './registry/MCPServersInitializer';
import { MCPServersRegistry } from './registry/MCPServersRegistry';
import { MCPConnectionFactory } from './MCPConnectionFactory';
import { preProcessGraphTokens } from '~/utils/graph';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
@ -68,6 +69,70 @@ export class MCPManager extends UserConnectionManager {
}
}
/**
* Discovers tools from an MCP server, even when OAuth is required.
* Per MCP spec, tool listing should be possible without authentication.
* Use this for agent initialization to get tool schemas before OAuth flow.
*/
public async discoverServerTools(args: t.ToolDiscoveryOptions): Promise<t.ToolDiscoveryResult> {
const { serverName, user } = args;
const logPrefix = user?.id ? `[MCP][User: ${user.id}][${serverName}]` : `[MCP][${serverName}]`;
try {
const existingAppConnection = await this.appConnections?.get(serverName);
if (existingAppConnection && (await existingAppConnection.isConnected())) {
const tools = await existingAppConnection.fetchTools();
return { tools, oauthRequired: false, oauthUrl: null };
}
} catch {
logger.debug(`${logPrefix} [Discovery] App connection not available, trying discovery mode`);
}
const serverConfig = (await MCPServersRegistry.getInstance().getServerConfig(
serverName,
user?.id,
)) as t.MCPOptions | null;
if (!serverConfig) {
logger.warn(`${logPrefix} [Discovery] Server config not found`);
return { tools: null, oauthRequired: false, oauthUrl: null };
}
const useOAuth = Boolean(
serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata,
);
const basic: t.BasicConnectionOptions = { serverName, serverConfig };
if (!useOAuth) {
const result = await MCPConnectionFactory.discoverTools(basic);
return {
tools: result.tools,
oauthRequired: result.oauthRequired,
oauthUrl: result.oauthUrl,
};
}
if (!user || !args.flowManager) {
logger.warn(`${logPrefix} [Discovery] OAuth server requires user and flowManager`);
return { tools: null, oauthRequired: true, oauthUrl: null };
}
const result = await MCPConnectionFactory.discoverTools(basic, {
user,
useOAuth: true,
flowManager: args.flowManager,
tokenMethods: args.tokenMethods,
signal: args.signal,
oauthStart: args.oauthStart,
customUserVars: args.customUserVars,
requestBody: args.requestBody,
connectionTimeout: args.connectionTimeout,
});
return { tools: result.tools, oauthRequired: result.oauthRequired, oauthUrl: result.oauthUrl };
}
/** Returns all available tool functions from app-level connections */
public async getAppToolFunctions(): Promise<t.LCAvailableTools> {
const toolFunctions: t.LCAvailableTools = {};

View file

@ -49,7 +49,7 @@ export abstract class UserConnectionManager {
serverName: string;
forceNew?: boolean;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
const userId = user.id;
const userId = user?.id;
if (!userId) {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
}

View file

@ -1,6 +1,5 @@
import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { TokenMethods, IUser } from '@librechat/data-schemas';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import type * as t from '~/mcp/types';
@ -27,7 +26,7 @@ const mockMCPConnection = MCPConnection as jest.MockedClass<typeof MCPConnection
const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked<typeof MCPOAuthHandler>;
describe('MCPConnectionFactory', () => {
let mockUser: TUser;
let mockUser: IUser | undefined;
let mockServerConfig: t.MCPOptions;
let mockFlowManager: jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
let mockConnectionInstance: jest.Mocked<MCPConnection>;
@ -37,7 +36,7 @@ describe('MCPConnectionFactory', () => {
mockUser = {
id: 'user123',
email: 'test@example.com',
} as TUser;
} as IUser;
mockServerConfig = {
command: 'node',
@ -275,7 +274,7 @@ describe('MCPConnectionFactory', () => {
user: mockUser,
};
const oauthOptions = {
const oauthOptions: t.OAuthConnectionOptions = {
user: mockUser,
useOAuth: true,
returnOnOAuth: true,
@ -424,4 +423,116 @@ describe('MCPConnectionFactory', () => {
);
});
});
describe('discoverTools static method', () => {
const mockTools = [
{ name: 'tool1', description: 'First tool', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Second tool', inputSchema: { type: 'object' } },
];
it('should discover tools from a successfully connected server', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockResolvedValue(undefined);
mockConnectionInstance.isConnected.mockResolvedValue(true);
mockConnectionInstance.fetchTools = jest.fn().mockResolvedValue(mockTools);
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toEqual(mockTools);
expect(result.oauthRequired).toBe(false);
expect(result.oauthUrl).toBeNull();
expect(result.connection).toBe(mockConnectionInstance);
});
it('should detect OAuth required without generating URL in discovery mode', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const mockOAuthStart = jest.fn().mockResolvedValue(undefined);
const oauthOptions = {
useOAuth: true as const,
user: mockUser as unknown as IUser,
flowManager: mockFlowManager,
oauthStart: mockOAuthStart,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
let oauthHandler: (() => Promise<void>) | undefined;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthHandler = handler as () => Promise<void>;
}
return mockConnectionInstance;
});
mockConnectionInstance.connect.mockImplementation(async () => {
if (oauthHandler) {
await oauthHandler();
}
throw new Error('OAuth required');
});
const result = await MCPConnectionFactory.discoverTools(basicOptions, oauthOptions);
expect(result.connection).toBeNull();
expect(result.tools).toBeNull();
expect(result.oauthRequired).toBe(true);
expect(result.oauthUrl).toBeNull();
expect(mockOAuthStart).not.toHaveBeenCalled();
});
it('should return null tools when discovery fails completely', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toBeNull();
expect(result.connection).toBeNull();
expect(result.oauthRequired).toBe(false);
});
it('should handle disconnect errors gracefully during cleanup', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest
.fn()
.mockRejectedValue(new Error('Disconnect failed'));
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toBeNull();
expect(mockLogger.debug).toHaveBeenCalled();
});
});
});

View file

@ -4,6 +4,7 @@ import type { GraphTokenResolver } from '~/utils/graph';
import type * as t from '~/mcp/types';
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPConnection } from '~/mcp/connection';
import { MCPManager } from '~/mcp/MCPManager';
@ -48,6 +49,7 @@ jest.mock('~/mcp/registry/MCPServersInitializer', () => ({
jest.mock('~/mcp/registry/MCPServerInspector');
jest.mock('~/mcp/ConnectionsRepository');
jest.mock('~/mcp/MCPConnectionFactory');
const mockLogger = logger as jest.Mocked<typeof logger>;
@ -787,4 +789,139 @@ describe('MCPManager', () => {
);
});
});
describe('discoverServerTools', () => {
const mockTools = [
{ name: 'tool1', description: 'First tool', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Second tool', inputSchema: { type: 'object' } },
];
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
fetchTools: jest.fn().mockResolvedValue(mockTools),
} as unknown as MCPConnection;
beforeEach(() => {
(MCPConnectionFactory.discoverTools as jest.Mock) = jest.fn();
});
it('should return tools from existing app connection when available', async () => {
mockAppConnections({
get: jest.fn().mockResolvedValue(mockConnection),
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.discoverServerTools({ serverName });
expect(result.tools).toEqual(mockTools);
expect(result.oauthRequired).toBe(false);
expect(result.oauthUrl).toBeNull();
expect(MCPConnectionFactory.discoverTools).not.toHaveBeenCalled();
});
it('should use MCPConnectionFactory.discoverTools when no app connection available', async () => {
mockAppConnections({
get: jest.fn().mockResolvedValue(null),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
type: 'stdio',
command: 'test',
args: [],
});
(MCPConnectionFactory.discoverTools as jest.Mock).mockResolvedValue({
tools: mockTools,
connection: null,
oauthRequired: false,
oauthUrl: null,
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.discoverServerTools({ serverName });
expect(result.tools).toEqual(mockTools);
expect(result.oauthRequired).toBe(false);
expect(MCPConnectionFactory.discoverTools).toHaveBeenCalled();
});
it('should return null tools when server config not found', async () => {
mockAppConnections({
get: jest.fn().mockResolvedValue(null),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(null);
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.discoverServerTools({ serverName });
expect(result.tools).toBeNull();
expect(result.oauthRequired).toBe(false);
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('Server config not found'),
);
});
it('should return OAuth info when server requires OAuth but no user provided', async () => {
mockAppConnections({
get: jest.fn().mockResolvedValue(null),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
type: 'sse',
url: 'https://api.example.com',
requiresOAuth: true,
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.discoverServerTools({ serverName });
expect(result.tools).toBeNull();
expect(result.oauthRequired).toBe(true);
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('OAuth server requires user and flowManager'),
);
});
it('should discover tools with OAuth when user and flowManager provided', async () => {
const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser;
const mockFlowManager = {
createFlow: jest.fn(),
getFlowState: jest.fn(),
deleteFlow: jest.fn(),
};
mockAppConnections({
get: jest.fn().mockResolvedValue(null),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
type: 'sse',
url: 'https://api.example.com',
requiresOAuth: true,
});
(MCPConnectionFactory.discoverTools as jest.Mock).mockResolvedValue({
tools: mockTools,
connection: null,
oauthRequired: true,
oauthUrl: 'https://auth.example.com/authorize',
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.discoverServerTools({
serverName,
user: mockUser,
flowManager: mockFlowManager as unknown as t.ToolDiscoveryOptions['flowManager'],
});
expect(result.tools).toEqual(mockTools);
expect(result.oauthRequired).toBe(true);
expect(result.oauthUrl).toBe('https://auth.example.com/authorize');
expect(MCPConnectionFactory.discoverTools).toHaveBeenCalledWith(
expect.objectContaining({ serverName }),
expect.objectContaining({ user: mockUser, useOAuth: true }),
);
});
});
});

View file

@ -169,7 +169,7 @@ export interface BasicConnectionOptions {
}
export interface OAuthConnectionOptions {
user: IUser;
user?: IUser;
useOAuth: true;
requestBody?: RequestBody;
customUserVars?: Record<string, string>;
@ -181,3 +181,21 @@ export interface OAuthConnectionOptions {
returnOnOAuth?: boolean;
connectionTimeout?: number;
}
export interface ToolDiscoveryOptions {
serverName: string;
user?: IUser;
flowManager?: FlowStateManager<o.MCPOAuthTokens | null>;
tokenMethods?: TokenMethods;
signal?: AbortSignal;
oauthStart?: (authURL: string) => Promise<void>;
customUserVars?: Record<string, string>;
requestBody?: RequestBody;
connectionTimeout?: number;
}
export interface ToolDiscoveryResult {
tools: Tool[] | null;
oauthRequired: boolean;
oauthUrl: string | null;
}

View file

@ -238,6 +238,7 @@ class GenerationJobManagerClass {
const currentRuntime = this.runtimeState.get(streamId);
if (currentRuntime) {
currentRuntime.syncSent = false;
currentRuntime.hasSubscriber = false;
// Persist syncSent=false to Redis for cross-replica consistency
this.jobStore.updateJob(streamId, { syncSent: false }).catch((err) => {
logger.error(`[GenerationJobManager] Failed to persist syncSent=false:`, err);
@ -435,6 +436,7 @@ class GenerationJobManagerClass {
const currentRuntime = this.runtimeState.get(streamId);
if (currentRuntime) {
currentRuntime.syncSent = false;
currentRuntime.hasSubscriber = false;
// Persist syncSent=false to Redis
this.jobStore.updateJob(streamId, { syncSent: false }).catch((err) => {
logger.error(`[GenerationJobManager] Failed to persist syncSent=false:`, err);
@ -767,7 +769,6 @@ class GenerationJobManagerClass {
for (const bufferedEvent of runtime.earlyEventBuffer) {
onChunk(bufferedEvent);
}
// Clear buffer after replay
runtime.earlyEventBuffer = [];
}
}
@ -822,7 +823,6 @@ class GenerationJobManagerClass {
// Buffer early events if no subscriber yet (replay when first subscriber connects)
if (!runtime.hasSubscriber) {
runtime.earlyEventBuffer.push(event);
// Also emit to transport in case subscriber connects mid-flight
}
this.eventTransport.emitChunk(streamId, event);