This commit is contained in:
Artyom Bogachenko 2026-04-04 06:50:00 +02:00 committed by GitHub
commit 16f03a7593
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1311 additions and 57 deletions

View file

@ -109,11 +109,24 @@ router.get('/chat/stream/:streamId', async (req, res) => {
}
};
const onChunk = (event) => {
if (!res.writableEnded) {
if (event.event === 'progress') {
res.write(`event: progress\ndata: ${JSON.stringify(event.data)}\n\n`);
} else {
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
}
if (typeof res.flush === 'function') {
res.flush();
}
}
};
let result;
if (isResume) {
const { subscription, resumeState, pendingEvents } =
await GenerationJobManager.subscribeWithResume(streamId, writeEvent, onDone, onError);
await GenerationJobManager.subscribeWithResume(streamId, onChunk, onDone, onError);
if (!res.writableEnded) {
if (resumeState) {
@ -139,7 +152,7 @@ router.get('/chat/stream/:streamId', async (req, res) => {
result = subscription;
} else {
result = await GenerationJobManager.subscribe(streamId, writeEvent, onDone, onError);
result = await GenerationJobManager.subscribe(streamId, onChunk, onDone, onError);
}
if (!result) {

View file

@ -8,6 +8,7 @@ const {
} = require('@librechat/agents');
const {
sendEvent,
sendProgress,
MCPOAuthHandler,
isMCPDomainAllowed,
normalizeServerName,
@ -649,6 +650,30 @@ function createToolInstance({
oauthStart,
oauthEnd,
graphTokenResolver: getGraphApiToken,
onProgress: async (progressData) => {
logger.debug(
`[MCP][${serverName}][${toolName}] Sending progress to client:`,
progressData,
);
const eventData = {
progress: progressData.progress,
total: progressData.total,
message: progressData.message,
toolCallId: toolCall.id,
};
try {
if (streamId) {
await GenerationJobManager.emitTransientEvent(streamId, {
event: 'progress',
data: eventData,
});
} else {
sendProgress(res, eventData);
}
} catch (err) {
logger.error(`[MCP][${serverName}][${toolName}] Failed to emit progress:`, err);
}
},
});
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {

View file

@ -180,6 +180,7 @@ const Part = memo(function Part({
attachments={attachments}
auth={toolCall.auth}
isLast={isLast}
toolCallId={toolCall.id}
/>
);
} else if (toolCall.type === ToolCallTypes.CODE_INTERPRETER) {

View file

@ -2,6 +2,7 @@ import { useMemo, useState, useEffect, useCallback } from 'react';
import { useRecoilValue } from 'recoil';
import { Button } from '@librechat/client';
import { TriangleAlert } from 'lucide-react';
import { useAtomValue, useSetAtom } from 'jotai';
import {
Constants,
dataService,
@ -17,6 +18,32 @@ import ToolCallInfo from './ToolCallInfo';
import ProgressText from './ProgressText';
import { logger } from '~/utils';
import store from '~/store';
import {
toolCallProgressFamily,
clearToolCallProgressAtom,
type ProgressState,
} from '~/store/progress';
/**
* Gets the in-progress text to display for a tool call.
* Prioritizes MCP progress message, then progress/total, then default localized text.
*/
function getInProgressText(
mcpProgress: ProgressState | undefined,
functionName: string,
localize: ReturnType<typeof useLocalize>,
): string {
if (mcpProgress?.message) {
return mcpProgress.message;
}
if (mcpProgress?.total) {
return `${functionName}: ${mcpProgress.progress}/${mcpProgress.total}`;
}
if (functionName) {
return localize('com_assistants_running_var', { 0: functionName });
}
return localize('com_assistants_running_action');
}
export default function ToolCall({
initialProgress = 0.1,
@ -27,6 +54,7 @@ export default function ToolCall({
output,
attachments,
auth,
toolCallId,
}: {
initialProgress: number;
isLast?: boolean;
@ -36,6 +64,7 @@ export default function ToolCall({
output?: string | null;
attachments?: TAttachment[];
auth?: string;
toolCallId?: string;
}) {
const localize = useLocalize();
const autoExpand = useRecoilValue(store.autoExpandTools);
@ -130,10 +159,6 @@ export default function ToolCall({
window.open(auth, '_blank', 'noopener,noreferrer');
}, [auth, isMCPToolCall, mcpServerName, actionId]);
const hasError = typeof output === 'string' && isError(output);
const cancelled = !isSubmitting && initialProgress < 1 && !hasError;
const errorState = hasError;
const args = useMemo(() => {
if (typeof _args === 'string') {
return _args;
@ -158,8 +183,32 @@ export default function ToolCall({
return parsedAuthUrl?.hostname ?? '';
}, [parsedAuthUrl]);
const progress = useProgress(initialProgress);
const showCancelled = cancelled || (errorState && !output);
// Get simulated progress
const simulatedProgress = useProgress(initialProgress);
// Get real-time progress from MCP server by tool call ID
const mcpProgress = useAtomValue(toolCallProgressFamily(toolCallId ?? ''));
const clearProgress = useSetAtom(clearToolCallProgressAtom);
// Clean up progress data when tool completes
useEffect(() => {
if (hasOutput && toolCallId) {
clearProgress(toolCallId);
}
}, [hasOutput, toolCallId, clearProgress]);
// If tool has output, it's completed (progress = 1), otherwise use simulated progress
const progress = useMemo(() => {
if (hasOutput) {
return 1;
}
return simulatedProgress;
}, [hasOutput, simulatedProgress]);
const hasError = typeof output === 'string' && isError(output);
const errorState = hasError;
const cancelled = (!isSubmitting && progress < 1 && !hasOutput) || errorState;
const showCancelled = (cancelled && !hasOutput) || (errorState && !output);
const subtitle = useMemo(() => {
if (isMCPToolCall && mcpServerName) {
@ -191,24 +240,15 @@ export default function ToolCall({
return (
<>
<span className="sr-only" aria-live="polite" aria-atomic="true">
{(() => {
if (progress < 1 && !showCancelled) {
return function_name
? localize('com_assistants_running_var', { 0: function_name })
: localize('com_assistants_running_action');
}
return getFinishedText();
})()}
{progress < 1 && !showCancelled
? getInProgressText(mcpProgress, function_name, localize)
: getFinishedText()}
</span>
<div className="relative my-1.5 flex h-5 shrink-0 items-center gap-2.5">
<ProgressText
progress={progress}
onClick={() => setShowInfo((prev) => !prev)}
inProgressText={
function_name
? localize('com_assistants_running_var', { 0: function_name })
: localize('com_assistants_running_action')
}
inProgressText={getInProgressText(mcpProgress, function_name, localize)}
authText={
!showCancelled && authDomain.length > 0 ? localize('com_ui_requires_auth') : undefined
}

View file

@ -1,7 +1,7 @@
import React from 'react';
import { RecoilRoot } from 'recoil';
import { Tools, Constants } from 'librechat-data-provider';
import { render, screen, fireEvent } from '@testing-library/react';
import { Provider } from 'jotai';
import { Tools, Constants } from 'librechat-data-provider';
import ToolCall from '../ToolCall';
// Mock dependencies
@ -57,15 +57,19 @@ jest.mock('../ProgressText', () => ({
onClick,
inProgressText,
finishedText,
error,
progress,
subtitle,
}: {
onClick?: () => void;
inProgressText?: string;
finishedText?: string;
error?: string;
progress: number;
subtitle?: string;
}) => (
<div data-testid="progress-text" onClick={onClick}>
{finishedText || inProgressText}
{error || progress >= 1 ? finishedText : inProgressText}
{subtitle && <span data-testid="subtitle">{subtitle}</span>}
</div>
),
@ -98,6 +102,34 @@ jest.mock('~/utils', () => ({
cn: (...classes: any[]) => classes.filter(Boolean).join(' '),
}));
const mockUseAtomValue = jest.fn().mockReturnValue(undefined);
const mockClearProgress = jest.fn();
jest.mock('jotai', () => ({
...jest.requireActual('jotai'),
useAtomValue: (...args: any[]) => mockUseAtomValue(...args),
useSetAtom: () => mockClearProgress,
}));
const DUMMY_ATOM = { toString: () => 'dummy-atom' };
jest.mock('~/store/progress', () => ({
toolCallProgressFamily: jest.fn().mockReturnValue(DUMMY_ATOM),
clearToolCallProgressAtom: {},
}));
jest.mock('recoil', () => ({
...jest.requireActual('recoil'),
useRecoilValue: jest.fn().mockReturnValue(false),
}));
jest.mock('~/store', () => ({
__esModule: true,
default: {
autoExpandTools: 'autoExpandTools',
},
}));
describe('ToolCall', () => {
const mockProps = {
args: '{"test": "input"}',
@ -107,12 +139,14 @@ describe('ToolCall', () => {
isSubmitting: false,
};
const renderWithRecoil = (component: React.ReactElement) => {
return render(<RecoilRoot>{component}</RecoilRoot>);
const renderWithJotai = (component: React.ReactElement) => {
return render(<Provider>{component}</Provider>);
};
beforeEach(() => {
jest.clearAllMocks();
mockUseAtomValue.mockReturnValue(undefined);
mockClearProgress.mockClear();
});
describe('attachments prop passing', () => {
@ -129,7 +163,7 @@ describe('ToolCall', () => {
},
];
renderWithRecoil(<ToolCall {...mockProps} attachments={attachments as any} />);
renderWithJotai(<ToolCall {...mockProps} attachments={attachments} />);
fireEvent.click(screen.getByTestId('progress-text'));
@ -141,7 +175,7 @@ describe('ToolCall', () => {
});
it('should pass empty array when no attachments', () => {
renderWithRecoil(<ToolCall {...mockProps} />);
renderWithJotai(<ToolCall {...mockProps} />);
fireEvent.click(screen.getByTestId('progress-text'));
@ -172,7 +206,7 @@ describe('ToolCall', () => {
},
];
renderWithRecoil(<ToolCall {...mockProps} attachments={attachments as any} />);
renderWithJotai(<ToolCall {...mockProps} attachments={attachments} />);
fireEvent.click(screen.getByTestId('progress-text'));
@ -196,7 +230,7 @@ describe('ToolCall', () => {
},
];
renderWithRecoil(<ToolCall {...mockProps} attachments={attachments as any} />);
renderWithJotai(<ToolCall {...mockProps} attachments={attachments} />);
const attachmentGroup = screen.getByTestId('attachment-group');
expect(attachmentGroup).toBeInTheDocument();
@ -204,13 +238,13 @@ describe('ToolCall', () => {
});
it('should not render AttachmentGroup when no attachments', () => {
renderWithRecoil(<ToolCall {...mockProps} />);
renderWithJotai(<ToolCall {...mockProps} />);
expect(screen.queryByTestId('attachment-group')).not.toBeInTheDocument();
});
it('should not render AttachmentGroup when attachments is empty array', () => {
renderWithRecoil(<ToolCall {...mockProps} attachments={[]} />);
renderWithJotai(<ToolCall {...mockProps} attachments={[]} />);
expect(screen.queryByTestId('attachment-group')).not.toBeInTheDocument();
});
@ -218,7 +252,7 @@ describe('ToolCall', () => {
describe('tool call info visibility', () => {
it('should toggle tool call info expand/collapse when clicking header', () => {
renderWithRecoil(<ToolCall {...mockProps} />);
renderWithJotai(<ToolCall {...mockProps} />);
// ToolCallInfo is always in the DOM (CSS expand/collapse), but initially collapsed
const toolCallInfo = screen.getByTestId('tool-call-info');
@ -233,8 +267,29 @@ describe('ToolCall', () => {
expect(screen.getByTestId('tool-call-info')).toBeInTheDocument();
});
it('should pass input and output props to ToolCallInfo', () => {
renderWithRecoil(<ToolCall {...mockProps} />);
it('should pass all required props to ToolCallInfo', () => {
const attachments = [
{
type: Tools.ui_resources,
messageId: 'msg123',
toolCallId: 'tool456',
conversationId: 'conv789',
[Tools.ui_resources]: {
'0': { type: 'button', label: 'Test' },
},
},
];
// Use a name with domain separator (_action_) and domain separator (---)
const propsWithDomain = {
...mockProps,
name: 'testFunction_action_test---domain---com',
attachments,
};
renderWithJotai(<ToolCall {...propsWithDomain} />);
fireEvent.click(screen.getByTestId('progress-text'));
const toolCallInfo = screen.getByTestId('tool-call-info');
const props = JSON.parse(toolCallInfo.textContent!);
@ -249,9 +304,10 @@ describe('ToolCall', () => {
const originalOpen = window.open;
window.open = jest.fn();
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
output={undefined}
initialProgress={0.5} // Less than 1 so it's not complete
auth="https://auth.example.com"
isSubmitting={true} // Should be submitting for auth to show
@ -272,7 +328,7 @@ describe('ToolCall', () => {
});
it('should not show auth section when cancelled', () => {
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
auth="https://auth.example.com"
@ -285,7 +341,7 @@ describe('ToolCall', () => {
});
it('should not show auth section when progress is complete', () => {
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
auth="https://auth.example.com"
@ -300,7 +356,9 @@ describe('ToolCall', () => {
describe('edge cases', () => {
it('should handle undefined args', () => {
renderWithRecoil(<ToolCall {...mockProps} args={undefined as any} />);
renderWithJotai(<ToolCall {...mockProps} args={undefined} />);
fireEvent.click(screen.getByTestId('progress-text'));
const toolCallInfo = screen.getByTestId('tool-call-info');
const props = JSON.parse(toolCallInfo.textContent!);
@ -308,15 +366,17 @@ describe('ToolCall', () => {
});
it('should handle null output', () => {
renderWithRecoil(<ToolCall {...mockProps} output={null} />);
renderWithJotai(<ToolCall {...mockProps} output={null} />);
const toolCallInfo = screen.getByTestId('tool-call-info');
const props = JSON.parse(toolCallInfo.textContent!);
expect(props.output).toBeNull();
});
it('should handle simple function name without domain', () => {
renderWithRecoil(<ToolCall {...mockProps} name="simpleName" />);
it('should handle missing domain', () => {
renderWithJotai(<ToolCall {...mockProps} domain={undefined} authDomain={undefined} />);
fireEvent.click(screen.getByTestId('progress-text'));
const toolCallInfo = screen.getByTestId('tool-call-info');
expect(toolCallInfo).toBeInTheDocument();
@ -344,7 +404,7 @@ describe('ToolCall', () => {
},
];
renderWithRecoil(<ToolCall {...mockProps} attachments={complexAttachments as any} />);
renderWithJotai(<ToolCall {...mockProps} attachments={complexAttachments} />);
fireEvent.click(screen.getByTestId('progress-text'));
@ -361,7 +421,7 @@ describe('ToolCall', () => {
const d = Constants.mcp_delimiter;
it('should detect MCP OAuth from delimiter in tool-call name', () => {
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name={`oauth${d}my-server`}
@ -375,7 +435,7 @@ describe('ToolCall', () => {
});
it('should preserve full server name when it contains the delimiter substring', () => {
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name={`oauth${d}foo${d}bar`}
@ -389,7 +449,7 @@ describe('ToolCall', () => {
});
it('should display server name (not "oauth") as function_name for OAuth tool calls', () => {
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name={`oauth${d}my-server`}
@ -407,7 +467,7 @@ describe('ToolCall', () => {
it('should display server name even when auth is cleared (post-completion)', () => {
// After OAuth completes, createOAuthEnd re-emits the toolCall without auth.
// The display should still show the server name, not literal "oauth".
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name={`oauth${d}my-server`}
@ -425,7 +485,7 @@ describe('ToolCall', () => {
const authUrl =
'https://oauth.example.com/authorize?redirect_uri=' +
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name="bare_name"
@ -442,7 +502,7 @@ describe('ToolCall', () => {
const authUrl =
'https://oauth.example.com/authorize?redirect_uri=' +
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name="bare_name"
@ -462,7 +522,7 @@ describe('ToolCall', () => {
// gets prefixed to oauth_mcp_oauth_mcp_server. Client parses:
// func="oauth", server="oauth_mcp_server". Visually awkward but
// semantically correct — the normalized name IS oauth_mcp_server.
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name={`oauth${d}oauth${d}server`}
@ -479,7 +539,7 @@ describe('ToolCall', () => {
const authUrl =
'https://oauth.example.com/authorize?redirect_uri=' +
encodeURIComponent('https://app.example.com/api/actions/xyz/oauth/callback');
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
name="action_name"
@ -494,7 +554,7 @@ describe('ToolCall', () => {
describe('A11Y-04: screen reader status announcements', () => {
it('includes sr-only aria-live region for status announcements', () => {
renderWithRecoil(
renderWithJotai(
<ToolCall
{...mockProps}
initialProgress={1}
@ -509,4 +569,143 @@ describe('ToolCall', () => {
expect(liveRegion!.className).toContain('sr-only');
});
});
describe('getInProgressText - MCP progress display', () => {
it('shows mcpProgress.message when available', () => {
mockUseAtomValue.mockReturnValue({
progress: 2,
total: 10,
message: 'Fetching data from API...',
timestamp: Date.now(),
});
renderWithJotai(
<ToolCall
{...mockProps}
output={undefined}
initialProgress={0.1}
isSubmitting={true}
toolCallId="call-123"
/>,
);
expect(screen.getByTestId('progress-text')).toHaveTextContent('Fetching data from API...');
});
it('shows "functionName: X/Y" when mcpProgress has total but no message', () => {
mockUseAtomValue.mockReturnValue({
progress: 3,
total: 10,
timestamp: Date.now(),
});
renderWithJotai(
<ToolCall
{...mockProps}
output={undefined}
initialProgress={0.1}
isSubmitting={true}
toolCallId="call-123"
/>,
);
expect(screen.getByTestId('progress-text')).toHaveTextContent('testFunction: 3/10');
});
it('falls back to running_var localisation when no mcpProgress', () => {
mockUseAtomValue.mockReturnValue(undefined);
renderWithJotai(
<ToolCall {...mockProps} output={undefined} initialProgress={0.1} isSubmitting={true} />,
);
expect(screen.getByTestId('progress-text')).toHaveTextContent('Running testFunction');
});
it('prefers message over progress/total when both are present', () => {
mockUseAtomValue.mockReturnValue({
progress: 5,
total: 10,
message: 'Custom status message',
timestamp: Date.now(),
});
renderWithJotai(
<ToolCall
{...mockProps}
output={undefined}
initialProgress={0.1}
isSubmitting={true}
toolCallId="call-123"
/>,
);
expect(screen.getByTestId('progress-text')).toHaveTextContent('Custom status message');
expect(screen.getByTestId('progress-text')).not.toHaveTextContent('testFunction: 5/10');
});
});
describe('toolCallId prop and progress atom integration', () => {
it('passes toolCallId to toolCallProgressFamily when provided', () => {
const { toolCallProgressFamily } = jest.requireMock('~/store/progress');
renderWithJotai(<ToolCall {...mockProps} toolCallId="specific-call-id" />);
expect(toolCallProgressFamily).toHaveBeenCalledWith('specific-call-id');
});
it('passes empty string to toolCallProgressFamily when toolCallId is undefined', () => {
const { toolCallProgressFamily } = jest.requireMock('~/store/progress');
renderWithJotai(<ToolCall {...mockProps} />);
expect(toolCallProgressFamily).toHaveBeenCalledWith('');
});
it('calls clearProgress with toolCallId when output arrives', () => {
mockUseAtomValue.mockReturnValue(mockClearProgress);
renderWithJotai(
<ToolCall {...mockProps} output="Tool completed" toolCallId="call-to-clear" />,
);
expect(mockClearProgress).toHaveBeenCalledWith('call-to-clear');
});
it('does not call clearProgress when toolCallId is undefined', () => {
mockUseAtomValue.mockReturnValue(mockClearProgress);
renderWithJotai(<ToolCall {...mockProps} output="Tool completed" />);
expect(mockClearProgress).not.toHaveBeenCalled();
});
});
describe('cancelled state with hasOutput', () => {
it('is not cancelled when output exists even with low progress', () => {
renderWithJotai(
<ToolCall {...mockProps} output="Result" initialProgress={0.1} isSubmitting={false} />,
);
// When not cancelled and has output → shows finished text, not "Cancelled"
expect(screen.queryByTestId('progress-text')).not.toHaveTextContent('Cancelled');
expect(screen.queryByTestId('progress-text')).toHaveTextContent('Completed testFunction');
});
it('is cancelled when no output and not submitting and progress < 1', () => {
renderWithJotai(
<ToolCall {...mockProps} output={undefined} initialProgress={0.1} isSubmitting={false} />,
);
expect(screen.getByTestId('progress-text')).toHaveTextContent('Cancelled');
});
it('shows finished text when progress is 1 and output is present', () => {
renderWithJotai(
<ToolCall {...mockProps} output="Done" initialProgress={1} isSubmitting={false} />,
);
expect(screen.getByTestId('progress-text')).toHaveTextContent('Completed testFunction');
});
});
});

View file

@ -75,6 +75,15 @@ jest.mock('~/data-provider', () => ({
const mockErrorHandler = jest.fn();
const mockSetIsSubmitting = jest.fn();
const mockClearStepMaps = jest.fn();
const mockHandleProgressEvent = jest.fn();
const mockCleanupProgress = jest.fn();
jest.mock('~/hooks/SSE/useProgressTracking', () => ({
useProgressTracking: () => ({
handleProgressEvent: mockHandleProgressEvent,
cleanupProgress: mockCleanupProgress,
}),
}));
jest.mock('~/hooks/SSE/useEventHandlers', () =>
jest.fn(() => ({
@ -282,3 +291,81 @@ describe('useResumableSSE - 404 error path', () => {
},
);
});
describe('useResumableSSE - progress event integration', () => {
beforeEach(() => {
mockSSEInstances.length = 0;
localStorage.clear();
mockHandleProgressEvent.mockClear();
mockCleanupProgress.mockClear();
});
const renderAndInit = async (conversationId = CONV_ID) => {
const submission = buildSubmission({ conversation: { conversationId } });
const chatHelpers = buildChatHelpers();
const { unmount } = renderHook(() => useResumableSSE(submission, chatHelpers));
await act(async () => {
await Promise.resolve();
});
return { sse: getLastSSE(), unmount, chatHelpers };
};
it('registers a "progress" event listener on the SSE connection', async () => {
const { unmount } = await renderAndInit();
const sse = getLastSSE();
const registeredEvents = sse.addEventListener.mock.calls.map(([event]: [string]) => event);
expect(registeredEvents).toContain('progress');
unmount();
});
it('routes incoming progress events to handleProgressEvent', async () => {
const { unmount } = await renderAndInit();
const sse = getLastSSE();
const progressData = JSON.stringify({
progress: 3,
total: 10,
message: 'Working...',
toolCallId: 'call-abc',
});
await act(async () => {
sse._emit('progress', { data: progressData });
});
expect(mockHandleProgressEvent).toHaveBeenCalledTimes(1);
expect(mockHandleProgressEvent).toHaveBeenCalledWith(
expect.objectContaining({ data: progressData }),
);
unmount();
});
it('calls cleanupProgress on unmount', async () => {
const { unmount } = await renderAndInit();
unmount();
expect(mockCleanupProgress).toHaveBeenCalledTimes(1);
});
it('does not call cleanupProgress before unmount', async () => {
const { unmount } = await renderAndInit();
expect(mockCleanupProgress).not.toHaveBeenCalled();
unmount();
});
it('calls cleanupProgress on unmount even after a 404 error', async () => {
const { sse, unmount } = await renderAndInit();
await act(async () => {
sse._emit('error', { responseCode: 404 });
});
unmount();
expect(mockCleanupProgress).toHaveBeenCalled();
});
});

View file

@ -5,3 +5,5 @@ export { default as useResumeOnLoad } from './useResumeOnLoad';
export { default as useStepHandler } from './useStepHandler';
export { default as useContentHandler } from './useContentHandler';
export { default as useAttachmentHandler } from './useAttachmentHandler';
export { useProgressTracking } from './useProgressTracking';
export type { ProgressState } from '~/store/progress';

View file

@ -0,0 +1,48 @@
import { useRef, useCallback } from 'react';
import { useSetAtom } from 'jotai';
import { toolCallProgressMapAtom } from '~/store/progress';
export function useProgressTracking() {
const setToolCallProgressMap = useSetAtom(toolCallProgressMapAtom);
const progressCleanupTimers = useRef<Map<string, NodeJS.Timeout>>(new Map());
const handleProgressEvent = useCallback(
(e: MessageEvent) => {
try {
const data = JSON.parse(e.data);
const { progress, total, message, toolCallId } = data;
if (toolCallId != null) {
setToolCallProgressMap((currentMap) => {
const newMap = new Map(currentMap);
newMap.set(toolCallId, { progress, total, message, timestamp: Date.now() });
return newMap;
});
if (total && progress >= total) {
const existingTimer = progressCleanupTimers.current.get(toolCallId);
if (existingTimer) clearTimeout(existingTimer);
const timerId = setTimeout(() => {
setToolCallProgressMap((currentMap) => {
const newMap = new Map(currentMap);
newMap.delete(toolCallId);
return newMap;
});
progressCleanupTimers.current.delete(toolCallId);
}, 5000);
progressCleanupTimers.current.set(toolCallId, timerId);
}
}
} catch (error) {
console.error('Error parsing progress event:', error);
}
},
[setToolCallProgressMap],
);
const cleanupProgress = useCallback(() => {
setToolCallProgressMap(new Map());
progressCleanupTimers.current.forEach(clearTimeout);
progressCleanupTimers.current.clear();
}, [setToolCallProgressMap]);
return { handleProgressEvent, cleanupProgress };
}

View file

@ -25,6 +25,7 @@ import {
import type { ActiveJobsResponse } from '~/data-provider';
import { useAuthContext } from '~/hooks/AuthContext';
import useEventHandlers from './useEventHandlers';
import { useProgressTracking } from './useProgressTracking';
import { clearAllDrafts } from '~/utils';
import store from '~/store';
@ -132,6 +133,7 @@ export default function useResumableSSE(
const balanceQuery = useGetUserBalance({
enabled: !!isAuthenticated && startupConfig?.balance?.enabled,
});
const { handleProgressEvent, cleanupProgress } = useProgressTracking();
/**
* Subscribe to stream via SSE library (supports custom headers)
@ -335,6 +337,8 @@ export default function useResumableSSE(
}
});
sse.addEventListener('progress', handleProgressEvent);
/**
* Error event handler - handles BOTH:
* 1. HTTP-level errors (responseCode present) - 404, 401, network failures
@ -552,6 +556,7 @@ export default function useResumableSSE(
balanceQuery,
removeActiveJob,
queryClient,
handleProgressEvent,
],
);
@ -703,6 +708,8 @@ export default function useResumableSSE(
// Reset UI state on cleanup - useResumeOnLoad will restore if needed
setIsSubmitting(false);
setShowStopButton(false);
// Clear progress map and pending cleanup timers on unmount
cleanupProgress();
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [submission]);

View file

@ -10,6 +10,7 @@ import { useGetStartupConfig, useGetUserBalance } from '~/data-provider';
import { useAuthContext } from '~/hooks/AuthContext';
import useEventHandlers from './useEventHandlers';
import { clearAllDrafts } from '~/utils';
import { useProgressTracking } from './useProgressTracking';
import store from '~/store';
type ChatHelpers = Pick<
@ -71,6 +72,7 @@ export default function useSSE(
const balanceQuery = useGetUserBalance({
enabled: !!isAuthenticated && startupConfig?.balance?.enabled,
});
const { handleProgressEvent, cleanupProgress } = useProgressTracking();
useEffect(() => {
if (submission == null || Object.keys(submission).length === 0) {
@ -100,6 +102,8 @@ export default function useSSE(
}
});
sse.addEventListener('progress', handleProgressEvent);
sse.addEventListener('message', (e: MessageEvent) => {
const data = JSON.parse(e.data);
@ -234,6 +238,8 @@ export default function useSSE(
return () => {
const isCancelled = sse.readyState <= 1;
sse.close();
// Clear progress map and pending cleanup timers on unmount
cleanupProgress();
if (isCancelled) {
const e = new Event('cancel');
/* @ts-ignore */

View file

@ -0,0 +1,30 @@
import { atom } from 'jotai';
import { atomFamily } from 'jotai/utils';
export type ProgressState = {
progress: number;
total?: number;
message?: string;
timestamp: number;
};
// Map of toolCallId -> ProgressState (for matching progress to specific tool calls)
export const toolCallProgressMapAtom = atom<Map<string, ProgressState>>(new Map());
// Family for tool call based progress lookup
export const toolCallProgressFamily = atomFamily((toolCallId: string) =>
atom((get) => {
// Don't return data for empty string key
if (!toolCallId) return undefined;
return get(toolCallProgressMapAtom).get(toolCallId);
}),
);
// Cleanup action - remove progress entry for a specific tool call
export const clearToolCallProgressAtom = atom(null, (get, set, toolCallId: string) => {
if (!toolCallId) return;
const current = get(toolCallProgressMapAtom);
const newMap = new Map(current);
newMap.delete(toolCallId);
set(toolCallProgressMapAtom, newMap);
});

View file

@ -272,6 +272,7 @@ Please follow these instructions when using tools from the respective MCP server
oauthEnd,
customUserVars,
graphTokenResolver,
onProgress,
}: {
user?: IUser;
serverName: string;
@ -288,9 +289,19 @@ Please follow these instructions when using tools from the respective MCP server
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
graphTokenResolver?: GraphTokenResolver;
onProgress?: (progressData: {
progressToken: t.ProgressToken;
progress: number;
total?: number;
message?: string;
serverName: string;
}) => void;
}): Promise<t.FormattedToolResponse> {
/** User-specific connection */
let connection: MCPConnection | undefined;
let progressHandler: ((data: t.ProgressNotification & { serverName: string }) => void) | null =
null;
let progressToken: t.ProgressToken | undefined;
const userId = user?.id;
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
@ -348,12 +359,44 @@ Please follow these instructions when using tools from the respective MCP server
connection.setRequestHeaders(currentOptions.headers || {});
}
// Generate and register progress token BEFORE setting up listener
// to avoid race condition where progress arrives before token is registered
progressToken = connection.generateProgressToken();
connection.registerProgressToken(progressToken);
logger.info(
`${logPrefix}[${toolName}] Progress token generated and registered: ${progressToken}`,
);
// Set up progress event listener if callback is provided
if (onProgress) {
progressHandler = (data) => {
logger.info(`${logPrefix}[${toolName}] Progress event received:`, {
progressToken: data.progressToken,
progress: data.progress,
total: data.total,
message: data.message,
});
onProgress({
progressToken: data.progressToken,
progress: data.progress,
total: data.total,
message: data.message,
serverName: data.serverName,
});
};
connection.on('progress', progressHandler);
logger.info(`${logPrefix}[${toolName}] Progress listener registered`);
}
const result = await connection.client.request(
{
method: 'tools/call',
params: {
name: toolName,
arguments: toolArguments,
_meta: {
progressToken,
},
},
},
CallToolResultSchema,
@ -367,12 +410,30 @@ Please follow these instructions when using tools from the respective MCP server
this.updateUserLastActivity(userId);
}
this.checkIdleConnections();
return formatToolContent(result as t.MCPToolCallResponse, provider);
// Format and return the tool response as a proper tuple [content, artifacts]
// Progress is handled separately via SSE events emitted by the connection
const formattedResponse = formatToolContent(result as t.MCPToolCallResponse, provider);
return formattedResponse;
} catch (error) {
// Log with context and re-throw or handle as needed
logger.error(`${logPrefix}[${toolName}] Tool call failed`, error);
// Rethrowing allows the caller (createMCPTool) to handle the final user message
throw error;
} finally {
if (connection && (progressHandler || progressToken)) {
setTimeout(() => {
// Clean up progress listener to prevent memory leaks
if (progressHandler) {
connection?.off('progress', progressHandler);
}
// Clean up progress token to prevent memory leak
if (progressToken) {
connection?.unregisterProgressToken(progressToken);
}
}, 500);
}
}
}
}

View file

@ -8,6 +8,7 @@ import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPConnection } from '~/mcp/connection';
import { MCPManager } from '~/mcp/MCPManager';
import { EventEmitter } from 'events';
import * as graphUtils from '~/utils/graph';
// Mock external dependencies
@ -429,6 +430,10 @@ describe('MCPManager', () => {
isError: false,
}),
},
generateProgressToken: jest.fn().mockReturnValue('mock-progress-token'),
registerProgressToken: jest.fn(),
subscribeToProgress: jest.fn().mockReturnValue(() => {}),
unregisterProgressToken: jest.fn(),
} as unknown as MCPConnection;
const mockGraphTokenResolver: GraphTokenResolver = jest.fn().mockResolvedValue({
@ -966,4 +971,396 @@ describe('MCPManager', () => {
);
});
});
describe('callTool Progress Integration', () => {
const serverName = 'test_server';
const mockUser: Partial<IUser> = {
id: 'user-123',
provider: 'openid',
openidId: 'oidc-sub-456',
};
const mockFlowManager = {
getState: jest.fn(),
setState: jest.fn(),
clearState: jest.fn(),
};
function buildMockConnection(overrides: Partial<Record<string, unknown>> = {}) {
const emitter = new EventEmitter();
return {
isConnected: jest.fn().mockResolvedValue(true),
setRequestHeaders: jest.fn(),
timeout: 30000,
on: jest.fn((event, handler) => emitter.on(event, handler)),
off: jest.fn((event, handler) => emitter.off(event, handler)),
emit: (event: string, data: unknown) => emitter.emit(event, data),
client: {
request: jest.fn().mockResolvedValue({
content: [{ type: 'text', text: 'Tool result' }],
isError: false,
}),
},
generateProgressToken: jest.fn().mockReturnValue('mock-progress-token'),
registerProgressToken: jest.fn(),
unregisterProgressToken: jest.fn(),
...overrides,
} as unknown as MCPConnection;
}
function mockAppConnections(connection: MCPConnection) {
(ConnectionsRepository as jest.MockedClass<typeof ConnectionsRepository>).mockImplementation(
() =>
({
has: jest.fn().mockResolvedValue(false),
get: jest.fn().mockResolvedValue(connection),
}) as unknown as ConnectionsRepository,
);
}
function newMCPServersConfig(): t.MCPServers {
return { [serverName]: { type: 'stdio', command: 'test', args: [] } };
}
beforeEach(() => {
(MCPManager as unknown as { instance: null }).instance = null;
jest.clearAllMocks();
(MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined);
(mockRegistryInstance.getAllServerConfigs as jest.Mock).mockResolvedValue({});
(graphUtils.preProcessGraphTokens as jest.Mock).mockImplementation(
async (options) => options,
);
mockRegistryInstance.getServerConfig.mockResolvedValue({
type: 'sse',
url: 'https://api.example.com',
});
});
describe('progress token lifecycle', () => {
it('generates and registers a progress token before the tool call', async () => {
const connection = buildMockConnection();
mockAppConnections(connection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
});
expect(connection.generateProgressToken).toHaveBeenCalledTimes(1);
expect(connection.registerProgressToken).toHaveBeenCalledWith('mock-progress-token');
});
it('passes _meta.progressToken in the MCP request params', async () => {
const connection = buildMockConnection();
mockAppConnections(connection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
});
expect(connection.client.request).toHaveBeenCalledWith(
expect.objectContaining({
params: expect.objectContaining({
_meta: { progressToken: 'mock-progress-token' },
}),
}),
expect.anything(),
expect.anything(),
);
});
it('schedules unregisterProgressToken cleanup in finally block after success', async () => {
jest.useFakeTimers();
const connection = buildMockConnection();
mockAppConnections(connection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
});
expect(connection.unregisterProgressToken).not.toHaveBeenCalled();
jest.advanceTimersByTime(600);
expect(connection.unregisterProgressToken).toHaveBeenCalledWith('mock-progress-token');
jest.useRealTimers();
});
it('schedules cleanup in finally block even when tool call throws', async () => {
jest.useFakeTimers();
const connection = buildMockConnection({
client: {
request: jest.fn().mockRejectedValue(new Error('MCP server error')),
},
});
mockAppConnections(connection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await expect(
manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
}),
).rejects.toThrow('MCP server error');
jest.advanceTimersByTime(600);
expect(connection.unregisterProgressToken).toHaveBeenCalledWith('mock-progress-token');
jest.useRealTimers();
});
});
describe('onProgress callback', () => {
it('does not register a progress listener when onProgress is not provided', async () => {
const connection = buildMockConnection();
mockAppConnections(connection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
});
expect(connection.on).not.toHaveBeenCalledWith('progress', expect.any(Function));
});
it('registers a progress listener when onProgress is provided', async () => {
const connection = buildMockConnection();
mockAppConnections(connection);
const onProgress = jest.fn();
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
onProgress,
});
expect(connection.on).toHaveBeenCalledWith('progress', expect.any(Function));
});
it('calls onProgress with the correct shape when a progress event fires', async () => {
const emitter = new EventEmitter();
const connection = buildMockConnection({
on: jest.fn((event, handler) => emitter.on(event, handler)),
off: jest.fn((event, handler) => emitter.off(event, handler)),
client: {
request: jest.fn().mockImplementation(async () => {
emitter.emit('progress', {
serverName,
progressToken: 'mock-progress-token',
progress: 3,
total: 10,
message: 'Processing...',
});
return { content: [{ type: 'text', text: 'done' }], isError: false };
}),
},
});
mockAppConnections(connection);
const onProgress = jest.fn();
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
onProgress,
});
expect(onProgress).toHaveBeenCalledTimes(1);
expect(onProgress).toHaveBeenCalledWith({
progressToken: 'mock-progress-token',
progress: 3,
total: 10,
message: 'Processing...',
serverName,
});
});
it('calls onProgress multiple times as progress events accumulate', async () => {
const emitter = new EventEmitter();
const connection = buildMockConnection({
on: jest.fn((event, handler) => emitter.on(event, handler)),
off: jest.fn((event, handler) => emitter.off(event, handler)),
client: {
request: jest.fn().mockImplementation(async () => {
for (let i = 1; i <= 3; i++) {
emitter.emit('progress', {
serverName,
progressToken: 'mock-progress-token',
progress: i,
total: 3,
});
}
return { content: [{ type: 'text', text: 'done' }], isError: false };
}),
},
});
mockAppConnections(connection);
const onProgress = jest.fn();
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
onProgress,
});
expect(onProgress).toHaveBeenCalledTimes(3);
});
it('detaches the progress listener after call completes', async () => {
jest.useFakeTimers();
const connection = buildMockConnection();
mockAppConnections(connection);
const onProgress = jest.fn();
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
onProgress,
});
jest.advanceTimersByTime(600);
expect(connection.off).toHaveBeenCalledWith('progress', expect.any(Function));
jest.useRealTimers();
});
it('detaches the progress listener after call throws', async () => {
jest.useFakeTimers();
const connection = buildMockConnection({
client: {
request: jest.fn().mockRejectedValue(new Error('fail')),
},
});
mockAppConnections(connection);
const onProgress = jest.fn();
const manager = await MCPManager.createInstance(newMCPServersConfig());
await expect(
manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
onProgress,
}),
).rejects.toThrow('fail');
jest.advanceTimersByTime(600);
expect(connection.off).toHaveBeenCalledWith('progress', expect.any(Function));
jest.useRealTimers();
});
it('logs each received progress event at info level', async () => {
const emitter = new EventEmitter();
const connection = buildMockConnection({
on: jest.fn((event, handler) => emitter.on(event, handler)),
off: jest.fn((event, handler) => emitter.off(event, handler)),
client: {
request: jest.fn().mockImplementation(async () => {
emitter.emit('progress', {
serverName,
progressToken: 'mock-progress-token',
progress: 1,
total: 5,
message: 'Step 1',
});
return { content: [{ type: 'text', text: 'done' }], isError: false };
}),
},
});
mockAppConnections(connection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
onProgress: jest.fn(),
});
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('Progress event received'),
expect.objectContaining({ progress: 1, total: 5, message: 'Step 1' }),
);
});
it('logs progress token registration at info level', async () => {
const connection = buildMockConnection();
mockAppConnections(connection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.callTool({
user: mockUser as IUser,
serverName,
toolName: 'test_tool',
provider: 'openai',
flowManager: mockFlowManager as unknown as Parameters<
typeof manager.callTool
>[0]['flowManager'],
});
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('Progress token generated and registered'),
);
});
});
});
});

View file

@ -8,7 +8,10 @@ import {
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import {
ResourceListChangedNotificationSchema,
ProgressNotificationSchema,
} from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type {
@ -368,6 +371,13 @@ export class MCPConnection extends EventEmitter {
}
}
// Progress notification state
private progressTokenCounter = 0;
private activeProgressTokens = new Map<t.ProgressToken, t.ProgressState>();
private lastProgressEmit = new Map<t.ProgressToken, number>();
private progressCleanupTimers = new Map<t.ProgressToken, NodeJS.Timeout>();
private readonly PROGRESS_THROTTLE_MS = 200; // Max 5 updates/second
setRequestHeaders(headers: Record<string, string> | null): void {
if (!headers) {
return;
@ -686,6 +696,7 @@ export class MCPConnection extends EventEmitter {
});
this.subscribeToResources();
this.subscribeToProgress();
}
private async handleReconnection(): Promise<void> {
@ -767,6 +778,125 @@ export class MCPConnection extends EventEmitter {
});
}
/**
* Generates a unique progress token for tracking tool call progress
*/
public generateProgressToken(): t.ProgressToken {
const userId = this.userId ? `${this.userId}-` : '';
return `${userId}${this.serverName}-${++this.progressTokenCounter}-${Date.now()}`;
}
/**
* Registers a progress token for tracking
*/
public registerProgressToken(token: t.ProgressToken): void {
this.activeProgressTokens.set(token, {
token,
progress: 0,
timestamp: Date.now(),
});
}
/**
* Gets the current progress state for a token
*/
public getProgressState(token: t.ProgressToken): t.ProgressState | undefined {
return this.activeProgressTokens.get(token);
}
/**
* Unregisters a progress token and cleans up associated state
*/
public unregisterProgressToken(token: t.ProgressToken): void {
this.activeProgressTokens.delete(token);
this.lastProgressEmit.delete(token);
// Cancel any pending cleanup timer for this token
const existingTimer = this.progressCleanupTimers.get(token);
if (existingTimer) {
clearTimeout(existingTimer);
this.progressCleanupTimers.delete(token);
}
}
/**
* Checks if progress should be emitted (rate limiting)
*/
private shouldEmitProgress(token: t.ProgressToken): boolean {
const lastEmit = this.lastProgressEmit.get(token) || 0;
const now = Date.now();
if (now - lastEmit < this.PROGRESS_THROTTLE_MS) {
return false;
}
this.lastProgressEmit.set(token, now);
return true;
}
/**
* Subscribes to progress notifications from MCP server
*/
private subscribeToProgress(): void {
try {
this.client.setNotificationHandler(ProgressNotificationSchema, async (notification) => {
try {
logger.info(
`${this.getLogPrefix()} Received progress notification:`,
notification.params,
);
const { progressToken, progress, total, message } = notification.params;
// Validate token
if (!this.activeProgressTokens.has(progressToken)) {
logger.debug(`${this.getLogPrefix()} Progress for unknown token: ${progressToken}`);
return;
}
// Rate limiting
if (!this.shouldEmitProgress(progressToken) && (!total || progress < total)) {
return;
}
// Update state
const newState: t.ProgressState = {
token: progressToken,
progress,
total,
message,
timestamp: Date.now(),
};
this.activeProgressTokens.set(progressToken, newState);
// Emit progress event
this.emit('progress', {
serverName: this.serverName,
progressToken,
progress,
total,
message,
});
// Cleanup if complete
if (total && progress >= total) {
// Clear existing timer if present to prevent duplicates
const existingTimer = this.progressCleanupTimers.get(progressToken);
if (existingTimer) {
clearTimeout(existingTimer);
}
const timerId = setTimeout(() => {
this.activeProgressTokens.delete(progressToken);
this.lastProgressEmit.delete(progressToken);
this.progressCleanupTimers.delete(progressToken);
}, 5000);
this.progressCleanupTimers.set(progressToken, timerId);
}
} catch (error) {
logger.error(`${this.getLogPrefix()} Error handling progress:`, error);
}
});
} catch (error) {
logger.warn(`${this.getLogPrefix()} Failed to setup progress notifications:`, error);
}
}
async connectClient(): Promise<void> {
if (this.connectionState === 'connected') {
return;
@ -1115,6 +1245,12 @@ export class MCPConnection extends EventEmitter {
this.emit('connectionChange', 'disconnected');
} finally {
this.connectPromise = null;
// Clean up progress tracking state to prevent memory leaks
this.activeProgressTokens.clear();
this.lastProgressEmit.clear();
// Clear all pending cleanup timers
this.progressCleanupTimers.forEach((timer) => clearTimeout(timer));
this.progressCleanupTimers.clear();
if (!resetCycleTracking) {
this.recordCycle();
}

View file

@ -72,6 +72,23 @@ export type MCPToolCallResponse =
isError?: boolean;
};
export type ProgressToken = string | number;
export interface ProgressNotification {
progressToken: ProgressToken;
progress: number;
total?: number;
message?: string;
}
export interface ProgressState {
token: ProgressToken;
progress: number;
total?: number;
message?: string;
timestamp: number;
}
export type Provider =
| 'google'
| 'anthropic'
@ -142,6 +159,10 @@ export type FormattedContentResult = [string, Artifacts | undefined];
export type ImageFormatter = (item: ImageContent) => FormattedContent;
/**
* Tool response for MCP tools - must be a proper two-tuple for LangChain's content_and_artifact format.
* Progress notifications are handled separately via SSE events, not attached to the response.
*/
export type FormattedToolResponse = FormattedContentResult;
/**

View file

@ -915,6 +915,26 @@ class GenerationJobManagerClass {
await this.eventTransport.emitChunk(streamId, event);
}
/**
* Emit an ephemeral event directly to live subscribers, bypassing the
* earlyEventBuffer, Redis persistence, and trackUserMessage side-effects.
*
* Use for fire-and-forget events like MCP tool-call progress that must not
* be replayed to reconnecting clients. Silently drops when no subscriber
* is connected or the job has been aborted.
*/
async emitTransientEvent(streamId: string, event: t.ServerSentEvent): Promise<void> {
const runtime = this.runtimeState.get(streamId);
if (!runtime || runtime.abortController.signal.aborted) {
return;
}
if (!runtime.hasSubscriber) {
return;
}
// Emit directly to transport, bypassing appendChunk and earlyEventBuffer
await this.eventTransport.emitChunk(streamId, event);
}
/**
* Extract and save run step from event data.
* The data is already the run step object from the event payload.

View file

@ -0,0 +1,119 @@
import type * as t from '~/types';
interface RuntimeState {
abortController: AbortController;
hasSubscriber: boolean;
}
class GenerationJobManagerStub {
runtimeState = new Map<string, RuntimeState>();
eventTransport = { emitChunk: jest.fn() };
async emitTransientEvent(streamId: string, event: t.ServerSentEvent): Promise<void> {
const runtime = this.runtimeState.get(streamId);
if (!runtime || runtime.abortController.signal.aborted) {
return;
}
if (!runtime.hasSubscriber) {
return;
}
await this.eventTransport.emitChunk(streamId, event);
}
}
function makeRuntime(overrides: Partial<RuntimeState> = {}): RuntimeState {
return {
abortController: new AbortController(),
hasSubscriber: true,
...overrides,
};
}
describe('GenerationJobManager - emitTransientEvent', () => {
let manager: GenerationJobManagerStub;
const streamId = 'stream-abc-123';
const progressEvent: t.ServerSentEvent = {
event: 'progress',
data: { progress: 2, total: 5, message: 'Working…', toolCallId: 'call-1' },
} as unknown as t.ServerSentEvent;
beforeEach(() => {
manager = new GenerationJobManagerStub();
jest.clearAllMocks();
});
it('emits to transport when runtime exists and has subscriber', async () => {
manager.runtimeState.set(streamId, makeRuntime());
await manager.emitTransientEvent(streamId, progressEvent);
expect(manager.eventTransport.emitChunk).toHaveBeenCalledTimes(1);
expect(manager.eventTransport.emitChunk).toHaveBeenCalledWith(streamId, progressEvent);
});
it('silently drops when streamId has no runtime entry', async () => {
await manager.emitTransientEvent('unknown-stream', progressEvent);
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
});
it('silently drops when job has been aborted', async () => {
const runtime = makeRuntime();
runtime.abortController.abort();
manager.runtimeState.set(streamId, runtime);
await manager.emitTransientEvent(streamId, progressEvent);
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
});
it('silently drops when there is no active subscriber', async () => {
manager.runtimeState.set(streamId, makeRuntime({ hasSubscriber: false }));
await manager.emitTransientEvent(streamId, progressEvent);
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
});
it('does not persist the event (calls emitChunk directly, not appendChunk)', async () => {
const appendChunk = jest.fn();
(manager as unknown as Record<string, unknown>).appendChunk = appendChunk;
manager.runtimeState.set(streamId, makeRuntime());
await manager.emitTransientEvent(streamId, progressEvent);
expect(appendChunk).not.toHaveBeenCalled();
expect(manager.eventTransport.emitChunk).toHaveBeenCalled();
});
it('forwards any event shape without mutation', async () => {
manager.runtimeState.set(streamId, makeRuntime());
const customEvent = { event: 'progress', data: { foo: 'bar' } } as unknown as t.ServerSentEvent;
await manager.emitTransientEvent(streamId, customEvent);
expect(manager.eventTransport.emitChunk).toHaveBeenCalledWith(streamId, customEvent);
});
it('handles transport throwing without crashing the caller', async () => {
manager.runtimeState.set(streamId, makeRuntime());
manager.eventTransport.emitChunk.mockRejectedValueOnce(new Error('transport error'));
await expect(manager.emitTransientEvent(streamId, progressEvent)).rejects.toThrow(
'transport error',
);
});
it('does not emit after abort even if subscriber flag is still true', async () => {
const runtime = makeRuntime({ hasSubscriber: true });
manager.runtimeState.set(streamId, runtime);
// Abort happens between registration and emit
runtime.abortController.abort();
await manager.emitTransientEvent(streamId, progressEvent);
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
});
});

View file

@ -1,6 +1,25 @@
import type { Response as ServerResponse } from 'express';
import type { ServerSentEvent } from '~/types';
/**
* Safely writes to a server response, handling disconnected clients.
* @param res - The server response.
* @param data - The data to write.
* @returns true if write succeeded, false otherwise.
*/
function safeWrite(res: ServerResponse, data: string): boolean {
try {
if (!res.writable) {
return false;
}
res.write(data);
return true;
} catch {
// Client may have disconnected - log but don't crash
return false;
}
}
/**
* Sends a Server-Sent Event to the client.
* Empty-string StreamEvent data is silently dropped.
@ -9,7 +28,7 @@ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void {
if ('data' in event && typeof event.data === 'string' && event.data.length === 0) {
return;
}
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
safeWrite(res, `event: message\ndata: ${JSON.stringify(event)}\n\n`);
}
/**
@ -18,6 +37,29 @@ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void {
* @param message - The error message.
*/
export function handleError(res: ServerResponse, message: string): void {
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
res.end();
safeWrite(res, `event: error\ndata: ${JSON.stringify(message)}\n\n`);
try {
res.end();
} catch {
// Client may have disconnected
}
}
/**
* Sends progress notification in Server Sent Events format.
* @param res - The server response.
* @param progressData - Progress notification data.
*/
export function sendProgress(
res: ServerResponse,
progressData: {
progressToken: string | number;
progress: number;
total?: number;
message?: string;
serverName?: string;
toolCallId?: string; // Tool call ID for matching progress to specific tool call
},
): void {
safeWrite(res, `event: progress\ndata: ${JSON.stringify(progressData)}\n\n`);
}