From 7844a93f8b536293c02268f3025336f1f1004cd6 Mon Sep 17 00:00:00 2001 From: Artyom Bogachenko <32168471+SpectralOne@users.noreply.github.com> Date: Thu, 25 Dec 2025 20:24:01 +0300 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20fix:=20use=20DOMAIN=5FCLIE?= =?UTF-8?q?NT=20for=20MCP=20OAuth=20Redirects=20(#11057)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Artyom Bogachenco --- api/server/routes/__tests__/mcp.spec.js | 42 +++++++++++++++++-------- api/server/routes/actions.js | 13 ++++---- api/server/routes/mcp.js | 20 +++++++----- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 1da1e0aa86..3b0d20feac 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -2,6 +2,7 @@ const express = require('express'); const request = require('supertest'); const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); +const { getBasePath } = require('@librechat/api'); const mockRegistryInstance = { getServerConfig: jest.fn(), @@ -281,27 +282,30 @@ describe('MCP Routes', () => { error: 'access_denied', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/error?error=access_denied'); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`); }); it('should redirect to error page when code is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/error?error=missing_code'); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_code`); }); it('should redirect to error page when state is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ code: 'test-auth-code', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/error?error=missing_state'); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`); }); it('should redirect to error page when flow state is not found', async () => { @@ -311,9 +315,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'invalid-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/error?error=invalid_state'); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`); }); it('should handle OAuth callback successfully', async () => { @@ -368,9 +373,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/success?serverName=test-server'); + expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( 'test-flow-id', 'test-auth-code', @@ -404,9 +410,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/error?error=callback_failed'); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`); }); it('should handle system-level OAuth completion', async () => { @@ -439,9 +446,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/success?serverName=test-server'); + expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); }); @@ -484,9 +492,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/success?serverName=test-server'); + expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); expect(MCPTokenStorage.storeTokens).toHaveBeenCalled(); expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); }); @@ -525,9 +534,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/error?error=callback_failed'); + expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`); expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled(); }); @@ -583,9 +593,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/success?serverName=test-server'); + expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); // Verify storeTokens was called with ORIGINAL flow state credentials expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith( @@ -624,9 +635,10 @@ describe('MCP Routes', () => { code: 'test-auth-code', state: 'test-flow-id', }); + const basePath = getBasePath(); expect(response.status).toBe(302); - expect(response.headers.location).toBe('/oauth/success?serverName=test-server'); + expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); // Verify completeOAuthFlow was NOT called (prevented duplicate) expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled(); @@ -1395,8 +1407,10 @@ describe('MCP Routes', () => { .get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id') .expect(302); + const basePath = getBasePath(); + expect(mockFlowManager.completeFlow).not.toHaveBeenCalled(); - expect(response.headers.location).toContain('/oauth/success'); + expect(response.headers.location).toContain(`${basePath}/oauth/success`); }); it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => { @@ -1443,7 +1457,9 @@ describe('MCP Routes', () => { .get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id') .expect(302); - expect(response.headers.location).toContain('/oauth/success'); + const basePath = getBasePath(); + + expect(response.headers.location).toContain(`${basePath}/oauth/success`); }); }); diff --git a/api/server/routes/actions.js b/api/server/routes/actions.js index 9f94f617ce..14474a53d3 100644 --- a/api/server/routes/actions.js +++ b/api/server/routes/actions.js @@ -1,6 +1,6 @@ const express = require('express'); const jwt = require('jsonwebtoken'); -const { getAccessToken } = require('@librechat/api'); +const { getAccessToken, getBasePath } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); const { findToken, updateToken, createToken } = require('~/models'); @@ -24,6 +24,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => { const { code, state } = req.query; const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); + const basePath = getBasePath(); let identifier = action_id; try { let decodedState; @@ -32,17 +33,17 @@ router.get('/:action_id/oauth/callback', async (req, res) => { } catch (err) { logger.error('Error verifying state parameter:', err); await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter'); - return res.redirect('/oauth/error?error=invalid_state'); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } if (decodedState.action_id !== action_id) { await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter'); - return res.redirect('/oauth/error?error=invalid_state'); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } if (!decodedState.user) { await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter'); - return res.redirect('/oauth/error?error=invalid_state'); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } identifier = `${decodedState.user}:${action_id}`; const flowState = await flowManager.getFlowState(identifier, 'oauth'); @@ -72,12 +73,12 @@ router.get('/:action_id/oauth/callback', async (req, res) => { /** Redirect to React success page */ const serverName = flowState.metadata?.action_name || `Action ${action_id}`; - const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`; + const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`; res.redirect(redirectUrl); } catch (error) { logger.error('Error in OAuth callback:', error); await flowManager.failFlow(identifier, 'oauth', error); - res.redirect('/oauth/error?error=callback_failed'); + res.redirect(`${basePath}/oauth/error?error=callback_failed`); } }); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 0cee7f991a..f01c7ff71c 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -11,6 +11,7 @@ const { createSafeUser, MCPOAuthHandler, MCPTokenStorage, + getBasePath, getUserMCPAuthMap, generateCheckAccess, } = require('@librechat/api'); @@ -105,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { * This handles the OAuth callback after the user has authorized the application */ router.get('/:serverName/oauth/callback', async (req, res) => { + const basePath = getBasePath(); try { const { serverName } = req.params; const { code, state, error: oauthError } = req.query; @@ -118,17 +120,19 @@ router.get('/:serverName/oauth/callback', async (req, res) => { if (oauthError) { logger.error('[MCP OAuth] OAuth error received', { error: oauthError }); - return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`); + return res.redirect( + `${basePath}/oauth/error?error=${encodeURIComponent(String(oauthError))}`, + ); } if (!code || typeof code !== 'string') { logger.error('[MCP OAuth] Missing or invalid code'); - return res.redirect('/oauth/error?error=missing_code'); + return res.redirect(`${basePath}/oauth/error?error=missing_code`); } if (!state || typeof state !== 'string') { logger.error('[MCP OAuth] Missing or invalid state'); - return res.redirect('/oauth/error?error=missing_state'); + return res.redirect(`${basePath}/oauth/error?error=missing_state`); } const flowId = state; @@ -142,7 +146,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => { if (!flowState) { logger.error('[MCP OAuth] Flow state not found for flowId:', flowId); - return res.redirect('/oauth/error?error=invalid_state'); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } logger.debug('[MCP OAuth] Flow state details', { @@ -160,7 +164,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => { flowId, serverName, }); - return res.redirect(`/oauth/success?serverName=${encodeURIComponent(serverName)}`); + return res.redirect(`${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`); } logger.debug('[MCP OAuth] Completing OAuth flow'); @@ -254,11 +258,11 @@ router.get('/:serverName/oauth/callback', async (req, res) => { } /** Redirect to success page with flowId and serverName */ - const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`; + const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`; res.redirect(redirectUrl); } catch (error) { logger.error('[MCP OAuth] OAuth callback error', error); - res.redirect('/oauth/error?error=callback_failed'); + res.redirect(`${basePath}/oauth/error?error=callback_failed`); } }); @@ -588,7 +592,7 @@ async function getOAuthHeaders(serverName, userId) { return serverConfig?.oauth_headers ?? {}; } -/** +/** MCP Server CRUD Routes (User-Managed MCP Servers) */