🛡️ fix: Secure MCP/Actions OAuth Flows, Resolve Race Condition & Tool Cache Cleanup (#11756)
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run

* 🔧 fix: Update OAuth error message for clarity

- Changed the default error message in the OAuth error route from 'Unknown error' to 'Unknown OAuth error' to provide clearer context during authentication failures.

* 🔒 feat: Enhance OAuth flow with CSRF protection and session management

- Implemented CSRF protection for OAuth flows by introducing `generateOAuthCsrfToken`, `setOAuthCsrfCookie`, and `validateOAuthCsrf` functions.
- Added session management for OAuth with `setOAuthSession` and `validateOAuthSession` middleware.
- Updated routes to bind CSRF tokens for MCP and action OAuth flows, ensuring secure authentication.
- Enhanced tests to validate CSRF handling and session management in OAuth processes.

* 🔧 refactor: Invalidate cached tools after user plugin disconnection

- Added a call to `invalidateCachedTools` in the `updateUserPluginsController` to ensure that cached tools are refreshed when a user disconnects from an MCP server after a plugin authentication update. This change improves the accuracy of tool data for users.

* chore: imports order

* fix: domain separator regex usage in ToolService

- Moved the declaration of `domainSeparatorRegex` to avoid redundancy in the `loadActionToolsForExecution` function, improving code clarity and performance.

* chore: OAuth flow error handling and CSRF token generation

- Enhanced the OAuth callback route to validate the flow ID format, ensuring proper error handling for invalid states.
- Updated the CSRF token generation function to require a JWT secret, throwing an error if not provided, which improves security and clarity in token generation.
- Adjusted tests to reflect changes in flow ID handling and ensure robust validation across various scenarios.
This commit is contained in:
Danny Avila 2026-02-12 14:22:05 -05:00 committed by GitHub
parent 72a30cd9c4
commit 599f4a11f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 523 additions and 141 deletions

View file

@ -36,6 +36,7 @@ const {
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { getAppConfig } = require('~/server/services/Config');
@ -215,6 +216,7 @@ const updateUserPluginsController = async (req, res) => {
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
);
await mcpManager.disconnectUserConnection(user.id, serverName);
await invalidateCachedTools({ userId: user.id, serverName });
}
} catch (disconnectError) {
logger.error(

View file

@ -7,16 +7,13 @@ const { isEnabled } = require('@librechat/api');
* Switches between JWT and OpenID authentication based on cookies and environment settings
*/
const requireJwtAuth = (req, res, next) => {
// Check if token provider is specified in cookies
const cookieHeader = req.headers.cookie;
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
// Use OpenID authentication if token provider is OpenID and OPENID_REUSE_TOKENS is enabled
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
return passport.authenticate('openidJwt', { session: false })(req, res, next);
}
// Default to standard JWT authentication
return passport.authenticate('jwt', { session: false })(req, res, next);
};

View file

@ -1,8 +1,18 @@
const crypto = require('crypto');
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const cookieParser = require('cookie-parser');
const { getBasePath } = require('@librechat/api');
const { MongoMemoryServer } = require('mongodb-memory-server');
function generateTestCsrfToken(flowId) {
return crypto
.createHmac('sha256', process.env.JWT_SECRET)
.update(flowId)
.digest('hex')
.slice(0, 32);
}
const mockRegistryInstance = {
getServerConfig: jest.fn(),
@ -130,6 +140,7 @@ describe('MCP Routes', () => {
app = express();
app.use(express.json());
app.use(cookieParser());
app.use((req, res, next) => {
req.user = { id: 'test-user-id' };
@ -168,12 +179,12 @@ describe('MCP Routes', () => {
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
authorizationUrl: 'https://oauth.example.com/auth',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(302);
@ -190,7 +201,7 @@ describe('MCP Routes', () => {
it('should return 403 when userId does not match authenticated user', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'different-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(403);
@ -228,7 +239,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(400);
@ -245,7 +256,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(500);
@ -255,7 +266,7 @@ describe('MCP Routes', () => {
it('should return 400 when flow state metadata is null', async () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
id: 'test-flow-id',
id: 'test-user-id:test-server',
metadata: null,
}),
};
@ -265,7 +276,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
flowId: 'test-flow-id',
flowId: 'test-user-id:test-server',
});
expect(response.status).toBe(400);
@ -280,7 +291,7 @@ describe('MCP Routes', () => {
it('should redirect to error page when OAuth error is received', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
error: 'access_denied',
state: 'test-flow-id',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
@ -290,7 +301,7 @@ describe('MCP Routes', () => {
it('should redirect to error page when code is missing', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
state: 'test-flow-id',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
@ -308,15 +319,50 @@ describe('MCP Routes', () => {
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`);
});
it('should redirect to error page when flow state is not found', async () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(null);
it('should redirect to error page when CSRF cookie is missing', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'invalid-flow-id',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should redirect to error page when CSRF cookie does not match state', async () => {
const csrfToken = generateTestCsrfToken('different-flow-id');
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: 'test-user-id:test-server',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should redirect to error page when flow state is not found', async () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(null);
const flowId = 'invalid-flow:id';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
});
@ -369,16 +415,22 @@ describe('MCP Routes', () => {
});
setCachedTools.mockResolvedValue();
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
'test-flow-id',
flowId,
'test-auth-code',
mockFlowManager,
{},
@ -400,16 +452,24 @@ describe('MCP Routes', () => {
'mcp_oauth',
mockTokens,
);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(
'test-user-id:test-server',
'mcp_get_tokens',
);
});
it('should redirect to error page when callback processing fails', async () => {
MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error'));
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
@ -442,15 +502,21 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens');
});
it('should handle reconnection failure after OAuth', async () => {
@ -488,16 +554,22 @@ describe('MCP Routes', () => {
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens');
});
it('should redirect to error page if token storage fails', async () => {
@ -530,10 +602,16 @@ describe('MCP Routes', () => {
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
@ -589,22 +667,27 @@ describe('MCP Routes', () => {
clearReconnection: jest.fn(),
});
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
// Verify storeTokens was called with ORIGINAL flow state credentials
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
expect.objectContaining({
userId: 'test-user-id',
serverName: 'test-server',
tokens: mockTokens,
clientInfo: clientInfo, // Uses original flow state, not any "updated" credentials
clientInfo: clientInfo,
metadata: flowState.metadata,
}),
);
@ -631,16 +714,21 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.query({
code: 'test-auth-code',
state: flowId,
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
// Verify completeOAuthFlow was NOT called (prevented duplicate)
expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled();
expect(MCPTokenStorage.storeTokens).not.toHaveBeenCalled();
});
@ -755,7 +843,7 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/oauth/status/test-flow-id');
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:test-server');
expect(response.status).toBe(200);
expect(response.body).toEqual({
@ -766,6 +854,13 @@ describe('MCP Routes', () => {
});
});
it('should return 403 when flowId does not match authenticated user', async () => {
const response = await request(app).get('/api/mcp/oauth/status/other-user-id:test-server');
expect(response.status).toBe(403);
expect(response.body).toEqual({ error: 'Access denied' });
});
it('should return 404 when flow is not found', async () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue(null),
@ -774,7 +869,7 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/oauth/status/non-existent-flow');
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:non-existent');
expect(response.status).toBe(404);
expect(response.body).toEqual({ error: 'Flow not found' });
@ -788,7 +883,7 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app).get('/api/mcp/oauth/status/error-flow-id');
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:error-server');
expect(response.status).toBe(500);
expect(response.body).toEqual({ error: 'Failed to get flow status' });
@ -1375,7 +1470,7 @@ describe('MCP Routes', () => {
refresh_token: 'edge-refresh-token',
};
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
id: 'test-flow-id',
id: 'test-user-id:test-server',
userId: 'test-user-id',
metadata: {
serverUrl: 'https://example.com',
@ -1403,8 +1498,12 @@ describe('MCP Routes', () => {
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
.get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`)
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.expect(302);
const basePath = getBasePath();
@ -1424,7 +1523,7 @@ describe('MCP Routes', () => {
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
id: 'test-flow-id',
id: 'test-user-id:test-server',
userId: 'test-user-id',
metadata: { serverUrl: 'https://example.com', oauth: {} },
clientInfo: {},
@ -1453,8 +1552,12 @@ describe('MCP Routes', () => {
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const flowId = 'test-user-id:test-server';
const csrfToken = generateTestCsrfToken(flowId);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
.get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`)
.set('Cookie', [`oauth_csrf=${csrfToken}`])
.expect(302);
const basePath = getBasePath();

View file

@ -1,14 +1,47 @@
const express = require('express');
const jwt = require('jsonwebtoken');
const { getAccessToken, getBasePath } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const {
getBasePath,
getAccessToken,
setOAuthSession,
validateOAuthCsrf,
OAUTH_CSRF_COOKIE,
setOAuthCsrfCookie,
validateOAuthSession,
OAUTH_SESSION_COOKIE,
} = require('@librechat/api');
const { findToken, updateToken, createToken } = require('~/models');
const { requireJwtAuth } = require('~/server/middleware');
const { getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
const router = express.Router();
const JWT_SECRET = process.env.JWT_SECRET;
const OAUTH_CSRF_COOKIE_PATH = '/api/actions';
/**
* Sets a CSRF cookie binding the action OAuth flow to the current browser session.
* Must be called before the user opens the IdP authorization URL.
*
* @route POST /actions/:action_id/oauth/bind
*/
router.post('/:action_id/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { action_id } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
const flowId = `${user.id}:${action_id}`;
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
res.json({ success: true });
} catch (error) {
logger.error('[Action OAuth] Failed to set CSRF binding cookie', error);
res.status(500).json({ error: 'Failed to bind OAuth flow' });
}
});
/**
* Handles the OAuth callback and exchanges the authorization code for tokens.
@ -45,7 +78,22 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
identifier = `${decodedState.user}:${action_id}`;
if (
!validateOAuthCsrf(req, res, identifier, OAUTH_CSRF_COOKIE_PATH) &&
!validateOAuthSession(req, decodedState.user)
) {
logger.error('[Action OAuth] CSRF validation failed: no valid CSRF or session cookie', {
identifier,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
});
await flowManager.failFlow(identifier, 'oauth', 'CSRF validation failed');
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
}
const flowState = await flowManager.getFlowState(identifier, 'oauth');
if (!flowState) {
throw new Error('OAuth flow not found');
@ -71,7 +119,6 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
);
await flowManager.completeFlow(identifier, 'oauth', tokenData);
/** Redirect to React success page */
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
res.redirect(redirectUrl);

View file

@ -8,18 +8,32 @@ const {
Permissions,
} = require('librechat-data-provider');
const {
getBasePath,
createSafeUser,
MCPOAuthHandler,
MCPTokenStorage,
getBasePath,
setOAuthSession,
getUserMCPAuthMap,
validateOAuthCsrf,
OAUTH_CSRF_COOKIE,
setOAuthCsrfCookie,
generateCheckAccess,
validateOAuthSession,
OAUTH_SESSION_COOKIE,
} = require('@librechat/api');
const {
getMCPManager,
getFlowStateManager,
createMCPServerController,
updateMCPServerController,
deleteMCPServerController,
getMCPServersList,
getMCPServerById,
getMCPTools,
} = require('~/server/controllers/mcp');
const {
getOAuthReconnectionManager,
getMCPServersRegistry,
getFlowStateManager,
getMCPManager,
} = require('~/config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware');
@ -27,20 +41,14 @@ const { findToken, updateToken, createToken, deleteTokens } = require('~/models'
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { getMCPTools } = require('~/server/controllers/mcp');
const { findPluginAuthsByKeys } = require('~/models');
const { getRoleByName } = require('~/models/Role');
const { getLogStores } = require('~/cache');
const {
createMCPServerController,
getMCPServerById,
getMCPServersList,
updateMCPServerController,
deleteMCPServerController,
} = require('~/server/controllers/mcp');
const router = Router();
const OAUTH_CSRF_COOKIE_PATH = '/api/mcp';
/**
* Get all MCP tools available to the user
* Returns only MCP tools, completely decoupled from regular LibreChat tools
@ -53,7 +61,7 @@ router.get('/tools', requireJwtAuth, async (req, res) => {
* Initiate OAuth flow
* This endpoint is called when the user clicks the auth link in the UI
*/
router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { serverName } = req.params;
const { userId, flowId } = req.query;
@ -93,7 +101,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
// Redirect user to the authorization URL
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
res.redirect(authorizationUrl);
} catch (error) {
logger.error('[MCP OAuth] Failed to initiate OAuth', error);
@ -138,6 +146,25 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const flowId = state;
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
const flowParts = flowId.split(':');
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
const [flowUserId] = flowParts;
if (
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
!validateOAuthSession(req, flowUserId)
) {
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
flowId,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
});
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
@ -302,13 +329,47 @@ router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => {
}
});
/**
* Set CSRF binding cookie for OAuth flows initiated outside of HTTP request/response
* (e.g. during chat via SSE). The frontend should call this before opening the OAuth URL
* so the callback can verify the browser matches the flow initiator.
*/
router.post('/:serverName/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { serverName } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
res.json({ success: true });
} catch (error) {
logger.error('[MCP OAuth] Failed to set CSRF binding cookie', error);
res.status(500).json({ error: 'Failed to bind OAuth flow' });
}
});
/**
* Check OAuth flow status
* This endpoint can be used to poll the status of an OAuth flow
*/
router.get('/oauth/status/:flowId', async (req, res) => {
router.get('/oauth/status/:flowId', requireJwtAuth, async (req, res) => {
try {
const { flowId } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) {
return res.status(403).json({ error: 'Access denied' });
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
@ -375,7 +436,7 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
* Reinitialize MCP server
* This endpoint allows reinitializing a specific MCP server
*/
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => {
try {
const { serverName } = req.params;
const user = createSafeUser(req.user);
@ -421,6 +482,11 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
const { success, message, oauthRequired, oauthUrl } = result;
if (oauthRequired) {
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
}
res.json({
success,
message,

View file

@ -29,7 +29,7 @@ const oauthHandler = createOAuthHandler();
router.get('/error', (req, res) => {
/** A single error message is pushed by passport when authentication fails. */
const errorMessage = req.session?.messages?.pop() || 'Unknown error';
const errorMessage = req.session?.messages?.pop() || 'Unknown OAuth error';
logger.error('Error in OAuth authentication:', {
message: errorMessage,
});

View file

@ -1339,6 +1339,7 @@ async function loadActionToolsForExecution({
});
}
const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g');
for (const toolName of actionToolNames) {
let currentDomain = '';
for (const domain of domainMap.keys()) {
@ -1355,7 +1356,6 @@ async function loadActionToolsForExecution({
const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } =
processedActionSets.get(currentDomain);
const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g');
const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_');
const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);

View file

@ -1,7 +1,12 @@
import { useMemo, useState, useEffect, useRef, useLayoutEffect } from 'react';
import { useMemo, useState, useEffect, useRef, useCallback, useLayoutEffect } from 'react';
import { Button } from '@librechat/client';
import { TriangleAlert } from 'lucide-react';
import { actionDelimiter, actionDomainSeparator, Constants } from 'librechat-data-provider';
import {
Constants,
dataService,
actionDelimiter,
actionDomainSeparator,
} from 'librechat-data-provider';
import type { TAttachment } from 'librechat-data-provider';
import { useLocalize, useProgress } from '~/hooks';
import { AttachmentGroup } from './Parts';
@ -36,9 +41,9 @@ export default function ToolCall({
const [isAnimating, setIsAnimating] = useState(false);
const prevShowInfoRef = useRef<boolean>(showInfo);
const { function_name, domain, isMCPToolCall } = useMemo(() => {
const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => {
if (typeof name !== 'string') {
return { function_name: '', domain: null, isMCPToolCall: false };
return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' };
}
if (name.includes(Constants.mcp_delimiter)) {
const [func, server] = name.split(Constants.mcp_delimiter);
@ -46,6 +51,7 @@ export default function ToolCall({
function_name: func || '',
domain: server && (server.replaceAll(actionDomainSeparator, '.') || null),
isMCPToolCall: true,
mcpServerName: server || '',
};
}
const [func, _domain] = name.includes(actionDelimiter)
@ -55,9 +61,40 @@ export default function ToolCall({
function_name: func || '',
domain: _domain && (_domain.replaceAll(actionDomainSeparator, '.') || null),
isMCPToolCall: false,
mcpServerName: '',
};
}, [name]);
const actionId = useMemo(() => {
if (isMCPToolCall || !auth) {
return '';
}
try {
const url = new URL(auth);
const redirectUri = url.searchParams.get('redirect_uri') || '';
const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/);
return match?.[1] || '';
} catch {
return '';
}
}, [auth, isMCPToolCall]);
const handleOAuthClick = useCallback(async () => {
if (!auth) {
return;
}
try {
if (isMCPToolCall && mcpServerName) {
await dataService.bindMCPOAuth(mcpServerName);
} else if (actionId) {
await dataService.bindActionOAuth(actionId);
}
} catch (e) {
logger.error('Failed to bind OAuth CSRF cookie', e);
}
window.open(auth, '_blank', 'noopener,noreferrer');
}, [auth, isMCPToolCall, mcpServerName, actionId]);
const error =
typeof output === 'string' && output.toLowerCase().includes('error processing tool');
@ -230,7 +267,7 @@ export default function ToolCall({
className="font-mediu inline-flex items-center justify-center rounded-xl px-4 py-2 text-sm"
variant="default"
rel="noopener noreferrer"
onClick={() => window.open(auth, '_blank', 'noopener,noreferrer')}
onClick={handleOAuthClick}
>
{localize('com_ui_sign_in_to_domain', { 0: authDomain })}
</Button>

View file

@ -298,38 +298,45 @@ export class MCPConnectionFactory {
const oauthHandler = async (data: { serverUrl?: string }) => {
logger.info(`${this.logPrefix} oauthRequired event received`);
// If we just want to initiate OAuth and return, handle it differently
if (this.returnOnOAuth) {
try {
const config = this.serverConfig;
const { authorizationUrl, flowId, flowMetadata } =
await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
data.serverUrl || '',
this.userId!,
config?.oauth_headers ?? {},
config?.oauth,
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
if (existingFlow?.status === 'PENDING') {
logger.debug(
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
);
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
}
// Delete any existing flow state to ensure we start fresh
// This prevents stale codeVerifier issues when re-authenticating
await this.flowManager!.deleteFlow(flowId, 'mcp_oauth');
const {
authorizationUrl,
flowId: newFlowId,
flowMetadata,
} = await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
data.serverUrl || '',
this.userId!,
config?.oauth_headers ?? {},
config?.oauth,
);
// Create the flow state so the OAuth callback can find it
// We spawn this in the background without waiting for it
// Pass signal so the flow can be aborted if the request is cancelled
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
// The OAuth callback will resolve this flow, so we expect it to timeout here
// or it will be aborted if the request is cancelled - both are fine
});
if (existingFlow) {
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
}
this.flowManager!.createFlow(newFlowId, 'mcp_oauth', flowMetadata, this.signal).catch(
() => {},
);
if (this.oauthStart) {
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
await this.oauthStart(authorizationUrl);
}
// Emit oauthFailed to signal that connection should not proceed
// but OAuth was successfully initiated
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
} catch (error) {
@ -391,11 +398,9 @@ export class MCPConnectionFactory {
logger.error(`${this.logPrefix} Failed to establish connection.`);
}
// Handles connection attempts with retry logic and OAuth error handling
private async connectTo(connection: MCPConnection): Promise<void> {
const maxAttempts = 3;
let attempts = 0;
let oauthHandled = false;
while (attempts < maxAttempts) {
try {
@ -408,22 +413,6 @@ export class MCPConnectionFactory {
attempts++;
if (this.useOAuth && this.isOAuthError(error)) {
// For returnOnOAuth mode, let the event handler (handleOAuthEvents) deal with OAuth
// We just need to stop retrying and let the error propagate
if (this.returnOnOAuth) {
logger.info(
`${this.logPrefix} OAuth required (return on OAuth mode), stopping retries`,
);
throw error;
}
// Normal flow - wait for OAuth to complete
if (this.oauthStart && !oauthHandled) {
oauthHandled = true;
logger.info(`${this.logPrefix} Handling OAuth`);
await this.handleOAuthRequired();
}
// Don't retry on OAuth errors - just throw
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
throw error;
}
@ -499,26 +488,15 @@ export class MCPConnectionFactory {
/** Check if there's already an ongoing OAuth flow for this flowId */
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
// If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
if (existingFlow) {
logger.debug(
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
);
try {
if (existingFlow.status === 'PENDING') {
await this.flowManager.failFlow(
flowId,
'mcp_oauth',
new Error('Cancelled for new OAuth request'),
);
} else {
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
}
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
} catch (error) {
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
}
// Continue to start a new flow below
}
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);

View file

@ -270,7 +270,54 @@ describe('MCPConnectionFactory', () => {
);
});
it('should delete existing flow before creating new OAuth flow to prevent stale codeVerifier', async () => {
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
user: mockUser,
};
const oauthOptions: t.OAuthConnectionOptions = {
user: mockUser,
useOAuth: true,
returnOnOAuth: true,
oauthStart: jest.fn(),
flowManager: mockFlowManager,
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'PENDING',
type: 'mcp_oauth',
metadata: { codeVerifier: 'existing-verifier' },
createdAt: Date.now(),
});
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled();
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({ message: 'OAuth flow initiated - return early' }),
);
});
it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
@ -303,6 +350,12 @@ describe('MCPConnectionFactory', () => {
},
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'COMPLETED',
type: 'mcp_oauth',
metadata: { codeVerifier: 'old-verifier' },
createdAt: Date.now() - 60000,
});
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.deleteFlow.mockResolvedValue(true);
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
@ -319,21 +372,17 @@ describe('MCPConnectionFactory', () => {
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail due to connection not established
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
// Verify deleteFlow was called with correct parameters
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth');
// Verify deleteFlow was called before createFlow
const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0];
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
expect(deleteCallOrder).toBeLessThan(createCallOrder);
// Verify createFlow was called with fresh metadata
// 4th arg is the abort signal (undefined in this test since no signal was provided)
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
'user123:test-server',
'mcp_oauth',

View file

@ -0,0 +1,89 @@
import crypto from 'crypto';
import type { Request, Response, NextFunction } from 'express';
export const OAUTH_CSRF_COOKIE = 'oauth_csrf';
export const OAUTH_CSRF_MAX_AGE = 10 * 60 * 1000;
export const OAUTH_SESSION_COOKIE = 'oauth_session';
export const OAUTH_SESSION_MAX_AGE = 24 * 60 * 60 * 1000;
export const OAUTH_SESSION_COOKIE_PATH = '/api';
const isProduction = process.env.NODE_ENV === 'production';
/** Generates an HMAC-based token for OAuth CSRF protection */
export function generateOAuthCsrfToken(flowId: string, secret?: string): string {
const key = secret || process.env.JWT_SECRET;
if (!key) {
throw new Error('JWT_SECRET is required for OAuth CSRF token generation');
}
return crypto.createHmac('sha256', key).update(flowId).digest('hex').slice(0, 32);
}
/** Sets a SameSite=Lax CSRF cookie bound to a specific OAuth flow */
export function setOAuthCsrfCookie(res: Response, flowId: string, cookiePath: string): void {
res.cookie(OAUTH_CSRF_COOKIE, generateOAuthCsrfToken(flowId), {
httpOnly: true,
secure: isProduction,
sameSite: 'lax',
maxAge: OAUTH_CSRF_MAX_AGE,
path: cookiePath,
});
}
/**
* Validates the per-flow CSRF cookie against the expected HMAC.
* Uses timing-safe comparison and always clears the cookie to prevent replay.
*/
export function validateOAuthCsrf(
req: Request,
res: Response,
flowId: string,
cookiePath: string,
): boolean {
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_CSRF_COOKIE];
res.clearCookie(OAUTH_CSRF_COOKIE, { path: cookiePath });
if (!cookie) {
return false;
}
const expected = generateOAuthCsrfToken(flowId);
if (cookie.length !== expected.length) {
return false;
}
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
}
/**
* Express middleware that sets the OAuth session cookie after JWT authentication.
* Chain after requireJwtAuth on routes that precede an OAuth redirect (e.g., reinitialize, bind).
*/
export function setOAuthSession(req: Request, res: Response, next: NextFunction): void {
const user = (req as Request & { user?: { id?: string } }).user;
if (user?.id && !(req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE]) {
setOAuthSessionCookie(res, user.id);
}
next();
}
/** Sets a SameSite=Lax session cookie that binds the browser to the authenticated userId */
export function setOAuthSessionCookie(res: Response, userId: string): void {
res.cookie(OAUTH_SESSION_COOKIE, generateOAuthCsrfToken(userId), {
httpOnly: true,
secure: isProduction,
sameSite: 'lax',
maxAge: OAUTH_SESSION_MAX_AGE,
path: OAUTH_SESSION_COOKIE_PATH,
});
}
/** Validates the session cookie against the expected userId using timing-safe comparison */
export function validateOAuthSession(req: Request, userId: string): boolean {
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE];
if (!cookie) {
return false;
}
const expected = generateOAuthCsrfToken(userId);
if (cookie.length !== expected.length) {
return false;
}
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
}

View file

@ -1 +1,2 @@
export * from './csrf';
export * from './tokens';

View file

@ -181,6 +181,11 @@ export const cancelMCPOAuth = (serverName: string) => {
return `${BASE_URL}/api/mcp/oauth/cancel/${serverName}`;
};
export const mcpOAuthBind = (serverName: string) => `${BASE_URL}/api/mcp/${serverName}/oauth/bind`;
export const actionOAuthBind = (actionId: string) =>
`${BASE_URL}/api/actions/${actionId}/oauth/bind`;
export const config = () => `${BASE_URL}/api/config`;
export const prompts = () => `${BASE_URL}/api/prompts`;

View file

@ -178,6 +178,14 @@ export const reinitializeMCPServer = (serverName: string) => {
return request.post(endpoints.mcpReinitialize(serverName));
};
export const bindMCPOAuth = (serverName: string): Promise<{ success: boolean }> => {
return request.post(endpoints.mcpOAuthBind(serverName));
};
export const bindActionOAuth = (actionId: string): Promise<{ success: boolean }> => {
return request.post(endpoints.actionOAuthBind(actionId));
};
export const getMCPConnectionStatus = (): Promise<q.MCPConnectionStatusResponse> => {
return request.get(endpoints.mcpConnectionStatus());
};