mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-10 04:28:50 +01:00
refactor: Improve OAuth server management and connection status handling
This commit is contained in:
parent
8a1a38f346
commit
f25407768e
3 changed files with 176 additions and 6 deletions
|
|
@ -300,6 +300,120 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Revoke OAuth tokens for an MCP server
|
||||
* This endpoint revokes OAuth tokens and disconnects the server for a fresh start
|
||||
*/
|
||||
router.post('/:serverName/oauth/revoke', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
logger.info(`[MCP OAuth Revoke] Revoking OAuth tokens for ${serverName} by user ${user.id}`);
|
||||
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
});
|
||||
}
|
||||
|
||||
const mcpManager = getMCPManager(user.id);
|
||||
|
||||
// Delete OAuth access and refresh tokens
|
||||
const baseIdentifier = `mcp:${serverName}`;
|
||||
|
||||
try {
|
||||
await deleteTokens({ identifier: baseIdentifier });
|
||||
await deleteTokens({ identifier: `${baseIdentifier}:refresh` });
|
||||
logger.info(`[MCP OAuth Revoke] Successfully cleared OAuth tokens for ${serverName}`);
|
||||
} catch (error) {
|
||||
logger.warn(`[MCP OAuth Revoke] Failed to clear OAuth tokens for ${serverName}:`, error);
|
||||
}
|
||||
|
||||
// Disconnect the server and clear all connection state
|
||||
try {
|
||||
await mcpManager.disconnectServer(serverName);
|
||||
logger.info(`[MCP OAuth Revoke] Disconnected server: ${serverName}`);
|
||||
|
||||
// Clear the server from OAuth servers set to prevent it from showing as requiring OAuth
|
||||
if (mcpManager.removeOAuthServer) {
|
||||
mcpManager.removeOAuthServer(serverName);
|
||||
logger.info(`[MCP OAuth Revoke] Removed ${serverName} from OAuth servers set`);
|
||||
} else {
|
||||
// Fallback: clear the OAuth servers set entry manually (should not be needed now)
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
if (oauthServers && oauthServers.has(serverName)) {
|
||||
oauthServers.delete(serverName);
|
||||
logger.info(`[MCP OAuth Revoke] Manually removed ${serverName} from OAuth servers set`);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear connection state from both app and user connection maps
|
||||
const appConnections = mcpManager.getAllConnections();
|
||||
const userConnections = mcpManager.getUserConnections(user.id);
|
||||
|
||||
if (appConnections && appConnections.has(serverName)) {
|
||||
appConnections.delete(serverName);
|
||||
logger.info(`[MCP OAuth Revoke] Cleared ${serverName} from app connections`);
|
||||
}
|
||||
|
||||
if (userConnections && userConnections.has(serverName)) {
|
||||
userConnections.delete(serverName);
|
||||
logger.info(`[MCP OAuth Revoke] Cleared ${serverName} from user connections`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`[MCP OAuth Revoke] Failed to disconnect server ${serverName}:`, error);
|
||||
}
|
||||
|
||||
// Clear cached tools for this server
|
||||
try {
|
||||
const userTools = (await getCachedTools({ userId: user.id })) || {};
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete userTools[key];
|
||||
}
|
||||
}
|
||||
|
||||
await setCachedTools(userTools, { userId: user.id });
|
||||
logger.info(`[MCP OAuth Revoke] Cleared cached tools for ${serverName}`);
|
||||
} catch (error) {
|
||||
logger.warn(`[MCP OAuth Revoke] Failed to clear cached tools for ${serverName}:`, error);
|
||||
}
|
||||
|
||||
// Cancel any pending OAuth flows
|
||||
try {
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (flowState && flowState.status === 'PENDING') {
|
||||
await flowManager.failFlow(flowId, 'mcp_oauth', 'OAuth tokens revoked by user');
|
||||
logger.info(`[MCP OAuth Revoke] Cancelled pending OAuth flow for ${serverName}`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`[MCP OAuth Revoke] Failed to cancel OAuth flow for ${serverName}:`, error);
|
||||
}
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
message: `OAuth tokens revoked for ${serverName}. Server can now be re-authenticated.`,
|
||||
serverName,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth Revoke] Unexpected error', error);
|
||||
res.status(500).json({ error: 'Failed to revoke OAuth tokens' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Reinitialize MCP server
|
||||
* This endpoint allows reinitializing a specific MCP server
|
||||
|
|
|
|||
|
|
@ -339,7 +339,7 @@ async function getServerConnectionStatus(
|
|||
serverName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
_oauthServers,
|
||||
) {
|
||||
const getConnectionState = () =>
|
||||
appConnections.get(serverName)?.connectionState ??
|
||||
|
|
@ -349,7 +349,36 @@ async function getServerConnectionStatus(
|
|||
const baseConnectionState = getConnectionState();
|
||||
let finalConnectionState = baseConnectionState;
|
||||
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
const mcpManager = getMCPManager(userId);
|
||||
|
||||
const hasOAuthConfig = mcpManager.serverRequiresOAuth(serverName);
|
||||
|
||||
if (hasOAuthConfig) {
|
||||
const baseIdentifier = `mcp:${serverName}`;
|
||||
try {
|
||||
const accessToken = await findToken({ identifier: baseIdentifier });
|
||||
|
||||
if (!accessToken) {
|
||||
// No tokens found, server should be considered disconnected regardless of in-memory state
|
||||
finalConnectionState = 'disconnected';
|
||||
logger.debug(
|
||||
`[Connection Status] No OAuth tokens found for ${serverName}, marking as disconnected`,
|
||||
);
|
||||
} else if (baseConnectionState === 'disconnected') {
|
||||
// Tokens exist but connection shows disconnected, check OAuth flow status
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`[Connection Status] Error checking tokens for ${serverName}:`, error);
|
||||
finalConnectionState = 'disconnected';
|
||||
}
|
||||
} else if (baseConnectionState === 'disconnected') {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
|
|
@ -359,10 +388,21 @@ async function getServerConnectionStatus(
|
|||
}
|
||||
}
|
||||
|
||||
return {
|
||||
requiresOAuth: oauthServers.has(serverName),
|
||||
connectionState: finalConnectionState,
|
||||
};
|
||||
let requiresOAuth = hasOAuthConfig;
|
||||
if (hasOAuthConfig) {
|
||||
try {
|
||||
const baseIdentifier = `mcp:${serverName}`;
|
||||
const accessToken = await findToken({ identifier: baseIdentifier });
|
||||
requiresOAuth = !accessToken;
|
||||
} catch (_error) {
|
||||
requiresOAuth = true;
|
||||
}
|
||||
}
|
||||
|
||||
// return {
|
||||
// requiresOAuth: oauthServers.has(serverName),
|
||||
// connectionState: finalConnectionState,
|
||||
// };
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
|
|
|
|||
|
|
@ -494,6 +494,9 @@ export class MCPManager {
|
|||
connection.on('oauthRequired', async (data) => {
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}] oauthRequired event received`);
|
||||
|
||||
// Add server to OAuth servers set
|
||||
this.oauthServers.add(serverName);
|
||||
|
||||
// If we just want to initiate OAuth and return, handle it differently
|
||||
if (returnOnOAuth) {
|
||||
try {
|
||||
|
|
@ -1140,4 +1143,17 @@ ${logPrefix} Flow ID: ${newFlowId}
|
|||
public getOAuthServers(): Set<string> {
|
||||
return this.oauthServers;
|
||||
}
|
||||
|
||||
/** Remove a server from OAuth servers set */
|
||||
public removeOAuthServer(serverName: string): void {
|
||||
this.oauthServers.delete(serverName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a server requires OAuth based on configuration
|
||||
*/
|
||||
public serverRequiresOAuth(serverName: string): boolean {
|
||||
const config = this.mcpConfigs[serverName];
|
||||
return !!config?.oauth;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue