mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-13 13:04:24 +01:00
🛡️ fix: Secure MCP/Actions OAuth Flows, Resolve Race Condition & Tool Cache Cleanup (#11756)
* 🔧 fix: Update OAuth error message for clarity - Changed the default error message in the OAuth error route from 'Unknown error' to 'Unknown OAuth error' to provide clearer context during authentication failures. * 🔒 feat: Enhance OAuth flow with CSRF protection and session management - Implemented CSRF protection for OAuth flows by introducing `generateOAuthCsrfToken`, `setOAuthCsrfCookie`, and `validateOAuthCsrf` functions. - Added session management for OAuth with `setOAuthSession` and `validateOAuthSession` middleware. - Updated routes to bind CSRF tokens for MCP and action OAuth flows, ensuring secure authentication. - Enhanced tests to validate CSRF handling and session management in OAuth processes. * 🔧 refactor: Invalidate cached tools after user plugin disconnection - Added a call to `invalidateCachedTools` in the `updateUserPluginsController` to ensure that cached tools are refreshed when a user disconnects from an MCP server after a plugin authentication update. This change improves the accuracy of tool data for users. * chore: imports order * fix: domain separator regex usage in ToolService - Moved the declaration of `domainSeparatorRegex` to avoid redundancy in the `loadActionToolsForExecution` function, improving code clarity and performance. * chore: OAuth flow error handling and CSRF token generation - Enhanced the OAuth callback route to validate the flow ID format, ensuring proper error handling for invalid states. - Updated the CSRF token generation function to require a JWT secret, throwing an error if not provided, which improves security and clarity in token generation. - Adjusted tests to reflect changes in flow ID handling and ensure robust validation across various scenarios.
This commit is contained in:
parent
72a30cd9c4
commit
599f4a11f1
14 changed files with 523 additions and 141 deletions
|
|
@ -36,6 +36,7 @@ const {
|
|||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
|
||||
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
|
@ -215,6 +216,7 @@ const updateUserPluginsController = async (req, res) => {
|
|||
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
|
||||
);
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
await invalidateCachedTools({ userId: user.id, serverName });
|
||||
}
|
||||
} catch (disconnectError) {
|
||||
logger.error(
|
||||
|
|
|
|||
|
|
@ -7,16 +7,13 @@ const { isEnabled } = require('@librechat/api');
|
|||
* Switches between JWT and OpenID authentication based on cookies and environment settings
|
||||
*/
|
||||
const requireJwtAuth = (req, res, next) => {
|
||||
// Check if token provider is specified in cookies
|
||||
const cookieHeader = req.headers.cookie;
|
||||
const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null;
|
||||
|
||||
// Use OpenID authentication if token provider is OpenID and OPENID_REUSE_TOKENS is enabled
|
||||
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
return passport.authenticate('openidJwt', { session: false })(req, res, next);
|
||||
}
|
||||
|
||||
// Default to standard JWT authentication
|
||||
return passport.authenticate('jwt', { session: false })(req, res, next);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
const crypto = require('crypto');
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const cookieParser = require('cookie-parser');
|
||||
const { getBasePath } = require('@librechat/api');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
function generateTestCsrfToken(flowId) {
|
||||
return crypto
|
||||
.createHmac('sha256', process.env.JWT_SECRET)
|
||||
.update(flowId)
|
||||
.digest('hex')
|
||||
.slice(0, 32);
|
||||
}
|
||||
|
||||
const mockRegistryInstance = {
|
||||
getServerConfig: jest.fn(),
|
||||
|
|
@ -130,6 +140,7 @@ describe('MCP Routes', () => {
|
|||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use(cookieParser());
|
||||
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: 'test-user-id' };
|
||||
|
|
@ -168,12 +179,12 @@ describe('MCP Routes', () => {
|
|||
|
||||
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
||||
authorizationUrl: 'https://oauth.example.com/auth',
|
||||
flowId: 'test-flow-id',
|
||||
flowId: 'test-user-id:test-server',
|
||||
});
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||
userId: 'test-user-id',
|
||||
flowId: 'test-flow-id',
|
||||
flowId: 'test-user-id:test-server',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
|
|
@ -190,7 +201,7 @@ describe('MCP Routes', () => {
|
|||
it('should return 403 when userId does not match authenticated user', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||
userId: 'different-user-id',
|
||||
flowId: 'test-flow-id',
|
||||
flowId: 'test-user-id:test-server',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
|
|
@ -228,7 +239,7 @@ describe('MCP Routes', () => {
|
|||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||
userId: 'test-user-id',
|
||||
flowId: 'test-flow-id',
|
||||
flowId: 'test-user-id:test-server',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
|
|
@ -245,7 +256,7 @@ describe('MCP Routes', () => {
|
|||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||
userId: 'test-user-id',
|
||||
flowId: 'test-flow-id',
|
||||
flowId: 'test-user-id:test-server',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
|
|
@ -255,7 +266,7 @@ describe('MCP Routes', () => {
|
|||
it('should return 400 when flow state metadata is null', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
id: 'test-flow-id',
|
||||
id: 'test-user-id:test-server',
|
||||
metadata: null,
|
||||
}),
|
||||
};
|
||||
|
|
@ -265,7 +276,7 @@ describe('MCP Routes', () => {
|
|||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||
userId: 'test-user-id',
|
||||
flowId: 'test-flow-id',
|
||||
flowId: 'test-user-id:test-server',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
|
|
@ -280,7 +291,7 @@ describe('MCP Routes', () => {
|
|||
it('should redirect to error page when OAuth error is received', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
error: 'access_denied',
|
||||
state: 'test-flow-id',
|
||||
state: 'test-user-id:test-server',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
|
|
@ -290,7 +301,7 @@ describe('MCP Routes', () => {
|
|||
|
||||
it('should redirect to error page when code is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
state: 'test-flow-id',
|
||||
state: 'test-user-id:test-server',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
|
|
@ -308,15 +319,50 @@ describe('MCP Routes', () => {
|
|||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when flow state is not found', async () => {
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(null);
|
||||
|
||||
it('should redirect to error page when CSRF cookie is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'invalid-flow-id',
|
||||
state: 'test-user-id:test-server',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should redirect to error page when CSRF cookie does not match state', async () => {
|
||||
const csrfToken = generateTestCsrfToken('different-flow-id');
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-user-id:test-server',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should redirect to error page when flow state is not found', async () => {
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(null);
|
||||
const flowId = 'invalid-flow:id';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
||||
});
|
||||
|
|
@ -369,16 +415,22 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
setCachedTools.mockResolvedValue();
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
'test-flow-id',
|
||||
flowId,
|
||||
'test-auth-code',
|
||||
mockFlowManager,
|
||||
{},
|
||||
|
|
@ -400,16 +452,24 @@ describe('MCP Routes', () => {
|
|||
'mcp_oauth',
|
||||
mockTokens,
|
||||
);
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(
|
||||
'test-user-id:test-server',
|
||||
'mcp_get_tokens',
|
||||
);
|
||||
});
|
||||
|
||||
it('should redirect to error page when callback processing fails', async () => {
|
||||
MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error'));
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
|
|
@ -442,15 +502,21 @@ describe('MCP Routes', () => {
|
|||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
it('should handle reconnection failure after OAuth', async () => {
|
||||
|
|
@ -488,16 +554,22 @@ describe('MCP Routes', () => {
|
|||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
it('should redirect to error page if token storage fails', async () => {
|
||||
|
|
@ -530,10 +602,16 @@ describe('MCP Routes', () => {
|
|||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
|
|
@ -589,22 +667,27 @@ describe('MCP Routes', () => {
|
|||
clearReconnection: jest.fn(),
|
||||
});
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify storeTokens was called with ORIGINAL flow state credentials
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
userId: 'test-user-id',
|
||||
serverName: 'test-server',
|
||||
tokens: mockTokens,
|
||||
clientInfo: clientInfo, // Uses original flow state, not any "updated" credentials
|
||||
clientInfo: clientInfo,
|
||||
metadata: flowState.metadata,
|
||||
}),
|
||||
);
|
||||
|
|
@ -631,16 +714,21 @@ describe('MCP Routes', () => {
|
|||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
code: 'test-auth-code',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify completeOAuthFlow was NOT called (prevented duplicate)
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled();
|
||||
expect(MCPTokenStorage.storeTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
|
@ -755,7 +843,7 @@ describe('MCP Routes', () => {
|
|||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/oauth/status/test-flow-id');
|
||||
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:test-server');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual({
|
||||
|
|
@ -766,6 +854,13 @@ describe('MCP Routes', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('should return 403 when flowId does not match authenticated user', async () => {
|
||||
const response = await request(app).get('/api/mcp/oauth/status/other-user-id:test-server');
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body).toEqual({ error: 'Access denied' });
|
||||
});
|
||||
|
||||
it('should return 404 when flow is not found', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue(null),
|
||||
|
|
@ -774,7 +869,7 @@ describe('MCP Routes', () => {
|
|||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/oauth/status/non-existent-flow');
|
||||
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:non-existent');
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(response.body).toEqual({ error: 'Flow not found' });
|
||||
|
|
@ -788,7 +883,7 @@ describe('MCP Routes', () => {
|
|||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/oauth/status/error-flow-id');
|
||||
const response = await request(app).get('/api/mcp/oauth/status/test-user-id:error-server');
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.body).toEqual({ error: 'Failed to get flow status' });
|
||||
|
|
@ -1375,7 +1470,7 @@ describe('MCP Routes', () => {
|
|||
refresh_token: 'edge-refresh-token',
|
||||
};
|
||||
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
|
||||
id: 'test-flow-id',
|
||||
id: 'test-user-id:test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: {
|
||||
serverUrl: 'https://example.com',
|
||||
|
|
@ -1403,8 +1498,12 @@ describe('MCP Routes', () => {
|
|||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`)
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.expect(302);
|
||||
|
||||
const basePath = getBasePath();
|
||||
|
|
@ -1424,7 +1523,7 @@ describe('MCP Routes', () => {
|
|||
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
id: 'test-flow-id',
|
||||
id: 'test-user-id:test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: { serverUrl: 'https://example.com', oauth: {} },
|
||||
clientInfo: {},
|
||||
|
|
@ -1453,8 +1552,12 @@ describe('MCP Routes', () => {
|
|||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`)
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.expect(302);
|
||||
|
||||
const basePath = getBasePath();
|
||||
|
|
|
|||
|
|
@ -1,14 +1,47 @@
|
|||
const express = require('express');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { getAccessToken, getBasePath } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
getBasePath,
|
||||
getAccessToken,
|
||||
setOAuthSession,
|
||||
validateOAuthCsrf,
|
||||
OAUTH_CSRF_COOKIE,
|
||||
setOAuthCsrfCookie,
|
||||
validateOAuthSession,
|
||||
OAUTH_SESSION_COOKIE,
|
||||
} = require('@librechat/api');
|
||||
const { findToken, updateToken, createToken } = require('~/models');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const router = express.Router();
|
||||
const JWT_SECRET = process.env.JWT_SECRET;
|
||||
const OAUTH_CSRF_COOKIE_PATH = '/api/actions';
|
||||
|
||||
/**
|
||||
* Sets a CSRF cookie binding the action OAuth flow to the current browser session.
|
||||
* Must be called before the user opens the IdP authorization URL.
|
||||
*
|
||||
* @route POST /actions/:action_id/oauth/bind
|
||||
*/
|
||||
router.post('/:action_id/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => {
|
||||
try {
|
||||
const { action_id } = req.params;
|
||||
const user = req.user;
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
const flowId = `${user.id}:${action_id}`;
|
||||
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
res.json({ success: true });
|
||||
} catch (error) {
|
||||
logger.error('[Action OAuth] Failed to set CSRF binding cookie', error);
|
||||
res.status(500).json({ error: 'Failed to bind OAuth flow' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Handles the OAuth callback and exchanges the authorization code for tokens.
|
||||
|
|
@ -45,7 +78,22 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
identifier = `${decodedState.user}:${action_id}`;
|
||||
|
||||
if (
|
||||
!validateOAuthCsrf(req, res, identifier, OAUTH_CSRF_COOKIE_PATH) &&
|
||||
!validateOAuthSession(req, decodedState.user)
|
||||
) {
|
||||
logger.error('[Action OAuth] CSRF validation failed: no valid CSRF or session cookie', {
|
||||
identifier,
|
||||
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
||||
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
||||
});
|
||||
await flowManager.failFlow(identifier, 'oauth', 'CSRF validation failed');
|
||||
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
||||
}
|
||||
|
||||
const flowState = await flowManager.getFlowState(identifier, 'oauth');
|
||||
if (!flowState) {
|
||||
throw new Error('OAuth flow not found');
|
||||
|
|
@ -71,7 +119,6 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
);
|
||||
await flowManager.completeFlow(identifier, 'oauth', tokenData);
|
||||
|
||||
/** Redirect to React success page */
|
||||
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
|
||||
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
|
|
|
|||
|
|
@ -8,18 +8,32 @@ const {
|
|||
Permissions,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getBasePath,
|
||||
createSafeUser,
|
||||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
getBasePath,
|
||||
setOAuthSession,
|
||||
getUserMCPAuthMap,
|
||||
validateOAuthCsrf,
|
||||
OAUTH_CSRF_COOKIE,
|
||||
setOAuthCsrfCookie,
|
||||
generateCheckAccess,
|
||||
validateOAuthSession,
|
||||
OAUTH_SESSION_COOKIE,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
getMCPManager,
|
||||
getFlowStateManager,
|
||||
createMCPServerController,
|
||||
updateMCPServerController,
|
||||
deleteMCPServerController,
|
||||
getMCPServersList,
|
||||
getMCPServerById,
|
||||
getMCPTools,
|
||||
} = require('~/server/controllers/mcp');
|
||||
const {
|
||||
getOAuthReconnectionManager,
|
||||
getMCPServersRegistry,
|
||||
getFlowStateManager,
|
||||
getMCPManager,
|
||||
} = require('~/config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware');
|
||||
|
|
@ -27,20 +41,14 @@ const { findToken, updateToken, createToken, deleteTokens } = require('~/models'
|
|||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { getMCPTools } = require('~/server/controllers/mcp');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const {
|
||||
createMCPServerController,
|
||||
getMCPServerById,
|
||||
getMCPServersList,
|
||||
updateMCPServerController,
|
||||
deleteMCPServerController,
|
||||
} = require('~/server/controllers/mcp');
|
||||
|
||||
const router = Router();
|
||||
|
||||
const OAUTH_CSRF_COOKIE_PATH = '/api/mcp';
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
* Returns only MCP tools, completely decoupled from regular LibreChat tools
|
||||
|
|
@ -53,7 +61,7 @@ router.get('/tools', requireJwtAuth, async (req, res) => {
|
|||
* Initiate OAuth flow
|
||||
* This endpoint is called when the user clicks the auth link in the UI
|
||||
*/
|
||||
router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
||||
router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const { userId, flowId } = req.query;
|
||||
|
|
@ -93,7 +101,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
|||
|
||||
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
|
||||
|
||||
// Redirect user to the authorization URL
|
||||
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
res.redirect(authorizationUrl);
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] Failed to initiate OAuth', error);
|
||||
|
|
@ -138,6 +146,25 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
const flowId = state;
|
||||
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
|
||||
|
||||
const flowParts = flowId.split(':');
|
||||
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
|
||||
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
const [flowUserId] = flowParts;
|
||||
if (
|
||||
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
|
||||
!validateOAuthSession(req, flowUserId)
|
||||
) {
|
||||
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
|
||||
flowId,
|
||||
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
||||
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
||||
});
|
||||
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
|
|
@ -302,13 +329,47 @@ router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Set CSRF binding cookie for OAuth flows initiated outside of HTTP request/response
|
||||
* (e.g. during chat via SSE). The frontend should call this before opening the OAuth URL
|
||||
* so the callback can verify the browser matches the flow initiator.
|
||||
*/
|
||||
router.post('/:serverName/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
|
||||
res.json({ success: true });
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] Failed to set CSRF binding cookie', error);
|
||||
res.status(500).json({ error: 'Failed to bind OAuth flow' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Check OAuth flow status
|
||||
* This endpoint can be used to poll the status of an OAuth flow
|
||||
*/
|
||||
router.get('/oauth/status/:flowId', async (req, res) => {
|
||||
router.get('/oauth/status/:flowId', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { flowId } = req.params;
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) {
|
||||
return res.status(403).json({ error: 'Access denied' });
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
|
|
@ -375,7 +436,7 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
|||
* Reinitialize MCP server
|
||||
* This endpoint allows reinitializing a specific MCP server
|
||||
*/
|
||||
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const user = createSafeUser(req.user);
|
||||
|
|
@ -421,6 +482,11 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
|||
|
||||
const { success, message, oauthRequired, oauthUrl } = result;
|
||||
|
||||
if (oauthRequired) {
|
||||
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
}
|
||||
|
||||
res.json({
|
||||
success,
|
||||
message,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ const oauthHandler = createOAuthHandler();
|
|||
|
||||
router.get('/error', (req, res) => {
|
||||
/** A single error message is pushed by passport when authentication fails. */
|
||||
const errorMessage = req.session?.messages?.pop() || 'Unknown error';
|
||||
const errorMessage = req.session?.messages?.pop() || 'Unknown OAuth error';
|
||||
logger.error('Error in OAuth authentication:', {
|
||||
message: errorMessage,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1339,6 +1339,7 @@ async function loadActionToolsForExecution({
|
|||
});
|
||||
}
|
||||
|
||||
const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g');
|
||||
for (const toolName of actionToolNames) {
|
||||
let currentDomain = '';
|
||||
for (const domain of domainMap.keys()) {
|
||||
|
|
@ -1355,7 +1356,6 @@ async function loadActionToolsForExecution({
|
|||
|
||||
const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } =
|
||||
processedActionSets.get(currentDomain);
|
||||
const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g');
|
||||
const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_');
|
||||
const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, '');
|
||||
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
import { useMemo, useState, useEffect, useRef, useLayoutEffect } from 'react';
|
||||
import { useMemo, useState, useEffect, useRef, useCallback, useLayoutEffect } from 'react';
|
||||
import { Button } from '@librechat/client';
|
||||
import { TriangleAlert } from 'lucide-react';
|
||||
import { actionDelimiter, actionDomainSeparator, Constants } from 'librechat-data-provider';
|
||||
import {
|
||||
Constants,
|
||||
dataService,
|
||||
actionDelimiter,
|
||||
actionDomainSeparator,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TAttachment } from 'librechat-data-provider';
|
||||
import { useLocalize, useProgress } from '~/hooks';
|
||||
import { AttachmentGroup } from './Parts';
|
||||
|
|
@ -36,9 +41,9 @@ export default function ToolCall({
|
|||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const prevShowInfoRef = useRef<boolean>(showInfo);
|
||||
|
||||
const { function_name, domain, isMCPToolCall } = useMemo(() => {
|
||||
const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => {
|
||||
if (typeof name !== 'string') {
|
||||
return { function_name: '', domain: null, isMCPToolCall: false };
|
||||
return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' };
|
||||
}
|
||||
if (name.includes(Constants.mcp_delimiter)) {
|
||||
const [func, server] = name.split(Constants.mcp_delimiter);
|
||||
|
|
@ -46,6 +51,7 @@ export default function ToolCall({
|
|||
function_name: func || '',
|
||||
domain: server && (server.replaceAll(actionDomainSeparator, '.') || null),
|
||||
isMCPToolCall: true,
|
||||
mcpServerName: server || '',
|
||||
};
|
||||
}
|
||||
const [func, _domain] = name.includes(actionDelimiter)
|
||||
|
|
@ -55,9 +61,40 @@ export default function ToolCall({
|
|||
function_name: func || '',
|
||||
domain: _domain && (_domain.replaceAll(actionDomainSeparator, '.') || null),
|
||||
isMCPToolCall: false,
|
||||
mcpServerName: '',
|
||||
};
|
||||
}, [name]);
|
||||
|
||||
const actionId = useMemo(() => {
|
||||
if (isMCPToolCall || !auth) {
|
||||
return '';
|
||||
}
|
||||
try {
|
||||
const url = new URL(auth);
|
||||
const redirectUri = url.searchParams.get('redirect_uri') || '';
|
||||
const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/);
|
||||
return match?.[1] || '';
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
}, [auth, isMCPToolCall]);
|
||||
|
||||
const handleOAuthClick = useCallback(async () => {
|
||||
if (!auth) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (isMCPToolCall && mcpServerName) {
|
||||
await dataService.bindMCPOAuth(mcpServerName);
|
||||
} else if (actionId) {
|
||||
await dataService.bindActionOAuth(actionId);
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('Failed to bind OAuth CSRF cookie', e);
|
||||
}
|
||||
window.open(auth, '_blank', 'noopener,noreferrer');
|
||||
}, [auth, isMCPToolCall, mcpServerName, actionId]);
|
||||
|
||||
const error =
|
||||
typeof output === 'string' && output.toLowerCase().includes('error processing tool');
|
||||
|
||||
|
|
@ -230,7 +267,7 @@ export default function ToolCall({
|
|||
className="font-mediu inline-flex items-center justify-center rounded-xl px-4 py-2 text-sm"
|
||||
variant="default"
|
||||
rel="noopener noreferrer"
|
||||
onClick={() => window.open(auth, '_blank', 'noopener,noreferrer')}
|
||||
onClick={handleOAuthClick}
|
||||
>
|
||||
{localize('com_ui_sign_in_to_domain', { 0: authDomain })}
|
||||
</Button>
|
||||
|
|
|
|||
|
|
@ -298,38 +298,45 @@ export class MCPConnectionFactory {
|
|||
const oauthHandler = async (data: { serverUrl?: string }) => {
|
||||
logger.info(`${this.logPrefix} oauthRequired event received`);
|
||||
|
||||
// If we just want to initiate OAuth and return, handle it differently
|
||||
if (this.returnOnOAuth) {
|
||||
try {
|
||||
const config = this.serverConfig;
|
||||
const { authorizationUrl, flowId, flowMetadata } =
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
this.serverName,
|
||||
data.serverUrl || '',
|
||||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
|
||||
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow?.status === 'PENDING') {
|
||||
logger.debug(
|
||||
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Delete any existing flow state to ensure we start fresh
|
||||
// This prevents stale codeVerifier issues when re-authenticating
|
||||
await this.flowManager!.deleteFlow(flowId, 'mcp_oauth');
|
||||
const {
|
||||
authorizationUrl,
|
||||
flowId: newFlowId,
|
||||
flowMetadata,
|
||||
} = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
this.serverName,
|
||||
data.serverUrl || '',
|
||||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
);
|
||||
|
||||
// Create the flow state so the OAuth callback can find it
|
||||
// We spawn this in the background without waiting for it
|
||||
// Pass signal so the flow can be aborted if the request is cancelled
|
||||
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
|
||||
// The OAuth callback will resolve this flow, so we expect it to timeout here
|
||||
// or it will be aborted if the request is cancelled - both are fine
|
||||
});
|
||||
if (existingFlow) {
|
||||
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
|
||||
}
|
||||
|
||||
this.flowManager!.createFlow(newFlowId, 'mcp_oauth', flowMetadata, this.signal).catch(
|
||||
() => {},
|
||||
);
|
||||
|
||||
if (this.oauthStart) {
|
||||
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
|
||||
await this.oauthStart(authorizationUrl);
|
||||
}
|
||||
|
||||
// Emit oauthFailed to signal that connection should not proceed
|
||||
// but OAuth was successfully initiated
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
} catch (error) {
|
||||
|
|
@ -391,11 +398,9 @@ export class MCPConnectionFactory {
|
|||
logger.error(`${this.logPrefix} Failed to establish connection.`);
|
||||
}
|
||||
|
||||
// Handles connection attempts with retry logic and OAuth error handling
|
||||
private async connectTo(connection: MCPConnection): Promise<void> {
|
||||
const maxAttempts = 3;
|
||||
let attempts = 0;
|
||||
let oauthHandled = false;
|
||||
|
||||
while (attempts < maxAttempts) {
|
||||
try {
|
||||
|
|
@ -408,22 +413,6 @@ export class MCPConnectionFactory {
|
|||
attempts++;
|
||||
|
||||
if (this.useOAuth && this.isOAuthError(error)) {
|
||||
// For returnOnOAuth mode, let the event handler (handleOAuthEvents) deal with OAuth
|
||||
// We just need to stop retrying and let the error propagate
|
||||
if (this.returnOnOAuth) {
|
||||
logger.info(
|
||||
`${this.logPrefix} OAuth required (return on OAuth mode), stopping retries`,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Normal flow - wait for OAuth to complete
|
||||
if (this.oauthStart && !oauthHandled) {
|
||||
oauthHandled = true;
|
||||
logger.info(`${this.logPrefix} Handling OAuth`);
|
||||
await this.handleOAuthRequired();
|
||||
}
|
||||
// Don't retry on OAuth errors - just throw
|
||||
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
|
||||
throw error;
|
||||
}
|
||||
|
|
@ -499,26 +488,15 @@ export class MCPConnectionFactory {
|
|||
/** Check if there's already an ongoing OAuth flow for this flowId */
|
||||
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
// If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
|
||||
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
|
||||
if (existingFlow) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
|
||||
);
|
||||
try {
|
||||
if (existingFlow.status === 'PENDING') {
|
||||
await this.flowManager.failFlow(
|
||||
flowId,
|
||||
'mcp_oauth',
|
||||
new Error('Cancelled for new OAuth request'),
|
||||
);
|
||||
} else {
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
}
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
} catch (error) {
|
||||
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
|
||||
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
|
||||
}
|
||||
// Continue to start a new flow below
|
||||
}
|
||||
|
||||
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);
|
||||
|
|
|
|||
|
|
@ -270,7 +270,54 @@ describe('MCPConnectionFactory', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should delete existing flow before creating new OAuth flow to prevent stale codeVerifier', async () => {
|
||||
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: mockServerConfig,
|
||||
user: mockUser,
|
||||
};
|
||||
|
||||
const oauthOptions: t.OAuthConnectionOptions = {
|
||||
user: mockUser,
|
||||
useOAuth: true,
|
||||
returnOnOAuth: true,
|
||||
oauthStart: jest.fn(),
|
||||
flowManager: mockFlowManager,
|
||||
};
|
||||
|
||||
mockFlowManager.getFlowState.mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
type: 'mcp_oauth',
|
||||
metadata: { codeVerifier: 'existing-verifier' },
|
||||
createdAt: Date.now(),
|
||||
});
|
||||
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||
|
||||
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
|
||||
mockConnectionInstance.on.mockImplementation((event, handler) => {
|
||||
if (event === 'oauthRequired') {
|
||||
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
|
||||
}
|
||||
return mockConnectionInstance;
|
||||
});
|
||||
|
||||
try {
|
||||
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||
|
||||
expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled();
|
||||
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
|
||||
'oauthFailed',
|
||||
expect.objectContaining({ message: 'OAuth flow initiated - return early' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: mockServerConfig,
|
||||
|
|
@ -303,6 +350,12 @@ describe('MCPConnectionFactory', () => {
|
|||
},
|
||||
};
|
||||
|
||||
mockFlowManager.getFlowState.mockResolvedValue({
|
||||
status: 'COMPLETED',
|
||||
type: 'mcp_oauth',
|
||||
metadata: { codeVerifier: 'old-verifier' },
|
||||
createdAt: Date.now() - 60000,
|
||||
});
|
||||
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
|
||||
mockFlowManager.deleteFlow.mockResolvedValue(true);
|
||||
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
|
||||
|
|
@ -319,21 +372,17 @@ describe('MCPConnectionFactory', () => {
|
|||
try {
|
||||
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||
} catch {
|
||||
// Expected to fail due to connection not established
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||
|
||||
// Verify deleteFlow was called with correct parameters
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth');
|
||||
|
||||
// Verify deleteFlow was called before createFlow
|
||||
const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0];
|
||||
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
|
||||
expect(deleteCallOrder).toBeLessThan(createCallOrder);
|
||||
|
||||
// Verify createFlow was called with fresh metadata
|
||||
// 4th arg is the abort signal (undefined in this test since no signal was provided)
|
||||
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
|
||||
'user123:test-server',
|
||||
'mcp_oauth',
|
||||
|
|
|
|||
89
packages/api/src/oauth/csrf.ts
Normal file
89
packages/api/src/oauth/csrf.ts
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import crypto from 'crypto';
|
||||
import type { Request, Response, NextFunction } from 'express';
|
||||
|
||||
export const OAUTH_CSRF_COOKIE = 'oauth_csrf';
|
||||
export const OAUTH_CSRF_MAX_AGE = 10 * 60 * 1000;
|
||||
|
||||
export const OAUTH_SESSION_COOKIE = 'oauth_session';
|
||||
export const OAUTH_SESSION_MAX_AGE = 24 * 60 * 60 * 1000;
|
||||
export const OAUTH_SESSION_COOKIE_PATH = '/api';
|
||||
|
||||
const isProduction = process.env.NODE_ENV === 'production';
|
||||
|
||||
/** Generates an HMAC-based token for OAuth CSRF protection */
|
||||
export function generateOAuthCsrfToken(flowId: string, secret?: string): string {
|
||||
const key = secret || process.env.JWT_SECRET;
|
||||
if (!key) {
|
||||
throw new Error('JWT_SECRET is required for OAuth CSRF token generation');
|
||||
}
|
||||
return crypto.createHmac('sha256', key).update(flowId).digest('hex').slice(0, 32);
|
||||
}
|
||||
|
||||
/** Sets a SameSite=Lax CSRF cookie bound to a specific OAuth flow */
|
||||
export function setOAuthCsrfCookie(res: Response, flowId: string, cookiePath: string): void {
|
||||
res.cookie(OAUTH_CSRF_COOKIE, generateOAuthCsrfToken(flowId), {
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'lax',
|
||||
maxAge: OAUTH_CSRF_MAX_AGE,
|
||||
path: cookiePath,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the per-flow CSRF cookie against the expected HMAC.
|
||||
* Uses timing-safe comparison and always clears the cookie to prevent replay.
|
||||
*/
|
||||
export function validateOAuthCsrf(
|
||||
req: Request,
|
||||
res: Response,
|
||||
flowId: string,
|
||||
cookiePath: string,
|
||||
): boolean {
|
||||
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_CSRF_COOKIE];
|
||||
res.clearCookie(OAUTH_CSRF_COOKIE, { path: cookiePath });
|
||||
if (!cookie) {
|
||||
return false;
|
||||
}
|
||||
const expected = generateOAuthCsrfToken(flowId);
|
||||
if (cookie.length !== expected.length) {
|
||||
return false;
|
||||
}
|
||||
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
|
||||
}
|
||||
|
||||
/**
|
||||
* Express middleware that sets the OAuth session cookie after JWT authentication.
|
||||
* Chain after requireJwtAuth on routes that precede an OAuth redirect (e.g., reinitialize, bind).
|
||||
*/
|
||||
export function setOAuthSession(req: Request, res: Response, next: NextFunction): void {
|
||||
const user = (req as Request & { user?: { id?: string } }).user;
|
||||
if (user?.id && !(req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE]) {
|
||||
setOAuthSessionCookie(res, user.id);
|
||||
}
|
||||
next();
|
||||
}
|
||||
|
||||
/** Sets a SameSite=Lax session cookie that binds the browser to the authenticated userId */
|
||||
export function setOAuthSessionCookie(res: Response, userId: string): void {
|
||||
res.cookie(OAUTH_SESSION_COOKIE, generateOAuthCsrfToken(userId), {
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'lax',
|
||||
maxAge: OAUTH_SESSION_MAX_AGE,
|
||||
path: OAUTH_SESSION_COOKIE_PATH,
|
||||
});
|
||||
}
|
||||
|
||||
/** Validates the session cookie against the expected userId using timing-safe comparison */
|
||||
export function validateOAuthSession(req: Request, userId: string): boolean {
|
||||
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE];
|
||||
if (!cookie) {
|
||||
return false;
|
||||
}
|
||||
const expected = generateOAuthCsrfToken(userId);
|
||||
if (cookie.length !== expected.length) {
|
||||
return false;
|
||||
}
|
||||
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
|
||||
}
|
||||
|
|
@ -1 +1,2 @@
|
|||
export * from './csrf';
|
||||
export * from './tokens';
|
||||
|
|
|
|||
|
|
@ -181,6 +181,11 @@ export const cancelMCPOAuth = (serverName: string) => {
|
|||
return `${BASE_URL}/api/mcp/oauth/cancel/${serverName}`;
|
||||
};
|
||||
|
||||
export const mcpOAuthBind = (serverName: string) => `${BASE_URL}/api/mcp/${serverName}/oauth/bind`;
|
||||
|
||||
export const actionOAuthBind = (actionId: string) =>
|
||||
`${BASE_URL}/api/actions/${actionId}/oauth/bind`;
|
||||
|
||||
export const config = () => `${BASE_URL}/api/config`;
|
||||
|
||||
export const prompts = () => `${BASE_URL}/api/prompts`;
|
||||
|
|
|
|||
|
|
@ -178,6 +178,14 @@ export const reinitializeMCPServer = (serverName: string) => {
|
|||
return request.post(endpoints.mcpReinitialize(serverName));
|
||||
};
|
||||
|
||||
export const bindMCPOAuth = (serverName: string): Promise<{ success: boolean }> => {
|
||||
return request.post(endpoints.mcpOAuthBind(serverName));
|
||||
};
|
||||
|
||||
export const bindActionOAuth = (actionId: string): Promise<{ success: boolean }> => {
|
||||
return request.post(endpoints.actionOAuthBind(actionId));
|
||||
};
|
||||
|
||||
export const getMCPConnectionStatus = (): Promise<q.MCPConnectionStatusResponse> => {
|
||||
return request.get(endpoints.mcpConnectionStatus());
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue