From 599f4a11f185d79ca885c60561c20c58fc63a31f Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 12 Feb 2026 14:22:05 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20fix:=20Secure=20MCP/Act?= =?UTF-8?q?ions=20OAuth=20Flows,=20Resolve=20Race=20Condition=20&=20Tool?= =?UTF-8?q?=20Cache=20Cleanup=20(#11756)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 fix: Update OAuth error message for clarity - Changed the default error message in the OAuth error route from 'Unknown error' to 'Unknown OAuth error' to provide clearer context during authentication failures. * 🔒 feat: Enhance OAuth flow with CSRF protection and session management - Implemented CSRF protection for OAuth flows by introducing `generateOAuthCsrfToken`, `setOAuthCsrfCookie`, and `validateOAuthCsrf` functions. - Added session management for OAuth with `setOAuthSession` and `validateOAuthSession` middleware. - Updated routes to bind CSRF tokens for MCP and action OAuth flows, ensuring secure authentication. - Enhanced tests to validate CSRF handling and session management in OAuth processes. * 🔧 refactor: Invalidate cached tools after user plugin disconnection - Added a call to `invalidateCachedTools` in the `updateUserPluginsController` to ensure that cached tools are refreshed when a user disconnects from an MCP server after a plugin authentication update. This change improves the accuracy of tool data for users. * chore: imports order * fix: domain separator regex usage in ToolService - Moved the declaration of `domainSeparatorRegex` to avoid redundancy in the `loadActionToolsForExecution` function, improving code clarity and performance. * chore: OAuth flow error handling and CSRF token generation - Enhanced the OAuth callback route to validate the flow ID format, ensuring proper error handling for invalid states. - Updated the CSRF token generation function to require a JWT secret, throwing an error if not provided, which improves security and clarity in token generation. - Adjusted tests to reflect changes in flow ID handling and ensure robust validation across various scenarios. --- api/server/controllers/UserController.js | 2 + api/server/middleware/requireJwtAuth.js | 3 - api/server/routes/__tests__/mcp.spec.js | 215 +++++++++++++----- api/server/routes/actions.js | 51 ++++- api/server/routes/mcp.js | 96 ++++++-- api/server/routes/oauth.js | 2 +- api/server/services/ToolService.js | 2 +- .../Chat/Messages/Content/ToolCall.tsx | 47 +++- packages/api/src/mcp/MCPConnectionFactory.ts | 82 +++---- .../__tests__/MCPConnectionFactory.test.ts | 61 ++++- packages/api/src/oauth/csrf.ts | 89 ++++++++ packages/api/src/oauth/index.ts | 1 + packages/data-provider/src/api-endpoints.ts | 5 + packages/data-provider/src/data-service.ts | 8 + 14 files changed, 523 insertions(+), 141 deletions(-) create mode 100644 packages/api/src/oauth/csrf.ts diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 0f17b4d3a9..7a9dd8125e 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -36,6 +36,7 @@ const { const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config'); +const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools'); const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud'); const { processDeleteRequest } = require('~/server/services/Files/process'); const { getAppConfig } = require('~/server/services/Config'); @@ -215,6 +216,7 @@ const updateUserPluginsController = async (req, res) => { `[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`, ); await mcpManager.disconnectUserConnection(user.id, serverName); + await invalidateCachedTools({ userId: user.id, serverName }); } } catch (disconnectError) { logger.error( diff --git a/api/server/middleware/requireJwtAuth.js b/api/server/middleware/requireJwtAuth.js index ed83c4773e..16b107aefc 100644 --- a/api/server/middleware/requireJwtAuth.js +++ b/api/server/middleware/requireJwtAuth.js @@ -7,16 +7,13 @@ const { isEnabled } = require('@librechat/api'); * Switches between JWT and OpenID authentication based on cookies and environment settings */ const requireJwtAuth = (req, res, next) => { - // Check if token provider is specified in cookies const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; - // Use OpenID authentication if token provider is OpenID and OPENID_REUSE_TOKENS is enabled if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) { return passport.authenticate('openidJwt', { session: false })(req, res, next); } - // Default to standard JWT authentication return passport.authenticate('jwt', { session: false })(req, res, next); }; diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 26d7988f0a..e87fcf8f15 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -1,8 +1,18 @@ +const crypto = require('crypto'); const express = require('express'); const request = require('supertest'); const mongoose = require('mongoose'); -const { MongoMemoryServer } = require('mongodb-memory-server'); +const cookieParser = require('cookie-parser'); const { getBasePath } = require('@librechat/api'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +function generateTestCsrfToken(flowId) { + return crypto + .createHmac('sha256', process.env.JWT_SECRET) + .update(flowId) + .digest('hex') + .slice(0, 32); +} const mockRegistryInstance = { getServerConfig: jest.fn(), @@ -130,6 +140,7 @@ describe('MCP Routes', () => { app = express(); app.use(express.json()); + app.use(cookieParser()); app.use((req, res, next) => { req.user = { id: 'test-user-id' }; @@ -168,12 +179,12 @@ describe('MCP Routes', () => { MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ authorizationUrl: 'https://oauth.example.com/auth', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(302); @@ -190,7 +201,7 @@ describe('MCP Routes', () => { it('should return 403 when userId does not match authenticated user', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'different-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(403); @@ -228,7 +239,7 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(400); @@ -245,7 +256,7 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(500); @@ -255,7 +266,7 @@ describe('MCP Routes', () => { it('should return 400 when flow state metadata is null', async () => { const mockFlowManager = { getFlowState: jest.fn().mockResolvedValue({ - id: 'test-flow-id', + id: 'test-user-id:test-server', metadata: null, }), }; @@ -265,7 +276,7 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(400); @@ -280,7 +291,7 @@ describe('MCP Routes', () => { it('should redirect to error page when OAuth error is received', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ error: 'access_denied', - state: 'test-flow-id', + state: 'test-user-id:test-server', }); const basePath = getBasePath(); @@ -290,7 +301,7 @@ describe('MCP Routes', () => { it('should redirect to error page when code is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - state: 'test-flow-id', + state: 'test-user-id:test-server', }); const basePath = getBasePath(); @@ -308,15 +319,50 @@ describe('MCP Routes', () => { expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`); }); - it('should redirect to error page when flow state is not found', async () => { - MCPOAuthHandler.getFlowState.mockResolvedValue(null); - + it('should redirect to error page when CSRF cookie is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ code: 'test-auth-code', - state: 'invalid-flow-id', + state: 'test-user-id:test-server', }); const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should redirect to error page when CSRF cookie does not match state', async () => { + const csrfToken = generateTestCsrfToken('different-flow-id'); + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: 'test-user-id:test-server', + }); + const basePath = getBasePath(); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should redirect to error page when flow state is not found', async () => { + MCPOAuthHandler.getFlowState.mockResolvedValue(null); + const flowId = 'invalid-flow:id'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); + const basePath = getBasePath(); + expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`); }); @@ -369,16 +415,22 @@ describe('MCP Routes', () => { }); setCachedTools.mockResolvedValue(); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( - 'test-flow-id', + flowId, 'test-auth-code', mockFlowManager, {}, @@ -400,16 +452,24 @@ describe('MCP Routes', () => { 'mcp_oauth', mockTokens, ); - expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith( + 'test-user-id:test-server', + 'mcp_get_tokens', + ); }); it('should redirect to error page when callback processing fails', async () => { MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error')); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); @@ -442,15 +502,21 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); - expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens'); }); it('should handle reconnection failure after OAuth', async () => { @@ -488,16 +554,22 @@ describe('MCP Routes', () => { getCachedTools.mockResolvedValue({}); setCachedTools.mockResolvedValue(); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); expect(MCPTokenStorage.storeTokens).toHaveBeenCalled(); - expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens'); }); it('should redirect to error page if token storage fails', async () => { @@ -530,10 +602,16 @@ describe('MCP Routes', () => { }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); @@ -589,22 +667,27 @@ describe('MCP Routes', () => { clearReconnection: jest.fn(), }); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); - // Verify storeTokens was called with ORIGINAL flow state credentials expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith( expect.objectContaining({ userId: 'test-user-id', serverName: 'test-server', tokens: mockTokens, - clientInfo: clientInfo, // Uses original flow state, not any "updated" credentials + clientInfo: clientInfo, metadata: flowState.metadata, }), ); @@ -631,16 +714,21 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); - // Verify completeOAuthFlow was NOT called (prevented duplicate) expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled(); expect(MCPTokenStorage.storeTokens).not.toHaveBeenCalled(); }); @@ -755,7 +843,7 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/oauth/status/test-flow-id'); + const response = await request(app).get('/api/mcp/oauth/status/test-user-id:test-server'); expect(response.status).toBe(200); expect(response.body).toEqual({ @@ -766,6 +854,13 @@ describe('MCP Routes', () => { }); }); + it('should return 403 when flowId does not match authenticated user', async () => { + const response = await request(app).get('/api/mcp/oauth/status/other-user-id:test-server'); + + expect(response.status).toBe(403); + expect(response.body).toEqual({ error: 'Access denied' }); + }); + it('should return 404 when flow is not found', async () => { const mockFlowManager = { getFlowState: jest.fn().mockResolvedValue(null), @@ -774,7 +869,7 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/oauth/status/non-existent-flow'); + const response = await request(app).get('/api/mcp/oauth/status/test-user-id:non-existent'); expect(response.status).toBe(404); expect(response.body).toEqual({ error: 'Flow not found' }); @@ -788,7 +883,7 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/oauth/status/error-flow-id'); + const response = await request(app).get('/api/mcp/oauth/status/test-user-id:error-server'); expect(response.status).toBe(500); expect(response.body).toEqual({ error: 'Failed to get flow status' }); @@ -1375,7 +1470,7 @@ describe('MCP Routes', () => { refresh_token: 'edge-refresh-token', }; MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({ - id: 'test-flow-id', + id: 'test-user-id:test-server', userId: 'test-user-id', metadata: { serverUrl: 'https://example.com', @@ -1403,8 +1498,12 @@ describe('MCP Routes', () => { }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + const response = await request(app) - .get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id') + .get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`) + .set('Cookie', [`oauth_csrf=${csrfToken}`]) .expect(302); const basePath = getBasePath(); @@ -1424,7 +1523,7 @@ describe('MCP Routes', () => { const mockFlowManager = { getFlowState: jest.fn().mockResolvedValue({ - id: 'test-flow-id', + id: 'test-user-id:test-server', userId: 'test-user-id', metadata: { serverUrl: 'https://example.com', oauth: {} }, clientInfo: {}, @@ -1453,8 +1552,12 @@ describe('MCP Routes', () => { }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + const response = await request(app) - .get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id') + .get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`) + .set('Cookie', [`oauth_csrf=${csrfToken}`]) .expect(302); const basePath = getBasePath(); diff --git a/api/server/routes/actions.js b/api/server/routes/actions.js index 14474a53d3..806edc66cc 100644 --- a/api/server/routes/actions.js +++ b/api/server/routes/actions.js @@ -1,14 +1,47 @@ const express = require('express'); const jwt = require('jsonwebtoken'); -const { getAccessToken, getBasePath } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { + getBasePath, + getAccessToken, + setOAuthSession, + validateOAuthCsrf, + OAUTH_CSRF_COOKIE, + setOAuthCsrfCookie, + validateOAuthSession, + OAUTH_SESSION_COOKIE, +} = require('@librechat/api'); const { findToken, updateToken, createToken } = require('~/models'); +const { requireJwtAuth } = require('~/server/middleware'); const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); const router = express.Router(); const JWT_SECRET = process.env.JWT_SECRET; +const OAUTH_CSRF_COOKIE_PATH = '/api/actions'; + +/** + * Sets a CSRF cookie binding the action OAuth flow to the current browser session. + * Must be called before the user opens the IdP authorization URL. + * + * @route POST /actions/:action_id/oauth/bind + */ +router.post('/:action_id/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => { + try { + const { action_id } = req.params; + const user = req.user; + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + const flowId = `${user.id}:${action_id}`; + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + res.json({ success: true }); + } catch (error) { + logger.error('[Action OAuth] Failed to set CSRF binding cookie', error); + res.status(500).json({ error: 'Failed to bind OAuth flow' }); + } +}); /** * Handles the OAuth callback and exchanges the authorization code for tokens. @@ -45,7 +78,22 @@ router.get('/:action_id/oauth/callback', async (req, res) => { await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter'); return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } + identifier = `${decodedState.user}:${action_id}`; + + if ( + !validateOAuthCsrf(req, res, identifier, OAUTH_CSRF_COOKIE_PATH) && + !validateOAuthSession(req, decodedState.user) + ) { + logger.error('[Action OAuth] CSRF validation failed: no valid CSRF or session cookie', { + identifier, + hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], + hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], + }); + await flowManager.failFlow(identifier, 'oauth', 'CSRF validation failed'); + return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + } + const flowState = await flowManager.getFlowState(identifier, 'oauth'); if (!flowState) { throw new Error('OAuth flow not found'); @@ -71,7 +119,6 @@ router.get('/:action_id/oauth/callback', async (req, res) => { ); await flowManager.completeFlow(identifier, 'oauth', tokenData); - /** Redirect to React success page */ const serverName = flowState.metadata?.action_name || `Action ${action_id}`; const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`; res.redirect(redirectUrl); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index f01c7ff71c..2db8c2c462 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -8,18 +8,32 @@ const { Permissions, } = require('librechat-data-provider'); const { + getBasePath, createSafeUser, MCPOAuthHandler, MCPTokenStorage, - getBasePath, + setOAuthSession, getUserMCPAuthMap, + validateOAuthCsrf, + OAUTH_CSRF_COOKIE, + setOAuthCsrfCookie, generateCheckAccess, + validateOAuthSession, + OAUTH_SESSION_COOKIE, } = require('@librechat/api'); const { - getMCPManager, - getFlowStateManager, + createMCPServerController, + updateMCPServerController, + deleteMCPServerController, + getMCPServersList, + getMCPServerById, + getMCPTools, +} = require('~/server/controllers/mcp'); +const { getOAuthReconnectionManager, getMCPServersRegistry, + getFlowStateManager, + getMCPManager, } = require('~/config'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware'); @@ -27,20 +41,14 @@ const { findToken, updateToken, createToken, deleteTokens } = require('~/models' const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { updateMCPServerTools } = require('~/server/services/Config/mcp'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); -const { getMCPTools } = require('~/server/controllers/mcp'); const { findPluginAuthsByKeys } = require('~/models'); const { getRoleByName } = require('~/models/Role'); const { getLogStores } = require('~/cache'); -const { - createMCPServerController, - getMCPServerById, - getMCPServersList, - updateMCPServerController, - deleteMCPServerController, -} = require('~/server/controllers/mcp'); const router = Router(); +const OAUTH_CSRF_COOKIE_PATH = '/api/mcp'; + /** * Get all MCP tools available to the user * Returns only MCP tools, completely decoupled from regular LibreChat tools @@ -53,7 +61,7 @@ router.get('/tools', requireJwtAuth, async (req, res) => { * Initiate OAuth flow * This endpoint is called when the user clicks the auth link in the UI */ -router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { +router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async (req, res) => { try { const { serverName } = req.params; const { userId, flowId } = req.query; @@ -93,7 +101,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); - // Redirect user to the authorization URL + setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH); res.redirect(authorizationUrl); } catch (error) { logger.error('[MCP OAuth] Failed to initiate OAuth', error); @@ -138,6 +146,25 @@ router.get('/:serverName/oauth/callback', async (req, res) => { const flowId = state; logger.debug('[MCP OAuth] Using flow ID from state', { flowId }); + const flowParts = flowId.split(':'); + if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) { + logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId }); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); + } + + const [flowUserId] = flowParts; + if ( + !validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) && + !validateOAuthSession(req, flowUserId) + ) { + logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', { + flowId, + hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], + hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], + }); + return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + } + const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); @@ -302,13 +329,47 @@ router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => { } }); +/** + * Set CSRF binding cookie for OAuth flows initiated outside of HTTP request/response + * (e.g. during chat via SSE). The frontend should call this before opening the OAuth URL + * so the callback can verify the browser matches the flow initiator. + */ +router.post('/:serverName/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => { + try { + const { serverName } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + + res.json({ success: true }); + } catch (error) { + logger.error('[MCP OAuth] Failed to set CSRF binding cookie', error); + res.status(500).json({ error: 'Failed to bind OAuth flow' }); + } +}); + /** * Check OAuth flow status * This endpoint can be used to poll the status of an OAuth flow */ -router.get('/oauth/status/:flowId', async (req, res) => { +router.get('/oauth/status/:flowId', requireJwtAuth, async (req, res) => { try { const { flowId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) { + return res.status(403).json({ error: 'Access denied' }); + } + const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); @@ -375,7 +436,7 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => { * Reinitialize MCP server * This endpoint allows reinitializing a specific MCP server */ -router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { +router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => { try { const { serverName } = req.params; const user = createSafeUser(req.user); @@ -421,6 +482,11 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { const { success, message, oauthRequired, oauthUrl } = result; + if (oauthRequired) { + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + } + res.json({ success, message, diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 4a2e2f70c6..f4bb5b6026 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -29,7 +29,7 @@ const oauthHandler = createOAuthHandler(); router.get('/error', (req, res) => { /** A single error message is pushed by passport when authentication fails. */ - const errorMessage = req.session?.messages?.pop() || 'Unknown error'; + const errorMessage = req.session?.messages?.pop() || 'Unknown OAuth error'; logger.error('Error in OAuth authentication:', { message: errorMessage, }); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 7f8c1d0460..eedb95bd4d 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -1339,6 +1339,7 @@ async function loadActionToolsForExecution({ }); } + const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); for (const toolName of actionToolNames) { let currentDomain = ''; for (const domain of domainMap.keys()) { @@ -1355,7 +1356,6 @@ async function loadActionToolsForExecution({ const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = processedActionSets.get(currentDomain); - const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_'); const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, ''); const functionSig = functionSignatures.find((sig) => sig.name === functionName); diff --git a/client/src/components/Chat/Messages/Content/ToolCall.tsx b/client/src/components/Chat/Messages/Content/ToolCall.tsx index b9feef1bad..c807288b46 100644 --- a/client/src/components/Chat/Messages/Content/ToolCall.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCall.tsx @@ -1,7 +1,12 @@ -import { useMemo, useState, useEffect, useRef, useLayoutEffect } from 'react'; +import { useMemo, useState, useEffect, useRef, useCallback, useLayoutEffect } from 'react'; import { Button } from '@librechat/client'; import { TriangleAlert } from 'lucide-react'; -import { actionDelimiter, actionDomainSeparator, Constants } from 'librechat-data-provider'; +import { + Constants, + dataService, + actionDelimiter, + actionDomainSeparator, +} from 'librechat-data-provider'; import type { TAttachment } from 'librechat-data-provider'; import { useLocalize, useProgress } from '~/hooks'; import { AttachmentGroup } from './Parts'; @@ -36,9 +41,9 @@ export default function ToolCall({ const [isAnimating, setIsAnimating] = useState(false); const prevShowInfoRef = useRef(showInfo); - const { function_name, domain, isMCPToolCall } = useMemo(() => { + const { function_name, domain, isMCPToolCall, mcpServerName } = useMemo(() => { if (typeof name !== 'string') { - return { function_name: '', domain: null, isMCPToolCall: false }; + return { function_name: '', domain: null, isMCPToolCall: false, mcpServerName: '' }; } if (name.includes(Constants.mcp_delimiter)) { const [func, server] = name.split(Constants.mcp_delimiter); @@ -46,6 +51,7 @@ export default function ToolCall({ function_name: func || '', domain: server && (server.replaceAll(actionDomainSeparator, '.') || null), isMCPToolCall: true, + mcpServerName: server || '', }; } const [func, _domain] = name.includes(actionDelimiter) @@ -55,9 +61,40 @@ export default function ToolCall({ function_name: func || '', domain: _domain && (_domain.replaceAll(actionDomainSeparator, '.') || null), isMCPToolCall: false, + mcpServerName: '', }; }, [name]); + const actionId = useMemo(() => { + if (isMCPToolCall || !auth) { + return ''; + } + try { + const url = new URL(auth); + const redirectUri = url.searchParams.get('redirect_uri') || ''; + const match = redirectUri.match(/\/api\/actions\/([^/]+)\/oauth\/callback/); + return match?.[1] || ''; + } catch { + return ''; + } + }, [auth, isMCPToolCall]); + + const handleOAuthClick = useCallback(async () => { + if (!auth) { + return; + } + try { + if (isMCPToolCall && mcpServerName) { + await dataService.bindMCPOAuth(mcpServerName); + } else if (actionId) { + await dataService.bindActionOAuth(actionId); + } + } catch (e) { + logger.error('Failed to bind OAuth CSRF cookie', e); + } + window.open(auth, '_blank', 'noopener,noreferrer'); + }, [auth, isMCPToolCall, mcpServerName, actionId]); + const error = typeof output === 'string' && output.toLowerCase().includes('error processing tool'); @@ -230,7 +267,7 @@ export default function ToolCall({ className="font-mediu inline-flex items-center justify-center rounded-xl px-4 py-2 text-sm" variant="default" rel="noopener noreferrer" - onClick={() => window.open(auth, '_blank', 'noopener,noreferrer')} + onClick={handleOAuthClick} > {localize('com_ui_sign_in_to_domain', { 0: authDomain })} diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 748cd0a967..a8f631614d 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -298,38 +298,45 @@ export class MCPConnectionFactory { const oauthHandler = async (data: { serverUrl?: string }) => { logger.info(`${this.logPrefix} oauthRequired event received`); - // If we just want to initiate OAuth and return, handle it differently if (this.returnOnOAuth) { try { const config = this.serverConfig; - const { authorizationUrl, flowId, flowMetadata } = - await MCPOAuthHandler.initiateOAuthFlow( - this.serverName, - data.serverUrl || '', - this.userId!, - config?.oauth_headers ?? {}, - config?.oauth, + const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName); + const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth'); + + if (existingFlow?.status === 'PENDING') { + logger.debug( + `${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`, ); + connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); + return; + } - // Delete any existing flow state to ensure we start fresh - // This prevents stale codeVerifier issues when re-authenticating - await this.flowManager!.deleteFlow(flowId, 'mcp_oauth'); + const { + authorizationUrl, + flowId: newFlowId, + flowMetadata, + } = await MCPOAuthHandler.initiateOAuthFlow( + this.serverName, + data.serverUrl || '', + this.userId!, + config?.oauth_headers ?? {}, + config?.oauth, + ); - // Create the flow state so the OAuth callback can find it - // We spawn this in the background without waiting for it - // Pass signal so the flow can be aborted if the request is cancelled - this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => { - // The OAuth callback will resolve this flow, so we expect it to timeout here - // or it will be aborted if the request is cancelled - both are fine - }); + if (existingFlow) { + await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth'); + } + + this.flowManager!.createFlow(newFlowId, 'mcp_oauth', flowMetadata, this.signal).catch( + () => {}, + ); if (this.oauthStart) { logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`); await this.oauthStart(authorizationUrl); } - // Emit oauthFailed to signal that connection should not proceed - // but OAuth was successfully initiated connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); return; } catch (error) { @@ -391,11 +398,9 @@ export class MCPConnectionFactory { logger.error(`${this.logPrefix} Failed to establish connection.`); } - // Handles connection attempts with retry logic and OAuth error handling private async connectTo(connection: MCPConnection): Promise { const maxAttempts = 3; let attempts = 0; - let oauthHandled = false; while (attempts < maxAttempts) { try { @@ -408,22 +413,6 @@ export class MCPConnectionFactory { attempts++; if (this.useOAuth && this.isOAuthError(error)) { - // For returnOnOAuth mode, let the event handler (handleOAuthEvents) deal with OAuth - // We just need to stop retrying and let the error propagate - if (this.returnOnOAuth) { - logger.info( - `${this.logPrefix} OAuth required (return on OAuth mode), stopping retries`, - ); - throw error; - } - - // Normal flow - wait for OAuth to complete - if (this.oauthStart && !oauthHandled) { - oauthHandled = true; - logger.info(`${this.logPrefix} Handling OAuth`); - await this.handleOAuthRequired(); - } - // Don't retry on OAuth errors - just throw logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`); throw error; } @@ -499,26 +488,15 @@ export class MCPConnectionFactory { /** Check if there's already an ongoing OAuth flow for this flowId */ const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth'); - // If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh - // This ensures the user always gets a new OAuth URL instead of waiting for stale flows if (existingFlow) { logger.debug( - `${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`, + `${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`, ); try { - if (existingFlow.status === 'PENDING') { - await this.flowManager.failFlow( - flowId, - 'mcp_oauth', - new Error('Cancelled for new OAuth request'), - ); - } else { - await this.flowManager.deleteFlow(flowId, 'mcp_oauth'); - } + await this.flowManager.deleteFlow(flowId, 'mcp_oauth'); } catch (error) { - logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error); + logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error); } - // Continue to start a new flow below } logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index 9f824bce23..263c84357a 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -270,7 +270,54 @@ describe('MCPConnectionFactory', () => { ); }); - it('should delete existing flow before creating new OAuth flow to prevent stale codeVerifier', async () => { + it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + user: mockUser, + }; + + const oauthOptions: t.OAuthConnectionOptions = { + user: mockUser, + useOAuth: true, + returnOnOAuth: true, + oauthStart: jest.fn(), + flowManager: mockFlowManager, + }; + + mockFlowManager.getFlowState.mockResolvedValue({ + status: 'PENDING', + type: 'mcp_oauth', + metadata: { codeVerifier: 'existing-verifier' }, + createdAt: Date.now(), + }); + mockConnectionInstance.isConnected.mockResolvedValue(false); + + let oauthRequiredHandler: (data: Record) => Promise; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthRequiredHandler = handler as (data: Record) => Promise; + } + return mockConnectionInstance; + }); + + try { + await MCPConnectionFactory.create(basicOptions, oauthOptions); + } catch { + // Expected to fail + } + + await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); + + expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled(); + expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled(); + expect(mockConnectionInstance.emit).toHaveBeenCalledWith( + 'oauthFailed', + expect.objectContaining({ message: 'OAuth flow initiated - return early' }), + ); + }); + + it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => { const basicOptions = { serverName: 'test-server', serverConfig: mockServerConfig, @@ -303,6 +350,12 @@ describe('MCPConnectionFactory', () => { }, }; + mockFlowManager.getFlowState.mockResolvedValue({ + status: 'COMPLETED', + type: 'mcp_oauth', + metadata: { codeVerifier: 'old-verifier' }, + createdAt: Date.now() - 60000, + }); mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData); mockFlowManager.deleteFlow.mockResolvedValue(true); mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected')); @@ -319,21 +372,17 @@ describe('MCPConnectionFactory', () => { try { await MCPConnectionFactory.create(basicOptions, oauthOptions); } catch { - // Expected to fail due to connection not established + // Expected to fail } await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' }); - // Verify deleteFlow was called with correct parameters expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth'); - // Verify deleteFlow was called before createFlow const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0]; const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0]; expect(deleteCallOrder).toBeLessThan(createCallOrder); - // Verify createFlow was called with fresh metadata - // 4th arg is the abort signal (undefined in this test since no signal was provided) expect(mockFlowManager.createFlow).toHaveBeenCalledWith( 'user123:test-server', 'mcp_oauth', diff --git a/packages/api/src/oauth/csrf.ts b/packages/api/src/oauth/csrf.ts new file mode 100644 index 0000000000..5bf0566b45 --- /dev/null +++ b/packages/api/src/oauth/csrf.ts @@ -0,0 +1,89 @@ +import crypto from 'crypto'; +import type { Request, Response, NextFunction } from 'express'; + +export const OAUTH_CSRF_COOKIE = 'oauth_csrf'; +export const OAUTH_CSRF_MAX_AGE = 10 * 60 * 1000; + +export const OAUTH_SESSION_COOKIE = 'oauth_session'; +export const OAUTH_SESSION_MAX_AGE = 24 * 60 * 60 * 1000; +export const OAUTH_SESSION_COOKIE_PATH = '/api'; + +const isProduction = process.env.NODE_ENV === 'production'; + +/** Generates an HMAC-based token for OAuth CSRF protection */ +export function generateOAuthCsrfToken(flowId: string, secret?: string): string { + const key = secret || process.env.JWT_SECRET; + if (!key) { + throw new Error('JWT_SECRET is required for OAuth CSRF token generation'); + } + return crypto.createHmac('sha256', key).update(flowId).digest('hex').slice(0, 32); +} + +/** Sets a SameSite=Lax CSRF cookie bound to a specific OAuth flow */ +export function setOAuthCsrfCookie(res: Response, flowId: string, cookiePath: string): void { + res.cookie(OAUTH_CSRF_COOKIE, generateOAuthCsrfToken(flowId), { + httpOnly: true, + secure: isProduction, + sameSite: 'lax', + maxAge: OAUTH_CSRF_MAX_AGE, + path: cookiePath, + }); +} + +/** + * Validates the per-flow CSRF cookie against the expected HMAC. + * Uses timing-safe comparison and always clears the cookie to prevent replay. + */ +export function validateOAuthCsrf( + req: Request, + res: Response, + flowId: string, + cookiePath: string, +): boolean { + const cookie = (req.cookies as Record | undefined)?.[OAUTH_CSRF_COOKIE]; + res.clearCookie(OAUTH_CSRF_COOKIE, { path: cookiePath }); + if (!cookie) { + return false; + } + const expected = generateOAuthCsrfToken(flowId); + if (cookie.length !== expected.length) { + return false; + } + return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected)); +} + +/** + * Express middleware that sets the OAuth session cookie after JWT authentication. + * Chain after requireJwtAuth on routes that precede an OAuth redirect (e.g., reinitialize, bind). + */ +export function setOAuthSession(req: Request, res: Response, next: NextFunction): void { + const user = (req as Request & { user?: { id?: string } }).user; + if (user?.id && !(req.cookies as Record | undefined)?.[OAUTH_SESSION_COOKIE]) { + setOAuthSessionCookie(res, user.id); + } + next(); +} + +/** Sets a SameSite=Lax session cookie that binds the browser to the authenticated userId */ +export function setOAuthSessionCookie(res: Response, userId: string): void { + res.cookie(OAUTH_SESSION_COOKIE, generateOAuthCsrfToken(userId), { + httpOnly: true, + secure: isProduction, + sameSite: 'lax', + maxAge: OAUTH_SESSION_MAX_AGE, + path: OAUTH_SESSION_COOKIE_PATH, + }); +} + +/** Validates the session cookie against the expected userId using timing-safe comparison */ +export function validateOAuthSession(req: Request, userId: string): boolean { + const cookie = (req.cookies as Record | undefined)?.[OAUTH_SESSION_COOKIE]; + if (!cookie) { + return false; + } + const expected = generateOAuthCsrfToken(userId); + if (cookie.length !== expected.length) { + return false; + } + return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected)); +} diff --git a/packages/api/src/oauth/index.ts b/packages/api/src/oauth/index.ts index e56053c166..01be92b6e3 100644 --- a/packages/api/src/oauth/index.ts +++ b/packages/api/src/oauth/index.ts @@ -1 +1,2 @@ +export * from './csrf'; export * from './tokens'; diff --git a/packages/data-provider/src/api-endpoints.ts b/packages/data-provider/src/api-endpoints.ts index d49535b094..db6df32015 100644 --- a/packages/data-provider/src/api-endpoints.ts +++ b/packages/data-provider/src/api-endpoints.ts @@ -181,6 +181,11 @@ export const cancelMCPOAuth = (serverName: string) => { return `${BASE_URL}/api/mcp/oauth/cancel/${serverName}`; }; +export const mcpOAuthBind = (serverName: string) => `${BASE_URL}/api/mcp/${serverName}/oauth/bind`; + +export const actionOAuthBind = (actionId: string) => + `${BASE_URL}/api/actions/${actionId}/oauth/bind`; + export const config = () => `${BASE_URL}/api/config`; export const prompts = () => `${BASE_URL}/api/prompts`; diff --git a/packages/data-provider/src/data-service.ts b/packages/data-provider/src/data-service.ts index 8919e2589b..be5cccd43b 100644 --- a/packages/data-provider/src/data-service.ts +++ b/packages/data-provider/src/data-service.ts @@ -178,6 +178,14 @@ export const reinitializeMCPServer = (serverName: string) => { return request.post(endpoints.mcpReinitialize(serverName)); }; +export const bindMCPOAuth = (serverName: string): Promise<{ success: boolean }> => { + return request.post(endpoints.mcpOAuthBind(serverName)); +}; + +export const bindActionOAuth = (actionId: string): Promise<{ success: boolean }> => { + return request.post(endpoints.actionOAuthBind(actionId)); +}; + export const getMCPConnectionStatus = (): Promise => { return request.get(endpoints.mcpConnectionStatus()); };