fix(mcp): clean up OAuth flows on abort and simplify flow handling

- Add abort handler in reconnectServer to clean up mcp_oauth and mcp_get_tokens flows
- Update createAbortHandler to clean up both flow types on tool call abort
- Pass abort signal to createFlow in returnOnOAuth path
- Simplify handleOAuthRequired to always cancel existing flows and start fresh
- This ensures user always gets a new OAuth URL instead of waiting for stale flows
This commit is contained in:
Danny Avila 2025-12-18 19:26:00 -05:00
parent cc931bcf51
commit 39adeac86e
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
2 changed files with 75 additions and 80 deletions

View file

@ -156,7 +156,9 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
return function () { return function () {
logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`); logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`);
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName); const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
// Clean up both mcp_oauth and mcp_get_tokens flows
flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted')); flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted'));
flowManager.failFlow(flowId, 'mcp_get_tokens', new Error('Tool call aborted'));
}; };
} }
@ -204,38 +206,60 @@ async function reconnectServer({
type: 'tool_call_chunk', type: 'tool_call_chunk',
}; };
const runStepEmitter = createRunStepEmitter({ // Set up abort handler to clean up OAuth flows if request is aborted
res, const oauthFlowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
index, const abortHandler = () => {
runId, logger.info(
stepId, `[MCP][User: ${user.id}][${serverName}] Tool loading aborted, cleaning up OAuth flows`,
toolCall, );
streamId, // Clean up both mcp_oauth and mcp_get_tokens flows
}); flowManager.failFlow(oauthFlowId, 'mcp_oauth', new Error('Tool loading aborted'));
const runStepDeltaEmitter = createRunStepDeltaEmitter({ flowManager.failFlow(oauthFlowId, 'mcp_get_tokens', new Error('Tool loading aborted'));
res, };
stepId,
toolCall, if (signal) {
streamId, signal.addEventListener('abort', abortHandler, { once: true });
}); }
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
const oauthStart = createOAuthStart({ try {
res, const runStepEmitter = createRunStepEmitter({
flowId, res,
callback, index,
flowManager, runId,
}); stepId,
return await reinitMCPServer({ toolCall,
user, streamId,
signal, });
serverName, const runStepDeltaEmitter = createRunStepDeltaEmitter({
oauthStart, res,
flowManager, stepId,
userMCPAuthMap, toolCall,
forceNew: true, streamId,
returnOnOAuth: false, });
connectionTimeout: Time.TWO_MINUTES, const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
}); const oauthStart = createOAuthStart({
res,
flowId,
callback,
flowManager,
});
return await reinitMCPServer({
user,
signal,
serverName,
oauthStart,
flowManager,
userMCPAuthMap,
forceNew: true,
returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES,
});
} finally {
// Clean up abort handler to prevent memory leaks
if (signal) {
signal.removeEventListener('abort', abortHandler);
}
}
} }
/** /**

View file

@ -1,7 +1,7 @@
import { logger } from '@librechat/data-schemas'; import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { TokenMethods } from '@librechat/data-schemas'; import type { TokenMethods } from '@librechat/data-schemas';
import type { MCPOAuthTokens, MCPOAuthFlowMetadata, OAuthMetadata } from '~/mcp/oauth'; import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager'; import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types'; import type { FlowMetadata } from '~/flow/types';
import type * as t from './types'; import type * as t from './types';
@ -173,9 +173,10 @@ export class MCPConnectionFactory {
// Create the flow state so the OAuth callback can find it // Create the flow state so the OAuth callback can find it
// We spawn this in the background without waiting for it // We spawn this in the background without waiting for it
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => { // Pass signal so the flow can be aborted if the request is cancelled
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
// The OAuth callback will resolve this flow, so we expect it to timeout here // The OAuth callback will resolve this flow, so we expect it to timeout here
// which is fine - we just need the flow state to exist // or it will be aborted if the request is cancelled - both are fine
}); });
if (this.oauthStart) { if (this.oauthStart) {
@ -354,56 +355,26 @@ export class MCPConnectionFactory {
/** Check if there's already an ongoing OAuth flow for this flowId */ /** Check if there's already an ongoing OAuth flow for this flowId */
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth'); const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow && existingFlow.status === 'PENDING') { // If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
if (existingFlow) {
logger.debug( logger.debug(
`${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`, `${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
); );
/** Tokens from existing flow to complete */ try {
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth'); if (existingFlow.status === 'PENDING') {
if (typeof this.oauthEnd === 'function') { await this.flowManager.failFlow(
await this.oauthEnd(); flowId,
} 'mcp_oauth',
logger.info( new Error('Cancelled for new OAuth request'),
`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`, );
); } else {
/** Client information from the existing flow metadata */
const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata;
const clientInfo = existingMetadata?.clientInfo;
return { tokens, clientInfo };
}
// Clean up old completed/failed flows, but only if they're actually stale
// This prevents race conditions where we delete a flow that's still being processed
if (existingFlow && existingFlow.status !== 'PENDING') {
const STALE_FLOW_THRESHOLD = 2 * 60 * 1000; // 2 minutes
const { isStale, age, status } = await this.flowManager.isFlowStale(
flowId,
'mcp_oauth',
STALE_FLOW_THRESHOLD,
);
if (isStale) {
try {
await this.flowManager.deleteFlow(flowId, 'mcp_oauth'); await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
logger.debug(
`${this.logPrefix} Cleared stale ${status} OAuth flow (age: ${Math.round(age / 1000)}s)`,
);
} catch (error) {
logger.warn(`${this.logPrefix} Failed to clear stale OAuth flow`, error);
}
} else {
logger.debug(
`${this.logPrefix} Skipping cleanup of recent ${status} flow (age: ${Math.round(age / 1000)}s, threshold: ${STALE_FLOW_THRESHOLD / 1000}s)`,
);
// If flow is recent but not pending, something might be wrong
if (status === 'FAILED') {
logger.warn(
`${this.logPrefix} Recent OAuth flow failed, will retry after ${Math.round((STALE_FLOW_THRESHOLD - age) / 1000)}s`,
);
} }
} catch (error) {
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
} }
// Continue to start a new flow below
} }
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`); logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);