diff --git a/api/package.json b/api/package.json index 62ffa8d9c3..492dbd02b9 100644 --- a/api/package.json +++ b/api/package.json @@ -45,7 +45,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.27", + "@librechat/agents": "^3.1.29", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index cb5fa79a48..c9d8ff25f4 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -1,22 +1,28 @@ const { sleep, EnvVar, - Constants, + StepTypes, + GraphEvents, createToolSearch, createProgrammaticToolCallingTool, } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); const { tool: toolFn, DynamicStructuredTool } = require('@langchain/core/tools'); const { + sendEvent, getToolkitKey, hasCustomUserVars, getUserMCPAuthMap, loadToolDefinitions, + GenerationJobManager, isActionDomainAllowed, buildToolClassification, } = require('@librechat/api'); const { + Time, Tools, + Constants, + CacheKeys, ErrorTypes, ContentTypes, imageGenTools, @@ -45,6 +51,8 @@ const { getCachedTools, getMCPServerTools, } = require('~/server/services/Config'); +const { getFlowStateManager } = require('~/config'); +const { getLogStores } = require('~/cache'); const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); @@ -409,7 +417,9 @@ const isBuiltInTool = (toolName) => * * @param {Object} params * @param {ServerRequest} params.req - The request object + * @param {ServerResponse} [params.res] - The response object for SSE events * @param {Object} params.agent - The agent configuration + * @param {string|null} [params.streamId] - Stream ID for resumable mode * @returns {Promise<{ * toolDefinitions?: import('@librechat/api').LCTool[]; * toolRegistry?: Map; @@ -417,7 +427,7 @@ const isBuiltInTool = (toolName) => * hasDeferredTools?: boolean; * }>} */ -async function loadToolDefinitionsWrapper({ req, agent }) { +async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null }) { if (!agent.tools || agent.tools.length === 0) { return { toolDefinitions: [] }; } @@ -473,14 +483,72 @@ async function loadToolDefinitionsWrapper({ req, agent }) { }); } + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const pendingOAuthServers = new Set(); + + const createOAuthEmitter = (serverName) => { + return async (authURL) => { + const flowId = `${req.user.id}:${serverName}:${Date.now()}`; + const stepId = 'step_oauth_login_' + serverName; + const toolCall = { + id: flowId, + name: serverName, + type: 'tool_call_chunk', + }; + + const runStepData = { + runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + id: stepId, + type: StepTypes.TOOL_CALLS, + index: 0, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [toolCall], + }, + }; + + const runStepDeltaData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + auth: authURL, + expires_at: Date.now() + Time.TWO_MINUTES, + }, + }; + + const runStepEvent = { event: GraphEvents.ON_RUN_STEP, data: runStepData }; + const runStepDeltaEvent = { event: GraphEvents.ON_RUN_STEP_DELTA, data: runStepDeltaData }; + + if (streamId) { + GenerationJobManager.emitChunk(streamId, runStepEvent); + GenerationJobManager.emitChunk(streamId, runStepDeltaEvent); + } else if (res && !res.writableEnded) { + sendEvent(res, runStepEvent); + sendEvent(res, runStepDeltaEvent); + } else { + logger.warn( + `[Tool Definitions] Cannot emit OAuth event for ${serverName}: no streamId and res not available`, + ); + } + }; + }; + const getOrFetchMCPServerTools = async (userId, serverName) => { const cached = await getMCPServerTools(userId, serverName); if (cached) { return cached; } + const oauthStart = async () => { + pendingOAuthServers.add(serverName); + }; + const result = await reinitMCPServer({ user: req.user, + oauthStart, + flowManager, serverName, userMCPAuthMap, }); @@ -535,7 +603,7 @@ async function loadToolDefinitionsWrapper({ req, agent }) { return definitions; }; - const { toolDefinitions, toolRegistry, hasDeferredTools } = await loadToolDefinitions( + let { toolDefinitions, toolRegistry, hasDeferredTools } = await loadToolDefinitions( { userId: req.user.id, agentId: agent.id, @@ -551,6 +619,65 @@ async function loadToolDefinitionsWrapper({ req, agent }) { }, ); + if (pendingOAuthServers.size > 0 && (res || streamId)) { + const serverNames = Array.from(pendingOAuthServers); + logger.info( + `[Tool Definitions] OAuth required for ${serverNames.length} server(s): ${serverNames.join(', ')}. Emitting events and waiting.`, + ); + + const oauthWaitPromises = serverNames.map(async (serverName) => { + try { + const result = await reinitMCPServer({ + user: req.user, + serverName, + userMCPAuthMap, + flowManager, + returnOnOAuth: false, + oauthStart: createOAuthEmitter(serverName), + connectionTimeout: Time.TWO_MINUTES, + }); + + if (result?.availableTools) { + logger.info(`[Tool Definitions] OAuth completed for ${serverName}, tools available`); + return { serverName, success: true }; + } + return { serverName, success: false }; + } catch (error) { + logger.debug(`[Tool Definitions] OAuth wait failed for ${serverName}:`, error?.message); + return { serverName, success: false }; + } + }); + + const results = await Promise.allSettled(oauthWaitPromises); + const successfulServers = results + .filter((r) => r.status === 'fulfilled' && r.value.success) + .map((r) => r.value.serverName); + + if (successfulServers.length > 0) { + logger.info( + `[Tool Definitions] Reloading tools after OAuth for: ${successfulServers.join(', ')}`, + ); + const reloadResult = await loadToolDefinitions( + { + userId: req.user.id, + agentId: agent.id, + tools: filteredTools, + toolOptions: agent.tool_options, + deferredToolsEnabled, + }, + { + isBuiltInTool, + loadAuthValues, + getOrFetchMCPServerTools, + getActionToolDefinitions, + }, + ); + toolDefinitions = reloadResult.toolDefinitions; + toolRegistry = reloadResult.toolRegistry; + hasDeferredTools = reloadResult.hasDeferredTools; + } + } + return { toolRegistry, userMCPAuthMap, @@ -584,7 +711,7 @@ async function loadAgentTools({ definitionsOnly = true, }) { if (definitionsOnly) { - return loadToolDefinitionsWrapper({ req, agent }); + return loadToolDefinitionsWrapper({ req, res, agent, streamId }); } if (!agent.tools || agent.tools.length === 0) { diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 33e67c8238..10f2d71a18 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -1,11 +1,14 @@ const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); const { findToken, createToken, updateToken, deleteTokens } = require('~/models'); -const { getMCPManager, getFlowStateManager } = require('~/config'); const { updateMCPServerTools } = require('~/server/services/Config'); +const { getMCPManager, getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); /** + * Reinitializes an MCP server connection and discovers available tools. + * When OAuth is required, uses discovery mode to list tools without full authentication + * (per MCP spec, tool listing should be possible without auth). * @param {Object} params * @param {IUser} params.user - The user from the request object. * @param {string} params.serverName - The name of the MCP server @@ -14,7 +17,7 @@ const { getLogStores } = require('~/cache'); * @param {boolean} [params.forceNew] * @param {number} [params.connectionTimeout] * @param {FlowStateManager} [params.flowManager] - * @param {(authURL: string) => Promise} [params.oauthStart] + * @param {(authURL: string) => Promise} [params.oauthStart] * @param {Record>} [params.userMCPAuthMap] */ async function reinitMCPServer({ @@ -36,10 +39,12 @@ async function reinitMCPServer({ let tools = null; let oauthRequired = false; let oauthUrl = null; + try { const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS)); const mcpManager = getMCPManager(); + const tokenMethods = { findToken, updateToken, createToken, deleteTokens }; const oauthStart = _oauthStart ?? @@ -57,15 +62,10 @@ async function reinitMCPServer({ oauthStart, serverName, flowManager, + tokenMethods, returnOnOAuth, customUserVars, connectionTimeout, - tokenMethods: { - findToken, - updateToken, - createToken, - deleteTokens, - }, }); logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); @@ -84,9 +84,33 @@ async function reinitMCPServer({ if (isOAuthError || oauthRequired || isOAuthFlowInitiated) { logger.info( - `[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`, + `[MCP Reinitialize] OAuth required for ${serverName}, attempting tool discovery without auth`, ); oauthRequired = true; + + try { + const discoveryResult = await mcpManager.discoverServerTools({ + user, + signal, + serverName, + flowManager, + tokenMethods, + oauthStart, + customUserVars, + connectionTimeout, + }); + + if (discoveryResult.tools && discoveryResult.tools.length > 0) { + tools = discoveryResult.tools; + logger.info( + `[MCP Reinitialize] Discovered ${tools.length} tools for ${serverName} without full auth`, + ); + } + } catch (discoveryErr) { + logger.debug( + `[MCP Reinitialize] Tool discovery failed for ${serverName}: ${discoveryErr?.message ?? String(discoveryErr)}`, + ); + } } else { logger.error( `[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`, @@ -97,6 +121,9 @@ async function reinitMCPServer({ if (connection && !oauthRequired) { tools = await connection.fetchTools(); + } + + if (tools && tools.length > 0) { availableTools = await updateMCPServerTools({ userId: user.id, serverName, @@ -109,6 +136,9 @@ async function reinitMCPServer({ ); const getResponseMessage = () => { + if (oauthRequired && tools && tools.length > 0) { + return `MCP server '${serverName}' tools discovered, OAuth required for execution`; + } if (oauthRequired) { return `MCP server '${serverName}' ready for OAuth authentication`; } @@ -120,19 +150,25 @@ async function reinitMCPServer({ const result = { availableTools, - success: Boolean((connection && !oauthRequired) || (oauthRequired && oauthUrl)), + success: Boolean( + (connection && !oauthRequired) || + (oauthRequired && oauthUrl) || + (tools && tools.length > 0), + ), message: getResponseMessage(), oauthRequired, serverName, oauthUrl, tools, }; + logger.debug(`[MCP Reinitialize] Response for ${serverName}:`, { success: result.success, oauthRequired: result.oauthRequired, oauthUrl: result.oauthUrl ? 'present' : null, toolsCount: tools?.length ?? 0, }); + return result; } catch (error) { logger.error( diff --git a/package-lock.json b/package-lock.json index 5b18f47521..f0c2c73a3e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -59,7 +59,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.27", + "@librechat/agents": "^3.1.29", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", @@ -11709,9 +11709,9 @@ } }, "node_modules/@librechat/agents": { - "version": "3.1.27", - "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.27.tgz", - "integrity": "sha512-cThf2+OoyjBGf1PoG3H9Au3zm+zFICHF53qHYc6B3/j9mss9NgmGXd30ILRXiXPgsMCfOHqJoqUWidQHFJLiiA==", + "version": "3.1.29", + "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.29.tgz", + "integrity": "sha512-jY2+UVjnJvkUmvcsz7wic4CKJuJVUgOlVv3ICInpd3SZFhsKlUwSNKQl1PbzrZPNFuIyUt9CgGWYw1I022zhaA==", "license": "MIT", "dependencies": { "@aws-sdk/client-bedrock-runtime": "^3.970.0", @@ -30903,12 +30903,6 @@ "tslib": "^1.14.1" } }, - "node_modules/keyv-file/node_modules/tslib": { - "version": "1.14.1", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", - "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", - "license": "0BSD" - }, "node_modules/keyv/node_modules/@keyv/serialize": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@keyv/serialize/-/serialize-1.1.1.tgz", @@ -43020,7 +43014,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.27", + "@librechat/agents": "^3.1.29", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.25.3", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/package.json b/packages/api/package.json index 31ce5856f5..01202704e6 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -87,7 +87,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.27", + "@librechat/agents": "^3.1.29", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.25.3", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 1a97755ec3..bcc63b7500 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -1,5 +1,6 @@ import { logger } from '@librechat/data-schemas'; import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; +import type { Tool } from '@modelcontextprotocol/sdk/types.js'; import type { TokenMethods } from '@librechat/data-schemas'; import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth'; import type { FlowStateManager } from '~/flow/manager'; @@ -11,6 +12,13 @@ import { withTimeout } from '~/utils/promise'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils'; +export interface ToolDiscoveryResult { + tools: Tool[] | null; + connection: MCPConnection | null; + oauthRequired: boolean; + oauthUrl: string | null; +} + /** * Factory for creating MCP connections with optional OAuth authentication. * Handles OAuth flows, token management, and connection retry logic. @@ -41,6 +49,137 @@ export class MCPConnectionFactory { return factory.createConnection(); } + /** + * Discovers tools from an MCP server, even when OAuth is required. + * Per MCP spec, tool listing should be possible without authentication. + * Returns tools if discoverable, plus OAuth status for tool execution. + */ + static async discoverTools( + basic: t.BasicConnectionOptions, + oauth?: Omit, + ): Promise { + const factory = new this(basic, oauth ? { ...oauth, returnOnOAuth: true } : undefined); + return factory.discoverToolsInternal(); + } + + protected async discoverToolsInternal(): Promise { + const oauthUrl: string | null = null; + let oauthRequired = false; + + const oauthTokens = this.useOAuth ? await this.getOAuthTokens() : null; + const connection = new MCPConnection({ + serverName: this.serverName, + serverConfig: this.serverConfig, + userId: this.userId, + oauthTokens, + }); + + const oauthHandler = async () => { + logger.info( + `${this.logPrefix} [Discovery] OAuth required; skipping URL generation in discovery mode`, + ); + oauthRequired = true; + connection.emit('oauthFailed', new Error('OAuth required during tool discovery')); + }; + + if (this.useOAuth) { + connection.on('oauthRequired', oauthHandler); + } + + try { + const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000; + await withTimeout( + connection.connect(), + connectTimeout, + `Connection timeout after ${connectTimeout}ms`, + ); + + if (await connection.isConnected()) { + const tools = await connection.fetchTools(); + if (this.useOAuth) { + connection.removeListener('oauthRequired', oauthHandler); + } + return { tools, connection, oauthRequired: false, oauthUrl: null }; + } + } catch { + logger.debug( + `${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`, + ); + } + + try { + const tools = await this.attemptUnauthenticatedToolListing(); + if (this.useOAuth) { + connection.removeListener('oauthRequired', oauthHandler); + } + if (tools && tools.length > 0) { + logger.info( + `${this.logPrefix} [Discovery] Successfully discovered ${tools.length} tools without auth`, + ); + try { + await connection.disconnect(); + } catch { + // Ignore cleanup errors + } + return { tools, connection: null, oauthRequired, oauthUrl }; + } + } catch (listError) { + logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError); + } + + if (this.useOAuth) { + connection.removeListener('oauthRequired', oauthHandler); + } + + try { + await connection.disconnect(); + } catch { + // Ignore cleanup errors + } + + return { tools: null, connection: null, oauthRequired, oauthUrl }; + } + + protected async attemptUnauthenticatedToolListing(): Promise { + const unauthConnection = new MCPConnection({ + serverName: this.serverName, + serverConfig: this.serverConfig, + userId: this.userId, + oauthTokens: null, + }); + + unauthConnection.on('oauthRequired', () => { + logger.debug( + `${this.logPrefix} [Discovery] Unauthenticated connection requires OAuth, failing fast`, + ); + unauthConnection.emit( + 'oauthFailed', + new Error('OAuth not supported in unauthenticated discovery'), + ); + }); + + try { + const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 15000; + await withTimeout(unauthConnection.connect(), connectTimeout, `Unauth connection timeout`); + + if (await unauthConnection.isConnected()) { + const tools = await unauthConnection.fetchTools(); + await unauthConnection.disconnect(); + return tools; + } + } catch { + logger.debug(`${this.logPrefix} [Discovery] Unauthenticated connection attempt failed`); + } + + try { + await unauthConnection.disconnect(); + } catch { + // Ignore cleanup errors + } + + return null; + } + protected constructor(basic: t.BasicConnectionOptions, oauth?: t.OAuthConnectionOptions) { this.serverConfig = processMCPEnv({ options: basic.serverConfig, @@ -56,7 +195,7 @@ export class MCPConnectionFactory { : `[MCP][${basic.serverName}]`; if (oauth?.useOAuth) { - this.userId = oauth.user.id; + this.userId = oauth.user?.id; this.flowManager = oauth.flowManager; this.tokenMethods = oauth.tokenMethods; this.signal = oauth.signal; diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 0b9dce7061..211382c032 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -8,11 +8,12 @@ import type { FlowStateManager } from '~/flow/manager'; import type { MCPOAuthTokens } from './oauth'; import type { RequestBody } from '~/types'; import type * as t from './types'; +import { MCPServersInitializer } from './registry/MCPServersInitializer'; +import { MCPServerInspector } from './registry/MCPServerInspector'; +import { MCPServersRegistry } from './registry/MCPServersRegistry'; import { UserConnectionManager } from './UserConnectionManager'; import { ConnectionsRepository } from './ConnectionsRepository'; -import { MCPServerInspector } from './registry/MCPServerInspector'; -import { MCPServersInitializer } from './registry/MCPServersInitializer'; -import { MCPServersRegistry } from './registry/MCPServersRegistry'; +import { MCPConnectionFactory } from './MCPConnectionFactory'; import { preProcessGraphTokens } from '~/utils/graph'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; @@ -68,6 +69,70 @@ export class MCPManager extends UserConnectionManager { } } + /** + * Discovers tools from an MCP server, even when OAuth is required. + * Per MCP spec, tool listing should be possible without authentication. + * Use this for agent initialization to get tool schemas before OAuth flow. + */ + public async discoverServerTools(args: t.ToolDiscoveryOptions): Promise { + const { serverName, user } = args; + const logPrefix = user?.id ? `[MCP][User: ${user.id}][${serverName}]` : `[MCP][${serverName}]`; + + try { + const existingAppConnection = await this.appConnections?.get(serverName); + if (existingAppConnection && (await existingAppConnection.isConnected())) { + const tools = await existingAppConnection.fetchTools(); + return { tools, oauthRequired: false, oauthUrl: null }; + } + } catch { + logger.debug(`${logPrefix} [Discovery] App connection not available, trying discovery mode`); + } + + const serverConfig = (await MCPServersRegistry.getInstance().getServerConfig( + serverName, + user?.id, + )) as t.MCPOptions | null; + + if (!serverConfig) { + logger.warn(`${logPrefix} [Discovery] Server config not found`); + return { tools: null, oauthRequired: false, oauthUrl: null }; + } + + const useOAuth = Boolean( + serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata, + ); + + const basic: t.BasicConnectionOptions = { serverName, serverConfig }; + + if (!useOAuth) { + const result = await MCPConnectionFactory.discoverTools(basic); + return { + tools: result.tools, + oauthRequired: result.oauthRequired, + oauthUrl: result.oauthUrl, + }; + } + + if (!user || !args.flowManager) { + logger.warn(`${logPrefix} [Discovery] OAuth server requires user and flowManager`); + return { tools: null, oauthRequired: true, oauthUrl: null }; + } + + const result = await MCPConnectionFactory.discoverTools(basic, { + user, + useOAuth: true, + flowManager: args.flowManager, + tokenMethods: args.tokenMethods, + signal: args.signal, + oauthStart: args.oauthStart, + customUserVars: args.customUserVars, + requestBody: args.requestBody, + connectionTimeout: args.connectionTimeout, + }); + + return { tools: result.tools, oauthRequired: result.oauthRequired, oauthUrl: result.oauthUrl }; + } + /** Returns all available tool functions from app-level connections */ public async getAppToolFunctions(): Promise { const toolFunctions: t.LCAvailableTools = {}; diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 1b85b69eac..25fc753d6b 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -49,7 +49,7 @@ export abstract class UserConnectionManager { serverName: string; forceNew?: boolean; } & Omit): Promise { - const userId = user.id; + const userId = user?.id; if (!userId) { throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); } diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index 528e635204..0986188e04 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -1,6 +1,5 @@ import { logger } from '@librechat/data-schemas'; -import type { TokenMethods } from '@librechat/data-schemas'; -import type { TUser } from 'librechat-data-provider'; +import type { TokenMethods, IUser } from '@librechat/data-schemas'; import type { FlowStateManager } from '~/flow/manager'; import type { MCPOAuthTokens } from '~/mcp/oauth'; import type * as t from '~/mcp/types'; @@ -27,7 +26,7 @@ const mockMCPConnection = MCPConnection as jest.MockedClass; describe('MCPConnectionFactory', () => { - let mockUser: TUser; + let mockUser: IUser | undefined; let mockServerConfig: t.MCPOptions; let mockFlowManager: jest.Mocked>; let mockConnectionInstance: jest.Mocked; @@ -37,7 +36,7 @@ describe('MCPConnectionFactory', () => { mockUser = { id: 'user123', email: 'test@example.com', - } as TUser; + } as IUser; mockServerConfig = { command: 'node', @@ -275,7 +274,7 @@ describe('MCPConnectionFactory', () => { user: mockUser, }; - const oauthOptions = { + const oauthOptions: t.OAuthConnectionOptions = { user: mockUser, useOAuth: true, returnOnOAuth: true, @@ -424,4 +423,116 @@ describe('MCPConnectionFactory', () => { ); }); }); + + describe('discoverTools static method', () => { + const mockTools = [ + { name: 'tool1', description: 'First tool', inputSchema: { type: 'object' } }, + { name: 'tool2', description: 'Second tool', inputSchema: { type: 'object' } }, + ]; + + it('should discover tools from a successfully connected server', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + mockConnectionInstance.connect.mockResolvedValue(undefined); + mockConnectionInstance.isConnected.mockResolvedValue(true); + mockConnectionInstance.fetchTools = jest.fn().mockResolvedValue(mockTools); + + const result = await MCPConnectionFactory.discoverTools(basicOptions); + + expect(result.tools).toEqual(mockTools); + expect(result.oauthRequired).toBe(false); + expect(result.oauthUrl).toBeNull(); + expect(result.connection).toBe(mockConnectionInstance); + }); + + it('should detect OAuth required without generating URL in discovery mode', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: { + ...mockServerConfig, + url: 'https://api.example.com', + type: 'sse' as const, + } as t.SSEOptions, + }; + + const mockOAuthStart = jest.fn().mockResolvedValue(undefined); + + const oauthOptions = { + useOAuth: true as const, + user: mockUser as unknown as IUser, + flowManager: mockFlowManager, + oauthStart: mockOAuthStart, + tokenMethods: { + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), + deleteTokens: jest.fn(), + }, + }; + + mockConnectionInstance.isConnected.mockResolvedValue(false); + mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined); + + let oauthHandler: (() => Promise) | undefined; + mockConnectionInstance.on.mockImplementation((event, handler) => { + if (event === 'oauthRequired') { + oauthHandler = handler as () => Promise; + } + return mockConnectionInstance; + }); + + mockConnectionInstance.connect.mockImplementation(async () => { + if (oauthHandler) { + await oauthHandler(); + } + throw new Error('OAuth required'); + }); + + const result = await MCPConnectionFactory.discoverTools(basicOptions, oauthOptions); + + expect(result.connection).toBeNull(); + expect(result.tools).toBeNull(); + expect(result.oauthRequired).toBe(true); + expect(result.oauthUrl).toBeNull(); + expect(mockOAuthStart).not.toHaveBeenCalled(); + }); + + it('should return null tools when discovery fails completely', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed')); + mockConnectionInstance.isConnected.mockResolvedValue(false); + mockConnectionInstance.disconnect = jest.fn().mockResolvedValue(undefined); + + const result = await MCPConnectionFactory.discoverTools(basicOptions); + + expect(result.tools).toBeNull(); + expect(result.connection).toBeNull(); + expect(result.oauthRequired).toBe(false); + }); + + it('should handle disconnect errors gracefully during cleanup', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + mockConnectionInstance.connect.mockRejectedValue(new Error('Connection failed')); + mockConnectionInstance.isConnected.mockResolvedValue(false); + mockConnectionInstance.disconnect = jest + .fn() + .mockRejectedValue(new Error('Disconnect failed')); + + const result = await MCPConnectionFactory.discoverTools(basicOptions); + + expect(result.tools).toBeNull(); + expect(mockLogger.debug).toHaveBeenCalled(); + }); + }); }); diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index f210fcb63a..caeb9176d3 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -4,6 +4,7 @@ import type { GraphTokenResolver } from '~/utils/graph'; import type * as t from '~/mcp/types'; import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer'; import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnection } from '~/mcp/connection'; import { MCPManager } from '~/mcp/MCPManager'; @@ -48,6 +49,7 @@ jest.mock('~/mcp/registry/MCPServersInitializer', () => ({ jest.mock('~/mcp/registry/MCPServerInspector'); jest.mock('~/mcp/ConnectionsRepository'); +jest.mock('~/mcp/MCPConnectionFactory'); const mockLogger = logger as jest.Mocked; @@ -787,4 +789,139 @@ describe('MCPManager', () => { ); }); }); + + describe('discoverServerTools', () => { + const mockTools = [ + { name: 'tool1', description: 'First tool', inputSchema: { type: 'object' } }, + { name: 'tool2', description: 'Second tool', inputSchema: { type: 'object' } }, + ]; + + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + fetchTools: jest.fn().mockResolvedValue(mockTools), + } as unknown as MCPConnection; + + beforeEach(() => { + (MCPConnectionFactory.discoverTools as jest.Mock) = jest.fn(); + }); + + it('should return tools from existing app connection when available', async () => { + mockAppConnections({ + get: jest.fn().mockResolvedValue(mockConnection), + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.discoverServerTools({ serverName }); + + expect(result.tools).toEqual(mockTools); + expect(result.oauthRequired).toBe(false); + expect(result.oauthUrl).toBeNull(); + expect(MCPConnectionFactory.discoverTools).not.toHaveBeenCalled(); + }); + + it('should use MCPConnectionFactory.discoverTools when no app connection available', async () => { + mockAppConnections({ + get: jest.fn().mockResolvedValue(null), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({ + type: 'stdio', + command: 'test', + args: [], + }); + + (MCPConnectionFactory.discoverTools as jest.Mock).mockResolvedValue({ + tools: mockTools, + connection: null, + oauthRequired: false, + oauthUrl: null, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.discoverServerTools({ serverName }); + + expect(result.tools).toEqual(mockTools); + expect(result.oauthRequired).toBe(false); + expect(MCPConnectionFactory.discoverTools).toHaveBeenCalled(); + }); + + it('should return null tools when server config not found', async () => { + mockAppConnections({ + get: jest.fn().mockResolvedValue(null), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(null); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.discoverServerTools({ serverName }); + + expect(result.tools).toBeNull(); + expect(result.oauthRequired).toBe(false); + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('Server config not found'), + ); + }); + + it('should return OAuth info when server requires OAuth but no user provided', async () => { + mockAppConnections({ + get: jest.fn().mockResolvedValue(null), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({ + type: 'sse', + url: 'https://api.example.com', + requiresOAuth: true, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.discoverServerTools({ serverName }); + + expect(result.tools).toBeNull(); + expect(result.oauthRequired).toBe(true); + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('OAuth server requires user and flowManager'), + ); + }); + + it('should discover tools with OAuth when user and flowManager provided', async () => { + const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser; + const mockFlowManager = { + createFlow: jest.fn(), + getFlowState: jest.fn(), + deleteFlow: jest.fn(), + }; + + mockAppConnections({ + get: jest.fn().mockResolvedValue(null), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({ + type: 'sse', + url: 'https://api.example.com', + requiresOAuth: true, + }); + + (MCPConnectionFactory.discoverTools as jest.Mock).mockResolvedValue({ + tools: mockTools, + connection: null, + oauthRequired: true, + oauthUrl: 'https://auth.example.com/authorize', + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.discoverServerTools({ + serverName, + user: mockUser, + flowManager: mockFlowManager as unknown as t.ToolDiscoveryOptions['flowManager'], + }); + + expect(result.tools).toEqual(mockTools); + expect(result.oauthRequired).toBe(true); + expect(result.oauthUrl).toBe('https://auth.example.com/authorize'); + expect(MCPConnectionFactory.discoverTools).toHaveBeenCalledWith( + expect.objectContaining({ serverName }), + expect.objectContaining({ user: mockUser, useOAuth: true }), + ); + }); + }); }); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 3b0d31b83b..46447c6687 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -169,7 +169,7 @@ export interface BasicConnectionOptions { } export interface OAuthConnectionOptions { - user: IUser; + user?: IUser; useOAuth: true; requestBody?: RequestBody; customUserVars?: Record; @@ -181,3 +181,21 @@ export interface OAuthConnectionOptions { returnOnOAuth?: boolean; connectionTimeout?: number; } + +export interface ToolDiscoveryOptions { + serverName: string; + user?: IUser; + flowManager?: FlowStateManager; + tokenMethods?: TokenMethods; + signal?: AbortSignal; + oauthStart?: (authURL: string) => Promise; + customUserVars?: Record; + requestBody?: RequestBody; + connectionTimeout?: number; +} + +export interface ToolDiscoveryResult { + tools: Tool[] | null; + oauthRequired: boolean; + oauthUrl: string | null; +} diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index 26c2ef73a6..d4b9b97eda 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -238,6 +238,7 @@ class GenerationJobManagerClass { const currentRuntime = this.runtimeState.get(streamId); if (currentRuntime) { currentRuntime.syncSent = false; + currentRuntime.hasSubscriber = false; // Persist syncSent=false to Redis for cross-replica consistency this.jobStore.updateJob(streamId, { syncSent: false }).catch((err) => { logger.error(`[GenerationJobManager] Failed to persist syncSent=false:`, err); @@ -435,6 +436,7 @@ class GenerationJobManagerClass { const currentRuntime = this.runtimeState.get(streamId); if (currentRuntime) { currentRuntime.syncSent = false; + currentRuntime.hasSubscriber = false; // Persist syncSent=false to Redis this.jobStore.updateJob(streamId, { syncSent: false }).catch((err) => { logger.error(`[GenerationJobManager] Failed to persist syncSent=false:`, err); @@ -767,7 +769,6 @@ class GenerationJobManagerClass { for (const bufferedEvent of runtime.earlyEventBuffer) { onChunk(bufferedEvent); } - // Clear buffer after replay runtime.earlyEventBuffer = []; } } @@ -822,7 +823,6 @@ class GenerationJobManagerClass { // Buffer early events if no subscriber yet (replay when first subscriber connects) if (!runtime.hasSubscriber) { runtime.earlyEventBuffer.push(event); - // Also emit to transport in case subscriber connects mid-flight } this.eventTransport.emitChunk(streamId, event);