LibreChat/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts
Danny Avila fcb344da47
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
🛂 fix: MCP OAuth Race Conditions, CSRF Fallback, and Token Expiry Handling (#12171)
* fix: Implement race conditions in MCP OAuth flow

- Added connection mutex to coalesce concurrent `getUserConnection` calls, preventing multiple simultaneous attempts.
- Enhanced flow state management to retry once when a flow state is missing, improving resilience against race conditions.
- Introduced `ReauthenticationRequiredError` for better error handling when access tokens are expired or missing.
- Updated tests to cover new race condition scenarios and ensure proper handling of OAuth flows.

* fix: Stale PENDING flow detection and OAuth URL re-issuance

PENDING flows in handleOAuthRequired now check createdAt age — flows
older than 2 minutes are treated as stale and replaced instead of
joined. Fixes the case where a leftover PENDING flow from a previous
session blocks new OAuth initiation.

authorizationUrl is now stored in MCPOAuthFlowMetadata so that when a
second caller joins an active PENDING flow (e.g., the SSE-emitting path
in ToolService), it can re-issue the URL to the user via oauthStart.

* fix: CSRF fallback via active PENDING flow in OAuth callback

When the OAuth callback arrives without CSRF or session cookies (common
in the chat/SSE flow where cookies can't be set on streaming responses),
fall back to validating that a PENDING flow exists for the flowId. This
is safe because the flow was created server-side after JWT authentication
and the authorization code is PKCE-protected.

* test: Extract shared OAuth test server helpers

Move MockKeyv, getFreePort, trackSockets, and createOAuthMCPServer into
a shared helpers/oauthTestServer module. Enhance the test server with
refresh token support, token rotation, metadata discovery, and dynamic
client registration endpoints. Add InMemoryTokenStore for token storage
tests.

Refactor MCPOAuthRaceCondition.test.ts to import from shared helpers.

* test: Add comprehensive MCP OAuth test modules

MCPOAuthTokenStorage — 21 tests for storeTokens/getTokens with
InMemoryTokenStore: encrypt/decrypt round-trips, expiry calculation,
refresh callback wiring, ReauthenticationRequiredError paths.

MCPOAuthFlow — 10 tests against real HTTP server: token refresh with
stored client info, refresh token rotation, metadata discovery, dynamic
client registration, full store/retrieve/expire/refresh lifecycle.

MCPOAuthConnectionEvents — 5 tests for MCPConnection OAuth event cycle
with real OAuth-gated MCP server: oauthRequired emission on 401,
oauthHandled reconnection, oauthFailed rejection, token expiry detection.

MCPOAuthTokenExpiry — 12 tests for the token expiry edge case: refresh
success/failure paths, ReauthenticationRequiredError, PENDING flow CSRF
fallback, authorizationUrl metadata storage, full re-auth cycle after
refresh failure, concurrent expired token coalescing, stale PENDING
flow detection.

* test: Enhance MCP OAuth connection tests with cooldown reset

Added a `beforeEach` hook to clear the cooldown for `MCPConnection` before each test, ensuring a clean state. Updated the race condition handling in the tests to properly clear the timeout, improving reliability in the event data retrieval process.

* refactor: PENDING flow management and state recovery in MCP OAuth

- Introduced a constant `PENDING_STALE_MS` to define the age threshold for PENDING flows, improving the handling of stale flows.
- Updated the logic in `MCPConnectionFactory` and `FlowStateManager` to check the age of PENDING flows before joining or reusing them.
- Modified the `completeFlow` method to return false when the flow state is deleted, ensuring graceful handling of race conditions.
- Enhanced tests to validate the new behavior and ensure robustness against state recovery issues.

* refactor: MCP OAuth flow management and testing

- Updated the `completeFlow` method to log warnings when a tool flow state is not found during completion, improving error handling.
- Introduced a new `normalizeExpiresAt` function to standardize expiration timestamp handling across the application.
- Refactored token expiration checks in `MCPConnectionFactory` to utilize the new normalization function, ensuring consistent behavior.
- Added a comprehensive test suite for OAuth callback CSRF fallback logic, validating the handling of PENDING flows and their staleness.
- Enhanced existing tests to cover new expiration normalization logic and ensure robust flow state management.

* test: Add CSRF fallback tests for active PENDING flows in MCP OAuth

- Introduced new tests to validate CSRF fallback behavior when a fresh PENDING flow exists without cookies, ensuring successful OAuth callback handling.
- Added scenarios to reject requests when no PENDING flow exists, when only a COMPLETED flow is present, and when a PENDING flow is stale, enhancing the robustness of flow state management.
- Improved overall test coverage for OAuth callback logic, reinforcing the handling of CSRF validation failures.

* chore: imports order

* refactor: Update UserConnectionManager to conditionally manage pending connections

- Modified the logic in `UserConnectionManager` to only set pending connections if `forceNew` is false, preventing unnecessary overwrites.
- Adjusted the cleanup process to ensure pending connections are only deleted when not forced, enhancing connection management efficiency.

* refactor: MCP OAuth flow state management

- Introduced a new method `storeStateMapping` in `MCPOAuthHandler` to securely map the OAuth state parameter to the flow ID, improving callback resolution and security against forgery.
- Updated the OAuth initiation and callback handling in `mcp.js` to utilize the new state mapping functionality, ensuring robust flow management.
- Refactored `MCPConnectionFactory` to store state mappings during flow initialization, enhancing the integrity of the OAuth process.
- Adjusted comments to clarify the purpose of state parameters in authorization URLs, reinforcing code readability.

* refactor: MCPConnection with OAuth recovery handling

- Added `oauthRecovery` flag to manage OAuth recovery state during connection attempts.
- Introduced `decrementCycleCount` method to reduce the circuit breaker's cycle count upon successful reconnection after OAuth recovery.
- Updated connection logic to reset the `oauthRecovery` flag after handling OAuth, improving state management and connection reliability.

* chore: Add debug logging for OAuth recovery cycle count decrement

- Introduced a debug log statement in the `MCPConnection` class to track the decrement of the cycle count after a successful reconnection during OAuth recovery.
- This enhancement improves observability and aids in troubleshooting connection issues related to OAuth recovery.

* test: Add OAuth recovery cycle management tests

- Introduced new tests for the OAuth recovery cycle in `MCPConnection`, validating the decrement of cycle counts after successful reconnections.
- Added scenarios to ensure that the cycle count is not decremented on OAuth failures, enhancing the robustness of connection management.
- Improved test coverage for OAuth reconnect scenarios, ensuring reliable behavior under various conditions.

* feat: Implement circuit breaker configuration in MCP

- Added circuit breaker settings to `.env.example` for max cycles, cycle window, and cooldown duration.
- Refactored `MCPConnection` to utilize the new configuration values from `mcpConfig`, enhancing circuit breaker management.
- Improved code maintainability by centralizing circuit breaker parameters in the configuration file.

* refactor: Update decrementCycleCount method for circuit breaker management

- Changed the visibility of the `decrementCycleCount` method in `MCPConnection` from private to public static, allowing it to be called with a server name parameter.
- Updated calls to `decrementCycleCount` in `MCPConnectionFactory` to use the new static method, improving clarity and consistency in circuit breaker management during connection failures and OAuth recovery.
- Enhanced the handling of circuit breaker state by ensuring the method checks for the existence of the circuit breaker before decrementing the cycle count.

* refactor: cycle count decrement on tool listing failure

- Added a call to `MCPConnection.decrementCycleCount` in the `MCPConnectionFactory` to handle cases where unauthenticated tool listing fails, improving circuit breaker management.
- This change ensures that the cycle count is decremented appropriately, maintaining the integrity of the connection recovery process.

* refactor: Update circuit breaker configuration and logic

- Enhanced circuit breaker settings in `.env.example` to include new parameters for failed rounds and backoff strategies.
- Refactored `MCPConnection` to utilize the updated configuration values from `mcpConfig`, improving circuit breaker management.
- Updated tests to reflect changes in circuit breaker logic, ensuring accurate validation of connection behavior under rapid reconnect scenarios.

* feat: Implement state mapping deletion in MCP flow management

- Added a new method `deleteStateMapping` in `MCPOAuthHandler` to remove orphaned state mappings when a flow is replaced, preventing old authorization URLs from resolving after a flow restart.
- Updated `MCPConnectionFactory` to call `deleteStateMapping` during flow cleanup, ensuring proper management of OAuth states.
- Enhanced test coverage for state mapping functionality to validate the new deletion logic.
2026-03-10 21:15:01 -04:00

853 lines
29 KiB
TypeScript

import { logger } from '@librechat/data-schemas';
import type { TokenMethods, IUser } from '@librechat/data-schemas';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import type * as t from '~/mcp/types';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPConnection } from '~/mcp/connection';
import { MCPOAuthHandler } from '~/mcp/oauth';
import { processMCPEnv } from '~/utils';
jest.mock('~/mcp/connection');
jest.mock('~/mcp/oauth');
jest.mock('~/utils');
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
const mockLogger = logger as jest.Mocked<typeof logger>;
const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction<typeof processMCPEnv>;
const mockMCPConnection = MCPConnection as jest.MockedClass<typeof MCPConnection>;
const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked<typeof MCPOAuthHandler>;
describe('MCPConnectionFactory', () => {
let mockUser: IUser | undefined;
let mockServerConfig: t.MCPOptions;
let mockFlowManager: jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
let mockConnectionInstance: jest.Mocked<MCPConnection>;
beforeEach(() => {
jest.clearAllMocks();
mockUser = {
id: 'user123',
email: 'test@example.com',
} as IUser;
mockServerConfig = {
command: 'node',
args: ['server.js'],
initTimeout: 5000,
} as t.MCPOptions;
mockFlowManager = {
initFlow: jest.fn().mockResolvedValue(undefined),
createFlow: jest.fn(),
createFlowWithHandler: jest.fn(),
getFlowState: jest.fn(),
deleteFlow: jest.fn().mockResolvedValue(true),
} as unknown as jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
mockConnectionInstance = {
connect: jest.fn(),
isConnected: jest.fn(),
setOAuthTokens: jest.fn(),
on: jest.fn().mockReturnValue(mockConnectionInstance),
once: jest.fn().mockReturnValue(mockConnectionInstance),
off: jest.fn().mockReturnValue(mockConnectionInstance),
removeListener: jest.fn().mockReturnValue(mockConnectionInstance),
emit: jest.fn(),
} as unknown as jest.Mocked<MCPConnection>;
mockMCPConnection.mockImplementation(() => mockConnectionInstance);
mockProcessMCPEnv.mockReturnValue(mockServerConfig);
});
describe('static create method', () => {
it('should create a basic connection without OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith({
options: mockServerConfig,
dbSourced: undefined,
});
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: undefined,
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
it('should create a connection with OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockTokens: MCPOAuthTokens = {
access_token: 'access123',
refresh_token: 'refresh123',
token_type: 'Bearer',
obtained_at: Date.now(),
};
mockFlowManager.createFlowWithHandler.mockResolvedValue(mockTokens);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith({
options: mockServerConfig,
user: mockUser,
dbSourced: undefined,
});
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: mockTokens,
useSSRFProtection: false,
});
});
});
describe('OAuth token handling', () => {
it('should return null when no findToken method is provided', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions: t.OAuthConnectionOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: undefined as unknown as TokenMethods['findToken'],
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(mockFlowManager.createFlowWithHandler).not.toHaveBeenCalled();
});
it('should handle token retrieval errors gracefully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockFlowManager.createFlowWithHandler.mockRejectedValue(new Error('Token fetch failed'));
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('No existing tokens found or error loading tokens'),
expect.any(Error),
);
});
});
describe('OAuth event handling', () => {
it('should handle oauthRequired event for returnOnOAuth scenario', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
returnOnOAuth: true,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'flow123',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'random-state',
clientInfo: { client_id: 'client123' },
},
};
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
// createFlow runs as a background monitor — simulate it staying pending
mockFlowManager.createFlow.mockReturnValue(new Promise(() => {}));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail due to connection not established
}
expect(oauthRequiredHandler!).toBeDefined();
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).toHaveBeenCalledWith(
'test-server',
'https://api.example.com',
'user123',
{},
undefined,
);
// initFlow must be awaited BEFORE the redirect to guarantee state is stored
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
'flow123',
'mcp_oauth',
expect.objectContaining(mockFlowData.flowMetadata),
);
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
.invocationCallOrder[0];
expect(initCallOrder).toBeLessThan(oauthStartCallOrder);
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({
message: 'OAuth flow initiated - return early',
}),
);
});
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
user: mockUser,
};
const oauthOptions: t.OAuthConnectionOptions = {
user: mockUser,
useOAuth: true,
returnOnOAuth: true,
oauthStart: jest.fn(),
flowManager: mockFlowManager,
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'PENDING',
type: 'mcp_oauth',
metadata: { codeVerifier: 'existing-verifier' },
createdAt: Date.now(),
});
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled();
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({ message: 'OAuth flow initiated - return early' }),
);
});
it('should emit oauthFailed when initFlow fails to store flow state (returnOnOAuth)', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
returnOnOAuth: true,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'flow123',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'state123',
},
};
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.initFlow.mockRejectedValue(new Error('Store write failed'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
// initFlow failed, so oauthStart should NOT have been called (redirect never happens)
expect(oauthOptions.oauthStart).not.toHaveBeenCalled();
// createFlow should NOT have been called since initFlow failed first
expect(mockFlowManager.createFlow).not.toHaveBeenCalled();
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({ message: 'OAuth initiation failed' }),
);
expect(mockLogger.error).toHaveBeenCalled();
});
it('should log warnings when background createFlow monitor rejects (returnOnOAuth)', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
returnOnOAuth: true,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'flow123',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'state123',
},
};
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
// Simulate the background monitor timing out
mockFlowManager.createFlow.mockRejectedValue(new Error('mcp_oauth flow timed out'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
// Allow the .catch handler on createFlow to execute
await Promise.resolve();
// initFlow should have succeeded and redirect should have happened
expect(mockFlowManager.initFlow).toHaveBeenCalled();
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
// The background monitor error should be logged, not silently swallowed
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('OAuth flow monitor ended'),
expect.any(Error),
);
});
it('should call initFlow before createFlow in blocking OAuth path (non-returnOnOAuth)', async () => {
const sseConfig = {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions;
const basicOptions = {
serverName: 'test-server',
serverConfig: sseConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
oauthStart: jest.fn(),
oauthEnd: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'flow123',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'random-state',
clientInfo: { client_id: 'client123' },
metadata: {
token_endpoint: 'https://auth.example.com/token',
authorization_endpoint: 'https://auth.example.com/authorize',
},
},
};
const mockTokens: MCPOAuthTokens = {
access_token: 'access123',
refresh_token: 'refresh123',
token_type: 'Bearer',
obtained_at: Date.now(),
};
// processMCPEnv must return config with url so handleOAuthRequired proceeds
mockProcessMCPEnv.mockReturnValue(sseConfig);
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockMCPOAuthHandler.generateFlowId.mockReturnValue('flow123');
mockFlowManager.getFlowState.mockResolvedValue(null);
mockFlowManager.createFlow.mockResolvedValue(mockTokens);
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
// initFlow must be called BEFORE oauthStart and createFlow
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
'flow123',
'mcp_oauth',
expect.objectContaining(mockFlowData.flowMetadata),
);
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
.invocationCallOrder[0];
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
expect(initCallOrder).toBeLessThan(oauthStartCallOrder);
expect(initCallOrder).toBeLessThan(createCallOrder);
// createFlow should receive {} since initFlow already persisted metadata
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
'flow123',
'mcp_oauth',
{},
undefined,
);
});
it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
user: mockUser,
};
const oauthOptions: t.OAuthConnectionOptions = {
user: mockUser,
useOAuth: true,
returnOnOAuth: true,
oauthStart: jest.fn(),
flowManager: mockFlowManager,
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'user123:test-server',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'test-state',
codeVerifier: 'new-code-verifier-xyz',
clientInfo: { client_id: 'test-client' },
metadata: {
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
issuer: 'https://api.example.com',
},
},
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'COMPLETED',
type: 'mcp_oauth',
metadata: { codeVerifier: 'old-verifier' },
createdAt: Date.now() - 60000,
});
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.deleteFlow.mockResolvedValue(true);
// createFlow runs as a background monitor — simulate it staying pending
mockFlowManager.createFlow.mockReturnValue(new Promise(() => {}));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth');
// initFlow must be called after deleteFlow and before createFlow
const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0];
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
expect(deleteCallOrder).toBeLessThan(initCallOrder);
expect(initCallOrder).toBeLessThan(createCallOrder);
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
'user123:test-server',
'mcp_oauth',
expect.objectContaining({
codeVerifier: 'new-code-verifier-xyz',
}),
);
// createFlow finds the existing PENDING state written by initFlow,
// so metadata arg is unused (passed as {})
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
'user123:test-server',
'mcp_oauth',
{},
undefined,
);
});
});
describe('connection retry logic', () => {
it('should establish connection successfully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig, // Use default 5000ms timeout
};
mockConnectionInstance.connect.mockResolvedValue(undefined);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockConnectionInstance.connect).toHaveBeenCalledTimes(1);
});
it('should handle OAuth errors during connection attempts', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const oauthError = new Error('Non-200 status code (401)');
(oauthError as unknown as Record<string, unknown>).isOAuthError = true;
mockConnectionInstance.connect.mockRejectedValue(oauthError);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow(
'Non-200 status code (401)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
describe('isOAuthError method', () => {
it('should identify OAuth errors by message content', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const error401 = new Error('401 Unauthorized');
mockConnectionInstance.connect.mockRejectedValue(error401);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow('401');
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
describe('discoverTools static method', () => {
const mockTools = [
{ name: 'tool1', description: 'First tool', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Second tool', inputSchema: { type: 'object' } },
];
it('should discover tools from a successfully connected server', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockResolvedValue(undefined);
mockConnectionInstance.isConnected.mockResolvedValue(true);
mockConnectionInstance.fetchTools = jest.fn().mockResolvedValue(mockTools);
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toEqual(mockTools);
expect(result.oauthRequired).toBe(false);
expect(result.oauthUrl).toBeNull();
expect(result.connection).toBe(mockConnectionInstance);
});
it('should detect OAuth required without generating URL in discovery mode', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const mockOAuthStart = jest.fn().mockResolvedValue(undefined);
const oauthOptions = {
useOAuth: true as const,
user: mockUser as unknown as IUser,
flowManager: mockFlowManager,
oauthStart: mockOAuthStart,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
let oauthHandler: (() => Promise<void>) | undefined;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthHandler = handler as () => Promise<void>;
}
return mockConnectionInstance;
});
mockConnectionInstance.connect.mockImplementation(async () => {
if (oauthHandler) {
await oauthHandler();
}
throw new Error('OAuth required');
});
const result = await MCPConnectionFactory.discoverTools(basicOptions, oauthOptions);
expect(result.connection).toBeNull();
expect(result.tools).toBeNull();
expect(result.oauthRequired).toBe(true);
expect(result.oauthUrl).toBeNull();
expect(mockOAuthStart).not.toHaveBeenCalled();
});
it('should return null tools when discovery fails completely', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toBeNull();
expect(result.connection).toBeNull();
expect(result.oauthRequired).toBe(false);
});
it('should handle disconnect errors gracefully during cleanup', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest
.fn()
.mockRejectedValue(new Error('Disconnect failed'));
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toBeNull();
expect(mockLogger.debug).toHaveBeenCalled();
});
});
});