From f25407768ed57ad89d6636394e5dfb9fcc95a181 Mon Sep 17 00:00:00 2001 From: Marco Beretta <81851188+berry-13@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:36:39 +0200 Subject: [PATCH] refactor: Improve OAuth server management and connection status handling --- api/server/routes/mcp.js | 114 ++++++++++++++++++++++++++++++++ api/server/services/MCP.js | 52 +++++++++++++-- packages/api/src/mcp/manager.ts | 16 +++++ 3 files changed, 176 insertions(+), 6 deletions(-) diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 4e35052cd7..4454c68095 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -300,6 +300,120 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => { } }); +/** + * Revoke OAuth tokens for an MCP server + * This endpoint revokes OAuth tokens and disconnects the server for a fresh start + */ +router.post('/:serverName/oauth/revoke', requireJwtAuth, async (req, res) => { + try { + const { serverName } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + logger.info(`[MCP OAuth Revoke] Revoking OAuth tokens for ${serverName} by user ${user.id}`); + + const printConfig = false; + const config = await loadCustomConfig(printConfig); + if (!config || !config.mcpServers || !config.mcpServers[serverName]) { + return res.status(404).json({ + error: `MCP server '${serverName}' not found in configuration`, + }); + } + + const mcpManager = getMCPManager(user.id); + + // Delete OAuth access and refresh tokens + const baseIdentifier = `mcp:${serverName}`; + + try { + await deleteTokens({ identifier: baseIdentifier }); + await deleteTokens({ identifier: `${baseIdentifier}:refresh` }); + logger.info(`[MCP OAuth Revoke] Successfully cleared OAuth tokens for ${serverName}`); + } catch (error) { + logger.warn(`[MCP OAuth Revoke] Failed to clear OAuth tokens for ${serverName}:`, error); + } + + // Disconnect the server and clear all connection state + try { + await mcpManager.disconnectServer(serverName); + logger.info(`[MCP OAuth Revoke] Disconnected server: ${serverName}`); + + // Clear the server from OAuth servers set to prevent it from showing as requiring OAuth + if (mcpManager.removeOAuthServer) { + mcpManager.removeOAuthServer(serverName); + logger.info(`[MCP OAuth Revoke] Removed ${serverName} from OAuth servers set`); + } else { + // Fallback: clear the OAuth servers set entry manually (should not be needed now) + const oauthServers = mcpManager.getOAuthServers(); + if (oauthServers && oauthServers.has(serverName)) { + oauthServers.delete(serverName); + logger.info(`[MCP OAuth Revoke] Manually removed ${serverName} from OAuth servers set`); + } + } + + // Clear connection state from both app and user connection maps + const appConnections = mcpManager.getAllConnections(); + const userConnections = mcpManager.getUserConnections(user.id); + + if (appConnections && appConnections.has(serverName)) { + appConnections.delete(serverName); + logger.info(`[MCP OAuth Revoke] Cleared ${serverName} from app connections`); + } + + if (userConnections && userConnections.has(serverName)) { + userConnections.delete(serverName); + logger.info(`[MCP OAuth Revoke] Cleared ${serverName} from user connections`); + } + } catch (error) { + logger.warn(`[MCP OAuth Revoke] Failed to disconnect server ${serverName}:`, error); + } + + // Clear cached tools for this server + try { + const userTools = (await getCachedTools({ userId: user.id })) || {}; + const mcpDelimiter = Constants.mcp_delimiter; + + for (const key of Object.keys(userTools)) { + if (key.endsWith(`${mcpDelimiter}${serverName}`)) { + delete userTools[key]; + } + } + + await setCachedTools(userTools, { userId: user.id }); + logger.info(`[MCP OAuth Revoke] Cleared cached tools for ${serverName}`); + } catch (error) { + logger.warn(`[MCP OAuth Revoke] Failed to clear cached tools for ${serverName}:`, error); + } + + // Cancel any pending OAuth flows + try { + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + + if (flowState && flowState.status === 'PENDING') { + await flowManager.failFlow(flowId, 'mcp_oauth', 'OAuth tokens revoked by user'); + logger.info(`[MCP OAuth Revoke] Cancelled pending OAuth flow for ${serverName}`); + } + } catch (error) { + logger.warn(`[MCP OAuth Revoke] Failed to cancel OAuth flow for ${serverName}:`, error); + } + + res.json({ + success: true, + message: `OAuth tokens revoked for ${serverName}. Server can now be re-authenticated.`, + serverName, + }); + } catch (error) { + logger.error('[MCP OAuth Revoke] Unexpected error', error); + res.status(500).json({ error: 'Failed to revoke OAuth tokens' }); + } +}); + /** * Reinitialize MCP server * This endpoint allows reinitializing a specific MCP server diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 147def1bbf..7c26763164 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -339,7 +339,7 @@ async function getServerConnectionStatus( serverName, appConnections, userConnections, - oauthServers, + _oauthServers, ) { const getConnectionState = () => appConnections.get(serverName)?.connectionState ?? @@ -349,7 +349,36 @@ async function getServerConnectionStatus( const baseConnectionState = getConnectionState(); let finalConnectionState = baseConnectionState; - if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) { + const mcpManager = getMCPManager(userId); + + const hasOAuthConfig = mcpManager.serverRequiresOAuth(serverName); + + if (hasOAuthConfig) { + const baseIdentifier = `mcp:${serverName}`; + try { + const accessToken = await findToken({ identifier: baseIdentifier }); + + if (!accessToken) { + // No tokens found, server should be considered disconnected regardless of in-memory state + finalConnectionState = 'disconnected'; + logger.debug( + `[Connection Status] No OAuth tokens found for ${serverName}, marking as disconnected`, + ); + } else if (baseConnectionState === 'disconnected') { + // Tokens exist but connection shows disconnected, check OAuth flow status + const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName); + + if (hasFailedFlow) { + finalConnectionState = 'error'; + } else if (hasActiveFlow) { + finalConnectionState = 'connecting'; + } + } + } catch (error) { + logger.warn(`[Connection Status] Error checking tokens for ${serverName}:`, error); + finalConnectionState = 'disconnected'; + } + } else if (baseConnectionState === 'disconnected') { const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName); if (hasFailedFlow) { @@ -359,10 +388,21 @@ async function getServerConnectionStatus( } } - return { - requiresOAuth: oauthServers.has(serverName), - connectionState: finalConnectionState, - }; + let requiresOAuth = hasOAuthConfig; + if (hasOAuthConfig) { + try { + const baseIdentifier = `mcp:${serverName}`; + const accessToken = await findToken({ identifier: baseIdentifier }); + requiresOAuth = !accessToken; + } catch (_error) { + requiresOAuth = true; + } + } + + // return { + // requiresOAuth: oauthServers.has(serverName), + // connectionState: finalConnectionState, + // }; } module.exports = { diff --git a/packages/api/src/mcp/manager.ts b/packages/api/src/mcp/manager.ts index 6db52ec9ec..6d6b1f7a7a 100644 --- a/packages/api/src/mcp/manager.ts +++ b/packages/api/src/mcp/manager.ts @@ -494,6 +494,9 @@ export class MCPManager { connection.on('oauthRequired', async (data) => { logger.info(`[MCP][User: ${userId}][${serverName}] oauthRequired event received`); + // Add server to OAuth servers set + this.oauthServers.add(serverName); + // If we just want to initiate OAuth and return, handle it differently if (returnOnOAuth) { try { @@ -1140,4 +1143,17 @@ ${logPrefix} Flow ID: ${newFlowId} public getOAuthServers(): Set { return this.oauthServers; } + + /** Remove a server from OAuth servers set */ + public removeOAuthServer(serverName: string): void { + this.oauthServers.delete(serverName); + } + + /** + * Check if a server requires OAuth based on configuration + */ + public serverRequiresOAuth(serverName: string): boolean { + const config = this.mcpConfigs[serverName]; + return !!config?.oauth; + } }