mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
📮 feat: Custom OAuth Headers Support for MCP Server Config (#10014)
Some checks failed
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Waiting to run
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Blocked by required conditions
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Has been cancelled
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Has been cancelled
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Has been cancelled
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Has been cancelled
Some checks failed
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Waiting to run
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Blocked by required conditions
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Has been cancelled
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Has been cancelled
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Has been cancelled
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Has been cancelled
* add oauth_headers field to mcp options * wrap fetch to pass oauth headers * fix order * consolidate headers passing * fix tests
This commit is contained in:
parent
cbd217efae
commit
5ce67b5b71
8 changed files with 304 additions and 35 deletions
|
|
@ -327,16 +327,23 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||||
const revocationEndpointAuthMethodsSupported =
|
const revocationEndpointAuthMethodsSupported =
|
||||||
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
||||||
clientMetadata.revocation_endpoint_auth_methods_supported;
|
clientMetadata.revocation_endpoint_auth_methods_supported;
|
||||||
|
const oauthHeaders = serverConfig.oauth_headers ?? {};
|
||||||
|
|
||||||
if (tokens?.access_token) {
|
if (tokens?.access_token) {
|
||||||
try {
|
try {
|
||||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', {
|
await MCPOAuthHandler.revokeOAuthToken(
|
||||||
|
serverName,
|
||||||
|
tokens.access_token,
|
||||||
|
'access',
|
||||||
|
{
|
||||||
serverUrl: serverConfig.url,
|
serverUrl: serverConfig.url,
|
||||||
clientId: clientInfo.client_id,
|
clientId: clientInfo.client_id,
|
||||||
clientSecret: clientInfo.client_secret ?? '',
|
clientSecret: clientInfo.client_secret ?? '',
|
||||||
revocationEndpoint,
|
revocationEndpoint,
|
||||||
revocationEndpointAuthMethodsSupported,
|
revocationEndpointAuthMethodsSupported,
|
||||||
});
|
},
|
||||||
|
oauthHeaders,
|
||||||
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
||||||
}
|
}
|
||||||
|
|
@ -344,13 +351,19 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||||
|
|
||||||
if (tokens?.refresh_token) {
|
if (tokens?.refresh_token) {
|
||||||
try {
|
try {
|
||||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', {
|
await MCPOAuthHandler.revokeOAuthToken(
|
||||||
|
serverName,
|
||||||
|
tokens.refresh_token,
|
||||||
|
'refresh',
|
||||||
|
{
|
||||||
serverUrl: serverConfig.url,
|
serverUrl: serverConfig.url,
|
||||||
clientId: clientInfo.client_id,
|
clientId: clientInfo.client_id,
|
||||||
clientSecret: clientInfo.client_secret ?? '',
|
clientSecret: clientInfo.client_secret ?? '',
|
||||||
revocationEndpoint,
|
revocationEndpoint,
|
||||||
revocationEndpointAuthMethodsSupported,
|
revocationEndpointAuthMethodsSupported,
|
||||||
});
|
},
|
||||||
|
oauthHeaders,
|
||||||
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -127,8 +127,13 @@ describe('MCP Routes', () => {
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const mockMcpManager = {
|
||||||
|
getRawConfig: jest.fn().mockReturnValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
getLogStores.mockReturnValue({});
|
getLogStores.mockReturnValue({});
|
||||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
||||||
authorizationUrl: 'https://oauth.example.com/auth',
|
authorizationUrl: 'https://oauth.example.com/auth',
|
||||||
|
|
@ -146,6 +151,7 @@ describe('MCP Routes', () => {
|
||||||
'test-server',
|
'test-server',
|
||||||
'https://test-server.com',
|
'https://test-server.com',
|
||||||
'test-user-id',
|
'test-user-id',
|
||||||
|
{},
|
||||||
{ clientId: 'test-client-id' },
|
{ clientId: 'test-client-id' },
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
@ -314,6 +320,7 @@ describe('MCP Routes', () => {
|
||||||
};
|
};
|
||||||
const mockMcpManager = {
|
const mockMcpManager = {
|
||||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||||
|
getRawConfig: jest.fn().mockReturnValue({}),
|
||||||
};
|
};
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
|
|
@ -336,6 +343,7 @@ describe('MCP Routes', () => {
|
||||||
'test-flow-id',
|
'test-flow-id',
|
||||||
'test-auth-code',
|
'test-auth-code',
|
||||||
mockFlowManager,
|
mockFlowManager,
|
||||||
|
{},
|
||||||
);
|
);
|
||||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
|
|
@ -392,6 +400,11 @@ describe('MCP Routes', () => {
|
||||||
getLogStores.mockReturnValue({});
|
getLogStores.mockReturnValue({});
|
||||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
const mockMcpManager = {
|
||||||
|
getRawConfig: jest.fn().mockReturnValue({}),
|
||||||
|
};
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||||
code: 'test-auth-code',
|
code: 'test-auth-code',
|
||||||
state: 'test-flow-id',
|
state: 'test-flow-id',
|
||||||
|
|
@ -427,6 +440,7 @@ describe('MCP Routes', () => {
|
||||||
|
|
||||||
const mockMcpManager = {
|
const mockMcpManager = {
|
||||||
getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')),
|
getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')),
|
||||||
|
getRawConfig: jest.fn().mockReturnValue({}),
|
||||||
};
|
};
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
|
|
@ -1234,6 +1248,7 @@ describe('MCP Routes', () => {
|
||||||
getUserConnection: jest.fn().mockResolvedValue({
|
getUserConnection: jest.fn().mockResolvedValue({
|
||||||
fetchTools: jest.fn().mockResolvedValue([]),
|
fetchTools: jest.fn().mockResolvedValue([]),
|
||||||
}),
|
}),
|
||||||
|
getRawConfig: jest.fn().mockReturnValue({}),
|
||||||
};
|
};
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
|
|
@ -1281,6 +1296,7 @@ describe('MCP Routes', () => {
|
||||||
.fn()
|
.fn()
|
||||||
.mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]),
|
.mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]),
|
||||||
}),
|
}),
|
||||||
|
getRawConfig: jest.fn().mockReturnValue({}),
|
||||||
};
|
};
|
||||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
||||||
serverName,
|
serverName,
|
||||||
serverUrl,
|
serverUrl,
|
||||||
userId,
|
userId,
|
||||||
|
getOAuthHeaders(serverName),
|
||||||
oauthConfig,
|
oauthConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -132,7 +133,12 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||||
});
|
});
|
||||||
|
|
||||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
|
const tokens = await MCPOAuthHandler.completeOAuthFlow(
|
||||||
|
flowId,
|
||||||
|
code,
|
||||||
|
flowManager,
|
||||||
|
getOAuthHeaders(serverName),
|
||||||
|
);
|
||||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||||
|
|
||||||
/** Persist tokens immediately so reconnection uses fresh credentials */
|
/** Persist tokens immediately so reconnection uses fresh credentials */
|
||||||
|
|
@ -538,4 +544,10 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
function getOAuthHeaders(serverName) {
|
||||||
|
const mcpManager = getMCPManager();
|
||||||
|
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||||
|
return serverConfig?.oauth_headers ?? {};
|
||||||
|
}
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,7 @@ export class MCPConnectionFactory {
|
||||||
serverName: metadata.serverName,
|
serverName: metadata.serverName,
|
||||||
clientInfo: metadata.clientInfo,
|
clientInfo: metadata.clientInfo,
|
||||||
},
|
},
|
||||||
|
this.serverConfig.oauth_headers ?? {},
|
||||||
this.serverConfig.oauth,
|
this.serverConfig.oauth,
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
@ -161,6 +162,7 @@ export class MCPConnectionFactory {
|
||||||
this.serverName,
|
this.serverName,
|
||||||
data.serverUrl || '',
|
data.serverUrl || '',
|
||||||
this.userId!,
|
this.userId!,
|
||||||
|
config?.oauth_headers ?? {},
|
||||||
config?.oauth,
|
config?.oauth,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -358,6 +360,7 @@ export class MCPConnectionFactory {
|
||||||
this.serverName,
|
this.serverName,
|
||||||
serverUrl,
|
serverUrl,
|
||||||
this.userId!,
|
this.userId!,
|
||||||
|
this.serverConfig.oauth_headers ?? {},
|
||||||
this.serverConfig.oauth,
|
this.serverConfig.oauth,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -255,6 +255,7 @@ describe('MCPConnectionFactory', () => {
|
||||||
'test-server',
|
'test-server',
|
||||||
'https://api.example.com',
|
'https://api.example.com',
|
||||||
'user123',
|
'user123',
|
||||||
|
{},
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
|
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import type { MCPOptions } from 'librechat-data-provider';
|
import type { MCPOptions } from 'librechat-data-provider';
|
||||||
import type { AuthorizationServerMetadata } from '@modelcontextprotocol/sdk/shared/auth.js';
|
import type { AuthorizationServerMetadata } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||||
import { MCPOAuthHandler } from '~/mcp/oauth';
|
import { MCPOAuthFlowMetadata, MCPOAuthHandler, MCPOAuthTokens } from '~/mcp/oauth';
|
||||||
|
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
logger: {
|
logger: {
|
||||||
|
|
@ -14,18 +14,33 @@ jest.mock('@librechat/data-schemas', () => ({
|
||||||
jest.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
|
jest.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
|
||||||
startAuthorization: jest.fn(),
|
startAuthorization: jest.fn(),
|
||||||
discoverAuthorizationServerMetadata: jest.fn(),
|
discoverAuthorizationServerMetadata: jest.fn(),
|
||||||
|
discoverOAuthProtectedResourceMetadata: jest.fn(),
|
||||||
|
registerClient: jest.fn(),
|
||||||
|
exchangeAuthorization: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
import {
|
import {
|
||||||
startAuthorization,
|
startAuthorization,
|
||||||
discoverAuthorizationServerMetadata,
|
discoverAuthorizationServerMetadata,
|
||||||
|
discoverOAuthProtectedResourceMetadata,
|
||||||
|
registerClient,
|
||||||
|
exchangeAuthorization,
|
||||||
} from '@modelcontextprotocol/sdk/client/auth.js';
|
} from '@modelcontextprotocol/sdk/client/auth.js';
|
||||||
|
import { FlowStateManager } from '../../flow/manager';
|
||||||
|
|
||||||
const mockStartAuthorization = startAuthorization as jest.MockedFunction<typeof startAuthorization>;
|
const mockStartAuthorization = startAuthorization as jest.MockedFunction<typeof startAuthorization>;
|
||||||
const mockDiscoverAuthorizationServerMetadata =
|
const mockDiscoverAuthorizationServerMetadata =
|
||||||
discoverAuthorizationServerMetadata as jest.MockedFunction<
|
discoverAuthorizationServerMetadata as jest.MockedFunction<
|
||||||
typeof discoverAuthorizationServerMetadata
|
typeof discoverAuthorizationServerMetadata
|
||||||
>;
|
>;
|
||||||
|
const mockDiscoverOAuthProtectedResourceMetadata =
|
||||||
|
discoverOAuthProtectedResourceMetadata as jest.MockedFunction<
|
||||||
|
typeof discoverOAuthProtectedResourceMetadata
|
||||||
|
>;
|
||||||
|
const mockRegisterClient = registerClient as jest.MockedFunction<typeof registerClient>;
|
||||||
|
const mockExchangeAuthorization = exchangeAuthorization as jest.MockedFunction<
|
||||||
|
typeof exchangeAuthorization
|
||||||
|
>;
|
||||||
|
|
||||||
describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
const mockServerName = 'test-server';
|
const mockServerName = 'test-server';
|
||||||
|
|
@ -60,6 +75,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
mockServerName,
|
mockServerName,
|
||||||
mockServerUrl,
|
mockServerUrl,
|
||||||
mockUserId,
|
mockUserId,
|
||||||
|
{},
|
||||||
baseConfig,
|
baseConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -82,7 +98,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
grant_types_supported: ['authorization_code'],
|
grant_types_supported: ['authorization_code'],
|
||||||
};
|
};
|
||||||
|
|
||||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
mockServerName,
|
||||||
|
mockServerUrl,
|
||||||
|
mockUserId,
|
||||||
|
{},
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||||
mockServerUrl,
|
mockServerUrl,
|
||||||
|
|
@ -100,7 +122,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||||
};
|
};
|
||||||
|
|
||||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
mockServerName,
|
||||||
|
mockServerUrl,
|
||||||
|
mockUserId,
|
||||||
|
{},
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||||
mockServerUrl,
|
mockServerUrl,
|
||||||
|
|
@ -118,7 +146,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
response_types_supported: ['code', 'token'],
|
response_types_supported: ['code', 'token'],
|
||||||
};
|
};
|
||||||
|
|
||||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
mockServerName,
|
||||||
|
mockServerUrl,
|
||||||
|
mockUserId,
|
||||||
|
{},
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||||
mockServerUrl,
|
mockServerUrl,
|
||||||
|
|
@ -136,7 +170,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
code_challenge_methods_supported: ['S256'],
|
code_challenge_methods_supported: ['S256'],
|
||||||
};
|
};
|
||||||
|
|
||||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
mockServerName,
|
||||||
|
mockServerUrl,
|
||||||
|
mockUserId,
|
||||||
|
{},
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||||
mockServerUrl,
|
mockServerUrl,
|
||||||
|
|
@ -157,7 +197,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
code_challenge_methods_supported: ['S256'],
|
code_challenge_methods_supported: ['S256'],
|
||||||
};
|
};
|
||||||
|
|
||||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
mockServerName,
|
||||||
|
mockServerUrl,
|
||||||
|
mockUserId,
|
||||||
|
{},
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||||
mockServerUrl,
|
mockServerUrl,
|
||||||
|
|
@ -181,7 +227,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
code_challenge_methods_supported: [],
|
code_challenge_methods_supported: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config);
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
mockServerName,
|
||||||
|
mockServerUrl,
|
||||||
|
mockUserId,
|
||||||
|
{},
|
||||||
|
config,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
expect(mockStartAuthorization).toHaveBeenCalledWith(
|
||||||
mockServerUrl,
|
mockServerUrl,
|
||||||
|
|
@ -251,7 +303,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
}),
|
}),
|
||||||
} as Response);
|
} as Response);
|
||||||
|
|
||||||
const result = await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
const result = await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||||
|
|
||||||
// Verify the call was made without Authorization header
|
// Verify the call was made without Authorization header
|
||||||
expect(mockFetch).toHaveBeenCalledWith(
|
expect(mockFetch).toHaveBeenCalledWith(
|
||||||
|
|
@ -314,7 +366,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
}),
|
}),
|
||||||
} as Response);
|
} as Response);
|
||||||
|
|
||||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||||
|
|
||||||
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
||||||
expect(mockFetch).toHaveBeenCalledWith(
|
expect(mockFetch).toHaveBeenCalledWith(
|
||||||
|
|
@ -363,7 +415,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
}),
|
}),
|
||||||
} as Response);
|
} as Response);
|
||||||
|
|
||||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||||
|
|
||||||
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
||||||
expect(mockFetch).toHaveBeenCalledWith(
|
expect(mockFetch).toHaveBeenCalledWith(
|
||||||
|
|
@ -410,7 +462,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
}),
|
}),
|
||||||
} as Response);
|
} as Response);
|
||||||
|
|
||||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||||
|
|
||||||
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`;
|
||||||
expect(mockFetch).toHaveBeenCalledWith(
|
expect(mockFetch).toHaveBeenCalledWith(
|
||||||
|
|
@ -457,7 +509,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
}),
|
}),
|
||||||
} as Response);
|
} as Response);
|
||||||
|
|
||||||
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata);
|
await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {});
|
||||||
|
|
||||||
// Verify the call was made without Authorization header
|
// Verify the call was made without Authorization header
|
||||||
expect(mockFetch).toHaveBeenCalledWith(
|
expect(mockFetch).toHaveBeenCalledWith(
|
||||||
|
|
@ -498,6 +550,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
await MCPOAuthHandler.refreshOAuthTokens(
|
await MCPOAuthHandler.refreshOAuthTokens(
|
||||||
mockRefreshToken,
|
mockRefreshToken,
|
||||||
{ serverName: 'test-server' },
|
{ serverName: 'test-server' },
|
||||||
|
{},
|
||||||
config,
|
config,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -539,6 +592,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
await MCPOAuthHandler.refreshOAuthTokens(
|
await MCPOAuthHandler.refreshOAuthTokens(
|
||||||
mockRefreshToken,
|
mockRefreshToken,
|
||||||
{ serverName: 'test-server' },
|
{ serverName: 'test-server' },
|
||||||
|
{},
|
||||||
config,
|
config,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -575,6 +629,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
await MCPOAuthHandler.refreshOAuthTokens(
|
await MCPOAuthHandler.refreshOAuthTokens(
|
||||||
mockRefreshToken,
|
mockRefreshToken,
|
||||||
{ serverName: 'test-server' },
|
{ serverName: 'test-server' },
|
||||||
|
{},
|
||||||
config,
|
config,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -617,7 +672,9 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
'{"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}',
|
'{"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}',
|
||||||
} as Response);
|
} as Response);
|
||||||
|
|
||||||
await expect(MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata)).rejects.toThrow(
|
await expect(
|
||||||
|
MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}),
|
||||||
|
).rejects.toThrow(
|
||||||
'Token refresh failed: 400 Bad Request - {"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}',
|
'Token refresh failed: 400 Bad Request - {"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
@ -813,4 +870,126 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Custom OAuth Headers', () => {
|
||||||
|
const originalFetch = global.fetch;
|
||||||
|
const mockFetch = jest.fn();
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
global.fetch = mockFetch as unknown as typeof fetch;
|
||||||
|
mockFetch.mockResolvedValue({ ok: true, json: async () => ({}) } as Response);
|
||||||
|
mockDiscoverAuthorizationServerMetadata.mockResolvedValue({
|
||||||
|
issuer: 'http://example.com',
|
||||||
|
authorization_endpoint: 'http://example.com/auth',
|
||||||
|
token_endpoint: 'http://example.com/token',
|
||||||
|
response_types_supported: ['code'],
|
||||||
|
} as AuthorizationServerMetadata);
|
||||||
|
mockStartAuthorization.mockResolvedValue({
|
||||||
|
authorizationUrl: new URL('http://example.com/auth'),
|
||||||
|
codeVerifier: 'test-verifier',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
global.fetch = originalFetch;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('passes headers to client registration', async () => {
|
||||||
|
mockRegisterClient.mockImplementation(async (_, options) => {
|
||||||
|
await options.fetchFn?.('http://example.com/register', {});
|
||||||
|
return { client_id: 'test', redirect_uris: [] };
|
||||||
|
});
|
||||||
|
|
||||||
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
'test-server',
|
||||||
|
'http://example.com',
|
||||||
|
'user-123',
|
||||||
|
{ foo: 'bar' },
|
||||||
|
{},
|
||||||
|
);
|
||||||
|
|
||||||
|
const headers = mockFetch.mock.calls[0][1]?.headers as Headers;
|
||||||
|
expect(headers.get('foo')).toBe('bar');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('passes headers to discovery operations', async () => {
|
||||||
|
mockDiscoverOAuthProtectedResourceMetadata.mockImplementation(async (_, __, fetchFn) => {
|
||||||
|
await fetchFn?.('http://example.com/.well-known/oauth-protected-resource', {});
|
||||||
|
return {
|
||||||
|
resource: 'http://example.com',
|
||||||
|
authorization_servers: ['http://auth.example.com'],
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
|
'test-server',
|
||||||
|
'http://example.com',
|
||||||
|
'user-123',
|
||||||
|
{ foo: 'bar' },
|
||||||
|
{},
|
||||||
|
);
|
||||||
|
|
||||||
|
const allHaveHeader = mockFetch.mock.calls.every((call) => {
|
||||||
|
const headers = call[1]?.headers as Headers;
|
||||||
|
return headers?.get('foo') === 'bar';
|
||||||
|
});
|
||||||
|
expect(allHaveHeader).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('passes headers to token exchange', async () => {
|
||||||
|
const mockFlowManager = {
|
||||||
|
getFlowState: jest.fn().mockResolvedValue({
|
||||||
|
status: 'PENDING',
|
||||||
|
metadata: {
|
||||||
|
serverName: 'test-server',
|
||||||
|
codeVerifier: 'test-verifier',
|
||||||
|
clientInfo: {},
|
||||||
|
metadata: {},
|
||||||
|
} as MCPOAuthFlowMetadata,
|
||||||
|
}),
|
||||||
|
completeFlow: jest.fn(),
|
||||||
|
} as unknown as FlowStateManager<MCPOAuthTokens>;
|
||||||
|
|
||||||
|
mockExchangeAuthorization.mockImplementation(async (_, options) => {
|
||||||
|
await options.fetchFn?.('http://example.com/token', {});
|
||||||
|
return { access_token: 'test-token', token_type: 'Bearer', expires_in: 3600 };
|
||||||
|
});
|
||||||
|
|
||||||
|
await MCPOAuthHandler.completeOAuthFlow('test-flow-id', 'test-auth-code', mockFlowManager, {
|
||||||
|
foo: 'bar',
|
||||||
|
});
|
||||||
|
|
||||||
|
const headers = mockFetch.mock.calls[0][1]?.headers as Headers;
|
||||||
|
expect(headers.get('foo')).toBe('bar');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('passes headers to token refresh', async () => {
|
||||||
|
mockDiscoverAuthorizationServerMetadata.mockImplementation(async (_, options) => {
|
||||||
|
await options?.fetchFn?.('http://example.com/.well-known/oauth-authorization-server', {});
|
||||||
|
return {
|
||||||
|
issuer: 'http://example.com',
|
||||||
|
token_endpoint: 'http://example.com/token',
|
||||||
|
} as AuthorizationServerMetadata;
|
||||||
|
});
|
||||||
|
|
||||||
|
await MCPOAuthHandler.refreshOAuthTokens(
|
||||||
|
'test-refresh-token',
|
||||||
|
{
|
||||||
|
serverName: 'test-server',
|
||||||
|
serverUrl: 'http://example.com',
|
||||||
|
clientInfo: { client_id: 'test-client', client_secret: 'test-secret' },
|
||||||
|
},
|
||||||
|
{ foo: 'bar' },
|
||||||
|
{},
|
||||||
|
);
|
||||||
|
|
||||||
|
const discoveryCall = mockFetch.mock.calls.find((call) =>
|
||||||
|
call[0].toString().includes('.well-known'),
|
||||||
|
);
|
||||||
|
expect(discoveryCall).toBeDefined();
|
||||||
|
const headers = discoveryCall![1]?.headers as Headers;
|
||||||
|
expect(headers.get('foo')).toBe('bar');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import type {
|
||||||
OAuthMetadata,
|
OAuthMetadata,
|
||||||
} from './types';
|
} from './types';
|
||||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||||
|
import { FetchLike } from '@modelcontextprotocol/sdk/shared/transport';
|
||||||
|
|
||||||
/** Type for the OAuth metadata from the SDK */
|
/** Type for the OAuth metadata from the SDK */
|
||||||
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
|
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
|
||||||
|
|
@ -26,10 +27,29 @@ export class MCPOAuthHandler {
|
||||||
private static readonly FLOW_TYPE = 'mcp_oauth';
|
private static readonly FLOW_TYPE = 'mcp_oauth';
|
||||||
private static readonly FLOW_TTL = 10 * 60 * 1000; // 10 minutes
|
private static readonly FLOW_TTL = 10 * 60 * 1000; // 10 minutes
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a fetch function with custom headers injected
|
||||||
|
*/
|
||||||
|
private static createOAuthFetch(headers: Record<string, string>): FetchLike {
|
||||||
|
return async (url: string | URL, init?: RequestInit): Promise<Response> => {
|
||||||
|
const newHeaders = new Headers(init?.headers ?? {});
|
||||||
|
for (const [key, value] of Object.entries(headers)) {
|
||||||
|
newHeaders.set(key, value);
|
||||||
|
}
|
||||||
|
return fetch(url, {
|
||||||
|
...init,
|
||||||
|
headers: newHeaders,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Discovers OAuth metadata from the server
|
* Discovers OAuth metadata from the server
|
||||||
*/
|
*/
|
||||||
private static async discoverMetadata(serverUrl: string): Promise<{
|
private static async discoverMetadata(
|
||||||
|
serverUrl: string,
|
||||||
|
oauthHeaders: Record<string, string>,
|
||||||
|
): Promise<{
|
||||||
metadata: OAuthMetadata;
|
metadata: OAuthMetadata;
|
||||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||||
authServerUrl: URL;
|
authServerUrl: URL;
|
||||||
|
|
@ -41,12 +61,14 @@ export class MCPOAuthHandler {
|
||||||
let authServerUrl = new URL(serverUrl);
|
let authServerUrl = new URL(serverUrl);
|
||||||
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
|
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
|
||||||
|
|
||||||
|
const fetchFn = this.createOAuthFetch(oauthHeaders);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Try to discover resource metadata first
|
// Try to discover resource metadata first
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[MCPOAuth] Attempting to discover protected resource metadata from ${serverUrl}`,
|
`[MCPOAuth] Attempting to discover protected resource metadata from ${serverUrl}`,
|
||||||
);
|
);
|
||||||
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl);
|
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn);
|
||||||
|
|
||||||
if (resourceMetadata?.authorization_servers?.length) {
|
if (resourceMetadata?.authorization_servers?.length) {
|
||||||
authServerUrl = new URL(resourceMetadata.authorization_servers[0]);
|
authServerUrl = new URL(resourceMetadata.authorization_servers[0]);
|
||||||
|
|
@ -66,7 +88,9 @@ export class MCPOAuthHandler {
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||||
);
|
);
|
||||||
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl);
|
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl, {
|
||||||
|
fetchFn,
|
||||||
|
});
|
||||||
|
|
||||||
if (!rawMetadata) {
|
if (!rawMetadata) {
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -92,6 +116,7 @@ export class MCPOAuthHandler {
|
||||||
private static async registerOAuthClient(
|
private static async registerOAuthClient(
|
||||||
serverUrl: string,
|
serverUrl: string,
|
||||||
metadata: OAuthMetadata,
|
metadata: OAuthMetadata,
|
||||||
|
oauthHeaders: Record<string, string>,
|
||||||
resourceMetadata?: OAuthProtectedResourceMetadata,
|
resourceMetadata?: OAuthProtectedResourceMetadata,
|
||||||
redirectUri?: string,
|
redirectUri?: string,
|
||||||
): Promise<OAuthClientInformation> {
|
): Promise<OAuthClientInformation> {
|
||||||
|
|
@ -159,6 +184,7 @@ export class MCPOAuthHandler {
|
||||||
const clientInfo = await registerClient(serverUrl, {
|
const clientInfo = await registerClient(serverUrl, {
|
||||||
metadata: metadata as unknown as SDKOAuthMetadata,
|
metadata: metadata as unknown as SDKOAuthMetadata,
|
||||||
clientMetadata,
|
clientMetadata,
|
||||||
|
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||||
});
|
});
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -181,7 +207,8 @@ export class MCPOAuthHandler {
|
||||||
serverName: string,
|
serverName: string,
|
||||||
serverUrl: string,
|
serverUrl: string,
|
||||||
userId: string,
|
userId: string,
|
||||||
config: MCPOptions['oauth'] | undefined,
|
oauthHeaders: Record<string, string>,
|
||||||
|
config?: MCPOptions['oauth'],
|
||||||
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
|
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
|
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
|
||||||
|
|
@ -259,7 +286,10 @@ export class MCPOAuthHandler {
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`,
|
`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`,
|
||||||
);
|
);
|
||||||
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl);
|
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(
|
||||||
|
serverUrl,
|
||||||
|
oauthHeaders,
|
||||||
|
);
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
|
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||||
|
|
@ -272,6 +302,7 @@ export class MCPOAuthHandler {
|
||||||
const clientInfo = await this.registerOAuthClient(
|
const clientInfo = await this.registerOAuthClient(
|
||||||
authServerUrl.toString(),
|
authServerUrl.toString(),
|
||||||
metadata,
|
metadata,
|
||||||
|
oauthHeaders,
|
||||||
resourceMetadata,
|
resourceMetadata,
|
||||||
redirectUri,
|
redirectUri,
|
||||||
);
|
);
|
||||||
|
|
@ -365,6 +396,7 @@ export class MCPOAuthHandler {
|
||||||
flowId: string,
|
flowId: string,
|
||||||
authorizationCode: string,
|
authorizationCode: string,
|
||||||
flowManager: FlowStateManager<MCPOAuthTokens>,
|
flowManager: FlowStateManager<MCPOAuthTokens>,
|
||||||
|
oauthHeaders: Record<string, string>,
|
||||||
): Promise<MCPOAuthTokens> {
|
): Promise<MCPOAuthTokens> {
|
||||||
try {
|
try {
|
||||||
/** Flow state which contains our metadata */
|
/** Flow state which contains our metadata */
|
||||||
|
|
@ -404,6 +436,7 @@ export class MCPOAuthHandler {
|
||||||
codeVerifier: metadata.codeVerifier,
|
codeVerifier: metadata.codeVerifier,
|
||||||
authorizationCode,
|
authorizationCode,
|
||||||
resource,
|
resource,
|
||||||
|
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||||
});
|
});
|
||||||
|
|
||||||
logger.debug('[MCPOAuth] Raw tokens from exchange:', {
|
logger.debug('[MCPOAuth] Raw tokens from exchange:', {
|
||||||
|
|
@ -476,6 +509,7 @@ export class MCPOAuthHandler {
|
||||||
static async refreshOAuthTokens(
|
static async refreshOAuthTokens(
|
||||||
refreshToken: string,
|
refreshToken: string,
|
||||||
metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation },
|
metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation },
|
||||||
|
oauthHeaders: Record<string, string>,
|
||||||
config?: MCPOptions['oauth'],
|
config?: MCPOptions['oauth'],
|
||||||
): Promise<MCPOAuthTokens> {
|
): Promise<MCPOAuthTokens> {
|
||||||
logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`);
|
logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`);
|
||||||
|
|
@ -509,7 +543,9 @@ export class MCPOAuthHandler {
|
||||||
throw new Error('No token URL available for refresh');
|
throw new Error('No token URL available for refresh');
|
||||||
} else {
|
} else {
|
||||||
/** Auto-discover OAuth configuration for refresh */
|
/** Auto-discover OAuth configuration for refresh */
|
||||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl);
|
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
|
||||||
|
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||||
|
});
|
||||||
if (!oauthMetadata) {
|
if (!oauthMetadata) {
|
||||||
throw new Error('Failed to discover OAuth metadata for token refresh');
|
throw new Error('Failed to discover OAuth metadata for token refresh');
|
||||||
}
|
}
|
||||||
|
|
@ -533,6 +569,7 @@ export class MCPOAuthHandler {
|
||||||
const headers: HeadersInit = {
|
const headers: HeadersInit = {
|
||||||
'Content-Type': 'application/x-www-form-urlencoded',
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
|
...oauthHeaders,
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Handle authentication based on server's advertised methods */
|
/** Handle authentication based on server's advertised methods */
|
||||||
|
|
@ -613,6 +650,7 @@ export class MCPOAuthHandler {
|
||||||
const headers: HeadersInit = {
|
const headers: HeadersInit = {
|
||||||
'Content-Type': 'application/x-www-form-urlencoded',
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
|
...oauthHeaders,
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Handle authentication based on configured methods */
|
/** Handle authentication based on configured methods */
|
||||||
|
|
@ -684,7 +722,9 @@ export class MCPOAuthHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Auto-discover OAuth configuration for refresh */
|
/** Auto-discover OAuth configuration for refresh */
|
||||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl);
|
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
|
||||||
|
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||||
|
});
|
||||||
|
|
||||||
if (!oauthMetadata?.token_endpoint) {
|
if (!oauthMetadata?.token_endpoint) {
|
||||||
throw new Error('No token endpoint found in OAuth metadata');
|
throw new Error('No token endpoint found in OAuth metadata');
|
||||||
|
|
@ -700,6 +740,7 @@ export class MCPOAuthHandler {
|
||||||
const headers: HeadersInit = {
|
const headers: HeadersInit = {
|
||||||
'Content-Type': 'application/x-www-form-urlencoded',
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
Accept: 'application/json',
|
Accept: 'application/json',
|
||||||
|
...oauthHeaders,
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(tokenUrl, {
|
const response = await fetch(tokenUrl, {
|
||||||
|
|
@ -742,6 +783,7 @@ export class MCPOAuthHandler {
|
||||||
revocationEndpoint?: string;
|
revocationEndpoint?: string;
|
||||||
revocationEndpointAuthMethodsSupported?: string[];
|
revocationEndpointAuthMethodsSupported?: string[];
|
||||||
},
|
},
|
||||||
|
oauthHeaders: Record<string, string> = {},
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
// build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided
|
// build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided
|
||||||
const revokeUrl: URL =
|
const revokeUrl: URL =
|
||||||
|
|
@ -759,6 +801,7 @@ export class MCPOAuthHandler {
|
||||||
// init the request headers
|
// init the request headers
|
||||||
const headers: Record<string, string> = {
|
const headers: Record<string, string> = {
|
||||||
'Content-Type': 'application/x-www-form-urlencoded',
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
|
...oauthHeaders,
|
||||||
};
|
};
|
||||||
|
|
||||||
// init the request body
|
// init the request body
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,8 @@ const BaseOptionsSchema = z.object({
|
||||||
revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(),
|
revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(),
|
||||||
})
|
})
|
||||||
.optional(),
|
.optional(),
|
||||||
|
/** Custom headers to send with OAuth requests (registration, discovery, token exchange, etc.) */
|
||||||
|
oauth_headers: z.record(z.string(), z.string()).optional(),
|
||||||
customUserVars: z
|
customUserVars: z
|
||||||
.record(
|
.record(
|
||||||
z.string(),
|
z.string(),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue