🔐 fix: MCP OAuth Tool Discovery and Event Emission (#11599)

* fix: MCP OAuth tool discovery and event emission in event-driven mode

- Add discoverServerTools method to MCPManager for tool discovery when OAuth is required
- Fix OAuth event emission to send both ON_RUN_STEP and ON_RUN_STEP_DELTA events
- Fix hasSubscriber flag reset in GenerationJobManager for proper event buffering
- Add ToolDiscoveryOptions and ToolDiscoveryResult types
- Update reinitMCPServer to use new discovery method and propagate OAuth URLs

* refactor: Update ToolService and MCP modules for improved functionality

- Reintroduced Constants in ToolService for better reference management.
- Enhanced loadToolDefinitionsWrapper to handle both response and streamId scenarios.
- Updated MCP module to correct type definitions for oauthStart parameter.
- Improved MCPConnectionFactory to ensure proper disconnection handling during tool discovery.
- Adjusted tests to reflect changes in mock implementations and ensure accurate behavior during OAuth handling.

* fix: Refine OAuth handling in MCPConnectionFactory and related tests

- Updated the OAuth URL assignment logic in reinitMCPServer to prevent overwriting existing URLs.
- Enhanced error logging to provide clearer messages when tool discovery fails.
- Adjusted tests to reflect changes in OAuth handling, ensuring accurate detection of OAuth requirements without generating URLs in discovery mode.

* refactor: Clean up OAuth URL assignment in reinitMCPServer

- Removed redundant OAuth URL assignment logic in the reinitMCPServer function to streamline the tool discovery process.
- Enhanced error logging for tool discovery failures, improving clarity in debugging and monitoring.

* fix: Update response handling in ToolService for event-driven mode

- Changed the condition in loadToolDefinitionsWrapper to check for writableEnded instead of headersSent, ensuring proper event emission when the response is still writable.
- This adjustment enhances the reliability of event handling during tool execution, particularly in streaming scenarios.
This commit is contained in:
Danny Avila 2026-02-01 19:37:04 -05:00 committed by GitHub
parent 5af1342dbb
commit d13037881a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 667 additions and 40 deletions

View file

@ -1,22 +1,28 @@
const {
sleep,
EnvVar,
Constants,
StepTypes,
GraphEvents,
createToolSearch,
createProgrammaticToolCallingTool,
} = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { tool: toolFn, DynamicStructuredTool } = require('@langchain/core/tools');
const {
sendEvent,
getToolkitKey,
hasCustomUserVars,
getUserMCPAuthMap,
loadToolDefinitions,
GenerationJobManager,
isActionDomainAllowed,
buildToolClassification,
} = require('@librechat/api');
const {
Time,
Tools,
Constants,
CacheKeys,
ErrorTypes,
ContentTypes,
imageGenTools,
@ -45,6 +51,8 @@ const {
getCachedTools,
getMCPServerTools,
} = require('~/server/services/Config');
const { getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest');
const { createOnSearchResults } = require('~/server/services/Tools/search');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
@ -409,7 +417,9 @@ const isBuiltInTool = (toolName) =>
*
* @param {Object} params
* @param {ServerRequest} params.req - The request object
* @param {ServerResponse} [params.res] - The response object for SSE events
* @param {Object} params.agent - The agent configuration
* @param {string|null} [params.streamId] - Stream ID for resumable mode
* @returns {Promise<{
* toolDefinitions?: import('@librechat/api').LCTool[];
* toolRegistry?: Map<string, import('@librechat/api').LCTool>;
@ -417,7 +427,7 @@ const isBuiltInTool = (toolName) =>
* hasDeferredTools?: boolean;
* }>}
*/
async function loadToolDefinitionsWrapper({ req, agent }) {
async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null }) {
if (!agent.tools || agent.tools.length === 0) {
return { toolDefinitions: [] };
}
@ -473,14 +483,72 @@ async function loadToolDefinitionsWrapper({ req, agent }) {
});
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const pendingOAuthServers = new Set();
const createOAuthEmitter = (serverName) => {
return async (authURL) => {
const flowId = `${req.user.id}:${serverName}:${Date.now()}`;
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
id: flowId,
name: serverName,
type: 'tool_call_chunk',
};
const runStepData = {
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
id: stepId,
type: StepTypes.TOOL_CALLS,
index: 0,
stepDetails: {
type: StepTypes.TOOL_CALLS,
tool_calls: [toolCall],
},
};
const runStepDeltaData = {
id: stepId,
delta: {
type: StepTypes.TOOL_CALLS,
tool_calls: [{ ...toolCall, args: '' }],
auth: authURL,
expires_at: Date.now() + Time.TWO_MINUTES,
},
};
const runStepEvent = { event: GraphEvents.ON_RUN_STEP, data: runStepData };
const runStepDeltaEvent = { event: GraphEvents.ON_RUN_STEP_DELTA, data: runStepDeltaData };
if (streamId) {
GenerationJobManager.emitChunk(streamId, runStepEvent);
GenerationJobManager.emitChunk(streamId, runStepDeltaEvent);
} else if (res && !res.writableEnded) {
sendEvent(res, runStepEvent);
sendEvent(res, runStepDeltaEvent);
} else {
logger.warn(
`[Tool Definitions] Cannot emit OAuth event for ${serverName}: no streamId and res not available`,
);
}
};
};
const getOrFetchMCPServerTools = async (userId, serverName) => {
const cached = await getMCPServerTools(userId, serverName);
if (cached) {
return cached;
}
const oauthStart = async () => {
pendingOAuthServers.add(serverName);
};
const result = await reinitMCPServer({
user: req.user,
oauthStart,
flowManager,
serverName,
userMCPAuthMap,
});
@ -535,7 +603,7 @@ async function loadToolDefinitionsWrapper({ req, agent }) {
return definitions;
};
const { toolDefinitions, toolRegistry, hasDeferredTools } = await loadToolDefinitions(
let { toolDefinitions, toolRegistry, hasDeferredTools } = await loadToolDefinitions(
{
userId: req.user.id,
agentId: agent.id,
@ -551,6 +619,65 @@ async function loadToolDefinitionsWrapper({ req, agent }) {
},
);
if (pendingOAuthServers.size > 0 && (res || streamId)) {
const serverNames = Array.from(pendingOAuthServers);
logger.info(
`[Tool Definitions] OAuth required for ${serverNames.length} server(s): ${serverNames.join(', ')}. Emitting events and waiting.`,
);
const oauthWaitPromises = serverNames.map(async (serverName) => {
try {
const result = await reinitMCPServer({
user: req.user,
serverName,
userMCPAuthMap,
flowManager,
returnOnOAuth: false,
oauthStart: createOAuthEmitter(serverName),
connectionTimeout: Time.TWO_MINUTES,
});
if (result?.availableTools) {
logger.info(`[Tool Definitions] OAuth completed for ${serverName}, tools available`);
return { serverName, success: true };
}
return { serverName, success: false };
} catch (error) {
logger.debug(`[Tool Definitions] OAuth wait failed for ${serverName}:`, error?.message);
return { serverName, success: false };
}
});
const results = await Promise.allSettled(oauthWaitPromises);
const successfulServers = results
.filter((r) => r.status === 'fulfilled' && r.value.success)
.map((r) => r.value.serverName);
if (successfulServers.length > 0) {
logger.info(
`[Tool Definitions] Reloading tools after OAuth for: ${successfulServers.join(', ')}`,
);
const reloadResult = await loadToolDefinitions(
{
userId: req.user.id,
agentId: agent.id,
tools: filteredTools,
toolOptions: agent.tool_options,
deferredToolsEnabled,
},
{
isBuiltInTool,
loadAuthValues,
getOrFetchMCPServerTools,
getActionToolDefinitions,
},
);
toolDefinitions = reloadResult.toolDefinitions;
toolRegistry = reloadResult.toolRegistry;
hasDeferredTools = reloadResult.hasDeferredTools;
}
}
return {
toolRegistry,
userMCPAuthMap,
@ -584,7 +711,7 @@ async function loadAgentTools({
definitionsOnly = true,
}) {
if (definitionsOnly) {
return loadToolDefinitionsWrapper({ req, agent });
return loadToolDefinitionsWrapper({ req, res, agent, streamId });
}
if (!agent.tools || agent.tools.length === 0) {