refactor: Improve OAuth server management and connection status handling

This commit is contained in:
Marco Beretta 2025-07-30 20:36:39 +02:00
parent 8a1a38f346
commit f25407768e
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
3 changed files with 176 additions and 6 deletions

View file

@ -300,6 +300,120 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
}
});
/**
* Revoke OAuth tokens for an MCP server
* This endpoint revokes OAuth tokens and disconnects the server for a fresh start
*/
router.post('/:serverName/oauth/revoke', requireJwtAuth, async (req, res) => {
try {
const { serverName } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
logger.info(`[MCP OAuth Revoke] Revoking OAuth tokens for ${serverName} by user ${user.id}`);
const printConfig = false;
const config = await loadCustomConfig(printConfig);
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
});
}
const mcpManager = getMCPManager(user.id);
// Delete OAuth access and refresh tokens
const baseIdentifier = `mcp:${serverName}`;
try {
await deleteTokens({ identifier: baseIdentifier });
await deleteTokens({ identifier: `${baseIdentifier}:refresh` });
logger.info(`[MCP OAuth Revoke] Successfully cleared OAuth tokens for ${serverName}`);
} catch (error) {
logger.warn(`[MCP OAuth Revoke] Failed to clear OAuth tokens for ${serverName}:`, error);
}
// Disconnect the server and clear all connection state
try {
await mcpManager.disconnectServer(serverName);
logger.info(`[MCP OAuth Revoke] Disconnected server: ${serverName}`);
// Clear the server from OAuth servers set to prevent it from showing as requiring OAuth
if (mcpManager.removeOAuthServer) {
mcpManager.removeOAuthServer(serverName);
logger.info(`[MCP OAuth Revoke] Removed ${serverName} from OAuth servers set`);
} else {
// Fallback: clear the OAuth servers set entry manually (should not be needed now)
const oauthServers = mcpManager.getOAuthServers();
if (oauthServers && oauthServers.has(serverName)) {
oauthServers.delete(serverName);
logger.info(`[MCP OAuth Revoke] Manually removed ${serverName} from OAuth servers set`);
}
}
// Clear connection state from both app and user connection maps
const appConnections = mcpManager.getAllConnections();
const userConnections = mcpManager.getUserConnections(user.id);
if (appConnections && appConnections.has(serverName)) {
appConnections.delete(serverName);
logger.info(`[MCP OAuth Revoke] Cleared ${serverName} from app connections`);
}
if (userConnections && userConnections.has(serverName)) {
userConnections.delete(serverName);
logger.info(`[MCP OAuth Revoke] Cleared ${serverName} from user connections`);
}
} catch (error) {
logger.warn(`[MCP OAuth Revoke] Failed to disconnect server ${serverName}:`, error);
}
// Clear cached tools for this server
try {
const userTools = (await getCachedTools({ userId: user.id })) || {};
const mcpDelimiter = Constants.mcp_delimiter;
for (const key of Object.keys(userTools)) {
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
delete userTools[key];
}
}
await setCachedTools(userTools, { userId: user.id });
logger.info(`[MCP OAuth Revoke] Cleared cached tools for ${serverName}`);
} catch (error) {
logger.warn(`[MCP OAuth Revoke] Failed to clear cached tools for ${serverName}:`, error);
}
// Cancel any pending OAuth flows
try {
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (flowState && flowState.status === 'PENDING') {
await flowManager.failFlow(flowId, 'mcp_oauth', 'OAuth tokens revoked by user');
logger.info(`[MCP OAuth Revoke] Cancelled pending OAuth flow for ${serverName}`);
}
} catch (error) {
logger.warn(`[MCP OAuth Revoke] Failed to cancel OAuth flow for ${serverName}:`, error);
}
res.json({
success: true,
message: `OAuth tokens revoked for ${serverName}. Server can now be re-authenticated.`,
serverName,
});
} catch (error) {
logger.error('[MCP OAuth Revoke] Unexpected error', error);
res.status(500).json({ error: 'Failed to revoke OAuth tokens' });
}
});
/**
* Reinitialize MCP server
* This endpoint allows reinitializing a specific MCP server

View file

@ -339,7 +339,7 @@ async function getServerConnectionStatus(
serverName,
appConnections,
userConnections,
oauthServers,
_oauthServers,
) {
const getConnectionState = () =>
appConnections.get(serverName)?.connectionState ??
@ -349,7 +349,36 @@ async function getServerConnectionStatus(
const baseConnectionState = getConnectionState();
let finalConnectionState = baseConnectionState;
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
const mcpManager = getMCPManager(userId);
const hasOAuthConfig = mcpManager.serverRequiresOAuth(serverName);
if (hasOAuthConfig) {
const baseIdentifier = `mcp:${serverName}`;
try {
const accessToken = await findToken({ identifier: baseIdentifier });
if (!accessToken) {
// No tokens found, server should be considered disconnected regardless of in-memory state
finalConnectionState = 'disconnected';
logger.debug(
`[Connection Status] No OAuth tokens found for ${serverName}, marking as disconnected`,
);
} else if (baseConnectionState === 'disconnected') {
// Tokens exist but connection shows disconnected, check OAuth flow status
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
}
} catch (error) {
logger.warn(`[Connection Status] Error checking tokens for ${serverName}:`, error);
finalConnectionState = 'disconnected';
}
} else if (baseConnectionState === 'disconnected') {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
@ -359,10 +388,21 @@ async function getServerConnectionStatus(
}
}
return {
requiresOAuth: oauthServers.has(serverName),
connectionState: finalConnectionState,
};
let requiresOAuth = hasOAuthConfig;
if (hasOAuthConfig) {
try {
const baseIdentifier = `mcp:${serverName}`;
const accessToken = await findToken({ identifier: baseIdentifier });
requiresOAuth = !accessToken;
} catch (_error) {
requiresOAuth = true;
}
}
// return {
// requiresOAuth: oauthServers.has(serverName),
// connectionState: finalConnectionState,
// };
}
module.exports = {

View file

@ -494,6 +494,9 @@ export class MCPManager {
connection.on('oauthRequired', async (data) => {
logger.info(`[MCP][User: ${userId}][${serverName}] oauthRequired event received`);
// Add server to OAuth servers set
this.oauthServers.add(serverName);
// If we just want to initiate OAuth and return, handle it differently
if (returnOnOAuth) {
try {
@ -1140,4 +1143,17 @@ ${logPrefix} Flow ID: ${newFlowId}
public getOAuthServers(): Set<string> {
return this.oauthServers;
}
/** Remove a server from OAuth servers set */
public removeOAuthServer(serverName: string): void {
this.oauthServers.delete(serverName);
}
/**
* Check if a server requires OAuth based on configuration
*/
public serverRequiresOAuth(serverName: string): boolean {
const config = this.mcpConfigs[serverName];
return !!config?.oauth;
}
}