mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
🔌 feat: Revoke MCP OAuth Credentials (#9464)
* revocation metadata fields * store metadata * get client info and meta * revoke oauth tokens * delete flow * uninstall oauth mcp * revoke button * revoke oauth refactor, add comments, test * adjust for clarity * test deleteFlow * handle metadata type * no mutation * adjust for clarity * styling * restructure for clarity * move token-specific stuff * use mcpmanager's oauth servers * fix typo * fix addressing of oauth prop * log prefix * remove debug log
This commit is contained in:
parent
5667cc9702
commit
04c3a5a861
12 changed files with 725 additions and 6 deletions
|
@ -1,5 +1,10 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { webSearchKeys, extractWebSearchEnvVars, normalizeHttpError } = require('@librechat/api');
|
||||
const {
|
||||
webSearchKeys,
|
||||
extractWebSearchEnvVars,
|
||||
normalizeHttpError,
|
||||
MCPTokenStorage,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
getFiles,
|
||||
updateUser,
|
||||
|
@ -16,11 +21,17 @@ const { verifyEmail, resendVerificationEmail } = require('~/server/services/Auth
|
|||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { Tools, Constants, FileSources } = require('librechat-data-provider');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { Transaction, Balance, User } = require('~/db/models');
|
||||
const { Transaction, Balance, User, Token } = require('~/db/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { deleteAllSharedLinks } = require('~/models');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { clearMCPServerTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { findToken } = require('~/models');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
|
@ -162,6 +173,15 @@ const updateUserPluginsController = async (req, res) => {
|
|||
);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
}
|
||||
try {
|
||||
// if the MCP server uses OAuth, perform a full cleanup and token revocation
|
||||
await maybeUninstallOAuthMCP(user.id, pluginKey, appConfig);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[updateUserPluginsController] Error uninstalling OAuth MCP for ${pluginKey}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// This handles:
|
||||
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
|
||||
|
@ -269,6 +289,97 @@ const resendVerificationController = async (req, res) => {
|
|||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* OAuth MCP specific uninstall logic
|
||||
*/
|
||||
const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||
if (!pluginKey.startsWith(Constants.mcp_prefix)) {
|
||||
// this is not an MCP server, so nothing to do here
|
||||
return;
|
||||
}
|
||||
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName];
|
||||
|
||||
if (!mcpManager.getOAuthServers().has(serverName)) {
|
||||
// this server does not use OAuth, so nothing to do here as well
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. get client info used for revocation (client id, secret)
|
||||
const clientTokenData = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
});
|
||||
if (clientTokenData == null) {
|
||||
return;
|
||||
}
|
||||
const { clientInfo, clientMetadata } = clientTokenData;
|
||||
|
||||
// 2. get decrypted tokens before deletion
|
||||
const tokens = await MCPTokenStorage.getTokens({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
});
|
||||
|
||||
// 3. revoke OAuth tokens at the provider
|
||||
const revocationEndpoint =
|
||||
serverConfig.oauth?.revocation_endpoint ?? clientMetadata.revocation_endpoint;
|
||||
const revocationEndpointAuthMethodsSupported =
|
||||
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
||||
clientMetadata.revocation_endpoint_auth_methods_supported;
|
||||
|
||||
if (tokens?.access_token) {
|
||||
try {
|
||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', {
|
||||
serverUrl: serverConfig.url,
|
||||
clientId: clientInfo.client_id,
|
||||
clientSecret: clientInfo.client_secret ?? '',
|
||||
revocationEndpoint,
|
||||
revocationEndpointAuthMethodsSupported,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens?.refresh_token) {
|
||||
try {
|
||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', {
|
||||
serverUrl: serverConfig.url,
|
||||
clientId: clientInfo.client_id,
|
||||
clientSecret: clientInfo.client_secret ?? '',
|
||||
revocationEndpoint,
|
||||
revocationEndpointAuthMethodsSupported,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. delete tokens from the DB after revocation attempts
|
||||
await MCPTokenStorage.deleteUserTokens({
|
||||
userId,
|
||||
serverName,
|
||||
deleteToken: async (filter) => {
|
||||
await Token.deleteOne(filter);
|
||||
},
|
||||
});
|
||||
|
||||
// 5. clear the flow state for the OAuth tokens
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
// 6. clear the tools cache for the server
|
||||
await clearMCPServerTools({ userId, serverName });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getUserController,
|
||||
getTermsStatusController,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import React, { useState, useMemo, useCallback } from 'react';
|
||||
import { ChevronLeft } from 'lucide-react';
|
||||
import { ChevronLeft, Trash2 } from 'lucide-react';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { Button, useToastContext } from '@librechat/client';
|
||||
import { Constants, QueryKeys } from 'librechat-data-provider';
|
||||
|
@ -123,6 +123,7 @@ function MCPPanelContent() {
|
|||
}
|
||||
|
||||
const serverStatus = connectionStatus?.[selectedServerNameForEditing];
|
||||
const isConnected = serverStatus?.connectionState === 'connected';
|
||||
|
||||
return (
|
||||
<div className="h-auto max-w-full space-y-4 overflow-x-hidden py-2">
|
||||
|
@ -159,6 +160,17 @@ function MCPPanelContent() {
|
|||
Object.keys(serverBeingEdited.config.customUserVars).length > 0
|
||||
}
|
||||
/>
|
||||
{serverStatus?.requiresOAuth && isConnected && (
|
||||
<Button
|
||||
className="w-full"
|
||||
size="sm"
|
||||
variant="destructive"
|
||||
onClick={() => handleConfigRevoke(selectedServerNameForEditing)}
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
{localize('com_ui_oauth_revoke')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
|
|
|
@ -1046,6 +1046,7 @@
|
|||
"com_ui_oauth_error_title": "Authentication Failed",
|
||||
"com_ui_oauth_success_description": "Your authentication was successful. This window will close in",
|
||||
"com_ui_oauth_success_title": "Authentication Successful",
|
||||
"com_ui_oauth_revoke": "Revoke",
|
||||
"com_ui_of": "of",
|
||||
"com_ui_off": "Off",
|
||||
"com_ui_offline": "Offline",
|
||||
|
|
|
@ -149,4 +149,36 @@ describe('FlowStateManager', () => {
|
|||
await expect(flowPromise).rejects.toThrow('failure');
|
||||
}, 15000);
|
||||
});
|
||||
|
||||
describe('deleteFlow', () => {
|
||||
const flowId = 'test-flow-123';
|
||||
const type = 'test-type';
|
||||
const flowKey = `${type}:${flowId}`;
|
||||
|
||||
it('deletes an existing flow', async () => {
|
||||
await store.set(flowKey, { type, status: 'PENDING', metadata: {}, createdAt: Date.now() });
|
||||
expect(await store.get(flowKey)).toBeDefined();
|
||||
|
||||
const result = await flowManager.deleteFlow(flowId, type);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(await store.get(flowKey)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns false if the deletion errors', async () => {
|
||||
jest.spyOn(store, 'delete').mockRejectedValue(new Error('Deletion failed'));
|
||||
|
||||
const result = await flowManager.deleteFlow(flowId, type);
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('does nothing if the flow does not exist', async () => {
|
||||
expect(await store.get(flowKey)).toBeUndefined();
|
||||
|
||||
const result = await flowManager.deleteFlow(flowId, type);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -241,4 +241,19 @@ export class FlowStateManager<T = unknown> {
|
|||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a flow state
|
||||
*/
|
||||
async deleteFlow(flowId: string, type: string): Promise<boolean> {
|
||||
const flowKey = this.getFlowKey(flowId, type);
|
||||
try {
|
||||
await this.keyv.delete(flowKey);
|
||||
logger.debug(`[${flowKey}] Flow deleted`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`[${flowKey}] Error deleting flow:`, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { MCPOAuthTokens, MCPOAuthFlowMetadata } from '~/mcp/oauth';
|
||||
import type { MCPOAuthTokens, MCPOAuthFlowMetadata, OAuthMetadata } from '~/mcp/oauth';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { FlowMetadata } from '~/flow/types';
|
||||
import type * as t from './types';
|
||||
|
@ -186,6 +186,7 @@ export class MCPConnectionFactory {
|
|||
updateToken: this.tokenMethods.updateToken,
|
||||
findToken: this.tokenMethods.findToken,
|
||||
clientInfo: result.clientInfo,
|
||||
metadata: result.metadata,
|
||||
});
|
||||
logger.info(`${this.logPrefix} OAuth tokens saved to storage`);
|
||||
} catch (error) {
|
||||
|
@ -284,6 +285,7 @@ export class MCPConnectionFactory {
|
|||
protected async handleOAuthRequired(): Promise<{
|
||||
tokens: MCPOAuthTokens | null;
|
||||
clientInfo?: OAuthClientInformation;
|
||||
metadata?: OAuthMetadata;
|
||||
} | null> {
|
||||
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
|
||||
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
|
||||
|
@ -359,8 +361,13 @@ export class MCPConnectionFactory {
|
|||
|
||||
/** Client information from the flow metadata */
|
||||
const clientInfo = flowMetadata?.clientInfo;
|
||||
const metadata = flowMetadata?.metadata;
|
||||
|
||||
return { tokens, clientInfo };
|
||||
return {
|
||||
tokens,
|
||||
clientInfo,
|
||||
metadata,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`${this.logPrefix} Failed to complete OAuth flow for ${this.serverName}`, error);
|
||||
return null;
|
||||
|
|
|
@ -187,4 +187,195 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('revokeOAuthToken', () => {
|
||||
const mockServerName = 'test-server';
|
||||
const mockToken = 'test-token-12345';
|
||||
|
||||
const originalFetch = global.fetch;
|
||||
const mockFetch = jest.fn() as unknown as jest.MockedFunction<typeof fetch>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
global.fetch = mockFetch;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
mockFetch.mockClear();
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
global.fetch = originalFetch;
|
||||
});
|
||||
|
||||
it('should successfully revoke an access token with client_secret_basic auth', async () => {
|
||||
const metadata = {
|
||||
serverUrl: 'https://auth.example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
revocationEndpoint: 'https://auth.example.com/oauth/revoke',
|
||||
revocationEndpointAuthMethodsSupported: ['client_secret_basic'],
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.revokeOAuthToken(mockServerName, mockToken, 'access', metadata);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(new URL('https://auth.example.com/oauth/revoke'), {
|
||||
method: 'POST',
|
||||
body: 'token=test-token-12345&token_type_hint=access_token',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Authorization: `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully revoke a refresh token with client_secret_basic auth', async () => {
|
||||
const metadata = {
|
||||
serverUrl: 'https://auth.example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
revocationEndpoint: 'https://auth.example.com/oauth/revoke',
|
||||
revocationEndpointAuthMethodsSupported: ['client_secret_basic'],
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.revokeOAuthToken(mockServerName, mockToken, 'refresh', metadata);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(new URL('https://auth.example.com/oauth/revoke'), {
|
||||
method: 'POST',
|
||||
body: 'token=test-token-12345&token_type_hint=refresh_token',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
Authorization: `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully revoke an access token with client_secret_post auth', async () => {
|
||||
const metadata = {
|
||||
serverUrl: 'https://auth.example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
revocationEndpoint: 'https://auth.example.com/oauth/revoke',
|
||||
revocationEndpointAuthMethodsSupported: ['client_secret_post'],
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.revokeOAuthToken(mockServerName, mockToken, 'access', metadata);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(new URL('https://auth.example.com/oauth/revoke'), {
|
||||
method: 'POST',
|
||||
body: 'token=test-token-12345&token_type_hint=access_token&client_secret=test-client-secret&client_id=test-client-id',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should fallback to /revoke endpoint when revocationEndpoint is not provided', async () => {
|
||||
const metadata = {
|
||||
serverUrl: 'https://auth.example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.revokeOAuthToken(mockServerName, mockToken, 'refresh', metadata);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
new URL('https://auth.example.com/revoke'),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should default to client_secret_basic auth when revocationEndpointAuthMethodsSupported is not provided', async () => {
|
||||
const metadata = {
|
||||
serverUrl: 'https://auth.example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
revocationEndpoint: 'https://auth.example.com/oauth/revoke',
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.revokeOAuthToken(mockServerName, mockToken, 'refresh', metadata);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
expect.any(URL),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: expect.stringMatching(/^Basic /),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error when the revocation request fails', async () => {
|
||||
const metadata = {
|
||||
serverUrl: 'https://auth.example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
revocationEndpoint: 'https://auth.example.com/oauth/revoke',
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 418,
|
||||
} as Response);
|
||||
|
||||
await expect(
|
||||
MCPOAuthHandler.revokeOAuthToken(mockServerName, mockToken, 'refresh', metadata),
|
||||
).rejects.toThrow('Token revocation failed: HTTP 418');
|
||||
});
|
||||
|
||||
it('should prioritize client_secret_basic over other auth methods', async () => {
|
||||
const metadata = {
|
||||
serverUrl: 'https://auth.example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
revocationEndpoint: 'https://auth.example.com/oauth/revoke',
|
||||
revocationEndpointAuthMethodsSupported: [
|
||||
'client_secret_post',
|
||||
'client_secret_basic',
|
||||
'some_other_method',
|
||||
],
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.revokeOAuthToken(mockServerName, mockToken, 'refresh', metadata);
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
expect.any(URL),
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: expect.stringMatching(/^Basic /),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
193
packages/api/src/mcp/__tests__/tokens.test.ts
Normal file
193
packages/api/src/mcp/__tests__/tokens.test.ts
Normal file
|
@ -0,0 +1,193 @@
|
|||
import { MCPTokenStorage } from '~/mcp/oauth/tokens';
|
||||
import { decryptV2 } from '~/crypto';
|
||||
import type { TokenMethods, IToken } from '@librechat/data-schemas';
|
||||
import { Types } from 'mongoose';
|
||||
|
||||
jest.mock('~/crypto', () => ({
|
||||
decryptV2: jest.fn(),
|
||||
}));
|
||||
|
||||
const mockDecryptV2 = decryptV2 as jest.MockedFunction<typeof decryptV2>;
|
||||
|
||||
describe('MCPTokenStorage', () => {
|
||||
afterAll(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('deleteUserTokens', () => {
|
||||
const userId = '000000001111111122222222';
|
||||
const serverName = 'test-server';
|
||||
let mockDeleteToken: jest.MockedFunction<
|
||||
(filter: { userId: string; type: string; identifier: string }) => Promise<void>
|
||||
>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockDeleteToken = jest.fn().mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
it('should delete all OAuth-related tokens for a user and server', async () => {
|
||||
await MCPTokenStorage.deleteUserTokens({
|
||||
userId,
|
||||
serverName,
|
||||
deleteToken: mockDeleteToken,
|
||||
});
|
||||
|
||||
// Verify all three token types were deleted with correct identifiers
|
||||
expect(mockDeleteToken).toHaveBeenCalledTimes(3);
|
||||
expect(mockDeleteToken).toHaveBeenCalledWith({
|
||||
userId,
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `mcp:${serverName}:client`,
|
||||
});
|
||||
expect(mockDeleteToken).toHaveBeenCalledWith({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}`,
|
||||
});
|
||||
expect(mockDeleteToken).toHaveBeenCalledWith({
|
||||
userId,
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: `mcp:${serverName}:refresh`,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle deletion errors gracefully', async () => {
|
||||
mockDeleteToken.mockRejectedValueOnce(new Error('Deletion failed'));
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.deleteUserTokens({
|
||||
userId,
|
||||
serverName,
|
||||
deleteToken: mockDeleteToken,
|
||||
}),
|
||||
).rejects.toThrow('Deletion failed');
|
||||
|
||||
expect(mockDeleteToken).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getClientInfoAndMetadata', () => {
|
||||
const userId = '000000001111111122222222';
|
||||
const serverName = 'test-server';
|
||||
const identifier = `mcp:${serverName}`;
|
||||
let mockFindToken: jest.MockedFunction<TokenMethods['findToken']>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockFindToken = jest.fn();
|
||||
});
|
||||
|
||||
it('should return null when no client info token exists', async () => {
|
||||
mockFindToken.mockResolvedValue(null);
|
||||
|
||||
const result = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken: mockFindToken,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(mockFindToken).toHaveBeenCalledWith({
|
||||
userId,
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return client info and metadata when token exists', async () => {
|
||||
const clientInfo = {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-secret',
|
||||
};
|
||||
|
||||
const metadata = new Map([
|
||||
['serverUrl', 'https://test.example.com'],
|
||||
['state', 'test-state'],
|
||||
]);
|
||||
|
||||
const mockToken: IToken = {
|
||||
userId: new Types.ObjectId(userId),
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
token: 'encrypted-token',
|
||||
metadata,
|
||||
} as IToken;
|
||||
|
||||
mockFindToken.mockResolvedValue(mockToken);
|
||||
mockDecryptV2.mockResolvedValue(JSON.stringify(clientInfo));
|
||||
|
||||
const result = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken: mockFindToken,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.clientInfo).toEqual(clientInfo);
|
||||
expect(result?.clientMetadata).toEqual({
|
||||
serverUrl: 'https://test.example.com',
|
||||
state: 'test-state',
|
||||
});
|
||||
expect(mockDecryptV2).toHaveBeenCalledWith('encrypted-token');
|
||||
});
|
||||
|
||||
it('should handle empty metadata', async () => {
|
||||
const clientInfo = {
|
||||
client_id: 'test-client-id',
|
||||
};
|
||||
|
||||
const mockToken: IToken = {
|
||||
userId: new Types.ObjectId(userId),
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
token: 'encrypted-token',
|
||||
} as IToken;
|
||||
|
||||
mockFindToken.mockResolvedValue(mockToken);
|
||||
mockDecryptV2.mockResolvedValue(JSON.stringify(clientInfo));
|
||||
|
||||
const result = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken: mockFindToken,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.clientInfo).toEqual(clientInfo);
|
||||
expect(result?.clientMetadata).toEqual({});
|
||||
});
|
||||
|
||||
it('should handle metadata as plain object', async () => {
|
||||
const clientInfo = {
|
||||
client_id: 'test-client-id',
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
serverUrl: 'https://test.example.com',
|
||||
state: 'test-state',
|
||||
};
|
||||
|
||||
const mockToken: IToken = {
|
||||
userId: new Types.ObjectId(userId),
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
token: 'encrypted-token',
|
||||
metadata: metadata as unknown, // runtime check
|
||||
} as IToken;
|
||||
|
||||
mockFindToken.mockResolvedValue(mockToken);
|
||||
mockDecryptV2.mockResolvedValue(JSON.stringify(clientInfo));
|
||||
|
||||
const result = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken: mockFindToken,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.clientInfo).toEqual(clientInfo);
|
||||
expect(result?.clientMetadata).toEqual(metadata);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -643,4 +643,68 @@ export class MCPOAuthHandler {
|
|||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Revokes OAuth tokens at the authorization server (RFC 7009)
|
||||
*/
|
||||
public static async revokeOAuthToken(
|
||||
serverName: string,
|
||||
token: string,
|
||||
tokenType: 'refresh' | 'access',
|
||||
metadata: {
|
||||
serverUrl: string;
|
||||
clientId: string;
|
||||
clientSecret: string;
|
||||
revocationEndpoint?: string;
|
||||
revocationEndpointAuthMethodsSupported?: string[];
|
||||
},
|
||||
): Promise<void> {
|
||||
// build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided
|
||||
const revokeUrl: URL =
|
||||
metadata.revocationEndpoint != null
|
||||
? new URL(metadata.revocationEndpoint)
|
||||
: new URL('/revoke', metadata.serverUrl);
|
||||
|
||||
// detect auth method to use
|
||||
const authMethods = metadata.revocationEndpointAuthMethodsSupported ?? [
|
||||
'client_secret_basic', // RFC 8414 (https://datatracker.ietf.org/doc/html/rfc8414)
|
||||
];
|
||||
const usesBasicAuth = authMethods.includes('client_secret_basic');
|
||||
const usesClientSecretPost = authMethods.includes('client_secret_post');
|
||||
|
||||
// init the request headers
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
};
|
||||
|
||||
// init the request body
|
||||
const body = new URLSearchParams({ token });
|
||||
body.set('token_type_hint', tokenType === 'refresh' ? 'refresh_token' : 'access_token');
|
||||
|
||||
// process auth method
|
||||
if (usesBasicAuth) {
|
||||
// encode the client id and secret and add to the headers
|
||||
const credentials = Buffer.from(`${metadata.clientId}:${metadata.clientSecret}`).toString(
|
||||
'base64',
|
||||
);
|
||||
headers['Authorization'] = `Basic ${credentials}`;
|
||||
} else if (usesClientSecretPost) {
|
||||
// add the client id and secret to the body
|
||||
body.set('client_secret', metadata.clientSecret);
|
||||
body.set('client_id', metadata.clientId);
|
||||
}
|
||||
|
||||
// perform the revoke request
|
||||
logger.info(`[MCPOAuth] Revoking tokens for ${serverName} via ${revokeUrl.toString()}`);
|
||||
const response = await fetch(revokeUrl, {
|
||||
method: 'POST',
|
||||
body: body.toString(),
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
logger.error(`[MCPOAuth] Token revocation failed for ${serverName}: HTTP ${response.status}`);
|
||||
throw new Error(`Token revocation failed: HTTP ${response.status}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type { OAuthTokens, OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { TokenMethods, IToken } from '@librechat/data-schemas';
|
||||
import type { MCPOAuthTokens, ExtendedOAuthTokens } from './types';
|
||||
import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types';
|
||||
import { encryptV2, decryptV2 } from '~/crypto';
|
||||
import { isSystemUserId } from '~/mcp/enum';
|
||||
|
||||
|
@ -13,6 +13,7 @@ interface StoreTokensParams {
|
|||
updateToken?: TokenMethods['updateToken'];
|
||||
findToken?: TokenMethods['findToken'];
|
||||
clientInfo?: OAuthClientInformation;
|
||||
metadata?: OAuthMetadata;
|
||||
/** Optional: Pass existing token state to avoid duplicate DB calls */
|
||||
existingTokens?: {
|
||||
accessToken?: IToken | null;
|
||||
|
@ -55,6 +56,7 @@ export class MCPTokenStorage {
|
|||
findToken,
|
||||
clientInfo,
|
||||
existingTokens,
|
||||
metadata,
|
||||
}: StoreTokensParams): Promise<void> {
|
||||
const logPrefix = this.getLogPrefix(userId, serverName);
|
||||
|
||||
|
@ -188,6 +190,7 @@ export class MCPTokenStorage {
|
|||
identifier: `${identifier}:client`,
|
||||
token: encryptedClientInfo,
|
||||
expiresIn: 365 * 24 * 60 * 60,
|
||||
metadata,
|
||||
};
|
||||
|
||||
// Check if client info already exists and update if it does
|
||||
|
@ -379,4 +382,86 @@ export class MCPTokenStorage {
|
|||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
static async getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
}: {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
findToken: TokenMethods['findToken'];
|
||||
}): Promise<{
|
||||
clientInfo: OAuthClientInformation;
|
||||
clientMetadata: Record<string, unknown>;
|
||||
} | null> {
|
||||
const identifier = `mcp:${serverName}`;
|
||||
|
||||
const clientInfoData: IToken | null = await findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
});
|
||||
if (clientInfoData == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const tokenData = await decryptV2(clientInfoData.token);
|
||||
const clientInfo = JSON.parse(tokenData);
|
||||
|
||||
// get metadata from the token as a plain object. While it's defined as a Map in the database type, it's a plain object at runtime.
|
||||
function getMetadata(
|
||||
metadata: Map<string, unknown> | Record<string, unknown> | null,
|
||||
): Record<string, unknown> {
|
||||
if (metadata == null) {
|
||||
return {};
|
||||
}
|
||||
if (metadata instanceof Map) {
|
||||
return Object.fromEntries(metadata);
|
||||
}
|
||||
return { ...(metadata as Record<string, unknown>) };
|
||||
}
|
||||
const clientMetadata = getMetadata(clientInfoData.metadata ?? null);
|
||||
|
||||
return {
|
||||
clientInfo,
|
||||
clientMetadata,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all OAuth-related tokens for a specific user and server
|
||||
*/
|
||||
static async deleteUserTokens({
|
||||
userId,
|
||||
serverName,
|
||||
deleteToken,
|
||||
}: {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
deleteToken: (filter: { userId: string; type: string; identifier: string }) => Promise<void>;
|
||||
}): Promise<void> {
|
||||
const identifier = `mcp:${serverName}`;
|
||||
|
||||
// delete client info token
|
||||
await deleteToken({
|
||||
userId,
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: `${identifier}:client`,
|
||||
});
|
||||
|
||||
// delete access token
|
||||
await deleteToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier,
|
||||
});
|
||||
|
||||
// delete refresh token
|
||||
await deleteToken({
|
||||
userId,
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: `${identifier}:refresh`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,10 @@ export interface OAuthMetadata {
|
|||
token_endpoint_auth_methods_supported?: string[];
|
||||
/** Code challenge methods supported */
|
||||
code_challenge_methods_supported?: string[];
|
||||
/** Revocation endpoint */
|
||||
revocation_endpoint?: string;
|
||||
/** Revocation endpoint auth methods supported */
|
||||
revocation_endpoint_auth_methods_supported?: string[];
|
||||
}
|
||||
|
||||
export interface OAuthProtectedResourceMetadata {
|
||||
|
|
|
@ -56,6 +56,10 @@ const BaseOptionsSchema = z.object({
|
|||
response_types_supported: z.array(z.string()).optional(),
|
||||
/** Supported code challenge methods (defaults to ['S256', 'plain']) */
|
||||
code_challenge_methods_supported: z.array(z.string()).optional(),
|
||||
/** OAuth revocation endpoint (optional - can be auto-discovered) */
|
||||
revocation_endpoint: z.string().url().optional(),
|
||||
/** OAuth revocation endpoint authentication methods supported (optional - can be auto-discovered) */
|
||||
revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(),
|
||||
})
|
||||
.optional(),
|
||||
customUserVars: z
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue