mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
add progress notification feature
This commit is contained in:
parent
b4d97bd888
commit
be0619755d
18 changed files with 1311 additions and 57 deletions
|
|
@ -109,11 +109,24 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
|||
}
|
||||
};
|
||||
|
||||
const onChunk = (event) => {
|
||||
if (!res.writableEnded) {
|
||||
if (event.event === 'progress') {
|
||||
res.write(`event: progress\ndata: ${JSON.stringify(event.data)}\n\n`);
|
||||
} else {
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
}
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let result;
|
||||
|
||||
if (isResume) {
|
||||
const { subscription, resumeState, pendingEvents } =
|
||||
await GenerationJobManager.subscribeWithResume(streamId, writeEvent, onDone, onError);
|
||||
await GenerationJobManager.subscribeWithResume(streamId, onChunk, onDone, onError);
|
||||
|
||||
if (!res.writableEnded) {
|
||||
if (resumeState) {
|
||||
|
|
@ -139,7 +152,7 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
|||
|
||||
result = subscription;
|
||||
} else {
|
||||
result = await GenerationJobManager.subscribe(streamId, writeEvent, onDone, onError);
|
||||
result = await GenerationJobManager.subscribe(streamId, onChunk, onDone, onError);
|
||||
}
|
||||
|
||||
if (!result) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ const {
|
|||
} = require('@librechat/agents');
|
||||
const {
|
||||
sendEvent,
|
||||
sendProgress,
|
||||
MCPOAuthHandler,
|
||||
isMCPDomainAllowed,
|
||||
normalizeServerName,
|
||||
|
|
@ -648,6 +649,30 @@ function createToolInstance({
|
|||
oauthStart,
|
||||
oauthEnd,
|
||||
graphTokenResolver: getGraphApiToken,
|
||||
onProgress: async (progressData) => {
|
||||
logger.debug(
|
||||
`[MCP][${serverName}][${toolName}] Sending progress to client:`,
|
||||
progressData,
|
||||
);
|
||||
const eventData = {
|
||||
progress: progressData.progress,
|
||||
total: progressData.total,
|
||||
message: progressData.message,
|
||||
toolCallId: toolCall.id,
|
||||
};
|
||||
try {
|
||||
if (streamId) {
|
||||
await GenerationJobManager.emitTransientEvent(streamId, {
|
||||
event: 'progress',
|
||||
data: eventData,
|
||||
});
|
||||
} else {
|
||||
sendProgress(res, eventData);
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`[MCP][${serverName}][${toolName}] Failed to emit progress:`, err);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
|
||||
|
|
|
|||
|
|
@ -180,6 +180,7 @@ const Part = memo(function Part({
|
|||
attachments={attachments}
|
||||
auth={toolCall.auth}
|
||||
isLast={isLast}
|
||||
toolCallId={toolCall.id}
|
||||
/>
|
||||
);
|
||||
} else if (toolCall.type === ToolCallTypes.CODE_INTERPRETER) {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import { useMemo, useState, useEffect, useCallback } from 'react';
|
|||
import { useRecoilValue } from 'recoil';
|
||||
import { Button } from '@librechat/client';
|
||||
import { TriangleAlert } from 'lucide-react';
|
||||
import { useAtomValue, useSetAtom } from 'jotai';
|
||||
import {
|
||||
Constants,
|
||||
dataService,
|
||||
|
|
@ -17,6 +18,32 @@ import ToolCallInfo from './ToolCallInfo';
|
|||
import ProgressText from './ProgressText';
|
||||
import { logger } from '~/utils';
|
||||
import store from '~/store';
|
||||
import {
|
||||
toolCallProgressFamily,
|
||||
clearToolCallProgressAtom,
|
||||
type ProgressState,
|
||||
} from '~/store/progress';
|
||||
|
||||
/**
|
||||
* Gets the in-progress text to display for a tool call.
|
||||
* Prioritizes MCP progress message, then progress/total, then default localized text.
|
||||
*/
|
||||
function getInProgressText(
|
||||
mcpProgress: ProgressState | undefined,
|
||||
functionName: string,
|
||||
localize: ReturnType<typeof useLocalize>,
|
||||
): string {
|
||||
if (mcpProgress?.message) {
|
||||
return mcpProgress.message;
|
||||
}
|
||||
if (mcpProgress?.total) {
|
||||
return `${functionName}: ${mcpProgress.progress}/${mcpProgress.total}`;
|
||||
}
|
||||
if (functionName) {
|
||||
return localize('com_assistants_running_var', { 0: functionName });
|
||||
}
|
||||
return localize('com_assistants_running_action');
|
||||
}
|
||||
|
||||
export default function ToolCall({
|
||||
initialProgress = 0.1,
|
||||
|
|
@ -27,6 +54,7 @@ export default function ToolCall({
|
|||
output,
|
||||
attachments,
|
||||
auth,
|
||||
toolCallId,
|
||||
}: {
|
||||
initialProgress: number;
|
||||
isLast?: boolean;
|
||||
|
|
@ -36,6 +64,7 @@ export default function ToolCall({
|
|||
output?: string | null;
|
||||
attachments?: TAttachment[];
|
||||
auth?: string;
|
||||
toolCallId?: string;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const autoExpand = useRecoilValue(store.autoExpandTools);
|
||||
|
|
@ -130,10 +159,6 @@ export default function ToolCall({
|
|||
window.open(auth, '_blank', 'noopener,noreferrer');
|
||||
}, [auth, isMCPToolCall, mcpServerName, actionId]);
|
||||
|
||||
const hasError = typeof output === 'string' && isError(output);
|
||||
const cancelled = !isSubmitting && initialProgress < 1 && !hasError;
|
||||
const errorState = hasError;
|
||||
|
||||
const args = useMemo(() => {
|
||||
if (typeof _args === 'string') {
|
||||
return _args;
|
||||
|
|
@ -158,8 +183,32 @@ export default function ToolCall({
|
|||
return parsedAuthUrl?.hostname ?? '';
|
||||
}, [parsedAuthUrl]);
|
||||
|
||||
const progress = useProgress(initialProgress);
|
||||
const showCancelled = cancelled || (errorState && !output);
|
||||
// Get simulated progress
|
||||
const simulatedProgress = useProgress(initialProgress);
|
||||
|
||||
// Get real-time progress from MCP server by tool call ID
|
||||
const mcpProgress = useAtomValue(toolCallProgressFamily(toolCallId ?? ''));
|
||||
const clearProgress = useSetAtom(clearToolCallProgressAtom);
|
||||
|
||||
// Clean up progress data when tool completes
|
||||
useEffect(() => {
|
||||
if (hasOutput && toolCallId) {
|
||||
clearProgress(toolCallId);
|
||||
}
|
||||
}, [hasOutput, toolCallId, clearProgress]);
|
||||
|
||||
// If tool has output, it's completed (progress = 1), otherwise use simulated progress
|
||||
const progress = useMemo(() => {
|
||||
if (hasOutput) {
|
||||
return 1;
|
||||
}
|
||||
return simulatedProgress;
|
||||
}, [hasOutput, simulatedProgress]);
|
||||
|
||||
const hasError = typeof output === 'string' && isError(output);
|
||||
const errorState = hasError;
|
||||
const cancelled = (!isSubmitting && progress < 1 && !hasOutput) || errorState;
|
||||
const showCancelled = (cancelled && !hasOutput) || (errorState && !output);
|
||||
|
||||
const subtitle = useMemo(() => {
|
||||
if (isMCPToolCall && mcpServerName) {
|
||||
|
|
@ -191,24 +240,15 @@ export default function ToolCall({
|
|||
return (
|
||||
<>
|
||||
<span className="sr-only" aria-live="polite" aria-atomic="true">
|
||||
{(() => {
|
||||
if (progress < 1 && !showCancelled) {
|
||||
return function_name
|
||||
? localize('com_assistants_running_var', { 0: function_name })
|
||||
: localize('com_assistants_running_action');
|
||||
}
|
||||
return getFinishedText();
|
||||
})()}
|
||||
{progress < 1 && !showCancelled
|
||||
? getInProgressText(mcpProgress, function_name, localize)
|
||||
: getFinishedText()}
|
||||
</span>
|
||||
<div className="relative my-1.5 flex h-5 shrink-0 items-center gap-2.5">
|
||||
<ProgressText
|
||||
progress={progress}
|
||||
onClick={() => setShowInfo((prev) => !prev)}
|
||||
inProgressText={
|
||||
function_name
|
||||
? localize('com_assistants_running_var', { 0: function_name })
|
||||
: localize('com_assistants_running_action')
|
||||
}
|
||||
inProgressText={getInProgressText(mcpProgress, function_name, localize)}
|
||||
authText={
|
||||
!showCancelled && authDomain.length > 0 ? localize('com_ui_requires_auth') : undefined
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import React from 'react';
|
||||
import { RecoilRoot } from 'recoil';
|
||||
import { Tools, Constants } from 'librechat-data-provider';
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import { Provider } from 'jotai';
|
||||
import { Tools, Constants } from 'librechat-data-provider';
|
||||
import ToolCall from '../ToolCall';
|
||||
|
||||
// Mock dependencies
|
||||
|
|
@ -57,15 +57,19 @@ jest.mock('../ProgressText', () => ({
|
|||
onClick,
|
||||
inProgressText,
|
||||
finishedText,
|
||||
error,
|
||||
progress,
|
||||
subtitle,
|
||||
}: {
|
||||
onClick?: () => void;
|
||||
inProgressText?: string;
|
||||
finishedText?: string;
|
||||
error?: string;
|
||||
progress: number;
|
||||
subtitle?: string;
|
||||
}) => (
|
||||
<div data-testid="progress-text" onClick={onClick}>
|
||||
{finishedText || inProgressText}
|
||||
{error || progress >= 1 ? finishedText : inProgressText}
|
||||
{subtitle && <span data-testid="subtitle">{subtitle}</span>}
|
||||
</div>
|
||||
),
|
||||
|
|
@ -98,6 +102,34 @@ jest.mock('~/utils', () => ({
|
|||
cn: (...classes: any[]) => classes.filter(Boolean).join(' '),
|
||||
}));
|
||||
|
||||
const mockUseAtomValue = jest.fn().mockReturnValue(undefined);
|
||||
const mockClearProgress = jest.fn();
|
||||
|
||||
jest.mock('jotai', () => ({
|
||||
...jest.requireActual('jotai'),
|
||||
useAtomValue: (...args: any[]) => mockUseAtomValue(...args),
|
||||
useSetAtom: () => mockClearProgress,
|
||||
}));
|
||||
|
||||
const DUMMY_ATOM = { toString: () => 'dummy-atom' };
|
||||
|
||||
jest.mock('~/store/progress', () => ({
|
||||
toolCallProgressFamily: jest.fn().mockReturnValue(DUMMY_ATOM),
|
||||
clearToolCallProgressAtom: {},
|
||||
}));
|
||||
|
||||
jest.mock('recoil', () => ({
|
||||
...jest.requireActual('recoil'),
|
||||
useRecoilValue: jest.fn().mockReturnValue(false),
|
||||
}));
|
||||
|
||||
jest.mock('~/store', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
autoExpandTools: 'autoExpandTools',
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ToolCall', () => {
|
||||
const mockProps = {
|
||||
args: '{"test": "input"}',
|
||||
|
|
@ -107,12 +139,14 @@ describe('ToolCall', () => {
|
|||
isSubmitting: false,
|
||||
};
|
||||
|
||||
const renderWithRecoil = (component: React.ReactElement) => {
|
||||
return render(<RecoilRoot>{component}</RecoilRoot>);
|
||||
const renderWithJotai = (component: React.ReactElement) => {
|
||||
return render(<Provider>{component}</Provider>);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockUseAtomValue.mockReturnValue(undefined);
|
||||
mockClearProgress.mockClear();
|
||||
});
|
||||
|
||||
describe('attachments prop passing', () => {
|
||||
|
|
@ -129,7 +163,7 @@ describe('ToolCall', () => {
|
|||
},
|
||||
];
|
||||
|
||||
renderWithRecoil(<ToolCall {...mockProps} attachments={attachments as any} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} attachments={attachments} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('progress-text'));
|
||||
|
||||
|
|
@ -141,7 +175,7 @@ describe('ToolCall', () => {
|
|||
});
|
||||
|
||||
it('should pass empty array when no attachments', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('progress-text'));
|
||||
|
||||
|
|
@ -172,7 +206,7 @@ describe('ToolCall', () => {
|
|||
},
|
||||
];
|
||||
|
||||
renderWithRecoil(<ToolCall {...mockProps} attachments={attachments as any} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} attachments={attachments} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('progress-text'));
|
||||
|
||||
|
|
@ -196,7 +230,7 @@ describe('ToolCall', () => {
|
|||
},
|
||||
];
|
||||
|
||||
renderWithRecoil(<ToolCall {...mockProps} attachments={attachments as any} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} attachments={attachments} />);
|
||||
|
||||
const attachmentGroup = screen.getByTestId('attachment-group');
|
||||
expect(attachmentGroup).toBeInTheDocument();
|
||||
|
|
@ -204,13 +238,13 @@ describe('ToolCall', () => {
|
|||
});
|
||||
|
||||
it('should not render AttachmentGroup when no attachments', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} />);
|
||||
|
||||
expect(screen.queryByTestId('attachment-group')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should not render AttachmentGroup when attachments is empty array', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} attachments={[]} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} attachments={[]} />);
|
||||
|
||||
expect(screen.queryByTestId('attachment-group')).not.toBeInTheDocument();
|
||||
});
|
||||
|
|
@ -218,7 +252,7 @@ describe('ToolCall', () => {
|
|||
|
||||
describe('tool call info visibility', () => {
|
||||
it('should toggle tool call info expand/collapse when clicking header', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} />);
|
||||
|
||||
// ToolCallInfo is always in the DOM (CSS expand/collapse), but initially collapsed
|
||||
const toolCallInfo = screen.getByTestId('tool-call-info');
|
||||
|
|
@ -233,8 +267,29 @@ describe('ToolCall', () => {
|
|||
expect(screen.getByTestId('tool-call-info')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should pass input and output props to ToolCallInfo', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} />);
|
||||
it('should pass all required props to ToolCallInfo', () => {
|
||||
const attachments = [
|
||||
{
|
||||
type: Tools.ui_resources,
|
||||
messageId: 'msg123',
|
||||
toolCallId: 'tool456',
|
||||
conversationId: 'conv789',
|
||||
[Tools.ui_resources]: {
|
||||
'0': { type: 'button', label: 'Test' },
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
// Use a name with domain separator (_action_) and domain separator (---)
|
||||
const propsWithDomain = {
|
||||
...mockProps,
|
||||
name: 'testFunction_action_test---domain---com',
|
||||
attachments,
|
||||
};
|
||||
|
||||
renderWithJotai(<ToolCall {...propsWithDomain} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('progress-text'));
|
||||
|
||||
const toolCallInfo = screen.getByTestId('tool-call-info');
|
||||
const props = JSON.parse(toolCallInfo.textContent!);
|
||||
|
|
@ -249,9 +304,10 @@ describe('ToolCall', () => {
|
|||
const originalOpen = window.open;
|
||||
window.open = jest.fn();
|
||||
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
output={undefined}
|
||||
initialProgress={0.5} // Less than 1 so it's not complete
|
||||
auth="https://auth.example.com"
|
||||
isSubmitting={true} // Should be submitting for auth to show
|
||||
|
|
@ -272,7 +328,7 @@ describe('ToolCall', () => {
|
|||
});
|
||||
|
||||
it('should not show auth section when cancelled', () => {
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
auth="https://auth.example.com"
|
||||
|
|
@ -285,7 +341,7 @@ describe('ToolCall', () => {
|
|||
});
|
||||
|
||||
it('should not show auth section when progress is complete', () => {
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
auth="https://auth.example.com"
|
||||
|
|
@ -300,7 +356,9 @@ describe('ToolCall', () => {
|
|||
|
||||
describe('edge cases', () => {
|
||||
it('should handle undefined args', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} args={undefined as any} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} args={undefined} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('progress-text'));
|
||||
|
||||
const toolCallInfo = screen.getByTestId('tool-call-info');
|
||||
const props = JSON.parse(toolCallInfo.textContent!);
|
||||
|
|
@ -308,15 +366,17 @@ describe('ToolCall', () => {
|
|||
});
|
||||
|
||||
it('should handle null output', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} output={null} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} output={null} />);
|
||||
|
||||
const toolCallInfo = screen.getByTestId('tool-call-info');
|
||||
const props = JSON.parse(toolCallInfo.textContent!);
|
||||
expect(props.output).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle simple function name without domain', () => {
|
||||
renderWithRecoil(<ToolCall {...mockProps} name="simpleName" />);
|
||||
it('should handle missing domain', () => {
|
||||
renderWithJotai(<ToolCall {...mockProps} domain={undefined} authDomain={undefined} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('progress-text'));
|
||||
|
||||
const toolCallInfo = screen.getByTestId('tool-call-info');
|
||||
expect(toolCallInfo).toBeInTheDocument();
|
||||
|
|
@ -344,7 +404,7 @@ describe('ToolCall', () => {
|
|||
},
|
||||
];
|
||||
|
||||
renderWithRecoil(<ToolCall {...mockProps} attachments={complexAttachments as any} />);
|
||||
renderWithJotai(<ToolCall {...mockProps} attachments={complexAttachments} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('progress-text'));
|
||||
|
||||
|
|
@ -361,7 +421,7 @@ describe('ToolCall', () => {
|
|||
const d = Constants.mcp_delimiter;
|
||||
|
||||
it('should detect MCP OAuth from delimiter in tool-call name', () => {
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}my-server`}
|
||||
|
|
@ -375,7 +435,7 @@ describe('ToolCall', () => {
|
|||
});
|
||||
|
||||
it('should preserve full server name when it contains the delimiter substring', () => {
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}foo${d}bar`}
|
||||
|
|
@ -389,7 +449,7 @@ describe('ToolCall', () => {
|
|||
});
|
||||
|
||||
it('should display server name (not "oauth") as function_name for OAuth tool calls', () => {
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}my-server`}
|
||||
|
|
@ -407,7 +467,7 @@ describe('ToolCall', () => {
|
|||
it('should display server name even when auth is cleared (post-completion)', () => {
|
||||
// After OAuth completes, createOAuthEnd re-emits the toolCall without auth.
|
||||
// The display should still show the server name, not literal "oauth".
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}my-server`}
|
||||
|
|
@ -425,7 +485,7 @@ describe('ToolCall', () => {
|
|||
const authUrl =
|
||||
'https://oauth.example.com/authorize?redirect_uri=' +
|
||||
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name="bare_name"
|
||||
|
|
@ -442,7 +502,7 @@ describe('ToolCall', () => {
|
|||
const authUrl =
|
||||
'https://oauth.example.com/authorize?redirect_uri=' +
|
||||
encodeURIComponent('https://app.example.com/api/mcp/my-server/oauth/callback');
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name="bare_name"
|
||||
|
|
@ -462,7 +522,7 @@ describe('ToolCall', () => {
|
|||
// gets prefixed to oauth_mcp_oauth_mcp_server. Client parses:
|
||||
// func="oauth", server="oauth_mcp_server". Visually awkward but
|
||||
// semantically correct — the normalized name IS oauth_mcp_server.
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name={`oauth${d}oauth${d}server`}
|
||||
|
|
@ -479,7 +539,7 @@ describe('ToolCall', () => {
|
|||
const authUrl =
|
||||
'https://oauth.example.com/authorize?redirect_uri=' +
|
||||
encodeURIComponent('https://app.example.com/api/actions/xyz/oauth/callback');
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
name="action_name"
|
||||
|
|
@ -494,7 +554,7 @@ describe('ToolCall', () => {
|
|||
|
||||
describe('A11Y-04: screen reader status announcements', () => {
|
||||
it('includes sr-only aria-live region for status announcements', () => {
|
||||
renderWithRecoil(
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
initialProgress={1}
|
||||
|
|
@ -509,4 +569,143 @@ describe('ToolCall', () => {
|
|||
expect(liveRegion!.className).toContain('sr-only');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getInProgressText - MCP progress display', () => {
|
||||
it('shows mcpProgress.message when available', () => {
|
||||
mockUseAtomValue.mockReturnValue({
|
||||
progress: 2,
|
||||
total: 10,
|
||||
message: 'Fetching data from API...',
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
output={undefined}
|
||||
initialProgress={0.1}
|
||||
isSubmitting={true}
|
||||
toolCallId="call-123"
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId('progress-text')).toHaveTextContent('Fetching data from API...');
|
||||
});
|
||||
|
||||
it('shows "functionName: X/Y" when mcpProgress has total but no message', () => {
|
||||
mockUseAtomValue.mockReturnValue({
|
||||
progress: 3,
|
||||
total: 10,
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
output={undefined}
|
||||
initialProgress={0.1}
|
||||
isSubmitting={true}
|
||||
toolCallId="call-123"
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId('progress-text')).toHaveTextContent('testFunction: 3/10');
|
||||
});
|
||||
|
||||
it('falls back to running_var localisation when no mcpProgress', () => {
|
||||
mockUseAtomValue.mockReturnValue(undefined);
|
||||
|
||||
renderWithJotai(
|
||||
<ToolCall {...mockProps} output={undefined} initialProgress={0.1} isSubmitting={true} />,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId('progress-text')).toHaveTextContent('Running testFunction');
|
||||
});
|
||||
|
||||
it('prefers message over progress/total when both are present', () => {
|
||||
mockUseAtomValue.mockReturnValue({
|
||||
progress: 5,
|
||||
total: 10,
|
||||
message: 'Custom status message',
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
renderWithJotai(
|
||||
<ToolCall
|
||||
{...mockProps}
|
||||
output={undefined}
|
||||
initialProgress={0.1}
|
||||
isSubmitting={true}
|
||||
toolCallId="call-123"
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId('progress-text')).toHaveTextContent('Custom status message');
|
||||
expect(screen.getByTestId('progress-text')).not.toHaveTextContent('testFunction: 5/10');
|
||||
});
|
||||
});
|
||||
|
||||
describe('toolCallId prop and progress atom integration', () => {
|
||||
it('passes toolCallId to toolCallProgressFamily when provided', () => {
|
||||
const { toolCallProgressFamily } = jest.requireMock('~/store/progress');
|
||||
|
||||
renderWithJotai(<ToolCall {...mockProps} toolCallId="specific-call-id" />);
|
||||
|
||||
expect(toolCallProgressFamily).toHaveBeenCalledWith('specific-call-id');
|
||||
});
|
||||
|
||||
it('passes empty string to toolCallProgressFamily when toolCallId is undefined', () => {
|
||||
const { toolCallProgressFamily } = jest.requireMock('~/store/progress');
|
||||
|
||||
renderWithJotai(<ToolCall {...mockProps} />);
|
||||
|
||||
expect(toolCallProgressFamily).toHaveBeenCalledWith('');
|
||||
});
|
||||
|
||||
it('calls clearProgress with toolCallId when output arrives', () => {
|
||||
mockUseAtomValue.mockReturnValue(mockClearProgress);
|
||||
|
||||
renderWithJotai(
|
||||
<ToolCall {...mockProps} output="Tool completed" toolCallId="call-to-clear" />,
|
||||
);
|
||||
|
||||
expect(mockClearProgress).toHaveBeenCalledWith('call-to-clear');
|
||||
});
|
||||
|
||||
it('does not call clearProgress when toolCallId is undefined', () => {
|
||||
mockUseAtomValue.mockReturnValue(mockClearProgress);
|
||||
|
||||
renderWithJotai(<ToolCall {...mockProps} output="Tool completed" />);
|
||||
|
||||
expect(mockClearProgress).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('cancelled state with hasOutput', () => {
|
||||
it('is not cancelled when output exists even with low progress', () => {
|
||||
renderWithJotai(
|
||||
<ToolCall {...mockProps} output="Result" initialProgress={0.1} isSubmitting={false} />,
|
||||
);
|
||||
|
||||
// When not cancelled and has output → shows finished text, not "Cancelled"
|
||||
expect(screen.queryByTestId('progress-text')).not.toHaveTextContent('Cancelled');
|
||||
expect(screen.queryByTestId('progress-text')).toHaveTextContent('Completed testFunction');
|
||||
});
|
||||
|
||||
it('is cancelled when no output and not submitting and progress < 1', () => {
|
||||
renderWithJotai(
|
||||
<ToolCall {...mockProps} output={undefined} initialProgress={0.1} isSubmitting={false} />,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId('progress-text')).toHaveTextContent('Cancelled');
|
||||
});
|
||||
|
||||
it('shows finished text when progress is 1 and output is present', () => {
|
||||
renderWithJotai(
|
||||
<ToolCall {...mockProps} output="Done" initialProgress={1} isSubmitting={false} />,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId('progress-text')).toHaveTextContent('Completed testFunction');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -75,6 +75,15 @@ jest.mock('~/data-provider', () => ({
|
|||
const mockErrorHandler = jest.fn();
|
||||
const mockSetIsSubmitting = jest.fn();
|
||||
const mockClearStepMaps = jest.fn();
|
||||
const mockHandleProgressEvent = jest.fn();
|
||||
const mockCleanupProgress = jest.fn();
|
||||
|
||||
jest.mock('~/hooks/SSE/useProgressTracking', () => ({
|
||||
useProgressTracking: () => ({
|
||||
handleProgressEvent: mockHandleProgressEvent,
|
||||
cleanupProgress: mockCleanupProgress,
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/hooks/SSE/useEventHandlers', () =>
|
||||
jest.fn(() => ({
|
||||
|
|
@ -282,3 +291,81 @@ describe('useResumableSSE - 404 error path', () => {
|
|||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe('useResumableSSE - progress event integration', () => {
|
||||
beforeEach(() => {
|
||||
mockSSEInstances.length = 0;
|
||||
localStorage.clear();
|
||||
mockHandleProgressEvent.mockClear();
|
||||
mockCleanupProgress.mockClear();
|
||||
});
|
||||
|
||||
const renderAndInit = async (conversationId = CONV_ID) => {
|
||||
const submission = buildSubmission({ conversation: { conversationId } });
|
||||
const chatHelpers = buildChatHelpers();
|
||||
const { unmount } = renderHook(() => useResumableSSE(submission, chatHelpers));
|
||||
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
return { sse: getLastSSE(), unmount, chatHelpers };
|
||||
};
|
||||
|
||||
it('registers a "progress" event listener on the SSE connection', async () => {
|
||||
const { unmount } = await renderAndInit();
|
||||
const sse = getLastSSE();
|
||||
|
||||
const registeredEvents = sse.addEventListener.mock.calls.map(([event]: [string]) => event);
|
||||
expect(registeredEvents).toContain('progress');
|
||||
unmount();
|
||||
});
|
||||
|
||||
it('routes incoming progress events to handleProgressEvent', async () => {
|
||||
const { unmount } = await renderAndInit();
|
||||
const sse = getLastSSE();
|
||||
|
||||
const progressData = JSON.stringify({
|
||||
progress: 3,
|
||||
total: 10,
|
||||
message: 'Working...',
|
||||
toolCallId: 'call-abc',
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
sse._emit('progress', { data: progressData });
|
||||
});
|
||||
|
||||
expect(mockHandleProgressEvent).toHaveBeenCalledTimes(1);
|
||||
expect(mockHandleProgressEvent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ data: progressData }),
|
||||
);
|
||||
unmount();
|
||||
});
|
||||
|
||||
it('calls cleanupProgress on unmount', async () => {
|
||||
const { unmount } = await renderAndInit();
|
||||
|
||||
unmount();
|
||||
|
||||
expect(mockCleanupProgress).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('does not call cleanupProgress before unmount', async () => {
|
||||
const { unmount } = await renderAndInit();
|
||||
|
||||
expect(mockCleanupProgress).not.toHaveBeenCalled();
|
||||
unmount();
|
||||
});
|
||||
|
||||
it('calls cleanupProgress on unmount even after a 404 error', async () => {
|
||||
const { sse, unmount } = await renderAndInit();
|
||||
|
||||
await act(async () => {
|
||||
sse._emit('error', { responseCode: 404 });
|
||||
});
|
||||
|
||||
unmount();
|
||||
expect(mockCleanupProgress).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -5,3 +5,5 @@ export { default as useResumeOnLoad } from './useResumeOnLoad';
|
|||
export { default as useStepHandler } from './useStepHandler';
|
||||
export { default as useContentHandler } from './useContentHandler';
|
||||
export { default as useAttachmentHandler } from './useAttachmentHandler';
|
||||
export { useProgressTracking } from './useProgressTracking';
|
||||
export type { ProgressState } from '~/store/progress';
|
||||
|
|
|
|||
48
client/src/hooks/SSE/useProgressTracking.ts
Normal file
48
client/src/hooks/SSE/useProgressTracking.ts
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import { useRef, useCallback } from 'react';
|
||||
import { useSetAtom } from 'jotai';
|
||||
import { toolCallProgressMapAtom } from '~/store/progress';
|
||||
|
||||
export function useProgressTracking() {
|
||||
const setToolCallProgressMap = useSetAtom(toolCallProgressMapAtom);
|
||||
const progressCleanupTimers = useRef<Map<string, NodeJS.Timeout>>(new Map());
|
||||
|
||||
const handleProgressEvent = useCallback(
|
||||
(e: MessageEvent) => {
|
||||
try {
|
||||
const data = JSON.parse(e.data);
|
||||
const { progress, total, message, toolCallId } = data;
|
||||
if (toolCallId != null) {
|
||||
setToolCallProgressMap((currentMap) => {
|
||||
const newMap = new Map(currentMap);
|
||||
newMap.set(toolCallId, { progress, total, message, timestamp: Date.now() });
|
||||
return newMap;
|
||||
});
|
||||
if (total && progress >= total) {
|
||||
const existingTimer = progressCleanupTimers.current.get(toolCallId);
|
||||
if (existingTimer) clearTimeout(existingTimer);
|
||||
const timerId = setTimeout(() => {
|
||||
setToolCallProgressMap((currentMap) => {
|
||||
const newMap = new Map(currentMap);
|
||||
newMap.delete(toolCallId);
|
||||
return newMap;
|
||||
});
|
||||
progressCleanupTimers.current.delete(toolCallId);
|
||||
}, 5000);
|
||||
progressCleanupTimers.current.set(toolCallId, timerId);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error parsing progress event:', error);
|
||||
}
|
||||
},
|
||||
[setToolCallProgressMap],
|
||||
);
|
||||
|
||||
const cleanupProgress = useCallback(() => {
|
||||
setToolCallProgressMap(new Map());
|
||||
progressCleanupTimers.current.forEach(clearTimeout);
|
||||
progressCleanupTimers.current.clear();
|
||||
}, [setToolCallProgressMap]);
|
||||
|
||||
return { handleProgressEvent, cleanupProgress };
|
||||
}
|
||||
|
|
@ -25,6 +25,7 @@ import {
|
|||
import type { ActiveJobsResponse } from '~/data-provider';
|
||||
import { useAuthContext } from '~/hooks/AuthContext';
|
||||
import useEventHandlers from './useEventHandlers';
|
||||
import { useProgressTracking } from './useProgressTracking';
|
||||
import { clearAllDrafts } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
|
|
@ -132,6 +133,7 @@ export default function useResumableSSE(
|
|||
const balanceQuery = useGetUserBalance({
|
||||
enabled: !!isAuthenticated && startupConfig?.balance?.enabled,
|
||||
});
|
||||
const { handleProgressEvent, cleanupProgress } = useProgressTracking();
|
||||
|
||||
/**
|
||||
* Subscribe to stream via SSE library (supports custom headers)
|
||||
|
|
@ -335,6 +337,8 @@ export default function useResumableSSE(
|
|||
}
|
||||
});
|
||||
|
||||
sse.addEventListener('progress', handleProgressEvent);
|
||||
|
||||
/**
|
||||
* Error event handler - handles BOTH:
|
||||
* 1. HTTP-level errors (responseCode present) - 404, 401, network failures
|
||||
|
|
@ -552,6 +556,7 @@ export default function useResumableSSE(
|
|||
balanceQuery,
|
||||
removeActiveJob,
|
||||
queryClient,
|
||||
handleProgressEvent,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
@ -703,6 +708,8 @@ export default function useResumableSSE(
|
|||
// Reset UI state on cleanup - useResumeOnLoad will restore if needed
|
||||
setIsSubmitting(false);
|
||||
setShowStopButton(false);
|
||||
// Clear progress map and pending cleanup timers on unmount
|
||||
cleanupProgress();
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [submission]);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import { useGetStartupConfig, useGetUserBalance } from '~/data-provider';
|
|||
import { useAuthContext } from '~/hooks/AuthContext';
|
||||
import useEventHandlers from './useEventHandlers';
|
||||
import { clearAllDrafts } from '~/utils';
|
||||
import { useProgressTracking } from './useProgressTracking';
|
||||
import store from '~/store';
|
||||
|
||||
type ChatHelpers = Pick<
|
||||
|
|
@ -71,6 +72,7 @@ export default function useSSE(
|
|||
const balanceQuery = useGetUserBalance({
|
||||
enabled: !!isAuthenticated && startupConfig?.balance?.enabled,
|
||||
});
|
||||
const { handleProgressEvent, cleanupProgress } = useProgressTracking();
|
||||
|
||||
useEffect(() => {
|
||||
if (submission == null || Object.keys(submission).length === 0) {
|
||||
|
|
@ -100,6 +102,8 @@ export default function useSSE(
|
|||
}
|
||||
});
|
||||
|
||||
sse.addEventListener('progress', handleProgressEvent);
|
||||
|
||||
sse.addEventListener('message', (e: MessageEvent) => {
|
||||
const data = JSON.parse(e.data);
|
||||
|
||||
|
|
@ -234,6 +238,8 @@ export default function useSSE(
|
|||
return () => {
|
||||
const isCancelled = sse.readyState <= 1;
|
||||
sse.close();
|
||||
// Clear progress map and pending cleanup timers on unmount
|
||||
cleanupProgress();
|
||||
if (isCancelled) {
|
||||
const e = new Event('cancel');
|
||||
/* @ts-ignore */
|
||||
|
|
|
|||
30
client/src/store/progress.ts
Normal file
30
client/src/store/progress.ts
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import { atom } from 'jotai';
|
||||
import { atomFamily } from 'jotai/utils';
|
||||
|
||||
export type ProgressState = {
|
||||
progress: number;
|
||||
total?: number;
|
||||
message?: string;
|
||||
timestamp: number;
|
||||
};
|
||||
|
||||
// Map of toolCallId -> ProgressState (for matching progress to specific tool calls)
|
||||
export const toolCallProgressMapAtom = atom<Map<string, ProgressState>>(new Map());
|
||||
|
||||
// Family for tool call based progress lookup
|
||||
export const toolCallProgressFamily = atomFamily((toolCallId: string) =>
|
||||
atom((get) => {
|
||||
// Don't return data for empty string key
|
||||
if (!toolCallId) return undefined;
|
||||
return get(toolCallProgressMapAtom).get(toolCallId);
|
||||
}),
|
||||
);
|
||||
|
||||
// Cleanup action - remove progress entry for a specific tool call
|
||||
export const clearToolCallProgressAtom = atom(null, (get, set, toolCallId: string) => {
|
||||
if (!toolCallId) return;
|
||||
const current = get(toolCallProgressMapAtom);
|
||||
const newMap = new Map(current);
|
||||
newMap.delete(toolCallId);
|
||||
set(toolCallProgressMapAtom, newMap);
|
||||
});
|
||||
|
|
@ -272,6 +272,7 @@ Please follow these instructions when using tools from the respective MCP server
|
|||
oauthEnd,
|
||||
customUserVars,
|
||||
graphTokenResolver,
|
||||
onProgress,
|
||||
}: {
|
||||
user?: IUser;
|
||||
serverName: string;
|
||||
|
|
@ -288,9 +289,19 @@ Please follow these instructions when using tools from the respective MCP server
|
|||
oauthStart?: (authURL: string) => Promise<void>;
|
||||
oauthEnd?: () => Promise<void>;
|
||||
graphTokenResolver?: GraphTokenResolver;
|
||||
onProgress?: (progressData: {
|
||||
progressToken: t.ProgressToken;
|
||||
progress: number;
|
||||
total?: number;
|
||||
message?: string;
|
||||
serverName: string;
|
||||
}) => void;
|
||||
}): Promise<t.FormattedToolResponse> {
|
||||
/** User-specific connection */
|
||||
let connection: MCPConnection | undefined;
|
||||
let progressHandler: ((data: t.ProgressNotification & { serverName: string }) => void) | null =
|
||||
null;
|
||||
let progressToken: t.ProgressToken | undefined;
|
||||
const userId = user?.id;
|
||||
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
|
||||
|
||||
|
|
@ -348,12 +359,44 @@ Please follow these instructions when using tools from the respective MCP server
|
|||
connection.setRequestHeaders(currentOptions.headers || {});
|
||||
}
|
||||
|
||||
// Generate and register progress token BEFORE setting up listener
|
||||
// to avoid race condition where progress arrives before token is registered
|
||||
progressToken = connection.generateProgressToken();
|
||||
connection.registerProgressToken(progressToken);
|
||||
logger.info(
|
||||
`${logPrefix}[${toolName}] Progress token generated and registered: ${progressToken}`,
|
||||
);
|
||||
|
||||
// Set up progress event listener if callback is provided
|
||||
if (onProgress) {
|
||||
progressHandler = (data) => {
|
||||
logger.info(`${logPrefix}[${toolName}] Progress event received:`, {
|
||||
progressToken: data.progressToken,
|
||||
progress: data.progress,
|
||||
total: data.total,
|
||||
message: data.message,
|
||||
});
|
||||
onProgress({
|
||||
progressToken: data.progressToken,
|
||||
progress: data.progress,
|
||||
total: data.total,
|
||||
message: data.message,
|
||||
serverName: data.serverName,
|
||||
});
|
||||
};
|
||||
connection.on('progress', progressHandler);
|
||||
logger.info(`${logPrefix}[${toolName}] Progress listener registered`);
|
||||
}
|
||||
|
||||
const result = await connection.client.request(
|
||||
{
|
||||
method: 'tools/call',
|
||||
params: {
|
||||
name: toolName,
|
||||
arguments: toolArguments,
|
||||
_meta: {
|
||||
progressToken,
|
||||
},
|
||||
},
|
||||
},
|
||||
CallToolResultSchema,
|
||||
|
|
@ -367,12 +410,30 @@ Please follow these instructions when using tools from the respective MCP server
|
|||
this.updateUserLastActivity(userId);
|
||||
}
|
||||
this.checkIdleConnections();
|
||||
return formatToolContent(result as t.MCPToolCallResponse, provider);
|
||||
|
||||
// Format and return the tool response as a proper tuple [content, artifacts]
|
||||
// Progress is handled separately via SSE events emitted by the connection
|
||||
const formattedResponse = formatToolContent(result as t.MCPToolCallResponse, provider);
|
||||
|
||||
return formattedResponse;
|
||||
} catch (error) {
|
||||
// Log with context and re-throw or handle as needed
|
||||
logger.error(`${logPrefix}[${toolName}] Tool call failed`, error);
|
||||
// Rethrowing allows the caller (createMCPTool) to handle the final user message
|
||||
throw error;
|
||||
} finally {
|
||||
if (connection && (progressHandler || progressToken)) {
|
||||
setTimeout(() => {
|
||||
// Clean up progress listener to prevent memory leaks
|
||||
if (progressHandler) {
|
||||
connection?.off('progress', progressHandler);
|
||||
}
|
||||
// Clean up progress token to prevent memory leak
|
||||
if (progressToken) {
|
||||
connection?.unregisterProgressToken(progressToken);
|
||||
}
|
||||
}, 500);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
|||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
import { EventEmitter } from 'events';
|
||||
import * as graphUtils from '~/utils/graph';
|
||||
|
||||
// Mock external dependencies
|
||||
|
|
@ -429,6 +430,10 @@ describe('MCPManager', () => {
|
|||
isError: false,
|
||||
}),
|
||||
},
|
||||
generateProgressToken: jest.fn().mockReturnValue('mock-progress-token'),
|
||||
registerProgressToken: jest.fn(),
|
||||
subscribeToProgress: jest.fn().mockReturnValue(() => {}),
|
||||
unregisterProgressToken: jest.fn(),
|
||||
} as unknown as MCPConnection;
|
||||
|
||||
const mockGraphTokenResolver: GraphTokenResolver = jest.fn().mockResolvedValue({
|
||||
|
|
@ -966,4 +971,396 @@ describe('MCPManager', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('callTool Progress Integration', () => {
|
||||
const serverName = 'test_server';
|
||||
|
||||
const mockUser: Partial<IUser> = {
|
||||
id: 'user-123',
|
||||
provider: 'openid',
|
||||
openidId: 'oidc-sub-456',
|
||||
};
|
||||
|
||||
const mockFlowManager = {
|
||||
getState: jest.fn(),
|
||||
setState: jest.fn(),
|
||||
clearState: jest.fn(),
|
||||
};
|
||||
|
||||
function buildMockConnection(overrides: Partial<Record<string, unknown>> = {}) {
|
||||
const emitter = new EventEmitter();
|
||||
return {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
setRequestHeaders: jest.fn(),
|
||||
timeout: 30000,
|
||||
on: jest.fn((event, handler) => emitter.on(event, handler)),
|
||||
off: jest.fn((event, handler) => emitter.off(event, handler)),
|
||||
emit: (event: string, data: unknown) => emitter.emit(event, data),
|
||||
client: {
|
||||
request: jest.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Tool result' }],
|
||||
isError: false,
|
||||
}),
|
||||
},
|
||||
generateProgressToken: jest.fn().mockReturnValue('mock-progress-token'),
|
||||
registerProgressToken: jest.fn(),
|
||||
unregisterProgressToken: jest.fn(),
|
||||
...overrides,
|
||||
} as unknown as MCPConnection;
|
||||
}
|
||||
|
||||
function mockAppConnections(connection: MCPConnection) {
|
||||
(ConnectionsRepository as jest.MockedClass<typeof ConnectionsRepository>).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
has: jest.fn().mockResolvedValue(false),
|
||||
get: jest.fn().mockResolvedValue(connection),
|
||||
}) as unknown as ConnectionsRepository,
|
||||
);
|
||||
}
|
||||
|
||||
function newMCPServersConfig(): t.MCPServers {
|
||||
return { [serverName]: { type: 'stdio', command: 'test', args: [] } };
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
(MCPManager as unknown as { instance: null }).instance = null;
|
||||
jest.clearAllMocks();
|
||||
(MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined);
|
||||
(mockRegistryInstance.getAllServerConfigs as jest.Mock).mockResolvedValue({});
|
||||
(graphUtils.preProcessGraphTokens as jest.Mock).mockImplementation(
|
||||
async (options) => options,
|
||||
);
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
type: 'sse',
|
||||
url: 'https://api.example.com',
|
||||
});
|
||||
});
|
||||
|
||||
describe('progress token lifecycle', () => {
|
||||
it('generates and registers a progress token before the tool call', async () => {
|
||||
const connection = buildMockConnection();
|
||||
mockAppConnections(connection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
});
|
||||
|
||||
expect(connection.generateProgressToken).toHaveBeenCalledTimes(1);
|
||||
expect(connection.registerProgressToken).toHaveBeenCalledWith('mock-progress-token');
|
||||
});
|
||||
|
||||
it('passes _meta.progressToken in the MCP request params', async () => {
|
||||
const connection = buildMockConnection();
|
||||
mockAppConnections(connection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
});
|
||||
|
||||
expect(connection.client.request).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
params: expect.objectContaining({
|
||||
_meta: { progressToken: 'mock-progress-token' },
|
||||
}),
|
||||
}),
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('schedules unregisterProgressToken cleanup in finally block after success', async () => {
|
||||
jest.useFakeTimers();
|
||||
const connection = buildMockConnection();
|
||||
mockAppConnections(connection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
});
|
||||
|
||||
expect(connection.unregisterProgressToken).not.toHaveBeenCalled();
|
||||
jest.advanceTimersByTime(600);
|
||||
expect(connection.unregisterProgressToken).toHaveBeenCalledWith('mock-progress-token');
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('schedules cleanup in finally block even when tool call throws', async () => {
|
||||
jest.useFakeTimers();
|
||||
const connection = buildMockConnection({
|
||||
client: {
|
||||
request: jest.fn().mockRejectedValue(new Error('MCP server error')),
|
||||
},
|
||||
});
|
||||
mockAppConnections(connection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await expect(
|
||||
manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
}),
|
||||
).rejects.toThrow('MCP server error');
|
||||
|
||||
jest.advanceTimersByTime(600);
|
||||
expect(connection.unregisterProgressToken).toHaveBeenCalledWith('mock-progress-token');
|
||||
jest.useRealTimers();
|
||||
});
|
||||
});
|
||||
|
||||
describe('onProgress callback', () => {
|
||||
it('does not register a progress listener when onProgress is not provided', async () => {
|
||||
const connection = buildMockConnection();
|
||||
mockAppConnections(connection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
});
|
||||
|
||||
expect(connection.on).not.toHaveBeenCalledWith('progress', expect.any(Function));
|
||||
});
|
||||
|
||||
it('registers a progress listener when onProgress is provided', async () => {
|
||||
const connection = buildMockConnection();
|
||||
mockAppConnections(connection);
|
||||
|
||||
const onProgress = jest.fn();
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
onProgress,
|
||||
});
|
||||
|
||||
expect(connection.on).toHaveBeenCalledWith('progress', expect.any(Function));
|
||||
});
|
||||
|
||||
it('calls onProgress with the correct shape when a progress event fires', async () => {
|
||||
const emitter = new EventEmitter();
|
||||
const connection = buildMockConnection({
|
||||
on: jest.fn((event, handler) => emitter.on(event, handler)),
|
||||
off: jest.fn((event, handler) => emitter.off(event, handler)),
|
||||
client: {
|
||||
request: jest.fn().mockImplementation(async () => {
|
||||
emitter.emit('progress', {
|
||||
serverName,
|
||||
progressToken: 'mock-progress-token',
|
||||
progress: 3,
|
||||
total: 10,
|
||||
message: 'Processing...',
|
||||
});
|
||||
return { content: [{ type: 'text', text: 'done' }], isError: false };
|
||||
}),
|
||||
},
|
||||
});
|
||||
mockAppConnections(connection);
|
||||
|
||||
const onProgress = jest.fn();
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
onProgress,
|
||||
});
|
||||
|
||||
expect(onProgress).toHaveBeenCalledTimes(1);
|
||||
expect(onProgress).toHaveBeenCalledWith({
|
||||
progressToken: 'mock-progress-token',
|
||||
progress: 3,
|
||||
total: 10,
|
||||
message: 'Processing...',
|
||||
serverName,
|
||||
});
|
||||
});
|
||||
|
||||
it('calls onProgress multiple times as progress events accumulate', async () => {
|
||||
const emitter = new EventEmitter();
|
||||
const connection = buildMockConnection({
|
||||
on: jest.fn((event, handler) => emitter.on(event, handler)),
|
||||
off: jest.fn((event, handler) => emitter.off(event, handler)),
|
||||
client: {
|
||||
request: jest.fn().mockImplementation(async () => {
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
emitter.emit('progress', {
|
||||
serverName,
|
||||
progressToken: 'mock-progress-token',
|
||||
progress: i,
|
||||
total: 3,
|
||||
});
|
||||
}
|
||||
return { content: [{ type: 'text', text: 'done' }], isError: false };
|
||||
}),
|
||||
},
|
||||
});
|
||||
mockAppConnections(connection);
|
||||
|
||||
const onProgress = jest.fn();
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
onProgress,
|
||||
});
|
||||
|
||||
expect(onProgress).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('detaches the progress listener after call completes', async () => {
|
||||
jest.useFakeTimers();
|
||||
const connection = buildMockConnection();
|
||||
mockAppConnections(connection);
|
||||
|
||||
const onProgress = jest.fn();
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
onProgress,
|
||||
});
|
||||
|
||||
jest.advanceTimersByTime(600);
|
||||
expect(connection.off).toHaveBeenCalledWith('progress', expect.any(Function));
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('detaches the progress listener after call throws', async () => {
|
||||
jest.useFakeTimers();
|
||||
const connection = buildMockConnection({
|
||||
client: {
|
||||
request: jest.fn().mockRejectedValue(new Error('fail')),
|
||||
},
|
||||
});
|
||||
mockAppConnections(connection);
|
||||
|
||||
const onProgress = jest.fn();
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await expect(
|
||||
manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
onProgress,
|
||||
}),
|
||||
).rejects.toThrow('fail');
|
||||
|
||||
jest.advanceTimersByTime(600);
|
||||
expect(connection.off).toHaveBeenCalledWith('progress', expect.any(Function));
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('logs each received progress event at info level', async () => {
|
||||
const emitter = new EventEmitter();
|
||||
const connection = buildMockConnection({
|
||||
on: jest.fn((event, handler) => emitter.on(event, handler)),
|
||||
off: jest.fn((event, handler) => emitter.off(event, handler)),
|
||||
client: {
|
||||
request: jest.fn().mockImplementation(async () => {
|
||||
emitter.emit('progress', {
|
||||
serverName,
|
||||
progressToken: 'mock-progress-token',
|
||||
progress: 1,
|
||||
total: 5,
|
||||
message: 'Step 1',
|
||||
});
|
||||
return { content: [{ type: 'text', text: 'done' }], isError: false };
|
||||
}),
|
||||
},
|
||||
});
|
||||
mockAppConnections(connection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
onProgress: jest.fn(),
|
||||
});
|
||||
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Progress event received'),
|
||||
expect.objectContaining({ progress: 1, total: 5, message: 'Step 1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('logs progress token registration at info level', async () => {
|
||||
const connection = buildMockConnection();
|
||||
mockAppConnections(connection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.callTool({
|
||||
user: mockUser as IUser,
|
||||
serverName,
|
||||
toolName: 'test_tool',
|
||||
provider: 'openai',
|
||||
flowManager: mockFlowManager as unknown as Parameters<
|
||||
typeof manager.callTool
|
||||
>[0]['flowManager'],
|
||||
});
|
||||
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Progress token generated and registered'),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ import {
|
|||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
||||
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import {
|
||||
ResourceListChangedNotificationSchema,
|
||||
ProgressNotificationSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type {
|
||||
|
|
@ -368,6 +371,13 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
}
|
||||
|
||||
// Progress notification state
|
||||
private progressTokenCounter = 0;
|
||||
private activeProgressTokens = new Map<t.ProgressToken, t.ProgressState>();
|
||||
private lastProgressEmit = new Map<t.ProgressToken, number>();
|
||||
private progressCleanupTimers = new Map<t.ProgressToken, NodeJS.Timeout>();
|
||||
private readonly PROGRESS_THROTTLE_MS = 200; // Max 5 updates/second
|
||||
|
||||
setRequestHeaders(headers: Record<string, string> | null): void {
|
||||
if (!headers) {
|
||||
return;
|
||||
|
|
@ -686,6 +696,7 @@ export class MCPConnection extends EventEmitter {
|
|||
});
|
||||
|
||||
this.subscribeToResources();
|
||||
this.subscribeToProgress();
|
||||
}
|
||||
|
||||
private async handleReconnection(): Promise<void> {
|
||||
|
|
@ -767,6 +778,125 @@ export class MCPConnection extends EventEmitter {
|
|||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a unique progress token for tracking tool call progress
|
||||
*/
|
||||
public generateProgressToken(): t.ProgressToken {
|
||||
const userId = this.userId ? `${this.userId}-` : '';
|
||||
return `${userId}${this.serverName}-${++this.progressTokenCounter}-${Date.now()}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a progress token for tracking
|
||||
*/
|
||||
public registerProgressToken(token: t.ProgressToken): void {
|
||||
this.activeProgressTokens.set(token, {
|
||||
token,
|
||||
progress: 0,
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current progress state for a token
|
||||
*/
|
||||
public getProgressState(token: t.ProgressToken): t.ProgressState | undefined {
|
||||
return this.activeProgressTokens.get(token);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregisters a progress token and cleans up associated state
|
||||
*/
|
||||
public unregisterProgressToken(token: t.ProgressToken): void {
|
||||
this.activeProgressTokens.delete(token);
|
||||
this.lastProgressEmit.delete(token);
|
||||
// Cancel any pending cleanup timer for this token
|
||||
const existingTimer = this.progressCleanupTimers.get(token);
|
||||
if (existingTimer) {
|
||||
clearTimeout(existingTimer);
|
||||
this.progressCleanupTimers.delete(token);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if progress should be emitted (rate limiting)
|
||||
*/
|
||||
private shouldEmitProgress(token: t.ProgressToken): boolean {
|
||||
const lastEmit = this.lastProgressEmit.get(token) || 0;
|
||||
const now = Date.now();
|
||||
if (now - lastEmit < this.PROGRESS_THROTTLE_MS) {
|
||||
return false;
|
||||
}
|
||||
this.lastProgressEmit.set(token, now);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribes to progress notifications from MCP server
|
||||
*/
|
||||
private subscribeToProgress(): void {
|
||||
try {
|
||||
this.client.setNotificationHandler(ProgressNotificationSchema, async (notification) => {
|
||||
try {
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} Received progress notification:`,
|
||||
notification.params,
|
||||
);
|
||||
const { progressToken, progress, total, message } = notification.params;
|
||||
|
||||
// Validate token
|
||||
if (!this.activeProgressTokens.has(progressToken)) {
|
||||
logger.debug(`${this.getLogPrefix()} Progress for unknown token: ${progressToken}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if (!this.shouldEmitProgress(progressToken) && (!total || progress < total)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update state
|
||||
const newState: t.ProgressState = {
|
||||
token: progressToken,
|
||||
progress,
|
||||
total,
|
||||
message,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
this.activeProgressTokens.set(progressToken, newState);
|
||||
|
||||
// Emit progress event
|
||||
this.emit('progress', {
|
||||
serverName: this.serverName,
|
||||
progressToken,
|
||||
progress,
|
||||
total,
|
||||
message,
|
||||
});
|
||||
|
||||
// Cleanup if complete
|
||||
if (total && progress >= total) {
|
||||
// Clear existing timer if present to prevent duplicates
|
||||
const existingTimer = this.progressCleanupTimers.get(progressToken);
|
||||
if (existingTimer) {
|
||||
clearTimeout(existingTimer);
|
||||
}
|
||||
const timerId = setTimeout(() => {
|
||||
this.activeProgressTokens.delete(progressToken);
|
||||
this.lastProgressEmit.delete(progressToken);
|
||||
this.progressCleanupTimers.delete(progressToken);
|
||||
}, 5000);
|
||||
this.progressCleanupTimers.set(progressToken, timerId);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`${this.getLogPrefix()} Error handling progress:`, error);
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
logger.warn(`${this.getLogPrefix()} Failed to setup progress notifications:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
async connectClient(): Promise<void> {
|
||||
if (this.connectionState === 'connected') {
|
||||
return;
|
||||
|
|
@ -1115,6 +1245,12 @@ export class MCPConnection extends EventEmitter {
|
|||
this.emit('connectionChange', 'disconnected');
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
// Clean up progress tracking state to prevent memory leaks
|
||||
this.activeProgressTokens.clear();
|
||||
this.lastProgressEmit.clear();
|
||||
// Clear all pending cleanup timers
|
||||
this.progressCleanupTimers.forEach((timer) => clearTimeout(timer));
|
||||
this.progressCleanupTimers.clear();
|
||||
if (!resetCycleTracking) {
|
||||
this.recordCycle();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,6 +72,23 @@ export type MCPToolCallResponse =
|
|||
isError?: boolean;
|
||||
};
|
||||
|
||||
export type ProgressToken = string | number;
|
||||
|
||||
export interface ProgressNotification {
|
||||
progressToken: ProgressToken;
|
||||
progress: number;
|
||||
total?: number;
|
||||
message?: string;
|
||||
}
|
||||
|
||||
export interface ProgressState {
|
||||
token: ProgressToken;
|
||||
progress: number;
|
||||
total?: number;
|
||||
message?: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export type Provider =
|
||||
| 'google'
|
||||
| 'anthropic'
|
||||
|
|
@ -142,6 +159,10 @@ export type FormattedContentResult = [string, Artifacts | undefined];
|
|||
|
||||
export type ImageFormatter = (item: ImageContent) => FormattedContent;
|
||||
|
||||
/**
|
||||
* Tool response for MCP tools - must be a proper two-tuple for LangChain's content_and_artifact format.
|
||||
* Progress notifications are handled separately via SSE events, not attached to the response.
|
||||
*/
|
||||
export type FormattedToolResponse = FormattedContentResult;
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -915,6 +915,26 @@ class GenerationJobManagerClass {
|
|||
await this.eventTransport.emitChunk(streamId, event);
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit an ephemeral event directly to live subscribers, bypassing the
|
||||
* earlyEventBuffer, Redis persistence, and trackUserMessage side-effects.
|
||||
*
|
||||
* Use for fire-and-forget events like MCP tool-call progress that must not
|
||||
* be replayed to reconnecting clients. Silently drops when no subscriber
|
||||
* is connected or the job has been aborted.
|
||||
*/
|
||||
async emitTransientEvent(streamId: string, event: t.ServerSentEvent): Promise<void> {
|
||||
const runtime = this.runtimeState.get(streamId);
|
||||
if (!runtime || runtime.abortController.signal.aborted) {
|
||||
return;
|
||||
}
|
||||
if (!runtime.hasSubscriber) {
|
||||
return;
|
||||
}
|
||||
// Emit directly to transport, bypassing appendChunk and earlyEventBuffer
|
||||
await this.eventTransport.emitChunk(streamId, event);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract and save run step from event data.
|
||||
* The data is already the run step object from the event payload.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,119 @@
|
|||
import type * as t from '~/types';
|
||||
|
||||
interface RuntimeState {
|
||||
abortController: AbortController;
|
||||
hasSubscriber: boolean;
|
||||
}
|
||||
|
||||
class GenerationJobManagerStub {
|
||||
runtimeState = new Map<string, RuntimeState>();
|
||||
eventTransport = { emitChunk: jest.fn() };
|
||||
|
||||
async emitTransientEvent(streamId: string, event: t.ServerSentEvent): Promise<void> {
|
||||
const runtime = this.runtimeState.get(streamId);
|
||||
if (!runtime || runtime.abortController.signal.aborted) {
|
||||
return;
|
||||
}
|
||||
if (!runtime.hasSubscriber) {
|
||||
return;
|
||||
}
|
||||
await this.eventTransport.emitChunk(streamId, event);
|
||||
}
|
||||
}
|
||||
|
||||
function makeRuntime(overrides: Partial<RuntimeState> = {}): RuntimeState {
|
||||
return {
|
||||
abortController: new AbortController(),
|
||||
hasSubscriber: true,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe('GenerationJobManager - emitTransientEvent', () => {
|
||||
let manager: GenerationJobManagerStub;
|
||||
|
||||
const streamId = 'stream-abc-123';
|
||||
const progressEvent: t.ServerSentEvent = {
|
||||
event: 'progress',
|
||||
data: { progress: 2, total: 5, message: 'Working…', toolCallId: 'call-1' },
|
||||
} as unknown as t.ServerSentEvent;
|
||||
|
||||
beforeEach(() => {
|
||||
manager = new GenerationJobManagerStub();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('emits to transport when runtime exists and has subscriber', async () => {
|
||||
manager.runtimeState.set(streamId, makeRuntime());
|
||||
|
||||
await manager.emitTransientEvent(streamId, progressEvent);
|
||||
|
||||
expect(manager.eventTransport.emitChunk).toHaveBeenCalledTimes(1);
|
||||
expect(manager.eventTransport.emitChunk).toHaveBeenCalledWith(streamId, progressEvent);
|
||||
});
|
||||
|
||||
it('silently drops when streamId has no runtime entry', async () => {
|
||||
await manager.emitTransientEvent('unknown-stream', progressEvent);
|
||||
|
||||
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('silently drops when job has been aborted', async () => {
|
||||
const runtime = makeRuntime();
|
||||
runtime.abortController.abort();
|
||||
manager.runtimeState.set(streamId, runtime);
|
||||
|
||||
await manager.emitTransientEvent(streamId, progressEvent);
|
||||
|
||||
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('silently drops when there is no active subscriber', async () => {
|
||||
manager.runtimeState.set(streamId, makeRuntime({ hasSubscriber: false }));
|
||||
|
||||
await manager.emitTransientEvent(streamId, progressEvent);
|
||||
|
||||
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not persist the event (calls emitChunk directly, not appendChunk)', async () => {
|
||||
const appendChunk = jest.fn();
|
||||
(manager as unknown as Record<string, unknown>).appendChunk = appendChunk;
|
||||
manager.runtimeState.set(streamId, makeRuntime());
|
||||
|
||||
await manager.emitTransientEvent(streamId, progressEvent);
|
||||
|
||||
expect(appendChunk).not.toHaveBeenCalled();
|
||||
expect(manager.eventTransport.emitChunk).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('forwards any event shape without mutation', async () => {
|
||||
manager.runtimeState.set(streamId, makeRuntime());
|
||||
|
||||
const customEvent = { event: 'progress', data: { foo: 'bar' } } as unknown as t.ServerSentEvent;
|
||||
await manager.emitTransientEvent(streamId, customEvent);
|
||||
|
||||
expect(manager.eventTransport.emitChunk).toHaveBeenCalledWith(streamId, customEvent);
|
||||
});
|
||||
|
||||
it('handles transport throwing without crashing the caller', async () => {
|
||||
manager.runtimeState.set(streamId, makeRuntime());
|
||||
manager.eventTransport.emitChunk.mockRejectedValueOnce(new Error('transport error'));
|
||||
|
||||
await expect(manager.emitTransientEvent(streamId, progressEvent)).rejects.toThrow(
|
||||
'transport error',
|
||||
);
|
||||
});
|
||||
|
||||
it('does not emit after abort even if subscriber flag is still true', async () => {
|
||||
const runtime = makeRuntime({ hasSubscriber: true });
|
||||
manager.runtimeState.set(streamId, runtime);
|
||||
|
||||
// Abort happens between registration and emit
|
||||
runtime.abortController.abort();
|
||||
|
||||
await manager.emitTransientEvent(streamId, progressEvent);
|
||||
|
||||
expect(manager.eventTransport.emitChunk).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
|
@ -1,6 +1,25 @@
|
|||
import type { Response as ServerResponse } from 'express';
|
||||
import type { ServerSentEvent } from '~/types';
|
||||
|
||||
/**
|
||||
* Safely writes to a server response, handling disconnected clients.
|
||||
* @param res - The server response.
|
||||
* @param data - The data to write.
|
||||
* @returns true if write succeeded, false otherwise.
|
||||
*/
|
||||
function safeWrite(res: ServerResponse, data: string): boolean {
|
||||
try {
|
||||
if (!res.writable) {
|
||||
return false;
|
||||
}
|
||||
res.write(data);
|
||||
return true;
|
||||
} catch {
|
||||
// Client may have disconnected - log but don't crash
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a Server-Sent Event to the client.
|
||||
* Empty-string StreamEvent data is silently dropped.
|
||||
|
|
@ -9,7 +28,7 @@ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void {
|
|||
if ('data' in event && typeof event.data === 'string' && event.data.length === 0) {
|
||||
return;
|
||||
}
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
safeWrite(res, `event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -18,6 +37,29 @@ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void {
|
|||
* @param message - The error message.
|
||||
*/
|
||||
export function handleError(res: ServerResponse, message: string): void {
|
||||
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
|
||||
res.end();
|
||||
safeWrite(res, `event: error\ndata: ${JSON.stringify(message)}\n\n`);
|
||||
try {
|
||||
res.end();
|
||||
} catch {
|
||||
// Client may have disconnected
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends progress notification in Server Sent Events format.
|
||||
* @param res - The server response.
|
||||
* @param progressData - Progress notification data.
|
||||
*/
|
||||
export function sendProgress(
|
||||
res: ServerResponse,
|
||||
progressData: {
|
||||
progressToken: string | number;
|
||||
progress: number;
|
||||
total?: number;
|
||||
message?: string;
|
||||
serverName?: string;
|
||||
toolCallId?: string; // Tool call ID for matching progress to specific tool call
|
||||
},
|
||||
): void {
|
||||
safeWrite(res, `event: progress\ndata: ${JSON.stringify(progressData)}\n\n`);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue