mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-11 10:32:37 +01:00
🛂 fix: MCP OAuth Race Conditions, CSRF Fallback, and Token Expiry Handling (#12171)
* fix: Implement race conditions in MCP OAuth flow - Added connection mutex to coalesce concurrent `getUserConnection` calls, preventing multiple simultaneous attempts. - Enhanced flow state management to retry once when a flow state is missing, improving resilience against race conditions. - Introduced `ReauthenticationRequiredError` for better error handling when access tokens are expired or missing. - Updated tests to cover new race condition scenarios and ensure proper handling of OAuth flows. * fix: Stale PENDING flow detection and OAuth URL re-issuance PENDING flows in handleOAuthRequired now check createdAt age — flows older than 2 minutes are treated as stale and replaced instead of joined. Fixes the case where a leftover PENDING flow from a previous session blocks new OAuth initiation. authorizationUrl is now stored in MCPOAuthFlowMetadata so that when a second caller joins an active PENDING flow (e.g., the SSE-emitting path in ToolService), it can re-issue the URL to the user via oauthStart. * fix: CSRF fallback via active PENDING flow in OAuth callback When the OAuth callback arrives without CSRF or session cookies (common in the chat/SSE flow where cookies can't be set on streaming responses), fall back to validating that a PENDING flow exists for the flowId. This is safe because the flow was created server-side after JWT authentication and the authorization code is PKCE-protected. * test: Extract shared OAuth test server helpers Move MockKeyv, getFreePort, trackSockets, and createOAuthMCPServer into a shared helpers/oauthTestServer module. Enhance the test server with refresh token support, token rotation, metadata discovery, and dynamic client registration endpoints. Add InMemoryTokenStore for token storage tests. Refactor MCPOAuthRaceCondition.test.ts to import from shared helpers. * test: Add comprehensive MCP OAuth test modules MCPOAuthTokenStorage — 21 tests for storeTokens/getTokens with InMemoryTokenStore: encrypt/decrypt round-trips, expiry calculation, refresh callback wiring, ReauthenticationRequiredError paths. MCPOAuthFlow — 10 tests against real HTTP server: token refresh with stored client info, refresh token rotation, metadata discovery, dynamic client registration, full store/retrieve/expire/refresh lifecycle. MCPOAuthConnectionEvents — 5 tests for MCPConnection OAuth event cycle with real OAuth-gated MCP server: oauthRequired emission on 401, oauthHandled reconnection, oauthFailed rejection, token expiry detection. MCPOAuthTokenExpiry — 12 tests for the token expiry edge case: refresh success/failure paths, ReauthenticationRequiredError, PENDING flow CSRF fallback, authorizationUrl metadata storage, full re-auth cycle after refresh failure, concurrent expired token coalescing, stale PENDING flow detection. * test: Enhance MCP OAuth connection tests with cooldown reset Added a `beforeEach` hook to clear the cooldown for `MCPConnection` before each test, ensuring a clean state. Updated the race condition handling in the tests to properly clear the timeout, improving reliability in the event data retrieval process. * refactor: PENDING flow management and state recovery in MCP OAuth - Introduced a constant `PENDING_STALE_MS` to define the age threshold for PENDING flows, improving the handling of stale flows. - Updated the logic in `MCPConnectionFactory` and `FlowStateManager` to check the age of PENDING flows before joining or reusing them. - Modified the `completeFlow` method to return false when the flow state is deleted, ensuring graceful handling of race conditions. - Enhanced tests to validate the new behavior and ensure robustness against state recovery issues. * refactor: MCP OAuth flow management and testing - Updated the `completeFlow` method to log warnings when a tool flow state is not found during completion, improving error handling. - Introduced a new `normalizeExpiresAt` function to standardize expiration timestamp handling across the application. - Refactored token expiration checks in `MCPConnectionFactory` to utilize the new normalization function, ensuring consistent behavior. - Added a comprehensive test suite for OAuth callback CSRF fallback logic, validating the handling of PENDING flows and their staleness. - Enhanced existing tests to cover new expiration normalization logic and ensure robust flow state management. * test: Add CSRF fallback tests for active PENDING flows in MCP OAuth - Introduced new tests to validate CSRF fallback behavior when a fresh PENDING flow exists without cookies, ensuring successful OAuth callback handling. - Added scenarios to reject requests when no PENDING flow exists, when only a COMPLETED flow is present, and when a PENDING flow is stale, enhancing the robustness of flow state management. - Improved overall test coverage for OAuth callback logic, reinforcing the handling of CSRF validation failures. * chore: imports order * refactor: Update UserConnectionManager to conditionally manage pending connections - Modified the logic in `UserConnectionManager` to only set pending connections if `forceNew` is false, preventing unnecessary overwrites. - Adjusted the cleanup process to ensure pending connections are only deleted when not forced, enhancing connection management efficiency. * refactor: MCP OAuth flow state management - Introduced a new method `storeStateMapping` in `MCPOAuthHandler` to securely map the OAuth state parameter to the flow ID, improving callback resolution and security against forgery. - Updated the OAuth initiation and callback handling in `mcp.js` to utilize the new state mapping functionality, ensuring robust flow management. - Refactored `MCPConnectionFactory` to store state mappings during flow initialization, enhancing the integrity of the OAuth process. - Adjusted comments to clarify the purpose of state parameters in authorization URLs, reinforcing code readability. * refactor: MCPConnection with OAuth recovery handling - Added `oauthRecovery` flag to manage OAuth recovery state during connection attempts. - Introduced `decrementCycleCount` method to reduce the circuit breaker's cycle count upon successful reconnection after OAuth recovery. - Updated connection logic to reset the `oauthRecovery` flag after handling OAuth, improving state management and connection reliability. * chore: Add debug logging for OAuth recovery cycle count decrement - Introduced a debug log statement in the `MCPConnection` class to track the decrement of the cycle count after a successful reconnection during OAuth recovery. - This enhancement improves observability and aids in troubleshooting connection issues related to OAuth recovery. * test: Add OAuth recovery cycle management tests - Introduced new tests for the OAuth recovery cycle in `MCPConnection`, validating the decrement of cycle counts after successful reconnections. - Added scenarios to ensure that the cycle count is not decremented on OAuth failures, enhancing the robustness of connection management. - Improved test coverage for OAuth reconnect scenarios, ensuring reliable behavior under various conditions. * feat: Implement circuit breaker configuration in MCP - Added circuit breaker settings to `.env.example` for max cycles, cycle window, and cooldown duration. - Refactored `MCPConnection` to utilize the new configuration values from `mcpConfig`, enhancing circuit breaker management. - Improved code maintainability by centralizing circuit breaker parameters in the configuration file. * refactor: Update decrementCycleCount method for circuit breaker management - Changed the visibility of the `decrementCycleCount` method in `MCPConnection` from private to public static, allowing it to be called with a server name parameter. - Updated calls to `decrementCycleCount` in `MCPConnectionFactory` to use the new static method, improving clarity and consistency in circuit breaker management during connection failures and OAuth recovery. - Enhanced the handling of circuit breaker state by ensuring the method checks for the existence of the circuit breaker before decrementing the cycle count. * refactor: cycle count decrement on tool listing failure - Added a call to `MCPConnection.decrementCycleCount` in the `MCPConnectionFactory` to handle cases where unauthenticated tool listing fails, improving circuit breaker management. - This change ensures that the cycle count is decremented appropriately, maintaining the integrity of the connection recovery process. * refactor: Update circuit breaker configuration and logic - Enhanced circuit breaker settings in `.env.example` to include new parameters for failed rounds and backoff strategies. - Refactored `MCPConnection` to utilize the updated configuration values from `mcpConfig`, improving circuit breaker management. - Updated tests to reflect changes in circuit breaker logic, ensuring accurate validation of connection behavior under rapid reconnect scenarios. * feat: Implement state mapping deletion in MCP flow management - Added a new method `deleteStateMapping` in `MCPOAuthHandler` to remove orphaned state mappings when a flow is replaced, preventing old authorization URLs from resolving after a flow restart. - Updated `MCPConnectionFactory` to call `deleteStateMapping` during flow cleanup, ensuring proper management of OAuth states. - Enhanced test coverage for state mapping functionality to validate the new deletion logic.
This commit is contained in:
parent
6167ce6e57
commit
fcb344da47
22 changed files with 3865 additions and 128 deletions
21
.env.example
21
.env.example
|
|
@ -850,3 +850,24 @@ OPENWEATHER_API_KEY=
|
|||
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
|
||||
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
|
||||
# MCP_SKIP_CODE_CHALLENGE_CHECK=false
|
||||
|
||||
# Circuit breaker: max connect/disconnect cycles before tripping (per server)
|
||||
# MCP_CB_MAX_CYCLES=7
|
||||
|
||||
# Circuit breaker: sliding window (ms) for counting cycles
|
||||
# MCP_CB_CYCLE_WINDOW_MS=45000
|
||||
|
||||
# Circuit breaker: cooldown (ms) after the cycle breaker trips
|
||||
# MCP_CB_CYCLE_COOLDOWN_MS=15000
|
||||
|
||||
# Circuit breaker: max consecutive failed connection rounds before backoff
|
||||
# MCP_CB_MAX_FAILED_ROUNDS=3
|
||||
|
||||
# Circuit breaker: sliding window (ms) for counting failed rounds
|
||||
# MCP_CB_FAILED_WINDOW_MS=120000
|
||||
|
||||
# Circuit breaker: base backoff (ms) after failed round threshold is reached
|
||||
# MCP_CB_BASE_BACKOFF_MS=30000
|
||||
|
||||
# Circuit breaker: max backoff cap (ms) for exponential backoff
|
||||
# MCP_CB_MAX_BACKOFF_MS=300000
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ jest.mock('@librechat/api', () => {
|
|||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
resolveStateToFlowId: jest.fn(async (state) => state),
|
||||
storeStateMapping: jest.fn(),
|
||||
deleteStateMapping: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
|
|
@ -180,7 +183,10 @@ describe('MCP Routes', () => {
|
|||
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
||||
authorizationUrl: 'https://oauth.example.com/auth',
|
||||
flowId: 'test-user-id:test-server',
|
||||
flowMetadata: { state: 'random-state-value' },
|
||||
});
|
||||
MCPOAuthHandler.storeStateMapping.mockResolvedValue();
|
||||
mockFlowManager.initFlow = jest.fn().mockResolvedValue();
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||
userId: 'test-user-id',
|
||||
|
|
@ -367,6 +373,121 @@ describe('MCP Routes', () => {
|
|||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
||||
});
|
||||
|
||||
describe('CSRF fallback via active PENDING flow', () => {
|
||||
it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now(),
|
||||
}),
|
||||
completeFlow: jest.fn().mockResolvedValue(true),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: {},
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({
|
||||
access_token: 'test-token',
|
||||
});
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({});
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getOAuthReconnectionManager.mockReturnValue({
|
||||
clearReconnection: jest.fn(),
|
||||
});
|
||||
require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue();
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
|
||||
it('should reject when no PENDING flow exists and no cookies are present', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue(null),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
status: 'COMPLETED',
|
||||
createdAt: Date.now(),
|
||||
}),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now() - 3 * 60 * 1000,
|
||||
}),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle OAuth callback successfully', async () => {
|
||||
// mockRegistryInstance is defined at the top of the file
|
||||
const mockFlowManager = {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ const {
|
|||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
setOAuthSession,
|
||||
PENDING_STALE_MS,
|
||||
getUserMCPAuthMap,
|
||||
validateOAuthCsrf,
|
||||
OAUTH_CSRF_COOKIE,
|
||||
|
|
@ -91,7 +92,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
|||
}
|
||||
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, userId);
|
||||
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
const {
|
||||
authorizationUrl,
|
||||
flowId: oauthFlowId,
|
||||
flowMetadata,
|
||||
} = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
serverName,
|
||||
serverUrl,
|
||||
userId,
|
||||
|
|
@ -101,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
|||
|
||||
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
|
||||
|
||||
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager);
|
||||
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
res.redirect(authorizationUrl);
|
||||
} catch (error) {
|
||||
|
|
@ -143,30 +149,52 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
|
||||
}
|
||||
|
||||
const flowId = state;
|
||||
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager);
|
||||
if (!flowId) {
|
||||
logger.error('[MCP OAuth] Could not resolve state to flow ID', { state });
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId });
|
||||
|
||||
const flowParts = flowId.split(':');
|
||||
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
|
||||
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
|
||||
logger.error('[MCP OAuth] Invalid flow ID format', { flowId });
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
const [flowUserId] = flowParts;
|
||||
if (
|
||||
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
|
||||
!validateOAuthSession(req, flowUserId)
|
||||
) {
|
||||
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
|
||||
flowId,
|
||||
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
||||
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
||||
});
|
||||
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
||||
|
||||
const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId);
|
||||
let hasActiveFlow = false;
|
||||
if (!hasCsrf && !hasSession) {
|
||||
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
|
||||
hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS;
|
||||
if (hasActiveFlow) {
|
||||
logger.debug(
|
||||
'[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow',
|
||||
{
|
||||
flowId,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
if (!hasCsrf && !hasSession && !hasActiveFlow) {
|
||||
logger.error(
|
||||
'[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow',
|
||||
{
|
||||
flowId,
|
||||
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
||||
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
||||
},
|
||||
);
|
||||
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
|
||||
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
|
||||
|
|
@ -281,7 +309,13 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
const toolFlowId = flowState.metadata?.toolFlowId;
|
||||
if (toolFlowId) {
|
||||
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
|
||||
await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
|
||||
const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
|
||||
if (!completed) {
|
||||
logger.warn(
|
||||
'[MCP OAuth] Tool flow state not found during completion — waiter will time out',
|
||||
{ toolFlowId },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Redirect to success page with flowId and serverName */
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ export default {
|
|||
'\\.dev\\.ts$',
|
||||
'\\.helper\\.ts$',
|
||||
'\\.helper\\.d\\.ts$',
|
||||
'/__tests__/helpers/',
|
||||
],
|
||||
coverageReporters: ['text', 'cobertura'],
|
||||
testResultsProcessor: 'jest-junit',
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@
|
|||
"build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs",
|
||||
"build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs",
|
||||
"build:watch:prod": "rollup -c -w --bundleConfigAsCjs",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
|
||||
"test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
"test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
|
||||
"test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,18 @@ import { logger } from '@librechat/data-schemas';
|
|||
import type { StoredDataNoRaw } from 'keyv';
|
||||
import type { FlowState, FlowMetadata, FlowManagerOptions } from './types';
|
||||
|
||||
export const PENDING_STALE_MS = 2 * 60 * 1000;
|
||||
|
||||
const SECONDS_THRESHOLD = 1e10;
|
||||
|
||||
/**
|
||||
* Normalizes an expiration timestamp to milliseconds.
|
||||
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
|
||||
*/
|
||||
export function normalizeExpiresAt(timestamp: number): number {
|
||||
return timestamp < SECONDS_THRESHOLD ? timestamp * 1000 : timestamp;
|
||||
}
|
||||
|
||||
export class FlowStateManager<T = unknown> {
|
||||
private keyv: Keyv;
|
||||
private ttl: number;
|
||||
|
|
@ -45,32 +57,8 @@ export class FlowStateManager<T = unknown> {
|
|||
return `${type}:${flowId}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes an expiration timestamp to milliseconds.
|
||||
* Detects whether the input is in seconds or milliseconds based on magnitude.
|
||||
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
|
||||
* @param timestamp - The expiration timestamp (in seconds or milliseconds)
|
||||
* @returns The timestamp normalized to milliseconds
|
||||
*/
|
||||
private normalizeExpirationTimestamp(timestamp: number): number {
|
||||
const SECONDS_THRESHOLD = 1e10;
|
||||
if (timestamp < SECONDS_THRESHOLD) {
|
||||
return timestamp * 1000;
|
||||
}
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a flow's token has expired based on its expires_at field
|
||||
* @param flowState - The flow state to check
|
||||
* @returns true if the token has expired, false otherwise (including if no expires_at exists)
|
||||
*/
|
||||
private isTokenExpired(flowState: FlowState<T> | undefined): boolean {
|
||||
if (!flowState?.result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (typeof flowState.result !== 'object') {
|
||||
if (!flowState?.result || typeof flowState.result !== 'object') {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -79,13 +67,11 @@ export class FlowStateManager<T = unknown> {
|
|||
}
|
||||
|
||||
const expiresAt = (flowState.result as { expires_at: unknown }).expires_at;
|
||||
|
||||
if (typeof expiresAt !== 'number' || !Number.isFinite(expiresAt)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const normalizedExpiresAt = this.normalizeExpirationTimestamp(expiresAt);
|
||||
return normalizedExpiresAt < Date.now();
|
||||
return normalizeExpiresAt(expiresAt) < Date.now();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -149,6 +135,8 @@ export class FlowStateManager<T = unknown> {
|
|||
let elapsedTime = 0;
|
||||
let isCleanedUp = false;
|
||||
let intervalId: NodeJS.Timeout | null = null;
|
||||
let missingStateRetried = false;
|
||||
let isRetrying = false;
|
||||
|
||||
// Cleanup function to avoid duplicate cleanup
|
||||
const cleanup = () => {
|
||||
|
|
@ -188,16 +176,29 @@ export class FlowStateManager<T = unknown> {
|
|||
}
|
||||
|
||||
intervalId = setInterval(async () => {
|
||||
if (isCleanedUp) return;
|
||||
if (isCleanedUp || isRetrying) return;
|
||||
|
||||
try {
|
||||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
let flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
||||
if (!flowState) {
|
||||
cleanup();
|
||||
logger.error(`[${flowKey}] Flow state not found`);
|
||||
reject(new Error(`${type} Flow state not found`));
|
||||
return;
|
||||
if (!missingStateRetried) {
|
||||
missingStateRetried = true;
|
||||
isRetrying = true;
|
||||
logger.warn(
|
||||
`[${flowKey}] Flow state not found, retrying once after 500ms (race recovery)`,
|
||||
);
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
isRetrying = false;
|
||||
}
|
||||
|
||||
if (!flowState) {
|
||||
cleanup();
|
||||
logger.error(`[${flowKey}] Flow state not found after retry`);
|
||||
reject(new Error(`${type} Flow state not found`));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (signal?.aborted) {
|
||||
|
|
@ -251,10 +252,10 @@ export class FlowStateManager<T = unknown> {
|
|||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
||||
if (!flowState) {
|
||||
logger.warn('[FlowStateManager] Cannot complete flow - flow state not found', {
|
||||
flowId,
|
||||
type,
|
||||
});
|
||||
logger.warn(
|
||||
'[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping',
|
||||
{ flowId, type },
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -297,7 +298,7 @@ export class FlowStateManager<T = unknown> {
|
|||
async isFlowStale(
|
||||
flowId: string,
|
||||
type: string,
|
||||
staleThresholdMs: number = 2 * 60 * 1000,
|
||||
staleThresholdMs: number = PENDING_STALE_MS,
|
||||
): Promise<{ isStale: boolean; age: number; status?: string }> {
|
||||
const flowKey = this.getFlowKey(flowId, type);
|
||||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import { logger } from '@librechat/data-schemas';
|
|||
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { Tool } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth';
|
||||
import type { MCPOAuthTokens, OAuthMetadata, MCPOAuthFlowMetadata } from '~/mcp/oauth';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { FlowMetadata } from '~/flow/types';
|
||||
import type * as t from './types';
|
||||
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { PENDING_STALE_MS, normalizeExpiresAt } from '~/flow/manager';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
import { MCPConnection } from './connection';
|
||||
|
|
@ -104,6 +104,7 @@ export class MCPConnectionFactory {
|
|||
return { tools, connection, oauthRequired: false, oauthUrl: null };
|
||||
}
|
||||
} catch {
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
logger.debug(
|
||||
`${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`,
|
||||
);
|
||||
|
|
@ -125,7 +126,9 @@ export class MCPConnectionFactory {
|
|||
}
|
||||
return { tools, connection: null, oauthRequired, oauthUrl };
|
||||
}
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
} catch (listError) {
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError);
|
||||
}
|
||||
|
||||
|
|
@ -265,6 +268,10 @@ export class MCPConnectionFactory {
|
|||
if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`);
|
||||
return tokens;
|
||||
} catch (error) {
|
||||
if (error instanceof ReauthenticationRequiredError) {
|
||||
logger.info(`${this.logPrefix} ${error.message}, will trigger OAuth flow`);
|
||||
return null;
|
||||
}
|
||||
logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error);
|
||||
return null;
|
||||
}
|
||||
|
|
@ -306,11 +313,21 @@ export class MCPConnectionFactory {
|
|||
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow?.status === 'PENDING') {
|
||||
const pendingAge = existingFlow.createdAt
|
||||
? Date.now() - existingFlow.createdAt
|
||||
: Infinity;
|
||||
|
||||
if (pendingAge < PENDING_STALE_MS) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Recent PENDING OAuth flow exists (${Math.round(pendingAge / 1000)}s old), skipping new initiation`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
|
||||
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will replace`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
|
|
@ -326,11 +343,17 @@ export class MCPConnectionFactory {
|
|||
);
|
||||
|
||||
if (existingFlow) {
|
||||
const oldState = (existingFlow.metadata as MCPOAuthFlowMetadata)?.state;
|
||||
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
|
||||
if (oldState) {
|
||||
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager!);
|
||||
}
|
||||
}
|
||||
|
||||
// Store flow state BEFORE redirecting so the callback can find it
|
||||
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', flowMetadata);
|
||||
const metadataWithUrl = { ...flowMetadata, authorizationUrl };
|
||||
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
|
||||
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager!);
|
||||
|
||||
// Start monitoring in background — createFlow will find the existing PENDING state
|
||||
// written by initFlow above, so metadata arg is unused (pass {} to make that explicit)
|
||||
|
|
@ -495,11 +518,75 @@ export class MCPConnectionFactory {
|
|||
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow) {
|
||||
const flowMeta = existingFlow.metadata as MCPOAuthFlowMetadata | undefined;
|
||||
|
||||
if (existingFlow.status === 'PENDING') {
|
||||
const pendingAge = existingFlow.createdAt
|
||||
? Date.now() - existingFlow.createdAt
|
||||
: Infinity;
|
||||
|
||||
if (pendingAge < PENDING_STALE_MS) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found recent PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), joining instead of creating new one`,
|
||||
);
|
||||
|
||||
const storedAuthUrl = flowMeta?.authorizationUrl;
|
||||
if (storedAuthUrl && typeof this.oauthStart === 'function') {
|
||||
logger.info(
|
||||
`${this.logPrefix} Re-issuing stored authorization URL to caller while joining PENDING flow`,
|
||||
);
|
||||
await this.oauthStart(storedAuthUrl);
|
||||
}
|
||||
|
||||
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth', {}, this.signal);
|
||||
if (typeof this.oauthEnd === 'function') {
|
||||
await this.oauthEnd();
|
||||
}
|
||||
logger.info(
|
||||
`${this.logPrefix} Joined existing OAuth flow completed for ${this.serverName}`,
|
||||
);
|
||||
return {
|
||||
tokens,
|
||||
clientInfo: flowMeta?.clientInfo,
|
||||
metadata: flowMeta?.metadata,
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will delete and start fresh`,
|
||||
);
|
||||
}
|
||||
|
||||
if (existingFlow.status === 'COMPLETED') {
|
||||
const completedAge = existingFlow.completedAt
|
||||
? Date.now() - existingFlow.completedAt
|
||||
: Infinity;
|
||||
const cachedTokens = existingFlow.result as MCPOAuthTokens | null | undefined;
|
||||
const isTokenExpired =
|
||||
cachedTokens?.expires_at != null &&
|
||||
normalizeExpiresAt(cachedTokens.expires_at) < Date.now();
|
||||
|
||||
if (completedAge <= PENDING_STALE_MS && cachedTokens !== undefined && !isTokenExpired) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found non-stale COMPLETED OAuth flow, reusing cached tokens`,
|
||||
);
|
||||
return {
|
||||
tokens: cachedTokens,
|
||||
clientInfo: flowMeta?.clientInfo,
|
||||
metadata: flowMeta?.metadata,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
|
||||
);
|
||||
try {
|
||||
const oldState = flowMeta?.state;
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
if (oldState) {
|
||||
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
|
||||
}
|
||||
|
|
@ -519,7 +606,9 @@ export class MCPConnectionFactory {
|
|||
);
|
||||
|
||||
// Store flow state BEFORE redirecting so the callback can find it
|
||||
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata);
|
||||
const metadataWithUrl = { ...flowMetadata, authorizationUrl };
|
||||
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
|
||||
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager);
|
||||
|
||||
if (typeof this.oauthStart === 'function') {
|
||||
logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`);
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { MCPConnection } from './connection';
|
||||
import type * as t from './types';
|
||||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPConnection } from './connection';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
||||
/**
|
||||
|
|
@ -21,6 +21,8 @@ export abstract class UserConnectionManager {
|
|||
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
|
||||
/** Last activity timestamp for users (not per server) */
|
||||
protected userLastActivity: Map<string, number> = new Map();
|
||||
/** In-flight connection promises keyed by `userId:serverName` — coalesces concurrent attempts */
|
||||
protected pendingConnections: Map<string, Promise<MCPConnection>> = new Map();
|
||||
|
||||
/** Updates the last activity timestamp for a user */
|
||||
protected updateUserLastActivity(userId: string): void {
|
||||
|
|
@ -31,29 +33,64 @@ export abstract class UserConnectionManager {
|
|||
);
|
||||
}
|
||||
|
||||
/** Gets or creates a connection for a specific user */
|
||||
public async getUserConnection({
|
||||
serverName,
|
||||
forceNew,
|
||||
user,
|
||||
flowManager,
|
||||
customUserVars,
|
||||
requestBody,
|
||||
tokenMethods,
|
||||
oauthStart,
|
||||
oauthEnd,
|
||||
signal,
|
||||
returnOnOAuth = false,
|
||||
connectionTimeout,
|
||||
}: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
|
||||
/** Gets or creates a connection for a specific user, coalescing concurrent attempts */
|
||||
public async getUserConnection(
|
||||
opts: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
|
||||
): Promise<MCPConnection> {
|
||||
const { serverName, forceNew, user } = opts;
|
||||
const userId = user?.id;
|
||||
if (!userId) {
|
||||
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
|
||||
}
|
||||
|
||||
const lockKey = `${userId}:${serverName}`;
|
||||
|
||||
if (!forceNew) {
|
||||
const pending = this.pendingConnections.get(lockKey);
|
||||
if (pending) {
|
||||
logger.debug(`[MCP][User: ${userId}][${serverName}] Joining in-flight connection attempt`);
|
||||
return pending;
|
||||
}
|
||||
}
|
||||
|
||||
const connectionPromise = this.createUserConnectionInternal(opts, userId);
|
||||
|
||||
if (!forceNew) {
|
||||
this.pendingConnections.set(lockKey, connectionPromise);
|
||||
}
|
||||
|
||||
try {
|
||||
return await connectionPromise;
|
||||
} finally {
|
||||
if (!forceNew && this.pendingConnections.get(lockKey) === connectionPromise) {
|
||||
this.pendingConnections.delete(lockKey);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async createUserConnectionInternal(
|
||||
{
|
||||
serverName,
|
||||
forceNew,
|
||||
user,
|
||||
flowManager,
|
||||
customUserVars,
|
||||
requestBody,
|
||||
tokenMethods,
|
||||
oauthStart,
|
||||
oauthEnd,
|
||||
signal,
|
||||
returnOnOAuth = false,
|
||||
connectionTimeout,
|
||||
}: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
|
||||
userId: string,
|
||||
): Promise<MCPConnection> {
|
||||
if (await this.appConnections!.has(serverName)) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
|
|
@ -188,6 +225,7 @@ export abstract class UserConnectionManager {
|
|||
|
||||
/** Disconnects and removes a specific user connection */
|
||||
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
|
||||
this.pendingConnections.delete(`${userId}:${serverName}`);
|
||||
const userMap = this.userConnections.get(userId);
|
||||
const connection = userMap?.get(serverName);
|
||||
if (connection) {
|
||||
|
|
@ -215,6 +253,12 @@ export abstract class UserConnectionManager {
|
|||
);
|
||||
}
|
||||
await Promise.allSettled(disconnectPromises);
|
||||
// Clean up any pending connection promises for this user
|
||||
for (const key of this.pendingConnections.keys()) {
|
||||
if (key.startsWith(`${userId}:`)) {
|
||||
this.pendingConnections.delete(key);
|
||||
}
|
||||
}
|
||||
// Ensure user activity timestamp is removed
|
||||
this.userLastActivity.delete(userId);
|
||||
logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ describe('MCPConnectionFactory', () => {
|
|||
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
|
||||
'flow123',
|
||||
'mcp_oauth',
|
||||
mockFlowData.flowMetadata,
|
||||
expect.objectContaining(mockFlowData.flowMetadata),
|
||||
);
|
||||
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
|
||||
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
|
||||
|
|
@ -550,7 +550,7 @@ describe('MCPConnectionFactory', () => {
|
|||
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
|
||||
'flow123',
|
||||
'mcp_oauth',
|
||||
mockFlowData.flowMetadata,
|
||||
expect.objectContaining(mockFlowData.flowMetadata),
|
||||
);
|
||||
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
|
||||
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
|
||||
|
|
|
|||
232
packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts
Normal file
232
packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
/**
|
||||
* Tests for the OAuth callback CSRF fallback logic.
|
||||
*
|
||||
* The callback route validates requests via three mechanisms (in order):
|
||||
* 1. CSRF cookie (HMAC-based, set during initiate)
|
||||
* 2. Session cookie (bound to authenticated userId)
|
||||
* 3. Active PENDING flow in FlowStateManager (fallback for SSE/chat flows)
|
||||
*
|
||||
* This suite tests mechanism 3 — the PENDING flow fallback — including
|
||||
* staleness enforcement and rejection of non-PENDING flows.
|
||||
*
|
||||
* These tests exercise the validation functions directly for fast,
|
||||
* focused coverage. Route-level integration tests using supertest
|
||||
* are in api/server/routes/__tests__/mcp.spec.js ("CSRF fallback
|
||||
* via active PENDING flow" describe block).
|
||||
*/
|
||||
|
||||
import { Keyv } from 'keyv';
|
||||
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
|
||||
import type { Request, Response } from 'express';
|
||||
import {
|
||||
generateOAuthCsrfToken,
|
||||
OAUTH_SESSION_COOKIE,
|
||||
validateOAuthSession,
|
||||
OAUTH_CSRF_COOKIE,
|
||||
validateOAuthCsrf,
|
||||
} from '~/oauth/csrf';
|
||||
import { MockKeyv } from './helpers/oauthTestServer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
const CSRF_COOKIE_PATH = '/api/mcp';
|
||||
|
||||
function makeReq(cookies: Record<string, string> = {}): Request {
|
||||
return { cookies } as unknown as Request;
|
||||
}
|
||||
|
||||
function makeRes(): Response {
|
||||
const res = {
|
||||
clearCookie: jest.fn(),
|
||||
} as unknown as Response;
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Replicate the callback route's three-tier validation logic.
|
||||
* Returns which mechanism (if any) authorized the request.
|
||||
*/
|
||||
async function validateCallback(
|
||||
req: Request,
|
||||
res: Response,
|
||||
flowId: string,
|
||||
flowManager: FlowStateManager,
|
||||
): Promise<'csrf' | 'session' | 'pendingFlow' | false> {
|
||||
const flowUserId = flowId.split(':')[0];
|
||||
|
||||
const hasCsrf = validateOAuthCsrf(req, res, flowId, CSRF_COOKIE_PATH);
|
||||
if (hasCsrf) {
|
||||
return 'csrf';
|
||||
}
|
||||
|
||||
const hasSession = validateOAuthSession(req, flowUserId);
|
||||
if (hasSession) {
|
||||
return 'session';
|
||||
}
|
||||
|
||||
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
|
||||
if (pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS) {
|
||||
return 'pendingFlow';
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
describe('OAuth Callback CSRF Fallback', () => {
|
||||
let flowManager: FlowStateManager;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.JWT_SECRET = 'test-secret-for-csrf';
|
||||
const store = new MockKeyv();
|
||||
flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 300000, ci: true });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
delete process.env.JWT_SECRET;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('CSRF cookie validation (mechanism 1)', () => {
|
||||
it('should accept valid CSRF cookie', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('csrf');
|
||||
});
|
||||
|
||||
it('should reject invalid CSRF cookie', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const req = makeReq({ [OAUTH_CSRF_COOKIE]: 'wrong-token-value' });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Session cookie validation (mechanism 2)', () => {
|
||||
it('should accept valid session cookie when CSRF is absent', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('session');
|
||||
});
|
||||
});
|
||||
|
||||
describe('PENDING flow fallback (mechanism 3)', () => {
|
||||
it('should accept when a fresh PENDING flow exists and no cookies are present', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('pendingFlow');
|
||||
});
|
||||
|
||||
it('should reject when no PENDING flow, no CSRF cookie, and no session cookie', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', { access_token: 'tok' } as never);
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject when only a FAILED flow exists', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
await flowManager.failFlow(flowId, 'mcp_oauth', 'some error');
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
// Artificially age the flow past the staleness threshold
|
||||
const store = (flowManager as unknown as { keyv: { get: (k: string) => Promise<unknown> } })
|
||||
.keyv;
|
||||
const flowState = (await store.get(`mcp_oauth:${flowId}`)) as { createdAt: number };
|
||||
flowState.createdAt = Date.now() - PENDING_STALE_MS - 1000;
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should accept PENDING flow that is just under the staleness threshold', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
// Flow was just created, well under threshold
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('pendingFlow');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Priority ordering', () => {
|
||||
it('should prefer CSRF cookie over PENDING flow', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('csrf');
|
||||
});
|
||||
|
||||
it('should prefer session cookie over PENDING flow when CSRF is absent', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('session');
|
||||
});
|
||||
});
|
||||
});
|
||||
268
packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts
Normal file
268
packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
/**
|
||||
* Tests for MCPConnection OAuth event cycle against a real OAuth-gated MCP server.
|
||||
*
|
||||
* Verifies: oauthRequired emission on 401, oauthHandled reconnection,
|
||||
* oauthFailed rejection, timeout behavior, and token expiry mid-session.
|
||||
*/
|
||||
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
jest.mock('~/auth', () => ({
|
||||
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
|
||||
resolveHostnameSSRF: jest.fn(async () => false),
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/mcpConfig', () => ({
|
||||
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
|
||||
}));
|
||||
|
||||
async function safeDisconnect(conn: MCPConnection | null): Promise<void> {
|
||||
if (!conn) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await conn.disconnect();
|
||||
} catch {
|
||||
// Ignore disconnect errors during cleanup
|
||||
}
|
||||
}
|
||||
|
||||
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
|
||||
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${serverUrl}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const data = (await tokenRes.json()) as { access_token: string };
|
||||
return data.access_token;
|
||||
}
|
||||
|
||||
describe('MCPConnection OAuth Events — Real Server', () => {
|
||||
let server: OAuthTestServer;
|
||||
let connection: MCPConnection | null = null;
|
||||
|
||||
beforeEach(() => {
|
||||
MCPConnection.clearCooldown('test-server');
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await safeDisconnect(connection);
|
||||
connection = null;
|
||||
if (server) {
|
||||
await server.close();
|
||||
}
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('oauthRequired event', () => {
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
it('should emit oauthRequired when connecting without a token', async () => {
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: { type: 'streamable-http', url: server.url },
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
const oauthRequiredPromise = new Promise<{
|
||||
serverName: string;
|
||||
error: Error;
|
||||
serverUrl?: string;
|
||||
userId?: string;
|
||||
}>((resolve) => {
|
||||
connection!.on('oauthRequired', (data) => {
|
||||
resolve(
|
||||
data as {
|
||||
serverName: string;
|
||||
error: Error;
|
||||
serverUrl?: string;
|
||||
userId?: string;
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// Connection will fail with 401, emitting oauthRequired
|
||||
const connectPromise = connection.connect().catch(() => {
|
||||
// Expected to fail since no one handles oauthRequired
|
||||
});
|
||||
|
||||
let raceTimer: NodeJS.Timeout | undefined;
|
||||
const eventData = await Promise.race([
|
||||
oauthRequiredPromise,
|
||||
new Promise<never>((_, reject) => {
|
||||
raceTimer = setTimeout(
|
||||
() => reject(new Error('Timed out waiting for oauthRequired')),
|
||||
10000,
|
||||
);
|
||||
}),
|
||||
]).finally(() => clearTimeout(raceTimer));
|
||||
|
||||
expect(eventData.serverName).toBe('test-server');
|
||||
expect(eventData.error).toBeDefined();
|
||||
|
||||
// Emit oauthFailed to unblock connect()
|
||||
connection.emit('oauthFailed', new Error('test cleanup'));
|
||||
await connectPromise.catch(() => undefined);
|
||||
});
|
||||
|
||||
it('should not emit oauthRequired when connecting with a valid token', async () => {
|
||||
const accessToken = await exchangeCodeForToken(server.url);
|
||||
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: { type: 'streamable-http', url: server.url },
|
||||
userId: 'user-1',
|
||||
oauthTokens: {
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens,
|
||||
});
|
||||
|
||||
let oauthFired = false;
|
||||
connection.on('oauthRequired', () => {
|
||||
oauthFired = true;
|
||||
});
|
||||
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
expect(oauthFired).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('oauthHandled reconnection', () => {
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
it('should succeed on retry after oauthHandled provides valid tokens', async () => {
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
type: 'streamable-http',
|
||||
url: server.url,
|
||||
initTimeout: 15000,
|
||||
},
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
// First connect fails with 401 → oauthRequired fires
|
||||
let oauthFired = false;
|
||||
connection.on('oauthRequired', () => {
|
||||
oauthFired = true;
|
||||
connection!.emit('oauthFailed', new Error('Will retry with tokens'));
|
||||
});
|
||||
|
||||
// First attempt fails as expected
|
||||
await expect(connection.connect()).rejects.toThrow();
|
||||
expect(oauthFired).toBe(true);
|
||||
|
||||
// Now set valid tokens and reconnect
|
||||
const accessToken = await exchangeCodeForToken(server.url);
|
||||
connection.setOAuthTokens({
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('oauthFailed rejection', () => {
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
it('should reject connect() when oauthFailed is emitted', async () => {
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
type: 'streamable-http',
|
||||
url: server.url,
|
||||
initTimeout: 15000,
|
||||
},
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
connection.on('oauthRequired', () => {
|
||||
connection!.emit('oauthFailed', new Error('User denied OAuth'));
|
||||
});
|
||||
|
||||
await expect(connection.connect()).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Token expiry during session', () => {
|
||||
it('should detect expired token on reconnect and emit oauthRequired', async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 1000 });
|
||||
|
||||
const accessToken = await exchangeCodeForToken(server.url);
|
||||
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
type: 'streamable-http',
|
||||
url: server.url,
|
||||
initTimeout: 15000,
|
||||
},
|
||||
userId: 'user-1',
|
||||
oauthTokens: {
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens,
|
||||
});
|
||||
|
||||
// Initial connect should succeed
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
await connection.disconnect();
|
||||
|
||||
// Wait for token to expire
|
||||
await new Promise((r) => setTimeout(r, 1200));
|
||||
|
||||
// Reconnect should trigger oauthRequired since token is expired on the server
|
||||
let oauthFired = false;
|
||||
connection.on('oauthRequired', () => {
|
||||
oauthFired = true;
|
||||
connection!.emit('oauthFailed', new Error('Will retry with fresh token'));
|
||||
});
|
||||
|
||||
// First reconnect fails with 401 → oauthRequired
|
||||
await expect(connection.connect()).rejects.toThrow();
|
||||
expect(oauthFired).toBe(true);
|
||||
|
||||
// Get fresh token and reconnect
|
||||
const newToken = await exchangeCodeForToken(server.url);
|
||||
connection.setOAuthTokens({
|
||||
access_token: newToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
538
packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts
Normal file
538
packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts
Normal file
|
|
@ -0,0 +1,538 @@
|
|||
/**
|
||||
* OAuth flow tests against a real HTTP server.
|
||||
*
|
||||
* Tests MCPOAuthHandler.refreshOAuthTokens and MCPTokenStorage lifecycle
|
||||
* using a real test OAuth server (not mocked SDK functions).
|
||||
*/
|
||||
|
||||
import { createHash } from 'crypto';
|
||||
import { Keyv } from 'keyv';
|
||||
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
describe('MCP OAuth Flow — Real HTTP Server', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Token refresh with real server', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should refresh tokens with stored client info via real /token endpoint', async () => {
|
||||
// First get initial tokens
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
expect(initial.refresh_token).toBeDefined();
|
||||
|
||||
// Register a client so we have clientInfo
|
||||
const regRes = await fetch(`${server.url}register`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
|
||||
});
|
||||
const clientInfo = (await regRes.json()) as {
|
||||
client_id: string;
|
||||
client_secret: string;
|
||||
};
|
||||
|
||||
// Refresh tokens using the real endpoint
|
||||
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
initial.refresh_token,
|
||||
{
|
||||
serverName: 'test-server',
|
||||
serverUrl: server.url,
|
||||
clientInfo: {
|
||||
...clientInfo,
|
||||
redirect_uris: ['http://localhost/callback'],
|
||||
},
|
||||
},
|
||||
{},
|
||||
{
|
||||
token_url: `${server.url}token`,
|
||||
client_id: clientInfo.client_id,
|
||||
client_secret: clientInfo.client_secret,
|
||||
token_exchange_method: 'DefaultPost',
|
||||
},
|
||||
);
|
||||
|
||||
expect(refreshed.access_token).toBeDefined();
|
||||
expect(refreshed.access_token).not.toBe(initial.access_token);
|
||||
expect(refreshed.token_type).toBe('Bearer');
|
||||
expect(refreshed.obtained_at).toBeDefined();
|
||||
});
|
||||
|
||||
it('should get new refresh token when server rotates', async () => {
|
||||
const rotatingServer = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
rotateRefreshTokens: true,
|
||||
});
|
||||
|
||||
try {
|
||||
const code = await rotatingServer.getAuthCode();
|
||||
const tokenRes = await fetch(`${rotatingServer.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
initial.refresh_token,
|
||||
{
|
||||
serverName: 'test-server',
|
||||
serverUrl: rotatingServer.url,
|
||||
},
|
||||
{},
|
||||
{
|
||||
token_url: `${rotatingServer.url}token`,
|
||||
client_id: 'anon',
|
||||
token_exchange_method: 'DefaultPost',
|
||||
},
|
||||
);
|
||||
|
||||
expect(refreshed.access_token).not.toBe(initial.access_token);
|
||||
expect(refreshed.refresh_token).toBeDefined();
|
||||
expect(refreshed.refresh_token).not.toBe(initial.refresh_token);
|
||||
} finally {
|
||||
await rotatingServer.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should fail refresh with invalid refresh token', async () => {
|
||||
await expect(
|
||||
MCPOAuthHandler.refreshOAuthTokens(
|
||||
'invalid-refresh-token',
|
||||
{
|
||||
serverName: 'test-server',
|
||||
serverUrl: server.url,
|
||||
},
|
||||
{},
|
||||
{
|
||||
token_url: `${server.url}token`,
|
||||
client_id: 'anon',
|
||||
token_exchange_method: 'DefaultPost',
|
||||
},
|
||||
),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('OAuth server metadata discovery', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ issueRefreshTokens: true });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should expose /.well-known/oauth-authorization-server', async () => {
|
||||
const res = await fetch(`${server.url}.well-known/oauth-authorization-server`);
|
||||
expect(res.status).toBe(200);
|
||||
|
||||
const metadata = (await res.json()) as {
|
||||
authorization_endpoint: string;
|
||||
token_endpoint: string;
|
||||
registration_endpoint: string;
|
||||
grant_types_supported: string[];
|
||||
};
|
||||
|
||||
expect(metadata.authorization_endpoint).toContain('/authorize');
|
||||
expect(metadata.token_endpoint).toContain('/token');
|
||||
expect(metadata.registration_endpoint).toContain('/register');
|
||||
expect(metadata.grant_types_supported).toContain('authorization_code');
|
||||
expect(metadata.grant_types_supported).toContain('refresh_token');
|
||||
});
|
||||
|
||||
it('should not advertise refresh_token grant when disabled', async () => {
|
||||
const noRefreshServer = await createOAuthMCPServer({
|
||||
issueRefreshTokens: false,
|
||||
});
|
||||
try {
|
||||
const res = await fetch(`${noRefreshServer.url}.well-known/oauth-authorization-server`);
|
||||
const metadata = (await res.json()) as { grant_types_supported: string[] };
|
||||
expect(metadata.grant_types_supported).not.toContain('refresh_token');
|
||||
} finally {
|
||||
await noRefreshServer.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Dynamic client registration', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should register a client via /register endpoint', async () => {
|
||||
const res = await fetch(`${server.url}register`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
redirect_uris: ['http://localhost/callback'],
|
||||
}),
|
||||
});
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
const client = (await res.json()) as {
|
||||
client_id: string;
|
||||
client_secret: string;
|
||||
redirect_uris: string[];
|
||||
};
|
||||
|
||||
expect(client.client_id).toBeDefined();
|
||||
expect(client.client_secret).toBeDefined();
|
||||
expect(client.redirect_uris).toEqual(['http://localhost/callback']);
|
||||
expect(server.registeredClients.has(client.client_id)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('End-to-End: store, retrieve, expire, refresh cycle', () => {
|
||||
it('should perform full token lifecycle with real server', async () => {
|
||||
const server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 1000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
const tokenStore = new InMemoryTokenStore();
|
||||
|
||||
try {
|
||||
// 1. Get initial tokens via auth code exchange
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// 2. Store tokens
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
tokens: initial,
|
||||
createToken: tokenStore.createToken,
|
||||
});
|
||||
|
||||
// 3. Retrieve — should succeed
|
||||
const valid = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
expect(valid).not.toBeNull();
|
||||
expect(valid!.access_token).toBe(initial.access_token);
|
||||
expect(valid!.refresh_token).toBe(initial.refresh_token);
|
||||
|
||||
// 4. Wait for expiry
|
||||
await new Promise((r) => setTimeout(r, 1200));
|
||||
|
||||
// 5. Retrieve again — should trigger refresh via callback
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const refreshRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
|
||||
if (!refreshRes.ok) {
|
||||
throw new Error(`Refresh failed: ${refreshRes.status}`);
|
||||
}
|
||||
|
||||
const data = (await refreshRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
...data,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + data.expires_in * 1000,
|
||||
};
|
||||
};
|
||||
|
||||
const refreshed = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(refreshed).not.toBeNull();
|
||||
expect(refreshed!.access_token).not.toBe(initial.access_token);
|
||||
|
||||
// 6. Verify the refreshed token works against the server
|
||||
const mcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${refreshed!.access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
expect(mcpRes.status).toBe(200);
|
||||
} finally {
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('completeOAuthFlow via FlowStateManager', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ issueRefreshTokens: true });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should exchange auth code and complete flow in FlowStateManager', async () => {
|
||||
const store = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'test-user:test-server';
|
||||
const code = await server.getAuthCode();
|
||||
|
||||
// Initialize the flow with metadata the handler needs
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverUrl: server.url,
|
||||
clientInfo: {
|
||||
client_id: 'test-client',
|
||||
redirect_uris: ['http://localhost/callback'],
|
||||
},
|
||||
codeVerifier: 'test-verifier',
|
||||
metadata: {
|
||||
token_endpoint: `${server.url}token`,
|
||||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||
},
|
||||
});
|
||||
|
||||
// The SDK's exchangeAuthorization wants full OAuth metadata,
|
||||
// so we'll test the token exchange directly instead of going through
|
||||
// completeOAuthFlow (which requires full SDK-compatible metadata)
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
const tokens = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
const mcpTokens: MCPOAuthTokens = {
|
||||
...tokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + tokens.expires_in * 1000,
|
||||
};
|
||||
|
||||
// Complete the flow
|
||||
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
|
||||
expect(completed).toBe(true);
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('COMPLETED');
|
||||
expect(state?.result?.access_token).toBe(tokens.access_token);
|
||||
});
|
||||
|
||||
it('should fail flow when authorization code is invalid', async () => {
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: 'grant_type=authorization_code&code=invalid-code',
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(400);
|
||||
const body = (await tokenRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
|
||||
it('should fail when authorization code is reused', async () => {
|
||||
const code = await server.getAuthCode();
|
||||
|
||||
// First exchange succeeds
|
||||
const firstRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
expect(firstRes.status).toBe(200);
|
||||
|
||||
// Second exchange fails
|
||||
const secondRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
expect(secondRes.status).toBe(400);
|
||||
const body = (await secondRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
});
|
||||
|
||||
describe('PKCE verification', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
function generatePKCE(): { verifier: string; challenge: string } {
|
||||
const verifier = 'dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk';
|
||||
const challenge = createHash('sha256').update(verifier).digest('base64url');
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
||||
it('should accept valid code_verifier matching code_challenge', async () => {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}&code_verifier=${verifier}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(200);
|
||||
const data = (await tokenRes.json()) as { access_token: string };
|
||||
expect(data.access_token).toBeDefined();
|
||||
});
|
||||
|
||||
it('should reject wrong code_verifier', async () => {
|
||||
const { challenge } = generatePKCE();
|
||||
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}&code_verifier=wrong-verifier`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(400);
|
||||
const body = (await tokenRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
|
||||
it('should reject missing code_verifier when code_challenge was provided', async () => {
|
||||
const { challenge } = generatePKCE();
|
||||
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(400);
|
||||
const body = (await tokenRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
|
||||
it('should still accept codes without PKCE when no code_challenge was provided', async () => {
|
||||
const code = await server.getAuthCode();
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(200);
|
||||
});
|
||||
});
|
||||
});
|
||||
516
packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts
Normal file
516
packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts
Normal file
|
|
@ -0,0 +1,516 @@
|
|||
/**
|
||||
* Tests for MCP OAuth race condition fixes:
|
||||
*
|
||||
* 1. Connection mutex coalesces concurrent getUserConnection() calls
|
||||
* 2. PENDING OAuth flows are reused, not deleted
|
||||
* 3. No-refresh-token expiry throws ReauthenticationRequiredError
|
||||
* 4. completeFlow recovers when flow state was deleted by a race
|
||||
* 5. monitorFlow retries once when flow state disappears mid-poll
|
||||
*/
|
||||
|
||||
import { Keyv } from 'keyv';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { MockKeyv, createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
jest.mock('~/auth', () => ({
|
||||
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
|
||||
resolveHostnameSSRF: jest.fn(async () => false),
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/mcpConfig', () => ({
|
||||
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
|
||||
}));
|
||||
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
|
||||
describe('MCP OAuth Race Condition Fixes', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Fix 1: Connection mutex coalesces concurrent attempts', () => {
|
||||
it('should return the same pending promise for concurrent getUserConnection calls', async () => {
|
||||
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
|
||||
|
||||
class TestManager extends UserConnectionManager {
|
||||
public createCallCount = 0;
|
||||
|
||||
getPendingConnections() {
|
||||
return this.pendingConnections;
|
||||
}
|
||||
}
|
||||
|
||||
const manager = new TestManager();
|
||||
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
isStale: jest.fn().mockReturnValue(false),
|
||||
};
|
||||
|
||||
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
|
||||
manager.appConnections = mockAppConnections as never;
|
||||
|
||||
const mockConfig = {
|
||||
type: 'streamable-http',
|
||||
url: 'http://localhost:9999/',
|
||||
updatedAt: undefined,
|
||||
dbId: undefined,
|
||||
};
|
||||
|
||||
jest
|
||||
.spyOn(
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
|
||||
'getInstance',
|
||||
)
|
||||
.mockReturnValue({
|
||||
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
});
|
||||
|
||||
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
|
||||
const createSpy = jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
|
||||
manager.createCallCount++;
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
return mockConnection as never;
|
||||
});
|
||||
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
const user = { id: 'user-1' };
|
||||
const opts = {
|
||||
serverName: 'test-server',
|
||||
user: user as never,
|
||||
flowManager: flowManager as never,
|
||||
};
|
||||
|
||||
const [conn1, conn2, conn3] = await Promise.all([
|
||||
manager.getUserConnection(opts),
|
||||
manager.getUserConnection(opts),
|
||||
manager.getUserConnection(opts),
|
||||
]);
|
||||
|
||||
expect(conn1).toBe(conn2);
|
||||
expect(conn2).toBe(conn3);
|
||||
expect(createSpy).toHaveBeenCalledTimes(1);
|
||||
expect(manager.createCallCount).toBe(1);
|
||||
|
||||
createSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should not coalesce when forceNew is true', async () => {
|
||||
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
|
||||
|
||||
class TestManager extends UserConnectionManager {}
|
||||
|
||||
const manager = new TestManager();
|
||||
|
||||
let callCount = 0;
|
||||
const makeConnection = () => ({
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
isStale: jest.fn().mockReturnValue(false),
|
||||
});
|
||||
|
||||
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
|
||||
manager.appConnections = mockAppConnections as never;
|
||||
|
||||
const mockConfig = {
|
||||
type: 'streamable-http',
|
||||
url: 'http://localhost:9999/',
|
||||
updatedAt: undefined,
|
||||
dbId: undefined,
|
||||
};
|
||||
|
||||
jest
|
||||
.spyOn(
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
|
||||
'getInstance',
|
||||
)
|
||||
.mockReturnValue({
|
||||
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
});
|
||||
|
||||
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
|
||||
jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
|
||||
callCount++;
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
return makeConnection() as never;
|
||||
});
|
||||
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
const user = { id: 'user-2' };
|
||||
|
||||
const [conn1, conn2] = await Promise.all([
|
||||
manager.getUserConnection({
|
||||
serverName: 'test-server',
|
||||
forceNew: true,
|
||||
user: user as never,
|
||||
flowManager: flowManager as never,
|
||||
}),
|
||||
manager.getUserConnection({
|
||||
serverName: 'test-server',
|
||||
forceNew: true,
|
||||
user: user as never,
|
||||
flowManager: flowManager as never,
|
||||
}),
|
||||
]);
|
||||
|
||||
expect(callCount).toBe(2);
|
||||
expect(conn1).not.toBe(conn2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fix 2: PENDING flow is reused, not deleted', () => {
|
||||
it('should join an existing PENDING flow via createFlow instead of deleting it', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'test-flow-pending';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
clientInfo: { client_id: 'test-client' },
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('PENDING');
|
||||
|
||||
const deleteSpy = jest.spyOn(flowManager, 'deleteFlow');
|
||||
|
||||
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
|
||||
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', {
|
||||
access_token: 'test-token',
|
||||
token_type: 'Bearer',
|
||||
} as never);
|
||||
|
||||
const result = await monitorPromise;
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({ access_token: 'test-token', token_type: 'Bearer' }),
|
||||
);
|
||||
expect(deleteSpy).not.toHaveBeenCalled();
|
||||
|
||||
deleteSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should delete and recreate FAILED flows', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'test-flow-failed';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
await flowManager.failFlow(flowId, 'mcp_oauth', 'previous error');
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('FAILED');
|
||||
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
const afterDelete = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(afterDelete).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fix 3: completeFlow handles deleted state gracefully', () => {
|
||||
it('should return false when state was deleted by race', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'race-deleted-flow';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
const stateBeforeComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(stateBeforeComplete).toBeUndefined();
|
||||
|
||||
const result = await flowManager.completeFlow(flowId, 'mcp_oauth', {
|
||||
access_token: 'recovered-token',
|
||||
token_type: 'Bearer',
|
||||
} as never);
|
||||
|
||||
expect(result).toBe(false);
|
||||
|
||||
const stateAfterComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(stateAfterComplete).toBeUndefined();
|
||||
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('cannot recover metadata'),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject monitorFlow when state is deleted and not recoverable', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'monitor-retry-flow';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
|
||||
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
|
||||
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
await expect(monitorPromise).rejects.toThrow('Flow state not found');
|
||||
});
|
||||
});
|
||||
|
||||
describe('State mapping cleanup on flow replacement', () => {
|
||||
it('should delete old state mapping when a flow is replaced', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager<MCPOAuthTokens | null>(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
const oldState = 'old-random-state-abc123';
|
||||
const newState = 'new-random-state-xyz789';
|
||||
|
||||
// Simulate initial flow with state mapping
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { state: oldState });
|
||||
await MCPOAuthHandler.storeStateMapping(oldState, flowId, flowManager);
|
||||
|
||||
// Old state should resolve
|
||||
const resolvedBefore = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
|
||||
expect(resolvedBefore).toBe(flowId);
|
||||
|
||||
// Replace the flow: delete old, create new, clean up old state mapping
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
await MCPOAuthHandler.deleteStateMapping(oldState, flowManager);
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { state: newState });
|
||||
await MCPOAuthHandler.storeStateMapping(newState, flowId, flowManager);
|
||||
|
||||
// Old state should no longer resolve
|
||||
const resolvedOld = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
|
||||
expect(resolvedOld).toBeNull();
|
||||
|
||||
// New state should resolve
|
||||
const resolvedNew = await MCPOAuthHandler.resolveStateToFlowId(newState, flowManager);
|
||||
expect(resolvedNew).toBe(flowId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fix 4: ReauthenticationRequiredError for no-refresh-token', () => {
|
||||
it('should throw ReauthenticationRequiredError when access token expired and no refresh token', async () => {
|
||||
const expiredDate = new Date(Date.now() - 60000);
|
||||
|
||||
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
|
||||
if (filter.type === 'mcp_oauth') {
|
||||
return {
|
||||
token: 'enc:expired-access-token',
|
||||
expiresAt: expiredDate,
|
||||
createdAt: new Date(Date.now() - 120000),
|
||||
};
|
||||
}
|
||||
if (filter.type === 'mcp_oauth_refresh') {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
}),
|
||||
).rejects.toThrow('Re-authentication required');
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when access token is missing and no refresh token', async () => {
|
||||
const findToken = jest.fn().mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should not throw when access token is valid', async () => {
|
||||
const futureDate = new Date(Date.now() + 3600000);
|
||||
|
||||
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
|
||||
if (filter.type === 'mcp_oauth') {
|
||||
return {
|
||||
token: 'enc:valid-access-token',
|
||||
expiresAt: futureDate,
|
||||
createdAt: new Date(),
|
||||
};
|
||||
}
|
||||
if (filter.type === 'mcp_oauth_refresh') {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.access_token).toBe('valid-access-token');
|
||||
});
|
||||
});
|
||||
|
||||
describe('E2E: OAuth-gated MCP server with no refresh tokens', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should start OAuth-gated MCP server that validates Bearer tokens', async () => {
|
||||
const res = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'initialize', id: 1 }),
|
||||
});
|
||||
|
||||
expect(res.status).toBe(401);
|
||||
const body = (await res.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_token');
|
||||
});
|
||||
|
||||
it('should issue tokens via authorization code exchange with no refresh token', async () => {
|
||||
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
|
||||
expect(authRes.status).toBe(302);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
expect(code).toBeTruthy();
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(200);
|
||||
const tokenBody = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
refresh_token?: string;
|
||||
};
|
||||
expect(tokenBody.access_token).toBeTruthy();
|
||||
expect(tokenBody.token_type).toBe('Bearer');
|
||||
expect(tokenBody.refresh_token).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should allow MCP requests with valid Bearer token', async () => {
|
||||
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
const { access_token } = (await tokenRes.json()) as { access_token: string };
|
||||
|
||||
const mcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
expect(mcpRes.status).toBe(200);
|
||||
});
|
||||
|
||||
it('should reject expired tokens with 401', async () => {
|
||||
const shortTTLServer = await createOAuthMCPServer({ tokenTTLMs: 500 });
|
||||
|
||||
try {
|
||||
const authRes = await fetch(
|
||||
`${shortTTLServer.url}authorize?redirect_uri=http://localhost&state=s1`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
|
||||
const tokenRes = await fetch(`${shortTTLServer.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
const { access_token } = (await tokenRes.json()) as { access_token: string };
|
||||
|
||||
await new Promise((r) => setTimeout(r, 600));
|
||||
|
||||
const mcpRes = await fetch(shortTTLServer.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${access_token}`,
|
||||
},
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 2 }),
|
||||
});
|
||||
|
||||
expect(mcpRes.status).toBe(401);
|
||||
} finally {
|
||||
await shortTTLServer.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
654
packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts
Normal file
654
packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts
Normal file
|
|
@ -0,0 +1,654 @@
|
|||
/**
|
||||
* Tests for MCP OAuth token expiry → re-authentication scenarios.
|
||||
*
|
||||
* Reproduces the edge case where:
|
||||
* 1. Tokens are stored (access + refresh)
|
||||
* 2. Access token expires
|
||||
* 3. Refresh attempt fails (server rejects/revokes refresh token)
|
||||
* 4. System must fall back to full OAuth re-auth via handleOAuthRequired
|
||||
* 5. The CSRF cookie may be absent (chat/SSE flow), so the PENDING flow fallback is needed
|
||||
*
|
||||
* Also tests the happy path: access token expired but refresh succeeds.
|
||||
*/
|
||||
|
||||
import { Keyv } from 'keyv';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
|
||||
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { MockKeyv, InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
describe('MCP OAuth Token Expiry Scenarios', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Access token expired + refresh token available + refresh succeeds', () => {
|
||||
let server: OAuthTestServer;
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 500,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should refresh expired access token via real /token endpoint', async () => {
|
||||
// Get initial tokens from real server
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// Store expired access token directly (bypassing storeTokens' expiresIn clamping)
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: `enc:${initial.access_token}`,
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: `enc:${initial.refresh_token}`,
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const res = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(`Refresh failed: ${res.status}`);
|
||||
}
|
||||
const data = (await res.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
};
|
||||
return {
|
||||
...data,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + data.expires_in * 1000,
|
||||
};
|
||||
};
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.access_token).not.toBe(initial.access_token);
|
||||
|
||||
// Verify the refreshed token works against the server
|
||||
const mcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${result!.access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
expect(mcpRes.status).toBe(200);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Access token expired + refresh token rejected by server', () => {
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(() => {
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
it('should return null when refresh token is rejected (invalid_grant)', async () => {
|
||||
const server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
|
||||
try {
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// Store expired access token directly
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: `enc:${initial.access_token}`,
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: `enc:${initial.refresh_token}`,
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
// Simulate server revoking the refresh token
|
||||
server.issuedRefreshTokens.clear();
|
||||
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const res = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const body = (await res.json()) as { error: string };
|
||||
throw new Error(`Token refresh failed: ${body.error}`);
|
||||
}
|
||||
const data = (await res.json()) as MCPOAuthTokens;
|
||||
return data;
|
||||
};
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to refresh tokens'),
|
||||
expect.any(Error),
|
||||
);
|
||||
} finally {
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should return null when refresh endpoint returns unauthorized_client', async () => {
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: 'enc:some-refresh-token',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshCallback = jest
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('unauthorized_client: client not authorized for refresh'));
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('does not support refresh tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Access token expired + NO refresh token → ReauthenticationRequiredError', () => {
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(() => {
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when no refresh token stored', async () => {
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError with correct reason for expired token', async () => {
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
}),
|
||||
).rejects.toThrow('access token expired');
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError with correct reason for missing token', async () => {
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
}),
|
||||
).rejects.toThrow('access token missing');
|
||||
});
|
||||
});
|
||||
|
||||
describe('PENDING flow fallback for CSRF-less OAuth callbacks', () => {
|
||||
it('should allow OAuth completion when PENDING flow exists (simulating chat/SSE path)', async () => {
|
||||
const store = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
userId: 'user1',
|
||||
serverUrl: 'https://example.com',
|
||||
state: 'test-state',
|
||||
authorizationUrl: 'https://example.com/authorize?state=user1:test-server',
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('PENDING');
|
||||
|
||||
const tokens: MCPOAuthTokens = {
|
||||
access_token: 'new-access-token',
|
||||
token_type: 'Bearer',
|
||||
refresh_token: 'new-refresh-token',
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + 3600000,
|
||||
};
|
||||
|
||||
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', tokens);
|
||||
expect(completed).toBe(true);
|
||||
|
||||
const completedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(completedState?.status).toBe('COMPLETED');
|
||||
expect((completedState?.result as MCPOAuthTokens | undefined)?.access_token).toBe(
|
||||
'new-access-token',
|
||||
);
|
||||
});
|
||||
|
||||
it('should store authorizationUrl in flow metadata for re-issuance', async () => {
|
||||
const store = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
const authUrl = 'https://auth.example.com/authorize?client_id=abc&state=user1:test-server';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
userId: 'user1',
|
||||
serverUrl: 'https://example.com',
|
||||
state: 'test-state',
|
||||
authorizationUrl: authUrl,
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect((state?.metadata as Record<string, unknown>)?.authorizationUrl).toBe(authUrl);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Full token expiry → refresh failure → re-auth flow', () => {
|
||||
let server: OAuthTestServer;
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should go through full cycle: get tokens → expire → refresh fails → re-auth needed', async () => {
|
||||
// Step 1: Get initial tokens
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// Step 2: Store tokens with valid expiry first
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
tokens: initial,
|
||||
createToken: tokenStore.createToken,
|
||||
});
|
||||
|
||||
// Step 3: Verify tokens work
|
||||
const validResult = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
expect(validResult).not.toBeNull();
|
||||
expect(validResult!.access_token).toBe(initial.access_token);
|
||||
|
||||
// Step 4: Simulate token expiry by directly updating the stored token's expiresAt
|
||||
await tokenStore.updateToken({ userId: 'u1', identifier: 'mcp:test-srv' }, { expiresIn: -1 });
|
||||
|
||||
// Step 5: Revoke refresh token on server side (simulating server-side revocation)
|
||||
server.issuedRefreshTokens.clear();
|
||||
|
||||
// Step 6: Try to get tokens — refresh should fail, return null
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const res = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const body = (await res.json()) as { error: string };
|
||||
throw new Error(`Refresh failed: ${body.error}`);
|
||||
}
|
||||
const data = (await res.json()) as MCPOAuthTokens;
|
||||
return data;
|
||||
};
|
||||
|
||||
const expiredResult = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
// Refresh failed → returns null → triggers OAuth re-auth flow
|
||||
expect(expiredResult).toBeNull();
|
||||
|
||||
// Step 7: Simulate the re-auth flow via FlowStateManager
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
const flowId = 'u1:test-srv';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-srv',
|
||||
userId: 'u1',
|
||||
serverUrl: server.url,
|
||||
state: 'test-state',
|
||||
authorizationUrl: `${server.url}authorize?state=${flowId}`,
|
||||
});
|
||||
|
||||
// Step 8: Get a new auth code and exchange for tokens (simulating user re-auth)
|
||||
const newCode = await server.getAuthCode();
|
||||
const newTokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${newCode}`,
|
||||
});
|
||||
const newTokens = (await newTokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
// Step 9: Complete the flow
|
||||
const mcpTokens: MCPOAuthTokens = {
|
||||
...newTokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + newTokens.expires_in * 1000,
|
||||
};
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
|
||||
|
||||
// Step 10: Store the new tokens
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
tokens: mcpTokens,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
|
||||
// Step 11: Verify new tokens work
|
||||
const newResult = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
expect(newResult).not.toBeNull();
|
||||
expect(newResult!.access_token).toBe(newTokens.access_token);
|
||||
|
||||
// Step 12: Verify new token works against server
|
||||
const finalMcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${newResult!.access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
expect(finalMcpRes.status).toBe(200);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Concurrent token expiry with connection mutex', () => {
|
||||
it('should handle multiple concurrent getTokens calls when token is expired', async () => {
|
||||
const tokenStore = new InMemoryTokenStore();
|
||||
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: 'enc:valid-refresh',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
let refreshCallCount = 0;
|
||||
const refreshCallback = jest.fn().mockImplementation(async () => {
|
||||
refreshCallCount++;
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
return {
|
||||
access_token: `refreshed-token-${refreshCallCount}`,
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + 3600000,
|
||||
};
|
||||
});
|
||||
|
||||
// Fire 3 concurrent getTokens calls via FlowStateManager (like the connection mutex does)
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const getTokensViaFlow = () =>
|
||||
flowManager.createFlowWithHandler('u1:test-srv', 'mcp_get_tokens', async () => {
|
||||
return await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
});
|
||||
|
||||
const [r1, r2, r3] = await Promise.all([
|
||||
getTokensViaFlow(),
|
||||
getTokensViaFlow(),
|
||||
getTokensViaFlow(),
|
||||
]);
|
||||
|
||||
// All should get tokens (either directly or via flow coalescing)
|
||||
expect(r1).not.toBeNull();
|
||||
expect(r2).not.toBeNull();
|
||||
expect(r3).not.toBeNull();
|
||||
|
||||
// The refresh callback should only be called once due to flow coalescing
|
||||
expect(refreshCallback).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Stale PENDING flow detection', () => {
|
||||
it('should treat PENDING flows older than 2 minutes as stale', async () => {
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 300000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
authorizationUrl: 'https://example.com/auth',
|
||||
});
|
||||
|
||||
// Manually age the flow to 3 minutes
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (state) {
|
||||
state.createdAt = Date.now() - 3 * 60 * 1000;
|
||||
await (flowStore as unknown as { set: (k: string, v: unknown) => Promise<void> }).set(
|
||||
`mcp_oauth:${flowId}`,
|
||||
state,
|
||||
);
|
||||
}
|
||||
|
||||
const agedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(agedState?.status).toBe('PENDING');
|
||||
|
||||
const age = agedState?.createdAt ? Date.now() - agedState.createdAt : 0;
|
||||
expect(age).toBeGreaterThan(2 * 60 * 1000);
|
||||
|
||||
// A new flow should be created (the stale one would be deleted + recreated)
|
||||
// This verifies our staleness check threshold
|
||||
expect(age > PENDING_STALE_MS).toBe(true);
|
||||
});
|
||||
|
||||
it('should not treat recent PENDING flows as stale', async () => {
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 300000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
authorizationUrl: 'https://example.com/auth',
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
const age = state?.createdAt ? Date.now() - state.createdAt : Infinity;
|
||||
|
||||
expect(age < PENDING_STALE_MS).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
544
packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts
Normal file
544
packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
/**
|
||||
* Integration tests for MCPTokenStorage.storeTokens() and MCPTokenStorage.getTokens().
|
||||
*
|
||||
* Uses InMemoryTokenStore to exercise encrypt/decrypt round-trips, expiry calculation,
|
||||
* refresh callback wiring, and ReauthenticationRequiredError paths.
|
||||
*/
|
||||
|
||||
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { InMemoryTokenStore } from './helpers/oauthTestServer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
describe('MCPTokenStorage', () => {
|
||||
let store: InMemoryTokenStore;
|
||||
|
||||
beforeEach(() => {
|
||||
store = new InMemoryTokenStore();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('storeTokens', () => {
|
||||
it('should create new access token with expires_in', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved).not.toBeNull();
|
||||
expect(saved!.token).toBe('enc:at1');
|
||||
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
|
||||
expect(expiresInMs).toBeGreaterThan(3500 * 1000);
|
||||
expect(expiresInMs).toBeLessThanOrEqual(3600 * 1000);
|
||||
});
|
||||
|
||||
it('should create new access token with expires_at (MCPOAuthTokens format)', async () => {
|
||||
const expiresAt = Date.now() + 7200 * 1000;
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'at1',
|
||||
token_type: 'Bearer',
|
||||
expires_at: expiresAt,
|
||||
obtained_at: Date.now(),
|
||||
},
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved).not.toBeNull();
|
||||
const diff = Math.abs(saved!.expiresAt.getTime() - expiresAt);
|
||||
expect(diff).toBeLessThan(2000);
|
||||
});
|
||||
|
||||
it('should default to 1-year expiry when none provided', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer' },
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
|
||||
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
|
||||
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
|
||||
});
|
||||
|
||||
it('should update existing access token', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:old-token',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'new-token', token_type: 'Bearer', expires_in: 7200 },
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved!.token).toBe('enc:new-token');
|
||||
});
|
||||
|
||||
it('should store refresh token alongside access token', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'at1',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'rt1',
|
||||
},
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const refreshSaved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
});
|
||||
expect(refreshSaved).not.toBeNull();
|
||||
expect(refreshSaved!.token).toBe('enc:rt1');
|
||||
});
|
||||
|
||||
it('should skip refresh token when not in response', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const refreshSaved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
});
|
||||
expect(refreshSaved).toBeNull();
|
||||
});
|
||||
|
||||
it('should store client info when provided', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] },
|
||||
});
|
||||
|
||||
const clientSaved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: 'mcp:srv1:client',
|
||||
});
|
||||
expect(clientSaved).not.toBeNull();
|
||||
expect(clientSaved!.token).toContain('enc:');
|
||||
expect(clientSaved!.token).toContain('cid');
|
||||
});
|
||||
|
||||
it('should use existingTokens to skip DB lookups', async () => {
|
||||
const findSpy = jest.fn();
|
||||
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
findToken: findSpy,
|
||||
existingTokens: {
|
||||
accessToken: null,
|
||||
refreshToken: null,
|
||||
clientInfoToken: null,
|
||||
},
|
||||
});
|
||||
|
||||
expect(findSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle invalid NaN expiry date', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'at1',
|
||||
token_type: 'Bearer',
|
||||
expires_at: NaN,
|
||||
obtained_at: Date.now(),
|
||||
},
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved).not.toBeNull();
|
||||
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
|
||||
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
|
||||
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokens', () => {
|
||||
it('should return valid non-expired tokens', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:valid-token',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.access_token).toBe('valid-token');
|
||||
expect(result!.token_type).toBe('Bearer');
|
||||
});
|
||||
|
||||
it('should return tokens with refresh token when available', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:at',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result!.refresh_token).toBe('rt');
|
||||
});
|
||||
|
||||
it('should return tokens without refresh token field when none stored', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:at',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result!.refresh_token).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when expired and no refresh', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when missing and no refresh', async () => {
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should refresh expired access token when refresh token and callback are available', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockResolvedValue({
|
||||
access_token: 'refreshed-at',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + 3600000,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.access_token).toBe('refreshed-at');
|
||||
expect(refreshTokens).toHaveBeenCalledWith(
|
||||
'rt',
|
||||
expect.objectContaining({ userId: 'u1', serverName: 'srv1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null when refresh fails', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockRejectedValue(new Error('refresh failed'));
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when no refreshTokens callback provided', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when no createToken callback provided', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
refreshTokens: jest.fn(),
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should pass client info to refreshTokens metadata', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: 'mcp:srv1:client',
|
||||
token: 'enc:{"client_id":"cid","client_secret":"csec"}',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockResolvedValue({
|
||||
access_token: 'new-at',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
});
|
||||
|
||||
await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(refreshTokens).toHaveBeenCalledWith(
|
||||
'rt',
|
||||
expect.objectContaining({
|
||||
clientInfo: expect.objectContaining({ client_id: 'cid' }),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle unauthorized_client refresh error', async () => {
|
||||
const { logger } = await import('@librechat/data-schemas');
|
||||
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockRejectedValue(new Error('unauthorized_client'));
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('does not support refresh tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('storeTokens + getTokens round-trip', () => {
|
||||
it('should store and retrieve tokens with full encrypt/decrypt cycle', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'my-access-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'my-refresh-token',
|
||||
},
|
||||
createToken: store.createToken,
|
||||
clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] },
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result!.access_token).toBe('my-access-token');
|
||||
expect(result!.refresh_token).toBe('my-refresh-token');
|
||||
expect(result!.token_type).toBe('Bearer');
|
||||
expect(result!.obtained_at).toBeDefined();
|
||||
expect(result!.expires_at).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
449
packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts
Normal file
449
packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import { randomUUID, createHash } from 'crypto';
|
||||
import { z } from 'zod';
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import type { FlowState } from '~/flow/types';
|
||||
import type { Socket } from 'net';
|
||||
|
||||
export class MockKeyv<T = unknown> {
|
||||
private store: Map<string, FlowState<T>>;
|
||||
|
||||
constructor() {
|
||||
this.store = new Map();
|
||||
}
|
||||
|
||||
async get(key: string): Promise<FlowState<T> | undefined> {
|
||||
return this.store.get(key);
|
||||
}
|
||||
|
||||
async set(key: string, value: FlowState<T>, _ttl?: number): Promise<true> {
|
||||
this.store.set(key, value);
|
||||
return true;
|
||||
}
|
||||
|
||||
async delete(key: string): Promise<boolean> {
|
||||
return this.store.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
export function getFreePort(): Promise<number> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const srv = net.createServer();
|
||||
srv.listen(0, '127.0.0.1', () => {
|
||||
const addr = srv.address() as net.AddressInfo;
|
||||
srv.close((err) => (err ? reject(err) : resolve(addr.port)));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export function trackSockets(httpServer: http.Server): () => Promise<void> {
|
||||
const sockets = new Set<Socket>();
|
||||
httpServer.on('connection', (socket: Socket) => {
|
||||
sockets.add(socket);
|
||||
socket.once('close', () => sockets.delete(socket));
|
||||
});
|
||||
return () =>
|
||||
new Promise<void>((resolve) => {
|
||||
for (const socket of sockets) {
|
||||
socket.destroy();
|
||||
}
|
||||
sockets.clear();
|
||||
httpServer.close(() => resolve());
|
||||
});
|
||||
}
|
||||
|
||||
export interface OAuthTestServerOptions {
|
||||
tokenTTLMs?: number;
|
||||
issueRefreshTokens?: boolean;
|
||||
refreshTokenTTLMs?: number;
|
||||
rotateRefreshTokens?: boolean;
|
||||
}
|
||||
|
||||
export interface OAuthTestServer {
|
||||
url: string;
|
||||
port: number;
|
||||
close: () => Promise<void>;
|
||||
issuedTokens: Set<string>;
|
||||
tokenTTL: number;
|
||||
tokenIssueTimes: Map<string, number>;
|
||||
issuedRefreshTokens: Map<string, string>;
|
||||
registeredClients: Map<string, { client_id: string; client_secret: string }>;
|
||||
getAuthCode: () => Promise<string>;
|
||||
}
|
||||
|
||||
async function readRequestBody(req: http.IncomingMessage): Promise<string> {
|
||||
const chunks: Uint8Array[] = [];
|
||||
for await (const chunk of req) {
|
||||
chunks.push(chunk as Uint8Array);
|
||||
}
|
||||
return Buffer.concat(chunks).toString();
|
||||
}
|
||||
|
||||
function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null {
|
||||
if (contentType?.includes('application/x-www-form-urlencoded')) {
|
||||
return new URLSearchParams(body);
|
||||
}
|
||||
if (contentType?.includes('application/json')) {
|
||||
const json = JSON.parse(body) as Record<string, string>;
|
||||
return new URLSearchParams(json);
|
||||
}
|
||||
return new URLSearchParams(body);
|
||||
}
|
||||
|
||||
export async function createOAuthMCPServer(
|
||||
options: OAuthTestServerOptions = {},
|
||||
): Promise<OAuthTestServer> {
|
||||
const {
|
||||
tokenTTLMs = 60000,
|
||||
issueRefreshTokens = false,
|
||||
refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000,
|
||||
rotateRefreshTokens = false,
|
||||
} = options;
|
||||
|
||||
const sessions = new Map<string, StreamableHTTPServerTransport>();
|
||||
const issuedTokens = new Set<string>();
|
||||
const tokenIssueTimes = new Map<string, number>();
|
||||
const issuedRefreshTokens = new Map<string, string>();
|
||||
const refreshTokenIssueTimes = new Map<string, number>();
|
||||
const authCodes = new Map<string, { codeChallenge?: string; codeChallengeMethod?: string }>();
|
||||
const registeredClients = new Map<string, { client_id: string; client_secret: string }>();
|
||||
|
||||
let port = 0;
|
||||
|
||||
const httpServer = http.createServer(async (req, res) => {
|
||||
const url = new URL(req.url ?? '/', `http://${req.headers.host}`);
|
||||
|
||||
if (url.pathname === '/.well-known/oauth-authorization-server' && req.method === 'GET') {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(
|
||||
JSON.stringify({
|
||||
issuer: `http://127.0.0.1:${port}`,
|
||||
authorization_endpoint: `http://127.0.0.1:${port}/authorize`,
|
||||
token_endpoint: `http://127.0.0.1:${port}/token`,
|
||||
registration_endpoint: `http://127.0.0.1:${port}/register`,
|
||||
response_types_supported: ['code'],
|
||||
grant_types_supported: issueRefreshTokens
|
||||
? ['authorization_code', 'refresh_token']
|
||||
: ['authorization_code'],
|
||||
token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post'],
|
||||
code_challenge_methods_supported: ['S256'],
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (url.pathname === '/register' && req.method === 'POST') {
|
||||
const body = await readRequestBody(req);
|
||||
const data = JSON.parse(body) as { redirect_uris?: string[] };
|
||||
const clientId = `client-${randomUUID().slice(0, 8)}`;
|
||||
const clientSecret = `secret-${randomUUID()}`;
|
||||
registeredClients.set(clientId, { client_id: clientId, client_secret: clientSecret });
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(
|
||||
JSON.stringify({
|
||||
client_id: clientId,
|
||||
client_secret: clientSecret,
|
||||
redirect_uris: data.redirect_uris ?? [],
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (url.pathname === '/authorize') {
|
||||
const code = randomUUID();
|
||||
const codeChallenge = url.searchParams.get('code_challenge') ?? undefined;
|
||||
const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined;
|
||||
authCodes.set(code, { codeChallenge, codeChallengeMethod });
|
||||
const redirectUri = url.searchParams.get('redirect_uri') ?? '';
|
||||
const state = url.searchParams.get('state') ?? '';
|
||||
res.writeHead(302, {
|
||||
Location: `${redirectUri}?code=${code}&state=${state}`,
|
||||
});
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
|
||||
if (url.pathname === '/token' && req.method === 'POST') {
|
||||
const body = await readRequestBody(req);
|
||||
const params = parseTokenRequest(body, req.headers['content-type']);
|
||||
if (!params) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_request' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const grantType = params.get('grant_type');
|
||||
|
||||
if (grantType === 'authorization_code') {
|
||||
const code = params.get('code');
|
||||
const codeData = code ? authCodes.get(code) : undefined;
|
||||
if (!code || !codeData) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
|
||||
if (codeData.codeChallenge) {
|
||||
const codeVerifier = params.get('code_verifier');
|
||||
if (!codeVerifier) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
if (codeData.codeChallengeMethod === 'S256') {
|
||||
const expected = createHash('sha256').update(codeVerifier).digest('base64url');
|
||||
if (expected !== codeData.codeChallenge) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
authCodes.delete(code);
|
||||
|
||||
const accessToken = randomUUID();
|
||||
issuedTokens.add(accessToken);
|
||||
tokenIssueTimes.set(accessToken, Date.now());
|
||||
|
||||
const tokenResponse: Record<string, string | number> = {
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
expires_in: Math.ceil(tokenTTLMs / 1000),
|
||||
};
|
||||
|
||||
if (issueRefreshTokens) {
|
||||
const refreshToken = randomUUID();
|
||||
issuedRefreshTokens.set(refreshToken, accessToken);
|
||||
refreshTokenIssueTimes.set(refreshToken, Date.now());
|
||||
tokenResponse.refresh_token = refreshToken;
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(tokenResponse));
|
||||
return;
|
||||
}
|
||||
|
||||
if (grantType === 'refresh_token' && issueRefreshTokens) {
|
||||
const refreshToken = params.get('refresh_token');
|
||||
if (!refreshToken || !issuedRefreshTokens.has(refreshToken)) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const issueTime = refreshTokenIssueTimes.get(refreshToken) ?? 0;
|
||||
if (Date.now() - issueTime > refreshTokenTTLMs) {
|
||||
issuedRefreshTokens.delete(refreshToken);
|
||||
refreshTokenIssueTimes.delete(refreshToken);
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const newAccessToken = randomUUID();
|
||||
issuedTokens.add(newAccessToken);
|
||||
tokenIssueTimes.set(newAccessToken, Date.now());
|
||||
|
||||
const tokenResponse: Record<string, string | number> = {
|
||||
access_token: newAccessToken,
|
||||
token_type: 'Bearer',
|
||||
expires_in: Math.ceil(tokenTTLMs / 1000),
|
||||
};
|
||||
|
||||
if (rotateRefreshTokens) {
|
||||
issuedRefreshTokens.delete(refreshToken);
|
||||
refreshTokenIssueTimes.delete(refreshToken);
|
||||
const newRefreshToken = randomUUID();
|
||||
issuedRefreshTokens.set(newRefreshToken, newAccessToken);
|
||||
refreshTokenIssueTimes.set(newRefreshToken, Date.now());
|
||||
tokenResponse.refresh_token = newRefreshToken;
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(tokenResponse));
|
||||
return;
|
||||
}
|
||||
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'unsupported_grant_type' }));
|
||||
return;
|
||||
}
|
||||
|
||||
// All other paths require Bearer token auth
|
||||
const authHeader = req.headers.authorization;
|
||||
if (!authHeader || !authHeader.startsWith('Bearer ')) {
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_token' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const token = authHeader.slice(7);
|
||||
if (!issuedTokens.has(token)) {
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_token' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const issueTime = tokenIssueTimes.get(token) ?? 0;
|
||||
if (Date.now() - issueTime > tokenTTLMs) {
|
||||
issuedTokens.delete(token);
|
||||
tokenIssueTimes.delete(token);
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_token' }));
|
||||
return;
|
||||
}
|
||||
|
||||
// Authenticated MCP request — route to transport
|
||||
const sid = req.headers['mcp-session-id'] as string | undefined;
|
||||
let transport = sid ? sessions.get(sid) : undefined;
|
||||
|
||||
if (!transport) {
|
||||
transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => randomUUID(),
|
||||
});
|
||||
const mcp = new McpServer({ name: 'oauth-test-server', version: '0.0.1' });
|
||||
mcp.tool('echo', { message: z.string() }, async (args) => ({
|
||||
content: [{ type: 'text' as const, text: `echo: ${args.message}` }],
|
||||
}));
|
||||
await mcp.connect(transport);
|
||||
}
|
||||
|
||||
await transport.handleRequest(req, res);
|
||||
|
||||
if (transport.sessionId && !sessions.has(transport.sessionId)) {
|
||||
sessions.set(transport.sessionId, transport);
|
||||
transport.onclose = () => sessions.delete(transport!.sessionId!);
|
||||
}
|
||||
});
|
||||
|
||||
const destroySockets = trackSockets(httpServer);
|
||||
port = await getFreePort();
|
||||
await new Promise<void>((resolve) => httpServer.listen(port, '127.0.0.1', resolve));
|
||||
|
||||
return {
|
||||
url: `http://127.0.0.1:${port}/`,
|
||||
port,
|
||||
issuedTokens,
|
||||
tokenTTL: tokenTTLMs,
|
||||
tokenIssueTimes,
|
||||
issuedRefreshTokens,
|
||||
registeredClients,
|
||||
getAuthCode: async () => {
|
||||
const authRes = await fetch(
|
||||
`http://127.0.0.1:${port}/authorize?redirect_uri=http://localhost&state=test`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
return new URL(location).searchParams.get('code') ?? '';
|
||||
},
|
||||
close: async () => {
|
||||
const closing = [...sessions.values()].map((t) => t.close().catch(() => undefined));
|
||||
sessions.clear();
|
||||
await Promise.all(closing);
|
||||
await destroySockets();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export interface InMemoryToken {
|
||||
userId: string;
|
||||
type: string;
|
||||
identifier: string;
|
||||
token: string;
|
||||
expiresAt: Date;
|
||||
createdAt: Date;
|
||||
metadata?: Map<string, unknown> | Record<string, unknown>;
|
||||
}
|
||||
|
||||
export class InMemoryTokenStore {
|
||||
private tokens: Map<string, InMemoryToken> = new Map();
|
||||
|
||||
private key(filter: { userId?: string; type?: string; identifier?: string }): string {
|
||||
return `${filter.userId}:${filter.type}:${filter.identifier}`;
|
||||
}
|
||||
|
||||
findToken = async (filter: {
|
||||
userId?: string;
|
||||
type?: string;
|
||||
identifier?: string;
|
||||
}): Promise<InMemoryToken | null> => {
|
||||
for (const token of this.tokens.values()) {
|
||||
const matchUserId = !filter.userId || token.userId === filter.userId;
|
||||
const matchType = !filter.type || token.type === filter.type;
|
||||
const matchIdentifier = !filter.identifier || token.identifier === filter.identifier;
|
||||
if (matchUserId && matchType && matchIdentifier) {
|
||||
return token;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
createToken = async (data: {
|
||||
userId: string;
|
||||
type: string;
|
||||
identifier: string;
|
||||
token: string;
|
||||
expiresIn?: number;
|
||||
metadata?: Record<string, unknown>;
|
||||
}): Promise<InMemoryToken> => {
|
||||
const expiresIn = data.expiresIn ?? 365 * 24 * 60 * 60;
|
||||
const token: InMemoryToken = {
|
||||
userId: data.userId,
|
||||
type: data.type,
|
||||
identifier: data.identifier,
|
||||
token: data.token,
|
||||
expiresAt: new Date(Date.now() + expiresIn * 1000),
|
||||
createdAt: new Date(),
|
||||
metadata: data.metadata,
|
||||
};
|
||||
this.tokens.set(this.key(data), token);
|
||||
return token;
|
||||
};
|
||||
|
||||
updateToken = async (
|
||||
filter: { userId?: string; type?: string; identifier?: string },
|
||||
data: {
|
||||
userId?: string;
|
||||
type?: string;
|
||||
identifier?: string;
|
||||
token?: string;
|
||||
expiresIn?: number;
|
||||
metadata?: Record<string, unknown>;
|
||||
},
|
||||
): Promise<InMemoryToken> => {
|
||||
const existing = await this.findToken(filter);
|
||||
if (!existing) {
|
||||
throw new Error(`Token not found for filter: ${JSON.stringify(filter)}`);
|
||||
}
|
||||
const existingKey = this.key(existing);
|
||||
const expiresIn =
|
||||
data.expiresIn ?? Math.floor((existing.expiresAt.getTime() - Date.now()) / 1000);
|
||||
const updated: InMemoryToken = {
|
||||
...existing,
|
||||
token: data.token ?? existing.token,
|
||||
expiresAt: data.expiresIn ? new Date(Date.now() + expiresIn * 1000) : existing.expiresAt,
|
||||
metadata: data.metadata ?? existing.metadata,
|
||||
};
|
||||
this.tokens.set(existingKey, updated);
|
||||
return updated;
|
||||
};
|
||||
|
||||
deleteToken = async (filter: {
|
||||
userId: string;
|
||||
type: string;
|
||||
identifier: string;
|
||||
}): Promise<void> => {
|
||||
this.tokens.delete(this.key(filter));
|
||||
};
|
||||
|
||||
getAll(): InMemoryToken[] {
|
||||
return [...this.tokens.values()];
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.tokens.clear();
|
||||
}
|
||||
}
|
||||
|
|
@ -12,8 +12,12 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
|||
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import type { Socket } from 'net';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
import { OAuthReconnectionTracker } from '~/mcp/oauth/OAuthReconnectionTracker';
|
||||
import { createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { mcpConfig } from '~/mcp/mcpConfig';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -143,16 +147,17 @@ afterEach(() => {
|
|||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Fix #2 — Circuit breaker trips after rapid connect/disconnect */
|
||||
/* cycles (5 cycles within 60s -> 30s cooldown) */
|
||||
/* cycles (CB_MAX_CYCLES within window -> cooldown) */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => {
|
||||
it('blocks connection after 5 rapid cycles via static circuit breaker', async () => {
|
||||
it('blocks connection after CB_MAX_CYCLES rapid cycles via static circuit breaker', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('cycling-server', srv.url);
|
||||
|
||||
let completedCycles = 0;
|
||||
let breakerMessage = '';
|
||||
for (let cycle = 0; cycle < 10; cycle++) {
|
||||
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
|
||||
for (let cycle = 0; cycle < maxAttempts; cycle++) {
|
||||
try {
|
||||
await conn.connect();
|
||||
await teardownConnection(conn);
|
||||
|
|
@ -166,7 +171,7 @@ describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => {
|
|||
}
|
||||
|
||||
expect(breakerMessage).toContain('Circuit breaker is open');
|
||||
expect(completedCycles).toBeLessThanOrEqual(5);
|
||||
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
|
||||
|
||||
await srv.close();
|
||||
});
|
||||
|
|
@ -266,12 +271,13 @@ describe('Fix #4: Circuit breaker state persists across instance replacement', (
|
|||
/* recordFailedRound() in the catch path */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #5: Dead server triggers circuit breaker', () => {
|
||||
it('3 failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => {
|
||||
it('failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => {
|
||||
const conn = createConnection('dead', 'http://127.0.0.1:1/mcp', 1000);
|
||||
const spy = jest.spyOn(conn.client, 'connect');
|
||||
|
||||
const totalAttempts = mcpConfig.CB_MAX_FAILED_ROUNDS + 2;
|
||||
const errors: string[] = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
for (let i = 0; i < totalAttempts; i++) {
|
||||
try {
|
||||
await conn.connect();
|
||||
} catch (e) {
|
||||
|
|
@ -279,8 +285,8 @@ describe('Fix #5: Dead server triggers circuit breaker', () => {
|
|||
}
|
||||
}
|
||||
|
||||
expect(spy.mock.calls.length).toBe(3);
|
||||
expect(errors).toHaveLength(5);
|
||||
expect(spy.mock.calls.length).toBe(mcpConfig.CB_MAX_FAILED_ROUNDS);
|
||||
expect(errors).toHaveLength(totalAttempts);
|
||||
expect(errors.filter((m) => m.includes('Circuit breaker is open'))).toHaveLength(2);
|
||||
|
||||
await conn.disconnect();
|
||||
|
|
@ -295,7 +301,7 @@ describe('Fix #5: Dead server triggers circuit breaker', () => {
|
|||
userId: 'user-A',
|
||||
});
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
|
||||
try {
|
||||
await userA.connect();
|
||||
} catch {
|
||||
|
|
@ -332,7 +338,7 @@ describe('Fix #5: Dead server triggers circuit breaker', () => {
|
|||
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
|
||||
userId: 'user-A',
|
||||
});
|
||||
for (let i = 0; i < 3; i++) {
|
||||
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
|
||||
try {
|
||||
await userA.connect();
|
||||
} catch {
|
||||
|
|
@ -448,13 +454,14 @@ describe('Fix #6: OAuth failure uses cooldown-based retry', () => {
|
|||
/* Integration: Circuit breaker caps rapid cycling with real transport */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Cascade: Circuit breaker caps rapid cycling', () => {
|
||||
it('breaker trips before 10 cycles complete against a live server', async () => {
|
||||
it('breaker trips before double CB_MAX_CYCLES complete against a live server', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('cascade', srv.url);
|
||||
const spy = jest.spyOn(conn.client, 'connect');
|
||||
|
||||
let completedCycles = 0;
|
||||
for (let i = 0; i < 10; i++) {
|
||||
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
|
||||
for (let i = 0; i < maxAttempts; i++) {
|
||||
try {
|
||||
await conn.connect();
|
||||
await teardownConnection(conn);
|
||||
|
|
@ -469,8 +476,8 @@ describe('Cascade: Circuit breaker caps rapid cycling', () => {
|
|||
}
|
||||
}
|
||||
|
||||
expect(completedCycles).toBeLessThanOrEqual(5);
|
||||
expect(spy.mock.calls.length).toBeLessThanOrEqual(5);
|
||||
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
|
||||
expect(spy.mock.calls.length).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
|
||||
|
||||
await srv.close();
|
||||
});
|
||||
|
|
@ -501,6 +508,146 @@ describe('Cascade: Circuit breaker caps rapid cycling', () => {
|
|||
}, 30_000);
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* OAuth: cycle recovery after successful OAuth reconnect */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('OAuth: cycle budget recovery after successful OAuth', () => {
|
||||
let oauthServer: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
oauthServer = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await oauthServer.close();
|
||||
});
|
||||
|
||||
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
|
||||
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
const tokenRes = await fetch(`${serverUrl}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const data = (await tokenRes.json()) as { access_token: string };
|
||||
return data.access_token;
|
||||
}
|
||||
|
||||
it('should decrement cycle count after successful OAuth recovery', async () => {
|
||||
const serverName = 'oauth-cycle-test';
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
|
||||
const conn = new MCPConnection({
|
||||
serverName,
|
||||
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
// When oauthRequired fires, get a token and emit oauthHandled
|
||||
// This triggers the oauthRecovery path inside connectClient
|
||||
conn.on('oauthRequired', async () => {
|
||||
const accessToken = await exchangeCodeForToken(oauthServer.url);
|
||||
conn.setOAuthTokens({
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
conn.emit('oauthHandled');
|
||||
});
|
||||
|
||||
// connect() → 401 → oauthRequired → oauthHandled → connectClient returns
|
||||
// connect() sees not connected → throws "Connection not established"
|
||||
await expect(conn.connect()).rejects.toThrow('Connection not established');
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
|
||||
const cyclesBeforeRetry = cb.cycleCount;
|
||||
|
||||
// Retry — should succeed and decrement cycle count via oauthRecovery
|
||||
await conn.connect();
|
||||
expect(await conn.isConnected()).toBe(true);
|
||||
|
||||
const cyclesAfterSuccess = cb.cycleCount;
|
||||
// The retry adds +1 cycle (disconnect(false)) then -1 (oauthRecovery decrement)
|
||||
// So cyclesAfterSuccess should equal cyclesBeforeRetry, not cyclesBeforeRetry + 1
|
||||
expect(cyclesAfterSuccess).toBe(cyclesBeforeRetry);
|
||||
|
||||
await teardownConnection(conn);
|
||||
});
|
||||
|
||||
it('should allow more OAuth reconnects than non-OAuth before breaker trips', async () => {
|
||||
const serverName = 'oauth-budget';
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
|
||||
// Each OAuth flow: connect (+1) → 401 → oauthHandled → retry connect (+1) → success (-1) = net 1
|
||||
// Without the decrement it would be net 2 per flow, tripping the breaker after ~2 users
|
||||
let successfulFlows = 0;
|
||||
for (let i = 0; i < 10; i++) {
|
||||
const conn = new MCPConnection({
|
||||
serverName,
|
||||
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
|
||||
userId: `user-${i}`,
|
||||
});
|
||||
|
||||
conn.on('oauthRequired', async () => {
|
||||
const accessToken = await exchangeCodeForToken(oauthServer.url);
|
||||
conn.setOAuthTokens({
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
conn.emit('oauthHandled');
|
||||
});
|
||||
|
||||
try {
|
||||
// First connect: 401 → oauthHandled → returns without connection
|
||||
await conn.connect().catch(() => {});
|
||||
// Retry: succeeds with token, decrements cycle
|
||||
await conn.connect();
|
||||
successfulFlows++;
|
||||
await teardownConnection(conn);
|
||||
} catch (e) {
|
||||
conn.removeAllListeners();
|
||||
if ((e as Error).message.includes('Circuit breaker is open')) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// With the oauthRecovery decrement, each flow is net ~1 cycle instead of ~2,
|
||||
// so we should get more successful flows before the breaker trips
|
||||
expect(successfulFlows).toBeGreaterThanOrEqual(3);
|
||||
});
|
||||
|
||||
it('should not decrement cycle count when OAuth fails', async () => {
|
||||
const serverName = 'oauth-failed-no-decrement';
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
|
||||
const conn = new MCPConnection({
|
||||
serverName,
|
||||
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
conn.on('oauthRequired', () => {
|
||||
conn.emit('oauthFailed', new Error('user denied'));
|
||||
});
|
||||
|
||||
await expect(conn.connect()).rejects.toThrow();
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
|
||||
const cyclesAfterFailure = cb.cycleCount;
|
||||
|
||||
// connect() recorded +1 cycle, oauthFailed should NOT decrement
|
||||
expect(cyclesAfterFailure).toBeGreaterThanOrEqual(1);
|
||||
|
||||
conn.removeAllListeners();
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Sanity: Real transport works end-to-end */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
|
|
|||
|
|
@ -82,14 +82,6 @@ interface CircuitBreakerState {
|
|||
failedBackoffUntil: number;
|
||||
}
|
||||
|
||||
const CB_MAX_CYCLES = 5;
|
||||
const CB_CYCLE_WINDOW_MS = 60_000;
|
||||
const CB_CYCLE_COOLDOWN_MS = 30_000;
|
||||
|
||||
const CB_MAX_FAILED_ROUNDS = 3;
|
||||
const CB_FAILED_WINDOW_MS = 120_000;
|
||||
const CB_BASE_BACKOFF_MS = 30_000;
|
||||
const CB_MAX_BACKOFF_MS = 300_000;
|
||||
/** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */
|
||||
const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES;
|
||||
|
||||
|
|
@ -281,6 +273,7 @@ export class MCPConnection extends EventEmitter {
|
|||
private oauthTokens?: MCPOAuthTokens | null;
|
||||
private requestHeaders?: Record<string, string> | null;
|
||||
private oauthRequired = false;
|
||||
private oauthRecovery = false;
|
||||
private readonly useSSRFProtection: boolean;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
|
|
@ -325,17 +318,17 @@ export class MCPConnection extends EventEmitter {
|
|||
private recordCycle(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
|
||||
if (now - cb.cycleWindowStart > mcpConfig.CB_CYCLE_WINDOW_MS) {
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
}
|
||||
cb.cycleCount++;
|
||||
if (cb.cycleCount >= CB_MAX_CYCLES) {
|
||||
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
|
||||
if (cb.cycleCount >= mcpConfig.CB_MAX_CYCLES) {
|
||||
cb.cooldownUntil = now + mcpConfig.CB_CYCLE_COOLDOWN_MS;
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${CB_CYCLE_COOLDOWN_MS}ms`,
|
||||
`${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${mcpConfig.CB_CYCLE_COOLDOWN_MS}ms`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -343,15 +336,16 @@ export class MCPConnection extends EventEmitter {
|
|||
private recordFailedRound(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
|
||||
if (now - cb.failedWindowStart > mcpConfig.CB_FAILED_WINDOW_MS) {
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = now;
|
||||
}
|
||||
cb.failedRounds++;
|
||||
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
|
||||
if (cb.failedRounds >= mcpConfig.CB_MAX_FAILED_ROUNDS) {
|
||||
const backoff = Math.min(
|
||||
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
|
||||
CB_MAX_BACKOFF_MS,
|
||||
mcpConfig.CB_BASE_BACKOFF_MS *
|
||||
Math.pow(2, cb.failedRounds - mcpConfig.CB_MAX_FAILED_ROUNDS),
|
||||
mcpConfig.CB_MAX_BACKOFF_MS,
|
||||
);
|
||||
cb.failedBackoffUntil = now + backoff;
|
||||
logger.warn(
|
||||
|
|
@ -367,6 +361,13 @@ export class MCPConnection extends EventEmitter {
|
|||
cb.failedBackoffUntil = 0;
|
||||
}
|
||||
|
||||
public static decrementCycleCount(serverName: string): void {
|
||||
const cb = MCPConnection.circuitBreakers.get(serverName);
|
||||
if (cb && cb.cycleCount > 0) {
|
||||
cb.cycleCount--;
|
||||
}
|
||||
}
|
||||
|
||||
setRequestHeaders(headers: Record<string, string> | null): void {
|
||||
if (!headers) {
|
||||
return;
|
||||
|
|
@ -816,6 +817,13 @@ export class MCPConnection extends EventEmitter {
|
|||
this.emit('connectionChange', 'connected');
|
||||
this.reconnectAttempts = 0;
|
||||
this.resetFailedRounds();
|
||||
if (this.oauthRecovery) {
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
this.oauthRecovery = false;
|
||||
logger.debug(
|
||||
`${this.getLogPrefix()} OAuth recovery: decremented cycle count after successful reconnect`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
// Check if it's a rate limit error - stop immediately to avoid making it worse
|
||||
if (this.isRateLimitError(error)) {
|
||||
|
|
@ -899,9 +907,8 @@ export class MCPConnection extends EventEmitter {
|
|||
try {
|
||||
// Wait for OAuth to be handled
|
||||
await oauthHandledPromise;
|
||||
// Reset the oauthRequired flag
|
||||
this.oauthRequired = false;
|
||||
// Don't throw the error - just return so connection can be retried
|
||||
this.oauthRecovery = true;
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} OAuth handled successfully, connection will be retried`,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -12,4 +12,18 @@ export const mcpConfig = {
|
|||
USER_CONNECTION_IDLE_TIMEOUT: math(
|
||||
process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000,
|
||||
),
|
||||
/** Max connect/disconnect cycles before the circuit breaker trips. Default: 7 */
|
||||
CB_MAX_CYCLES: math(process.env.MCP_CB_MAX_CYCLES ?? 7),
|
||||
/** Sliding window (ms) for counting cycles. Default: 45s */
|
||||
CB_CYCLE_WINDOW_MS: math(process.env.MCP_CB_CYCLE_WINDOW_MS ?? 45_000),
|
||||
/** Cooldown (ms) after the cycle breaker trips. Default: 15s */
|
||||
CB_CYCLE_COOLDOWN_MS: math(process.env.MCP_CB_CYCLE_COOLDOWN_MS ?? 15_000),
|
||||
/** Max consecutive failed connection rounds before backoff. Default: 3 */
|
||||
CB_MAX_FAILED_ROUNDS: math(process.env.MCP_CB_MAX_FAILED_ROUNDS ?? 3),
|
||||
/** Sliding window (ms) for counting failed rounds. Default: 120s */
|
||||
CB_FAILED_WINDOW_MS: math(process.env.MCP_CB_FAILED_WINDOW_MS ?? 120_000),
|
||||
/** Base backoff (ms) after failed round threshold is reached. Default: 30s */
|
||||
CB_BASE_BACKOFF_MS: math(process.env.MCP_CB_BASE_BACKOFF_MS ?? 30_000),
|
||||
/** Max backoff cap (ms) for exponential backoff. Default: 300s */
|
||||
CB_MAX_BACKOFF_MS: math(process.env.MCP_CB_MAX_BACKOFF_MS ?? 300_000),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -426,8 +426,8 @@ export class MCPOAuthHandler {
|
|||
scope: config.scope,
|
||||
});
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
/** Add cryptographic state parameter to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', state);
|
||||
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
|
||||
|
||||
const flowMetadata: MCPOAuthFlowMetadata = {
|
||||
|
|
@ -505,8 +505,8 @@ export class MCPOAuthHandler {
|
|||
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
|
||||
);
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
/** Add cryptographic state parameter to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', state);
|
||||
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
|
||||
|
||||
if (resourceMetadata?.resource != null && resourceMetadata.resource) {
|
||||
|
|
@ -672,6 +672,44 @@ export class MCPOAuthHandler {
|
|||
return randomBytes(32).toString('base64url');
|
||||
}
|
||||
|
||||
private static readonly STATE_MAP_TYPE = 'mcp_oauth_state';
|
||||
|
||||
/**
|
||||
* Stores a mapping from the opaque OAuth state parameter to the flowId.
|
||||
* This allows the callback to resolve the flowId from an unguessable state
|
||||
* value, preventing attackers from forging callback requests.
|
||||
*/
|
||||
static async storeStateMapping(
|
||||
state: string,
|
||||
flowId: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
): Promise<void> {
|
||||
await flowManager.initFlow(state, this.STATE_MAP_TYPE, { flowId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves an opaque OAuth state parameter back to the original flowId.
|
||||
* Returns null if the state is not found (expired or never stored).
|
||||
*/
|
||||
static async resolveStateToFlowId(
|
||||
state: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
): Promise<string | null> {
|
||||
const mapping = await flowManager.getFlowState(state, this.STATE_MAP_TYPE);
|
||||
return (mapping?.metadata?.flowId as string) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes an orphaned state mapping when a flow is replaced.
|
||||
* Prevents old authorization URLs from resolving after a flow restart.
|
||||
*/
|
||||
static async deleteStateMapping(
|
||||
state: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
): Promise<void> {
|
||||
await flowManager.deleteFlow(state, this.STATE_MAP_TYPE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the default redirect URI for a server
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -4,6 +4,15 @@ import type { TokenMethods, IToken } from '@librechat/data-schemas';
|
|||
import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types';
|
||||
import { isSystemUserId } from '~/mcp/enum';
|
||||
|
||||
export class ReauthenticationRequiredError extends Error {
|
||||
constructor(serverName: string, reason: 'expired' | 'missing') {
|
||||
super(
|
||||
`Re-authentication required for "${serverName}": access token ${reason} and no refresh token available`,
|
||||
);
|
||||
this.name = 'ReauthenticationRequiredError';
|
||||
}
|
||||
}
|
||||
|
||||
interface StoreTokensParams {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
|
|
@ -27,7 +36,12 @@ interface GetTokensParams {
|
|||
findToken: TokenMethods['findToken'];
|
||||
refreshTokens?: (
|
||||
refreshToken: string,
|
||||
metadata: { userId: string; serverName: string; identifier: string },
|
||||
metadata: {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
identifier: string;
|
||||
clientInfo?: OAuthClientInformation;
|
||||
},
|
||||
) => Promise<MCPOAuthTokens>;
|
||||
createToken?: TokenMethods['createToken'];
|
||||
updateToken?: TokenMethods['updateToken'];
|
||||
|
|
@ -273,10 +287,11 @@ export class MCPTokenStorage {
|
|||
});
|
||||
|
||||
if (!refreshTokenData) {
|
||||
const reason = isMissing ? 'missing' : 'expired';
|
||||
logger.info(
|
||||
`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'} and no refresh token available`,
|
||||
`${logPrefix} Access token ${reason} and no refresh token available — re-authentication required`,
|
||||
);
|
||||
return null;
|
||||
throw new ReauthenticationRequiredError(serverName, reason);
|
||||
}
|
||||
|
||||
if (!refreshTokens) {
|
||||
|
|
@ -395,6 +410,9 @@ export class MCPTokenStorage {
|
|||
logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`);
|
||||
return tokens;
|
||||
} catch (error) {
|
||||
if (error instanceof ReauthenticationRequiredError) {
|
||||
throw error;
|
||||
}
|
||||
logger.error(`${logPrefix} Failed to retrieve tokens`, error);
|
||||
return null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata {
|
|||
clientInfo?: OAuthClientInformation;
|
||||
metadata?: OAuthMetadata;
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
authorizationUrl?: string;
|
||||
}
|
||||
|
||||
export interface MCPOAuthTokens extends OAuthTokens {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue