LibreChat/packages/api/src/mcp/connection.ts

743 lines
24 KiB
TypeScript

import { EventEmitter } from 'events';
import { fetch as undiciFetch, Agent } from 'undici';
import {
StdioClientTransport,
getDefaultEnvironment,
} from '@modelcontextprotocol/sdk/client/stdio.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { logger } from '@librechat/data-schemas';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type {
RequestInit as UndiciRequestInit,
RequestInfo as UndiciRequestInfo,
Response as UndiciResponse,
} from 'undici';
import type { MCPOAuthTokens } from './oauth/types';
import { mcpConfig } from './mcpConfig';
import type * as t from './types';
type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
return 'command' in options;
}
function isWebSocketOptions(options: t.MCPOptions): options is t.WebSocketOptions {
if ('url' in options) {
const protocol = new URL(options.url).protocol;
return protocol === 'ws:' || protocol === 'wss:';
}
return false;
}
function isSSEOptions(options: t.MCPOptions): options is t.SSEOptions {
if ('url' in options) {
const protocol = new URL(options.url).protocol;
return protocol !== 'ws:' && protocol !== 'wss:';
}
return false;
}
/**
* Checks if the provided options are for a Streamable HTTP transport.
*
* Streamable HTTP is an MCP transport that uses HTTP POST for sending messages
* and supports streaming responses. It provides better performance than
* SSE transport while maintaining compatibility with most network environments.
*
* @param options MCP connection options to check
* @returns True if options are for a streamable HTTP transport
*/
function isStreamableHTTPOptions(options: t.MCPOptions): options is t.StreamableHTTPOptions {
if ('url' in options && 'type' in options) {
const optionType = options.type as string;
if (optionType === 'streamable-http' || optionType === 'http') {
const protocol = new URL(options.url).protocol;
return protocol !== 'ws:' && protocol !== 'wss:';
}
}
return false;
}
const FIVE_MINUTES = 5 * 60 * 1000;
const DEFAULT_TIMEOUT = 60000;
interface MCPConnectionParams {
serverName: string;
serverConfig: t.MCPOptions;
userId?: string;
oauthTokens?: MCPOAuthTokens | null;
}
export class MCPConnection extends EventEmitter {
public client: Client;
private options: t.MCPOptions;
private transport: Transport | null = null; // Make this nullable
private connectionState: t.ConnectionState = 'disconnected';
private connectPromise: Promise<void> | null = null;
private readonly MAX_RECONNECT_ATTEMPTS = 3;
public readonly serverName: string;
private shouldStopReconnecting = false;
private isReconnecting = false;
private isInitializing = false;
private reconnectAttempts = 0;
private readonly userId?: string;
private lastPingTime: number;
private lastConnectionCheckAt: number = 0;
private oauthTokens?: MCPOAuthTokens | null;
private requestHeaders?: Record<string, string> | null;
private oauthRequired = false;
iconPath?: string;
timeout?: number;
url?: string;
setRequestHeaders(headers: Record<string, string> | null): void {
if (!headers) {
return;
}
const normalizedHeaders: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
normalizedHeaders[key.toLowerCase()] = value;
}
this.requestHeaders = normalizedHeaders;
}
getRequestHeaders(): Record<string, string> | null | undefined {
return this.requestHeaders;
}
constructor(params: MCPConnectionParams) {
super();
this.options = params.serverConfig;
this.serverName = params.serverName;
this.userId = params.userId;
this.iconPath = params.serverConfig.iconPath;
this.timeout = params.serverConfig.timeout;
this.lastPingTime = Date.now();
if (params.oauthTokens) {
this.oauthTokens = params.oauthTokens;
}
this.client = new Client(
{
name: '@librechat/api-client',
version: '1.2.3',
},
{
capabilities: {},
},
);
this.setupEventListeners();
}
/** Helper to generate consistent log prefixes */
private getLogPrefix(): string {
const userPart = this.userId ? `[User: ${this.userId}]` : '';
return `[MCP]${userPart}[${this.serverName}]`;
}
/**
* Factory function to create fetch functions without capturing the entire `this` context.
* This helps prevent memory leaks by only passing necessary dependencies.
*
* @param getHeaders Function to retrieve request headers
* @param timeout Timeout value for the agent (in milliseconds)
* @returns A fetch function that merges headers appropriately
*/
private createFetchFunction(
getHeaders: () => Record<string, string> | null | undefined,
timeout?: number,
): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> {
return function customFetch(
input: UndiciRequestInfo,
init?: UndiciRequestInit,
): Promise<UndiciResponse> {
const requestHeaders = getHeaders();
const effectiveTimeout = timeout || DEFAULT_TIMEOUT;
const agent = new Agent({
bodyTimeout: effectiveTimeout,
headersTimeout: effectiveTimeout,
});
if (!requestHeaders) {
return undiciFetch(input, { ...init, dispatcher: agent });
}
let initHeaders: Record<string, string> = {};
if (init?.headers) {
if (init.headers instanceof Headers) {
initHeaders = Object.fromEntries(init.headers.entries());
} else if (Array.isArray(init.headers)) {
initHeaders = Object.fromEntries(init.headers);
} else {
initHeaders = init.headers as Record<string, string>;
}
}
return undiciFetch(input, {
...init,
headers: {
...initHeaders,
...requestHeaders,
},
dispatcher: agent,
});
};
}
private emitError(error: unknown, errorContext: string): void {
const errorMessage = error instanceof Error ? error.message : String(error);
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
}
private constructTransport(options: t.MCPOptions): Transport {
try {
let type: t.MCPOptions['type'];
if (isStdioOptions(options)) {
type = 'stdio';
} else if (isWebSocketOptions(options)) {
type = 'websocket';
} else if (isStreamableHTTPOptions(options)) {
// Could be either 'streamable-http' or 'http', normalize to 'streamable-http'
type = 'streamable-http';
} else if (isSSEOptions(options)) {
type = 'sse';
} else {
throw new Error(
'Cannot infer transport type: options.type is not provided and cannot be inferred from other properties.',
);
}
switch (type) {
case 'stdio':
if (!isStdioOptions(options)) {
throw new Error('Invalid options for stdio transport.');
}
return new StdioClientTransport({
command: options.command,
args: options.args,
// workaround bug of mcp sdk that can't pass env:
// https://github.com/modelcontextprotocol/typescript-sdk/issues/216
env: { ...getDefaultEnvironment(), ...(options.env ?? {}) },
});
case 'websocket':
if (!isWebSocketOptions(options)) {
throw new Error('Invalid options for websocket transport.');
}
this.url = options.url;
return new WebSocketClientTransport(new URL(options.url));
case 'sse': {
if (!isSSEOptions(options)) {
throw new Error('Invalid options for sse transport.');
}
this.url = options.url;
const url = new URL(options.url);
logger.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
const abortController = new AbortController();
/** Add OAuth token to headers if available */
const headers = { ...options.headers };
if (this.oauthTokens?.access_token) {
headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`;
}
const timeoutValue = this.timeout || DEFAULT_TIMEOUT;
const transport = new SSEClientTransport(url, {
requestInit: {
headers,
signal: abortController.signal,
},
eventSourceInit: {
fetch: (url, init) => {
const fetchHeaders = new Headers(Object.assign({}, init?.headers, headers));
const agent = new Agent({
bodyTimeout: timeoutValue,
headersTimeout: timeoutValue,
});
return undiciFetch(url, {
...init,
dispatcher: agent,
headers: fetchHeaders,
});
},
},
fetch: this.createFetchFunction(
this.getRequestHeaders.bind(this),
this.timeout,
) as unknown as FetchLike,
});
transport.onclose = () => {
logger.info(`${this.getLogPrefix()} SSE transport closed`);
this.emit('connectionChange', 'disconnected');
};
transport.onmessage = (message) => {
logger.info(`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`);
};
this.setupTransportErrorHandlers(transport);
return transport;
}
case 'streamable-http': {
if (!isStreamableHTTPOptions(options)) {
throw new Error('Invalid options for streamable-http transport.');
}
this.url = options.url;
const url = new URL(options.url);
logger.info(
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
);
const abortController = new AbortController();
/** Add OAuth token to headers if available */
const headers = { ...options.headers };
if (this.oauthTokens?.access_token) {
headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`;
}
const transport = new StreamableHTTPClientTransport(url, {
requestInit: {
headers,
signal: abortController.signal,
},
fetch: this.createFetchFunction(
this.getRequestHeaders.bind(this),
this.timeout,
) as unknown as FetchLike,
});
transport.onclose = () => {
logger.info(`${this.getLogPrefix()} Streamable-http transport closed`);
this.emit('connectionChange', 'disconnected');
};
transport.onmessage = (message: JSONRPCMessage) => {
logger.info(`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`);
};
this.setupTransportErrorHandlers(transport);
return transport;
}
default: {
throw new Error(`Unsupported transport type: ${type}`);
}
}
} catch (error) {
this.emitError(error, 'Failed to construct transport:');
throw error;
}
}
private setupEventListeners(): void {
this.isInitializing = true;
this.on('connectionChange', (state: t.ConnectionState) => {
this.connectionState = state;
if (state === 'connected') {
this.isReconnecting = false;
this.isInitializing = false;
this.shouldStopReconnecting = false;
this.reconnectAttempts = 0;
/**
* // FOR DEBUGGING
* // this.client.setRequestHandler(PingRequestSchema, async (request, extra) => {
* // logger.info(`[MCP][${this.serverName}] PingRequest: ${JSON.stringify(request)}`);
* // if (getEventListeners && extra.signal) {
* // const listenerCount = getEventListeners(extra.signal, 'abort').length;
* // logger.debug(`Signal has ${listenerCount} abort listeners`);
* // }
* // return {};
* // });
*/
} else if (state === 'error' && !this.isReconnecting && !this.isInitializing) {
this.handleReconnection().catch((error) => {
logger.error(`${this.getLogPrefix()} Reconnection handler failed:`, error);
});
}
});
this.subscribeToResources();
}
private async handleReconnection(): Promise<void> {
if (
this.isReconnecting ||
this.shouldStopReconnecting ||
this.isInitializing ||
this.oauthRequired
) {
if (this.oauthRequired) {
logger.info(`${this.getLogPrefix()} OAuth required, skipping reconnection attempts`);
}
return;
}
this.isReconnecting = true;
const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000);
try {
while (
this.reconnectAttempts < this.MAX_RECONNECT_ATTEMPTS &&
!(this.shouldStopReconnecting as boolean)
) {
this.reconnectAttempts++;
const delay = backoffDelay(this.reconnectAttempts);
logger.info(
`${this.getLogPrefix()} Reconnecting ${this.reconnectAttempts}/${this.MAX_RECONNECT_ATTEMPTS} (delay: ${delay}ms)`,
);
await new Promise((resolve) => setTimeout(resolve, delay));
try {
await this.connect();
this.reconnectAttempts = 0;
return;
} catch (error) {
logger.error(`${this.getLogPrefix()} Reconnection attempt failed:`, error);
if (
this.reconnectAttempts === this.MAX_RECONNECT_ATTEMPTS ||
(this.shouldStopReconnecting as boolean)
) {
logger.error(`${this.getLogPrefix()} Stopping reconnection attempts`);
return;
}
}
}
} finally {
this.isReconnecting = false;
}
}
private subscribeToResources(): void {
this.client.setNotificationHandler(ResourceListChangedNotificationSchema, async () => {
this.emit('resourcesChanged');
});
}
async connectClient(): Promise<void> {
if (this.connectionState === 'connected') {
return;
}
if (this.connectPromise) {
return this.connectPromise;
}
if (this.shouldStopReconnecting) {
return;
}
this.emit('connectionChange', 'connecting');
this.connectPromise = (async () => {
try {
if (this.transport) {
try {
await this.client.close();
this.transport = null;
} catch (error) {
logger.warn(`${this.getLogPrefix()} Error closing connection:`, error);
}
}
this.transport = this.constructTransport(this.options);
this.setupTransportDebugHandlers();
const connectTimeout = this.options.initTimeout ?? 120000;
await Promise.race([
this.client.connect(this.transport),
new Promise((_resolve, reject) =>
setTimeout(
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
connectTimeout,
),
),
]);
this.connectionState = 'connected';
this.emit('connectionChange', 'connected');
this.reconnectAttempts = 0;
} catch (error) {
// Check if it's an OAuth authentication error
if (this.isOAuthError(error)) {
logger.warn(`${this.getLogPrefix()} OAuth authentication required`);
this.oauthRequired = true;
const serverUrl = this.url;
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
/** Promise that will resolve when OAuth is handled */
const oauthHandledPromise = new Promise<void>((resolve, reject) => {
let timeoutId: NodeJS.Timeout | null = null;
let oauthHandledListener: (() => void) | null = null;
let oauthFailedListener: ((error: Error) => void) | null = null;
/** Cleanup function to remove listeners and clear timeout */
const cleanup = () => {
if (timeoutId) {
clearTimeout(timeoutId);
}
if (oauthHandledListener) {
this.off('oauthHandled', oauthHandledListener);
}
if (oauthFailedListener) {
this.off('oauthFailed', oauthFailedListener);
}
};
// Success handler
oauthHandledListener = () => {
cleanup();
resolve();
};
// Failure handler
oauthFailedListener = (error: Error) => {
cleanup();
reject(error);
};
// Timeout handler
timeoutId = setTimeout(() => {
cleanup();
reject(new Error(`OAuth handling timeout after ${oauthTimeout}ms`));
}, oauthTimeout);
// Listen for both success and failure events
this.once('oauthHandled', oauthHandledListener);
this.once('oauthFailed', oauthFailedListener);
});
// Emit the event
this.emit('oauthRequired', {
serverName: this.serverName,
error,
serverUrl,
userId: this.userId,
});
try {
// Wait for OAuth to be handled
await oauthHandledPromise;
// Reset the oauthRequired flag
this.oauthRequired = false;
// Don't throw the error - just return so connection can be retried
logger.info(
`${this.getLogPrefix()} OAuth handled successfully, connection will be retried`,
);
return;
} catch (oauthError) {
// OAuth failed or timed out
this.oauthRequired = false;
logger.error(`${this.getLogPrefix()} OAuth handling failed:`, oauthError);
// Re-throw the original authentication error
throw error;
}
}
this.connectionState = 'error';
this.emit('connectionChange', 'error');
throw error;
} finally {
this.connectPromise = null;
}
})();
return this.connectPromise;
}
private setupTransportDebugHandlers(): void {
if (!this.transport) {
return;
}
this.transport.onmessage = (msg) => {
logger.debug(`${this.getLogPrefix()} Transport received: ${JSON.stringify(msg)}`);
};
const originalSend = this.transport.send.bind(this.transport);
this.transport.send = async (msg) => {
if ('result' in msg && !('method' in msg) && Object.keys(msg.result ?? {}).length === 0) {
if (Date.now() - this.lastPingTime < FIVE_MINUTES) {
throw new Error('Empty result');
}
this.lastPingTime = Date.now();
}
logger.debug(`${this.getLogPrefix()} Transport sending: ${JSON.stringify(msg)}`);
return originalSend(msg);
};
}
async connect(): Promise<void> {
try {
await this.disconnect();
await this.connectClient();
if (!(await this.isConnected())) {
throw new Error('Connection not established');
}
} catch (error) {
logger.error(`${this.getLogPrefix()} Connection failed:`, error);
throw error;
}
}
private setupTransportErrorHandlers(transport: Transport): void {
transport.onerror = (error) => {
logger.error(`${this.getLogPrefix()} Transport error:`, error);
// Check if it's an OAuth authentication error
if (error && typeof error === 'object' && 'code' in error) {
const errorCode = (error as unknown as { code?: number }).code;
if (errorCode === 401 || errorCode === 403) {
logger.warn(`${this.getLogPrefix()} OAuth authentication error detected`);
this.emit('oauthError', error);
}
}
this.emit('connectionChange', 'error');
};
}
public async disconnect(): Promise<void> {
try {
if (this.transport) {
await this.client.close();
this.transport = null;
}
if (this.connectionState === 'disconnected') {
return;
}
this.connectionState = 'disconnected';
this.emit('connectionChange', 'disconnected');
} finally {
this.connectPromise = null;
}
}
async fetchResources(): Promise<t.MCPResource[]> {
try {
const { resources } = await this.client.listResources();
return resources;
} catch (error) {
this.emitError(error, 'Failed to fetch resources:');
return [];
}
}
async fetchTools() {
try {
const { tools } = await this.client.listTools();
return tools;
} catch (error) {
this.emitError(error, 'Failed to fetch tools:');
return [];
}
}
async fetchPrompts(): Promise<t.MCPPrompt[]> {
try {
const { prompts } = await this.client.listPrompts();
return prompts;
} catch (error) {
this.emitError(error, 'Failed to fetch prompts:');
return [];
}
}
public async isConnected(): Promise<boolean> {
// First check if we're in a connected state
if (this.connectionState !== 'connected') {
return false;
}
// If we recently checked, skip expensive verification
const now = Date.now();
if (now - this.lastConnectionCheckAt < mcpConfig.CONNECTION_CHECK_TTL) {
return true;
}
this.lastConnectionCheckAt = now;
try {
// Try ping first as it's the lightest check
await this.client.ping();
return this.connectionState === 'connected';
} catch (error) {
// Check if the error is because ping is not supported (method not found)
const pingUnsupported =
error instanceof Error &&
((error as Error)?.message.includes('-32601') ||
(error as Error)?.message.includes('invalid method ping') ||
(error as Error)?.message.includes('method not found'));
if (!pingUnsupported) {
logger.error(`${this.getLogPrefix()} Ping failed:`, error);
return false;
}
// Ping is not supported by this server, try an alternative verification
logger.debug(
`${this.getLogPrefix()} Server does not support ping method, verifying connection with capabilities`,
);
try {
// Get server capabilities to verify connection is truly active
const capabilities = this.client.getServerCapabilities();
// If we have capabilities, try calling a supported method to verify connection
if (capabilities?.tools) {
await this.client.listTools();
return this.connectionState === 'connected';
} else if (capabilities?.resources) {
await this.client.listResources();
return this.connectionState === 'connected';
} else if (capabilities?.prompts) {
await this.client.listPrompts();
return this.connectionState === 'connected';
} else {
// No capabilities to test, but we're in connected state and initialization succeeded
logger.debug(
`${this.getLogPrefix()} No capabilities to test, assuming connected based on state`,
);
return this.connectionState === 'connected';
}
} catch (capabilityError) {
// If capability check fails, the connection is likely broken
logger.error(`${this.getLogPrefix()} Connection verification failed:`, capabilityError);
return false;
}
}
}
public setOAuthTokens(tokens: MCPOAuthTokens): void {
this.oauthTokens = tokens;
}
private isOAuthError(error: unknown): boolean {
if (!error || typeof error !== 'object') {
return false;
}
// Check for SSE error with 401 status
if ('message' in error && typeof error.message === 'string') {
return error.message.includes('401') || error.message.includes('Non-200 status code (401)');
}
// Check for error code
if ('code' in error) {
const code = (error as { code?: number }).code;
return code === 401 || code === 403;
}
return false;
}
}