From 301ba801f431997314bbd978b4ccc65e455b8825 Mon Sep 17 00:00:00 2001 From: Aron Gates Date: Mon, 9 Mar 2026 11:26:01 +0000 Subject: [PATCH] feat: implement tool approval checks for agent tool calls Ports the tool approval feature from aron/tool-approval branch onto the latest codebase. Adds manual user approval flow for tool calls before execution, configurable via librechat.yaml toolApproval config. Key changes: - Add TToolApproval schema to data-provider config (required/excluded patterns) - Add approval.ts utilities (requiresApproval, matchesPattern, getToolServerName) - Add MCPToolCallValidationHandler for flow-based approval via FlowStateManager - Wrap non-MCP tools with approval in ToolService.loadAgentTools - Add MCP tool validation in MCP.js createToolInstance - Handle native Anthropic web search approval in callbacks.js - Disable native web_search when approval required (OpenAI initialize) - Add validation SSE delta handling in useStepHandler - Add approve/reject UI in ToolCall.tsx with confirm/reject API calls - Add validation routes: POST /api/mcp/validation/confirm|reject/:id - Add i18n keys for approval UI - Add toolApproval example config in librechat.example.yaml Co-Authored-By: Claude Opus 4.6 (1M context) --- api/server/controllers/agents/callbacks.js | 98 +++++++++++++- api/server/routes/mcp.js | 85 +++++++++++++ .../services/Endpoints/agents/initialize.js | 2 + api/server/services/MCP.js | 68 ++++++++++ api/server/services/ToolService.js | 102 ++++++++++++++- .../components/Chat/Messages/Content/Part.tsx | 2 + .../Chat/Messages/Content/ToolCall.tsx | 120 +++++++++++++++++- client/src/hooks/SSE/useStepHandler.ts | 6 + client/src/locales/en/translation.json | 7 + librechat.example.yaml | 6 + .../api/src/endpoints/openai/initialize.ts | 26 ++++ packages/api/src/index.ts | 2 + packages/api/src/mcp/validation/handler.ts | 102 +++++++++++++++ packages/api/src/mcp/validation/index.ts | 1 + packages/api/src/tools/approval.ts | 83 ++++++++++++ packages/api/src/tools/index.ts | 1 + packages/data-provider/src/config.ts | 10 ++ packages/data-provider/src/types/agents.ts | 3 + .../data-provider/src/types/assistants.ts | 1 + 19 files changed, 720 insertions(+), 5 deletions(-) create mode 100644 packages/api/src/mcp/validation/handler.ts create mode 100644 packages/api/src/mcp/validation/index.ts create mode 100644 packages/api/src/tools/approval.ts diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 40fdf74212..def76ad3e4 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,6 +1,6 @@ const { nanoid } = require('nanoid'); const { logger } = require('@librechat/data-schemas'); -const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider'); +const { Tools, StepTypes, FileContext, ErrorTypes, EModelEndpoint } = require('librechat-data-provider'); const { EnvVar, Constants, @@ -10,14 +10,18 @@ const { } = require('@librechat/agents'); const { sendEvent, + requiresApproval, GenerationJobManager, writeAttachmentEvent, createToolExecuteHandler, + MCPToolCallValidationHandler, } = require('@librechat/api'); const { processFileCitations } = require('~/server/services/Files/Citations'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { saveBase64Image } = require('~/server/services/Files/process'); +const { getFlowStateManager } = require('~/config'); +const { getLogStores } = require('~/cache'); class ModelEndHandler { /** @@ -116,6 +120,78 @@ async function emitEvent(res, streamId, eventData) { } } +/** + * Checks if a tool call is a native Anthropic web search (server_tool_use). + * @param {Object} toolCall - The tool call object + * @returns {boolean} + */ +function isNativeWebSearch(toolCall) { + return toolCall?.name === Tools.web_search && toolCall?.id?.startsWith('srvtoolu_'); +} + +/** + * Handles approval flow for native web search tool calls. + * @param {Object} params + * @param {Object} params.toolCall - The tool call requiring approval + * @param {Object} params.toolApprovalConfig - The tool approval configuration + * @param {ServerResponse} params.res - The response object for SSE + * @param {string | null} params.streamId - The stream ID for resumable mode + * @param {string} params.stepId - The step ID + * @param {string} params.userId - The user ID + * @param {AbortSignal} [params.signal] - Optional abort signal + * @returns {Promise} + */ +async function handleNativeWebSearchApproval({ + toolCall, + toolApprovalConfig, + res, + streamId, + stepId, + userId, + signal, +}) { + const needsApproval = requiresApproval(Tools.web_search, toolApprovalConfig); + if (!needsApproval) { + return; + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const derivedSignal = signal ? AbortSignal.any([signal]) : undefined; + + const toolArgs = toolCall.args || {}; + const { validationId, flowMetadata } = + await MCPToolCallValidationHandler.initiateValidationFlow( + userId, + 'anthropic', + Tools.web_search, + typeof toolArgs === 'string' ? { input: toolArgs } : toolArgs, + ); + + const validationData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ id: toolCall.id, name: toolCall.name, args: '' }], + validation: validationId, + expires_at: Date.now() + Time.TEN_MINUTES, + }, + }; + await emitEvent(res, streamId, { event: GraphEvents.ON_RUN_STEP_DELTA, data: validationData }); + + const validationFlowType = MCPToolCallValidationHandler.getFlowType(); + await flowManager.createFlow(validationId, validationFlowType, flowMetadata, derivedSignal); + + const successData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ id: toolCall.id, name: toolCall.name }], + }, + }; + await emitEvent(res, streamId, { event: GraphEvents.ON_RUN_STEP_DELTA, data: successData }); +} + /** * @typedef {Object} ToolExecuteOptions * @property {(toolNames: string[]) => Promise<{loadedTools: StructuredTool[]}>} loadTools - Function to load tools by name @@ -142,12 +218,14 @@ function getDefaultHandlers({ streamId = null, toolExecuteOptions = null, summarizationOptions = null, + toolApprovalConfig, }) { if (!res || !aggregateContent) { throw new Error( `[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`, ); } + const pendingNativeWebSearchApprovals = new Set(); const handlers = { [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(collectedUsage), [GraphEvents.TOOL_END]: new ToolEndHandler(toolEndCallback, logger), @@ -159,6 +237,24 @@ function getDefaultHandlers({ * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ handle: async (event, data, metadata) => { + if (data?.stepDetails?.type === StepTypes.TOOL_CALLS && toolApprovalConfig) { + const toolCalls = data.stepDetails.tool_calls || []; + for (const toolCall of toolCalls) { + if (isNativeWebSearch(toolCall) && !pendingNativeWebSearchApprovals.has(toolCall.id)) { + pendingNativeWebSearchApprovals.add(toolCall.id); + await handleNativeWebSearchApproval({ + toolCall, + toolApprovalConfig, + res, + streamId, + stepId: data.id, + userId: metadata?.user_id || metadata?.user?.id, + signal: metadata?.signal, + }); + } + } + } + aggregateContent({ event, data }); if (data?.stepDetails.type === StepTypes.TOOL_CALLS) { await emitEvent(res, streamId, { event, data }); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index c6496ad4b4..08585404bb 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -21,6 +21,7 @@ const { generateCheckAccess, validateOAuthSession, OAUTH_SESSION_COOKIE, + MCPToolCallValidationHandler, } = require('@librechat/api'); const { createMCPServerController, @@ -732,6 +733,90 @@ async function getOAuthHeaders(serverName, userId, configServers) { return serverConfig?.oauth_headers ?? {}; } +/** + * Tool Call Validation Routes + */ + +router.post('/validation/confirm/:validationId', requireJwtAuth, async (req, res) => { + try { + const { validationId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!validationId.startsWith(`${user.id}:`)) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + await MCPToolCallValidationHandler.completeValidationFlow(validationId, flowManager); + + res.json({ success: true, message: 'Tool call validation confirmed' }); + } catch (error) { + logger.error('[MCP Validation] Failed to confirm validation', error); + res.status(500).json({ error: error.message || 'Failed to confirm validation' }); + } +}); + +router.post('/validation/reject/:validationId', requireJwtAuth, async (req, res) => { + try { + const { validationId } = req.params; + const { reason } = req.body; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!validationId.startsWith(`${user.id}:`)) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + await MCPToolCallValidationHandler.rejectValidationFlow(validationId, flowManager, reason); + + res.json({ success: true, message: 'Tool call validation rejected' }); + } catch (error) { + logger.error('[MCP Validation] Failed to reject validation', error); + res.status(500).json({ error: error.message || 'Failed to reject validation' }); + } +}); + +router.get('/validation/status/:validationId', requireJwtAuth, async (req, res) => { + try { + const { validationId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!validationId.startsWith(`${user.id}:`)) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowState = await MCPToolCallValidationHandler.getFlowState(validationId, flowManager); + + if (!flowState) { + return res.status(404).json({ error: 'Validation flow not found' }); + } + + res.json({ success: true, validationId, metadata: flowState }); + } catch (error) { + logger.error('[MCP Validation] Failed to get validation status', error); + res.status(500).json({ error: 'Failed to get validation status' }); + } +}); + /** MCP Server CRUD Routes (User-Managed MCP Servers) */ diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 69767e191c..82260585db 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -147,6 +147,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const summarizationOptions = appConfig?.summarization?.enabled === false ? { enabled: false } : { enabled: true }; + const toolApprovalConfig = appConfig?.endpoints?.[EModelEndpoint.agents]?.toolApproval; const eventHandlers = getDefaultHandlers({ res, toolExecuteOptions, @@ -155,6 +156,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { toolEndCallback, collectedUsage, streamId, + toolApprovalConfig, }); if (!endpointOption.agent) { diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index dbb44740a9..484cfc3a36 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -15,6 +15,7 @@ const { GenerationJobManager, resolveJsonSchemaRefs, buildOAuthToolCallName, + MCPToolCallValidationHandler, } = require('@librechat/api'); const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider'); const { @@ -624,6 +625,61 @@ function createToolInstance({ derivedSignal.addEventListener('abort', abortHandler, { once: true }); } + // Tool call validation flow - requires user approval before executing + const validationFlowType = MCPToolCallValidationHandler.getFlowType(); + const { validationId, flowMetadata } = + await MCPToolCallValidationHandler.initiateValidationFlow( + userId, + serverName, + toolName, + typeof toolArguments === 'string' ? { input: toolArguments } : toolArguments, + ); + + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const validationData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + validation: validationId, + expires_at: Date.now() + Time.TEN_MINUTES, + }, + }; + + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: validationData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: validationData }); + } + + try { + await flowManager.createFlow(validationId, validationFlowType, flowMetadata, derivedSignal); + + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const successData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall }], + }, + }; + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: successData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: successData }); + } + } catch (validationError) { + throw new Error( + `Tool call validation required for ${serverName}/${toolName}. User rejected or validation timed out.`, + ); + } + const customUserVars = config?.configurable?.userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; @@ -660,6 +716,18 @@ function createToolInstance({ error, ); + /** Validation error - user rejected or timeout */ + const isValidationError = + error.message?.includes('validation required') || + error.message?.includes('User rejected') || + error.message?.includes('mcp_tool_validation'); + + if (isValidationError) { + throw new Error( + `Tool call for ${serverName}/${toolName} was not approved by the user.`, + ); + } + /** OAuth error, provide a helpful message */ const isOAuthError = error.message?.includes('401') || diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index b4d948eda4..4e31ad6733 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -12,6 +12,8 @@ const { const { sendEvent, getToolkitKey, + requiresApproval, + getToolServerName, getUserMCPAuthMap, loadToolDefinitions, GenerationJobManager, @@ -20,6 +22,7 @@ const { buildImageToolContext, buildToolClassification, buildOAuthToolCallName, + MCPToolCallValidationHandler, } = require('@librechat/api'); const { Time, @@ -810,6 +813,94 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to * @param {ServerRequest} params.req - The request object * @param {ServerResponse} params.res - The response object * @param {Object} params.agent - The agent configuration +/** + * Wraps a tool with approval validation flow. + * The wrapped tool sends an SSE event and waits for user approval before executing. + * @param {Object} params + * @param {Object} params.tool - The tool to wrap + * @param {ServerResponse} params.res - The response object for SSE + * @param {string|null} params.streamId - Stream ID for resumable mode + * @returns {Object} The wrapped tool + */ +function wrapToolWithApproval({ tool, res, streamId }) { + const originalCall = tool._call.bind(tool); + const toolName = tool.name; + const serverName = getToolServerName(toolName); + + tool._call = async (toolArguments, runManager, parentConfig) => { + const config = parentConfig; + const userId = config?.configurable?.user?.id || config?.configurable?.user_id; + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; + + const { args: _args, stepId, ...toolCall } = config?.toolCall ?? {}; + + const { validationId, flowMetadata } = + await MCPToolCallValidationHandler.initiateValidationFlow( + userId, + serverName, + toolName, + typeof toolArguments === 'string' ? { input: toolArguments } : toolArguments, + ); + + const validationData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + validation: validationId, + expires_at: Date.now() + Time.TEN_MINUTES, + }, + }; + + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: validationData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: validationData }); + } + + const validationFlowType = MCPToolCallValidationHandler.getFlowType(); + try { + await flowManager.createFlow(validationId, validationFlowType, flowMetadata, derivedSignal); + + const successData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall }], + }, + }; + if (streamId) { + await GenerationJobManager.emitChunk(streamId, { + event: GraphEvents.ON_RUN_STEP_DELTA, + data: successData, + }); + } else { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data: successData }); + } + } catch (validationError) { + throw new Error( + `Tool call validation required for ${toolName}. User rejected or validation timed out.`, + ); + } + + return await originalCall(toolArguments, runManager, parentConfig); + }; + + tool.requiresApproval = true; + return tool; +} + +/** + * @param {Object} params - Run params containing user and request information. + * @param {ServerRequest} params.req - The server request + * @param {ServerResponse} params.res - The server response + * @param {Agent} params.agent - The Agent * @param {AbortSignal} [params.signal] - Abort signal * @param {Object} [params.tool_resources] - Tool resources * @param {string} [params.openAIApiKey] - OpenAI API key @@ -933,9 +1024,16 @@ async function loadAgentTools({ loadAuthValues, }); + const toolApprovalConfig = appConfig.endpoints?.[EModelEndpoint.agents]?.toolApproval; const agentTools = []; for (let i = 0; i < loadedTools.length; i++) { - const tool = loadedTools[i]; + let tool = loadedTools[i]; + + const needsApproval = requiresApproval(tool.name, toolApprovalConfig); + if (res && needsApproval && tool.mcp !== true) { + tool = wrapToolWithApproval({ tool, res, streamId }); + } + if (tool.name && (tool.name === Tools.execute_code || tool.name === Tools.file_search)) { agentTools.push(tool); continue; @@ -945,7 +1043,7 @@ async function loadAgentTools({ continue; } - if (tool.mcp === true) { + if (tool.mcp === true || tool.requiresApproval === true) { agentTools.push(tool); continue; } diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index 1b4b9057f6..715e5c6f9f 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -179,6 +179,8 @@ const Part = memo(function Part({ isSubmitting={isSubmitting} attachments={attachments} auth={toolCall.auth} + validation={toolCall.validation} + expires_at={toolCall.expires_at} isLast={isLast} /> ); diff --git a/client/src/components/Chat/Messages/Content/ToolCall.tsx b/client/src/components/Chat/Messages/Content/ToolCall.tsx index c7dd974577..de9de6ddcd 100644 --- a/client/src/components/Chat/Messages/Content/ToolCall.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCall.tsx @@ -1,7 +1,7 @@ import { useMemo, useState, useEffect, useCallback } from 'react'; import { useRecoilValue } from 'recoil'; import { Button } from '@librechat/client'; -import { TriangleAlert } from 'lucide-react'; +import { TriangleAlert, CheckCircle, XCircle } from 'lucide-react'; import { Constants, dataService, @@ -9,7 +9,7 @@ import { actionDomainSeparator, } from 'librechat-data-provider'; import type { TAttachment } from 'librechat-data-provider'; -import { useLocalize, useProgress, useExpandCollapse } from '~/hooks'; +import { useLocalize, useProgress, useExpandCollapse, useAuthContext } from '~/hooks'; import { ToolIcon, getToolIconType, isError } from './ToolOutput'; import { useMCPIconMap } from '~/hooks/MCP'; import { AttachmentGroup } from './Parts'; @@ -27,6 +27,7 @@ export default function ToolCall({ output, attachments, auth, + validation, }: { initialProgress: number; isLast?: boolean; @@ -36,6 +37,8 @@ export default function ToolCall({ output?: string | null; attachments?: TAttachment[]; auth?: string; + validation?: string; + expires_at?: number; }) { const localize = useLocalize(); const autoExpand = useRecoilValue(store.autoExpandTools); @@ -130,6 +133,66 @@ export default function ToolCall({ window.open(auth, '_blank', 'noopener,noreferrer'); }, [auth, isMCPToolCall, mcpServerName, actionId]); + const [validationConfirmed, setValidationConfirmed] = useState(false); + const [validationRejected, setValidationRejected] = useState(false); + const [validationError, setValidationError] = useState(null); + const [isConfirming, setIsConfirming] = useState(false); + const [isRejecting, setIsRejecting] = useState(false); + const { token } = useAuthContext(); + + const handleValidationConfirm = useCallback(async () => { + if (!validation || validationConfirmed || validationRejected) { + return; + } + setIsConfirming(true); + setValidationError(null); + try { + const response = await fetch(`/api/mcp/validation/confirm/${validation}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + }); + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || 'Failed to confirm validation'); + } + setValidationConfirmed(true); + } catch (err) { + setValidationError(err instanceof Error ? err.message : 'Unknown error'); + } finally { + setIsConfirming(false); + } + }, [validation, validationConfirmed, validationRejected, token]); + + const handleValidationReject = useCallback(async () => { + if (!validation || validationConfirmed || validationRejected) { + return; + } + setIsRejecting(true); + setValidationError(null); + try { + const response = await fetch(`/api/mcp/validation/reject/${validation}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ reason: 'User rejected tool call' }), + }); + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || 'Failed to reject validation'); + } + setValidationRejected(true); + } catch (err) { + setValidationError(err instanceof Error ? err.message : 'Unknown error'); + } finally { + setIsRejecting(false); + } + }, [validation, validationConfirmed, validationRejected, token]); + const hasError = typeof output === 'string' && isError(output); const cancelled = !isSubmitting && initialProgress < 1 && !hasError; const errorState = hasError; @@ -254,6 +317,59 @@ export default function ToolCall({

)} + {validation != null && + validation && + progress < 1 && + !cancelled && + !validationConfirmed && + !validationRejected && ( +
+
+ + +
+ {validationError && ( +

+

+ )} +

+

+
+ )} + {validation != null && validationConfirmed && ( +

+

+ )} + {validation != null && validationRejected && ( +

+

+ )} {attachments && attachments.length > 0 && } ); diff --git a/client/src/hooks/SSE/useStepHandler.ts b/client/src/hooks/SSE/useStepHandler.ts index 1f28d97433..b7f28b68d0 100644 --- a/client/src/hooks/SSE/useStepHandler.ts +++ b/client/src/hooks/SSE/useStepHandler.ts @@ -206,6 +206,7 @@ export default function useStepHandler({ args, type: ToolCallTypes.TOOL_CALL, auth: contentPart.tool_call.auth, + validation: contentPart.tool_call.validation, expires_at: contentPart.tool_call.expires_at, }; @@ -551,6 +552,11 @@ export default function useStepHandler({ contentPart.tool_call.expires_at = runStepDelta.delta.expires_at; } + if (runStepDelta.delta.validation != null) { + contentPart.tool_call.validation = runStepDelta.delta.validation; + contentPart.tool_call.expires_at = runStepDelta.delta.expires_at; + } + // Use server's index, offset by initialContent for edit scenarios const currentIndex = runStep.index + initialContent.length; updatedResponse = updateContent( diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 3d19f65ad6..12f02bfa88 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -845,6 +845,8 @@ "com_ui_confirm_action": "Confirm Action", "com_ui_confirm_admin_use_change": "Changing this setting will block access for admins, including yourself. Are you sure you want to proceed?", "com_ui_confirm_change": "Confirm Change", + "com_ui_confirm_tool_call": "Approve", + "com_ui_confirming": "Approving...", "com_ui_connecting": "Connecting", "com_ui_contact_admin_if_issue_persists": "Contact the Admin if the issue persists", "com_ui_context": "Context", @@ -1314,6 +1316,8 @@ "com_ui_regenerating": "Regenerating...", "com_ui_region": "Region", "com_ui_reinitialize": "Reinitialize", + "com_ui_reject_tool_call": "Reject", + "com_ui_rejecting": "Rejecting...", "com_ui_relevance": "Relevance", "com_ui_remote_access": "Remote Access", "com_ui_remote_agent_role_editor": "Editor", @@ -1477,6 +1481,9 @@ "com_ui_tool_collection_prefix": "A collection of tools from", "com_ui_tool_failed": "failed", "com_ui_tool_list_collapse": "Collapse {{serverName}} tool list", + "com_ui_tool_call_approved": "Tool call approved", + "com_ui_tool_call_rejected": "Tool call rejected", + "com_ui_tool_call_requires_approval": "This tool call requires your approval before it can be executed", "com_ui_tool_list_expand": "Expand {{serverName}} tool list", "com_ui_tool_name_code": "Code", "com_ui_tool_name_code_analysis": "Code Analysis", diff --git a/librechat.example.yaml b/librechat.example.yaml index 03bb5f5bc2..904078f754 100644 --- a/librechat.example.yaml +++ b/librechat.example.yaml @@ -274,6 +274,12 @@ endpoints: # minRelevanceScore: 0.45 # # (optional) Agent Capabilities available to all users. Omit the ones you wish to exclude. Defaults to list below. # capabilities: ["deferred_tools", "execute_code", "file_search", "actions", "tools"] + # # (optional) Tool Approval - require user approval before tool calls execute + # # toolApproval: + # # # Set to true to require approval for all tools, or provide an array of tool patterns + # # required: true # or: ["web_search", "mcp:*", "image_*"] + # # # (optional) Exclude specific tools from approval requirement + # # excluded: ["calculator", "google"] # Anthropic endpoint configuration with Vertex AI support # Use this to run Anthropic Claude models through Google Cloud Vertex AI diff --git a/packages/api/src/endpoints/openai/initialize.ts b/packages/api/src/endpoints/openai/initialize.ts index a6ad6df895..38f91a4198 100644 --- a/packages/api/src/endpoints/openai/initialize.ts +++ b/packages/api/src/endpoints/openai/initialize.ts @@ -1,4 +1,5 @@ import { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } from 'librechat-data-provider'; +import type { TToolApproval } from 'librechat-data-provider'; import type { BaseInitializeParams, InitializeResultBase, @@ -9,6 +10,23 @@ import { getAzureCredentials, resolveHeaders, isUserProvided, checkUserKeyExpiry import { validateEndpointURL } from '~/auth'; import { getOpenAIConfig } from './config'; +function shouldDisableNativeWebSearch(toolApproval: TToolApproval | undefined): boolean { + if (!toolApproval) { + return false; + } + const { required, excluded } = toolApproval; + if (excluded?.includes('web_search')) { + return false; + } + if (required === true) { + return true; + } + if (Array.isArray(required) && required.includes('web_search')) { + return true; + } + return false; +} + /** * Initializes OpenAI options for agent usage. This function always returns configuration * options and never creates a client instance (equivalent to optionsOnly=true behavior). @@ -133,6 +151,14 @@ export async function initializeOpenAI({ user: req.user?.id, }; + const toolApproval = appConfig?.endpoints?.[EModelEndpoint.agents]?.toolApproval; + if (shouldDisableNativeWebSearch(toolApproval)) { + clientOptions.dropParams = clientOptions.dropParams ?? []; + if (!clientOptions.dropParams.includes('web_search')) { + clientOptions.dropParams.push('web_search'); + } + } + const finalClientOptions: OpenAIConfigOptions = { ...clientOptions, modelOptions, diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index d4b6ac9542..4e7511c0b1 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -16,6 +16,8 @@ export * from './mcp/zod'; export * from './mcp/errors'; export * from './mcp/cache'; export * from './mcp/tools'; +/* MCP Validation */ +export * from './mcp/validation'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; diff --git a/packages/api/src/mcp/validation/handler.ts b/packages/api/src/mcp/validation/handler.ts new file mode 100644 index 0000000000..8b4ca741c9 --- /dev/null +++ b/packages/api/src/mcp/validation/handler.ts @@ -0,0 +1,102 @@ +import { randomBytes } from 'crypto'; +import { logger } from '@librechat/data-schemas'; +import type { FlowStateManager } from '~/flow/manager'; +import type { FlowMetadata } from '~/flow/types'; + +export class MCPToolCallValidationHandler { + private static readonly FLOW_TYPE = 'mcp_tool_validation'; + private static readonly FLOW_TTL = 10 * 60 * 1000; + + static async initiateValidationFlow( + userId: string, + serverName: string, + toolName: string, + toolArguments: Record, + ): Promise<{ validationId: string; flowMetadata: FlowMetadata }> { + const validationId = this.generateValidationId(userId, serverName, toolName); + const state = this.generateState(); + + const flowMetadata: FlowMetadata = { + userId, + serverName, + toolName, + toolArguments, + state, + timestamp: Date.now(), + }; + + return { validationId, flowMetadata }; + } + + static async completeValidationFlow( + validationId: string, + flowManager: FlowStateManager, + ): Promise { + try { + const flowState = await flowManager.getFlowState(validationId, this.FLOW_TYPE); + if (!flowState) { + throw new Error('Validation flow not found'); + } + + await flowManager.completeFlow(validationId, this.FLOW_TYPE, true); + logger.info(`[MCPValidation] Validation flow completed successfully: ${validationId}`); + return true; + } catch (error) { + logger.error('[MCPValidation] Failed to complete validation flow', { error, validationId }); + await flowManager.failFlow(validationId, this.FLOW_TYPE, error as Error); + throw error; + } + } + + static async rejectValidationFlow( + validationId: string, + flowManager: FlowStateManager, + reason?: string, + ): Promise { + try { + const flowState = await flowManager.getFlowState(validationId, this.FLOW_TYPE); + if (!flowState) { + throw new Error('Validation flow not found'); + } + + const errorMessage = reason || 'User rejected tool call'; + await flowManager.failFlow(validationId, this.FLOW_TYPE, new Error(errorMessage)); + logger.info(`[MCPValidation] Validation flow rejected: ${validationId}`); + return true; + } catch (error) { + logger.error('[MCPValidation] Failed to reject validation flow', { error, validationId }); + throw error; + } + } + + static async getFlowState( + validationId: string, + flowManager: FlowStateManager, + ): Promise { + const flowState = await flowManager.getFlowState(validationId, this.FLOW_TYPE); + if (!flowState) { + return null; + } + return flowState.metadata as FlowMetadata; + } + + public static generateValidationId( + userId: string, + serverName: string, + toolName: string, + ): string { + return `${userId}:${serverName}:${toolName}:${Date.now()}`; + } + + public static getFlowType(): string { + return this.FLOW_TYPE; + } + + public static getFlowTTL(): number { + return this.FLOW_TTL; + } + + private static generateState(): string { + return randomBytes(32).toString('base64url'); + } +} diff --git a/packages/api/src/mcp/validation/index.ts b/packages/api/src/mcp/validation/index.ts new file mode 100644 index 0000000000..146faa1ca8 --- /dev/null +++ b/packages/api/src/mcp/validation/index.ts @@ -0,0 +1 @@ +export { MCPToolCallValidationHandler } from './handler'; diff --git a/packages/api/src/tools/approval.ts b/packages/api/src/tools/approval.ts new file mode 100644 index 0000000000..033539f2ba --- /dev/null +++ b/packages/api/src/tools/approval.ts @@ -0,0 +1,83 @@ +import type { TToolApproval } from 'librechat-data-provider'; + +export function requiresApproval( + toolName: string, + toolApproval: TToolApproval | undefined, +): boolean { + if (!toolApproval) { + return false; + } + + const { required, excluded } = toolApproval; + + if (required === undefined || required === false) { + return false; + } + + if (excluded && excluded.length > 0) { + for (const pattern of excluded) { + if (matchesPattern(toolName, pattern)) { + return false; + } + } + } + + if (required === true) { + return true; + } + + if (Array.isArray(required)) { + for (const pattern of required) { + if (matchesPattern(toolName, pattern)) { + return true; + } + } + } + + return false; +} + +export function matchesPattern(toolName: string, pattern: string): boolean { + if (pattern === toolName) { + return true; + } + + if (pattern === 'all') { + return true; + } + + if (pattern === 'mcp:*' || pattern === 'mcp_*') { + return toolName.includes(':::mcp:::') || /_mcp_/.test(toolName); + } + + if (pattern.endsWith('*')) { + const prefix = pattern.slice(0, -1); + return toolName.startsWith(prefix); + } + + return false; +} + +export function getToolServerName(toolName: string): string { + if (toolName.includes(':::mcp:::')) { + const parts = toolName.split(':::mcp:::'); + return parts[1] || 'mcp'; + } + const mcpMatch = toolName.match(/_mcp_([^_]+)$/); + if (mcpMatch) { + return mcpMatch[1]; + } + return 'builtin'; +} + +export function getBaseToolName(toolName: string): string { + if (toolName.includes(':::mcp:::')) { + const parts = toolName.split(':::mcp:::'); + return parts[0] || toolName; + } + const mcpMatch = toolName.match(/^(.+)_mcp_[^_]+$/); + if (mcpMatch) { + return mcpMatch[1]; + } + return toolName; +} diff --git a/packages/api/src/tools/index.ts b/packages/api/src/tools/index.ts index 8695d06707..5f37a13f2b 100644 --- a/packages/api/src/tools/index.ts +++ b/packages/api/src/tools/index.ts @@ -1,3 +1,4 @@ +export * from './approval'; export * from './format'; export * from './registry'; export * from './toolkits'; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index ae3f5b9560..a61888ec4a 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -279,6 +279,15 @@ export const defaultAgentCapabilities = [ AgentCapabilities.ocr, ]; +export const toolApprovalSchema = z + .object({ + required: z.union([z.boolean(), z.array(z.string())]).optional(), + excluded: z.array(z.string()).optional(), + }) + .optional(); + +export type TToolApproval = z.infer; + export const agentsEndpointSchema = baseEndpointSchema .omit({ baseURL: true }) .merge( @@ -295,6 +304,7 @@ export const agentsEndpointSchema = baseEndpointSchema .array(z.nativeEnum(AgentCapabilities)) .optional() .default(defaultAgentCapabilities), + toolApproval: toolApprovalSchema, }), ) .default({ diff --git a/packages/data-provider/src/types/agents.ts b/packages/data-provider/src/types/agents.ts index db70de8c9d..84909d03af 100644 --- a/packages/data-provider/src/types/agents.ts +++ b/packages/data-provider/src/types/agents.ts @@ -81,6 +81,8 @@ export namespace Agents { output?: string; /** Auth URL */ auth?: string; + /** Validation ID for tool call approval flow */ + validation?: string; /** Expiration time */ expires_at?: number; }; @@ -247,6 +249,7 @@ export namespace Agents { type: StepTypes.TOOL_CALLS | string; tool_calls?: ToolCallChunk[]; auth?: string; + validation?: string; expires_at?: number; }; export type AgentToolCall = FunctionToolCall | ToolCall; diff --git a/packages/data-provider/src/types/assistants.ts b/packages/data-provider/src/types/assistants.ts index 2ee30490c3..c9200d549a 100644 --- a/packages/data-provider/src/types/assistants.ts +++ b/packages/data-provider/src/types/assistants.ts @@ -496,6 +496,7 @@ export type PartMetadata = { status?: string; action?: boolean; auth?: string; + validation?: string; expires_at?: number; /** Index indicating parallel sibling content (same stepIndex in multi-agent runs) */ siblingIndex?: number;