diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index eb42046bed..29991fb0bf 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -109,11 +109,24 @@ router.get('/chat/stream/:streamId', async (req, res) => { } }; + const onChunk = (event) => { + if (!res.writableEnded) { + if (event.event === 'progress') { + res.write(`event: progress\ndata: ${JSON.stringify(event.data)}\n\n`); + } else { + res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); + } + if (typeof res.flush === 'function') { + res.flush(); + } + } + }; + let result; if (isResume) { const { subscription, resumeState, pendingEvents } = - await GenerationJobManager.subscribeWithResume(streamId, writeEvent, onDone, onError); + await GenerationJobManager.subscribeWithResume(streamId, onChunk, onDone, onError); if (!res.writableEnded) { if (resumeState) { @@ -139,7 +152,7 @@ router.get('/chat/stream/:streamId', async (req, res) => { result = subscription; } else { - result = await GenerationJobManager.subscribe(streamId, writeEvent, onDone, onError); + result = await GenerationJobManager.subscribe(streamId, onChunk, onDone, onError); } if (!result) { diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index ccff184d4d..d1da29d86e 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -8,6 +8,7 @@ const { } = require('@librechat/agents'); const { sendEvent, + sendProgress, MCPOAuthHandler, isMCPDomainAllowed, normalizeServerName, @@ -649,6 +650,30 @@ function createToolInstance({ oauthStart, oauthEnd, graphTokenResolver: getGraphApiToken, + onProgress: async (progressData) => { + logger.debug( + `[MCP][${serverName}][${toolName}] Sending progress to client:`, + progressData, + ); + const eventData = { + progress: progressData.progress, + total: progressData.total, + message: progressData.message, + toolCallId: toolCall.id, + }; + try { + if (streamId) { + await GenerationJobManager.emitTransientEvent(streamId, { + event: 'progress', + data: eventData, + }); + } else { + sendProgress(res, eventData); + } + } catch (err) { + logger.error(`[MCP][${serverName}][${toolName}] Failed to emit progress:`, err); + } + }, }); if (isAssistantsEndpoint(provider) && Array.isArray(result)) { diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index 1b4b9057f6..73b6d2797a 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -180,6 +180,7 @@ const Part = memo(function Part({ attachments={attachments} auth={toolCall.auth} isLast={isLast} + toolCallId={toolCall.id} /> ); } else if (toolCall.type === ToolCallTypes.CODE_INTERPRETER) { diff --git a/client/src/components/Chat/Messages/Content/ToolCall.tsx b/client/src/components/Chat/Messages/Content/ToolCall.tsx index c7dd974577..6cc1e97c08 100644 --- a/client/src/components/Chat/Messages/Content/ToolCall.tsx +++ b/client/src/components/Chat/Messages/Content/ToolCall.tsx @@ -2,6 +2,7 @@ import { useMemo, useState, useEffect, useCallback } from 'react'; import { useRecoilValue } from 'recoil'; import { Button } from '@librechat/client'; import { TriangleAlert } from 'lucide-react'; +import { useAtomValue, useSetAtom } from 'jotai'; import { Constants, dataService, @@ -17,6 +18,32 @@ import ToolCallInfo from './ToolCallInfo'; import ProgressText from './ProgressText'; import { logger } from '~/utils'; import store from '~/store'; +import { + toolCallProgressFamily, + clearToolCallProgressAtom, + type ProgressState, +} from '~/store/progress'; + +/** + * Gets the in-progress text to display for a tool call. + * Prioritizes MCP progress message, then progress/total, then default localized text. + */ +function getInProgressText( + mcpProgress: ProgressState | undefined, + functionName: string, + localize: ReturnType, +): string { + if (mcpProgress?.message) { + return mcpProgress.message; + } + if (mcpProgress?.total) { + return `${functionName}: ${mcpProgress.progress}/${mcpProgress.total}`; + } + if (functionName) { + return localize('com_assistants_running_var', { 0: functionName }); + } + return localize('com_assistants_running_action'); +} export default function ToolCall({ initialProgress = 0.1, @@ -27,6 +54,7 @@ export default function ToolCall({ output, attachments, auth, + toolCallId, }: { initialProgress: number; isLast?: boolean; @@ -36,6 +64,7 @@ export default function ToolCall({ output?: string | null; attachments?: TAttachment[]; auth?: string; + toolCallId?: string; }) { const localize = useLocalize(); const autoExpand = useRecoilValue(store.autoExpandTools); @@ -130,10 +159,6 @@ export default function ToolCall({ window.open(auth, '_blank', 'noopener,noreferrer'); }, [auth, isMCPToolCall, mcpServerName, actionId]); - const hasError = typeof output === 'string' && isError(output); - const cancelled = !isSubmitting && initialProgress < 1 && !hasError; - const errorState = hasError; - const args = useMemo(() => { if (typeof _args === 'string') { return _args; @@ -158,8 +183,32 @@ export default function ToolCall({ return parsedAuthUrl?.hostname ?? ''; }, [parsedAuthUrl]); - const progress = useProgress(initialProgress); - const showCancelled = cancelled || (errorState && !output); + // Get simulated progress + const simulatedProgress = useProgress(initialProgress); + + // Get real-time progress from MCP server by tool call ID + const mcpProgress = useAtomValue(toolCallProgressFamily(toolCallId ?? '')); + const clearProgress = useSetAtom(clearToolCallProgressAtom); + + // Clean up progress data when tool completes + useEffect(() => { + if (hasOutput && toolCallId) { + clearProgress(toolCallId); + } + }, [hasOutput, toolCallId, clearProgress]); + + // If tool has output, it's completed (progress = 1), otherwise use simulated progress + const progress = useMemo(() => { + if (hasOutput) { + return 1; + } + return simulatedProgress; + }, [hasOutput, simulatedProgress]); + + const hasError = typeof output === 'string' && isError(output); + const errorState = hasError; + const cancelled = (!isSubmitting && progress < 1 && !hasOutput) || errorState; + const showCancelled = (cancelled && !hasOutput) || (errorState && !output); const subtitle = useMemo(() => { if (isMCPToolCall && mcpServerName) { @@ -191,24 +240,15 @@ export default function ToolCall({ return ( <> - {(() => { - if (progress < 1 && !showCancelled) { - return function_name - ? localize('com_assistants_running_var', { 0: function_name }) - : localize('com_assistants_running_action'); - } - return getFinishedText(); - })()} + {progress < 1 && !showCancelled + ? getInProgressText(mcpProgress, function_name, localize) + : getFinishedText()}
setShowInfo((prev) => !prev)} - inProgressText={ - function_name - ? localize('com_assistants_running_var', { 0: function_name }) - : localize('com_assistants_running_action') - } + inProgressText={getInProgressText(mcpProgress, function_name, localize)} authText={ !showCancelled && authDomain.length > 0 ? localize('com_ui_requires_auth') : undefined } diff --git a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx index 14b4b7e07a..10b85475a5 100644 --- a/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx +++ b/client/src/components/Chat/Messages/Content/__tests__/ToolCall.test.tsx @@ -1,7 +1,7 @@ import React from 'react'; -import { RecoilRoot } from 'recoil'; -import { Tools, Constants } from 'librechat-data-provider'; import { render, screen, fireEvent } from '@testing-library/react'; +import { Provider } from 'jotai'; +import { Tools, Constants } from 'librechat-data-provider'; import ToolCall from '../ToolCall'; // Mock dependencies @@ -57,15 +57,19 @@ jest.mock('../ProgressText', () => ({ onClick, inProgressText, finishedText, + error, + progress, subtitle, }: { onClick?: () => void; inProgressText?: string; finishedText?: string; + error?: string; + progress: number; subtitle?: string; }) => (
- {finishedText || inProgressText} + {error || progress >= 1 ? finishedText : inProgressText} {subtitle && {subtitle}}
), @@ -98,6 +102,34 @@ jest.mock('~/utils', () => ({ cn: (...classes: any[]) => classes.filter(Boolean).join(' '), })); +const mockUseAtomValue = jest.fn().mockReturnValue(undefined); +const mockClearProgress = jest.fn(); + +jest.mock('jotai', () => ({ + ...jest.requireActual('jotai'), + useAtomValue: (...args: any[]) => mockUseAtomValue(...args), + useSetAtom: () => mockClearProgress, +})); + +const DUMMY_ATOM = { toString: () => 'dummy-atom' }; + +jest.mock('~/store/progress', () => ({ + toolCallProgressFamily: jest.fn().mockReturnValue(DUMMY_ATOM), + clearToolCallProgressAtom: {}, +})); + +jest.mock('recoil', () => ({ + ...jest.requireActual('recoil'), + useRecoilValue: jest.fn().mockReturnValue(false), +})); + +jest.mock('~/store', () => ({ + __esModule: true, + default: { + autoExpandTools: 'autoExpandTools', + }, +})); + describe('ToolCall', () => { const mockProps = { args: '{"test": "input"}', @@ -107,12 +139,14 @@ describe('ToolCall', () => { isSubmitting: false, }; - const renderWithRecoil = (component: React.ReactElement) => { - return render({component}); + const renderWithJotai = (component: React.ReactElement) => { + return render({component}); }; beforeEach(() => { jest.clearAllMocks(); + mockUseAtomValue.mockReturnValue(undefined); + mockClearProgress.mockClear(); }); describe('attachments prop passing', () => { @@ -129,7 +163,7 @@ describe('ToolCall', () => { }, ]; - renderWithRecoil(); + renderWithJotai(); fireEvent.click(screen.getByTestId('progress-text')); @@ -141,7 +175,7 @@ describe('ToolCall', () => { }); it('should pass empty array when no attachments', () => { - renderWithRecoil(); + renderWithJotai(); fireEvent.click(screen.getByTestId('progress-text')); @@ -172,7 +206,7 @@ describe('ToolCall', () => { }, ]; - renderWithRecoil(); + renderWithJotai(); fireEvent.click(screen.getByTestId('progress-text')); @@ -196,7 +230,7 @@ describe('ToolCall', () => { }, ]; - renderWithRecoil(); + renderWithJotai(); const attachmentGroup = screen.getByTestId('attachment-group'); expect(attachmentGroup).toBeInTheDocument(); @@ -204,13 +238,13 @@ describe('ToolCall', () => { }); it('should not render AttachmentGroup when no attachments', () => { - renderWithRecoil(); + renderWithJotai(); expect(screen.queryByTestId('attachment-group')).not.toBeInTheDocument(); }); it('should not render AttachmentGroup when attachments is empty array', () => { - renderWithRecoil(); + renderWithJotai(); expect(screen.queryByTestId('attachment-group')).not.toBeInTheDocument(); }); @@ -218,7 +252,7 @@ describe('ToolCall', () => { describe('tool call info visibility', () => { it('should toggle tool call info expand/collapse when clicking header', () => { - renderWithRecoil(); + renderWithJotai(); // ToolCallInfo is always in the DOM (CSS expand/collapse), but initially collapsed const toolCallInfo = screen.getByTestId('tool-call-info'); @@ -233,8 +267,29 @@ describe('ToolCall', () => { expect(screen.getByTestId('tool-call-info')).toBeInTheDocument(); }); - it('should pass input and output props to ToolCallInfo', () => { - renderWithRecoil(); + it('should pass all required props to ToolCallInfo', () => { + const attachments = [ + { + type: Tools.ui_resources, + messageId: 'msg123', + toolCallId: 'tool456', + conversationId: 'conv789', + [Tools.ui_resources]: { + '0': { type: 'button', label: 'Test' }, + }, + }, + ]; + + // Use a name with domain separator (_action_) and domain separator (---) + const propsWithDomain = { + ...mockProps, + name: 'testFunction_action_test---domain---com', + attachments, + }; + + renderWithJotai(); + + fireEvent.click(screen.getByTestId('progress-text')); const toolCallInfo = screen.getByTestId('tool-call-info'); const props = JSON.parse(toolCallInfo.textContent!); @@ -249,9 +304,10 @@ describe('ToolCall', () => { const originalOpen = window.open; window.open = jest.fn(); - renderWithRecoil( + renderWithJotai( { }); it('should not show auth section when cancelled', () => { - renderWithRecoil( + renderWithJotai( { }); it('should not show auth section when progress is complete', () => { - renderWithRecoil( + renderWithJotai( { describe('edge cases', () => { it('should handle undefined args', () => { - renderWithRecoil(); + renderWithJotai(); + + fireEvent.click(screen.getByTestId('progress-text')); const toolCallInfo = screen.getByTestId('tool-call-info'); const props = JSON.parse(toolCallInfo.textContent!); @@ -308,15 +366,17 @@ describe('ToolCall', () => { }); it('should handle null output', () => { - renderWithRecoil(); + renderWithJotai(); const toolCallInfo = screen.getByTestId('tool-call-info'); const props = JSON.parse(toolCallInfo.textContent!); expect(props.output).toBeNull(); }); - it('should handle simple function name without domain', () => { - renderWithRecoil(); + it('should handle missing domain', () => { + renderWithJotai(); + + fireEvent.click(screen.getByTestId('progress-text')); const toolCallInfo = screen.getByTestId('tool-call-info'); expect(toolCallInfo).toBeInTheDocument(); @@ -344,7 +404,7 @@ describe('ToolCall', () => { }, ]; - renderWithRecoil(); + renderWithJotai(); fireEvent.click(screen.getByTestId('progress-text')); @@ -361,7 +421,7 @@ describe('ToolCall', () => { const d = Constants.mcp_delimiter; it('should detect MCP OAuth from delimiter in tool-call name', () => { - renderWithRecoil( + renderWithJotai( { }); it('should preserve full server name when it contains the delimiter substring', () => { - renderWithRecoil( + renderWithJotai( { }); it('should display server name (not "oauth") as function_name for OAuth tool calls', () => { - renderWithRecoil( + renderWithJotai( { it('should display server name even when auth is cleared (post-completion)', () => { // After OAuth completes, createOAuthEnd re-emits the toolCall without auth. // The display should still show the server name, not literal "oauth". - renderWithRecoil( + renderWithJotai( { const authUrl = 'https://oauth.example.com/authorize?redirect_uri=' + encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback'); - renderWithRecoil( + renderWithJotai( { const authUrl = 'https://oauth.example.com/authorize?redirect_uri=' + encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback'); - renderWithRecoil( + renderWithJotai( { // gets prefixed to oauth_mcp_oauth_mcp_server. Client parses: // func="oauth", server="oauth_mcp_server". Visually awkward but // semantically correct — the normalized name IS oauth_mcp_server. - renderWithRecoil( + renderWithJotai( { const authUrl = 'https://oauth.example.com/authorize?redirect_uri=' + encodeURIComponent('https://app.example.com/api/actions/xyz/oauth/callback'); - renderWithRecoil( + renderWithJotai( { describe('A11Y-04: screen reader status announcements', () => { it('includes sr-only aria-live region for status announcements', () => { - renderWithRecoil( + renderWithJotai( { expect(liveRegion!.className).toContain('sr-only'); }); }); + + describe('getInProgressText - MCP progress display', () => { + it('shows mcpProgress.message when available', () => { + mockUseAtomValue.mockReturnValue({ + progress: 2, + total: 10, + message: 'Fetching data from API...', + timestamp: Date.now(), + }); + + renderWithJotai( + , + ); + + expect(screen.getByTestId('progress-text')).toHaveTextContent('Fetching data from API...'); + }); + + it('shows "functionName: X/Y" when mcpProgress has total but no message', () => { + mockUseAtomValue.mockReturnValue({ + progress: 3, + total: 10, + timestamp: Date.now(), + }); + + renderWithJotai( + , + ); + + expect(screen.getByTestId('progress-text')).toHaveTextContent('testFunction: 3/10'); + }); + + it('falls back to running_var localisation when no mcpProgress', () => { + mockUseAtomValue.mockReturnValue(undefined); + + renderWithJotai( + , + ); + + expect(screen.getByTestId('progress-text')).toHaveTextContent('Running testFunction'); + }); + + it('prefers message over progress/total when both are present', () => { + mockUseAtomValue.mockReturnValue({ + progress: 5, + total: 10, + message: 'Custom status message', + timestamp: Date.now(), + }); + + renderWithJotai( + , + ); + + expect(screen.getByTestId('progress-text')).toHaveTextContent('Custom status message'); + expect(screen.getByTestId('progress-text')).not.toHaveTextContent('testFunction: 5/10'); + }); + }); + + describe('toolCallId prop and progress atom integration', () => { + it('passes toolCallId to toolCallProgressFamily when provided', () => { + const { toolCallProgressFamily } = jest.requireMock('~/store/progress'); + + renderWithJotai(); + + expect(toolCallProgressFamily).toHaveBeenCalledWith('specific-call-id'); + }); + + it('passes empty string to toolCallProgressFamily when toolCallId is undefined', () => { + const { toolCallProgressFamily } = jest.requireMock('~/store/progress'); + + renderWithJotai(); + + expect(toolCallProgressFamily).toHaveBeenCalledWith(''); + }); + + it('calls clearProgress with toolCallId when output arrives', () => { + mockUseAtomValue.mockReturnValue(mockClearProgress); + + renderWithJotai( + , + ); + + expect(mockClearProgress).toHaveBeenCalledWith('call-to-clear'); + }); + + it('does not call clearProgress when toolCallId is undefined', () => { + mockUseAtomValue.mockReturnValue(mockClearProgress); + + renderWithJotai(); + + expect(mockClearProgress).not.toHaveBeenCalled(); + }); + }); + + describe('cancelled state with hasOutput', () => { + it('is not cancelled when output exists even with low progress', () => { + renderWithJotai( + , + ); + + // When not cancelled and has output → shows finished text, not "Cancelled" + expect(screen.queryByTestId('progress-text')).not.toHaveTextContent('Cancelled'); + expect(screen.queryByTestId('progress-text')).toHaveTextContent('Completed testFunction'); + }); + + it('is cancelled when no output and not submitting and progress < 1', () => { + renderWithJotai( + , + ); + + expect(screen.getByTestId('progress-text')).toHaveTextContent('Cancelled'); + }); + + it('shows finished text when progress is 1 and output is present', () => { + renderWithJotai( + , + ); + + expect(screen.getByTestId('progress-text')).toHaveTextContent('Completed testFunction'); + }); + }); }); diff --git a/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts b/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts index 1717d27c22..52256858c9 100644 --- a/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts +++ b/client/src/hooks/SSE/__tests__/useResumableSSE.spec.ts @@ -75,6 +75,15 @@ jest.mock('~/data-provider', () => ({ const mockErrorHandler = jest.fn(); const mockSetIsSubmitting = jest.fn(); const mockClearStepMaps = jest.fn(); +const mockHandleProgressEvent = jest.fn(); +const mockCleanupProgress = jest.fn(); + +jest.mock('~/hooks/SSE/useProgressTracking', () => ({ + useProgressTracking: () => ({ + handleProgressEvent: mockHandleProgressEvent, + cleanupProgress: mockCleanupProgress, + }), +})); jest.mock('~/hooks/SSE/useEventHandlers', () => jest.fn(() => ({ @@ -282,3 +291,81 @@ describe('useResumableSSE - 404 error path', () => { }, ); }); + +describe('useResumableSSE - progress event integration', () => { + beforeEach(() => { + mockSSEInstances.length = 0; + localStorage.clear(); + mockHandleProgressEvent.mockClear(); + mockCleanupProgress.mockClear(); + }); + + const renderAndInit = async (conversationId = CONV_ID) => { + const submission = buildSubmission({ conversation: { conversationId } }); + const chatHelpers = buildChatHelpers(); + const { unmount } = renderHook(() => useResumableSSE(submission, chatHelpers)); + + await act(async () => { + await Promise.resolve(); + }); + + return { sse: getLastSSE(), unmount, chatHelpers }; + }; + + it('registers a "progress" event listener on the SSE connection', async () => { + const { unmount } = await renderAndInit(); + const sse = getLastSSE(); + + const registeredEvents = sse.addEventListener.mock.calls.map(([event]: [string]) => event); + expect(registeredEvents).toContain('progress'); + unmount(); + }); + + it('routes incoming progress events to handleProgressEvent', async () => { + const { unmount } = await renderAndInit(); + const sse = getLastSSE(); + + const progressData = JSON.stringify({ + progress: 3, + total: 10, + message: 'Working...', + toolCallId: 'call-abc', + }); + + await act(async () => { + sse._emit('progress', { data: progressData }); + }); + + expect(mockHandleProgressEvent).toHaveBeenCalledTimes(1); + expect(mockHandleProgressEvent).toHaveBeenCalledWith( + expect.objectContaining({ data: progressData }), + ); + unmount(); + }); + + it('calls cleanupProgress on unmount', async () => { + const { unmount } = await renderAndInit(); + + unmount(); + + expect(mockCleanupProgress).toHaveBeenCalledTimes(1); + }); + + it('does not call cleanupProgress before unmount', async () => { + const { unmount } = await renderAndInit(); + + expect(mockCleanupProgress).not.toHaveBeenCalled(); + unmount(); + }); + + it('calls cleanupProgress on unmount even after a 404 error', async () => { + const { sse, unmount } = await renderAndInit(); + + await act(async () => { + sse._emit('error', { responseCode: 404 }); + }); + + unmount(); + expect(mockCleanupProgress).toHaveBeenCalled(); + }); +}); diff --git a/client/src/hooks/SSE/index.ts b/client/src/hooks/SSE/index.ts index 2829db76f6..6c6ff7815a 100644 --- a/client/src/hooks/SSE/index.ts +++ b/client/src/hooks/SSE/index.ts @@ -5,3 +5,5 @@ export { default as useResumeOnLoad } from './useResumeOnLoad'; export { default as useStepHandler } from './useStepHandler'; export { default as useContentHandler } from './useContentHandler'; export { default as useAttachmentHandler } from './useAttachmentHandler'; +export { useProgressTracking } from './useProgressTracking'; +export type { ProgressState } from '~/store/progress'; diff --git a/client/src/hooks/SSE/useProgressTracking.ts b/client/src/hooks/SSE/useProgressTracking.ts new file mode 100644 index 0000000000..8beddc0cb6 --- /dev/null +++ b/client/src/hooks/SSE/useProgressTracking.ts @@ -0,0 +1,48 @@ +import { useRef, useCallback } from 'react'; +import { useSetAtom } from 'jotai'; +import { toolCallProgressMapAtom } from '~/store/progress'; + +export function useProgressTracking() { + const setToolCallProgressMap = useSetAtom(toolCallProgressMapAtom); + const progressCleanupTimers = useRef>(new Map()); + + const handleProgressEvent = useCallback( + (e: MessageEvent) => { + try { + const data = JSON.parse(e.data); + const { progress, total, message, toolCallId } = data; + if (toolCallId != null) { + setToolCallProgressMap((currentMap) => { + const newMap = new Map(currentMap); + newMap.set(toolCallId, { progress, total, message, timestamp: Date.now() }); + return newMap; + }); + if (total && progress >= total) { + const existingTimer = progressCleanupTimers.current.get(toolCallId); + if (existingTimer) clearTimeout(existingTimer); + const timerId = setTimeout(() => { + setToolCallProgressMap((currentMap) => { + const newMap = new Map(currentMap); + newMap.delete(toolCallId); + return newMap; + }); + progressCleanupTimers.current.delete(toolCallId); + }, 5000); + progressCleanupTimers.current.set(toolCallId, timerId); + } + } + } catch (error) { + console.error('Error parsing progress event:', error); + } + }, + [setToolCallProgressMap], + ); + + const cleanupProgress = useCallback(() => { + setToolCallProgressMap(new Map()); + progressCleanupTimers.current.forEach(clearTimeout); + progressCleanupTimers.current.clear(); + }, [setToolCallProgressMap]); + + return { handleProgressEvent, cleanupProgress }; +} diff --git a/client/src/hooks/SSE/useResumableSSE.ts b/client/src/hooks/SSE/useResumableSSE.ts index 39dc610dae..34e959afe6 100644 --- a/client/src/hooks/SSE/useResumableSSE.ts +++ b/client/src/hooks/SSE/useResumableSSE.ts @@ -25,6 +25,7 @@ import { import type { ActiveJobsResponse } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import useEventHandlers from './useEventHandlers'; +import { useProgressTracking } from './useProgressTracking'; import { clearAllDrafts } from '~/utils'; import store from '~/store'; @@ -132,6 +133,7 @@ export default function useResumableSSE( const balanceQuery = useGetUserBalance({ enabled: !!isAuthenticated && startupConfig?.balance?.enabled, }); + const { handleProgressEvent, cleanupProgress } = useProgressTracking(); /** * Subscribe to stream via SSE library (supports custom headers) @@ -335,6 +337,8 @@ export default function useResumableSSE( } }); + sse.addEventListener('progress', handleProgressEvent); + /** * Error event handler - handles BOTH: * 1. HTTP-level errors (responseCode present) - 404, 401, network failures @@ -552,6 +556,7 @@ export default function useResumableSSE( balanceQuery, removeActiveJob, queryClient, + handleProgressEvent, ], ); @@ -703,6 +708,8 @@ export default function useResumableSSE( // Reset UI state on cleanup - useResumeOnLoad will restore if needed setIsSubmitting(false); setShowStopButton(false); + // Clear progress map and pending cleanup timers on unmount + cleanupProgress(); }; // eslint-disable-next-line react-hooks/exhaustive-deps }, [submission]); diff --git a/client/src/hooks/SSE/useSSE.ts b/client/src/hooks/SSE/useSSE.ts index 78835f5729..158f61f00e 100644 --- a/client/src/hooks/SSE/useSSE.ts +++ b/client/src/hooks/SSE/useSSE.ts @@ -10,6 +10,7 @@ import { useGetStartupConfig, useGetUserBalance } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import useEventHandlers from './useEventHandlers'; import { clearAllDrafts } from '~/utils'; +import { useProgressTracking } from './useProgressTracking'; import store from '~/store'; type ChatHelpers = Pick< @@ -71,6 +72,7 @@ export default function useSSE( const balanceQuery = useGetUserBalance({ enabled: !!isAuthenticated && startupConfig?.balance?.enabled, }); + const { handleProgressEvent, cleanupProgress } = useProgressTracking(); useEffect(() => { if (submission == null || Object.keys(submission).length === 0) { @@ -100,6 +102,8 @@ export default function useSSE( } }); + sse.addEventListener('progress', handleProgressEvent); + sse.addEventListener('message', (e: MessageEvent) => { const data = JSON.parse(e.data); @@ -234,6 +238,8 @@ export default function useSSE( return () => { const isCancelled = sse.readyState <= 1; sse.close(); + // Clear progress map and pending cleanup timers on unmount + cleanupProgress(); if (isCancelled) { const e = new Event('cancel'); /* @ts-ignore */ diff --git a/client/src/store/progress.ts b/client/src/store/progress.ts new file mode 100644 index 0000000000..423898da0e --- /dev/null +++ b/client/src/store/progress.ts @@ -0,0 +1,30 @@ +import { atom } from 'jotai'; +import { atomFamily } from 'jotai/utils'; + +export type ProgressState = { + progress: number; + total?: number; + message?: string; + timestamp: number; +}; + +// Map of toolCallId -> ProgressState (for matching progress to specific tool calls) +export const toolCallProgressMapAtom = atom>(new Map()); + +// Family for tool call based progress lookup +export const toolCallProgressFamily = atomFamily((toolCallId: string) => + atom((get) => { + // Don't return data for empty string key + if (!toolCallId) return undefined; + return get(toolCallProgressMapAtom).get(toolCallId); + }), +); + +// Cleanup action - remove progress entry for a specific tool call +export const clearToolCallProgressAtom = atom(null, (get, set, toolCallId: string) => { + if (!toolCallId) return; + const current = get(toolCallProgressMapAtom); + const newMap = new Map(current); + newMap.delete(toolCallId); + set(toolCallProgressMapAtom, newMap); +}); diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 12227de39f..fe0120b3df 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -272,6 +272,7 @@ Please follow these instructions when using tools from the respective MCP server oauthEnd, customUserVars, graphTokenResolver, + onProgress, }: { user?: IUser; serverName: string; @@ -288,9 +289,19 @@ Please follow these instructions when using tools from the respective MCP server oauthStart?: (authURL: string) => Promise; oauthEnd?: () => Promise; graphTokenResolver?: GraphTokenResolver; + onProgress?: (progressData: { + progressToken: t.ProgressToken; + progress: number; + total?: number; + message?: string; + serverName: string; + }) => void; }): Promise { /** User-specific connection */ let connection: MCPConnection | undefined; + let progressHandler: ((data: t.ProgressNotification & { serverName: string }) => void) | null = + null; + let progressToken: t.ProgressToken | undefined; const userId = user?.id; const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`; @@ -348,12 +359,44 @@ Please follow these instructions when using tools from the respective MCP server connection.setRequestHeaders(currentOptions.headers || {}); } + // Generate and register progress token BEFORE setting up listener + // to avoid race condition where progress arrives before token is registered + progressToken = connection.generateProgressToken(); + connection.registerProgressToken(progressToken); + logger.info( + `${logPrefix}[${toolName}] Progress token generated and registered: ${progressToken}`, + ); + + // Set up progress event listener if callback is provided + if (onProgress) { + progressHandler = (data) => { + logger.info(`${logPrefix}[${toolName}] Progress event received:`, { + progressToken: data.progressToken, + progress: data.progress, + total: data.total, + message: data.message, + }); + onProgress({ + progressToken: data.progressToken, + progress: data.progress, + total: data.total, + message: data.message, + serverName: data.serverName, + }); + }; + connection.on('progress', progressHandler); + logger.info(`${logPrefix}[${toolName}] Progress listener registered`); + } + const result = await connection.client.request( { method: 'tools/call', params: { name: toolName, arguments: toolArguments, + _meta: { + progressToken, + }, }, }, CallToolResultSchema, @@ -367,12 +410,30 @@ Please follow these instructions when using tools from the respective MCP server this.updateUserLastActivity(userId); } this.checkIdleConnections(); - return formatToolContent(result as t.MCPToolCallResponse, provider); + + // Format and return the tool response as a proper tuple [content, artifacts] + // Progress is handled separately via SSE events emitted by the connection + const formattedResponse = formatToolContent(result as t.MCPToolCallResponse, provider); + + return formattedResponse; } catch (error) { // Log with context and re-throw or handle as needed logger.error(`${logPrefix}[${toolName}] Tool call failed`, error); // Rethrowing allows the caller (createMCPTool) to handle the final user message throw error; + } finally { + if (connection && (progressHandler || progressToken)) { + setTimeout(() => { + // Clean up progress listener to prevent memory leaks + if (progressHandler) { + connection?.off('progress', progressHandler); + } + // Clean up progress token to prevent memory leak + if (progressToken) { + connection?.unregisterProgressToken(progressToken); + } + }, 500); + } } } } diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index ba5b0b3b8e..af7742f5f0 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -8,6 +8,7 @@ import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnection } from '~/mcp/connection'; import { MCPManager } from '~/mcp/MCPManager'; +import { EventEmitter } from 'events'; import * as graphUtils from '~/utils/graph'; // Mock external dependencies @@ -429,6 +430,10 @@ describe('MCPManager', () => { isError: false, }), }, + generateProgressToken: jest.fn().mockReturnValue('mock-progress-token'), + registerProgressToken: jest.fn(), + subscribeToProgress: jest.fn().mockReturnValue(() => {}), + unregisterProgressToken: jest.fn(), } as unknown as MCPConnection; const mockGraphTokenResolver: GraphTokenResolver = jest.fn().mockResolvedValue({ @@ -966,4 +971,396 @@ describe('MCPManager', () => { ); }); }); + + describe('callTool Progress Integration', () => { + const serverName = 'test_server'; + + const mockUser: Partial = { + id: 'user-123', + provider: 'openid', + openidId: 'oidc-sub-456', + }; + + const mockFlowManager = { + getState: jest.fn(), + setState: jest.fn(), + clearState: jest.fn(), + }; + + function buildMockConnection(overrides: Partial> = {}) { + const emitter = new EventEmitter(); + return { + isConnected: jest.fn().mockResolvedValue(true), + setRequestHeaders: jest.fn(), + timeout: 30000, + on: jest.fn((event, handler) => emitter.on(event, handler)), + off: jest.fn((event, handler) => emitter.off(event, handler)), + emit: (event: string, data: unknown) => emitter.emit(event, data), + client: { + request: jest.fn().mockResolvedValue({ + content: [{ type: 'text', text: 'Tool result' }], + isError: false, + }), + }, + generateProgressToken: jest.fn().mockReturnValue('mock-progress-token'), + registerProgressToken: jest.fn(), + unregisterProgressToken: jest.fn(), + ...overrides, + } as unknown as MCPConnection; + } + + function mockAppConnections(connection: MCPConnection) { + (ConnectionsRepository as jest.MockedClass).mockImplementation( + () => + ({ + has: jest.fn().mockResolvedValue(false), + get: jest.fn().mockResolvedValue(connection), + }) as unknown as ConnectionsRepository, + ); + } + + function newMCPServersConfig(): t.MCPServers { + return { [serverName]: { type: 'stdio', command: 'test', args: [] } }; + } + + beforeEach(() => { + (MCPManager as unknown as { instance: null }).instance = null; + jest.clearAllMocks(); + (MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined); + (mockRegistryInstance.getAllServerConfigs as jest.Mock).mockResolvedValue({}); + (graphUtils.preProcessGraphTokens as jest.Mock).mockImplementation( + async (options) => options, + ); + mockRegistryInstance.getServerConfig.mockResolvedValue({ + type: 'sse', + url: 'https://api.example.com', + }); + }); + + describe('progress token lifecycle', () => { + it('generates and registers a progress token before the tool call', async () => { + const connection = buildMockConnection(); + mockAppConnections(connection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + }); + + expect(connection.generateProgressToken).toHaveBeenCalledTimes(1); + expect(connection.registerProgressToken).toHaveBeenCalledWith('mock-progress-token'); + }); + + it('passes _meta.progressToken in the MCP request params', async () => { + const connection = buildMockConnection(); + mockAppConnections(connection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + }); + + expect(connection.client.request).toHaveBeenCalledWith( + expect.objectContaining({ + params: expect.objectContaining({ + _meta: { progressToken: 'mock-progress-token' }, + }), + }), + expect.anything(), + expect.anything(), + ); + }); + + it('schedules unregisterProgressToken cleanup in finally block after success', async () => { + jest.useFakeTimers(); + const connection = buildMockConnection(); + mockAppConnections(connection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + }); + + expect(connection.unregisterProgressToken).not.toHaveBeenCalled(); + jest.advanceTimersByTime(600); + expect(connection.unregisterProgressToken).toHaveBeenCalledWith('mock-progress-token'); + jest.useRealTimers(); + }); + + it('schedules cleanup in finally block even when tool call throws', async () => { + jest.useFakeTimers(); + const connection = buildMockConnection({ + client: { + request: jest.fn().mockRejectedValue(new Error('MCP server error')), + }, + }); + mockAppConnections(connection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await expect( + manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + }), + ).rejects.toThrow('MCP server error'); + + jest.advanceTimersByTime(600); + expect(connection.unregisterProgressToken).toHaveBeenCalledWith('mock-progress-token'); + jest.useRealTimers(); + }); + }); + + describe('onProgress callback', () => { + it('does not register a progress listener when onProgress is not provided', async () => { + const connection = buildMockConnection(); + mockAppConnections(connection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + }); + + expect(connection.on).not.toHaveBeenCalledWith('progress', expect.any(Function)); + }); + + it('registers a progress listener when onProgress is provided', async () => { + const connection = buildMockConnection(); + mockAppConnections(connection); + + const onProgress = jest.fn(); + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + onProgress, + }); + + expect(connection.on).toHaveBeenCalledWith('progress', expect.any(Function)); + }); + + it('calls onProgress with the correct shape when a progress event fires', async () => { + const emitter = new EventEmitter(); + const connection = buildMockConnection({ + on: jest.fn((event, handler) => emitter.on(event, handler)), + off: jest.fn((event, handler) => emitter.off(event, handler)), + client: { + request: jest.fn().mockImplementation(async () => { + emitter.emit('progress', { + serverName, + progressToken: 'mock-progress-token', + progress: 3, + total: 10, + message: 'Processing...', + }); + return { content: [{ type: 'text', text: 'done' }], isError: false }; + }), + }, + }); + mockAppConnections(connection); + + const onProgress = jest.fn(); + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + onProgress, + }); + + expect(onProgress).toHaveBeenCalledTimes(1); + expect(onProgress).toHaveBeenCalledWith({ + progressToken: 'mock-progress-token', + progress: 3, + total: 10, + message: 'Processing...', + serverName, + }); + }); + + it('calls onProgress multiple times as progress events accumulate', async () => { + const emitter = new EventEmitter(); + const connection = buildMockConnection({ + on: jest.fn((event, handler) => emitter.on(event, handler)), + off: jest.fn((event, handler) => emitter.off(event, handler)), + client: { + request: jest.fn().mockImplementation(async () => { + for (let i = 1; i <= 3; i++) { + emitter.emit('progress', { + serverName, + progressToken: 'mock-progress-token', + progress: i, + total: 3, + }); + } + return { content: [{ type: 'text', text: 'done' }], isError: false }; + }), + }, + }); + mockAppConnections(connection); + + const onProgress = jest.fn(); + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + onProgress, + }); + + expect(onProgress).toHaveBeenCalledTimes(3); + }); + + it('detaches the progress listener after call completes', async () => { + jest.useFakeTimers(); + const connection = buildMockConnection(); + mockAppConnections(connection); + + const onProgress = jest.fn(); + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + onProgress, + }); + + jest.advanceTimersByTime(600); + expect(connection.off).toHaveBeenCalledWith('progress', expect.any(Function)); + jest.useRealTimers(); + }); + + it('detaches the progress listener after call throws', async () => { + jest.useFakeTimers(); + const connection = buildMockConnection({ + client: { + request: jest.fn().mockRejectedValue(new Error('fail')), + }, + }); + mockAppConnections(connection); + + const onProgress = jest.fn(); + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await expect( + manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + onProgress, + }), + ).rejects.toThrow('fail'); + + jest.advanceTimersByTime(600); + expect(connection.off).toHaveBeenCalledWith('progress', expect.any(Function)); + jest.useRealTimers(); + }); + + it('logs each received progress event at info level', async () => { + const emitter = new EventEmitter(); + const connection = buildMockConnection({ + on: jest.fn((event, handler) => emitter.on(event, handler)), + off: jest.fn((event, handler) => emitter.off(event, handler)), + client: { + request: jest.fn().mockImplementation(async () => { + emitter.emit('progress', { + serverName, + progressToken: 'mock-progress-token', + progress: 1, + total: 5, + message: 'Step 1', + }); + return { content: [{ type: 'text', text: 'done' }], isError: false }; + }), + }, + }); + mockAppConnections(connection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + onProgress: jest.fn(), + }); + + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('Progress event received'), + expect.objectContaining({ progress: 1, total: 5, message: 'Step 1' }), + ); + }); + + it('logs progress token registration at info level', async () => { + const connection = buildMockConnection(); + mockAppConnections(connection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.callTool({ + user: mockUser as IUser, + serverName, + toolName: 'test_tool', + provider: 'openai', + flowManager: mockFlowManager as unknown as Parameters< + typeof manager.callTool + >[0]['flowManager'], + }); + + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('Progress token generated and registered'), + ); + }); + }); + }); }); diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index 8dc857cd3b..3736f9dcdb 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -8,7 +8,10 @@ import { import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js'; -import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js'; +import { + ResourceListChangedNotificationSchema, + ProgressNotificationSchema, +} from '@modelcontextprotocol/sdk/types.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { @@ -368,6 +371,13 @@ export class MCPConnection extends EventEmitter { } } + // Progress notification state + private progressTokenCounter = 0; + private activeProgressTokens = new Map(); + private lastProgressEmit = new Map(); + private progressCleanupTimers = new Map(); + private readonly PROGRESS_THROTTLE_MS = 200; // Max 5 updates/second + setRequestHeaders(headers: Record | null): void { if (!headers) { return; @@ -686,6 +696,7 @@ export class MCPConnection extends EventEmitter { }); this.subscribeToResources(); + this.subscribeToProgress(); } private async handleReconnection(): Promise { @@ -767,6 +778,125 @@ export class MCPConnection extends EventEmitter { }); } + /** + * Generates a unique progress token for tracking tool call progress + */ + public generateProgressToken(): t.ProgressToken { + const userId = this.userId ? `${this.userId}-` : ''; + return `${userId}${this.serverName}-${++this.progressTokenCounter}-${Date.now()}`; + } + + /** + * Registers a progress token for tracking + */ + public registerProgressToken(token: t.ProgressToken): void { + this.activeProgressTokens.set(token, { + token, + progress: 0, + timestamp: Date.now(), + }); + } + + /** + * Gets the current progress state for a token + */ + public getProgressState(token: t.ProgressToken): t.ProgressState | undefined { + return this.activeProgressTokens.get(token); + } + + /** + * Unregisters a progress token and cleans up associated state + */ + public unregisterProgressToken(token: t.ProgressToken): void { + this.activeProgressTokens.delete(token); + this.lastProgressEmit.delete(token); + // Cancel any pending cleanup timer for this token + const existingTimer = this.progressCleanupTimers.get(token); + if (existingTimer) { + clearTimeout(existingTimer); + this.progressCleanupTimers.delete(token); + } + } + + /** + * Checks if progress should be emitted (rate limiting) + */ + private shouldEmitProgress(token: t.ProgressToken): boolean { + const lastEmit = this.lastProgressEmit.get(token) || 0; + const now = Date.now(); + if (now - lastEmit < this.PROGRESS_THROTTLE_MS) { + return false; + } + this.lastProgressEmit.set(token, now); + return true; + } + + /** + * Subscribes to progress notifications from MCP server + */ + private subscribeToProgress(): void { + try { + this.client.setNotificationHandler(ProgressNotificationSchema, async (notification) => { + try { + logger.info( + `${this.getLogPrefix()} Received progress notification:`, + notification.params, + ); + const { progressToken, progress, total, message } = notification.params; + + // Validate token + if (!this.activeProgressTokens.has(progressToken)) { + logger.debug(`${this.getLogPrefix()} Progress for unknown token: ${progressToken}`); + return; + } + + // Rate limiting + if (!this.shouldEmitProgress(progressToken) && (!total || progress < total)) { + return; + } + + // Update state + const newState: t.ProgressState = { + token: progressToken, + progress, + total, + message, + timestamp: Date.now(), + }; + this.activeProgressTokens.set(progressToken, newState); + + // Emit progress event + this.emit('progress', { + serverName: this.serverName, + progressToken, + progress, + total, + message, + }); + + // Cleanup if complete + if (total && progress >= total) { + // Clear existing timer if present to prevent duplicates + const existingTimer = this.progressCleanupTimers.get(progressToken); + if (existingTimer) { + clearTimeout(existingTimer); + } + const timerId = setTimeout(() => { + this.activeProgressTokens.delete(progressToken); + this.lastProgressEmit.delete(progressToken); + this.progressCleanupTimers.delete(progressToken); + }, 5000); + this.progressCleanupTimers.set(progressToken, timerId); + } + } catch (error) { + logger.error(`${this.getLogPrefix()} Error handling progress:`, error); + } + }); + } catch (error) { + logger.warn(`${this.getLogPrefix()} Failed to setup progress notifications:`, error); + } + } + async connectClient(): Promise { if (this.connectionState === 'connected') { return; @@ -1115,6 +1245,12 @@ export class MCPConnection extends EventEmitter { this.emit('connectionChange', 'disconnected'); } finally { this.connectPromise = null; + // Clean up progress tracking state to prevent memory leaks + this.activeProgressTokens.clear(); + this.lastProgressEmit.clear(); + // Clear all pending cleanup timers + this.progressCleanupTimers.forEach((timer) => clearTimeout(timer)); + this.progressCleanupTimers.clear(); if (!resetCycleTracking) { this.recordCycle(); } diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 32c2787165..7048f869ac 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -72,6 +72,23 @@ export type MCPToolCallResponse = isError?: boolean; }; +export type ProgressToken = string | number; + +export interface ProgressNotification { + progressToken: ProgressToken; + progress: number; + total?: number; + message?: string; +} + +export interface ProgressState { + token: ProgressToken; + progress: number; + total?: number; + message?: string; + timestamp: number; +} + export type Provider = | 'google' | 'anthropic' @@ -142,6 +159,10 @@ export type FormattedContentResult = [string, Artifacts | undefined]; export type ImageFormatter = (item: ImageContent) => FormattedContent; +/** + * Tool response for MCP tools - must be a proper two-tuple for LangChain's content_and_artifact format. + * Progress notifications are handled separately via SSE events, not attached to the response. + */ export type FormattedToolResponse = FormattedContentResult; /** diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index 5993c911ff..5b42e6f8c5 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -915,6 +915,26 @@ class GenerationJobManagerClass { await this.eventTransport.emitChunk(streamId, event); } + /** + * Emit an ephemeral event directly to live subscribers, bypassing the + * earlyEventBuffer, Redis persistence, and trackUserMessage side-effects. + * + * Use for fire-and-forget events like MCP tool-call progress that must not + * be replayed to reconnecting clients. Silently drops when no subscriber + * is connected or the job has been aborted. + */ + async emitTransientEvent(streamId: string, event: t.ServerSentEvent): Promise { + const runtime = this.runtimeState.get(streamId); + if (!runtime || runtime.abortController.signal.aborted) { + return; + } + if (!runtime.hasSubscriber) { + return; + } + // Emit directly to transport, bypassing appendChunk and earlyEventBuffer + await this.eventTransport.emitChunk(streamId, event); + } + /** * Extract and save run step from event data. * The data is already the run step object from the event payload. diff --git a/packages/api/src/stream/__tests__/GenerationJobManager.emitTransientEvent.spec.ts b/packages/api/src/stream/__tests__/GenerationJobManager.emitTransientEvent.spec.ts new file mode 100644 index 0000000000..d61e031d2b --- /dev/null +++ b/packages/api/src/stream/__tests__/GenerationJobManager.emitTransientEvent.spec.ts @@ -0,0 +1,119 @@ +import type * as t from '~/types'; + +interface RuntimeState { + abortController: AbortController; + hasSubscriber: boolean; +} + +class GenerationJobManagerStub { + runtimeState = new Map(); + eventTransport = { emitChunk: jest.fn() }; + + async emitTransientEvent(streamId: string, event: t.ServerSentEvent): Promise { + const runtime = this.runtimeState.get(streamId); + if (!runtime || runtime.abortController.signal.aborted) { + return; + } + if (!runtime.hasSubscriber) { + return; + } + await this.eventTransport.emitChunk(streamId, event); + } +} + +function makeRuntime(overrides: Partial = {}): RuntimeState { + return { + abortController: new AbortController(), + hasSubscriber: true, + ...overrides, + }; +} + +describe('GenerationJobManager - emitTransientEvent', () => { + let manager: GenerationJobManagerStub; + + const streamId = 'stream-abc-123'; + const progressEvent: t.ServerSentEvent = { + event: 'progress', + data: { progress: 2, total: 5, message: 'Working…', toolCallId: 'call-1' }, + } as unknown as t.ServerSentEvent; + + beforeEach(() => { + manager = new GenerationJobManagerStub(); + jest.clearAllMocks(); + }); + + it('emits to transport when runtime exists and has subscriber', async () => { + manager.runtimeState.set(streamId, makeRuntime()); + + await manager.emitTransientEvent(streamId, progressEvent); + + expect(manager.eventTransport.emitChunk).toHaveBeenCalledTimes(1); + expect(manager.eventTransport.emitChunk).toHaveBeenCalledWith(streamId, progressEvent); + }); + + it('silently drops when streamId has no runtime entry', async () => { + await manager.emitTransientEvent('unknown-stream', progressEvent); + + expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled(); + }); + + it('silently drops when job has been aborted', async () => { + const runtime = makeRuntime(); + runtime.abortController.abort(); + manager.runtimeState.set(streamId, runtime); + + await manager.emitTransientEvent(streamId, progressEvent); + + expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled(); + }); + + it('silently drops when there is no active subscriber', async () => { + manager.runtimeState.set(streamId, makeRuntime({ hasSubscriber: false })); + + await manager.emitTransientEvent(streamId, progressEvent); + + expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled(); + }); + + it('does not persist the event (calls emitChunk directly, not appendChunk)', async () => { + const appendChunk = jest.fn(); + (manager as unknown as Record).appendChunk = appendChunk; + manager.runtimeState.set(streamId, makeRuntime()); + + await manager.emitTransientEvent(streamId, progressEvent); + + expect(appendChunk).not.toHaveBeenCalled(); + expect(manager.eventTransport.emitChunk).toHaveBeenCalled(); + }); + + it('forwards any event shape without mutation', async () => { + manager.runtimeState.set(streamId, makeRuntime()); + + const customEvent = { event: 'progress', data: { foo: 'bar' } } as unknown as t.ServerSentEvent; + await manager.emitTransientEvent(streamId, customEvent); + + expect(manager.eventTransport.emitChunk).toHaveBeenCalledWith(streamId, customEvent); + }); + + it('handles transport throwing without crashing the caller', async () => { + manager.runtimeState.set(streamId, makeRuntime()); + manager.eventTransport.emitChunk.mockRejectedValueOnce(new Error('transport error')); + + await expect(manager.emitTransientEvent(streamId, progressEvent)).rejects.toThrow( + 'transport error', + ); + }); + + it('does not emit after abort even if subscriber flag is still true', async () => { + const runtime = makeRuntime({ hasSubscriber: true }); + manager.runtimeState.set(streamId, runtime); + + // Abort happens between registration and emit + runtime.abortController.abort(); + + await manager.emitTransientEvent(streamId, progressEvent); + + expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/api/src/utils/events.ts b/packages/api/src/utils/events.ts index e084e631f5..2a1bc002d6 100644 --- a/packages/api/src/utils/events.ts +++ b/packages/api/src/utils/events.ts @@ -1,6 +1,25 @@ import type { Response as ServerResponse } from 'express'; import type { ServerSentEvent } from '~/types'; +/** + * Safely writes to a server response, handling disconnected clients. + * @param res - The server response. + * @param data - The data to write. + * @returns true if write succeeded, false otherwise. + */ +function safeWrite(res: ServerResponse, data: string): boolean { + try { + if (!res.writable) { + return false; + } + res.write(data); + return true; + } catch { + // Client may have disconnected - log but don't crash + return false; + } +} + /** * Sends a Server-Sent Event to the client. * Empty-string StreamEvent data is silently dropped. @@ -9,7 +28,7 @@ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void { if ('data' in event && typeof event.data === 'string' && event.data.length === 0) { return; } - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); + safeWrite(res, `event: message\ndata: ${JSON.stringify(event)}\n\n`); } /** @@ -18,6 +37,29 @@ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void { * @param message - The error message. */ export function handleError(res: ServerResponse, message: string): void { - res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); - res.end(); + safeWrite(res, `event: error\ndata: ${JSON.stringify(message)}\n\n`); + try { + res.end(); + } catch { + // Client may have disconnected + } +} + +/** + * Sends progress notification in Server Sent Events format. + * @param res - The server response. + * @param progressData - Progress notification data. + */ +export function sendProgress( + res: ServerResponse, + progressData: { + progressToken: string | number; + progress: number; + total?: number; + message?: string; + serverName?: string; + toolCallId?: string; // Tool call ID for matching progress to specific tool call + }, +): void { + safeWrite(res, `event: progress\ndata: ${JSON.stringify(progressData)}\n\n`); }