mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-17 07:58:08 +01:00
🔐 fix: MCP OAuth Tool Discovery and Event Emission (#11599)
* fix: MCP OAuth tool discovery and event emission in event-driven mode - Add discoverServerTools method to MCPManager for tool discovery when OAuth is required - Fix OAuth event emission to send both ON_RUN_STEP and ON_RUN_STEP_DELTA events - Fix hasSubscriber flag reset in GenerationJobManager for proper event buffering - Add ToolDiscoveryOptions and ToolDiscoveryResult types - Update reinitMCPServer to use new discovery method and propagate OAuth URLs * refactor: Update ToolService and MCP modules for improved functionality - Reintroduced Constants in ToolService for better reference management. - Enhanced loadToolDefinitionsWrapper to handle both response and streamId scenarios. - Updated MCP module to correct type definitions for oauthStart parameter. - Improved MCPConnectionFactory to ensure proper disconnection handling during tool discovery. - Adjusted tests to reflect changes in mock implementations and ensure accurate behavior during OAuth handling. * fix: Refine OAuth handling in MCPConnectionFactory and related tests - Updated the OAuth URL assignment logic in reinitMCPServer to prevent overwriting existing URLs. - Enhanced error logging to provide clearer messages when tool discovery fails. - Adjusted tests to reflect changes in OAuth handling, ensuring accurate detection of OAuth requirements without generating URLs in discovery mode. * refactor: Clean up OAuth URL assignment in reinitMCPServer - Removed redundant OAuth URL assignment logic in the reinitMCPServer function to streamline the tool discovery process. - Enhanced error logging for tool discovery failures, improving clarity in debugging and monitoring. * fix: Update response handling in ToolService for event-driven mode - Changed the condition in loadToolDefinitionsWrapper to check for writableEnded instead of headersSent, ensuring proper event emission when the response is still writable. - This adjustment enhances the reliability of event handling during tool execution, particularly in streaming scenarios.
This commit is contained in:
parent
5af1342dbb
commit
d13037881a
12 changed files with 667 additions and 40 deletions
|
|
@ -1,22 +1,28 @@
|
|||
const {
|
||||
sleep,
|
||||
EnvVar,
|
||||
Constants,
|
||||
StepTypes,
|
||||
GraphEvents,
|
||||
createToolSearch,
|
||||
createProgrammaticToolCallingTool,
|
||||
} = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { tool: toolFn, DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const {
|
||||
sendEvent,
|
||||
getToolkitKey,
|
||||
hasCustomUserVars,
|
||||
getUserMCPAuthMap,
|
||||
loadToolDefinitions,
|
||||
GenerationJobManager,
|
||||
isActionDomainAllowed,
|
||||
buildToolClassification,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Tools,
|
||||
Constants,
|
||||
CacheKeys,
|
||||
ErrorTypes,
|
||||
ContentTypes,
|
||||
imageGenTools,
|
||||
|
|
@ -45,6 +51,8 @@ const {
|
|||
getCachedTools,
|
||||
getMCPServerTools,
|
||||
} = require('~/server/services/Config');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest');
|
||||
const { createOnSearchResults } = require('~/server/services/Tools/search');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
|
|
@ -409,7 +417,9 @@ const isBuiltInTool = (toolName) =>
|
|||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req - The request object
|
||||
* @param {ServerResponse} [params.res] - The response object for SSE events
|
||||
* @param {Object} params.agent - The agent configuration
|
||||
* @param {string|null} [params.streamId] - Stream ID for resumable mode
|
||||
* @returns {Promise<{
|
||||
* toolDefinitions?: import('@librechat/api').LCTool[];
|
||||
* toolRegistry?: Map<string, import('@librechat/api').LCTool>;
|
||||
|
|
@ -417,7 +427,7 @@ const isBuiltInTool = (toolName) =>
|
|||
* hasDeferredTools?: boolean;
|
||||
* }>}
|
||||
*/
|
||||
async function loadToolDefinitionsWrapper({ req, agent }) {
|
||||
async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null }) {
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
return { toolDefinitions: [] };
|
||||
}
|
||||
|
|
@ -473,14 +483,72 @@ async function loadToolDefinitionsWrapper({ req, agent }) {
|
|||
});
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const pendingOAuthServers = new Set();
|
||||
|
||||
const createOAuthEmitter = (serverName) => {
|
||||
return async (authURL) => {
|
||||
const flowId = `${req.user.id}:${serverName}:${Date.now()}`;
|
||||
const stepId = 'step_oauth_login_' + serverName;
|
||||
const toolCall = {
|
||||
id: flowId,
|
||||
name: serverName,
|
||||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
const runStepData = {
|
||||
runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID,
|
||||
id: stepId,
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
index: 0,
|
||||
stepDetails: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [toolCall],
|
||||
},
|
||||
};
|
||||
|
||||
const runStepDeltaData = {
|
||||
id: stepId,
|
||||
delta: {
|
||||
type: StepTypes.TOOL_CALLS,
|
||||
tool_calls: [{ ...toolCall, args: '' }],
|
||||
auth: authURL,
|
||||
expires_at: Date.now() + Time.TWO_MINUTES,
|
||||
},
|
||||
};
|
||||
|
||||
const runStepEvent = { event: GraphEvents.ON_RUN_STEP, data: runStepData };
|
||||
const runStepDeltaEvent = { event: GraphEvents.ON_RUN_STEP_DELTA, data: runStepDeltaData };
|
||||
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, runStepEvent);
|
||||
GenerationJobManager.emitChunk(streamId, runStepDeltaEvent);
|
||||
} else if (res && !res.writableEnded) {
|
||||
sendEvent(res, runStepEvent);
|
||||
sendEvent(res, runStepDeltaEvent);
|
||||
} else {
|
||||
logger.warn(
|
||||
`[Tool Definitions] Cannot emit OAuth event for ${serverName}: no streamId and res not available`,
|
||||
);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
const getOrFetchMCPServerTools = async (userId, serverName) => {
|
||||
const cached = await getMCPServerTools(userId, serverName);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
const oauthStart = async () => {
|
||||
pendingOAuthServers.add(serverName);
|
||||
};
|
||||
|
||||
const result = await reinitMCPServer({
|
||||
user: req.user,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
|
@ -535,7 +603,7 @@ async function loadToolDefinitionsWrapper({ req, agent }) {
|
|||
return definitions;
|
||||
};
|
||||
|
||||
const { toolDefinitions, toolRegistry, hasDeferredTools } = await loadToolDefinitions(
|
||||
let { toolDefinitions, toolRegistry, hasDeferredTools } = await loadToolDefinitions(
|
||||
{
|
||||
userId: req.user.id,
|
||||
agentId: agent.id,
|
||||
|
|
@ -551,6 +619,65 @@ async function loadToolDefinitionsWrapper({ req, agent }) {
|
|||
},
|
||||
);
|
||||
|
||||
if (pendingOAuthServers.size > 0 && (res || streamId)) {
|
||||
const serverNames = Array.from(pendingOAuthServers);
|
||||
logger.info(
|
||||
`[Tool Definitions] OAuth required for ${serverNames.length} server(s): ${serverNames.join(', ')}. Emitting events and waiting.`,
|
||||
);
|
||||
|
||||
const oauthWaitPromises = serverNames.map(async (serverName) => {
|
||||
try {
|
||||
const result = await reinitMCPServer({
|
||||
user: req.user,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
flowManager,
|
||||
returnOnOAuth: false,
|
||||
oauthStart: createOAuthEmitter(serverName),
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
|
||||
if (result?.availableTools) {
|
||||
logger.info(`[Tool Definitions] OAuth completed for ${serverName}, tools available`);
|
||||
return { serverName, success: true };
|
||||
}
|
||||
return { serverName, success: false };
|
||||
} catch (error) {
|
||||
logger.debug(`[Tool Definitions] OAuth wait failed for ${serverName}:`, error?.message);
|
||||
return { serverName, success: false };
|
||||
}
|
||||
});
|
||||
|
||||
const results = await Promise.allSettled(oauthWaitPromises);
|
||||
const successfulServers = results
|
||||
.filter((r) => r.status === 'fulfilled' && r.value.success)
|
||||
.map((r) => r.value.serverName);
|
||||
|
||||
if (successfulServers.length > 0) {
|
||||
logger.info(
|
||||
`[Tool Definitions] Reloading tools after OAuth for: ${successfulServers.join(', ')}`,
|
||||
);
|
||||
const reloadResult = await loadToolDefinitions(
|
||||
{
|
||||
userId: req.user.id,
|
||||
agentId: agent.id,
|
||||
tools: filteredTools,
|
||||
toolOptions: agent.tool_options,
|
||||
deferredToolsEnabled,
|
||||
},
|
||||
{
|
||||
isBuiltInTool,
|
||||
loadAuthValues,
|
||||
getOrFetchMCPServerTools,
|
||||
getActionToolDefinitions,
|
||||
},
|
||||
);
|
||||
toolDefinitions = reloadResult.toolDefinitions;
|
||||
toolRegistry = reloadResult.toolRegistry;
|
||||
hasDeferredTools = reloadResult.hasDeferredTools;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
toolRegistry,
|
||||
userMCPAuthMap,
|
||||
|
|
@ -584,7 +711,7 @@ async function loadAgentTools({
|
|||
definitionsOnly = true,
|
||||
}) {
|
||||
if (definitionsOnly) {
|
||||
return loadToolDefinitionsWrapper({ req, agent });
|
||||
return loadToolDefinitionsWrapper({ req, res, agent, streamId });
|
||||
}
|
||||
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue