mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-14 14:38:11 +01:00
🛡️ fix: Secure MCP/Actions OAuth Flows, Resolve Race Condition & Tool Cache Cleanup (#11756)
* 🔧 fix: Update OAuth error message for clarity - Changed the default error message in the OAuth error route from 'Unknown error' to 'Unknown OAuth error' to provide clearer context during authentication failures. * 🔒 feat: Enhance OAuth flow with CSRF protection and session management - Implemented CSRF protection for OAuth flows by introducing `generateOAuthCsrfToken`, `setOAuthCsrfCookie`, and `validateOAuthCsrf` functions. - Added session management for OAuth with `setOAuthSession` and `validateOAuthSession` middleware. - Updated routes to bind CSRF tokens for MCP and action OAuth flows, ensuring secure authentication. - Enhanced tests to validate CSRF handling and session management in OAuth processes. * 🔧 refactor: Invalidate cached tools after user plugin disconnection - Added a call to `invalidateCachedTools` in the `updateUserPluginsController` to ensure that cached tools are refreshed when a user disconnects from an MCP server after a plugin authentication update. This change improves the accuracy of tool data for users. * chore: imports order * fix: domain separator regex usage in ToolService - Moved the declaration of `domainSeparatorRegex` to avoid redundancy in the `loadActionToolsForExecution` function, improving code clarity and performance. * chore: OAuth flow error handling and CSRF token generation - Enhanced the OAuth callback route to validate the flow ID format, ensuring proper error handling for invalid states. - Updated the CSRF token generation function to require a JWT secret, throwing an error if not provided, which improves security and clarity in token generation. - Adjusted tests to reflect changes in flow ID handling and ensure robust validation across various scenarios.
This commit is contained in:
parent
72a30cd9c4
commit
599f4a11f1
14 changed files with 523 additions and 141 deletions
|
|
@ -298,38 +298,45 @@ export class MCPConnectionFactory {
|
|||
const oauthHandler = async (data: { serverUrl?: string }) => {
|
||||
logger.info(`${this.logPrefix} oauthRequired event received`);
|
||||
|
||||
// If we just want to initiate OAuth and return, handle it differently
|
||||
if (this.returnOnOAuth) {
|
||||
try {
|
||||
const config = this.serverConfig;
|
||||
const { authorizationUrl, flowId, flowMetadata } =
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
this.serverName,
|
||||
data.serverUrl || '',
|
||||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
|
||||
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow?.status === 'PENDING') {
|
||||
logger.debug(
|
||||
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Delete any existing flow state to ensure we start fresh
|
||||
// This prevents stale codeVerifier issues when re-authenticating
|
||||
await this.flowManager!.deleteFlow(flowId, 'mcp_oauth');
|
||||
const {
|
||||
authorizationUrl,
|
||||
flowId: newFlowId,
|
||||
flowMetadata,
|
||||
} = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
this.serverName,
|
||||
data.serverUrl || '',
|
||||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
);
|
||||
|
||||
// Create the flow state so the OAuth callback can find it
|
||||
// We spawn this in the background without waiting for it
|
||||
// Pass signal so the flow can be aborted if the request is cancelled
|
||||
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
|
||||
// The OAuth callback will resolve this flow, so we expect it to timeout here
|
||||
// or it will be aborted if the request is cancelled - both are fine
|
||||
});
|
||||
if (existingFlow) {
|
||||
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
|
||||
}
|
||||
|
||||
this.flowManager!.createFlow(newFlowId, 'mcp_oauth', flowMetadata, this.signal).catch(
|
||||
() => {},
|
||||
);
|
||||
|
||||
if (this.oauthStart) {
|
||||
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
|
||||
await this.oauthStart(authorizationUrl);
|
||||
}
|
||||
|
||||
// Emit oauthFailed to signal that connection should not proceed
|
||||
// but OAuth was successfully initiated
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
} catch (error) {
|
||||
|
|
@ -391,11 +398,9 @@ export class MCPConnectionFactory {
|
|||
logger.error(`${this.logPrefix} Failed to establish connection.`);
|
||||
}
|
||||
|
||||
// Handles connection attempts with retry logic and OAuth error handling
|
||||
private async connectTo(connection: MCPConnection): Promise<void> {
|
||||
const maxAttempts = 3;
|
||||
let attempts = 0;
|
||||
let oauthHandled = false;
|
||||
|
||||
while (attempts < maxAttempts) {
|
||||
try {
|
||||
|
|
@ -408,22 +413,6 @@ export class MCPConnectionFactory {
|
|||
attempts++;
|
||||
|
||||
if (this.useOAuth && this.isOAuthError(error)) {
|
||||
// For returnOnOAuth mode, let the event handler (handleOAuthEvents) deal with OAuth
|
||||
// We just need to stop retrying and let the error propagate
|
||||
if (this.returnOnOAuth) {
|
||||
logger.info(
|
||||
`${this.logPrefix} OAuth required (return on OAuth mode), stopping retries`,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Normal flow - wait for OAuth to complete
|
||||
if (this.oauthStart && !oauthHandled) {
|
||||
oauthHandled = true;
|
||||
logger.info(`${this.logPrefix} Handling OAuth`);
|
||||
await this.handleOAuthRequired();
|
||||
}
|
||||
// Don't retry on OAuth errors - just throw
|
||||
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
|
||||
throw error;
|
||||
}
|
||||
|
|
@ -499,26 +488,15 @@ export class MCPConnectionFactory {
|
|||
/** Check if there's already an ongoing OAuth flow for this flowId */
|
||||
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
// If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
|
||||
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
|
||||
if (existingFlow) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
|
||||
);
|
||||
try {
|
||||
if (existingFlow.status === 'PENDING') {
|
||||
await this.flowManager.failFlow(
|
||||
flowId,
|
||||
'mcp_oauth',
|
||||
new Error('Cancelled for new OAuth request'),
|
||||
);
|
||||
} else {
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
}
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
} catch (error) {
|
||||
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
|
||||
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
|
||||
}
|
||||
// Continue to start a new flow below
|
||||
}
|
||||
|
||||
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);
|
||||
|
|
|
|||
|
|
@ -270,7 +270,54 @@ describe('MCPConnectionFactory', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should delete existing flow before creating new OAuth flow to prevent stale codeVerifier', async () => {
|
||||
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: mockServerConfig,
|
||||
user: mockUser,
|
||||
};
|
||||
|
||||
const oauthOptions: t.OAuthConnectionOptions = {
|
||||
user: mockUser,
|
||||
useOAuth: true,
|
||||
returnOnOAuth: true,
|
||||
oauthStart: jest.fn(),
|
||||
flowManager: mockFlowManager,
|
||||
};
|
||||
|
||||
mockFlowManager.getFlowState.mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
type: 'mcp_oauth',
|
||||
metadata: { codeVerifier: 'existing-verifier' },
|
||||
createdAt: Date.now(),
|
||||
});
|
||||
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||
|
||||
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
|
||||
mockConnectionInstance.on.mockImplementation((event, handler) => {
|
||||
if (event === 'oauthRequired') {
|
||||
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
|
||||
}
|
||||
return mockConnectionInstance;
|
||||
});
|
||||
|
||||
try {
|
||||
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||
|
||||
expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled();
|
||||
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
|
||||
'oauthFailed',
|
||||
expect.objectContaining({ message: 'OAuth flow initiated - return early' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: mockServerConfig,
|
||||
|
|
@ -303,6 +350,12 @@ describe('MCPConnectionFactory', () => {
|
|||
},
|
||||
};
|
||||
|
||||
mockFlowManager.getFlowState.mockResolvedValue({
|
||||
status: 'COMPLETED',
|
||||
type: 'mcp_oauth',
|
||||
metadata: { codeVerifier: 'old-verifier' },
|
||||
createdAt: Date.now() - 60000,
|
||||
});
|
||||
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
|
||||
mockFlowManager.deleteFlow.mockResolvedValue(true);
|
||||
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
|
||||
|
|
@ -319,21 +372,17 @@ describe('MCPConnectionFactory', () => {
|
|||
try {
|
||||
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||
} catch {
|
||||
// Expected to fail due to connection not established
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||
|
||||
// Verify deleteFlow was called with correct parameters
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth');
|
||||
|
||||
// Verify deleteFlow was called before createFlow
|
||||
const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0];
|
||||
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
|
||||
expect(deleteCallOrder).toBeLessThan(createCallOrder);
|
||||
|
||||
// Verify createFlow was called with fresh metadata
|
||||
// 4th arg is the abort signal (undefined in this test since no signal was provided)
|
||||
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
|
||||
'user123:test-server',
|
||||
'mcp_oauth',
|
||||
|
|
|
|||
89
packages/api/src/oauth/csrf.ts
Normal file
89
packages/api/src/oauth/csrf.ts
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import crypto from 'crypto';
|
||||
import type { Request, Response, NextFunction } from 'express';
|
||||
|
||||
export const OAUTH_CSRF_COOKIE = 'oauth_csrf';
|
||||
export const OAUTH_CSRF_MAX_AGE = 10 * 60 * 1000;
|
||||
|
||||
export const OAUTH_SESSION_COOKIE = 'oauth_session';
|
||||
export const OAUTH_SESSION_MAX_AGE = 24 * 60 * 60 * 1000;
|
||||
export const OAUTH_SESSION_COOKIE_PATH = '/api';
|
||||
|
||||
const isProduction = process.env.NODE_ENV === 'production';
|
||||
|
||||
/** Generates an HMAC-based token for OAuth CSRF protection */
|
||||
export function generateOAuthCsrfToken(flowId: string, secret?: string): string {
|
||||
const key = secret || process.env.JWT_SECRET;
|
||||
if (!key) {
|
||||
throw new Error('JWT_SECRET is required for OAuth CSRF token generation');
|
||||
}
|
||||
return crypto.createHmac('sha256', key).update(flowId).digest('hex').slice(0, 32);
|
||||
}
|
||||
|
||||
/** Sets a SameSite=Lax CSRF cookie bound to a specific OAuth flow */
|
||||
export function setOAuthCsrfCookie(res: Response, flowId: string, cookiePath: string): void {
|
||||
res.cookie(OAUTH_CSRF_COOKIE, generateOAuthCsrfToken(flowId), {
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'lax',
|
||||
maxAge: OAUTH_CSRF_MAX_AGE,
|
||||
path: cookiePath,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the per-flow CSRF cookie against the expected HMAC.
|
||||
* Uses timing-safe comparison and always clears the cookie to prevent replay.
|
||||
*/
|
||||
export function validateOAuthCsrf(
|
||||
req: Request,
|
||||
res: Response,
|
||||
flowId: string,
|
||||
cookiePath: string,
|
||||
): boolean {
|
||||
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_CSRF_COOKIE];
|
||||
res.clearCookie(OAUTH_CSRF_COOKIE, { path: cookiePath });
|
||||
if (!cookie) {
|
||||
return false;
|
||||
}
|
||||
const expected = generateOAuthCsrfToken(flowId);
|
||||
if (cookie.length !== expected.length) {
|
||||
return false;
|
||||
}
|
||||
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
|
||||
}
|
||||
|
||||
/**
|
||||
* Express middleware that sets the OAuth session cookie after JWT authentication.
|
||||
* Chain after requireJwtAuth on routes that precede an OAuth redirect (e.g., reinitialize, bind).
|
||||
*/
|
||||
export function setOAuthSession(req: Request, res: Response, next: NextFunction): void {
|
||||
const user = (req as Request & { user?: { id?: string } }).user;
|
||||
if (user?.id && !(req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE]) {
|
||||
setOAuthSessionCookie(res, user.id);
|
||||
}
|
||||
next();
|
||||
}
|
||||
|
||||
/** Sets a SameSite=Lax session cookie that binds the browser to the authenticated userId */
|
||||
export function setOAuthSessionCookie(res: Response, userId: string): void {
|
||||
res.cookie(OAUTH_SESSION_COOKIE, generateOAuthCsrfToken(userId), {
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'lax',
|
||||
maxAge: OAUTH_SESSION_MAX_AGE,
|
||||
path: OAUTH_SESSION_COOKIE_PATH,
|
||||
});
|
||||
}
|
||||
|
||||
/** Validates the session cookie against the expected userId using timing-safe comparison */
|
||||
export function validateOAuthSession(req: Request, userId: string): boolean {
|
||||
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE];
|
||||
if (!cookie) {
|
||||
return false;
|
||||
}
|
||||
const expected = generateOAuthCsrfToken(userId);
|
||||
if (cookie.length !== expected.length) {
|
||||
return false;
|
||||
}
|
||||
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
|
||||
}
|
||||
|
|
@ -1 +1,2 @@
|
|||
export * from './csrf';
|
||||
export * from './tokens';
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue