diff --git a/.env.example b/.env.example index b851749baf..e746737ea4 100644 --- a/.env.example +++ b/.env.example @@ -850,3 +850,24 @@ OPENWEATHER_API_KEY= # Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it) # When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration # MCP_SKIP_CODE_CHALLENGE_CHECK=false + +# Circuit breaker: max connect/disconnect cycles before tripping (per server) +# MCP_CB_MAX_CYCLES=7 + +# Circuit breaker: sliding window (ms) for counting cycles +# MCP_CB_CYCLE_WINDOW_MS=45000 + +# Circuit breaker: cooldown (ms) after the cycle breaker trips +# MCP_CB_CYCLE_COOLDOWN_MS=15000 + +# Circuit breaker: max consecutive failed connection rounds before backoff +# MCP_CB_MAX_FAILED_ROUNDS=3 + +# Circuit breaker: sliding window (ms) for counting failed rounds +# MCP_CB_FAILED_WINDOW_MS=120000 + +# Circuit breaker: base backoff (ms) after failed round threshold is reached +# MCP_CB_BASE_BACKOFF_MS=30000 + +# Circuit breaker: max backoff cap (ms) for exponential backoff +# MCP_CB_MAX_BACKOFF_MS=300000 diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index e87fcf8f15..009b602604 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -32,6 +32,9 @@ jest.mock('@librechat/api', () => { getFlowState: jest.fn(), completeOAuthFlow: jest.fn(), generateFlowId: jest.fn(), + resolveStateToFlowId: jest.fn(async (state) => state), + storeStateMapping: jest.fn(), + deleteStateMapping: jest.fn(), }, MCPTokenStorage: { storeTokens: jest.fn(), @@ -180,7 +183,10 @@ describe('MCP Routes', () => { MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ authorizationUrl: 'https://oauth.example.com/auth', flowId: 'test-user-id:test-server', + flowMetadata: { state: 'random-state-value' }, }); + MCPOAuthHandler.storeStateMapping.mockResolvedValue(); + mockFlowManager.initFlow = jest.fn().mockResolvedValue(); const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', @@ -367,6 +373,121 @@ describe('MCP Routes', () => { expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`); }); + describe('CSRF fallback via active PENDING flow', () => { + it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ + status: 'PENDING', + createdAt: Date.now(), + }), + completeFlow: jest.fn().mockResolvedValue(true), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: {}, + clientInfo: {}, + codeVerifier: 'test-verifier', + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({ + access_token: 'test-token', + }); + MCPTokenStorage.storeTokens.mockResolvedValue(); + mockRegistryInstance.getServerConfig.mockResolvedValue({}); + + const mockMcpManager = { + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }; + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue(); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toContain(`${basePath}/oauth/success`); + }); + + it('should reject when no PENDING flow exists and no cookies are present', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue(null), + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should reject when only a COMPLETED flow exists (not PENDING)', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ + status: 'COMPLETED', + createdAt: Date.now(), + }), + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ + status: 'PENDING', + createdAt: Date.now() - 3 * 60 * 1000, + }), + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + }); + it('should handle OAuth callback successfully', async () => { // mockRegistryInstance is defined at the top of the file const mockFlowManager = { diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 2db8c2c462..0afac81192 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -13,6 +13,7 @@ const { MCPOAuthHandler, MCPTokenStorage, setOAuthSession, + PENDING_STALE_MS, getUserMCPAuthMap, validateOAuthCsrf, OAUTH_CSRF_COOKIE, @@ -91,7 +92,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async } const oauthHeaders = await getOAuthHeaders(serverName, userId); - const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow( + const { + authorizationUrl, + flowId: oauthFlowId, + flowMetadata, + } = await MCPOAuthHandler.initiateOAuthFlow( serverName, serverUrl, userId, @@ -101,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); + await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager); setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH); res.redirect(authorizationUrl); } catch (error) { @@ -143,30 +149,52 @@ router.get('/:serverName/oauth/callback', async (req, res) => { return res.redirect(`${basePath}/oauth/error?error=missing_state`); } - const flowId = state; - logger.debug('[MCP OAuth] Using flow ID from state', { flowId }); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager); + if (!flowId) { + logger.error('[MCP OAuth] Could not resolve state to flow ID', { state }); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); + } + logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId }); const flowParts = flowId.split(':'); if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) { - logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId }); + logger.error('[MCP OAuth] Invalid flow ID format', { flowId }); return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } const [flowUserId] = flowParts; - if ( - !validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) && - !validateOAuthSession(req, flowUserId) - ) { - logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', { - flowId, - hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], - hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], - }); - return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + + const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH); + const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId); + let hasActiveFlow = false; + if (!hasCsrf && !hasSession) { + const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth'); + const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity; + hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS; + if (hasActiveFlow) { + logger.debug( + '[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow', + { + flowId, + }, + ); + } } - const flowsCache = getLogStores(CacheKeys.FLOWS); - const flowManager = getFlowStateManager(flowsCache); + if (!hasCsrf && !hasSession && !hasActiveFlow) { + logger.error( + '[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow', + { + flowId, + hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], + hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], + }, + ); + return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + } logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId); const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager); @@ -281,7 +309,13 @@ router.get('/:serverName/oauth/callback', async (req, res) => { const toolFlowId = flowState.metadata?.toolFlowId; if (toolFlowId) { logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId }); - await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); + const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); + if (!completed) { + logger.warn( + '[MCP OAuth] Tool flow state not found during completion — waiter will time out', + { toolFlowId }, + ); + } } /** Redirect to success page with flowId and serverName */ diff --git a/packages/api/jest.config.mjs b/packages/api/jest.config.mjs index 530150a7fa..df9cf6bcc2 100644 --- a/packages/api/jest.config.mjs +++ b/packages/api/jest.config.mjs @@ -7,6 +7,7 @@ export default { '\\.dev\\.ts$', '\\.helper\\.ts$', '\\.helper\\.d\\.ts$', + '/__tests__/helpers/', ], coverageReporters: ['text', 'cobertura'], testResultsProcessor: 'jest-junit', diff --git a/packages/api/package.json b/packages/api/package.json index e4ca4ef3c5..46587797a5 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -18,8 +18,8 @@ "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs", - "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", - "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", + "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", + "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", "test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", "test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", diff --git a/packages/api/src/flow/manager.ts b/packages/api/src/flow/manager.ts index 4f9023a3d7..b68b9edb7a 100644 --- a/packages/api/src/flow/manager.ts +++ b/packages/api/src/flow/manager.ts @@ -3,6 +3,18 @@ import { logger } from '@librechat/data-schemas'; import type { StoredDataNoRaw } from 'keyv'; import type { FlowState, FlowMetadata, FlowManagerOptions } from './types'; +export const PENDING_STALE_MS = 2 * 60 * 1000; + +const SECONDS_THRESHOLD = 1e10; + +/** + * Normalizes an expiration timestamp to milliseconds. + * Timestamps below 10 billion are assumed to be in seconds (valid until ~2286). + */ +export function normalizeExpiresAt(timestamp: number): number { + return timestamp < SECONDS_THRESHOLD ? timestamp * 1000 : timestamp; +} + export class FlowStateManager { private keyv: Keyv; private ttl: number; @@ -45,32 +57,8 @@ export class FlowStateManager { return `${type}:${flowId}`; } - /** - * Normalizes an expiration timestamp to milliseconds. - * Detects whether the input is in seconds or milliseconds based on magnitude. - * Timestamps below 10 billion are assumed to be in seconds (valid until ~2286). - * @param timestamp - The expiration timestamp (in seconds or milliseconds) - * @returns The timestamp normalized to milliseconds - */ - private normalizeExpirationTimestamp(timestamp: number): number { - const SECONDS_THRESHOLD = 1e10; - if (timestamp < SECONDS_THRESHOLD) { - return timestamp * 1000; - } - return timestamp; - } - - /** - * Checks if a flow's token has expired based on its expires_at field - * @param flowState - The flow state to check - * @returns true if the token has expired, false otherwise (including if no expires_at exists) - */ private isTokenExpired(flowState: FlowState | undefined): boolean { - if (!flowState?.result) { - return false; - } - - if (typeof flowState.result !== 'object') { + if (!flowState?.result || typeof flowState.result !== 'object') { return false; } @@ -79,13 +67,11 @@ export class FlowStateManager { } const expiresAt = (flowState.result as { expires_at: unknown }).expires_at; - if (typeof expiresAt !== 'number' || !Number.isFinite(expiresAt)) { return false; } - const normalizedExpiresAt = this.normalizeExpirationTimestamp(expiresAt); - return normalizedExpiresAt < Date.now(); + return normalizeExpiresAt(expiresAt) < Date.now(); } /** @@ -149,6 +135,8 @@ export class FlowStateManager { let elapsedTime = 0; let isCleanedUp = false; let intervalId: NodeJS.Timeout | null = null; + let missingStateRetried = false; + let isRetrying = false; // Cleanup function to avoid duplicate cleanup const cleanup = () => { @@ -188,16 +176,29 @@ export class FlowStateManager { } intervalId = setInterval(async () => { - if (isCleanedUp) return; + if (isCleanedUp || isRetrying) return; try { - const flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; + let flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; if (!flowState) { - cleanup(); - logger.error(`[${flowKey}] Flow state not found`); - reject(new Error(`${type} Flow state not found`)); - return; + if (!missingStateRetried) { + missingStateRetried = true; + isRetrying = true; + logger.warn( + `[${flowKey}] Flow state not found, retrying once after 500ms (race recovery)`, + ); + await new Promise((r) => setTimeout(r, 500)); + flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; + isRetrying = false; + } + + if (!flowState) { + cleanup(); + logger.error(`[${flowKey}] Flow state not found after retry`); + reject(new Error(`${type} Flow state not found`)); + return; + } } if (signal?.aborted) { @@ -251,10 +252,10 @@ export class FlowStateManager { const flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; if (!flowState) { - logger.warn('[FlowStateManager] Cannot complete flow - flow state not found', { - flowId, - type, - }); + logger.warn( + '[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping', + { flowId, type }, + ); return false; } @@ -297,7 +298,7 @@ export class FlowStateManager { async isFlowStale( flowId: string, type: string, - staleThresholdMs: number = 2 * 60 * 1000, + staleThresholdMs: number = PENDING_STALE_MS, ): Promise<{ isStale: boolean; age: number; status?: string }> { const flowKey = this.getFlowKey(flowId, type); const flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 03131b659b..0fc86e0315 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -2,11 +2,11 @@ import { logger } from '@librechat/data-schemas'; import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; import type { Tool } from '@modelcontextprotocol/sdk/types.js'; import type { TokenMethods } from '@librechat/data-schemas'; -import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth'; +import type { MCPOAuthTokens, OAuthMetadata, MCPOAuthFlowMetadata } from '~/mcp/oauth'; import type { FlowStateManager } from '~/flow/manager'; -import type { FlowMetadata } from '~/flow/types'; import type * as t from './types'; -import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; +import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { PENDING_STALE_MS, normalizeExpiresAt } from '~/flow/manager'; import { sanitizeUrlForLogging } from './utils'; import { withTimeout } from '~/utils/promise'; import { MCPConnection } from './connection'; @@ -104,6 +104,7 @@ export class MCPConnectionFactory { return { tools, connection, oauthRequired: false, oauthUrl: null }; } } catch { + MCPConnection.decrementCycleCount(this.serverName); logger.debug( `${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`, ); @@ -125,7 +126,9 @@ export class MCPConnectionFactory { } return { tools, connection: null, oauthRequired, oauthUrl }; } + MCPConnection.decrementCycleCount(this.serverName); } catch (listError) { + MCPConnection.decrementCycleCount(this.serverName); logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError); } @@ -265,6 +268,10 @@ export class MCPConnectionFactory { if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`); return tokens; } catch (error) { + if (error instanceof ReauthenticationRequiredError) { + logger.info(`${this.logPrefix} ${error.message}, will trigger OAuth flow`); + return null; + } logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error); return null; } @@ -306,11 +313,21 @@ export class MCPConnectionFactory { const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth'); if (existingFlow?.status === 'PENDING') { + const pendingAge = existingFlow.createdAt + ? Date.now() - existingFlow.createdAt + : Infinity; + + if (pendingAge < PENDING_STALE_MS) { + logger.debug( + `${this.logPrefix} Recent PENDING OAuth flow exists (${Math.round(pendingAge / 1000)}s old), skipping new initiation`, + ); + connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); + return; + } + logger.debug( - `${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`, + `${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will replace`, ); - connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); - return; } const { @@ -326,11 +343,17 @@ export class MCPConnectionFactory { ); if (existingFlow) { + const oldState = (existingFlow.metadata as MCPOAuthFlowMetadata)?.state; await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth'); + if (oldState) { + await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager!); + } } // Store flow state BEFORE redirecting so the callback can find it - await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', flowMetadata); + const metadataWithUrl = { ...flowMetadata, authorizationUrl }; + await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl); + await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager!); // Start monitoring in background — createFlow will find the existing PENDING state // written by initFlow above, so metadata arg is unused (pass {} to make that explicit) @@ -495,11 +518,75 @@ export class MCPConnectionFactory { const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth'); if (existingFlow) { + const flowMeta = existingFlow.metadata as MCPOAuthFlowMetadata | undefined; + + if (existingFlow.status === 'PENDING') { + const pendingAge = existingFlow.createdAt + ? Date.now() - existingFlow.createdAt + : Infinity; + + if (pendingAge < PENDING_STALE_MS) { + logger.debug( + `${this.logPrefix} Found recent PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), joining instead of creating new one`, + ); + + const storedAuthUrl = flowMeta?.authorizationUrl; + if (storedAuthUrl && typeof this.oauthStart === 'function') { + logger.info( + `${this.logPrefix} Re-issuing stored authorization URL to caller while joining PENDING flow`, + ); + await this.oauthStart(storedAuthUrl); + } + + const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth', {}, this.signal); + if (typeof this.oauthEnd === 'function') { + await this.oauthEnd(); + } + logger.info( + `${this.logPrefix} Joined existing OAuth flow completed for ${this.serverName}`, + ); + return { + tokens, + clientInfo: flowMeta?.clientInfo, + metadata: flowMeta?.metadata, + }; + } + + logger.debug( + `${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will delete and start fresh`, + ); + } + + if (existingFlow.status === 'COMPLETED') { + const completedAge = existingFlow.completedAt + ? Date.now() - existingFlow.completedAt + : Infinity; + const cachedTokens = existingFlow.result as MCPOAuthTokens | null | undefined; + const isTokenExpired = + cachedTokens?.expires_at != null && + normalizeExpiresAt(cachedTokens.expires_at) < Date.now(); + + if (completedAge <= PENDING_STALE_MS && cachedTokens !== undefined && !isTokenExpired) { + logger.debug( + `${this.logPrefix} Found non-stale COMPLETED OAuth flow, reusing cached tokens`, + ); + return { + tokens: cachedTokens, + clientInfo: flowMeta?.clientInfo, + metadata: flowMeta?.metadata, + }; + } + } + logger.debug( `${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`, ); try { + const oldState = flowMeta?.state; await this.flowManager.deleteFlow(flowId, 'mcp_oauth'); + if (oldState) { + await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager); + } } catch (error) { logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error); } @@ -519,7 +606,9 @@ export class MCPConnectionFactory { ); // Store flow state BEFORE redirecting so the callback can find it - await this.flowManager.initFlow(newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata); + const metadataWithUrl = { ...flowMetadata, authorizationUrl }; + await this.flowManager.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl); + await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager); if (typeof this.oauthStart === 'function') { logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`); diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 0828b1720a..76523fc0fc 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -1,10 +1,10 @@ import { logger } from '@librechat/data-schemas'; import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; -import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; -import { MCPConnection } from './connection'; import type * as t from './types'; +import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { MCPConnection } from './connection'; import { mcpConfig } from './mcpConfig'; /** @@ -21,6 +21,8 @@ export abstract class UserConnectionManager { protected userConnections: Map> = new Map(); /** Last activity timestamp for users (not per server) */ protected userLastActivity: Map = new Map(); + /** In-flight connection promises keyed by `userId:serverName` — coalesces concurrent attempts */ + protected pendingConnections: Map> = new Map(); /** Updates the last activity timestamp for a user */ protected updateUserLastActivity(userId: string): void { @@ -31,29 +33,64 @@ export abstract class UserConnectionManager { ); } - /** Gets or creates a connection for a specific user */ - public async getUserConnection({ - serverName, - forceNew, - user, - flowManager, - customUserVars, - requestBody, - tokenMethods, - oauthStart, - oauthEnd, - signal, - returnOnOAuth = false, - connectionTimeout, - }: { - serverName: string; - forceNew?: boolean; - } & Omit): Promise { + /** Gets or creates a connection for a specific user, coalescing concurrent attempts */ + public async getUserConnection( + opts: { + serverName: string; + forceNew?: boolean; + } & Omit, + ): Promise { + const { serverName, forceNew, user } = opts; const userId = user?.id; if (!userId) { throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); } + const lockKey = `${userId}:${serverName}`; + + if (!forceNew) { + const pending = this.pendingConnections.get(lockKey); + if (pending) { + logger.debug(`[MCP][User: ${userId}][${serverName}] Joining in-flight connection attempt`); + return pending; + } + } + + const connectionPromise = this.createUserConnectionInternal(opts, userId); + + if (!forceNew) { + this.pendingConnections.set(lockKey, connectionPromise); + } + + try { + return await connectionPromise; + } finally { + if (!forceNew && this.pendingConnections.get(lockKey) === connectionPromise) { + this.pendingConnections.delete(lockKey); + } + } + } + + private async createUserConnectionInternal( + { + serverName, + forceNew, + user, + flowManager, + customUserVars, + requestBody, + tokenMethods, + oauthStart, + oauthEnd, + signal, + returnOnOAuth = false, + connectionTimeout, + }: { + serverName: string; + forceNew?: boolean; + } & Omit, + userId: string, + ): Promise { if (await this.appConnections!.has(serverName)) { throw new McpError( ErrorCode.InvalidRequest, @@ -188,6 +225,7 @@ export abstract class UserConnectionManager { /** Disconnects and removes a specific user connection */ public async disconnectUserConnection(userId: string, serverName: string): Promise { + this.pendingConnections.delete(`${userId}:${serverName}`); const userMap = this.userConnections.get(userId); const connection = userMap?.get(serverName); if (connection) { @@ -215,6 +253,12 @@ export abstract class UserConnectionManager { ); } await Promise.allSettled(disconnectPromises); + // Clean up any pending connection promises for this user + for (const key of this.pendingConnections.keys()) { + if (key.startsWith(`${userId}:`)) { + this.pendingConnections.delete(key); + } + } // Ensure user activity timestamp is removed this.userLastActivity.delete(userId); logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index de18e27e89..bceb23b246 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -275,7 +275,7 @@ describe('MCPConnectionFactory', () => { expect(mockFlowManager.initFlow).toHaveBeenCalledWith( 'flow123', 'mcp_oauth', - mockFlowData.flowMetadata, + expect.objectContaining(mockFlowData.flowMetadata), ); const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock @@ -550,7 +550,7 @@ describe('MCPConnectionFactory', () => { expect(mockFlowManager.initFlow).toHaveBeenCalledWith( 'flow123', 'mcp_oauth', - mockFlowData.flowMetadata, + expect.objectContaining(mockFlowData.flowMetadata), ); const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock diff --git a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts new file mode 100644 index 0000000000..cdba06cf8d --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts @@ -0,0 +1,232 @@ +/** + * Tests for the OAuth callback CSRF fallback logic. + * + * The callback route validates requests via three mechanisms (in order): + * 1. CSRF cookie (HMAC-based, set during initiate) + * 2. Session cookie (bound to authenticated userId) + * 3. Active PENDING flow in FlowStateManager (fallback for SSE/chat flows) + * + * This suite tests mechanism 3 — the PENDING flow fallback — including + * staleness enforcement and rejection of non-PENDING flows. + * + * These tests exercise the validation functions directly for fast, + * focused coverage. Route-level integration tests using supertest + * are in api/server/routes/__tests__/mcp.spec.js ("CSRF fallback + * via active PENDING flow" describe block). + */ + +import { Keyv } from 'keyv'; +import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager'; +import type { Request, Response } from 'express'; +import { + generateOAuthCsrfToken, + OAUTH_SESSION_COOKIE, + validateOAuthSession, + OAUTH_CSRF_COOKIE, + validateOAuthCsrf, +} from '~/oauth/csrf'; +import { MockKeyv } from './helpers/oauthTestServer'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +const CSRF_COOKIE_PATH = '/api/mcp'; + +function makeReq(cookies: Record = {}): Request { + return { cookies } as unknown as Request; +} + +function makeRes(): Response { + const res = { + clearCookie: jest.fn(), + } as unknown as Response; + return res; +} + +/** + * Replicate the callback route's three-tier validation logic. + * Returns which mechanism (if any) authorized the request. + */ +async function validateCallback( + req: Request, + res: Response, + flowId: string, + flowManager: FlowStateManager, +): Promise<'csrf' | 'session' | 'pendingFlow' | false> { + const flowUserId = flowId.split(':')[0]; + + const hasCsrf = validateOAuthCsrf(req, res, flowId, CSRF_COOKIE_PATH); + if (hasCsrf) { + return 'csrf'; + } + + const hasSession = validateOAuthSession(req, flowUserId); + if (hasSession) { + return 'session'; + } + + const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth'); + const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity; + if (pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS) { + return 'pendingFlow'; + } + + return false; +} + +describe('OAuth Callback CSRF Fallback', () => { + let flowManager: FlowStateManager; + + beforeEach(() => { + process.env.JWT_SECRET = 'test-secret-for-csrf'; + const store = new MockKeyv(); + flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 300000, ci: true }); + }); + + afterEach(() => { + delete process.env.JWT_SECRET; + jest.clearAllMocks(); + }); + + describe('CSRF cookie validation (mechanism 1)', () => { + it('should accept valid CSRF cookie', async () => { + const flowId = 'user1:test-server'; + const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('csrf'); + }); + + it('should reject invalid CSRF cookie', async () => { + const flowId = 'user1:test-server'; + const req = makeReq({ [OAUTH_CSRF_COOKIE]: 'wrong-token-value' }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + }); + + describe('Session cookie validation (mechanism 2)', () => { + it('should accept valid session cookie when CSRF is absent', async () => { + const flowId = 'user1:test-server'; + const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('session'); + }); + }); + + describe('PENDING flow fallback (mechanism 3)', () => { + it('should accept when a fresh PENDING flow exists and no cookies are present', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('pendingFlow'); + }); + + it('should reject when no PENDING flow, no CSRF cookie, and no session cookie', async () => { + const flowId = 'user1:test-server'; + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should reject when only a COMPLETED flow exists (not PENDING)', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + await flowManager.completeFlow(flowId, 'mcp_oauth', { access_token: 'tok' } as never); + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should reject when only a FAILED flow exists', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + await flowManager.failFlow(flowId, 'mcp_oauth', 'some error'); + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + // Artificially age the flow past the staleness threshold + const store = (flowManager as unknown as { keyv: { get: (k: string) => Promise } }) + .keyv; + const flowState = (await store.get(`mcp_oauth:${flowId}`)) as { createdAt: number }; + flowState.createdAt = Date.now() - PENDING_STALE_MS - 1000; + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should accept PENDING flow that is just under the staleness threshold', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + // Flow was just created, well under threshold + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('pendingFlow'); + }); + }); + + describe('Priority ordering', () => { + it('should prefer CSRF cookie over PENDING flow', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('csrf'); + }); + + it('should prefer session cookie over PENDING flow when CSRF is absent', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('session'); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts new file mode 100644 index 0000000000..4e168d00f3 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts @@ -0,0 +1,268 @@ +/** + * Tests for MCPConnection OAuth event cycle against a real OAuth-gated MCP server. + * + * Verifies: oauthRequired emission on 401, oauthHandled reconnection, + * oauthFailed rejection, timeout behavior, and token expiry mid-session. + */ + +import { MCPConnection } from '~/mcp/connection'; +import { createOAuthMCPServer } from './helpers/oauthTestServer'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +jest.mock('~/auth', () => ({ + createSSRFSafeUndiciConnect: jest.fn(() => undefined), + resolveHostnameSSRF: jest.fn(async () => false), +})); + +jest.mock('~/mcp/mcpConfig', () => ({ + mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 }, +})); + +async function safeDisconnect(conn: MCPConnection | null): Promise { + if (!conn) { + return; + } + try { + await conn.disconnect(); + } catch { + // Ignore disconnect errors during cleanup + } +} + +async function exchangeCodeForToken(serverUrl: string): Promise { + const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, { + redirect: 'manual', + }); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${serverUrl}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const data = (await tokenRes.json()) as { access_token: string }; + return data.access_token; +} + +describe('MCPConnection OAuth Events — Real Server', () => { + let server: OAuthTestServer; + let connection: MCPConnection | null = null; + + beforeEach(() => { + MCPConnection.clearCooldown('test-server'); + }); + + afterEach(async () => { + await safeDisconnect(connection); + connection = null; + if (server) { + await server.close(); + } + jest.clearAllMocks(); + }); + + describe('oauthRequired event', () => { + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + it('should emit oauthRequired when connecting without a token', async () => { + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { type: 'streamable-http', url: server.url }, + userId: 'user-1', + }); + + const oauthRequiredPromise = new Promise<{ + serverName: string; + error: Error; + serverUrl?: string; + userId?: string; + }>((resolve) => { + connection!.on('oauthRequired', (data) => { + resolve( + data as { + serverName: string; + error: Error; + serverUrl?: string; + userId?: string; + }, + ); + }); + }); + + // Connection will fail with 401, emitting oauthRequired + const connectPromise = connection.connect().catch(() => { + // Expected to fail since no one handles oauthRequired + }); + + let raceTimer: NodeJS.Timeout | undefined; + const eventData = await Promise.race([ + oauthRequiredPromise, + new Promise((_, reject) => { + raceTimer = setTimeout( + () => reject(new Error('Timed out waiting for oauthRequired')), + 10000, + ); + }), + ]).finally(() => clearTimeout(raceTimer)); + + expect(eventData.serverName).toBe('test-server'); + expect(eventData.error).toBeDefined(); + + // Emit oauthFailed to unblock connect() + connection.emit('oauthFailed', new Error('test cleanup')); + await connectPromise.catch(() => undefined); + }); + + it('should not emit oauthRequired when connecting with a valid token', async () => { + const accessToken = await exchangeCodeForToken(server.url); + + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { type: 'streamable-http', url: server.url }, + userId: 'user-1', + oauthTokens: { + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens, + }); + + let oauthFired = false; + connection.on('oauthRequired', () => { + oauthFired = true; + }); + + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + expect(oauthFired).toBe(false); + }); + }); + + describe('oauthHandled reconnection', () => { + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + it('should succeed on retry after oauthHandled provides valid tokens', async () => { + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { + type: 'streamable-http', + url: server.url, + initTimeout: 15000, + }, + userId: 'user-1', + }); + + // First connect fails with 401 → oauthRequired fires + let oauthFired = false; + connection.on('oauthRequired', () => { + oauthFired = true; + connection!.emit('oauthFailed', new Error('Will retry with tokens')); + }); + + // First attempt fails as expected + await expect(connection.connect()).rejects.toThrow(); + expect(oauthFired).toBe(true); + + // Now set valid tokens and reconnect + const accessToken = await exchangeCodeForToken(server.url); + connection.setOAuthTokens({ + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + }); + }); + + describe('oauthFailed rejection', () => { + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + it('should reject connect() when oauthFailed is emitted', async () => { + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { + type: 'streamable-http', + url: server.url, + initTimeout: 15000, + }, + userId: 'user-1', + }); + + connection.on('oauthRequired', () => { + connection!.emit('oauthFailed', new Error('User denied OAuth')); + }); + + await expect(connection.connect()).rejects.toThrow(); + }); + }); + + describe('Token expiry during session', () => { + it('should detect expired token on reconnect and emit oauthRequired', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 1000 }); + + const accessToken = await exchangeCodeForToken(server.url); + + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { + type: 'streamable-http', + url: server.url, + initTimeout: 15000, + }, + userId: 'user-1', + oauthTokens: { + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens, + }); + + // Initial connect should succeed + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + await connection.disconnect(); + + // Wait for token to expire + await new Promise((r) => setTimeout(r, 1200)); + + // Reconnect should trigger oauthRequired since token is expired on the server + let oauthFired = false; + connection.on('oauthRequired', () => { + oauthFired = true; + connection!.emit('oauthFailed', new Error('Will retry with fresh token')); + }); + + // First reconnect fails with 401 → oauthRequired + await expect(connection.connect()).rejects.toThrow(); + expect(oauthFired).toBe(true); + + // Get fresh token and reconnect + const newToken = await exchangeCodeForToken(server.url); + connection.setOAuthTokens({ + access_token: newToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts new file mode 100644 index 0000000000..8437177c86 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -0,0 +1,538 @@ +/** + * OAuth flow tests against a real HTTP server. + * + * Tests MCPOAuthHandler.refreshOAuthTokens and MCPTokenStorage lifecycle + * using a real test OAuth server (not mocked SDK functions). + */ + +import { createHash } from 'crypto'; +import { Keyv } from 'keyv'; +import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; +import { FlowStateManager } from '~/flow/manager'; +import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +describe('MCP OAuth Flow — Real HTTP Server', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('Token refresh with real server', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should refresh tokens with stored client info via real /token endpoint', async () => { + // First get initial tokens + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + expect(initial.refresh_token).toBeDefined(); + + // Register a client so we have clientInfo + const regRes = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const clientInfo = (await regRes.json()) as { + client_id: string; + client_secret: string; + }; + + // Refresh tokens using the real endpoint + const refreshed = await MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'test-server', + serverUrl: server.url, + clientInfo: { + ...clientInfo, + redirect_uris: ['http://localhost/callback'], + }, + }, + {}, + { + token_url: `${server.url}token`, + client_id: clientInfo.client_id, + client_secret: clientInfo.client_secret, + token_exchange_method: 'DefaultPost', + }, + ); + + expect(refreshed.access_token).toBeDefined(); + expect(refreshed.access_token).not.toBe(initial.access_token); + expect(refreshed.token_type).toBe('Bearer'); + expect(refreshed.obtained_at).toBeDefined(); + }); + + it('should get new refresh token when server rotates', async () => { + const rotatingServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + rotateRefreshTokens: true, + }); + + try { + const code = await rotatingServer.getAuthCode(); + const tokenRes = await fetch(`${rotatingServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const refreshed = await MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'test-server', + serverUrl: rotatingServer.url, + }, + {}, + { + token_url: `${rotatingServer.url}token`, + client_id: 'anon', + token_exchange_method: 'DefaultPost', + }, + ); + + expect(refreshed.access_token).not.toBe(initial.access_token); + expect(refreshed.refresh_token).toBeDefined(); + expect(refreshed.refresh_token).not.toBe(initial.refresh_token); + } finally { + await rotatingServer.close(); + } + }); + + it('should fail refresh with invalid refresh token', async () => { + await expect( + MCPOAuthHandler.refreshOAuthTokens( + 'invalid-refresh-token', + { + serverName: 'test-server', + serverUrl: server.url, + }, + {}, + { + token_url: `${server.url}token`, + client_id: 'anon', + token_exchange_method: 'DefaultPost', + }, + ), + ).rejects.toThrow(); + }); + }); + + describe('OAuth server metadata discovery', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ issueRefreshTokens: true }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should expose /.well-known/oauth-authorization-server', async () => { + const res = await fetch(`${server.url}.well-known/oauth-authorization-server`); + expect(res.status).toBe(200); + + const metadata = (await res.json()) as { + authorization_endpoint: string; + token_endpoint: string; + registration_endpoint: string; + grant_types_supported: string[]; + }; + + expect(metadata.authorization_endpoint).toContain('/authorize'); + expect(metadata.token_endpoint).toContain('/token'); + expect(metadata.registration_endpoint).toContain('/register'); + expect(metadata.grant_types_supported).toContain('authorization_code'); + expect(metadata.grant_types_supported).toContain('refresh_token'); + }); + + it('should not advertise refresh_token grant when disabled', async () => { + const noRefreshServer = await createOAuthMCPServer({ + issueRefreshTokens: false, + }); + try { + const res = await fetch(`${noRefreshServer.url}.well-known/oauth-authorization-server`); + const metadata = (await res.json()) as { grant_types_supported: string[] }; + expect(metadata.grant_types_supported).not.toContain('refresh_token'); + } finally { + await noRefreshServer.close(); + } + }); + }); + + describe('Dynamic client registration', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer(); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should register a client via /register endpoint', async () => { + const res = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + redirect_uris: ['http://localhost/callback'], + }), + }); + + expect(res.status).toBe(200); + const client = (await res.json()) as { + client_id: string; + client_secret: string; + redirect_uris: string[]; + }; + + expect(client.client_id).toBeDefined(); + expect(client.client_secret).toBeDefined(); + expect(client.redirect_uris).toEqual(['http://localhost/callback']); + expect(server.registeredClients.has(client.client_id)).toBe(true); + }); + }); + + describe('End-to-End: store, retrieve, expire, refresh cycle', () => { + it('should perform full token lifecycle with real server', async () => { + const server = await createOAuthMCPServer({ + tokenTTLMs: 1000, + issueRefreshTokens: true, + }); + const tokenStore = new InMemoryTokenStore(); + + try { + // 1. Get initial tokens via auth code exchange + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // 2. Store tokens + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'test-srv', + tokens: initial, + createToken: tokenStore.createToken, + }); + + // 3. Retrieve — should succeed + const valid = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }); + expect(valid).not.toBeNull(); + expect(valid!.access_token).toBe(initial.access_token); + expect(valid!.refresh_token).toBe(initial.refresh_token); + + // 4. Wait for expiry + await new Promise((r) => setTimeout(r, 1200)); + + // 5. Retrieve again — should trigger refresh via callback + const refreshCallback = async (refreshToken: string): Promise => { + const refreshRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + + if (!refreshRes.ok) { + throw new Error(`Refresh failed: ${refreshRes.status}`); + } + + const data = (await refreshRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token?: string; + }; + + return { + ...data, + obtained_at: Date.now(), + expires_at: Date.now() + data.expires_in * 1000, + }; + }; + + const refreshed = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(refreshed).not.toBeNull(); + expect(refreshed!.access_token).not.toBe(initial.access_token); + + // 6. Verify the refreshed token works against the server + const mcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${refreshed!.access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + expect(mcpRes.status).toBe(200); + } finally { + await server.close(); + } + }); + }); + + describe('completeOAuthFlow via FlowStateManager', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ issueRefreshTokens: true }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should exchange auth code and complete flow in FlowStateManager', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'test-user:test-server'; + const code = await server.getAuthCode(); + + // Initialize the flow with metadata the handler needs + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverUrl: server.url, + clientInfo: { + client_id: 'test-client', + redirect_uris: ['http://localhost/callback'], + }, + codeVerifier: 'test-verifier', + metadata: { + token_endpoint: `${server.url}token`, + token_endpoint_auth_methods_supported: ['client_secret_post'], + }, + }); + + // The SDK's exchangeAuthorization wants full OAuth metadata, + // so we'll test the token exchange directly instead of going through + // completeOAuthFlow (which requires full SDK-compatible metadata) + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + const tokens = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token?: string; + }; + + const mcpTokens: MCPOAuthTokens = { + ...tokens, + obtained_at: Date.now(), + expires_at: Date.now() + tokens.expires_in * 1000, + }; + + // Complete the flow + const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens); + expect(completed).toBe(true); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('COMPLETED'); + expect(state?.result?.access_token).toBe(tokens.access_token); + }); + + it('should fail flow when authorization code is invalid', async () => { + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: 'grant_type=authorization_code&code=invalid-code', + }); + + expect(tokenRes.status).toBe(400); + const body = (await tokenRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + + it('should fail when authorization code is reused', async () => { + const code = await server.getAuthCode(); + + // First exchange succeeds + const firstRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + expect(firstRes.status).toBe(200); + + // Second exchange fails + const secondRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + expect(secondRes.status).toBe(400); + const body = (await secondRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + }); + + describe('PKCE verification', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + afterEach(async () => { + await server.close(); + }); + + function generatePKCE(): { verifier: string; challenge: string } { + const verifier = 'dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk'; + const challenge = createHash('sha256').update(verifier).digest('base64url'); + return { verifier, challenge }; + } + + it('should accept valid code_verifier matching code_challenge', async () => { + const { verifier, challenge } = generatePKCE(); + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&code_verifier=${verifier}`, + }); + + expect(tokenRes.status).toBe(200); + const data = (await tokenRes.json()) as { access_token: string }; + expect(data.access_token).toBeDefined(); + }); + + it('should reject wrong code_verifier', async () => { + const { challenge } = generatePKCE(); + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&code_verifier=wrong-verifier`, + }); + + expect(tokenRes.status).toBe(400); + const body = (await tokenRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + + it('should reject missing code_verifier when code_challenge was provided', async () => { + const { challenge } = generatePKCE(); + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + expect(tokenRes.status).toBe(400); + const body = (await tokenRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + + it('should still accept codes without PKCE when no code_challenge was provided', async () => { + const code = await server.getAuthCode(); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + expect(tokenRes.status).toBe(200); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts new file mode 100644 index 0000000000..85febb3ece --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -0,0 +1,516 @@ +/** + * Tests for MCP OAuth race condition fixes: + * + * 1. Connection mutex coalesces concurrent getUserConnection() calls + * 2. PENDING OAuth flows are reused, not deleted + * 3. No-refresh-token expiry throws ReauthenticationRequiredError + * 4. completeFlow recovers when flow state was deleted by a race + * 5. monitorFlow retries once when flow state disappears mid-poll + */ + +import { Keyv } from 'keyv'; +import { logger } from '@librechat/data-schemas'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; +import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { MockKeyv, createOAuthMCPServer } from './helpers/oauthTestServer'; +import { FlowStateManager } from '~/flow/manager'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +jest.mock('~/auth', () => ({ + createSSRFSafeUndiciConnect: jest.fn(() => undefined), + resolveHostnameSSRF: jest.fn(async () => false), +})); + +jest.mock('~/mcp/mcpConfig', () => ({ + mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 }, +})); + +const mockLogger = logger as jest.Mocked; + +describe('MCP OAuth Race Condition Fixes', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('Fix 1: Connection mutex coalesces concurrent attempts', () => { + it('should return the same pending promise for concurrent getUserConnection calls', async () => { + const { UserConnectionManager } = await import('~/mcp/UserConnectionManager'); + + class TestManager extends UserConnectionManager { + public createCallCount = 0; + + getPendingConnections() { + return this.pendingConnections; + } + } + + const manager = new TestManager(); + + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn().mockResolvedValue(undefined), + isStale: jest.fn().mockReturnValue(false), + }; + + const mockAppConnections = { has: jest.fn().mockResolvedValue(false) }; + manager.appConnections = mockAppConnections as never; + + const mockConfig = { + type: 'streamable-http', + url: 'http://localhost:9999/', + updatedAt: undefined, + dbId: undefined, + }; + + jest + .spyOn( + // eslint-disable-next-line @typescript-eslint/no-require-imports + require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry, + 'getInstance', + ) + .mockReturnValue({ + getServerConfig: jest.fn().mockResolvedValue(mockConfig), + shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + }); + + const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); + const createSpy = jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => { + manager.createCallCount++; + await new Promise((r) => setTimeout(r, 100)); + return mockConnection as never; + }); + + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + const user = { id: 'user-1' }; + const opts = { + serverName: 'test-server', + user: user as never, + flowManager: flowManager as never, + }; + + const [conn1, conn2, conn3] = await Promise.all([ + manager.getUserConnection(opts), + manager.getUserConnection(opts), + manager.getUserConnection(opts), + ]); + + expect(conn1).toBe(conn2); + expect(conn2).toBe(conn3); + expect(createSpy).toHaveBeenCalledTimes(1); + expect(manager.createCallCount).toBe(1); + + createSpy.mockRestore(); + }); + + it('should not coalesce when forceNew is true', async () => { + const { UserConnectionManager } = await import('~/mcp/UserConnectionManager'); + + class TestManager extends UserConnectionManager {} + + const manager = new TestManager(); + + let callCount = 0; + const makeConnection = () => ({ + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn().mockResolvedValue(undefined), + isStale: jest.fn().mockReturnValue(false), + }); + + const mockAppConnections = { has: jest.fn().mockResolvedValue(false) }; + manager.appConnections = mockAppConnections as never; + + const mockConfig = { + type: 'streamable-http', + url: 'http://localhost:9999/', + updatedAt: undefined, + dbId: undefined, + }; + + jest + .spyOn( + // eslint-disable-next-line @typescript-eslint/no-require-imports + require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry, + 'getInstance', + ) + .mockReturnValue({ + getServerConfig: jest.fn().mockResolvedValue(mockConfig), + shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + }); + + const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); + jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => { + callCount++; + await new Promise((r) => setTimeout(r, 50)); + return makeConnection() as never; + }); + + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + const user = { id: 'user-2' }; + + const [conn1, conn2] = await Promise.all([ + manager.getUserConnection({ + serverName: 'test-server', + forceNew: true, + user: user as never, + flowManager: flowManager as never, + }), + manager.getUserConnection({ + serverName: 'test-server', + forceNew: true, + user: user as never, + flowManager: flowManager as never, + }), + ]); + + expect(callCount).toBe(2); + expect(conn1).not.toBe(conn2); + }); + }); + + describe('Fix 2: PENDING flow is reused, not deleted', () => { + it('should join an existing PENDING flow via createFlow instead of deleting it', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'test-flow-pending'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + clientInfo: { client_id: 'test-client' }, + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('PENDING'); + + const deleteSpy = jest.spyOn(flowManager, 'deleteFlow'); + + const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {}); + + await new Promise((r) => setTimeout(r, 500)); + + await flowManager.completeFlow(flowId, 'mcp_oauth', { + access_token: 'test-token', + token_type: 'Bearer', + } as never); + + const result = await monitorPromise; + expect(result).toEqual( + expect.objectContaining({ access_token: 'test-token', token_type: 'Bearer' }), + ); + expect(deleteSpy).not.toHaveBeenCalled(); + + deleteSpy.mockRestore(); + }); + + it('should delete and recreate FAILED flows', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'test-flow-failed'; + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + await flowManager.failFlow(flowId, 'mcp_oauth', 'previous error'); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('FAILED'); + + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + + const afterDelete = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(afterDelete).toBeUndefined(); + }); + }); + + describe('Fix 3: completeFlow handles deleted state gracefully', () => { + it('should return false when state was deleted by race', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'race-deleted-flow'; + + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + + const stateBeforeComplete = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(stateBeforeComplete).toBeUndefined(); + + const result = await flowManager.completeFlow(flowId, 'mcp_oauth', { + access_token: 'recovered-token', + token_type: 'Bearer', + } as never); + + expect(result).toBe(false); + + const stateAfterComplete = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(stateAfterComplete).toBeUndefined(); + + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('cannot recover metadata'), + expect.any(Object), + ); + }); + + it('should reject monitorFlow when state is deleted and not recoverable', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'monitor-retry-flow'; + + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + + const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {}); + + await new Promise((r) => setTimeout(r, 500)); + + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + + await expect(monitorPromise).rejects.toThrow('Flow state not found'); + }); + }); + + describe('State mapping cleanup on flow replacement', () => { + it('should delete old state mapping when a flow is replaced', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'user1:test-server'; + const oldState = 'old-random-state-abc123'; + const newState = 'new-random-state-xyz789'; + + // Simulate initial flow with state mapping + await flowManager.initFlow(flowId, 'mcp_oauth', { state: oldState }); + await MCPOAuthHandler.storeStateMapping(oldState, flowId, flowManager); + + // Old state should resolve + const resolvedBefore = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager); + expect(resolvedBefore).toBe(flowId); + + // Replace the flow: delete old, create new, clean up old state mapping + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + await MCPOAuthHandler.deleteStateMapping(oldState, flowManager); + await flowManager.initFlow(flowId, 'mcp_oauth', { state: newState }); + await MCPOAuthHandler.storeStateMapping(newState, flowId, flowManager); + + // Old state should no longer resolve + const resolvedOld = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager); + expect(resolvedOld).toBeNull(); + + // New state should resolve + const resolvedNew = await MCPOAuthHandler.resolveStateToFlowId(newState, flowManager); + expect(resolvedNew).toBe(flowId); + }); + }); + + describe('Fix 4: ReauthenticationRequiredError for no-refresh-token', () => { + it('should throw ReauthenticationRequiredError when access token expired and no refresh token', async () => { + const expiredDate = new Date(Date.now() - 60000); + + const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => { + if (filter.type === 'mcp_oauth') { + return { + token: 'enc:expired-access-token', + expiresAt: expiredDate, + createdAt: new Date(Date.now() - 120000), + }; + } + if (filter.type === 'mcp_oauth_refresh') { + return null; + } + return null; + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }), + ).rejects.toThrow('Re-authentication required'); + }); + + it('should throw ReauthenticationRequiredError when access token is missing and no refresh token', async () => { + const findToken = jest.fn().mockResolvedValue(null); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should not throw when access token is valid', async () => { + const futureDate = new Date(Date.now() + 3600000); + + const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => { + if (filter.type === 'mcp_oauth') { + return { + token: 'enc:valid-access-token', + expiresAt: futureDate, + createdAt: new Date(), + }; + } + if (filter.type === 'mcp_oauth_refresh') { + return null; + } + return null; + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }); + + expect(result).not.toBeNull(); + expect(result?.access_token).toBe('valid-access-token'); + }); + }); + + describe('E2E: OAuth-gated MCP server with no refresh tokens', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should start OAuth-gated MCP server that validates Bearer tokens', async () => { + const res = await fetch(server.url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'initialize', id: 1 }), + }); + + expect(res.status).toBe(401); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('invalid_token'); + }); + + it('should issue tokens via authorization code exchange with no refresh token', async () => { + const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, { + redirect: 'manual', + }); + + expect(authRes.status).toBe(302); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + expect(code).toBeTruthy(); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + expect(tokenRes.status).toBe(200); + const tokenBody = (await tokenRes.json()) as { + access_token: string; + token_type: string; + refresh_token?: string; + }; + expect(tokenBody.access_token).toBeTruthy(); + expect(tokenBody.token_type).toBe('Bearer'); + expect(tokenBody.refresh_token).toBeUndefined(); + }); + + it('should allow MCP requests with valid Bearer token', async () => { + const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, { + redirect: 'manual', + }); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + const { access_token } = (await tokenRes.json()) as { access_token: string }; + + const mcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + + expect(mcpRes.status).toBe(200); + }); + + it('should reject expired tokens with 401', async () => { + const shortTTLServer = await createOAuthMCPServer({ tokenTTLMs: 500 }); + + try { + const authRes = await fetch( + `${shortTTLServer.url}authorize?redirect_uri=http://localhost&state=s1`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + const tokenRes = await fetch(`${shortTTLServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + const { access_token } = (await tokenRes.json()) as { access_token: string }; + + await new Promise((r) => setTimeout(r, 600)); + + const mcpRes = await fetch(shortTTLServer.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${access_token}`, + }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 2 }), + }); + + expect(mcpRes.status).toBe(401); + } finally { + await shortTTLServer.close(); + } + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts new file mode 100644 index 0000000000..986ac4c8b4 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts @@ -0,0 +1,654 @@ +/** + * Tests for MCP OAuth token expiry → re-authentication scenarios. + * + * Reproduces the edge case where: + * 1. Tokens are stored (access + refresh) + * 2. Access token expires + * 3. Refresh attempt fails (server rejects/revokes refresh token) + * 4. System must fall back to full OAuth re-auth via handleOAuthRequired + * 5. The CSRF cookie may be absent (chat/SSE flow), so the PENDING flow fallback is needed + * + * Also tests the happy path: access token expired but refresh succeeds. + */ + +import { Keyv } from 'keyv'; +import { logger } from '@librechat/data-schemas'; +import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager'; +import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { MockKeyv, InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +describe('MCP OAuth Token Expiry Scenarios', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('Access token expired + refresh token available + refresh succeeds', () => { + let server: OAuthTestServer; + let tokenStore: InMemoryTokenStore; + + beforeEach(async () => { + server = await createOAuthMCPServer({ + tokenTTLMs: 500, + issueRefreshTokens: true, + }); + tokenStore = new InMemoryTokenStore(); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should refresh expired access token via real /token endpoint', async () => { + // Get initial tokens from real server + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // Store expired access token directly (bypassing storeTokens' expiresIn clamping) + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: `enc:${initial.access_token}`, + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: `enc:${initial.refresh_token}`, + expiresIn: 86400, + }); + + const refreshCallback = async (refreshToken: string): Promise => { + const res = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + if (!res.ok) { + throw new Error(`Refresh failed: ${res.status}`); + } + const data = (await res.json()) as { + access_token: string; + token_type: string; + expires_in: number; + }; + return { + ...data, + obtained_at: Date.now(), + expires_at: Date.now() + data.expires_in * 1000, + }; + }; + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(result).not.toBeNull(); + expect(result!.access_token).not.toBe(initial.access_token); + + // Verify the refreshed token works against the server + const mcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${result!.access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + expect(mcpRes.status).toBe(200); + }); + }); + + describe('Access token expired + refresh token rejected by server', () => { + let tokenStore: InMemoryTokenStore; + + beforeEach(() => { + tokenStore = new InMemoryTokenStore(); + }); + + it('should return null when refresh token is rejected (invalid_grant)', async () => { + const server = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + + try { + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // Store expired access token directly + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: `enc:${initial.access_token}`, + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: `enc:${initial.refresh_token}`, + expiresIn: 86400, + }); + + // Simulate server revoking the refresh token + server.issuedRefreshTokens.clear(); + + const refreshCallback = async (refreshToken: string): Promise => { + const res = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + if (!res.ok) { + const body = (await res.json()) as { error: string }; + throw new Error(`Token refresh failed: ${body.error}`); + } + const data = (await res.json()) as MCPOAuthTokens; + return data; + }; + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(result).toBeNull(); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to refresh tokens'), + expect.any(Error), + ); + } finally { + await server.close(); + } + }); + + it('should return null when refresh endpoint returns unauthorized_client', async () => { + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: 'enc:some-refresh-token', + expiresIn: 86400, + }); + + const refreshCallback = jest + .fn() + .mockRejectedValue(new Error('unauthorized_client: client not authorized for refresh')); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(result).toBeNull(); + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining('does not support refresh tokens'), + ); + }); + }); + + describe('Access token expired + NO refresh token → ReauthenticationRequiredError', () => { + let tokenStore: InMemoryTokenStore; + + beforeEach(() => { + tokenStore = new InMemoryTokenStore(); + }); + + it('should throw ReauthenticationRequiredError when no refresh token stored', async () => { + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should throw ReauthenticationRequiredError with correct reason for expired token', async () => { + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }), + ).rejects.toThrow('access token expired'); + }); + + it('should throw ReauthenticationRequiredError with correct reason for missing token', async () => { + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }), + ).rejects.toThrow('access token missing'); + }); + }); + + describe('PENDING flow fallback for CSRF-less OAuth callbacks', () => { + it('should allow OAuth completion when PENDING flow exists (simulating chat/SSE path)', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'user1:test-server'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + userId: 'user1', + serverUrl: 'https://example.com', + state: 'test-state', + authorizationUrl: 'https://example.com/authorize?state=user1:test-server', + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('PENDING'); + + const tokens: MCPOAuthTokens = { + access_token: 'new-access-token', + token_type: 'Bearer', + refresh_token: 'new-refresh-token', + obtained_at: Date.now(), + expires_at: Date.now() + 3600000, + }; + + const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', tokens); + expect(completed).toBe(true); + + const completedState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(completedState?.status).toBe('COMPLETED'); + expect((completedState?.result as MCPOAuthTokens | undefined)?.access_token).toBe( + 'new-access-token', + ); + }); + + it('should store authorizationUrl in flow metadata for re-issuance', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'user1:test-server'; + const authUrl = 'https://auth.example.com/authorize?client_id=abc&state=user1:test-server'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + userId: 'user1', + serverUrl: 'https://example.com', + state: 'test-state', + authorizationUrl: authUrl, + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect((state?.metadata as Record)?.authorizationUrl).toBe(authUrl); + }); + }); + + describe('Full token expiry → refresh failure → re-auth flow', () => { + let server: OAuthTestServer; + let tokenStore: InMemoryTokenStore; + + beforeEach(async () => { + server = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + tokenStore = new InMemoryTokenStore(); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should go through full cycle: get tokens → expire → refresh fails → re-auth needed', async () => { + // Step 1: Get initial tokens + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // Step 2: Store tokens with valid expiry first + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'test-srv', + tokens: initial, + createToken: tokenStore.createToken, + }); + + // Step 3: Verify tokens work + const validResult = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }); + expect(validResult).not.toBeNull(); + expect(validResult!.access_token).toBe(initial.access_token); + + // Step 4: Simulate token expiry by directly updating the stored token's expiresAt + await tokenStore.updateToken({ userId: 'u1', identifier: 'mcp:test-srv' }, { expiresIn: -1 }); + + // Step 5: Revoke refresh token on server side (simulating server-side revocation) + server.issuedRefreshTokens.clear(); + + // Step 6: Try to get tokens — refresh should fail, return null + const refreshCallback = async (refreshToken: string): Promise => { + const res = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + if (!res.ok) { + const body = (await res.json()) as { error: string }; + throw new Error(`Refresh failed: ${body.error}`); + } + const data = (await res.json()) as MCPOAuthTokens; + return data; + }; + + const expiredResult = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + // Refresh failed → returns null → triggers OAuth re-auth flow + expect(expiredResult).toBeNull(); + + // Step 7: Simulate the re-auth flow via FlowStateManager + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + const flowId = 'u1:test-srv'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-srv', + userId: 'u1', + serverUrl: server.url, + state: 'test-state', + authorizationUrl: `${server.url}authorize?state=${flowId}`, + }); + + // Step 8: Get a new auth code and exchange for tokens (simulating user re-auth) + const newCode = await server.getAuthCode(); + const newTokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${newCode}`, + }); + const newTokens = (await newTokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token?: string; + }; + + // Step 9: Complete the flow + const mcpTokens: MCPOAuthTokens = { + ...newTokens, + obtained_at: Date.now(), + expires_at: Date.now() + newTokens.expires_in * 1000, + }; + await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens); + + // Step 10: Store the new tokens + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'test-srv', + tokens: mcpTokens, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + }); + + // Step 11: Verify new tokens work + const newResult = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }); + expect(newResult).not.toBeNull(); + expect(newResult!.access_token).toBe(newTokens.access_token); + + // Step 12: Verify new token works against server + const finalMcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${newResult!.access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + expect(finalMcpRes.status).toBe(200); + }); + }); + + describe('Concurrent token expiry with connection mutex', () => { + it('should handle multiple concurrent getTokens calls when token is expired', async () => { + const tokenStore = new InMemoryTokenStore(); + + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: 'enc:valid-refresh', + expiresIn: 86400, + }); + + let refreshCallCount = 0; + const refreshCallback = jest.fn().mockImplementation(async () => { + refreshCallCount++; + await new Promise((r) => setTimeout(r, 100)); + return { + access_token: `refreshed-token-${refreshCallCount}`, + token_type: 'Bearer', + expires_in: 3600, + obtained_at: Date.now(), + expires_at: Date.now() + 3600000, + }; + }); + + // Fire 3 concurrent getTokens calls via FlowStateManager (like the connection mutex does) + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const getTokensViaFlow = () => + flowManager.createFlowWithHandler('u1:test-srv', 'mcp_get_tokens', async () => { + return await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + }); + + const [r1, r2, r3] = await Promise.all([ + getTokensViaFlow(), + getTokensViaFlow(), + getTokensViaFlow(), + ]); + + // All should get tokens (either directly or via flow coalescing) + expect(r1).not.toBeNull(); + expect(r2).not.toBeNull(); + expect(r3).not.toBeNull(); + + // The refresh callback should only be called once due to flow coalescing + expect(refreshCallback).toHaveBeenCalledTimes(1); + }); + }); + + describe('Stale PENDING flow detection', () => { + it('should treat PENDING flows older than 2 minutes as stale', async () => { + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 300000, + ci: true, + }); + + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + authorizationUrl: 'https://example.com/auth', + }); + + // Manually age the flow to 3 minutes + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (state) { + state.createdAt = Date.now() - 3 * 60 * 1000; + await (flowStore as unknown as { set: (k: string, v: unknown) => Promise }).set( + `mcp_oauth:${flowId}`, + state, + ); + } + + const agedState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(agedState?.status).toBe('PENDING'); + + const age = agedState?.createdAt ? Date.now() - agedState.createdAt : 0; + expect(age).toBeGreaterThan(2 * 60 * 1000); + + // A new flow should be created (the stale one would be deleted + recreated) + // This verifies our staleness check threshold + expect(age > PENDING_STALE_MS).toBe(true); + }); + + it('should not treat recent PENDING flows as stale', async () => { + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 300000, + ci: true, + }); + + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + authorizationUrl: 'https://example.com/auth', + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + const age = state?.createdAt ? Date.now() - state.createdAt : Infinity; + + expect(age < PENDING_STALE_MS).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts new file mode 100644 index 0000000000..3805586453 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts @@ -0,0 +1,544 @@ +/** + * Integration tests for MCPTokenStorage.storeTokens() and MCPTokenStorage.getTokens(). + * + * Uses InMemoryTokenStore to exercise encrypt/decrypt round-trips, expiry calculation, + * refresh callback wiring, and ReauthenticationRequiredError paths. + */ + +import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { InMemoryTokenStore } from './helpers/oauthTestServer'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +describe('MCPTokenStorage', () => { + let store: InMemoryTokenStore; + + beforeEach(() => { + store = new InMemoryTokenStore(); + jest.clearAllMocks(); + }); + + describe('storeTokens', () => { + it('should create new access token with expires_in', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved).not.toBeNull(); + expect(saved!.token).toBe('enc:at1'); + const expiresInMs = saved!.expiresAt.getTime() - Date.now(); + expect(expiresInMs).toBeGreaterThan(3500 * 1000); + expect(expiresInMs).toBeLessThanOrEqual(3600 * 1000); + }); + + it('should create new access token with expires_at (MCPOAuthTokens format)', async () => { + const expiresAt = Date.now() + 7200 * 1000; + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'at1', + token_type: 'Bearer', + expires_at: expiresAt, + obtained_at: Date.now(), + }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved).not.toBeNull(); + const diff = Math.abs(saved!.expiresAt.getTime() - expiresAt); + expect(diff).toBeLessThan(2000); + }); + + it('should default to 1-year expiry when none provided', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer' }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + const oneYearMs = 365 * 24 * 60 * 60 * 1000; + const expiresInMs = saved!.expiresAt.getTime() - Date.now(); + expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000); + }); + + it('should update existing access token', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:old-token', + expiresIn: 3600, + }); + + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'new-token', token_type: 'Bearer', expires_in: 7200 }, + createToken: store.createToken, + updateToken: store.updateToken, + findToken: store.findToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved!.token).toBe('enc:new-token'); + }); + + it('should store refresh token alongside access token', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'at1', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'rt1', + }, + createToken: store.createToken, + }); + + const refreshSaved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + }); + expect(refreshSaved).not.toBeNull(); + expect(refreshSaved!.token).toBe('enc:rt1'); + }); + + it('should skip refresh token when not in response', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + }); + + const refreshSaved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + }); + expect(refreshSaved).toBeNull(); + }); + + it('should store client info when provided', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] }, + }); + + const clientSaved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth_client', + identifier: 'mcp:srv1:client', + }); + expect(clientSaved).not.toBeNull(); + expect(clientSaved!.token).toContain('enc:'); + expect(clientSaved!.token).toContain('cid'); + }); + + it('should use existingTokens to skip DB lookups', async () => { + const findSpy = jest.fn(); + + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + updateToken: store.updateToken, + findToken: findSpy, + existingTokens: { + accessToken: null, + refreshToken: null, + clientInfoToken: null, + }, + }); + + expect(findSpy).not.toHaveBeenCalled(); + }); + + it('should handle invalid NaN expiry date', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'at1', + token_type: 'Bearer', + expires_at: NaN, + obtained_at: Date.now(), + }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved).not.toBeNull(); + const oneYearMs = 365 * 24 * 60 * 60 * 1000; + const expiresInMs = saved!.expiresAt.getTime() - Date.now(); + expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000); + }); + }); + + describe('getTokens', () => { + it('should return valid non-expired tokens', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:valid-token', + expiresIn: 3600, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result).not.toBeNull(); + expect(result!.access_token).toBe('valid-token'); + expect(result!.token_type).toBe('Bearer'); + }); + + it('should return tokens with refresh token when available', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:at', + expiresIn: 3600, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result!.refresh_token).toBe('rt'); + }); + + it('should return tokens without refresh token field when none stored', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:at', + expiresIn: 3600, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result!.refresh_token).toBeUndefined(); + }); + + it('should throw ReauthenticationRequiredError when expired and no refresh', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should throw ReauthenticationRequiredError when missing and no refresh', async () => { + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should refresh expired access token when refresh token and callback are available', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockResolvedValue({ + access_token: 'refreshed-at', + token_type: 'Bearer', + expires_in: 3600, + obtained_at: Date.now(), + expires_at: Date.now() + 3600000, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + updateToken: store.updateToken, + refreshTokens, + }); + + expect(result).not.toBeNull(); + expect(result!.access_token).toBe('refreshed-at'); + expect(refreshTokens).toHaveBeenCalledWith( + 'rt', + expect.objectContaining({ userId: 'u1', serverName: 'srv1' }), + ); + }); + + it('should return null when refresh fails', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockRejectedValue(new Error('refresh failed')); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + updateToken: store.updateToken, + refreshTokens, + }); + + expect(result).toBeNull(); + }); + + it('should return null when no refreshTokens callback provided', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result).toBeNull(); + }); + + it('should return null when no createToken callback provided', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + refreshTokens: jest.fn(), + }); + + expect(result).toBeNull(); + }); + + it('should pass client info to refreshTokens metadata', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_client', + identifier: 'mcp:srv1:client', + token: 'enc:{"client_id":"cid","client_secret":"csec"}', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockResolvedValue({ + access_token: 'new-at', + token_type: 'Bearer', + expires_in: 3600, + }); + + await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + updateToken: store.updateToken, + refreshTokens, + }); + + expect(refreshTokens).toHaveBeenCalledWith( + 'rt', + expect.objectContaining({ + clientInfo: expect.objectContaining({ client_id: 'cid' }), + }), + ); + }); + + it('should handle unauthorized_client refresh error', async () => { + const { logger } = await import('@librechat/data-schemas'); + + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockRejectedValue(new Error('unauthorized_client')); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + refreshTokens, + }); + + expect(result).toBeNull(); + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining('does not support refresh tokens'), + ); + }); + }); + + describe('storeTokens + getTokens round-trip', () => { + it('should store and retrieve tokens with full encrypt/decrypt cycle', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'my-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'my-refresh-token', + }, + createToken: store.createToken, + clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] }, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result!.access_token).toBe('my-access-token'); + expect(result!.refresh_token).toBe('my-refresh-token'); + expect(result!.token_type).toBe('Bearer'); + expect(result!.obtained_at).toBeDefined(); + expect(result!.expires_at).toBeDefined(); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts new file mode 100644 index 0000000000..3b68b2ded4 --- /dev/null +++ b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts @@ -0,0 +1,449 @@ +import * as http from 'http'; +import * as net from 'net'; +import { randomUUID, createHash } from 'crypto'; +import { z } from 'zod'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; +import type { FlowState } from '~/flow/types'; +import type { Socket } from 'net'; + +export class MockKeyv { + private store: Map>; + + constructor() { + this.store = new Map(); + } + + async get(key: string): Promise | undefined> { + return this.store.get(key); + } + + async set(key: string, value: FlowState, _ttl?: number): Promise { + this.store.set(key, value); + return true; + } + + async delete(key: string): Promise { + return this.store.delete(key); + } +} + +export function getFreePort(): Promise { + return new Promise((resolve, reject) => { + const srv = net.createServer(); + srv.listen(0, '127.0.0.1', () => { + const addr = srv.address() as net.AddressInfo; + srv.close((err) => (err ? reject(err) : resolve(addr.port))); + }); + }); +} + +export function trackSockets(httpServer: http.Server): () => Promise { + const sockets = new Set(); + httpServer.on('connection', (socket: Socket) => { + sockets.add(socket); + socket.once('close', () => sockets.delete(socket)); + }); + return () => + new Promise((resolve) => { + for (const socket of sockets) { + socket.destroy(); + } + sockets.clear(); + httpServer.close(() => resolve()); + }); +} + +export interface OAuthTestServerOptions { + tokenTTLMs?: number; + issueRefreshTokens?: boolean; + refreshTokenTTLMs?: number; + rotateRefreshTokens?: boolean; +} + +export interface OAuthTestServer { + url: string; + port: number; + close: () => Promise; + issuedTokens: Set; + tokenTTL: number; + tokenIssueTimes: Map; + issuedRefreshTokens: Map; + registeredClients: Map; + getAuthCode: () => Promise; +} + +async function readRequestBody(req: http.IncomingMessage): Promise { + const chunks: Uint8Array[] = []; + for await (const chunk of req) { + chunks.push(chunk as Uint8Array); + } + return Buffer.concat(chunks).toString(); +} + +function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null { + if (contentType?.includes('application/x-www-form-urlencoded')) { + return new URLSearchParams(body); + } + if (contentType?.includes('application/json')) { + const json = JSON.parse(body) as Record; + return new URLSearchParams(json); + } + return new URLSearchParams(body); +} + +export async function createOAuthMCPServer( + options: OAuthTestServerOptions = {}, +): Promise { + const { + tokenTTLMs = 60000, + issueRefreshTokens = false, + refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000, + rotateRefreshTokens = false, + } = options; + + const sessions = new Map(); + const issuedTokens = new Set(); + const tokenIssueTimes = new Map(); + const issuedRefreshTokens = new Map(); + const refreshTokenIssueTimes = new Map(); + const authCodes = new Map(); + const registeredClients = new Map(); + + let port = 0; + + const httpServer = http.createServer(async (req, res) => { + const url = new URL(req.url ?? '/', `http://${req.headers.host}`); + + if (url.pathname === '/.well-known/oauth-authorization-server' && req.method === 'GET') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + issuer: `http://127.0.0.1:${port}`, + authorization_endpoint: `http://127.0.0.1:${port}/authorize`, + token_endpoint: `http://127.0.0.1:${port}/token`, + registration_endpoint: `http://127.0.0.1:${port}/register`, + response_types_supported: ['code'], + grant_types_supported: issueRefreshTokens + ? ['authorization_code', 'refresh_token'] + : ['authorization_code'], + token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post'], + code_challenge_methods_supported: ['S256'], + }), + ); + return; + } + + if (url.pathname === '/register' && req.method === 'POST') { + const body = await readRequestBody(req); + const data = JSON.parse(body) as { redirect_uris?: string[] }; + const clientId = `client-${randomUUID().slice(0, 8)}`; + const clientSecret = `secret-${randomUUID()}`; + registeredClients.set(clientId, { client_id: clientId, client_secret: clientSecret }); + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + client_id: clientId, + client_secret: clientSecret, + redirect_uris: data.redirect_uris ?? [], + }), + ); + return; + } + + if (url.pathname === '/authorize') { + const code = randomUUID(); + const codeChallenge = url.searchParams.get('code_challenge') ?? undefined; + const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined; + authCodes.set(code, { codeChallenge, codeChallengeMethod }); + const redirectUri = url.searchParams.get('redirect_uri') ?? ''; + const state = url.searchParams.get('state') ?? ''; + res.writeHead(302, { + Location: `${redirectUri}?code=${code}&state=${state}`, + }); + res.end(); + return; + } + + if (url.pathname === '/token' && req.method === 'POST') { + const body = await readRequestBody(req); + const params = parseTokenRequest(body, req.headers['content-type']); + if (!params) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_request' })); + return; + } + + const grantType = params.get('grant_type'); + + if (grantType === 'authorization_code') { + const code = params.get('code'); + const codeData = code ? authCodes.get(code) : undefined; + if (!code || !codeData) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + + if (codeData.codeChallenge) { + const codeVerifier = params.get('code_verifier'); + if (!codeVerifier) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + if (codeData.codeChallengeMethod === 'S256') { + const expected = createHash('sha256').update(codeVerifier).digest('base64url'); + if (expected !== codeData.codeChallenge) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + } + } + + authCodes.delete(code); + + const accessToken = randomUUID(); + issuedTokens.add(accessToken); + tokenIssueTimes.set(accessToken, Date.now()); + + const tokenResponse: Record = { + access_token: accessToken, + token_type: 'Bearer', + expires_in: Math.ceil(tokenTTLMs / 1000), + }; + + if (issueRefreshTokens) { + const refreshToken = randomUUID(); + issuedRefreshTokens.set(refreshToken, accessToken); + refreshTokenIssueTimes.set(refreshToken, Date.now()); + tokenResponse.refresh_token = refreshToken; + } + + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify(tokenResponse)); + return; + } + + if (grantType === 'refresh_token' && issueRefreshTokens) { + const refreshToken = params.get('refresh_token'); + if (!refreshToken || !issuedRefreshTokens.has(refreshToken)) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + + const issueTime = refreshTokenIssueTimes.get(refreshToken) ?? 0; + if (Date.now() - issueTime > refreshTokenTTLMs) { + issuedRefreshTokens.delete(refreshToken); + refreshTokenIssueTimes.delete(refreshToken); + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + + const newAccessToken = randomUUID(); + issuedTokens.add(newAccessToken); + tokenIssueTimes.set(newAccessToken, Date.now()); + + const tokenResponse: Record = { + access_token: newAccessToken, + token_type: 'Bearer', + expires_in: Math.ceil(tokenTTLMs / 1000), + }; + + if (rotateRefreshTokens) { + issuedRefreshTokens.delete(refreshToken); + refreshTokenIssueTimes.delete(refreshToken); + const newRefreshToken = randomUUID(); + issuedRefreshTokens.set(newRefreshToken, newAccessToken); + refreshTokenIssueTimes.set(newRefreshToken, Date.now()); + tokenResponse.refresh_token = newRefreshToken; + } + + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify(tokenResponse)); + return; + } + + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'unsupported_grant_type' })); + return; + } + + // All other paths require Bearer token auth + const authHeader = req.headers.authorization; + if (!authHeader || !authHeader.startsWith('Bearer ')) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_token' })); + return; + } + + const token = authHeader.slice(7); + if (!issuedTokens.has(token)) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_token' })); + return; + } + + const issueTime = tokenIssueTimes.get(token) ?? 0; + if (Date.now() - issueTime > tokenTTLMs) { + issuedTokens.delete(token); + tokenIssueTimes.delete(token); + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_token' })); + return; + } + + // Authenticated MCP request — route to transport + const sid = req.headers['mcp-session-id'] as string | undefined; + let transport = sid ? sessions.get(sid) : undefined; + + if (!transport) { + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + }); + const mcp = new McpServer({ name: 'oauth-test-server', version: '0.0.1' }); + mcp.tool('echo', { message: z.string() }, async (args) => ({ + content: [{ type: 'text' as const, text: `echo: ${args.message}` }], + })); + await mcp.connect(transport); + } + + await transport.handleRequest(req, res); + + if (transport.sessionId && !sessions.has(transport.sessionId)) { + sessions.set(transport.sessionId, transport); + transport.onclose = () => sessions.delete(transport!.sessionId!); + } + }); + + const destroySockets = trackSockets(httpServer); + port = await getFreePort(); + await new Promise((resolve) => httpServer.listen(port, '127.0.0.1', resolve)); + + return { + url: `http://127.0.0.1:${port}/`, + port, + issuedTokens, + tokenTTL: tokenTTLMs, + tokenIssueTimes, + issuedRefreshTokens, + registeredClients, + getAuthCode: async () => { + const authRes = await fetch( + `http://127.0.0.1:${port}/authorize?redirect_uri=http://localhost&state=test`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + return new URL(location).searchParams.get('code') ?? ''; + }, + close: async () => { + const closing = [...sessions.values()].map((t) => t.close().catch(() => undefined)); + sessions.clear(); + await Promise.all(closing); + await destroySockets(); + }, + }; +} + +export interface InMemoryToken { + userId: string; + type: string; + identifier: string; + token: string; + expiresAt: Date; + createdAt: Date; + metadata?: Map | Record; +} + +export class InMemoryTokenStore { + private tokens: Map = new Map(); + + private key(filter: { userId?: string; type?: string; identifier?: string }): string { + return `${filter.userId}:${filter.type}:${filter.identifier}`; + } + + findToken = async (filter: { + userId?: string; + type?: string; + identifier?: string; + }): Promise => { + for (const token of this.tokens.values()) { + const matchUserId = !filter.userId || token.userId === filter.userId; + const matchType = !filter.type || token.type === filter.type; + const matchIdentifier = !filter.identifier || token.identifier === filter.identifier; + if (matchUserId && matchType && matchIdentifier) { + return token; + } + } + return null; + }; + + createToken = async (data: { + userId: string; + type: string; + identifier: string; + token: string; + expiresIn?: number; + metadata?: Record; + }): Promise => { + const expiresIn = data.expiresIn ?? 365 * 24 * 60 * 60; + const token: InMemoryToken = { + userId: data.userId, + type: data.type, + identifier: data.identifier, + token: data.token, + expiresAt: new Date(Date.now() + expiresIn * 1000), + createdAt: new Date(), + metadata: data.metadata, + }; + this.tokens.set(this.key(data), token); + return token; + }; + + updateToken = async ( + filter: { userId?: string; type?: string; identifier?: string }, + data: { + userId?: string; + type?: string; + identifier?: string; + token?: string; + expiresIn?: number; + metadata?: Record; + }, + ): Promise => { + const existing = await this.findToken(filter); + if (!existing) { + throw new Error(`Token not found for filter: ${JSON.stringify(filter)}`); + } + const existingKey = this.key(existing); + const expiresIn = + data.expiresIn ?? Math.floor((existing.expiresAt.getTime() - Date.now()) / 1000); + const updated: InMemoryToken = { + ...existing, + token: data.token ?? existing.token, + expiresAt: data.expiresIn ? new Date(Date.now() + expiresIn * 1000) : existing.expiresAt, + metadata: data.metadata ?? existing.metadata, + }; + this.tokens.set(existingKey, updated); + return updated; + }; + + deleteToken = async (filter: { + userId: string; + type: string; + identifier: string; + }): Promise => { + this.tokens.delete(this.key(filter)); + }; + + getAll(): InMemoryToken[] { + return [...this.tokens.values()]; + } + + clear(): void { + this.tokens.clear(); + } +} diff --git a/packages/api/src/mcp/__tests__/reconnection-storm.test.ts b/packages/api/src/mcp/__tests__/reconnection-storm.test.ts index c1cf0ec5df..e073dca8a3 100644 --- a/packages/api/src/mcp/__tests__/reconnection-storm.test.ts +++ b/packages/api/src/mcp/__tests__/reconnection-storm.test.ts @@ -12,8 +12,12 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import type { Socket } from 'net'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; import { OAuthReconnectionTracker } from '~/mcp/oauth/OAuthReconnectionTracker'; +import { createOAuthMCPServer } from './helpers/oauthTestServer'; import { MCPConnection } from '~/mcp/connection'; +import { mcpConfig } from '~/mcp/mcpConfig'; jest.mock('@librechat/data-schemas', () => ({ logger: { @@ -143,16 +147,17 @@ afterEach(() => { /* ------------------------------------------------------------------ */ /* Fix #2 — Circuit breaker trips after rapid connect/disconnect */ -/* cycles (5 cycles within 60s -> 30s cooldown) */ +/* cycles (CB_MAX_CYCLES within window -> cooldown) */ /* ------------------------------------------------------------------ */ describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => { - it('blocks connection after 5 rapid cycles via static circuit breaker', async () => { + it('blocks connection after CB_MAX_CYCLES rapid cycles via static circuit breaker', async () => { const srv = await startMCPServer(); const conn = createConnection('cycling-server', srv.url); let completedCycles = 0; let breakerMessage = ''; - for (let cycle = 0; cycle < 10; cycle++) { + const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2; + for (let cycle = 0; cycle < maxAttempts; cycle++) { try { await conn.connect(); await teardownConnection(conn); @@ -166,7 +171,7 @@ describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => { } expect(breakerMessage).toContain('Circuit breaker is open'); - expect(completedCycles).toBeLessThanOrEqual(5); + expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES); await srv.close(); }); @@ -266,12 +271,13 @@ describe('Fix #4: Circuit breaker state persists across instance replacement', ( /* recordFailedRound() in the catch path */ /* ------------------------------------------------------------------ */ describe('Fix #5: Dead server triggers circuit breaker', () => { - it('3 failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => { + it('failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => { const conn = createConnection('dead', 'http://127.0.0.1:1/mcp', 1000); const spy = jest.spyOn(conn.client, 'connect'); + const totalAttempts = mcpConfig.CB_MAX_FAILED_ROUNDS + 2; const errors: string[] = []; - for (let i = 0; i < 5; i++) { + for (let i = 0; i < totalAttempts; i++) { try { await conn.connect(); } catch (e) { @@ -279,8 +285,8 @@ describe('Fix #5: Dead server triggers circuit breaker', () => { } } - expect(spy.mock.calls.length).toBe(3); - expect(errors).toHaveLength(5); + expect(spy.mock.calls.length).toBe(mcpConfig.CB_MAX_FAILED_ROUNDS); + expect(errors).toHaveLength(totalAttempts); expect(errors.filter((m) => m.includes('Circuit breaker is open'))).toHaveLength(2); await conn.disconnect(); @@ -295,7 +301,7 @@ describe('Fix #5: Dead server triggers circuit breaker', () => { userId: 'user-A', }); - for (let i = 0; i < 3; i++) { + for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) { try { await userA.connect(); } catch { @@ -332,7 +338,7 @@ describe('Fix #5: Dead server triggers circuit breaker', () => { serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never, userId: 'user-A', }); - for (let i = 0; i < 3; i++) { + for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) { try { await userA.connect(); } catch { @@ -448,13 +454,14 @@ describe('Fix #6: OAuth failure uses cooldown-based retry', () => { /* Integration: Circuit breaker caps rapid cycling with real transport */ /* ------------------------------------------------------------------ */ describe('Cascade: Circuit breaker caps rapid cycling', () => { - it('breaker trips before 10 cycles complete against a live server', async () => { + it('breaker trips before double CB_MAX_CYCLES complete against a live server', async () => { const srv = await startMCPServer(); const conn = createConnection('cascade', srv.url); const spy = jest.spyOn(conn.client, 'connect'); let completedCycles = 0; - for (let i = 0; i < 10; i++) { + const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2; + for (let i = 0; i < maxAttempts; i++) { try { await conn.connect(); await teardownConnection(conn); @@ -469,8 +476,8 @@ describe('Cascade: Circuit breaker caps rapid cycling', () => { } } - expect(completedCycles).toBeLessThanOrEqual(5); - expect(spy.mock.calls.length).toBeLessThanOrEqual(5); + expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES); + expect(spy.mock.calls.length).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES); await srv.close(); }); @@ -501,6 +508,146 @@ describe('Cascade: Circuit breaker caps rapid cycling', () => { }, 30_000); }); +/* ------------------------------------------------------------------ */ +/* OAuth: cycle recovery after successful OAuth reconnect */ +/* ------------------------------------------------------------------ */ +describe('OAuth: cycle budget recovery after successful OAuth', () => { + let oauthServer: OAuthTestServer; + + beforeEach(async () => { + oauthServer = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + afterEach(async () => { + await oauthServer.close(); + }); + + async function exchangeCodeForToken(serverUrl: string): Promise { + const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, { + redirect: 'manual', + }); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + const tokenRes = await fetch(`${serverUrl}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const data = (await tokenRes.json()) as { access_token: string }; + return data.access_token; + } + + it('should decrement cycle count after successful OAuth recovery', async () => { + const serverName = 'oauth-cycle-test'; + MCPConnection.clearCooldown(serverName); + + const conn = new MCPConnection({ + serverName, + serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 }, + userId: 'user-1', + }); + + // When oauthRequired fires, get a token and emit oauthHandled + // This triggers the oauthRecovery path inside connectClient + conn.on('oauthRequired', async () => { + const accessToken = await exchangeCodeForToken(oauthServer.url); + conn.setOAuthTokens({ + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + conn.emit('oauthHandled'); + }); + + // connect() → 401 → oauthRequired → oauthHandled → connectClient returns + // connect() sees not connected → throws "Connection not established" + await expect(conn.connect()).rejects.toThrow('Connection not established'); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cb = (MCPConnection as any).circuitBreakers.get(serverName); + const cyclesBeforeRetry = cb.cycleCount; + + // Retry — should succeed and decrement cycle count via oauthRecovery + await conn.connect(); + expect(await conn.isConnected()).toBe(true); + + const cyclesAfterSuccess = cb.cycleCount; + // The retry adds +1 cycle (disconnect(false)) then -1 (oauthRecovery decrement) + // So cyclesAfterSuccess should equal cyclesBeforeRetry, not cyclesBeforeRetry + 1 + expect(cyclesAfterSuccess).toBe(cyclesBeforeRetry); + + await teardownConnection(conn); + }); + + it('should allow more OAuth reconnects than non-OAuth before breaker trips', async () => { + const serverName = 'oauth-budget'; + MCPConnection.clearCooldown(serverName); + + // Each OAuth flow: connect (+1) → 401 → oauthHandled → retry connect (+1) → success (-1) = net 1 + // Without the decrement it would be net 2 per flow, tripping the breaker after ~2 users + let successfulFlows = 0; + for (let i = 0; i < 10; i++) { + const conn = new MCPConnection({ + serverName, + serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 }, + userId: `user-${i}`, + }); + + conn.on('oauthRequired', async () => { + const accessToken = await exchangeCodeForToken(oauthServer.url); + conn.setOAuthTokens({ + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + conn.emit('oauthHandled'); + }); + + try { + // First connect: 401 → oauthHandled → returns without connection + await conn.connect().catch(() => {}); + // Retry: succeeds with token, decrements cycle + await conn.connect(); + successfulFlows++; + await teardownConnection(conn); + } catch (e) { + conn.removeAllListeners(); + if ((e as Error).message.includes('Circuit breaker is open')) { + break; + } + } + } + + // With the oauthRecovery decrement, each flow is net ~1 cycle instead of ~2, + // so we should get more successful flows before the breaker trips + expect(successfulFlows).toBeGreaterThanOrEqual(3); + }); + + it('should not decrement cycle count when OAuth fails', async () => { + const serverName = 'oauth-failed-no-decrement'; + MCPConnection.clearCooldown(serverName); + + const conn = new MCPConnection({ + serverName, + serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 }, + userId: 'user-1', + }); + + conn.on('oauthRequired', () => { + conn.emit('oauthFailed', new Error('user denied')); + }); + + await expect(conn.connect()).rejects.toThrow(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cb = (MCPConnection as any).circuitBreakers.get(serverName); + const cyclesAfterFailure = cb.cycleCount; + + // connect() recorded +1 cycle, oauthFailed should NOT decrement + expect(cyclesAfterFailure).toBeGreaterThanOrEqual(1); + + conn.removeAllListeners(); + }); +}); + /* ------------------------------------------------------------------ */ /* Sanity: Real transport works end-to-end */ /* ------------------------------------------------------------------ */ diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index cac0a4afc5..8dc857cd3b 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -82,14 +82,6 @@ interface CircuitBreakerState { failedBackoffUntil: number; } -const CB_MAX_CYCLES = 5; -const CB_CYCLE_WINDOW_MS = 60_000; -const CB_CYCLE_COOLDOWN_MS = 30_000; - -const CB_MAX_FAILED_ROUNDS = 3; -const CB_FAILED_WINDOW_MS = 120_000; -const CB_BASE_BACKOFF_MS = 30_000; -const CB_MAX_BACKOFF_MS = 300_000; /** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */ const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES; @@ -281,6 +273,7 @@ export class MCPConnection extends EventEmitter { private oauthTokens?: MCPOAuthTokens | null; private requestHeaders?: Record | null; private oauthRequired = false; + private oauthRecovery = false; private readonly useSSRFProtection: boolean; iconPath?: string; timeout?: number; @@ -325,17 +318,17 @@ export class MCPConnection extends EventEmitter { private recordCycle(): void { const cb = this.getCircuitBreaker(); const now = Date.now(); - if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) { + if (now - cb.cycleWindowStart > mcpConfig.CB_CYCLE_WINDOW_MS) { cb.cycleCount = 0; cb.cycleWindowStart = now; } cb.cycleCount++; - if (cb.cycleCount >= CB_MAX_CYCLES) { - cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS; + if (cb.cycleCount >= mcpConfig.CB_MAX_CYCLES) { + cb.cooldownUntil = now + mcpConfig.CB_CYCLE_COOLDOWN_MS; cb.cycleCount = 0; cb.cycleWindowStart = now; logger.warn( - `${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${CB_CYCLE_COOLDOWN_MS}ms`, + `${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${mcpConfig.CB_CYCLE_COOLDOWN_MS}ms`, ); } } @@ -343,15 +336,16 @@ export class MCPConnection extends EventEmitter { private recordFailedRound(): void { const cb = this.getCircuitBreaker(); const now = Date.now(); - if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) { + if (now - cb.failedWindowStart > mcpConfig.CB_FAILED_WINDOW_MS) { cb.failedRounds = 0; cb.failedWindowStart = now; } cb.failedRounds++; - if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) { + if (cb.failedRounds >= mcpConfig.CB_MAX_FAILED_ROUNDS) { const backoff = Math.min( - CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS), - CB_MAX_BACKOFF_MS, + mcpConfig.CB_BASE_BACKOFF_MS * + Math.pow(2, cb.failedRounds - mcpConfig.CB_MAX_FAILED_ROUNDS), + mcpConfig.CB_MAX_BACKOFF_MS, ); cb.failedBackoffUntil = now + backoff; logger.warn( @@ -367,6 +361,13 @@ export class MCPConnection extends EventEmitter { cb.failedBackoffUntil = 0; } + public static decrementCycleCount(serverName: string): void { + const cb = MCPConnection.circuitBreakers.get(serverName); + if (cb && cb.cycleCount > 0) { + cb.cycleCount--; + } + } + setRequestHeaders(headers: Record | null): void { if (!headers) { return; @@ -816,6 +817,13 @@ export class MCPConnection extends EventEmitter { this.emit('connectionChange', 'connected'); this.reconnectAttempts = 0; this.resetFailedRounds(); + if (this.oauthRecovery) { + MCPConnection.decrementCycleCount(this.serverName); + this.oauthRecovery = false; + logger.debug( + `${this.getLogPrefix()} OAuth recovery: decremented cycle count after successful reconnect`, + ); + } } catch (error) { // Check if it's a rate limit error - stop immediately to avoid making it worse if (this.isRateLimitError(error)) { @@ -899,9 +907,8 @@ export class MCPConnection extends EventEmitter { try { // Wait for OAuth to be handled await oauthHandledPromise; - // Reset the oauthRequired flag this.oauthRequired = false; - // Don't throw the error - just return so connection can be retried + this.oauthRecovery = true; logger.info( `${this.getLogPrefix()} OAuth handled successfully, connection will be retried`, ); diff --git a/packages/api/src/mcp/mcpConfig.ts b/packages/api/src/mcp/mcpConfig.ts index f3efd3592b..a81752e909 100644 --- a/packages/api/src/mcp/mcpConfig.ts +++ b/packages/api/src/mcp/mcpConfig.ts @@ -12,4 +12,18 @@ export const mcpConfig = { USER_CONNECTION_IDLE_TIMEOUT: math( process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000, ), + /** Max connect/disconnect cycles before the circuit breaker trips. Default: 7 */ + CB_MAX_CYCLES: math(process.env.MCP_CB_MAX_CYCLES ?? 7), + /** Sliding window (ms) for counting cycles. Default: 45s */ + CB_CYCLE_WINDOW_MS: math(process.env.MCP_CB_CYCLE_WINDOW_MS ?? 45_000), + /** Cooldown (ms) after the cycle breaker trips. Default: 15s */ + CB_CYCLE_COOLDOWN_MS: math(process.env.MCP_CB_CYCLE_COOLDOWN_MS ?? 15_000), + /** Max consecutive failed connection rounds before backoff. Default: 3 */ + CB_MAX_FAILED_ROUNDS: math(process.env.MCP_CB_MAX_FAILED_ROUNDS ?? 3), + /** Sliding window (ms) for counting failed rounds. Default: 120s */ + CB_FAILED_WINDOW_MS: math(process.env.MCP_CB_FAILED_WINDOW_MS ?? 120_000), + /** Base backoff (ms) after failed round threshold is reached. Default: 30s */ + CB_BASE_BACKOFF_MS: math(process.env.MCP_CB_BASE_BACKOFF_MS ?? 30_000), + /** Max backoff cap (ms) for exponential backoff. Default: 300s */ + CB_MAX_BACKOFF_MS: math(process.env.MCP_CB_MAX_BACKOFF_MS ?? 300_000), }; diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 83e855591e..366d0d2fde 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -426,8 +426,8 @@ export class MCPOAuthHandler { scope: config.scope, }); - /** Add state parameter with flowId to the authorization URL */ - authorizationUrl.searchParams.set('state', flowId); + /** Add cryptographic state parameter to the authorization URL */ + authorizationUrl.searchParams.set('state', state); logger.debug(`[MCPOAuth] Added state parameter to authorization URL`); const flowMetadata: MCPOAuthFlowMetadata = { @@ -505,8 +505,8 @@ export class MCPOAuthHandler { `[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`, ); - /** Add state parameter with flowId to the authorization URL */ - authorizationUrl.searchParams.set('state', flowId); + /** Add cryptographic state parameter to the authorization URL */ + authorizationUrl.searchParams.set('state', state); logger.debug(`[MCPOAuth] Added state parameter to authorization URL`); if (resourceMetadata?.resource != null && resourceMetadata.resource) { @@ -672,6 +672,44 @@ export class MCPOAuthHandler { return randomBytes(32).toString('base64url'); } + private static readonly STATE_MAP_TYPE = 'mcp_oauth_state'; + + /** + * Stores a mapping from the opaque OAuth state parameter to the flowId. + * This allows the callback to resolve the flowId from an unguessable state + * value, preventing attackers from forging callback requests. + */ + static async storeStateMapping( + state: string, + flowId: string, + flowManager: FlowStateManager, + ): Promise { + await flowManager.initFlow(state, this.STATE_MAP_TYPE, { flowId }); + } + + /** + * Resolves an opaque OAuth state parameter back to the original flowId. + * Returns null if the state is not found (expired or never stored). + */ + static async resolveStateToFlowId( + state: string, + flowManager: FlowStateManager, + ): Promise { + const mapping = await flowManager.getFlowState(state, this.STATE_MAP_TYPE); + return (mapping?.metadata?.flowId as string) ?? null; + } + + /** + * Deletes an orphaned state mapping when a flow is replaced. + * Prevents old authorization URLs from resolving after a flow restart. + */ + static async deleteStateMapping( + state: string, + flowManager: FlowStateManager, + ): Promise { + await flowManager.deleteFlow(state, this.STATE_MAP_TYPE); + } + /** * Gets the default redirect URI for a server */ diff --git a/packages/api/src/mcp/oauth/tokens.ts b/packages/api/src/mcp/oauth/tokens.ts index 005ed7dd9a..7b1d189347 100644 --- a/packages/api/src/mcp/oauth/tokens.ts +++ b/packages/api/src/mcp/oauth/tokens.ts @@ -4,6 +4,15 @@ import type { TokenMethods, IToken } from '@librechat/data-schemas'; import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types'; import { isSystemUserId } from '~/mcp/enum'; +export class ReauthenticationRequiredError extends Error { + constructor(serverName: string, reason: 'expired' | 'missing') { + super( + `Re-authentication required for "${serverName}": access token ${reason} and no refresh token available`, + ); + this.name = 'ReauthenticationRequiredError'; + } +} + interface StoreTokensParams { userId: string; serverName: string; @@ -27,7 +36,12 @@ interface GetTokensParams { findToken: TokenMethods['findToken']; refreshTokens?: ( refreshToken: string, - metadata: { userId: string; serverName: string; identifier: string }, + metadata: { + userId: string; + serverName: string; + identifier: string; + clientInfo?: OAuthClientInformation; + }, ) => Promise; createToken?: TokenMethods['createToken']; updateToken?: TokenMethods['updateToken']; @@ -273,10 +287,11 @@ export class MCPTokenStorage { }); if (!refreshTokenData) { + const reason = isMissing ? 'missing' : 'expired'; logger.info( - `${logPrefix} Access token ${isMissing ? 'missing' : 'expired'} and no refresh token available`, + `${logPrefix} Access token ${reason} and no refresh token available — re-authentication required`, ); - return null; + throw new ReauthenticationRequiredError(serverName, reason); } if (!refreshTokens) { @@ -395,6 +410,9 @@ export class MCPTokenStorage { logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`); return tokens; } catch (error) { + if (error instanceof ReauthenticationRequiredError) { + throw error; + } logger.error(`${logPrefix} Failed to retrieve tokens`, error); return null; } diff --git a/packages/api/src/mcp/oauth/types.ts b/packages/api/src/mcp/oauth/types.ts index 178e20e35b..2138b4a782 100644 --- a/packages/api/src/mcp/oauth/types.ts +++ b/packages/api/src/mcp/oauth/types.ts @@ -88,6 +88,7 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata { clientInfo?: OAuthClientInformation; metadata?: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; + authorizationUrl?: string; } export interface MCPOAuthTokens extends OAuthTokens {