LibreChat/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts
Danny Avila 599f4a11f1
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
🛡️ fix: Secure MCP/Actions OAuth Flows, Resolve Race Condition & Tool Cache Cleanup (#11756)
* 🔧 fix: Update OAuth error message for clarity

- Changed the default error message in the OAuth error route from 'Unknown error' to 'Unknown OAuth error' to provide clearer context during authentication failures.

* 🔒 feat: Enhance OAuth flow with CSRF protection and session management

- Implemented CSRF protection for OAuth flows by introducing `generateOAuthCsrfToken`, `setOAuthCsrfCookie`, and `validateOAuthCsrf` functions.
- Added session management for OAuth with `setOAuthSession` and `validateOAuthSession` middleware.
- Updated routes to bind CSRF tokens for MCP and action OAuth flows, ensuring secure authentication.
- Enhanced tests to validate CSRF handling and session management in OAuth processes.

* 🔧 refactor: Invalidate cached tools after user plugin disconnection

- Added a call to `invalidateCachedTools` in the `updateUserPluginsController` to ensure that cached tools are refreshed when a user disconnects from an MCP server after a plugin authentication update. This change improves the accuracy of tool data for users.

* chore: imports order

* fix: domain separator regex usage in ToolService

- Moved the declaration of `domainSeparatorRegex` to avoid redundancy in the `loadActionToolsForExecution` function, improving code clarity and performance.

* chore: OAuth flow error handling and CSRF token generation

- Enhanced the OAuth callback route to validate the flow ID format, ensuring proper error handling for invalid states.
- Updated the CSRF token generation function to require a JWT secret, throwing an error if not provided, which improves security and clarity in token generation.
- Adjusted tests to reflect changes in flow ID handling and ensure robust validation across various scenarios.
2026-02-12 14:22:05 -05:00

590 lines
20 KiB
TypeScript

import { logger } from '@librechat/data-schemas';
import type { TokenMethods, IUser } from '@librechat/data-schemas';
import type { FlowStateManager } from '~/flow/manager';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import type * as t from '~/mcp/types';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPConnection } from '~/mcp/connection';
import { MCPOAuthHandler } from '~/mcp/oauth';
import { processMCPEnv } from '~/utils';
jest.mock('~/mcp/connection');
jest.mock('~/mcp/oauth');
jest.mock('~/utils');
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
const mockLogger = logger as jest.Mocked<typeof logger>;
const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction<typeof processMCPEnv>;
const mockMCPConnection = MCPConnection as jest.MockedClass<typeof MCPConnection>;
const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked<typeof MCPOAuthHandler>;
describe('MCPConnectionFactory', () => {
let mockUser: IUser | undefined;
let mockServerConfig: t.MCPOptions;
let mockFlowManager: jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
let mockConnectionInstance: jest.Mocked<MCPConnection>;
beforeEach(() => {
jest.clearAllMocks();
mockUser = {
id: 'user123',
email: 'test@example.com',
} as IUser;
mockServerConfig = {
command: 'node',
args: ['server.js'],
initTimeout: 5000,
} as t.MCPOptions;
mockFlowManager = {
createFlow: jest.fn(),
createFlowWithHandler: jest.fn(),
getFlowState: jest.fn(),
deleteFlow: jest.fn().mockResolvedValue(true),
} as unknown as jest.Mocked<FlowStateManager<MCPOAuthTokens | null>>;
mockConnectionInstance = {
connect: jest.fn(),
isConnected: jest.fn(),
setOAuthTokens: jest.fn(),
on: jest.fn().mockReturnValue(mockConnectionInstance),
once: jest.fn().mockReturnValue(mockConnectionInstance),
off: jest.fn().mockReturnValue(mockConnectionInstance),
removeListener: jest.fn().mockReturnValue(mockConnectionInstance),
emit: jest.fn(),
} as unknown as jest.Mocked<MCPConnection>;
mockMCPConnection.mockImplementation(() => mockConnectionInstance);
mockProcessMCPEnv.mockReturnValue(mockServerConfig);
});
describe('static create method', () => {
it('should create a basic connection without OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith({ options: mockServerConfig });
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: undefined,
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
it('should create a connection with OAuth', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockTokens: MCPOAuthTokens = {
access_token: 'access123',
refresh_token: 'refresh123',
token_type: 'Bearer',
obtained_at: Date.now(),
};
mockFlowManager.createFlowWithHandler.mockResolvedValue(mockTokens);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockProcessMCPEnv).toHaveBeenCalledWith({ options: mockServerConfig, user: mockUser });
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: mockTokens,
useSSRFProtection: false,
});
});
});
describe('OAuth token handling', () => {
it('should return null when no findToken method is provided', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions: t.OAuthConnectionOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: undefined as unknown as TokenMethods['findToken'],
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(mockFlowManager.createFlowWithHandler).not.toHaveBeenCalled();
});
it('should handle token retrieval errors gracefully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockFlowManager.createFlowWithHandler.mockRejectedValue(new Error('Token fetch failed'));
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions, oauthOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockMCPConnection).toHaveBeenCalledWith({
serverName: 'test-server',
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('No existing tokens found or error loading tokens'),
expect.any(Error),
);
});
});
describe('OAuth event handling', () => {
it('should handle oauthRequired event for returnOnOAuth scenario', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
returnOnOAuth: true,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'flow123',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'random-state',
clientInfo: { client_id: 'client123' },
},
};
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail due to connection not established
}
expect(oauthRequiredHandler!).toBeDefined();
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).toHaveBeenCalledWith(
'test-server',
'https://api.example.com',
'user123',
{},
undefined,
);
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({
message: 'OAuth flow initiated - return early',
}),
);
});
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
user: mockUser,
};
const oauthOptions: t.OAuthConnectionOptions = {
user: mockUser,
useOAuth: true,
returnOnOAuth: true,
oauthStart: jest.fn(),
flowManager: mockFlowManager,
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'PENDING',
type: 'mcp_oauth',
metadata: { codeVerifier: 'existing-verifier' },
createdAt: Date.now(),
});
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled();
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({ message: 'OAuth flow initiated - return early' }),
);
});
it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
user: mockUser,
};
const oauthOptions: t.OAuthConnectionOptions = {
user: mockUser,
useOAuth: true,
returnOnOAuth: true,
oauthStart: jest.fn(),
flowManager: mockFlowManager,
};
const mockFlowData = {
authorizationUrl: 'https://auth.example.com',
flowId: 'user123:test-server',
flowMetadata: {
serverName: 'test-server',
userId: 'user123',
serverUrl: 'https://api.example.com',
state: 'test-state',
codeVerifier: 'new-code-verifier-xyz',
clientInfo: { client_id: 'test-client' },
metadata: {
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
issuer: 'https://api.example.com',
},
},
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'COMPLETED',
type: 'mcp_oauth',
metadata: { codeVerifier: 'old-verifier' },
createdAt: Date.now() - 60000,
});
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.deleteFlow.mockResolvedValue(true);
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth');
const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0];
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
expect(deleteCallOrder).toBeLessThan(createCallOrder);
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
'user123:test-server',
'mcp_oauth',
expect.objectContaining({
codeVerifier: 'new-code-verifier-xyz',
}),
undefined,
);
});
});
describe('connection retry logic', () => {
it('should establish connection successfully', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig, // Use default 5000ms timeout
};
mockConnectionInstance.connect.mockResolvedValue(undefined);
mockConnectionInstance.isConnected.mockResolvedValue(true);
const connection = await MCPConnectionFactory.create(basicOptions);
expect(connection).toBe(mockConnectionInstance);
expect(mockConnectionInstance.connect).toHaveBeenCalledTimes(1);
});
it('should handle OAuth errors during connection attempts', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
oauthStart: jest.fn(),
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const oauthError = new Error('Non-200 status code (401)');
(oauthError as unknown as Record<string, unknown>).isOAuthError = true;
mockConnectionInstance.connect.mockRejectedValue(oauthError);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow(
'Non-200 status code (401)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
describe('isOAuthError method', () => {
it('should identify OAuth errors by message content', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
const oauthOptions = {
useOAuth: true as const,
user: mockUser,
flowManager: mockFlowManager,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
const error401 = new Error('401 Unauthorized');
mockConnectionInstance.connect.mockRejectedValue(error401);
mockConnectionInstance.isConnected.mockResolvedValue(false);
await expect(MCPConnectionFactory.create(basicOptions, oauthOptions)).rejects.toThrow('401');
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('OAuth required, stopping connection attempts'),
);
});
});
describe('discoverTools static method', () => {
const mockTools = [
{ name: 'tool1', description: 'First tool', inputSchema: { type: 'object' } },
{ name: 'tool2', description: 'Second tool', inputSchema: { type: 'object' } },
];
it('should discover tools from a successfully connected server', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockResolvedValue(undefined);
mockConnectionInstance.isConnected.mockResolvedValue(true);
mockConnectionInstance.fetchTools = jest.fn().mockResolvedValue(mockTools);
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toEqual(mockTools);
expect(result.oauthRequired).toBe(false);
expect(result.oauthUrl).toBeNull();
expect(result.connection).toBe(mockConnectionInstance);
});
it('should detect OAuth required without generating URL in discovery mode', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: {
...mockServerConfig,
url: 'https://api.example.com',
type: 'sse' as const,
} as t.SSEOptions,
};
const mockOAuthStart = jest.fn().mockResolvedValue(undefined);
const oauthOptions = {
useOAuth: true as const,
user: mockUser as unknown as IUser,
flowManager: mockFlowManager,
oauthStart: mockOAuthStart,
tokenMethods: {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteTokens: jest.fn(),
},
};
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
let oauthHandler: (() => Promise<void>) | undefined;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthHandler = handler as () => Promise<void>;
}
return mockConnectionInstance;
});
mockConnectionInstance.connect.mockImplementation(async () => {
if (oauthHandler) {
await oauthHandler();
}
throw new Error('OAuth required');
});
const result = await MCPConnectionFactory.discoverTools(basicOptions, oauthOptions);
expect(result.connection).toBeNull();
expect(result.tools).toBeNull();
expect(result.oauthRequired).toBe(true);
expect(result.oauthUrl).toBeNull();
expect(mockOAuthStart).not.toHaveBeenCalled();
});
it('should return null tools when discovery fails completely', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined);
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toBeNull();
expect(result.connection).toBeNull();
expect(result.oauthRequired).toBe(false);
});
it('should handle disconnect errors gracefully during cleanup', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed'));
mockConnectionInstance.isConnected.mockResolvedValue(false);
mockConnectionInstance.disconnect = jest
.fn()
.mockRejectedValue(new Error('Disconnect failed'));
const result = await MCPConnectionFactory.discoverTools(basicOptions);
expect(result.tools).toBeNull();
expect(mockLogger.debug).toHaveBeenCalled();
});
});
});