mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-22 19:30:15 +01:00
🔌 feat: MCP OAuth Integration in Chat UI
- **Real-Time Connection Status**: New backend APIs and React Query hooks provide live MCP server connection monitoring with automatic UI updates - **OAuth Flow Components**: Complete MCPConfigDialog, ServerInitializationSection, and CustomUserVarsSection with OAuth URL handling and polling-based completion - **Enhanced Server Selection**: MCPSelect component with connection-aware filtering, visual status indicators, and better credential management UX (still needs a lot of refinement since there is bloat/unused vars and functions leftover from the ideation phase on how to approach OAuth and connection statuses)
This commit is contained in:
parent
b39b60c012
commit
63140237a6
27 changed files with 1760 additions and 286 deletions
|
|
@ -69,6 +69,7 @@ export class MCPConnection extends EventEmitter {
|
|||
private lastPingTime: number;
|
||||
private oauthTokens?: MCPOAuthTokens | null;
|
||||
private oauthRequired = false;
|
||||
private oauthTimeoutId: NodeJS.Timeout | null = null;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
url?: string;
|
||||
|
|
@ -421,6 +422,7 @@ export class MCPConnection extends EventEmitter {
|
|||
const cleanup = () => {
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
this.oauthTimeoutId = null;
|
||||
}
|
||||
if (oauthHandledListener) {
|
||||
this.off('oauthHandled', oauthHandledListener);
|
||||
|
|
@ -448,11 +450,26 @@ export class MCPConnection extends EventEmitter {
|
|||
reject(new Error(`OAuth handling timeout after ${oauthTimeout}ms`));
|
||||
}, oauthTimeout);
|
||||
|
||||
// Store the timeout ID for potential cancellation
|
||||
this.oauthTimeoutId = timeoutId;
|
||||
|
||||
// Listen for both success and failure events
|
||||
this.once('oauthHandled', oauthHandledListener);
|
||||
this.once('oauthFailed', oauthFailedListener);
|
||||
});
|
||||
|
||||
// Check if there are any listeners for oauthRequired event
|
||||
const hasOAuthListeners = this.listenerCount('oauthRequired') > 0;
|
||||
|
||||
if (!hasOAuthListeners) {
|
||||
// No OAuth handler available (like during startup), immediately fail
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} OAuth required but no handler available, failing immediately`,
|
||||
);
|
||||
this.oauthRequired = false;
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Emit the event
|
||||
this.emit('oauthRequired', {
|
||||
serverName: this.serverName,
|
||||
|
|
@ -517,7 +534,7 @@ export class MCPConnection extends EventEmitter {
|
|||
try {
|
||||
await this.disconnect();
|
||||
await this.connectClient();
|
||||
if (!(await this.isConnected())) {
|
||||
if (!(await this.isConnected()) && !(this.isInitializing && this.oauthTokens)) {
|
||||
throw new Error('Connection not established');
|
||||
}
|
||||
} catch (error) {
|
||||
|
|
@ -545,6 +562,37 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
public async disconnect(): Promise<void> {
|
||||
try {
|
||||
// Cancel any pending OAuth timeout
|
||||
if (this.oauthTimeoutId) {
|
||||
clearTimeout(this.oauthTimeoutId);
|
||||
this.oauthTimeoutId = null;
|
||||
}
|
||||
|
||||
if (this.transport) {
|
||||
await this.client.close();
|
||||
this.transport = null;
|
||||
}
|
||||
if (this.connectionState === 'disconnected') {
|
||||
return;
|
||||
}
|
||||
this.connectionState = 'disconnected';
|
||||
this.emit('connectionChange', 'disconnected');
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
}
|
||||
}
|
||||
|
||||
public async disconnectAndStopReconnecting(): Promise<void> {
|
||||
try {
|
||||
// Stop any reconnection attempts
|
||||
this.shouldStopReconnecting = true;
|
||||
|
||||
// Cancel any pending OAuth timeout
|
||||
if (this.oauthTimeoutId) {
|
||||
clearTimeout(this.oauthTimeoutId);
|
||||
this.oauthTimeoutId = null;
|
||||
}
|
||||
|
||||
if (this.transport) {
|
||||
await this.client.close();
|
||||
this.transport = null;
|
||||
|
|
@ -650,6 +698,11 @@ export class MCPConnection extends EventEmitter {
|
|||
this.oauthTokens = tokens;
|
||||
}
|
||||
|
||||
/** Get the current connection state */
|
||||
public getConnectionState(): t.ConnectionState {
|
||||
return this.connectionState;
|
||||
}
|
||||
|
||||
private isOAuthError(error: unknown): boolean {
|
||||
if (!error || typeof error !== 'object') {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ import { MCPTokenStorage } from './oauth/tokens';
|
|||
import { formatToolContent } from './parsers';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils/env';
|
||||
import { EventEmitter } from 'events';
|
||||
|
||||
export class MCPManager {
|
||||
export class MCPManager extends EventEmitter {
|
||||
private static instance: MCPManager | null = null;
|
||||
/** App-level connections initialized at startup */
|
||||
private connections: Map<string, MCPConnection> = new Map();
|
||||
|
|
@ -29,6 +30,10 @@ export class MCPManager {
|
|||
/** Store MCP server instructions */
|
||||
private serverInstructions: Map<string, string> = new Map();
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
|
||||
public static getInstance(): MCPManager {
|
||||
if (!MCPManager.instance) {
|
||||
MCPManager.instance = new MCPManager();
|
||||
|
|
@ -47,7 +52,7 @@ export class MCPManager {
|
|||
mcpServers: t.MCPServers;
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||
tokenMethods?: TokenMethods;
|
||||
}): Promise<void> {
|
||||
}): Promise<Record<string, boolean>> {
|
||||
this.mcpConfigs = mcpServers;
|
||||
|
||||
if (!flowManager) {
|
||||
|
|
@ -59,6 +64,7 @@ export class MCPManager {
|
|||
}
|
||||
const entries = Object.entries(mcpServers);
|
||||
const initializedServers = new Set();
|
||||
const oauthSkippedServers = new Set();
|
||||
const connectionResults = await Promise.allSettled(
|
||||
entries.map(async ([serverName, config], i) => {
|
||||
try {
|
||||
|
|
@ -70,19 +76,46 @@ export class MCPManager {
|
|||
});
|
||||
initializedServers.add(i);
|
||||
} catch (error) {
|
||||
logger.error(`[MCP][${serverName}] Initialization failed`, error);
|
||||
// Check if this is an OAuth skipped error
|
||||
if (
|
||||
error instanceof Error &&
|
||||
(error as Error & { isOAuthSkipped?: boolean }).isOAuthSkipped
|
||||
) {
|
||||
oauthSkippedServers.add(i);
|
||||
} else {
|
||||
logger.error(`[MCP][${serverName}] Initialization failed`, error);
|
||||
// Debug: Log the actual error for filesystem server
|
||||
if (serverName === 'filesystem') {
|
||||
logger.error(`[MCP][${serverName}] Error details:`, {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
isOAuthError: this.isOAuthError(error),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
const failedConnections = connectionResults.filter(
|
||||
(result): result is PromiseRejectedResult => result.status === 'rejected',
|
||||
(result): result is PromiseRejectedResult =>
|
||||
result.status === 'rejected' &&
|
||||
!(
|
||||
result.reason instanceof Error &&
|
||||
(result.reason as Error & { isOAuthSkipped?: boolean }).isOAuthSkipped
|
||||
),
|
||||
);
|
||||
|
||||
logger.info(
|
||||
`[MCP] Initialized ${initializedServers.size}/${entries.length} app-level server(s)`,
|
||||
);
|
||||
|
||||
if (oauthSkippedServers.size > 0) {
|
||||
logger.info(
|
||||
`[MCP] ${oauthSkippedServers.size}/${entries.length} app-level server(s) skipped for OAuth`,
|
||||
);
|
||||
}
|
||||
|
||||
if (failedConnections.length > 0) {
|
||||
logger.warn(
|
||||
`[MCP] ${failedConnections.length}/${entries.length} app-level server(s) failed to initialize`,
|
||||
|
|
@ -92,6 +125,8 @@ export class MCPManager {
|
|||
entries.forEach(([serverName], index) => {
|
||||
if (initializedServers.has(index)) {
|
||||
logger.info(`[MCP][${serverName}] ✓ Initialized`);
|
||||
} else if (oauthSkippedServers.has(index)) {
|
||||
logger.info(`[MCP][${serverName}] OAuth Required`);
|
||||
} else {
|
||||
logger.info(`[MCP][${serverName}] ✗ Failed`);
|
||||
}
|
||||
|
|
@ -99,9 +134,16 @@ export class MCPManager {
|
|||
|
||||
if (initializedServers.size === entries.length) {
|
||||
logger.info('[MCP] All app-level servers initialized successfully');
|
||||
} else if (initializedServers.size === 0) {
|
||||
} else if (initializedServers.size === 0 && oauthSkippedServers.size === 0) {
|
||||
logger.warn('[MCP] No app-level servers initialized');
|
||||
}
|
||||
|
||||
// Return OAuth requirement map
|
||||
const oauthRequirementMap: Record<string, boolean> = {};
|
||||
entries.forEach(([serverName], index) => {
|
||||
oauthRequirementMap[serverName] = oauthSkippedServers.has(index);
|
||||
});
|
||||
return oauthRequirementMap;
|
||||
}
|
||||
|
||||
/** Initializes a single MCP server connection (app-level) */
|
||||
|
|
@ -166,40 +208,18 @@ export class MCPManager {
|
|||
logger.info(`[MCP][${serverName}] Loaded OAuth tokens`);
|
||||
}
|
||||
const connection = new MCPConnection(serverName, processedConfig, undefined, tokens);
|
||||
logger.info(`[MCP][${serverName}] Setting up OAuth event listener`);
|
||||
connection.on('oauthRequired', async (data) => {
|
||||
logger.debug(`[MCP][${serverName}] oauthRequired event received`);
|
||||
const result = await this.handleOAuthRequired({
|
||||
...data,
|
||||
flowManager,
|
||||
});
|
||||
if (result?.tokens && tokenMethods?.createToken) {
|
||||
try {
|
||||
connection.setOAuthTokens(result.tokens);
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: CONSTANTS.SYSTEM_USER_ID,
|
||||
serverName,
|
||||
tokens: result.tokens,
|
||||
createToken: tokenMethods.createToken,
|
||||
updateToken: tokenMethods.updateToken,
|
||||
findToken: tokenMethods.findToken,
|
||||
clientInfo: result.clientInfo,
|
||||
});
|
||||
logger.info(`[MCP][${serverName}] OAuth tokens saved to storage`);
|
||||
} catch (error) {
|
||||
logger.error(`[MCP][${serverName}] Failed to save OAuth tokens to storage`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit oauthHandled if we actually got tokens (OAuth succeeded)
|
||||
if (result?.tokens) {
|
||||
connection.emit('oauthHandled');
|
||||
} else {
|
||||
// OAuth failed, emit oauthFailed to properly reject the promise
|
||||
logger.warn(`[MCP][${serverName}] OAuth failed, emitting oauthFailed event`);
|
||||
connection.emit('oauthFailed', new Error('OAuth authentication failed'));
|
||||
}
|
||||
// Track OAuth skipped state explicitly
|
||||
let oauthSkipped = false;
|
||||
|
||||
connection.on('oauthRequired', async () => {
|
||||
logger.debug(`[MCP][${serverName}] oauthRequired event received`);
|
||||
oauthSkipped = true;
|
||||
// Emit event to signal that initialization should be skipped
|
||||
connection.emit('oauthSkipped');
|
||||
return;
|
||||
});
|
||||
|
||||
try {
|
||||
const connectTimeout = processedConfig.initTimeout ?? 30000;
|
||||
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||
|
|
@ -208,13 +228,35 @@ export class MCPManager {
|
|||
connectTimeout,
|
||||
),
|
||||
);
|
||||
|
||||
// Listen for oauthSkipped event to stop initialization
|
||||
const oauthSkippedPromise = new Promise<void>((resolve) => {
|
||||
connection.once('oauthSkipped', () => {
|
||||
logger.debug(`[MCP][${serverName}] OAuth skipped, stopping initialization`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
const connectionAttempt = this.initializeServer({
|
||||
connection,
|
||||
logPrefix: `[MCP][${serverName}]`,
|
||||
flowManager,
|
||||
handleOAuth: false,
|
||||
});
|
||||
await Promise.race([connectionAttempt, connectionTimeout]);
|
||||
|
||||
// Race between connection attempt, timeout, and oauthSkipped
|
||||
await Promise.race([connectionAttempt, connectionTimeout, oauthSkippedPromise]);
|
||||
|
||||
// Check if OAuth was explicitly skipped
|
||||
if (oauthSkipped) {
|
||||
// Throw a special error to signal OAuth was skipped
|
||||
const oauthSkippedError = new Error(`OAuth required for ${serverName}`) as Error & {
|
||||
isOAuthSkipped: boolean;
|
||||
};
|
||||
oauthSkippedError.isOAuthSkipped = true;
|
||||
throw oauthSkippedError;
|
||||
}
|
||||
|
||||
if (await connection.isConnected()) {
|
||||
this.connections.set(serverName, connection);
|
||||
|
||||
|
|
@ -269,6 +311,17 @@ export class MCPManager {
|
|||
logger.info(`[MCP][${serverName}] ✗ Failed`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Debug: Log the actual error for filesystem server
|
||||
if (serverName === 'filesystem') {
|
||||
logger.error(`[MCP][${serverName}] Error details:`, {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
isOAuthError: this.isOAuthError(error),
|
||||
errorType: error?.constructor?.name,
|
||||
errorKeys: error && typeof error === 'object' ? Object.keys(error) : [],
|
||||
oauthSkipped,
|
||||
});
|
||||
}
|
||||
logger.error(`[MCP][${serverName}] Initialization failed`, error);
|
||||
throw error;
|
||||
}
|
||||
|
|
@ -340,6 +393,21 @@ export class MCPManager {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Debug: Log error details for filesystem server
|
||||
if (error && typeof error === 'object' && 'message' in error) {
|
||||
const errorMessage = (error as { message?: string }).message;
|
||||
if (errorMessage && errorMessage.includes('filesystem')) {
|
||||
logger.debug('[MCP] isOAuthError check for filesystem:', {
|
||||
message: errorMessage,
|
||||
hasCode: 'code' in error,
|
||||
code: (error as { code?: number }).code,
|
||||
includes401: errorMessage.includes('401'),
|
||||
includes403: errorMessage.includes('403'),
|
||||
includesNon200: errorMessage.includes('Non-200 status code (401)'),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check for SSE error with 401 status
|
||||
if ('message' in error && typeof error.message === 'string') {
|
||||
return error.message.includes('401') || error.message.includes('Non-200 status code (401)');
|
||||
|
|
@ -578,6 +646,10 @@ export class MCPManager {
|
|||
this.userConnections.get(userId)?.set(serverName, connection);
|
||||
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`);
|
||||
|
||||
// Emit event that connection is established for waiting endpoints
|
||||
this.emit('connectionEstablished', { userId, serverName, connection });
|
||||
|
||||
// Update timestamp on creation
|
||||
this.updateUserLastActivity(userId);
|
||||
return connection;
|
||||
|
|
@ -618,7 +690,7 @@ export class MCPManager {
|
|||
const connection = userMap?.get(serverName);
|
||||
if (connection) {
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`);
|
||||
await connection.disconnect();
|
||||
await connection.disconnectAndStopReconnecting();
|
||||
this.removeUserConnection(userId, serverName);
|
||||
}
|
||||
}
|
||||
|
|
@ -657,6 +729,12 @@ export class MCPManager {
|
|||
return this.connections;
|
||||
}
|
||||
|
||||
/** Returns the user-level connection if it exists (does not create one) */
|
||||
public getUserConnectionIfExists(userId: string, serverName: string): MCPConnection | undefined {
|
||||
const userMap = this.userConnections.get(userId);
|
||||
return userMap?.get(serverName);
|
||||
}
|
||||
|
||||
/** Attempts to reconnect an app-level connection if it's disconnected */
|
||||
private async isConnectionActive({
|
||||
serverName,
|
||||
|
|
@ -928,24 +1006,62 @@ export class MCPManager {
|
|||
|
||||
/** Disconnects all app-level and user-level connections */
|
||||
public async disconnectAll(): Promise<void> {
|
||||
logger.info('[MCP] Disconnecting all app-level and user-level connections...');
|
||||
|
||||
const userDisconnectPromises = Array.from(this.userConnections.keys()).map((userId) =>
|
||||
this.disconnectUserConnections(userId),
|
||||
);
|
||||
await Promise.allSettled(userDisconnectPromises);
|
||||
this.userLastActivity.clear();
|
||||
logger.info('[MCP] Disconnecting all connections...');
|
||||
|
||||
// Disconnect all app-level connections
|
||||
const appDisconnectPromises = Array.from(this.connections.values()).map((connection) =>
|
||||
connection.disconnect().catch((error) => {
|
||||
logger.error(`[MCP][${connection.serverName}] Error during disconnectAll:`, error);
|
||||
}),
|
||||
);
|
||||
await Promise.allSettled(appDisconnectPromises);
|
||||
const appConnections = Array.from(this.connections.values());
|
||||
await Promise.allSettled(appConnections.map((connection) => connection.disconnect()));
|
||||
this.connections.clear();
|
||||
|
||||
logger.info('[MCP] All connections processed for disconnection.');
|
||||
// Disconnect all user-level connections
|
||||
const userConnections = Array.from(this.userConnections.values()).flatMap((userMap) =>
|
||||
Array.from(userMap.values()),
|
||||
);
|
||||
await Promise.allSettled(userConnections.map((connection) => connection.disconnect()));
|
||||
this.userConnections.clear();
|
||||
|
||||
// Clear activity timestamps
|
||||
this.userLastActivity.clear();
|
||||
|
||||
logger.info('[MCP] All connections disconnected');
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection status for a specific user and server
|
||||
*/
|
||||
public async getUserConnectionStatus(
|
||||
userId: string,
|
||||
serverName: string,
|
||||
): Promise<{
|
||||
connected: boolean;
|
||||
hasConnection: boolean;
|
||||
}> {
|
||||
const userConnections = this.userConnections.get(userId);
|
||||
const connection = userConnections?.get(serverName);
|
||||
|
||||
if (!connection) {
|
||||
return {
|
||||
connected: false,
|
||||
hasConnection: false,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const isConnected = await connection.isConnected();
|
||||
return {
|
||||
connected: isConnected,
|
||||
hasConnection: true,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[MCP] Error checking connection status for user ${userId}, server ${serverName}:`,
|
||||
error,
|
||||
);
|
||||
return {
|
||||
connected: false,
|
||||
hasConnection: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/** Destroys the singleton instance and disconnects all connections */
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue