mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-28 05:06:13 +01:00
🧠 feat: User Memories for Conversational Context (#7760)
* 🧠 feat: User Memories for Conversational Context
chore: mcp typing, use `t`
WIP: first pass, Memories UI
- Added MemoryViewer component for displaying, editing, and deleting user memories.
- Integrated data provider hooks for fetching, updating, and deleting memories.
- Implemented pagination and loading states for better user experience.
- Created unit tests for MemoryViewer to ensure functionality and interaction with data provider.
- Updated translation files to include new UI strings related to memories.
chore: move mcp-related files to own directory
chore: rename librechat-mcp to librechat-api
WIP: first pass, memory processing and data schemas
chore: linting in fileSearch.js query description
chore: rename librechat-api to @librechat/api across the project
WIP: first pass, functional memory agent
feat: add MemoryEditDialog and MemoryViewer components for managing user memories
- Introduced MemoryEditDialog for editing memory entries with validation and toast notifications.
- Updated MemoryViewer to support editing and deleting memories, including pagination and loading states.
- Enhanced data provider to handle memory updates with optional original key for better management.
- Added new localization strings for memory-related UI elements.
feat: add memory permissions management
- Implemented memory permissions in the backend, allowing roles to have specific permissions for using, creating, updating, and reading memories.
- Added new API endpoints for updating memory permissions associated with roles.
- Created a new AdminSettings component for managing memory permissions in the frontend.
- Integrated memory permissions into the existing roles and permissions schemas.
- Updated the interface to include memory settings and permissions.
- Enhanced the MemoryViewer component to conditionally render admin settings based on user roles.
- Added localization support for memory permissions in the translation files.
feat: move AdminSettings component to a new position in MemoryViewer for better visibility
refactor: clean up commented code in MemoryViewer component
feat: enhance MemoryViewer with search functionality and improve MemoryEditDialog integration
- Added a search input to filter memories in the MemoryViewer component.
- Refactored MemoryEditDialog to accept children for better customization.
- Updated MemoryViewer to utilize the new EditMemoryButton and DeleteMemoryButton components for editing and deleting memories.
- Improved localization support by adding new strings for memory filtering and deletion confirmation.
refactor: optimize memory filtering in MemoryViewer using match-sorter
- Replaced manual filtering logic with match-sorter for improved search functionality.
- Enhanced performance and readability of the filteredMemories computation.
feat: enhance MemoryEditDialog with triggerRef and improve updateMemory mutation handling
feat: implement access control for MemoryEditDialog and MemoryViewer components
refactor: remove commented out code and create runMemory method
refactor: rename role based files
feat: implement access control for memory usage in AgentClient
refactor: simplify checkVisionRequest method in AgentClient by removing commented-out code
refactor: make `agents` dir in api package
refactor: migrate Azure utilities to TypeScript and consolidate imports
refactor: move sanitizeFilename function to a new file and update imports, add related tests
refactor: update LLM configuration types and consolidate Azure options in the API package
chore: linting
chore: import order
refactor: replace getLLMConfig with getOpenAIConfig and remove unused LLM configuration file
chore: update winston-daily-rotate-file to version 5.0.0 and add object-hash dependency in package-lock.json
refactor: move primeResources and optionalChainWithEmptyCheck functions to resources.ts and update imports
refactor: move createRun function to a new run.ts file and update related imports
fix: ensure safeAttachments is correctly typed as an array of TFile
chore: add node-fetch dependency and refactor fetch-related functions into packages/api/utils, removing the old generators file
refactor: enhance TEndpointOption type by using Pick to streamline endpoint fields and add new properties for model parameters and client options
feat: implement initializeOpenAIOptions function and update OpenAI types for enhanced configuration handling
fix: update types due to new TEndpointOption typing
fix: ensure safe access to group parameters in initializeOpenAIOptions function
fix: remove redundant API key validation comment in initializeOpenAIOptions function
refactor: rename initializeOpenAIOptions to initializeOpenAI for consistency and update related documentation
refactor: decouple req.body fields and tool loading from initializeAgentOptions
chore: linting
refactor: adjust column widths in MemoryViewer for improved layout
refactor: simplify agent initialization by creating loadAgent function and removing unused code
feat: add memory configuration loading and validation functions
WIP: first pass, memory processing with config
feat: implement memory callback and artifact handling
feat: implement memory artifacts display and processing updates
feat: add memory configuration options and schema validation for validKeys
fix: update MemoryEditDialog and MemoryViewer to handle memory state and display improvements
refactor: remove padding from BookmarkTable and MemoryViewer headers for consistent styling
WIP: initial tokenLimit config and move Tokenizer to @librechat/api
refactor: update mongoMeili plugin methods to use callback for better error handling
feat: enhance memory management with token tracking and usage metrics
- Added token counting for memory entries to enforce limits and provide usage statistics.
- Updated memory retrieval and update routes to include total token usage and limit.
- Enhanced MemoryEditDialog and MemoryViewer components to display memory usage and token information.
- Refactored memory processing functions to handle token limits and provide feedback on memory capacity.
feat: implement memory artifact handling in attachment handler
- Enhanced useAttachmentHandler to process memory artifacts when receiving updates.
- Introduced handleMemoryArtifact utility to manage memory updates and deletions.
- Updated query client to reflect changes in memory state based on incoming data.
refactor: restructure web search key extraction logic
- Moved the logic for extracting API keys from the webSearchAuth configuration into a dedicated function, getWebSearchKeys.
- Updated webSearchKeys to utilize the new function for improved clarity and maintainability.
- Prevents build time errors
feat: add personalization settings and memory preferences management
- Introduced a new Personalization tab in settings to manage user memory preferences.
- Implemented API endpoints and client-side logic for updating memory preferences.
- Enhanced user interface components to reflect personalization options and memory usage.
- Updated permissions to allow users to opt out of memory features.
- Added localization support for new settings and messages related to personalization.
style: personalization switch class
feat: add PersonalizationIcon and align Side Panel UI
feat: implement memory creation functionality
- Added a new API endpoint for creating memory entries, including validation for key and value.
- Introduced MemoryCreateDialog component for user interface to facilitate memory creation.
- Integrated token limit checks to prevent exceeding user memory capacity.
- Updated MemoryViewer to include a button for opening the memory creation dialog.
- Enhanced localization support for new messages related to memory creation.
feat: enhance message processing with configurable window size
- Updated AgentClient to use a configurable message window size for processing messages.
- Introduced messageWindowSize option in memory configuration schema with a default value of 5.
- Improved logic for selecting messages to process based on the configured window size.
chore: update librechat-data-provider version to 0.7.87 in package.json and package-lock.json
chore: remove OpenAPIPlugin and its associated tests
chore: remove MIGRATION_README.md as migration tasks are completed
ci: fix backend tests
chore: remove unused translation keys from localization file
chore: remove problematic test file and unused var in AgentClient
chore: remove unused import and import directly for JSDoc
* feat: add api package build stage in Dockerfile for improved modularity
* docs: reorder build steps in contributing guide for clarity
This commit is contained in:
parent
cd7dd576c1
commit
29ef91b4dd
170 changed files with 5700 additions and 3632 deletions
583
packages/api/src/mcp/connection.ts
Normal file
583
packages/api/src/mcp/connection.ts
Normal file
|
|
@ -0,0 +1,583 @@
|
|||
import { EventEmitter } from 'events';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import {
|
||||
StdioClientTransport,
|
||||
getDefaultEnvironment,
|
||||
} from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
||||
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { Logger } from 'winston';
|
||||
import type * as t from './types';
|
||||
|
||||
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
|
||||
return 'command' in options;
|
||||
}
|
||||
|
||||
function isWebSocketOptions(options: t.MCPOptions): options is t.WebSocketOptions {
|
||||
if ('url' in options) {
|
||||
const protocol = new URL(options.url).protocol;
|
||||
return protocol === 'ws:' || protocol === 'wss:';
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function isSSEOptions(options: t.MCPOptions): options is t.SSEOptions {
|
||||
if ('url' in options) {
|
||||
const protocol = new URL(options.url).protocol;
|
||||
return protocol !== 'ws:' && protocol !== 'wss:';
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the provided options are for a Streamable HTTP transport.
|
||||
*
|
||||
* Streamable HTTP is an MCP transport that uses HTTP POST for sending messages
|
||||
* and supports streaming responses. It provides better performance than
|
||||
* SSE transport while maintaining compatibility with most network environments.
|
||||
*
|
||||
* @param options MCP connection options to check
|
||||
* @returns True if options are for a streamable HTTP transport
|
||||
*/
|
||||
function isStreamableHTTPOptions(options: t.MCPOptions): options is t.StreamableHTTPOptions {
|
||||
if ('url' in options && options.type === 'streamable-http') {
|
||||
const protocol = new URL(options.url).protocol;
|
||||
return protocol !== 'ws:' && protocol !== 'wss:';
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const FIVE_MINUTES = 5 * 60 * 1000;
|
||||
export class MCPConnection extends EventEmitter {
|
||||
private static instance: MCPConnection | null = null;
|
||||
public client: Client;
|
||||
private transport: Transport | null = null; // Make this nullable
|
||||
private connectionState: t.ConnectionState = 'disconnected';
|
||||
private connectPromise: Promise<void> | null = null;
|
||||
private lastError: Error | null = null;
|
||||
private lastConfigUpdate = 0;
|
||||
private readonly CONFIG_TTL = 5 * 60 * 1000; // 5 minutes
|
||||
private readonly MAX_RECONNECT_ATTEMPTS = 3;
|
||||
public readonly serverName: string;
|
||||
private shouldStopReconnecting = false;
|
||||
private isReconnecting = false;
|
||||
private isInitializing = false;
|
||||
private reconnectAttempts = 0;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
private readonly userId?: string;
|
||||
private lastPingTime: number;
|
||||
|
||||
constructor(
|
||||
serverName: string,
|
||||
private readonly options: t.MCPOptions,
|
||||
private logger?: Logger,
|
||||
userId?: string,
|
||||
) {
|
||||
super();
|
||||
this.serverName = serverName;
|
||||
this.logger = logger;
|
||||
this.userId = userId;
|
||||
this.iconPath = options.iconPath;
|
||||
this.timeout = options.timeout;
|
||||
this.lastPingTime = Date.now();
|
||||
this.client = new Client(
|
||||
{
|
||||
name: '@librechat/api-client',
|
||||
version: '1.2.2',
|
||||
},
|
||||
{
|
||||
capabilities: {},
|
||||
},
|
||||
);
|
||||
|
||||
this.setupEventListeners();
|
||||
}
|
||||
|
||||
/** Helper to generate consistent log prefixes */
|
||||
private getLogPrefix(): string {
|
||||
const userPart = this.userId ? `[User: ${this.userId}]` : '';
|
||||
return `[MCP]${userPart}[${this.serverName}]`;
|
||||
}
|
||||
|
||||
public static getInstance(
|
||||
serverName: string,
|
||||
options: t.MCPOptions,
|
||||
logger?: Logger,
|
||||
userId?: string,
|
||||
): MCPConnection {
|
||||
if (!MCPConnection.instance) {
|
||||
MCPConnection.instance = new MCPConnection(serverName, options, logger, userId);
|
||||
}
|
||||
return MCPConnection.instance;
|
||||
}
|
||||
|
||||
public static getExistingInstance(): MCPConnection | null {
|
||||
return MCPConnection.instance;
|
||||
}
|
||||
|
||||
public static async destroyInstance(): Promise<void> {
|
||||
if (MCPConnection.instance) {
|
||||
await MCPConnection.instance.disconnect();
|
||||
MCPConnection.instance = null;
|
||||
}
|
||||
}
|
||||
|
||||
private emitError(error: unknown, errorContext: string): void {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
this.logger?.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
|
||||
this.emit('error', new Error(`${errorContext}: ${errorMessage}`));
|
||||
}
|
||||
|
||||
private constructTransport(options: t.MCPOptions): Transport {
|
||||
try {
|
||||
let type: t.MCPOptions['type'];
|
||||
if (isStdioOptions(options)) {
|
||||
type = 'stdio';
|
||||
} else if (isWebSocketOptions(options)) {
|
||||
type = 'websocket';
|
||||
} else if (isStreamableHTTPOptions(options)) {
|
||||
type = 'streamable-http';
|
||||
} else if (isSSEOptions(options)) {
|
||||
type = 'sse';
|
||||
} else {
|
||||
throw new Error(
|
||||
'Cannot infer transport type: options.type is not provided and cannot be inferred from other properties.',
|
||||
);
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case 'stdio':
|
||||
if (!isStdioOptions(options)) {
|
||||
throw new Error('Invalid options for stdio transport.');
|
||||
}
|
||||
return new StdioClientTransport({
|
||||
command: options.command,
|
||||
args: options.args,
|
||||
// workaround bug of mcp sdk that can't pass env:
|
||||
// https://github.com/modelcontextprotocol/typescript-sdk/issues/216
|
||||
env: { ...getDefaultEnvironment(), ...(options.env ?? {}) },
|
||||
});
|
||||
|
||||
case 'websocket':
|
||||
if (!isWebSocketOptions(options)) {
|
||||
throw new Error('Invalid options for websocket transport.');
|
||||
}
|
||||
return new WebSocketClientTransport(new URL(options.url));
|
||||
|
||||
case 'sse': {
|
||||
if (!isSSEOptions(options)) {
|
||||
throw new Error('Invalid options for sse transport.');
|
||||
}
|
||||
const url = new URL(options.url);
|
||||
this.logger?.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
|
||||
const abortController = new AbortController();
|
||||
const transport = new SSEClientTransport(url, {
|
||||
requestInit: {
|
||||
headers: options.headers,
|
||||
signal: abortController.signal,
|
||||
},
|
||||
eventSourceInit: {
|
||||
fetch: (url, init) => {
|
||||
const headers = new Headers(Object.assign({}, init?.headers, options.headers));
|
||||
return fetch(url, {
|
||||
...init,
|
||||
headers,
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
transport.onclose = () => {
|
||||
this.logger?.info(`${this.getLogPrefix()} SSE transport closed`);
|
||||
this.emit('connectionChange', 'disconnected');
|
||||
};
|
||||
|
||||
transport.onerror = (error) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} SSE transport error:`, error);
|
||||
this.emitError(error, 'SSE transport error:');
|
||||
};
|
||||
|
||||
transport.onmessage = (message) => {
|
||||
this.logger?.info(
|
||||
`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`,
|
||||
);
|
||||
};
|
||||
|
||||
this.setupTransportErrorHandlers(transport);
|
||||
return transport;
|
||||
}
|
||||
|
||||
case 'streamable-http': {
|
||||
if (!isStreamableHTTPOptions(options)) {
|
||||
throw new Error('Invalid options for streamable-http transport.');
|
||||
}
|
||||
const url = new URL(options.url);
|
||||
this.logger?.info(
|
||||
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(url, {
|
||||
requestInit: {
|
||||
headers: options.headers,
|
||||
signal: abortController.signal,
|
||||
},
|
||||
});
|
||||
|
||||
transport.onclose = () => {
|
||||
this.logger?.info(`${this.getLogPrefix()} Streamable-http transport closed`);
|
||||
this.emit('connectionChange', 'disconnected');
|
||||
};
|
||||
|
||||
transport.onerror = (error: Error | unknown) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} Streamable-http transport error:`, error);
|
||||
this.emitError(error, 'Streamable-http transport error:');
|
||||
};
|
||||
|
||||
transport.onmessage = (message: JSONRPCMessage) => {
|
||||
this.logger?.info(
|
||||
`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`,
|
||||
);
|
||||
};
|
||||
|
||||
this.setupTransportErrorHandlers(transport);
|
||||
return transport;
|
||||
}
|
||||
|
||||
default: {
|
||||
throw new Error(`Unsupported transport type: ${type}`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
this.emitError(error, 'Failed to construct transport:');
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private setupEventListeners(): void {
|
||||
this.isInitializing = true;
|
||||
this.on('connectionChange', (state: t.ConnectionState) => {
|
||||
this.connectionState = state;
|
||||
if (state === 'connected') {
|
||||
this.isReconnecting = false;
|
||||
this.isInitializing = false;
|
||||
this.shouldStopReconnecting = false;
|
||||
this.reconnectAttempts = 0;
|
||||
/**
|
||||
* // FOR DEBUGGING
|
||||
* // this.client.setRequestHandler(PingRequestSchema, async (request, extra) => {
|
||||
* // this.logger?.info(`[MCP][${this.serverName}] PingRequest: ${JSON.stringify(request)}`);
|
||||
* // if (getEventListeners && extra.signal) {
|
||||
* // const listenerCount = getEventListeners(extra.signal, 'abort').length;
|
||||
* // this.logger?.debug(`Signal has ${listenerCount} abort listeners`);
|
||||
* // }
|
||||
* // return {};
|
||||
* // });
|
||||
*/
|
||||
} else if (state === 'error' && !this.isReconnecting && !this.isInitializing) {
|
||||
this.handleReconnection().catch((error) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} Reconnection handler failed:`, error);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
this.subscribeToResources();
|
||||
}
|
||||
|
||||
private async handleReconnection(): Promise<void> {
|
||||
if (this.isReconnecting || this.shouldStopReconnecting || this.isInitializing) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.isReconnecting = true;
|
||||
const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000);
|
||||
|
||||
try {
|
||||
while (
|
||||
this.reconnectAttempts < this.MAX_RECONNECT_ATTEMPTS &&
|
||||
!(this.shouldStopReconnecting as boolean)
|
||||
) {
|
||||
this.reconnectAttempts++;
|
||||
const delay = backoffDelay(this.reconnectAttempts);
|
||||
|
||||
this.logger?.info(
|
||||
`${this.getLogPrefix()} Reconnecting ${this.reconnectAttempts}/${this.MAX_RECONNECT_ATTEMPTS} (delay: ${delay}ms)`,
|
||||
);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
|
||||
try {
|
||||
await this.connect();
|
||||
this.reconnectAttempts = 0;
|
||||
return;
|
||||
} catch (error) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Reconnection attempt failed:`, error);
|
||||
|
||||
if (
|
||||
this.reconnectAttempts === this.MAX_RECONNECT_ATTEMPTS ||
|
||||
(this.shouldStopReconnecting as boolean)
|
||||
) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Stopping reconnection attempts`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
this.isReconnecting = false;
|
||||
}
|
||||
}
|
||||
|
||||
private subscribeToResources(): void {
|
||||
this.client.setNotificationHandler(ResourceListChangedNotificationSchema, async () => {
|
||||
this.invalidateCache();
|
||||
this.emit('resourcesChanged');
|
||||
});
|
||||
}
|
||||
|
||||
private invalidateCache(): void {
|
||||
// this.cachedConfig = null;
|
||||
this.lastConfigUpdate = 0;
|
||||
}
|
||||
|
||||
async connectClient(): Promise<void> {
|
||||
if (this.connectionState === 'connected') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.connectPromise) {
|
||||
return this.connectPromise;
|
||||
}
|
||||
|
||||
if (this.shouldStopReconnecting) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.emit('connectionChange', 'connecting');
|
||||
|
||||
this.connectPromise = (async () => {
|
||||
try {
|
||||
if (this.transport) {
|
||||
try {
|
||||
await this.client.close();
|
||||
this.transport = null;
|
||||
} catch (error) {
|
||||
this.logger?.warn(`${this.getLogPrefix()} Error closing connection:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
this.transport = this.constructTransport(this.options);
|
||||
this.setupTransportDebugHandlers();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 10000;
|
||||
await Promise.race([
|
||||
this.client.connect(this.transport),
|
||||
new Promise((_resolve, reject) =>
|
||||
setTimeout(() => reject(new Error('Connection timeout')), connectTimeout),
|
||||
),
|
||||
]);
|
||||
|
||||
this.connectionState = 'connected';
|
||||
this.emit('connectionChange', 'connected');
|
||||
this.reconnectAttempts = 0;
|
||||
} catch (error) {
|
||||
this.connectionState = 'error';
|
||||
this.emit('connectionChange', 'error');
|
||||
this.lastError = error instanceof Error ? error : new Error(String(error));
|
||||
throw error;
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
return this.connectPromise;
|
||||
}
|
||||
|
||||
private setupTransportDebugHandlers(): void {
|
||||
if (!this.transport) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.transport.onmessage = (msg) => {
|
||||
this.logger?.debug(`${this.getLogPrefix()} Transport received: ${JSON.stringify(msg)}`);
|
||||
};
|
||||
|
||||
const originalSend = this.transport.send.bind(this.transport);
|
||||
this.transport.send = async (msg) => {
|
||||
if ('result' in msg && !('method' in msg) && Object.keys(msg.result ?? {}).length === 0) {
|
||||
if (Date.now() - this.lastPingTime < FIVE_MINUTES) {
|
||||
throw new Error('Empty result');
|
||||
}
|
||||
this.lastPingTime = Date.now();
|
||||
}
|
||||
this.logger?.debug(`${this.getLogPrefix()} Transport sending: ${JSON.stringify(msg)}`);
|
||||
return originalSend(msg);
|
||||
};
|
||||
}
|
||||
|
||||
async connect(): Promise<void> {
|
||||
try {
|
||||
await this.disconnect();
|
||||
await this.connectClient();
|
||||
if (!this.isConnected()) {
|
||||
throw new Error('Connection not established');
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Connection failed:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private setupTransportErrorHandlers(transport: Transport): void {
|
||||
transport.onerror = (error) => {
|
||||
this.logger?.error(`${this.getLogPrefix()} Transport error:`, error);
|
||||
this.emit('connectionChange', 'error');
|
||||
};
|
||||
}
|
||||
|
||||
public async disconnect(): Promise<void> {
|
||||
try {
|
||||
if (this.transport) {
|
||||
await this.client.close();
|
||||
this.transport = null;
|
||||
}
|
||||
if (this.connectionState === 'disconnected') {
|
||||
return;
|
||||
}
|
||||
this.connectionState = 'disconnected';
|
||||
this.emit('connectionChange', 'disconnected');
|
||||
} catch (error) {
|
||||
this.emit('error', error);
|
||||
throw error;
|
||||
} finally {
|
||||
this.invalidateCache();
|
||||
this.connectPromise = null;
|
||||
}
|
||||
}
|
||||
|
||||
async fetchResources(): Promise<t.MCPResource[]> {
|
||||
try {
|
||||
const { resources } = await this.client.listResources();
|
||||
return resources;
|
||||
} catch (error) {
|
||||
this.emitError(error, 'Failed to fetch resources:');
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async fetchTools() {
|
||||
try {
|
||||
const { tools } = await this.client.listTools();
|
||||
return tools;
|
||||
} catch (error) {
|
||||
this.emitError(error, 'Failed to fetch tools:');
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async fetchPrompts(): Promise<t.MCPPrompt[]> {
|
||||
try {
|
||||
const { prompts } = await this.client.listPrompts();
|
||||
return prompts;
|
||||
} catch (error) {
|
||||
this.emitError(error, 'Failed to fetch prompts:');
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
// public async modifyConfig(config: ContinueConfig): Promise<ContinueConfig> {
|
||||
// try {
|
||||
// // Check cache
|
||||
// if (this.cachedConfig && Date.now() - this.lastConfigUpdate < this.CONFIG_TTL) {
|
||||
// return this.cachedConfig;
|
||||
// }
|
||||
|
||||
// await this.connectClient();
|
||||
|
||||
// // Fetch and process resources
|
||||
// const resources = await this.fetchResources();
|
||||
// const submenuItems = resources.map(resource => ({
|
||||
// title: resource.name,
|
||||
// description: resource.description,
|
||||
// id: resource.uri,
|
||||
// }));
|
||||
|
||||
// if (!config.contextProviders) {
|
||||
// config.contextProviders = [];
|
||||
// }
|
||||
|
||||
// config.contextProviders.push(
|
||||
// new MCPContextProvider({
|
||||
// submenuItems,
|
||||
// client: this.client,
|
||||
// }),
|
||||
// );
|
||||
|
||||
// // Fetch and process tools
|
||||
// const tools = await this.fetchTools();
|
||||
// const continueTools: Tool[] = tools.map(tool => ({
|
||||
// displayTitle: tool.name,
|
||||
// function: {
|
||||
// description: tool.description,
|
||||
// name: tool.name,
|
||||
// parameters: tool.inputSchema,
|
||||
// },
|
||||
// readonly: false,
|
||||
// type: 'function',
|
||||
// wouldLikeTo: `use the ${tool.name} tool`,
|
||||
// uri: `mcp://${tool.name}`,
|
||||
// }));
|
||||
|
||||
// config.tools = [...(config.tools || []), ...continueTools];
|
||||
|
||||
// // Fetch and process prompts
|
||||
// const prompts = await this.fetchPrompts();
|
||||
// if (!config.slashCommands) {
|
||||
// config.slashCommands = [];
|
||||
// }
|
||||
|
||||
// const slashCommands: SlashCommand[] = prompts.map(prompt =>
|
||||
// constructMcpSlashCommand(
|
||||
// this.client,
|
||||
// prompt.name,
|
||||
// prompt.description,
|
||||
// prompt.arguments?.map(a => a.name),
|
||||
// ),
|
||||
// );
|
||||
// config.slashCommands.push(...slashCommands);
|
||||
|
||||
// // Update cache
|
||||
// this.cachedConfig = config;
|
||||
// this.lastConfigUpdate = Date.now();
|
||||
|
||||
// return config;
|
||||
// } catch (error) {
|
||||
// this.emit('error', error);
|
||||
// // Return original config if modification fails
|
||||
// return config;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Public getters for state information
|
||||
public getConnectionState(): t.ConnectionState {
|
||||
return this.connectionState;
|
||||
}
|
||||
|
||||
public async isConnected(): Promise<boolean> {
|
||||
try {
|
||||
await this.client.ping();
|
||||
return this.connectionState === 'connected';
|
||||
} catch (error) {
|
||||
this.logger?.error(`${this.getLogPrefix()} Ping failed:`, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public getLastError(): Error | null {
|
||||
return this.lastError;
|
||||
}
|
||||
}
|
||||
3
packages/api/src/mcp/enum.ts
Normal file
3
packages/api/src/mcp/enum.ts
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
export enum CONSTANTS {
|
||||
mcp_delimiter = '_mcp_',
|
||||
}
|
||||
617
packages/api/src/mcp/manager.ts
Normal file
617
packages/api/src/mcp/manager.ts
Normal file
|
|
@ -0,0 +1,617 @@
|
|||
import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
|
||||
import type { JsonSchemaType, MCPOptions } from 'librechat-data-provider';
|
||||
import type { Logger } from 'winston';
|
||||
import type * as t from './types';
|
||||
import { formatToolContent } from './parsers';
|
||||
import { MCPConnection } from './connection';
|
||||
import { CONSTANTS } from './enum';
|
||||
|
||||
export interface CallToolOptions extends RequestOptions {
|
||||
userId?: string;
|
||||
}
|
||||
|
||||
export class MCPManager {
|
||||
private static instance: MCPManager | null = null;
|
||||
/** App-level connections initialized at startup */
|
||||
private connections: Map<string, MCPConnection> = new Map();
|
||||
/** User-specific connections initialized on demand */
|
||||
private userConnections: Map<string, Map<string, MCPConnection>> = new Map();
|
||||
/** Last activity timestamp for users (not per server) */
|
||||
private userLastActivity: Map<string, number> = new Map();
|
||||
private readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
|
||||
private mcpConfigs: t.MCPServers = {};
|
||||
private processMCPEnv?: (obj: MCPOptions, userId?: string) => MCPOptions; // Store the processing function
|
||||
/** Store MCP server instructions */
|
||||
private serverInstructions: Map<string, string> = new Map();
|
||||
private logger: Logger;
|
||||
|
||||
private static getDefaultLogger(): Logger {
|
||||
return {
|
||||
error: console.error,
|
||||
warn: console.warn,
|
||||
info: console.info,
|
||||
debug: console.debug,
|
||||
} as Logger;
|
||||
}
|
||||
|
||||
private constructor(logger?: Logger) {
|
||||
this.logger = logger || MCPManager.getDefaultLogger();
|
||||
}
|
||||
|
||||
public static getInstance(logger?: Logger): MCPManager {
|
||||
if (!MCPManager.instance) {
|
||||
MCPManager.instance = new MCPManager(logger);
|
||||
}
|
||||
// Check for idle connections when getInstance is called
|
||||
MCPManager.instance.checkIdleConnections();
|
||||
return MCPManager.instance;
|
||||
}
|
||||
|
||||
/** Stores configs and initializes app-level connections */
|
||||
public async initializeMCP(
|
||||
mcpServers: t.MCPServers,
|
||||
processMCPEnv?: (obj: MCPOptions) => MCPOptions,
|
||||
): Promise<void> {
|
||||
this.logger.info('[MCP] Initializing app-level servers');
|
||||
this.processMCPEnv = processMCPEnv; // Store the function
|
||||
this.mcpConfigs = mcpServers;
|
||||
|
||||
const entries = Object.entries(mcpServers);
|
||||
const initializedServers = new Set();
|
||||
const connectionResults = await Promise.allSettled(
|
||||
entries.map(async ([serverName, _config], i) => {
|
||||
/** Process env for app-level connections */
|
||||
const config = this.processMCPEnv ? this.processMCPEnv(_config) : _config;
|
||||
const connection = new MCPConnection(serverName, config, this.logger);
|
||||
|
||||
try {
|
||||
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Connection timeout')), 30000),
|
||||
);
|
||||
|
||||
const connectionAttempt = this.initializeServer(connection, `[MCP][${serverName}]`);
|
||||
await Promise.race([connectionAttempt, connectionTimeout]);
|
||||
|
||||
if (await connection.isConnected()) {
|
||||
initializedServers.add(i);
|
||||
this.connections.set(serverName, connection); // Store in app-level map
|
||||
|
||||
// Handle unified serverInstructions configuration
|
||||
const configInstructions = config.serverInstructions;
|
||||
|
||||
if (configInstructions !== undefined) {
|
||||
if (typeof configInstructions === 'string') {
|
||||
// Custom instructions provided
|
||||
this.serverInstructions.set(serverName, configInstructions);
|
||||
this.logger.info(
|
||||
`[MCP][${serverName}] Custom instructions stored for context inclusion: ${configInstructions}`,
|
||||
);
|
||||
} else if (configInstructions === true) {
|
||||
// Use server-provided instructions
|
||||
const serverInstructions = connection.client.getInstructions();
|
||||
|
||||
if (serverInstructions) {
|
||||
this.serverInstructions.set(serverName, serverInstructions);
|
||||
this.logger.info(
|
||||
`[MCP][${serverName}] Server instructions stored for context inclusion: ${serverInstructions}`,
|
||||
);
|
||||
} else {
|
||||
this.logger.info(
|
||||
`[MCP][${serverName}] serverInstructions=true but no server instructions available`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// configInstructions is false - explicitly disabled
|
||||
this.logger.info(
|
||||
`[MCP][${serverName}] Instructions explicitly disabled (serverInstructions=false)`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
this.logger.info(
|
||||
`[MCP][${serverName}] Instructions not included (serverInstructions not configured)`,
|
||||
);
|
||||
}
|
||||
|
||||
const serverCapabilities = connection.client.getServerCapabilities();
|
||||
this.logger.info(
|
||||
`[MCP][${serverName}] Capabilities: ${JSON.stringify(serverCapabilities)}`,
|
||||
);
|
||||
|
||||
if (serverCapabilities?.tools) {
|
||||
const tools = await connection.client.listTools();
|
||||
if (tools.tools.length) {
|
||||
this.logger.info(
|
||||
`[MCP][${serverName}] Available tools: ${tools.tools
|
||||
.map((tool) => tool.name)
|
||||
.join(', ')}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.error(`[MCP][${serverName}] Initialization failed`, error);
|
||||
throw error;
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
const failedConnections = connectionResults.filter(
|
||||
(result): result is PromiseRejectedResult => result.status === 'rejected',
|
||||
);
|
||||
|
||||
this.logger.info(
|
||||
`[MCP] Initialized ${initializedServers.size}/${entries.length} app-level server(s)`,
|
||||
);
|
||||
|
||||
if (failedConnections.length > 0) {
|
||||
this.logger.warn(
|
||||
`[MCP] ${failedConnections.length}/${entries.length} app-level server(s) failed to initialize`,
|
||||
);
|
||||
}
|
||||
|
||||
entries.forEach(([serverName], index) => {
|
||||
if (initializedServers.has(index)) {
|
||||
this.logger.info(`[MCP][${serverName}] ✓ Initialized`);
|
||||
} else {
|
||||
this.logger.info(`[MCP][${serverName}] ✗ Failed`);
|
||||
}
|
||||
});
|
||||
|
||||
if (initializedServers.size === entries.length) {
|
||||
this.logger.info('[MCP] All app-level servers initialized successfully');
|
||||
} else if (initializedServers.size === 0) {
|
||||
this.logger.warn('[MCP] No app-level servers initialized');
|
||||
}
|
||||
}
|
||||
|
||||
/** Generic server initialization logic */
|
||||
private async initializeServer(connection: MCPConnection, logPrefix: string): Promise<void> {
|
||||
const maxAttempts = 3;
|
||||
let attempts = 0;
|
||||
|
||||
while (attempts < maxAttempts) {
|
||||
try {
|
||||
await connection.connect();
|
||||
if (await connection.isConnected()) {
|
||||
return;
|
||||
}
|
||||
throw new Error('Connection attempt succeeded but status is not connected');
|
||||
} catch (error) {
|
||||
attempts++;
|
||||
if (attempts === maxAttempts) {
|
||||
this.logger.error(`${logPrefix} Failed to connect after ${maxAttempts} attempts`, error);
|
||||
throw error; // Re-throw the last error
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000 * attempts));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Check for and disconnect idle connections */
|
||||
private checkIdleConnections(currentUserId?: string): void {
|
||||
const now = Date.now();
|
||||
|
||||
// Iterate through all users to check for idle ones
|
||||
for (const [userId, lastActivity] of this.userLastActivity.entries()) {
|
||||
if (currentUserId && currentUserId === userId) {
|
||||
continue;
|
||||
}
|
||||
if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
|
||||
this.logger.info(
|
||||
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`,
|
||||
);
|
||||
// Disconnect all user connections asynchronously (fire and forget)
|
||||
this.disconnectUserConnections(userId).catch((err) =>
|
||||
this.logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Updates the last activity timestamp for a user */
|
||||
private updateUserLastActivity(userId: string): void {
|
||||
const now = Date.now();
|
||||
this.userLastActivity.set(userId, now);
|
||||
this.logger.debug(
|
||||
`[MCP][User: ${userId}] Updated last activity timestamp: ${new Date(now).toISOString()}`,
|
||||
);
|
||||
}
|
||||
|
||||
/** Gets or creates a connection for a specific user */
|
||||
public async getUserConnection(userId: string, serverName: string): Promise<MCPConnection> {
|
||||
const userServerMap = this.userConnections.get(userId);
|
||||
let connection = userServerMap?.get(serverName);
|
||||
const now = Date.now();
|
||||
|
||||
// Check if user is idle
|
||||
const lastActivity = this.userLastActivity.get(userId);
|
||||
if (lastActivity && now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
|
||||
this.logger.info(
|
||||
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections.`,
|
||||
);
|
||||
// Disconnect all user connections
|
||||
try {
|
||||
await this.disconnectUserConnections(userId);
|
||||
} catch (err) {
|
||||
this.logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err);
|
||||
}
|
||||
connection = undefined; // Force creation of a new connection
|
||||
} else if (connection) {
|
||||
if (await connection.isConnected()) {
|
||||
this.logger.debug(`[MCP][User: ${userId}][${serverName}] Reusing active connection`);
|
||||
// Update timestamp on reuse
|
||||
this.updateUserLastActivity(userId);
|
||||
return connection;
|
||||
} else {
|
||||
// Connection exists but is not connected, attempt to remove potentially stale entry
|
||||
this.logger.warn(
|
||||
`[MCP][User: ${userId}][${serverName}] Found existing but disconnected connection object. Cleaning up.`,
|
||||
);
|
||||
this.removeUserConnection(userId, serverName); // Clean up maps
|
||||
connection = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid connection exists, create a new one
|
||||
if (!connection) {
|
||||
this.logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
|
||||
}
|
||||
|
||||
let config = this.mcpConfigs[serverName];
|
||||
if (!config) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
`[MCP][User: ${userId}] Configuration for server "${serverName}" not found.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (this.processMCPEnv) {
|
||||
config = { ...(this.processMCPEnv(config, userId) ?? {}) };
|
||||
}
|
||||
|
||||
connection = new MCPConnection(serverName, config, this.logger, userId);
|
||||
|
||||
try {
|
||||
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Connection timeout')), 30000),
|
||||
);
|
||||
const connectionAttempt = this.initializeServer(
|
||||
connection,
|
||||
`[MCP][User: ${userId}][${serverName}]`,
|
||||
);
|
||||
await Promise.race([connectionAttempt, connectionTimeout]);
|
||||
|
||||
if (!(await connection.isConnected())) {
|
||||
throw new Error('Failed to establish connection after initialization attempt.');
|
||||
}
|
||||
|
||||
if (!this.userConnections.has(userId)) {
|
||||
this.userConnections.set(userId, new Map());
|
||||
}
|
||||
this.userConnections.get(userId)?.set(serverName, connection);
|
||||
this.logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`);
|
||||
// Update timestamp on creation
|
||||
this.updateUserLastActivity(userId);
|
||||
return connection;
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`[MCP][User: ${userId}][${serverName}] Failed to establish connection`,
|
||||
error,
|
||||
);
|
||||
// Ensure partial connection state is cleaned up if initialization fails
|
||||
await connection.disconnect().catch((disconnectError) => {
|
||||
this.logger.error(
|
||||
`[MCP][User: ${userId}][${serverName}] Error during cleanup after failed connection`,
|
||||
disconnectError,
|
||||
);
|
||||
});
|
||||
// Ensure cleanup even if connection attempt fails
|
||||
this.removeUserConnection(userId, serverName);
|
||||
throw error; // Re-throw the error to the caller
|
||||
}
|
||||
}
|
||||
|
||||
/** Removes a specific user connection entry */
|
||||
private removeUserConnection(userId: string, serverName: string): void {
|
||||
// Remove connection object
|
||||
const userMap = this.userConnections.get(userId);
|
||||
if (userMap) {
|
||||
userMap.delete(serverName);
|
||||
if (userMap.size === 0) {
|
||||
this.userConnections.delete(userId);
|
||||
// Only remove user activity timestamp if all connections are gone
|
||||
this.userLastActivity.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
this.logger.debug(`[MCP][User: ${userId}][${serverName}] Removed connection entry.`);
|
||||
}
|
||||
|
||||
/** Disconnects and removes a specific user connection */
|
||||
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
|
||||
const userMap = this.userConnections.get(userId);
|
||||
const connection = userMap?.get(serverName);
|
||||
if (connection) {
|
||||
this.logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`);
|
||||
await connection.disconnect();
|
||||
this.removeUserConnection(userId, serverName);
|
||||
}
|
||||
}
|
||||
|
||||
/** Disconnects and removes all connections for a specific user */
|
||||
public async disconnectUserConnections(userId: string): Promise<void> {
|
||||
const userMap = this.userConnections.get(userId);
|
||||
const disconnectPromises: Promise<void>[] = [];
|
||||
if (userMap) {
|
||||
this.logger.info(`[MCP][User: ${userId}] Disconnecting all servers...`);
|
||||
const userServers = Array.from(userMap.keys());
|
||||
for (const serverName of userServers) {
|
||||
disconnectPromises.push(
|
||||
this.disconnectUserConnection(userId, serverName).catch((error) => {
|
||||
this.logger.error(
|
||||
`[MCP][User: ${userId}][${serverName}] Error during disconnection:`,
|
||||
error,
|
||||
);
|
||||
}),
|
||||
);
|
||||
}
|
||||
await Promise.allSettled(disconnectPromises);
|
||||
// Ensure user activity timestamp is removed
|
||||
this.userLastActivity.delete(userId);
|
||||
this.logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns the app-level connection (used for mapping tools, etc.) */
|
||||
public getConnection(serverName: string): MCPConnection | undefined {
|
||||
return this.connections.get(serverName);
|
||||
}
|
||||
|
||||
/** Returns all app-level connections */
|
||||
public getAllConnections(): Map<string, MCPConnection> {
|
||||
return this.connections;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps available tools from all app-level connections into the provided object.
|
||||
* The object is modified in place.
|
||||
*/
|
||||
public async mapAvailableTools(availableTools: t.LCAvailableTools): Promise<void> {
|
||||
for (const [serverName, connection] of this.connections.entries()) {
|
||||
try {
|
||||
if ((await connection.isConnected()) !== true) {
|
||||
this.logger.warn(
|
||||
`[MCP][${serverName}] Connection not established. Skipping tool mapping.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const tools = await connection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
|
||||
availableTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema as JsonSchemaType,
|
||||
},
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(`[MCP][${serverName}] Error fetching tools for mapping:`, error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads tools from all app-level connections into the manifest.
|
||||
*/
|
||||
public async loadManifestTools(manifestTools: t.LCToolManifest): Promise<t.LCToolManifest> {
|
||||
const mcpTools: t.LCManifestTool[] = [];
|
||||
|
||||
for (const [serverName, connection] of this.connections.entries()) {
|
||||
try {
|
||||
if ((await connection.isConnected()) !== true) {
|
||||
this.logger.warn(
|
||||
`[MCP][${serverName}] Connection not established. Skipping manifest loading.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const tools = await connection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
|
||||
const manifestTool: t.LCManifestTool = {
|
||||
name: tool.name,
|
||||
pluginKey,
|
||||
description: tool.description ?? '',
|
||||
icon: connection.iconPath,
|
||||
};
|
||||
const config = this.mcpConfigs[serverName];
|
||||
if (config?.chatMenu === false) {
|
||||
manifestTool.chatMenu = false;
|
||||
}
|
||||
mcpTools.push(manifestTool);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.error(`[MCP][${serverName}] Error fetching tools for manifest:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
return [...mcpTools, ...manifestTools];
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls a tool on an MCP server, using either a user-specific connection
|
||||
* (if userId is provided) or an app-level connection. Updates the last activity timestamp
|
||||
* for user-specific connections upon successful call initiation.
|
||||
*/
|
||||
async callTool({
|
||||
serverName,
|
||||
toolName,
|
||||
provider,
|
||||
toolArguments,
|
||||
options,
|
||||
}: {
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
provider: t.Provider;
|
||||
toolArguments?: Record<string, unknown>;
|
||||
options?: CallToolOptions;
|
||||
}): Promise<t.FormattedToolResponse> {
|
||||
let connection: MCPConnection | undefined;
|
||||
const { userId, ...callOptions } = options ?? {};
|
||||
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
|
||||
|
||||
try {
|
||||
if (userId) {
|
||||
this.updateUserLastActivity(userId);
|
||||
// Get or create user-specific connection
|
||||
connection = await this.getUserConnection(userId, serverName);
|
||||
} else {
|
||||
// Use app-level connection
|
||||
connection = this.connections.get(serverName);
|
||||
if (!connection) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
`${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!(await connection.isConnected())) {
|
||||
// This might happen if getUserConnection failed silently or app connection dropped
|
||||
throw new McpError(
|
||||
ErrorCode.InternalError, // Use InternalError for connection issues
|
||||
`${logPrefix} Connection is not active. Cannot execute tool ${toolName}.`,
|
||||
);
|
||||
}
|
||||
|
||||
const result = await connection.client.request(
|
||||
{
|
||||
method: 'tools/call',
|
||||
params: {
|
||||
name: toolName,
|
||||
arguments: toolArguments,
|
||||
},
|
||||
},
|
||||
CallToolResultSchema,
|
||||
{
|
||||
timeout: connection.timeout,
|
||||
...callOptions,
|
||||
},
|
||||
);
|
||||
if (userId) {
|
||||
this.updateUserLastActivity(userId);
|
||||
}
|
||||
this.checkIdleConnections();
|
||||
return formatToolContent(result, provider);
|
||||
} catch (error) {
|
||||
// Log with context and re-throw or handle as needed
|
||||
this.logger.error(`${logPrefix}[${toolName}] Tool call failed`, error);
|
||||
// Rethrowing allows the caller (createMCPTool) to handle the final user message
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/** Disconnects a specific app-level server */
|
||||
public async disconnectServer(serverName: string): Promise<void> {
|
||||
const connection = this.connections.get(serverName);
|
||||
if (connection) {
|
||||
this.logger.info(`[MCP][${serverName}] Disconnecting...`);
|
||||
await connection.disconnect();
|
||||
this.connections.delete(serverName);
|
||||
}
|
||||
}
|
||||
|
||||
/** Disconnects all app-level and user-level connections */
|
||||
public async disconnectAll(): Promise<void> {
|
||||
this.logger.info('[MCP] Disconnecting all app-level and user-level connections...');
|
||||
|
||||
const userDisconnectPromises = Array.from(this.userConnections.keys()).map((userId) =>
|
||||
this.disconnectUserConnections(userId),
|
||||
);
|
||||
await Promise.allSettled(userDisconnectPromises);
|
||||
this.userLastActivity.clear();
|
||||
|
||||
// Disconnect all app-level connections
|
||||
const appDisconnectPromises = Array.from(this.connections.values()).map((connection) =>
|
||||
connection.disconnect().catch((error) => {
|
||||
this.logger.error(`[MCP][${connection.serverName}] Error during disconnectAll:`, error);
|
||||
}),
|
||||
);
|
||||
await Promise.allSettled(appDisconnectPromises);
|
||||
this.connections.clear();
|
||||
|
||||
this.logger.info('[MCP] All connections processed for disconnection.');
|
||||
}
|
||||
|
||||
/** Destroys the singleton instance and disconnects all connections */
|
||||
public static async destroyInstance(): Promise<void> {
|
||||
if (MCPManager.instance) {
|
||||
await MCPManager.instance.disconnectAll();
|
||||
MCPManager.instance = null;
|
||||
const logger = MCPManager.getDefaultLogger();
|
||||
logger.info('[MCP] Manager instance destroyed.');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get instructions for MCP servers
|
||||
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
|
||||
* @returns Object mapping server names to their instructions
|
||||
*/
|
||||
public getInstructions(serverNames?: string[]): Record<string, string> {
|
||||
const instructions: Record<string, string> = {};
|
||||
|
||||
if (!serverNames || serverNames.length === 0) {
|
||||
// Return all instructions if no specific servers requested
|
||||
for (const [serverName, serverInstructions] of this.serverInstructions.entries()) {
|
||||
instructions[serverName] = serverInstructions;
|
||||
}
|
||||
} else {
|
||||
// Return instructions for specific servers
|
||||
for (const serverName of serverNames) {
|
||||
const serverInstructions = this.serverInstructions.get(serverName);
|
||||
if (serverInstructions) {
|
||||
instructions[serverName] = serverInstructions;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return instructions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format MCP server instructions for injection into context
|
||||
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
|
||||
* @returns Formatted instructions string ready for context injection
|
||||
*/
|
||||
public formatInstructionsForContext(serverNames?: string[]): string {
|
||||
/** Instructions for specified servers or all stored instructions */
|
||||
const instructionsToInclude = this.getInstructions(serverNames);
|
||||
|
||||
if (Object.keys(instructionsToInclude).length === 0) {
|
||||
return '';
|
||||
}
|
||||
|
||||
// Format instructions for context injection
|
||||
const formattedInstructions = Object.entries(instructionsToInclude)
|
||||
.map(([serverName, instructions]) => {
|
||||
return `## ${serverName} MCP Server Instructions
|
||||
|
||||
${instructions}`;
|
||||
})
|
||||
.join('\n\n');
|
||||
|
||||
return `# MCP Server Instructions
|
||||
|
||||
The following MCP servers are available with their specific instructions:
|
||||
|
||||
${formattedInstructions}
|
||||
|
||||
Please follow these instructions when using tools from the respective MCP servers.`;
|
||||
}
|
||||
}
|
||||
183
packages/api/src/mcp/parsers.ts
Normal file
183
packages/api/src/mcp/parsers.ts
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
import type * as t from './types';
|
||||
const RECOGNIZED_PROVIDERS = new Set([
|
||||
'google',
|
||||
'anthropic',
|
||||
'openai',
|
||||
'openrouter',
|
||||
'xai',
|
||||
'deepseek',
|
||||
'ollama',
|
||||
]);
|
||||
const CONTENT_ARRAY_PROVIDERS = new Set(['google', 'anthropic', 'openai']);
|
||||
|
||||
const imageFormatters: Record<string, undefined | t.ImageFormatter> = {
|
||||
// google: (item) => ({
|
||||
// type: 'image',
|
||||
// inlineData: {
|
||||
// mimeType: item.mimeType,
|
||||
// data: item.data,
|
||||
// },
|
||||
// }),
|
||||
// anthropic: (item) => ({
|
||||
// type: 'image',
|
||||
// source: {
|
||||
// type: 'base64',
|
||||
// media_type: item.mimeType,
|
||||
// data: item.data,
|
||||
// },
|
||||
// }),
|
||||
default: (item) => ({
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: item.data.startsWith('http') ? item.data : `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
function isImageContent(item: t.ToolContentPart): item is t.ImageContent {
|
||||
return item.type === 'image';
|
||||
}
|
||||
|
||||
function parseAsString(result: t.MCPToolCallResponse): string {
|
||||
const content = result?.content ?? [];
|
||||
if (!content.length) {
|
||||
return '(No response)';
|
||||
}
|
||||
|
||||
const text = content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return item.text;
|
||||
}
|
||||
if (item.type === 'resource') {
|
||||
const resourceText = [];
|
||||
if (item.resource.text != null && item.resource.text) {
|
||||
resourceText.push(item.resource.text);
|
||||
}
|
||||
if (item.resource.uri) {
|
||||
resourceText.push(`Resource URI: ${item.resource.uri}`);
|
||||
}
|
||||
if (item.resource.name) {
|
||||
resourceText.push(`Resource: ${item.resource.name}`);
|
||||
}
|
||||
if (item.resource.description) {
|
||||
resourceText.push(`Description: ${item.resource.description}`);
|
||||
}
|
||||
if (item.resource.mimeType != null && item.resource.mimeType) {
|
||||
resourceText.push(`Type: ${item.resource.mimeType}`);
|
||||
}
|
||||
return resourceText.join('\n');
|
||||
}
|
||||
return JSON.stringify(item, null, 2);
|
||||
})
|
||||
.filter(Boolean)
|
||||
.join('\n\n');
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts MCPToolCallResponse content into recognized content block types
|
||||
* Recognized types: "image", "image_url", "text", "json"
|
||||
*
|
||||
* @param {t.MCPToolCallResponse} result - The MCPToolCallResponse object
|
||||
* @param {string} provider - The provider name (google, anthropic, openai)
|
||||
* @returns {Array<Object>} Formatted content blocks
|
||||
*/
|
||||
/**
|
||||
* Converts MCPToolCallResponse content into recognized content block types
|
||||
* First element: string or formatted content (excluding image_url)
|
||||
* Second element: image_url content if any
|
||||
*
|
||||
* @param {t.MCPToolCallResponse} result - The MCPToolCallResponse object
|
||||
* @param {string} provider - The provider name (google, anthropic, openai)
|
||||
* @returns {t.FormattedContentResult} Tuple of content and image_urls
|
||||
*/
|
||||
export function formatToolContent(
|
||||
result: t.MCPToolCallResponse,
|
||||
provider: t.Provider,
|
||||
): t.FormattedContentResult {
|
||||
if (!RECOGNIZED_PROVIDERS.has(provider)) {
|
||||
return [parseAsString(result), undefined];
|
||||
}
|
||||
|
||||
const content = result?.content ?? [];
|
||||
if (!content.length) {
|
||||
return [[{ type: 'text', text: '(No response)' }], undefined];
|
||||
}
|
||||
|
||||
const formattedContent: t.FormattedContent[] = [];
|
||||
const imageUrls: t.FormattedContent[] = [];
|
||||
let currentTextBlock = '';
|
||||
|
||||
type ContentHandler = undefined | ((item: t.ToolContentPart) => void);
|
||||
|
||||
const contentHandlers: {
|
||||
text: (item: Extract<t.ToolContentPart, { type: 'text' }>) => void;
|
||||
image: (item: t.ToolContentPart) => void;
|
||||
resource: (item: Extract<t.ToolContentPart, { type: 'resource' }>) => void;
|
||||
} = {
|
||||
text: (item) => {
|
||||
currentTextBlock += (currentTextBlock ? '\n\n' : '') + item.text;
|
||||
},
|
||||
|
||||
image: (item) => {
|
||||
if (!isImageContent(item)) {
|
||||
return;
|
||||
}
|
||||
if (CONTENT_ARRAY_PROVIDERS.has(provider) && currentTextBlock) {
|
||||
formattedContent.push({ type: 'text', text: currentTextBlock });
|
||||
currentTextBlock = '';
|
||||
}
|
||||
const formatter = imageFormatters.default as t.ImageFormatter;
|
||||
const formattedImage = formatter(item);
|
||||
|
||||
if (formattedImage.type === 'image_url') {
|
||||
imageUrls.push(formattedImage);
|
||||
} else {
|
||||
formattedContent.push(formattedImage);
|
||||
}
|
||||
},
|
||||
|
||||
resource: (item) => {
|
||||
const resourceText = [];
|
||||
if (item.resource.text != null && item.resource.text) {
|
||||
resourceText.push(item.resource.text);
|
||||
}
|
||||
if (item.resource.uri.length) {
|
||||
resourceText.push(`Resource URI: ${item.resource.uri}`);
|
||||
}
|
||||
if (item.resource.name) {
|
||||
resourceText.push(`Resource: ${item.resource.name}`);
|
||||
}
|
||||
if (item.resource.description) {
|
||||
resourceText.push(`Description: ${item.resource.description}`);
|
||||
}
|
||||
if (item.resource.mimeType != null && item.resource.mimeType) {
|
||||
resourceText.push(`Type: ${item.resource.mimeType}`);
|
||||
}
|
||||
currentTextBlock += (currentTextBlock ? '\n\n' : '') + resourceText.join('\n');
|
||||
},
|
||||
};
|
||||
|
||||
for (const item of content) {
|
||||
const handler = contentHandlers[item.type as keyof typeof contentHandlers] as ContentHandler;
|
||||
if (handler) {
|
||||
handler(item as never);
|
||||
} else {
|
||||
const stringified = JSON.stringify(item, null, 2);
|
||||
currentTextBlock += (currentTextBlock ? '\n\n' : '') + stringified;
|
||||
}
|
||||
}
|
||||
|
||||
if (CONTENT_ARRAY_PROVIDERS.has(provider) && currentTextBlock) {
|
||||
formattedContent.push({ type: 'text', text: currentTextBlock });
|
||||
}
|
||||
|
||||
const artifacts = imageUrls.length ? { content: imageUrls } : undefined;
|
||||
if (CONTENT_ARRAY_PROVIDERS.has(provider)) {
|
||||
return [formattedContent, artifacts];
|
||||
}
|
||||
|
||||
return [currentTextBlock, artifacts];
|
||||
}
|
||||
98
packages/api/src/mcp/types/index.ts
Normal file
98
packages/api/src/mcp/types/index.ts
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import { z } from 'zod';
|
||||
import {
|
||||
SSEOptionsSchema,
|
||||
MCPOptionsSchema,
|
||||
MCPServersSchema,
|
||||
StdioOptionsSchema,
|
||||
WebSocketOptionsSchema,
|
||||
StreamableHTTPOptionsSchema,
|
||||
} from 'librechat-data-provider';
|
||||
import type { JsonSchemaType, TPlugin } from 'librechat-data-provider';
|
||||
import type * as t from '@modelcontextprotocol/sdk/types.js';
|
||||
|
||||
export type StdioOptions = z.infer<typeof StdioOptionsSchema>;
|
||||
export type WebSocketOptions = z.infer<typeof WebSocketOptionsSchema>;
|
||||
export type SSEOptions = z.infer<typeof SSEOptionsSchema>;
|
||||
export type StreamableHTTPOptions = z.infer<typeof StreamableHTTPOptionsSchema>;
|
||||
export type MCPOptions = z.infer<typeof MCPOptionsSchema>;
|
||||
export type MCPServers = z.infer<typeof MCPServersSchema>;
|
||||
export interface MCPResource {
|
||||
uri: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
mimeType?: string;
|
||||
}
|
||||
export interface LCTool {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters: JsonSchemaType;
|
||||
}
|
||||
|
||||
export interface LCFunctionTool {
|
||||
type: 'function';
|
||||
['function']: LCTool;
|
||||
}
|
||||
|
||||
export type LCAvailableTools = Record<string, LCFunctionTool>;
|
||||
export type LCManifestTool = TPlugin;
|
||||
export type LCToolManifest = TPlugin[];
|
||||
export interface MCPPrompt {
|
||||
name: string;
|
||||
description?: string;
|
||||
arguments?: Array<{ name: string }>;
|
||||
}
|
||||
|
||||
export type ConnectionState = 'disconnected' | 'connecting' | 'connected' | 'error';
|
||||
|
||||
export type MCPTool = z.infer<typeof t.ToolSchema>;
|
||||
export type MCPToolListResponse = z.infer<typeof t.ListToolsResultSchema>;
|
||||
export type ToolContentPart = t.TextContent | t.ImageContent | t.EmbeddedResource | t.AudioContent;
|
||||
export type ImageContent = Extract<ToolContentPart, { type: 'image' }>;
|
||||
export type MCPToolCallResponse =
|
||||
| undefined
|
||||
| {
|
||||
_meta?: Record<string, unknown>;
|
||||
content?: Array<ToolContentPart>;
|
||||
isError?: boolean;
|
||||
};
|
||||
|
||||
export type Provider = 'google' | 'anthropic' | 'openAI';
|
||||
|
||||
export type FormattedContent =
|
||||
| {
|
||||
type: 'text';
|
||||
text: string;
|
||||
}
|
||||
| {
|
||||
type: 'image';
|
||||
inlineData: {
|
||||
mimeType: string;
|
||||
data: string;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: 'image';
|
||||
source: {
|
||||
type: 'base64';
|
||||
media_type: string;
|
||||
data: string;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: 'image_url';
|
||||
image_url: {
|
||||
url: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type FormattedContentResult = [
|
||||
string | FormattedContent[],
|
||||
undefined | { content: FormattedContent[] },
|
||||
];
|
||||
|
||||
export type ImageFormatter = (item: ImageContent) => FormattedContent;
|
||||
|
||||
export type FormattedToolResponse = [
|
||||
string | FormattedContent[],
|
||||
{ content: FormattedContent[] } | undefined,
|
||||
];
|
||||
28
packages/api/src/mcp/utils.test.ts
Normal file
28
packages/api/src/mcp/utils.test.ts
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import { normalizeServerName } from './utils';
|
||||
|
||||
describe('normalizeServerName', () => {
|
||||
it('should not modify server names that already match the pattern', () => {
|
||||
const result = normalizeServerName('valid-server_name.123');
|
||||
expect(result).toBe('valid-server_name.123');
|
||||
});
|
||||
|
||||
it('should normalize server names with non-ASCII characters', () => {
|
||||
const result = normalizeServerName('我的服务');
|
||||
// Should generate a fallback name with a hash
|
||||
expect(result).toMatch(/^server_\d+$/);
|
||||
expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/);
|
||||
});
|
||||
|
||||
it('should normalize server names with special characters', () => {
|
||||
const result = normalizeServerName('server@name!');
|
||||
// The actual result doesn't have the trailing underscore after trimming
|
||||
expect(result).toBe('server_name');
|
||||
expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/);
|
||||
});
|
||||
|
||||
it('should trim leading and trailing underscores', () => {
|
||||
const result = normalizeServerName('!server-name!');
|
||||
expect(result).toBe('server-name');
|
||||
expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/);
|
||||
});
|
||||
});
|
||||
30
packages/api/src/mcp/utils.ts
Normal file
30
packages/api/src/mcp/utils.ts
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$
|
||||
* This is required for Azure OpenAI models with Tool Calling
|
||||
*/
|
||||
export function normalizeServerName(serverName: string): string {
|
||||
// Check if the server name already matches the pattern
|
||||
if (/^[a-zA-Z0-9_.-]+$/.test(serverName)) {
|
||||
return serverName;
|
||||
}
|
||||
|
||||
/** Replace non-matching characters with underscores.
|
||||
This preserves the general structure while ensuring compatibility.
|
||||
Trims leading/trailing underscores
|
||||
*/
|
||||
const normalized = serverName.replace(/[^a-zA-Z0-9_.-]/g, '_').replace(/^_+|_+$/g, '');
|
||||
|
||||
// If the result is empty (e.g., all characters were non-ASCII and got trimmed),
|
||||
// generate a fallback name to ensure we always have a valid function name
|
||||
if (!normalized) {
|
||||
/** Hash of the original name to ensure uniqueness */
|
||||
let hash = 0;
|
||||
for (let i = 0; i < serverName.length; i++) {
|
||||
hash = (hash << 5) - hash + serverName.charCodeAt(i);
|
||||
hash |= 0; // Convert to 32bit integer
|
||||
}
|
||||
return `server_${Math.abs(hash)}`;
|
||||
}
|
||||
|
||||
return normalized;
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue