From fa9e1b228a09fb02541068902635a97686eb32cc Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 13 Mar 2026 23:18:56 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=AA=20fix:=20MCP=20API=20Responses=20a?= =?UTF-8?q?nd=20OAuth=20Validation=20(#12217)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔒 fix: Validate MCP Configs in Server Responses * 🔒 fix: Enhance OAuth URL Validation in MCPOAuthHandler - Introduced validation for OAuth URLs to ensure they do not target private or internal addresses, enhancing security against SSRF attacks. - Updated the OAuth flow to validate both authorization and token URLs before use, ensuring compliance with security standards. - Refactored redirect URI handling to streamline the OAuth client registration process. - Added comprehensive error handling for invalid URLs, improving robustness in OAuth interactions. * 🔒 feat: Implement Permission Checks for MCP Server Management - Added permission checkers for MCP server usage and creation, enhancing access control. - Updated routes for reinitializing MCP servers and retrieving authentication values to include these permission checks, ensuring only authorized users can access these functionalities. - Refactored existing permission logic to improve clarity and maintainability. * 🔒 fix: Enhance MCP Server Response Validation and Redaction - Updated MCP route tests to use `toMatchObject` for better validation of server response structures, ensuring consistency in expected properties. - Refactored the `redactServerSecrets` function to streamline the removal of sensitive information, ensuring that user-sourced API keys are properly redacted while retaining their source. - Improved OAuth security tests to validate rejection of private URLs across multiple endpoints, enhancing protection against SSRF vulnerabilities. - Added comprehensive tests for the `redactServerSecrets` function to ensure proper handling of various server configurations, reinforcing security measures. * chore: eslint * 🔒 fix: Enhance OAuth Server URL Validation in MCPOAuthHandler - Added validation for discovered authorization server URLs to ensure they meet security standards. - Improved logging to provide clearer insights when an authorization server is found from resource metadata. - Refactored the handling of authorization server URLs to enhance robustness against potential security vulnerabilities. * 🔒 test: Bypass SSRF validation for MCP OAuth Flow tests - Mocked SSRF validation functions to allow tests to use real local HTTP servers, facilitating more accurate testing of the MCP OAuth flow. - Updated test setup to ensure compatibility with the new mocking strategy, enhancing the reliability of the tests. * 🔒 fix: Add Validation for OAuth Metadata Endpoints in MCPOAuthHandler - Implemented checks for the presence and validity of registration and token endpoints in the OAuth metadata, enhancing security by ensuring that these URLs are properly validated before use. - Improved error handling and logging to provide better insights during the OAuth metadata processing, reinforcing the robustness of the OAuth flow. * 🔒 refactor: Simplify MCP Auth Values Endpoint Logic - Removed redundant permission checks for accessing the MCP server resource in the auth-values endpoint, streamlining the request handling process. - Consolidated error handling and response structure for improved clarity and maintainability. - Enhanced logging for better insights during the authentication value checks, reinforcing the robustness of the endpoint. * 🔒 test: Refactor LeaderElection Integration Tests for Improved Cleanup - Moved Redis key cleanup to the beforeEach hook to ensure a clean state before each test. - Enhanced afterEach logic to handle instance resignations and Redis key deletion more robustly, improving test reliability and maintainability. --- api/server/controllers/mcp.js | 14 +- api/server/routes/__tests__/mcp.spec.js | 118 ++++++++- api/server/routes/mcp.js | 143 +++++------ .../LeaderElection.cache_integration.spec.ts | 18 +- .../src/mcp/__tests__/MCPOAuthFlow.test.ts | 7 + .../mcp/__tests__/MCPOAuthSecurity.test.ts | 228 ++++++++++++++++++ packages/api/src/mcp/__tests__/utils.test.ts | 201 ++++++++++++++- packages/api/src/mcp/oauth/handler.ts | 60 ++++- .../__tests__/ServerConfigsDB.test.ts | 98 ++++++++ packages/api/src/mcp/utils.ts | 60 +++++ 10 files changed, 845 insertions(+), 102 deletions(-) create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index e5dfff61ca..729f01da9d 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -7,9 +7,11 @@ */ const { logger } = require('@librechat/data-schemas'); const { + MCPErrorCodes, + redactServerSecrets, + redactAllServerSecrets, isMCPDomainNotAllowedError, isMCPInspectionFailedError, - MCPErrorCodes, } = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); @@ -181,10 +183,8 @@ const getMCPServersList = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - // 2. Get all server configs from registry (YAML + DB) const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); - - return res.json(serverConfigs); + return res.json(redactAllServerSecrets(serverConfigs)); } catch (error) { logger.error('[getMCPServersList]', error); res.status(500).json({ error: error.message }); @@ -215,7 +215,7 @@ const createMCPServerController = async (req, res) => { ); res.status(201).json({ serverName: result.serverName, - ...result.config, + ...redactServerSecrets(result.config), }); } catch (error) { logger.error('[createMCPServer]', error); @@ -243,7 +243,7 @@ const getMCPServerById = async (req, res) => { return res.status(404).json({ message: 'MCP server not found' }); } - res.status(200).json(parsedConfig); + res.status(200).json(redactServerSecrets(parsedConfig)); } catch (error) { logger.error('[getMCPServerById]', error); res.status(500).json({ message: error.message }); @@ -274,7 +274,7 @@ const updateMCPServerController = async (req, res) => { userId, ); - res.status(200).json(parsedConfig); + res.status(200).json(redactServerSecrets(parsedConfig)); } catch (error) { logger.error('[updateMCPServer]', error); const mcpErrorResponse = handleMCPError(error, res); diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index e0cb680169..1ad8cac087 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -1693,12 +1693,14 @@ describe('MCP Routes', () => { it('should return all server configs for authenticated user', async () => { const mockServerConfigs = { 'server-1': { - endpoint: 'http://server1.com', - name: 'Server 1', + type: 'sse', + url: 'http://server1.com/sse', + title: 'Server 1', }, 'server-2': { - endpoint: 'http://server2.com', - name: 'Server 2', + type: 'sse', + url: 'http://server2.com/sse', + title: 'Server 2', }, }; @@ -1707,7 +1709,18 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/servers'); expect(response.status).toBe(200); - expect(response.body).toEqual(mockServerConfigs); + expect(response.body['server-1']).toMatchObject({ + type: 'sse', + url: 'http://server1.com/sse', + title: 'Server 1', + }); + expect(response.body['server-2']).toMatchObject({ + type: 'sse', + url: 'http://server2.com/sse', + title: 'Server 2', + }); + expect(response.body['server-1'].headers).toBeUndefined(); + expect(response.body['server-2'].headers).toBeUndefined(); expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); }); @@ -1762,10 +1775,10 @@ describe('MCP Routes', () => { const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); expect(response.status).toBe(201); - expect(response.body).toEqual({ - serverName: 'test-sse-server', - ...validConfig, - }); + expect(response.body.serverName).toBe('test-sse-server'); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://mcp-server.example.com/sse'); + expect(response.body.title).toBe('Test SSE Server'); expect(mockRegistryInstance.addServer).toHaveBeenCalledWith( 'temp_server_name', expect.objectContaining({ @@ -1864,6 +1877,33 @@ describe('MCP Routes', () => { expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); }); + it('should redact secrets from create response', async () => { + const validConfig = { + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Test Server', + }; + + mockRegistryInstance.addServer.mockResolvedValue({ + serverName: 'test-server', + config: { + ...validConfig, + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'admin-secret-key' }, + oauth: { client_id: 'cid', client_secret: 'admin-oauth-secret' }, + headers: { Authorization: 'Bearer leaked-token' }, + }, + }); + + const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); + + expect(response.status).toBe(201); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.headers).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_id).toBe('cid'); + }); + it('should return 500 when registry throws error', async () => { const validConfig = { type: 'sse', @@ -1893,7 +1933,9 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/servers/test-server'); expect(response.status).toBe(200); - expect(response.body).toEqual(mockConfig); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://mcp-server.example.com/sse'); + expect(response.body.title).toBe('Test Server'); expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( 'test-server', 'test-user-id', @@ -1909,6 +1951,29 @@ describe('MCP Routes', () => { expect(response.body).toEqual({ message: 'MCP server not found' }); }); + it('should redact secrets from get response', async () => { + mockRegistryInstance.getServerConfig.mockResolvedValue({ + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Secret Server', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'decrypted-admin-key' }, + oauth: { client_id: 'cid', client_secret: 'decrypted-oauth-secret' }, + headers: { Authorization: 'Bearer internal-token' }, + oauth_headers: { 'X-OAuth': 'secret-value' }, + }); + + const response = await request(app).get('/api/mcp/servers/secret-server'); + + expect(response.status).toBe(200); + expect(response.body.title).toBe('Secret Server'); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.oauth?.client_id).toBe('cid'); + expect(response.body.headers).toBeUndefined(); + expect(response.body.oauth_headers).toBeUndefined(); + }); + it('should return 500 when registry throws error', async () => { mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error')); @@ -1935,7 +2000,9 @@ describe('MCP Routes', () => { .send({ config: updatedConfig }); expect(response.status).toBe(200); - expect(response.body).toEqual(updatedConfig); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://updated-mcp-server.example.com/sse'); + expect(response.body.title).toBe('Updated Server'); expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith( 'test-server', expect.objectContaining({ @@ -1947,6 +2014,35 @@ describe('MCP Routes', () => { ); }); + it('should redact secrets from update response', async () => { + const validConfig = { + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Updated Server', + }; + + mockRegistryInstance.updateServer.mockResolvedValue({ + ...validConfig, + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'preserved-admin-key' }, + oauth: { client_id: 'cid', client_secret: 'preserved-oauth-secret' }, + headers: { Authorization: 'Bearer internal-token' }, + env: { DATABASE_URL: 'postgres://admin:pass@localhost/db' }, + }); + + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ config: validConfig }); + + expect(response.status).toBe(200); + expect(response.body.title).toBe('Updated Server'); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.oauth?.client_id).toBe('cid'); + expect(response.body.headers).toBeUndefined(); + expect(response.body.env).toBeUndefined(); + }); + it('should return 400 for invalid configuration', async () => { const invalidConfig = { type: 'sse', diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 0afac81192..57a99d199a 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -50,6 +50,18 @@ const router = Router(); const OAUTH_CSRF_COOKIE_PATH = '/api/mcp'; +const checkMCPUsePermissions = generateCheckAccess({ + permissionType: PermissionTypes.MCP_SERVERS, + permissions: [Permissions.USE], + getRoleByName, +}); + +const checkMCPCreate = generateCheckAccess({ + permissionType: PermissionTypes.MCP_SERVERS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); + /** * Get all MCP tools available to the user * Returns only MCP tools, completely decoupled from regular LibreChat tools @@ -470,69 +482,75 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => { * Reinitialize MCP server * This endpoint allows reinitializing a specific MCP server */ -router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => { - try { - const { serverName } = req.params; - const user = createSafeUser(req.user); +router.post( + '/:serverName/reinitialize', + requireJwtAuth, + checkMCPUsePermissions, + setOAuthSession, + async (req, res) => { + try { + const { serverName } = req.params; + const user = createSafeUser(req.user); - if (!user.id) { - return res.status(401).json({ error: 'User not authenticated' }); - } + if (!user.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } - logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); + logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); - const mcpManager = getMCPManager(); - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); - if (!serverConfig) { - return res.status(404).json({ - error: `MCP server '${serverName}' not found in configuration`, + const mcpManager = getMCPManager(); + const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + if (!serverConfig) { + return res.status(404).json({ + error: `MCP server '${serverName}' not found in configuration`, + }); + } + + await mcpManager.disconnectUserConnection(user.id, serverName); + logger.info( + `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, + ); + + /** @type {Record> | undefined} */ + let userMCPAuthMap; + if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { + userMCPAuthMap = await getUserMCPAuthMap({ + userId: user.id, + servers: [serverName], + findPluginAuthsByKeys, + }); + } + + const result = await reinitMCPServer({ + user, + serverName, + userMCPAuthMap, }); - } - await mcpManager.disconnectUserConnection(user.id, serverName); - logger.info( - `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, - ); + if (!result) { + return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); + } - /** @type {Record> | undefined} */ - let userMCPAuthMap; - if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { - userMCPAuthMap = await getUserMCPAuthMap({ - userId: user.id, - servers: [serverName], - findPluginAuthsByKeys, + const { success, message, oauthRequired, oauthUrl } = result; + + if (oauthRequired) { + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + } + + res.json({ + success, + message, + oauthUrl, + serverName, + oauthRequired, }); + } catch (error) { + logger.error('[MCP Reinitialize] Unexpected error', error); + res.status(500).json({ error: 'Internal server error' }); } - - const result = await reinitMCPServer({ - user, - serverName, - userMCPAuthMap, - }); - - if (!result) { - return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); - } - - const { success, message, oauthRequired, oauthUrl } = result; - - if (oauthRequired) { - const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); - setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); - } - - res.json({ - success, - message, - oauthUrl, - serverName, - oauthRequired, - }); - } catch (error) { - logger.error('[MCP Reinitialize] Unexpected error', error); - res.status(500).json({ error: 'Internal server error' }); - } -}); + }, +); /** * Get connection status for all MCP servers @@ -639,7 +657,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => * Check which authentication values exist for a specific MCP server * This endpoint returns only boolean flags indicating if values are set, not the actual values */ -router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { +router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, async (req, res) => { try { const { serverName } = req.params; const user = req.user; @@ -696,19 +714,6 @@ async function getOAuthHeaders(serverName, userId) { MCP Server CRUD Routes (User-Managed MCP Servers) */ -// Permission checkers for MCP server management -const checkMCPUsePermissions = generateCheckAccess({ - permissionType: PermissionTypes.MCP_SERVERS, - permissions: [Permissions.USE], - getRoleByName, -}); - -const checkMCPCreate = generateCheckAccess({ - permissionType: PermissionTypes.MCP_SERVERS, - permissions: [Permissions.USE, Permissions.CREATE], - getRoleByName, -}); - /** * Get list of accessible MCP servers * @route GET /api/mcp/servers diff --git a/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts index 9bad4dcfac..f1558db795 100644 --- a/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts +++ b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts @@ -32,14 +32,22 @@ describe('LeaderElection with Redis', () => { process.setMaxListeners(200); }); - afterEach(async () => { - await Promise.all(instances.map((instance) => instance.resign())); - instances = []; - - // Clean up: clear the leader key directly from Redis + beforeEach(async () => { if (keyvRedisClient) { await keyvRedisClient.del(LeaderElection.LEADER_KEY); } + new LeaderElection().clearRefreshTimer(); + }); + + afterEach(async () => { + try { + await Promise.all(instances.map((instance) => instance.resign())); + } finally { + instances = []; + if (keyvRedisClient) { + await keyvRedisClient.del(LeaderElection.LEADER_KEY); + } + } }); afterAll(async () => { diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index 8437177c86..f73a5ed3e8 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -24,6 +24,13 @@ jest.mock('@librechat/data-schemas', () => ({ decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); +/** Bypass SSRF validation — these tests use real local HTTP servers. */ +jest.mock('~/auth', () => ({ + ...jest.requireActual('~/auth'), + isSSRFTarget: jest.fn(() => false), + resolveHostnameSSRF: jest.fn(async () => false), +})); + describe('MCP OAuth Flow — Real HTTP Server', () => { afterEach(() => { jest.clearAllMocks(); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts new file mode 100644 index 0000000000..a5188e24b0 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts @@ -0,0 +1,228 @@ +/** + * Tests verifying MCP OAuth security hardening: + * + * 1. SSRF via OAuth URLs — validates that the OAuth handler rejects + * token_url, authorization_url, and revocation_endpoint values + * pointing to private/internal addresses. + * + * 2. redirect_uri manipulation — validates that user-supplied redirect_uri + * is ignored in favor of the server-controlled default. + */ + +import * as http from 'http'; +import * as net from 'net'; +import { TokenExchangeMethodEnum } from 'librechat-data-provider'; +import type { Socket } from 'net'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import { createOAuthMCPServer } from './helpers/oauthTestServer'; +import { MCPOAuthHandler } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +/** + * Mock only the DNS-dependent resolveHostnameSSRF; keep isSSRFTarget real. + * SSRF tests use literal private IPs (127.0.0.1, 169.254.169.254, 10.0.0.1) + * which are caught by isSSRFTarget before resolveHostnameSSRF is reached. + * This avoids non-deterministic DNS lookups in test execution. + */ +jest.mock('~/auth', () => ({ + ...jest.requireActual('~/auth'), + resolveHostnameSSRF: jest.fn(async () => false), +})); + +function getFreePort(): Promise { + return new Promise((resolve, reject) => { + const srv = net.createServer(); + srv.listen(0, '127.0.0.1', () => { + const addr = srv.address() as net.AddressInfo; + srv.close((err) => (err ? reject(err) : resolve(addr.port))); + }); + }); +} + +function trackSockets(httpServer: http.Server): () => Promise { + const sockets = new Set(); + httpServer.on('connection', (socket: Socket) => { + sockets.add(socket); + socket.once('close', () => sockets.delete(socket)); + }); + return () => + new Promise((resolve) => { + for (const socket of sockets) { + socket.destroy(); + } + sockets.clear(); + httpServer.close(() => resolve()); + }); +} + +describe('MCP OAuth SSRF protection', () => { + let oauthServer: OAuthTestServer; + let ssrfTargetServer: http.Server; + let ssrfTargetPort: number; + let ssrfRequestReceived: boolean; + let destroySSRFSockets: () => Promise; + + beforeEach(async () => { + ssrfRequestReceived = false; + + oauthServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + + ssrfTargetPort = await getFreePort(); + ssrfTargetServer = http.createServer((_req, res) => { + ssrfRequestReceived = true; + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + access_token: 'ssrf-token', + token_type: 'Bearer', + expires_in: 3600, + }), + ); + }); + destroySSRFSockets = trackSockets(ssrfTargetServer); + await new Promise((resolve) => + ssrfTargetServer.listen(ssrfTargetPort, '127.0.0.1', resolve), + ); + }); + + afterEach(async () => { + try { + await oauthServer.close(); + } finally { + await destroySSRFSockets(); + } + }); + + it('should reject token_url pointing to a private IP (refreshOAuthTokens)', async () => { + const code = await oauthServer.getAuthCode(); + const tokenRes = await fetch(`${oauthServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const regRes = await fetch(`${oauthServer.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const clientInfo = (await regRes.json()) as { + client_id: string; + client_secret: string; + }; + + const ssrfTokenUrl = `http://127.0.0.1:${ssrfTargetPort}/latest/meta-data/iam/security-credentials/`; + + await expect( + MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'ssrf-test-server', + serverUrl: oauthServer.url, + clientInfo: { + ...clientInfo, + redirect_uris: ['http://localhost/callback'], + }, + }, + {}, + { + token_url: ssrfTokenUrl, + client_id: clientInfo.client_id, + client_secret: clientInfo.client_secret, + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, + }, + ), + ).rejects.toThrow(/targets a blocked address/); + + expect(ssrfRequestReceived).toBe(false); + }); + + it('should reject private authorization_url in initiateOAuthFlow', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://169.254.169.254/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should reject private token_url in initiateOAuthFlow', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'https://auth.example.com/authorize', + token_url: `http://127.0.0.1:${ssrfTargetPort}/token`, + client_id: 'client', + client_secret: 'secret', + }, + ), + ).rejects.toThrow(/targets a blocked address/); + + expect(ssrfRequestReceived).toBe(false); + }); + + it('should reject private revocationEndpoint in revokeOAuthToken', async () => { + await expect( + MCPOAuthHandler.revokeOAuthToken('test-server', 'some-token', 'access', { + serverUrl: 'https://mcp.example.com/', + clientId: 'client', + clientSecret: 'secret', + revocationEndpoint: 'http://10.0.0.1/revoke', + }), + ).rejects.toThrow(/targets a blocked address/); + }); +}); + +describe('MCP OAuth redirect_uri enforcement', () => { + it('should ignore attacker-supplied redirect_uri and use the server default', async () => { + const attackerRedirectUri = 'https://attacker.example.com/steal-code'; + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'victim-server', + 'https://mcp.example.com/', + 'victim-user-id', + {}, + { + authorization_url: 'https://auth.example.com/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'attacker-client', + client_secret: 'attacker-secret', + redirect_uri: attackerRedirectUri, + }, + ); + + const authUrl = new URL(result.authorizationUrl); + const expectedRedirectUri = `${process.env.DOMAIN_SERVER || 'http://localhost:3080'}/api/mcp/victim-server/oauth/callback`; + expect(authUrl.searchParams.get('redirect_uri')).toBe(expectedRedirectUri); + expect(authUrl.searchParams.get('redirect_uri')).not.toBe(attackerRedirectUri); + }); +}); diff --git a/packages/api/src/mcp/__tests__/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts index 716a230ebe..e4fb31bdad 100644 --- a/packages/api/src/mcp/__tests__/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -1,4 +1,5 @@ -import { normalizeServerName } from '../utils'; +import { normalizeServerName, redactServerSecrets, redactAllServerSecrets } from '~/mcp/utils'; +import type { ParsedServerConfig } from '~/mcp/types'; describe('normalizeServerName', () => { it('should not modify server names that already match the pattern', () => { @@ -26,3 +27,201 @@ describe('normalizeServerName', () => { expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/); }); }); + +describe('redactServerSecrets', () => { + it('should strip apiKey.key from admin-sourced keys', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'super-secret-api-key', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.source).toBe('admin'); + expect(redacted.apiKey?.authorization_type).toBe('bearer'); + }); + + it('should strip oauth.client_secret', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + oauth: { + client_id: 'my-client', + client_secret: 'super-secret-oauth', + scope: 'read', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.oauth?.client_secret).toBeUndefined(); + expect(redacted.oauth?.client_id).toBe('my-client'); + expect(redacted.oauth?.scope).toBe('read'); + }); + + it('should strip both apiKey.key and oauth.client_secret simultaneously', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { + source: 'admin', + authorization_type: 'custom', + custom_header: 'X-API-Key', + key: 'secret-key', + }, + oauth: { + client_id: 'cid', + client_secret: 'csecret', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.custom_header).toBe('X-API-Key'); + expect(redacted.oauth?.client_secret).toBeUndefined(); + expect(redacted.oauth?.client_id).toBe('cid'); + }); + + it('should exclude headers from SSE configs', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'SSE Server', + }; + (config as ParsedServerConfig & { headers: Record }).headers = { + Authorization: 'Bearer admin-token-123', + 'X-Custom': 'safe-value', + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).headers).toBeUndefined(); + expect(redacted.title).toBe('SSE Server'); + }); + + it('should exclude env from stdio configs', () => { + const config: ParsedServerConfig = { + type: 'stdio', + command: 'node', + args: ['server.js'], + env: { DATABASE_URL: 'postgres://admin:password@localhost/db', PATH: '/usr/bin' }, + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).env).toBeUndefined(); + expect((redacted as Record).command).toBeUndefined(); + expect((redacted as Record).args).toBeUndefined(); + }); + + it('should exclude oauth_headers', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + oauth_headers: { Authorization: 'Bearer oauth-admin-token' }, + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).oauth_headers).toBeUndefined(); + }); + + it('should strip apiKey.key even for user-sourced keys', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { source: 'user', authorization_type: 'bearer', key: 'my-own-key' }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.source).toBe('user'); + }); + + it('should not mutate the original config', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'secret' }, + oauth: { client_id: 'cid', client_secret: 'csecret' }, + }; + redactServerSecrets(config); + expect(config.apiKey?.key).toBe('secret'); + expect(config.oauth?.client_secret).toBe('csecret'); + }); + + it('should preserve all safe metadata fields', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'My Server', + description: 'A test server', + iconPath: '/icons/test.png', + chatMenu: true, + requiresOAuth: false, + capabilities: '{"tools":{}}', + tools: 'tool_a, tool_b', + dbId: 'abc123', + updatedAt: 1700000000000, + consumeOnly: false, + inspectionFailed: false, + customUserVars: { API_KEY: { title: 'API Key', description: 'Your key' } }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.title).toBe('My Server'); + expect(redacted.description).toBe('A test server'); + expect(redacted.iconPath).toBe('/icons/test.png'); + expect(redacted.chatMenu).toBe(true); + expect(redacted.requiresOAuth).toBe(false); + expect(redacted.capabilities).toBe('{"tools":{}}'); + expect(redacted.tools).toBe('tool_a, tool_b'); + expect(redacted.dbId).toBe('abc123'); + expect(redacted.updatedAt).toBe(1700000000000); + expect(redacted.consumeOnly).toBe(false); + expect(redacted.inspectionFailed).toBe(false); + expect(redacted.customUserVars).toEqual(config.customUserVars); + }); + + it('should pass URLs through unchanged', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://mcp.example.com/sse?param=value', + }; + const redacted = redactServerSecrets(config); + expect(redacted.url).toBe('https://mcp.example.com/sse?param=value'); + }); + + it('should only include explicitly allowlisted fields', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Test', + }; + (config as Record).someNewSensitiveField = 'leaked-value'; + const redacted = redactServerSecrets(config); + expect((redacted as Record).someNewSensitiveField).toBeUndefined(); + expect(redacted.title).toBe('Test'); + }); +}); + +describe('redactAllServerSecrets', () => { + it('should redact secrets from all configs in the map', () => { + const configs: Record = { + 'server-a': { + type: 'sse', + url: 'https://a.com/mcp', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'key-a' }, + }, + 'server-b': { + type: 'sse', + url: 'https://b.com/mcp', + oauth: { client_id: 'cid-b', client_secret: 'secret-b' }, + }, + 'server-c': { + type: 'stdio', + command: 'node', + args: ['c.js'], + }, + }; + const redacted = redactAllServerSecrets(configs); + expect(redacted['server-a'].apiKey?.key).toBeUndefined(); + expect(redacted['server-a'].apiKey?.source).toBe('admin'); + expect(redacted['server-b'].oauth?.client_secret).toBeUndefined(); + expect(redacted['server-b'].oauth?.client_id).toBe('cid-b'); + expect((redacted['server-c'] as Record).command).toBeUndefined(); + }); +}); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 366d0d2fde..8d863bfe79 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -24,6 +24,7 @@ import { selectRegistrationAuthMethod, inferClientAuthMethod, } from './methods'; +import { isSSRFTarget, resolveHostnameSSRF } from '~/auth'; import { sanitizeUrlForLogging } from '~/mcp/utils'; /** Type for the OAuth metadata from the SDK */ @@ -144,7 +145,9 @@ export class MCPOAuthHandler { resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn); if (resourceMetadata?.authorization_servers?.length) { - authServerUrl = new URL(resourceMetadata.authorization_servers[0]); + const discoveredAuthServer = resourceMetadata.authorization_servers[0]; + await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server'); + authServerUrl = new URL(discoveredAuthServer); logger.debug( `[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`, ); @@ -200,6 +203,19 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] OAuth metadata discovered successfully`); const metadata = await OAuthMetadataSchema.parseAsync(rawMetadata); + const endpointChecks: Promise[] = []; + if (metadata.registration_endpoint) { + endpointChecks.push( + this.validateOAuthUrl(metadata.registration_endpoint, 'registration_endpoint'), + ); + } + if (metadata.token_endpoint) { + endpointChecks.push(this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint')); + } + if (endpointChecks.length > 0) { + await Promise.all(endpointChecks); + } + logger.debug(`[MCPOAuth] OAuth metadata parsed successfully`); return { metadata: metadata as unknown as OAuthMetadata, @@ -355,10 +371,14 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] Generated flowId: ${flowId}, state: ${state}`); try { - // Check if we have pre-configured OAuth settings if (config?.authorization_url && config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`); + await Promise.all([ + this.validateOAuthUrl(config.authorization_url, 'authorization_url'), + this.validateOAuthUrl(config.token_url, 'token_url'), + ]); + const skipCodeChallengeCheck = config?.skip_code_challenge_check === true || process.env.MCP_SKIP_CODE_CHALLENGE_CHECK === 'true'; @@ -410,10 +430,11 @@ export class MCPOAuthHandler { code_challenge_methods_supported: codeChallengeMethodsSupported, }; logger.debug(`[MCPOAuth] metadata for "${serverName}": ${JSON.stringify(metadata)}`); + const redirectUri = this.getDefaultRedirectUri(serverName); const clientInfo: OAuthClientInformation = { client_id: config.client_id, client_secret: config.client_secret, - redirect_uris: [config.redirect_uri || this.getDefaultRedirectUri(serverName)], + redirect_uris: [redirectUri], scope: config.scope, token_endpoint_auth_method: tokenEndpointAuthMethod, }; @@ -422,7 +443,7 @@ export class MCPOAuthHandler { const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { metadata: metadata as unknown as SDKOAuthMetadata, clientInformation: clientInfo, - redirectUrl: clientInfo.redirect_uris?.[0] || this.getDefaultRedirectUri(serverName), + redirectUrl: redirectUri, scope: config.scope, }); @@ -462,8 +483,7 @@ export class MCPOAuthHandler { `[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`, ); - /** Dynamic client registration based on the discovered metadata */ - const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName); + const redirectUri = this.getDefaultRedirectUri(serverName); logger.debug(`[MCPOAuth] Registering OAuth client with redirect URI: ${redirectUri}`); const clientInfo = await this.registerOAuthClient( @@ -672,6 +692,24 @@ export class MCPOAuthHandler { return randomBytes(32).toString('base64url'); } + /** Validates an OAuth URL is not targeting a private/internal address */ + private static async validateOAuthUrl(url: string, fieldName: string): Promise { + let hostname: string; + try { + hostname = new URL(url).hostname; + } catch { + throw new Error(`Invalid OAuth ${fieldName}: ${sanitizeUrlForLogging(url)}`); + } + + if (isSSRFTarget(hostname)) { + throw new Error(`OAuth ${fieldName} targets a blocked address`); + } + + if (await resolveHostnameSSRF(hostname)) { + throw new Error(`OAuth ${fieldName} resolves to a private IP address`); + } + } + private static readonly STATE_MAP_TYPE = 'mcp_oauth_state'; /** @@ -783,10 +821,10 @@ export class MCPOAuthHandler { scope: metadata.clientInfo.scope, }); - /** Use the stored client information and metadata to determine the token URL */ let tokenUrl: string; let authMethods: string[] | undefined; if (config?.token_url) { + await this.validateOAuthUrl(config.token_url, 'token_url'); tokenUrl = config.token_url; authMethods = config.token_endpoint_auth_methods_supported; } else if (!metadata.serverUrl) { @@ -813,6 +851,7 @@ export class MCPOAuthHandler { tokenUrl = oauthMetadata.token_endpoint; authMethods = oauthMetadata.token_endpoint_auth_methods_supported; } + await this.validateOAuthUrl(tokenUrl, 'token_url'); } const body = new URLSearchParams({ @@ -886,10 +925,10 @@ export class MCPOAuthHandler { return this.processRefreshResponse(tokens, metadata.serverName, 'stored client info'); } - // Fallback: If we have pre-configured OAuth settings, use them if (config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`); + await this.validateOAuthUrl(config.token_url, 'token_url'); const tokenUrl = new URL(config.token_url); const body = new URLSearchParams({ @@ -987,6 +1026,7 @@ export class MCPOAuthHandler { } else { tokenUrl = new URL(oauthMetadata.token_endpoint); } + await this.validateOAuthUrl(tokenUrl.href, 'token_url'); const body = new URLSearchParams({ grant_type: 'refresh_token', @@ -1036,7 +1076,9 @@ export class MCPOAuthHandler { }, oauthHeaders: Record = {}, ): Promise { - // build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided + if (metadata.revocationEndpoint != null) { + await this.validateOAuthUrl(metadata.revocationEndpoint, 'revocation_endpoint'); + } const revokeUrl: URL = metadata.revocationEndpoint != null ? new URL(metadata.revocationEndpoint) diff --git a/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts b/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts index 1c755ae0f0..38ed51cd99 100644 --- a/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts +++ b/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts @@ -1456,4 +1456,102 @@ describe('ServerConfigsDB', () => { expect(retrieved?.apiKey?.key).toBeUndefined(); }); }); + + describe('DB layer returns decrypted secrets (redaction is at controller layer)', () => { + it('should return decrypted apiKey.key to VIEW-only user via get()', async () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Secret API Key Server', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'admin-secret-api-key', + }, + }; + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.get(created.serverName, userId2); + expect(result).toBeDefined(); + expect(result?.apiKey?.key).toBe('admin-secret-api-key'); + }); + + it('should return decrypted oauth.client_secret to VIEW-only user via get()', async () => { + const config = createSSEConfig('Secret OAuth Server', 'Test', { + client_id: 'my-client-id', + client_secret: 'admin-oauth-secret', + }); + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.get(created.serverName, userId2); + expect(result).toBeDefined(); + expect(result?.oauth?.client_secret).toBe('admin-oauth-secret'); + }); + + it('should return decrypted secrets to VIEW-only user via getAll()', async () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Shared Secret Server', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'shared-api-key', + }, + oauth: { + client_id: 'shared-client', + client_secret: 'shared-oauth-secret', + }, + }; + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.getAll(userId2); + const serverConfig = result[created.serverName]; + expect(serverConfig).toBeDefined(); + expect(serverConfig?.apiKey?.key).toBe('shared-api-key'); + expect(serverConfig?.oauth?.client_secret).toBe('shared-oauth-secret'); + }); + }); }); diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index fddebb9db3..c517388a76 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -1,6 +1,66 @@ import { Constants } from 'librechat-data-provider'; +import type { ParsedServerConfig } from '~/mcp/types'; export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); + +/** + * Allowlist-based sanitization for API responses. Only explicitly listed fields are included; + * new fields added to ParsedServerConfig are excluded by default until allowlisted here. + * + * URLs are returned as-is: DB-stored configs reject ${VAR} patterns at validation time + * (MCPServerUserInputSchema), and YAML configs are admin-managed. Env variable resolution + * is handled at the schema/input boundary, not the output boundary. + */ +export function redactServerSecrets(config: ParsedServerConfig): Partial { + const safe: Partial = { + type: config.type, + url: config.url, + title: config.title, + description: config.description, + iconPath: config.iconPath, + chatMenu: config.chatMenu, + requiresOAuth: config.requiresOAuth, + capabilities: config.capabilities, + tools: config.tools, + toolFunctions: config.toolFunctions, + initDuration: config.initDuration, + updatedAt: config.updatedAt, + dbId: config.dbId, + consumeOnly: config.consumeOnly, + inspectionFailed: config.inspectionFailed, + customUserVars: config.customUserVars, + serverInstructions: config.serverInstructions, + }; + + if (config.apiKey) { + safe.apiKey = { + source: config.apiKey.source, + authorization_type: config.apiKey.authorization_type, + ...(config.apiKey.custom_header && { custom_header: config.apiKey.custom_header }), + }; + } + + if (config.oauth) { + const { client_secret: _secret, ...safeOAuth } = config.oauth; + safe.oauth = safeOAuth; + } + + return Object.fromEntries( + Object.entries(safe).filter(([, v]) => v !== undefined), + ) as Partial; +} + +/** Applies allowlist-based sanitization to a map of server configs. */ +export function redactAllServerSecrets( + configs: Record, +): Record> { + const result: Record> = {}; + for (const [key, config] of Object.entries(configs)) { + result[key] = redactServerSecrets(config); + } + return result; +} + /** * Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$ * This is required for Azure OpenAI models with Tool Calling