🔌 feat: MCP OAuth Integration in Chat UI

- **Real-Time Connection Status**: New backend APIs and React Query hooks provide live MCP server connection monitoring with automatic UI updates
- **OAuth Flow Components**: Complete MCPConfigDialog, ServerInitializationSection, and CustomUserVarsSection with OAuth URL handling and polling-based completion
- **Enhanced Server Selection**: MCPSelect component with connection-aware filtering, visual status indicators, and better credential management UX

(still needs a lot of refinement since there is bloat/unused vars and functions leftover from the ideation phase on how to approach OAuth and connection statuses)
This commit is contained in:
Dustin Healy 2025-07-21 01:29:33 -07:00
parent b39b60c012
commit 63140237a6
27 changed files with 1760 additions and 286 deletions

View file

@ -69,6 +69,7 @@ export class MCPConnection extends EventEmitter {
private lastPingTime: number;
private oauthTokens?: MCPOAuthTokens | null;
private oauthRequired = false;
private oauthTimeoutId: NodeJS.Timeout | null = null;
iconPath?: string;
timeout?: number;
url?: string;
@ -421,6 +422,7 @@ export class MCPConnection extends EventEmitter {
const cleanup = () => {
if (timeoutId) {
clearTimeout(timeoutId);
this.oauthTimeoutId = null;
}
if (oauthHandledListener) {
this.off('oauthHandled', oauthHandledListener);
@ -448,11 +450,26 @@ export class MCPConnection extends EventEmitter {
reject(new Error(`OAuth handling timeout after ${oauthTimeout}ms`));
}, oauthTimeout);
// Store the timeout ID for potential cancellation
this.oauthTimeoutId = timeoutId;
// Listen for both success and failure events
this.once('oauthHandled', oauthHandledListener);
this.once('oauthFailed', oauthFailedListener);
});
// Check if there are any listeners for oauthRequired event
const hasOAuthListeners = this.listenerCount('oauthRequired') > 0;
if (!hasOAuthListeners) {
// No OAuth handler available (like during startup), immediately fail
logger.warn(
`${this.getLogPrefix()} OAuth required but no handler available, failing immediately`,
);
this.oauthRequired = false;
throw error;
}
// Emit the event
this.emit('oauthRequired', {
serverName: this.serverName,
@ -517,7 +534,7 @@ export class MCPConnection extends EventEmitter {
try {
await this.disconnect();
await this.connectClient();
if (!(await this.isConnected())) {
if (!(await this.isConnected()) && !(this.isInitializing && this.oauthTokens)) {
throw new Error('Connection not established');
}
} catch (error) {
@ -545,6 +562,37 @@ export class MCPConnection extends EventEmitter {
public async disconnect(): Promise<void> {
try {
// Cancel any pending OAuth timeout
if (this.oauthTimeoutId) {
clearTimeout(this.oauthTimeoutId);
this.oauthTimeoutId = null;
}
if (this.transport) {
await this.client.close();
this.transport = null;
}
if (this.connectionState === 'disconnected') {
return;
}
this.connectionState = 'disconnected';
this.emit('connectionChange', 'disconnected');
} finally {
this.connectPromise = null;
}
}
public async disconnectAndStopReconnecting(): Promise<void> {
try {
// Stop any reconnection attempts
this.shouldStopReconnecting = true;
// Cancel any pending OAuth timeout
if (this.oauthTimeoutId) {
clearTimeout(this.oauthTimeoutId);
this.oauthTimeoutId = null;
}
if (this.transport) {
await this.client.close();
this.transport = null;
@ -650,6 +698,11 @@ export class MCPConnection extends EventEmitter {
this.oauthTokens = tokens;
}
/** Get the current connection state */
public getConnectionState(): t.ConnectionState {
return this.connectionState;
}
private isOAuthError(error: unknown): boolean {
if (!error || typeof error !== 'object') {
return false;

View file

@ -15,8 +15,9 @@ import { MCPTokenStorage } from './oauth/tokens';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils/env';
import { EventEmitter } from 'events';
export class MCPManager {
export class MCPManager extends EventEmitter {
private static instance: MCPManager | null = null;
/** App-level connections initialized at startup */
private connections: Map<string, MCPConnection> = new Map();
@ -29,6 +30,10 @@ export class MCPManager {
/** Store MCP server instructions */
private serverInstructions: Map<string, string> = new Map();
constructor() {
super();
}
public static getInstance(): MCPManager {
if (!MCPManager.instance) {
MCPManager.instance = new MCPManager();
@ -47,7 +52,7 @@ export class MCPManager {
mcpServers: t.MCPServers;
flowManager: FlowStateManager<MCPOAuthTokens | null>;
tokenMethods?: TokenMethods;
}): Promise<void> {
}): Promise<Record<string, boolean>> {
this.mcpConfigs = mcpServers;
if (!flowManager) {
@ -59,6 +64,7 @@ export class MCPManager {
}
const entries = Object.entries(mcpServers);
const initializedServers = new Set();
const oauthSkippedServers = new Set();
const connectionResults = await Promise.allSettled(
entries.map(async ([serverName, config], i) => {
try {
@ -70,19 +76,46 @@ export class MCPManager {
});
initializedServers.add(i);
} catch (error) {
logger.error(`[MCP][${serverName}] Initialization failed`, error);
// Check if this is an OAuth skipped error
if (
error instanceof Error &&
(error as Error & { isOAuthSkipped?: boolean }).isOAuthSkipped
) {
oauthSkippedServers.add(i);
} else {
logger.error(`[MCP][${serverName}] Initialization failed`, error);
// Debug: Log the actual error for filesystem server
if (serverName === 'filesystem') {
logger.error(`[MCP][${serverName}] Error details:`, {
message: error instanceof Error ? error.message : String(error),
stack: error instanceof Error ? error.stack : undefined,
isOAuthError: this.isOAuthError(error),
});
}
}
}
}),
);
const failedConnections = connectionResults.filter(
(result): result is PromiseRejectedResult => result.status === 'rejected',
(result): result is PromiseRejectedResult =>
result.status === 'rejected' &&
!(
result.reason instanceof Error &&
(result.reason as Error & { isOAuthSkipped?: boolean }).isOAuthSkipped
),
);
logger.info(
`[MCP] Initialized ${initializedServers.size}/${entries.length} app-level server(s)`,
);
if (oauthSkippedServers.size > 0) {
logger.info(
`[MCP] ${oauthSkippedServers.size}/${entries.length} app-level server(s) skipped for OAuth`,
);
}
if (failedConnections.length > 0) {
logger.warn(
`[MCP] ${failedConnections.length}/${entries.length} app-level server(s) failed to initialize`,
@ -92,6 +125,8 @@ export class MCPManager {
entries.forEach(([serverName], index) => {
if (initializedServers.has(index)) {
logger.info(`[MCP][${serverName}] ✓ Initialized`);
} else if (oauthSkippedServers.has(index)) {
logger.info(`[MCP][${serverName}] OAuth Required`);
} else {
logger.info(`[MCP][${serverName}] ✗ Failed`);
}
@ -99,9 +134,16 @@ export class MCPManager {
if (initializedServers.size === entries.length) {
logger.info('[MCP] All app-level servers initialized successfully');
} else if (initializedServers.size === 0) {
} else if (initializedServers.size === 0 && oauthSkippedServers.size === 0) {
logger.warn('[MCP] No app-level servers initialized');
}
// Return OAuth requirement map
const oauthRequirementMap: Record<string, boolean> = {};
entries.forEach(([serverName], index) => {
oauthRequirementMap[serverName] = oauthSkippedServers.has(index);
});
return oauthRequirementMap;
}
/** Initializes a single MCP server connection (app-level) */
@ -166,40 +208,18 @@ export class MCPManager {
logger.info(`[MCP][${serverName}] Loaded OAuth tokens`);
}
const connection = new MCPConnection(serverName, processedConfig, undefined, tokens);
logger.info(`[MCP][${serverName}] Setting up OAuth event listener`);
connection.on('oauthRequired', async (data) => {
logger.debug(`[MCP][${serverName}] oauthRequired event received`);
const result = await this.handleOAuthRequired({
...data,
flowManager,
});
if (result?.tokens && tokenMethods?.createToken) {
try {
connection.setOAuthTokens(result.tokens);
await MCPTokenStorage.storeTokens({
userId: CONSTANTS.SYSTEM_USER_ID,
serverName,
tokens: result.tokens,
createToken: tokenMethods.createToken,
updateToken: tokenMethods.updateToken,
findToken: tokenMethods.findToken,
clientInfo: result.clientInfo,
});
logger.info(`[MCP][${serverName}] OAuth tokens saved to storage`);
} catch (error) {
logger.error(`[MCP][${serverName}] Failed to save OAuth tokens to storage`, error);
}
}
// Only emit oauthHandled if we actually got tokens (OAuth succeeded)
if (result?.tokens) {
connection.emit('oauthHandled');
} else {
// OAuth failed, emit oauthFailed to properly reject the promise
logger.warn(`[MCP][${serverName}] OAuth failed, emitting oauthFailed event`);
connection.emit('oauthFailed', new Error('OAuth authentication failed'));
}
// Track OAuth skipped state explicitly
let oauthSkipped = false;
connection.on('oauthRequired', async () => {
logger.debug(`[MCP][${serverName}] oauthRequired event received`);
oauthSkipped = true;
// Emit event to signal that initialization should be skipped
connection.emit('oauthSkipped');
return;
});
try {
const connectTimeout = processedConfig.initTimeout ?? 30000;
const connectionTimeout = new Promise<void>((_, reject) =>
@ -208,13 +228,35 @@ export class MCPManager {
connectTimeout,
),
);
// Listen for oauthSkipped event to stop initialization
const oauthSkippedPromise = new Promise<void>((resolve) => {
connection.once('oauthSkipped', () => {
logger.debug(`[MCP][${serverName}] OAuth skipped, stopping initialization`);
resolve();
});
});
const connectionAttempt = this.initializeServer({
connection,
logPrefix: `[MCP][${serverName}]`,
flowManager,
handleOAuth: false,
});
await Promise.race([connectionAttempt, connectionTimeout]);
// Race between connection attempt, timeout, and oauthSkipped
await Promise.race([connectionAttempt, connectionTimeout, oauthSkippedPromise]);
// Check if OAuth was explicitly skipped
if (oauthSkipped) {
// Throw a special error to signal OAuth was skipped
const oauthSkippedError = new Error(`OAuth required for ${serverName}`) as Error & {
isOAuthSkipped: boolean;
};
oauthSkippedError.isOAuthSkipped = true;
throw oauthSkippedError;
}
if (await connection.isConnected()) {
this.connections.set(serverName, connection);
@ -269,6 +311,17 @@ export class MCPManager {
logger.info(`[MCP][${serverName}] ✗ Failed`);
}
} catch (error) {
// Debug: Log the actual error for filesystem server
if (serverName === 'filesystem') {
logger.error(`[MCP][${serverName}] Error details:`, {
message: error instanceof Error ? error.message : String(error),
stack: error instanceof Error ? error.stack : undefined,
isOAuthError: this.isOAuthError(error),
errorType: error?.constructor?.name,
errorKeys: error && typeof error === 'object' ? Object.keys(error) : [],
oauthSkipped,
});
}
logger.error(`[MCP][${serverName}] Initialization failed`, error);
throw error;
}
@ -340,6 +393,21 @@ export class MCPManager {
return false;
}
// Debug: Log error details for filesystem server
if (error && typeof error === 'object' && 'message' in error) {
const errorMessage = (error as { message?: string }).message;
if (errorMessage && errorMessage.includes('filesystem')) {
logger.debug('[MCP] isOAuthError check for filesystem:', {
message: errorMessage,
hasCode: 'code' in error,
code: (error as { code?: number }).code,
includes401: errorMessage.includes('401'),
includes403: errorMessage.includes('403'),
includesNon200: errorMessage.includes('Non-200 status code (401)'),
});
}
}
// Check for SSE error with 401 status
if ('message' in error && typeof error.message === 'string') {
return error.message.includes('401') || error.message.includes('Non-200 status code (401)');
@ -578,6 +646,10 @@ export class MCPManager {
this.userConnections.get(userId)?.set(serverName, connection);
logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`);
// Emit event that connection is established for waiting endpoints
this.emit('connectionEstablished', { userId, serverName, connection });
// Update timestamp on creation
this.updateUserLastActivity(userId);
return connection;
@ -618,7 +690,7 @@ export class MCPManager {
const connection = userMap?.get(serverName);
if (connection) {
logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`);
await connection.disconnect();
await connection.disconnectAndStopReconnecting();
this.removeUserConnection(userId, serverName);
}
}
@ -657,6 +729,12 @@ export class MCPManager {
return this.connections;
}
/** Returns the user-level connection if it exists (does not create one) */
public getUserConnectionIfExists(userId: string, serverName: string): MCPConnection | undefined {
const userMap = this.userConnections.get(userId);
return userMap?.get(serverName);
}
/** Attempts to reconnect an app-level connection if it's disconnected */
private async isConnectionActive({
serverName,
@ -928,24 +1006,62 @@ export class MCPManager {
/** Disconnects all app-level and user-level connections */
public async disconnectAll(): Promise<void> {
logger.info('[MCP] Disconnecting all app-level and user-level connections...');
const userDisconnectPromises = Array.from(this.userConnections.keys()).map((userId) =>
this.disconnectUserConnections(userId),
);
await Promise.allSettled(userDisconnectPromises);
this.userLastActivity.clear();
logger.info('[MCP] Disconnecting all connections...');
// Disconnect all app-level connections
const appDisconnectPromises = Array.from(this.connections.values()).map((connection) =>
connection.disconnect().catch((error) => {
logger.error(`[MCP][${connection.serverName}] Error during disconnectAll:`, error);
}),
);
await Promise.allSettled(appDisconnectPromises);
const appConnections = Array.from(this.connections.values());
await Promise.allSettled(appConnections.map((connection) => connection.disconnect()));
this.connections.clear();
logger.info('[MCP] All connections processed for disconnection.');
// Disconnect all user-level connections
const userConnections = Array.from(this.userConnections.values()).flatMap((userMap) =>
Array.from(userMap.values()),
);
await Promise.allSettled(userConnections.map((connection) => connection.disconnect()));
this.userConnections.clear();
// Clear activity timestamps
this.userLastActivity.clear();
logger.info('[MCP] All connections disconnected');
}
/**
* Get connection status for a specific user and server
*/
public async getUserConnectionStatus(
userId: string,
serverName: string,
): Promise<{
connected: boolean;
hasConnection: boolean;
}> {
const userConnections = this.userConnections.get(userId);
const connection = userConnections?.get(serverName);
if (!connection) {
return {
connected: false,
hasConnection: false,
};
}
try {
const isConnected = await connection.isConnected();
return {
connected: isConnected,
hasConnection: true,
};
} catch (error) {
logger.error(
`[MCP] Error checking connection status for user ${userId}, server ${serverName}:`,
error,
);
return {
connected: false,
hasConnection: true,
};
}
}
/** Destroys the singleton instance and disconnects all connections */