diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 40fdf74212..5ea5a00975 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,13 +1,6 @@ const { nanoid } = require('nanoid'); const { logger } = require('@librechat/data-schemas'); const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider'); -const { - EnvVar, - Constants, - GraphEvents, - GraphNodeKeys, - ToolEndHandler, -} = require('@librechat/agents'); const { sendEvent, GenerationJobManager, diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index b747e6f5ed..f67b7e08d3 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -21,6 +21,7 @@ const { generateCheckAccess, validateOAuthSession, OAUTH_SESSION_COOKIE, + MCPToolCallValidationHandler, } = require('@librechat/api'); const { createMCPServerController, @@ -755,6 +756,90 @@ async function getOAuthHeaders(serverName, userId, configServers) { return serverConfig?.oauth_headers ?? {}; } +/** + * Tool Call Validation Routes + */ + +router.post('/validation/confirm/:validationId', requireJwtAuth, async (req, res) => { + try { + const { validationId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!validationId.startsWith(`${user.id}:`)) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + await MCPToolCallValidationHandler.completeValidationFlow(validationId, flowManager); + + res.json({ success: true, message: 'Tool call validation confirmed' }); + } catch (error) { + logger.error('[MCP Validation] Failed to confirm validation', error); + res.status(500).json({ error: error.message || 'Failed to confirm validation' }); + } +}); + +router.post('/validation/reject/:validationId', requireJwtAuth, async (req, res) => { + try { + const { validationId } = req.params; + const { reason } = req.body; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!validationId.startsWith(`${user.id}:`)) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + await MCPToolCallValidationHandler.rejectValidationFlow(validationId, flowManager, reason); + + res.json({ success: true, message: 'Tool call validation rejected' }); + } catch (error) { + logger.error('[MCP Validation] Failed to reject validation', error); + res.status(500).json({ error: error.message || 'Failed to reject validation' }); + } +}); + +router.get('/validation/status/:validationId', requireJwtAuth, async (req, res) => { + try { + const { validationId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!validationId.startsWith(`${user.id}:`)) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowState = await MCPToolCallValidationHandler.getFlowState(validationId, flowManager); + + if (!flowState) { + return res.status(404).json({ error: 'Validation flow not found' }); + } + + res.json({ success: true, validationId, metadata: flowState }); + } catch (error) { + logger.error('[MCP Validation] Failed to get validation status', error); + res.status(500).json({ error: 'Failed to get validation status' }); + } +}); + /** MCP Server CRUD Routes (User-Managed MCP Servers) */ diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index ccff184d4d..a9de88a4fe 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -9,14 +9,23 @@ const { const { sendEvent, MCPOAuthHandler, + requiresApproval, isMCPDomainAllowed, normalizeServerName, normalizeJsonSchema, GenerationJobManager, resolveJsonSchemaRefs, buildOAuthToolCallName, + MCPToolCallValidationHandler, } = require('@librechat/api'); -const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider'); +const { + Time, + CacheKeys, + Constants, + ContentTypes, + EModelEndpoint, + isAssistantsEndpoint, +} = require('librechat-data-provider'); const { getOAuthReconnectionManager, getMCPServersRegistry, @@ -624,6 +633,73 @@ function createToolInstance({ derivedSignal.addEventListener('abort', abortHandler, { once: true }); } + // Tool call validation flow - only if tool requires approval + const appConfig = await getAppConfig({ role: config?.configurable?.user?.role }); + const toolApprovalConfig = appConfig?.endpoints?.[EModelEndpoint.agents]?.toolApproval; + const toolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`; + const needsApproval = requiresApproval(toolKey, toolApprovalConfig); + + if (needsApproval) { + const validationFlowType = MCPToolCallValidationHandler.getFlowType(); + const { validationId, flowMetadata } = + await MCPToolCallValidationHandler.initiateValidationFlow( + userId, + serverName, + toolName, + typeof toolArguments === 'string' ? { input: toolArguments } : toolArguments, + ); + + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const validationData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + validation: validationId, + expires_at: Date.now() + Time.TEN_MINUTES, + }, + }; + + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: validationData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: validationData }); + } + + try { + await flowManager.createFlow( + validationId, + validationFlowType, + flowMetadata, + derivedSignal, + ); + + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const successData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall }], + }, + }; + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: successData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: successData }); + } + } catch (_validationError) { + throw new Error( + `Tool call validation required for ${serverName}/${toolName}. User rejected or validation timed out.`, + ); + } + } + const customUserVars = config?.configurable?.userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; @@ -661,6 +737,18 @@ function createToolInstance({ error, ); + /** Validation error - user rejected or timeout */ + const isValidationError = + error.message?.includes('validation required') || + error.message?.includes('User rejected') || + error.message?.includes('mcp_tool_validation'); + + if (isValidationError) { + throw new Error( + `Tool call for ${serverName}/${toolName} was not approved by the user. Wait for next instructions.`, + ); + } + /** OAuth error, provide a helpful message */ const isOAuthError = error.message?.includes('401') || diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index b4d948eda4..6f253ae284 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -12,6 +12,8 @@ const { const { sendEvent, getToolkitKey, + requiresApproval, + getToolServerName, getUserMCPAuthMap, loadToolDefinitions, GenerationJobManager, @@ -20,6 +22,7 @@ const { buildImageToolContext, buildToolClassification, buildOAuthToolCallName, + MCPToolCallValidationHandler, } = require('@librechat/api'); const { Time, @@ -810,6 +813,94 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to * @param {ServerRequest} params.req - The request object * @param {ServerResponse} params.res - The response object * @param {Object} params.agent - The agent configuration +/** + * Wraps a tool with approval validation flow. + * The wrapped tool sends an SSE event and waits for user approval before executing. + * @param {Object} params + * @param {Object} params.tool - The tool to wrap + * @param {ServerResponse} params.res - The response object for SSE + * @param {string|null} params.streamId - Stream ID for resumable mode + * @returns {Object} The wrapped tool + */ +function wrapToolWithApproval({ tool, res, streamId }) { + const originalCall = tool._call.bind(tool); + const toolName = tool.name; + const serverName = getToolServerName(toolName); + + tool._call = async (toolArguments, runManager, parentConfig) => { + const config = parentConfig; + const userId = config?.configurable?.user?.id || config?.configurable?.user_id; + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; + + const { args: _args, stepId, ...toolCall } = config?.toolCall ?? {}; + + const { validationId, flowMetadata } = + await MCPToolCallValidationHandler.initiateValidationFlow( + userId, + serverName, + toolName, + typeof toolArguments === 'string' ? { input: toolArguments } : toolArguments, + ); + + const validationData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + validation: validationId, + expires_at: Date.now() + Time.TEN_MINUTES, + }, + }; + + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: validationData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: validationData }); + } + + const validationFlowType = MCPToolCallValidationHandler.getFlowType(); + try { + await flowManager.createFlow(validationId, validationFlowType, flowMetadata, derivedSignal); + + const successData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall }], + }, + }; + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: successData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: successData }); + } + } catch (_validationError) { + throw new Error( + `Tool call validation required for ${toolName}. User rejected or validation timed out.`, + ); + } + + return await originalCall(toolArguments, runManager, parentConfig); + }; + + tool.requiresApproval = true; + return tool; +} + +/** + * @param {Object} params - Run params containing user and request information. + * @param {ServerRequest} params.req - The server request + * @param {ServerResponse} params.res - The server response + * @param {Agent} params.agent - The Agent * @param {AbortSignal} [params.signal] - Abort signal * @param {Object} [params.tool_resources] - Tool resources * @param {string} [params.openAIApiKey] - OpenAI API key @@ -933,9 +1024,16 @@ async function loadAgentTools({ loadAuthValues, }); + const toolApprovalConfig = appConfig.endpoints?.[EModelEndpoint.agents]?.toolApproval; const agentTools = []; for (let i = 0; i < loadedTools.length; i++) { - const tool = loadedTools[i]; + let tool = loadedTools[i]; + + const needsApproval = requiresApproval(tool.name, toolApprovalConfig); + if (res && needsApproval && tool.mcp !== true) { + tool = wrapToolWithApproval({ tool, res, streamId }); + } + if (tool.name && (tool.name === Tools.execute_code || tool.name === Tools.file_search)) { agentTools.push(tool); continue; @@ -945,7 +1043,7 @@ async function loadAgentTools({ continue; } - if (tool.mcp === true) { + if (tool.mcp === true || tool.requiresApproval === true) { agentTools.push(tool); continue; } diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index 1b4b9057f6..715e5c6f9f 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -179,6 +179,8 @@ const Part = memo(function Part({ isSubmitting={isSubmitting} attachments={attachments} auth={toolCall.auth} + validation={toolCall.validation} + expires_at={toolCall.expires_at} isLast={isLast} /> ); diff --git a/client/src/components/Chat/Messages/Content/ToolCall.tsx b/client/src/components/Chat/Messages/Content/ToolCall.tsx index c7dd974577..de9de6ddcd 100644 --- a/client/src/components/Chat/Messages/Content/ToolCall.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCall.tsx @@ -1,7 +1,7 @@ import { useMemo, useState, useEffect, useCallback } from 'react'; import { useRecoilValue } from 'recoil'; import { Button } from '@librechat/client'; -import { TriangleAlert } from 'lucide-react'; +import { TriangleAlert, CheckCircle, XCircle } from 'lucide-react'; import { Constants, dataService, @@ -9,7 +9,7 @@ import { actionDomainSeparator, } from 'librechat-data-provider'; import type { TAttachment } from 'librechat-data-provider'; -import { useLocalize, useProgress, useExpandCollapse } from '~/hooks'; +import { useLocalize, useProgress, useExpandCollapse, useAuthContext } from '~/hooks'; import { ToolIcon, getToolIconType, isError } from './ToolOutput'; import { useMCPIconMap } from '~/hooks/MCP'; import { AttachmentGroup } from './Parts'; @@ -27,6 +27,7 @@ export default function ToolCall({ output, attachments, auth, + validation, }: { initialProgress: number; isLast?: boolean; @@ -36,6 +37,8 @@ export default function ToolCall({ output?: string | null; attachments?: TAttachment[]; auth?: string; + validation?: string; + expires_at?: number; }) { const localize = useLocalize(); const autoExpand = useRecoilValue(store.autoExpandTools); @@ -130,6 +133,66 @@ export default function ToolCall({ window.open(auth, '_blank', 'noopener,noreferrer'); }, [auth, isMCPToolCall, mcpServerName, actionId]); + const [validationConfirmed, setValidationConfirmed] = useState(false); + const [validationRejected, setValidationRejected] = useState(false); + const [validationError, setValidationError] = useState(null); + const [isConfirming, setIsConfirming] = useState(false); + const [isRejecting, setIsRejecting] = useState(false); + const { token } = useAuthContext(); + + const handleValidationConfirm = useCallback(async () => { + if (!validation || validationConfirmed || validationRejected) { + return; + } + setIsConfirming(true); + setValidationError(null); + try { + const response = await fetch(`/api/mcp/validation/confirm/${validation}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + }); + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || 'Failed to confirm validation'); + } + setValidationConfirmed(true); + } catch (err) { + setValidationError(err instanceof Error ? err.message : 'Unknown error'); + } finally { + setIsConfirming(false); + } + }, [validation, validationConfirmed, validationRejected, token]); + + const handleValidationReject = useCallback(async () => { + if (!validation || validationConfirmed || validationRejected) { + return; + } + setIsRejecting(true); + setValidationError(null); + try { + const response = await fetch(`/api/mcp/validation/reject/${validation}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ reason: 'User rejected tool call' }), + }); + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || 'Failed to reject validation'); + } + setValidationRejected(true); + } catch (err) { + setValidationError(err instanceof Error ? err.message : 'Unknown error'); + } finally { + setIsRejecting(false); + } + }, [validation, validationConfirmed, validationRejected, token]); + const hasError = typeof output === 'string' && isError(output); const cancelled = !isSubmitting && initialProgress < 1 && !hasError; const errorState = hasError; @@ -254,6 +317,59 @@ export default function ToolCall({

)} + {validation != null && + validation && + progress < 1 && + !cancelled && + !validationConfirmed && + !validationRejected && ( +
+
+ + +
+ {validationError && ( +

+

+ )} +

+

+
+ )} + {validation != null && validationConfirmed && ( +

+

+ )} + {validation != null && validationRejected && ( +

+

+ )} {attachments && attachments.length > 0 && } ); diff --git a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx index 14b4b7e07a..2f58ad6e3c 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx @@ -31,6 +31,7 @@ jest.mock('~/hooks', () => ({ }, ref: { current: null }, }), + useAuthContext: () => ({ token: 'mock-token' }), })); jest.mock('~/hooks/MCP', () => ({ @@ -89,6 +90,8 @@ jest.mock('lucide-react', () => ({ ChevronDown: () => {'ChevronDown'}, ChevronUp: () => {'ChevronUp'}, TriangleAlert: () => {'TriangleAlert'}, + CheckCircle: () => {'CheckCircle'}, + XCircle: () => {'XCircle'}, })); jest.mock('~/utils', () => ({ diff --git a/client/src/hooks/SSE/useStepHandler.ts b/client/src/hooks/SSE/useStepHandler.ts index 1f28d97433..b7f28b68d0 100644 --- a/client/src/hooks/SSE/useStepHandler.ts +++ b/client/src/hooks/SSE/useStepHandler.ts @@ -206,6 +206,7 @@ export default function useStepHandler({ args, type: ToolCallTypes.TOOL_CALL, auth: contentPart.tool_call.auth, + validation: contentPart.tool_call.validation, expires_at: contentPart.tool_call.expires_at, }; @@ -551,6 +552,11 @@ export default function useStepHandler({ contentPart.tool_call.expires_at = runStepDelta.delta.expires_at; } + if (runStepDelta.delta.validation != null) { + contentPart.tool_call.validation = runStepDelta.delta.validation; + contentPart.tool_call.expires_at = runStepDelta.delta.expires_at; + } + // Use server's index, offset by initialContent for edit scenarios const currentIndex = runStep.index + initialContent.length; updatedResponse = updateContent( diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 3d19f65ad6..12f02bfa88 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -845,6 +845,8 @@ "com_ui_confirm_action": "Confirm Action", "com_ui_confirm_admin_use_change": "Changing this setting will block access for admins, including yourself. Are you sure you want to proceed?", "com_ui_confirm_change": "Confirm Change", + "com_ui_confirm_tool_call": "Approve", + "com_ui_confirming": "Approving...", "com_ui_connecting": "Connecting", "com_ui_contact_admin_if_issue_persists": "Contact the Admin if the issue persists", "com_ui_context": "Context", @@ -1314,6 +1316,8 @@ "com_ui_regenerating": "Regenerating...", "com_ui_region": "Region", "com_ui_reinitialize": "Reinitialize", + "com_ui_reject_tool_call": "Reject", + "com_ui_rejecting": "Rejecting...", "com_ui_relevance": "Relevance", "com_ui_remote_access": "Remote Access", "com_ui_remote_agent_role_editor": "Editor", @@ -1477,6 +1481,9 @@ "com_ui_tool_collection_prefix": "A collection of tools from", "com_ui_tool_failed": "failed", "com_ui_tool_list_collapse": "Collapse {{serverName}} tool list", + "com_ui_tool_call_approved": "Tool call approved", + "com_ui_tool_call_rejected": "Tool call rejected", + "com_ui_tool_call_requires_approval": "This tool call requires your approval before it can be executed", "com_ui_tool_list_expand": "Expand {{serverName}} tool list", "com_ui_tool_name_code": "Code", "com_ui_tool_name_code_analysis": "Code Analysis", diff --git a/librechat.example.yaml b/librechat.example.yaml index 92206c4b6e..35153151d6 100644 --- a/librechat.example.yaml +++ b/librechat.example.yaml @@ -272,6 +272,12 @@ endpoints: # minRelevanceScore: 0.45 # # (optional) Agent Capabilities available to all users. Omit the ones you wish to exclude. Defaults to list below. # capabilities: ["deferred_tools", "execute_code", "file_search", "actions", "tools"] + # # (optional) Tool Approval - require user approval before tool calls execute + # # toolApproval: + # # # Set to true to require approval for all tools, or provide an array of tool patterns + # # required: true # or: ["web_search", "mcp:*", "image_*"] + # # # (optional) Exclude specific tools from approval requirement + # # excluded: ["calculator", "google"] # Anthropic endpoint configuration with Vertex AI support # Use this to run Anthropic Claude models through Google Cloud Vertex AI diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index d4b6ac9542..4e7511c0b1 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -16,6 +16,8 @@ export * from './mcp/zod'; export * from './mcp/errors'; export * from './mcp/cache'; export * from './mcp/tools'; +/* MCP Validation */ +export * from './mcp/validation'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; diff --git a/packages/api/src/mcp/validation/handler.ts b/packages/api/src/mcp/validation/handler.ts new file mode 100644 index 0000000000..8b4ca741c9 --- /dev/null +++ b/packages/api/src/mcp/validation/handler.ts @@ -0,0 +1,102 @@ +import { randomBytes } from 'crypto'; +import { logger } from '@librechat/data-schemas'; +import type { FlowStateManager } from '~/flow/manager'; +import type { FlowMetadata } from '~/flow/types'; + +export class MCPToolCallValidationHandler { + private static readonly FLOW_TYPE = 'mcp_tool_validation'; + private static readonly FLOW_TTL = 10 * 60 * 1000; + + static async initiateValidationFlow( + userId: string, + serverName: string, + toolName: string, + toolArguments: Record, + ): Promise<{ validationId: string; flowMetadata: FlowMetadata }> { + const validationId = this.generateValidationId(userId, serverName, toolName); + const state = this.generateState(); + + const flowMetadata: FlowMetadata = { + userId, + serverName, + toolName, + toolArguments, + state, + timestamp: Date.now(), + }; + + return { validationId, flowMetadata }; + } + + static async completeValidationFlow( + validationId: string, + flowManager: FlowStateManager, + ): Promise { + try { + const flowState = await flowManager.getFlowState(validationId, this.FLOW_TYPE); + if (!flowState) { + throw new Error('Validation flow not found'); + } + + await flowManager.completeFlow(validationId, this.FLOW_TYPE, true); + logger.info(`[MCPValidation] Validation flow completed successfully: ${validationId}`); + return true; + } catch (error) { + logger.error('[MCPValidation] Failed to complete validation flow', { error, validationId }); + await flowManager.failFlow(validationId, this.FLOW_TYPE, error as Error); + throw error; + } + } + + static async rejectValidationFlow( + validationId: string, + flowManager: FlowStateManager, + reason?: string, + ): Promise { + try { + const flowState = await flowManager.getFlowState(validationId, this.FLOW_TYPE); + if (!flowState) { + throw new Error('Validation flow not found'); + } + + const errorMessage = reason || 'User rejected tool call'; + await flowManager.failFlow(validationId, this.FLOW_TYPE, new Error(errorMessage)); + logger.info(`[MCPValidation] Validation flow rejected: ${validationId}`); + return true; + } catch (error) { + logger.error('[MCPValidation] Failed to reject validation flow', { error, validationId }); + throw error; + } + } + + static async getFlowState( + validationId: string, + flowManager: FlowStateManager, + ): Promise { + const flowState = await flowManager.getFlowState(validationId, this.FLOW_TYPE); + if (!flowState) { + return null; + } + return flowState.metadata as FlowMetadata; + } + + public static generateValidationId( + userId: string, + serverName: string, + toolName: string, + ): string { + return `${userId}:${serverName}:${toolName}:${Date.now()}`; + } + + public static getFlowType(): string { + return this.FLOW_TYPE; + } + + public static getFlowTTL(): number { + return this.FLOW_TTL; + } + + private static generateState(): string { + return randomBytes(32).toString('base64url'); + } +} diff --git a/packages/api/src/mcp/validation/index.ts b/packages/api/src/mcp/validation/index.ts new file mode 100644 index 0000000000..146faa1ca8 --- /dev/null +++ b/packages/api/src/mcp/validation/index.ts @@ -0,0 +1 @@ +export { MCPToolCallValidationHandler } from './handler'; diff --git a/packages/api/src/tools/approval.ts b/packages/api/src/tools/approval.ts new file mode 100644 index 0000000000..033539f2ba --- /dev/null +++ b/packages/api/src/tools/approval.ts @@ -0,0 +1,83 @@ +import type { TToolApproval } from 'librechat-data-provider'; + +export function requiresApproval( + toolName: string, + toolApproval: TToolApproval | undefined, +): boolean { + if (!toolApproval) { + return false; + } + + const { required, excluded } = toolApproval; + + if (required === undefined || required === false) { + return false; + } + + if (excluded && excluded.length > 0) { + for (const pattern of excluded) { + if (matchesPattern(toolName, pattern)) { + return false; + } + } + } + + if (required === true) { + return true; + } + + if (Array.isArray(required)) { + for (const pattern of required) { + if (matchesPattern(toolName, pattern)) { + return true; + } + } + } + + return false; +} + +export function matchesPattern(toolName: string, pattern: string): boolean { + if (pattern === toolName) { + return true; + } + + if (pattern === 'all') { + return true; + } + + if (pattern === 'mcp:*' || pattern === 'mcp_*') { + return toolName.includes(':::mcp:::') || /_mcp_/.test(toolName); + } + + if (pattern.endsWith('*')) { + const prefix = pattern.slice(0, -1); + return toolName.startsWith(prefix); + } + + return false; +} + +export function getToolServerName(toolName: string): string { + if (toolName.includes(':::mcp:::')) { + const parts = toolName.split(':::mcp:::'); + return parts[1] || 'mcp'; + } + const mcpMatch = toolName.match(/_mcp_([^_]+)$/); + if (mcpMatch) { + return mcpMatch[1]; + } + return 'builtin'; +} + +export function getBaseToolName(toolName: string): string { + if (toolName.includes(':::mcp:::')) { + const parts = toolName.split(':::mcp:::'); + return parts[0] || toolName; + } + const mcpMatch = toolName.match(/^(.+)_mcp_[^_]+$/); + if (mcpMatch) { + return mcpMatch[1]; + } + return toolName; +} diff --git a/packages/api/src/tools/index.ts b/packages/api/src/tools/index.ts index 8695d06707..5f37a13f2b 100644 --- a/packages/api/src/tools/index.ts +++ b/packages/api/src/tools/index.ts @@ -1,3 +1,4 @@ +export * from './approval'; export * from './format'; export * from './registry'; export * from './toolkits'; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index ca40ec2c8c..72dc90857a 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -318,6 +318,15 @@ export const defaultAgentCapabilities = [ AgentCapabilities.ocr, ]; +export const toolApprovalSchema = z + .object({ + required: z.union([z.boolean(), z.array(z.string())]).optional(), + excluded: z.array(z.string()).optional(), + }) + .optional(); + +export type TToolApproval = z.infer; + export const agentsEndpointSchema = baseEndpointSchema .omit({ baseURL: true }) .merge( @@ -334,6 +343,7 @@ export const agentsEndpointSchema = baseEndpointSchema .array(z.nativeEnum(AgentCapabilities)) .optional() .default(defaultAgentCapabilities), + toolApproval: toolApprovalSchema, }), ) .default({ diff --git a/packages/data-provider/src/types/agents.ts b/packages/data-provider/src/types/agents.ts index db70de8c9d..84909d03af 100644 --- a/packages/data-provider/src/types/agents.ts +++ b/packages/data-provider/src/types/agents.ts @@ -81,6 +81,8 @@ export namespace Agents { output?: string; /** Auth URL */ auth?: string; + /** Validation ID for tool call approval flow */ + validation?: string; /** Expiration time */ expires_at?: number; }; @@ -247,6 +249,7 @@ export namespace Agents { type: StepTypes.TOOL_CALLS | string; tool_calls?: ToolCallChunk[]; auth?: string; + validation?: string; expires_at?: number; }; export type AgentToolCall = FunctionToolCall | ToolCall; diff --git a/packages/data-provider/src/types/assistants.ts b/packages/data-provider/src/types/assistants.ts index 2ee30490c3..c9200d549a 100644 --- a/packages/data-provider/src/types/assistants.ts +++ b/packages/data-provider/src/types/assistants.ts @@ -496,6 +496,7 @@ export type PartMetadata = { status?: string; action?: boolean; auth?: string; + validation?: string; expires_at?: number; /** Index indicating parallel sibling content (same stepIndex in multi-agent runs) */ siblingIndex?: number;