diff --git a/client/src/hooks/SSE/useStepHandler.ts b/client/src/hooks/SSE/useStepHandler.ts index 2fc987cf9c..11d2cdd839 100644 --- a/client/src/hooks/SSE/useStepHandler.ts +++ b/client/src/hooks/SSE/useStepHandler.ts @@ -1,6 +1,6 @@ import { useCallback, useRef } from 'react'; import { StepTypes, ContentTypes, ToolCallTypes } from 'librechat-data-provider'; -import type { Agents, PartMetadata, TMessage } from 'librechat-data-provider'; +import type { Agents, PartMetadata, TMessage, TMessageContentParts } from 'librechat-data-provider'; import { getNonEmptyValue } from 'librechat-data-provider'; type TUseStepHandler = { @@ -13,8 +13,15 @@ type TStepEvent = { data: Agents.MessageDeltaEvent | Agents.RunStep | Agents.ToolEndEvent; }; +type AllContentTypes = + | ContentTypes.TEXT + | ContentTypes.TOOL_CALL + | ContentTypes.IMAGE_FILE + | ContentTypes.IMAGE_URL + | ContentTypes.ERROR; + export default function useStepHandler({ setMessages, getMessages }: TUseStepHandler) { - const toolCallIdMap = useRef(new Map()); + const toolCallIdMap = useRef(new Map()); const messageMap = useRef(new Map()); const stepMap = useRef(new Map()); @@ -24,18 +31,21 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan contentPart: Agents.MessageContentComplex, finalUpdate = false, ) => { - if (!contentPart.type) { + const contentType = contentPart.type ?? ''; + if (!contentType) { console.warn('No content type found in content part'); return message; } - const updatedContent = [...(message.content || [])]; + const updatedContent = [...(message.content || [])] as Array< + Partial | undefined + >; if (!updatedContent[index]) { - updatedContent[index] = { type: contentPart.type }; + updatedContent[index] = { type: contentPart.type as AllContentTypes }; } if ( - contentPart.type.startsWith(ContentTypes.TEXT) && + contentType.startsWith(ContentTypes.TEXT) && ContentTypes.TEXT in contentPart && typeof contentPart.text === 'string' ) { @@ -44,21 +54,25 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan type: ContentTypes.TEXT, text: (currentContent.text || '') + contentPart.text, }; - } else if (contentPart.type === 'image_url' && 'image_url' in contentPart) { - const currentContent = updatedContent[index] as { type: 'image_url'; image_url: string }; + } else if (contentType === ContentTypes.IMAGE_URL && 'image_url' in contentPart) { + const currentContent = updatedContent[index] as { + type: ContentTypes.IMAGE_URL; + image_url: string; + }; updatedContent[index] = { ...currentContent, }; - } else if (contentPart.type === ContentTypes.TOOL_CALL && 'tool_call' in contentPart) { - const existingContent = updatedContent[index] as Agents.ToolCallContent; + } else if (contentType === ContentTypes.TOOL_CALL && 'tool_call' in contentPart) { + const existingContent = updatedContent[index] as Agents.ToolCallContent | undefined; + const existingToolCall = existingContent?.tool_call; + const toolCallArgs = (contentPart.tool_call.args as unknown as string | undefined) ?? ''; const args = finalUpdate ? contentPart.tool_call.args - : (existingContent?.tool_call?.args || '') + (contentPart.tool_call.args || ''); + : (existingToolCall?.args ?? '') + toolCallArgs; - const id = getNonEmptyValue([contentPart.tool_call.id, existingContent?.tool_call?.id]) ?? ''; - const name = - getNonEmptyValue([contentPart.tool_call.name, existingContent?.tool_call?.name]) ?? ''; + const id = getNonEmptyValue([contentPart.tool_call.id, existingToolCall?.id]) ?? ''; + const name = getNonEmptyValue([contentPart.tool_call.name, existingToolCall?.name]) ?? ''; const newToolCall: Agents.ToolCall & PartMetadata = { id, @@ -78,7 +92,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan }; } - return { ...message, content: updatedContent }; + return { ...message, content: updatedContent as TMessageContentParts[] }; }; return useCallback( @@ -87,7 +101,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan if (event === 'on_run_step') { const runStep = data as Agents.RunStep; - const responseMessageId = runStep.runId; + const responseMessageId = runStep.runId ?? ''; if (!responseMessageId) { console.warn('No message id found in run step event'); return; @@ -98,12 +112,12 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan if (!response) { const responseMessage = messages[messages.length - 1] as TMessage; - const userMessage = messages[messages.length - 2]; + const userMessage = messages[messages.length - 2] as TMessage | null; response = { ...responseMessage, - parentMessageId: userMessage?.messageId, - conversationId: userMessage?.conversationId, + parentMessageId: userMessage?.messageId ?? '', + conversationId: userMessage?.conversationId ?? '', messageId: responseMessageId, content: [], }; @@ -115,20 +129,23 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan // Store tool call IDs if present if (runStep.stepDetails.type === StepTypes.TOOL_CALLS) { runStep.stepDetails.tool_calls.forEach((toolCall) => { - if ('id' in toolCall && toolCall.id) { - toolCallIdMap.current.set(runStep.id, toolCall.id); + const toolCallId = toolCall.id ?? ''; + if ('id' in toolCall && toolCallId) { + toolCallIdMap.current.set(runStep.id, toolCallId); } }); } } else if (event === 'on_message_delta') { const messageDelta = data as Agents.MessageDeltaEvent; const runStep = stepMap.current.get(messageDelta.id); - if (!runStep || !runStep.runId) { + const responseMessageId = runStep?.runId ?? ''; + + if (!runStep || !responseMessageId) { console.warn('No run step or runId found for message delta event'); return; } - const response = messageMap.current.get(runStep.runId); + const response = messageMap.current.get(responseMessageId); if (response && messageDelta.delta.content) { const contentPart = Array.isArray(messageDelta.delta.content) ? messageDelta.delta.content[0] @@ -136,19 +153,21 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan const updatedResponse = updateContent(response, runStep.index, contentPart); - messageMap.current.set(runStep.runId, updatedResponse); + messageMap.current.set(responseMessageId, updatedResponse); const currentMessages = getMessages() || []; setMessages([...currentMessages.slice(0, -1), updatedResponse]); } } else if (event === 'on_run_step_delta') { const runStepDelta = data as Agents.RunStepDeltaEvent; const runStep = stepMap.current.get(runStepDelta.id); - if (!runStep || !runStep.runId) { + const responseMessageId = runStep?.runId ?? ''; + + if (!runStep || !responseMessageId) { console.warn('No run step or runId found for run step delta event'); return; } - const response = messageMap.current.get(runStep.runId); + const response = messageMap.current.get(responseMessageId); if ( response && runStepDelta.delta.type === StepTypes.TOOL_CALLS && @@ -157,13 +176,13 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan let updatedResponse = { ...response }; runStepDelta.delta.tool_calls.forEach((toolCallDelta) => { - const toolCallId = toolCallIdMap.current.get(runStepDelta.id) || ''; + const toolCallId = toolCallIdMap.current.get(runStepDelta.id) ?? ''; const contentPart: Agents.MessageContentComplex = { type: ContentTypes.TOOL_CALL, tool_call: { name: toolCallDelta.name ?? '', - args: toolCallDelta.args || '', + args: toolCallDelta.args ?? '', id: toolCallId, }, }; @@ -171,7 +190,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan updatedResponse = updateContent(updatedResponse, runStep.index, contentPart); }); - messageMap.current.set(runStep.runId, updatedResponse); + messageMap.current.set(responseMessageId, updatedResponse); const updatedMessages = messages.map((msg) => msg.messageId === runStep.runId ? updatedResponse : msg, ); @@ -184,12 +203,14 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan const { id: stepId } = result; const runStep = stepMap.current.get(stepId); - if (!runStep || !runStep.runId) { + const responseMessageId = runStep?.runId ?? ''; + + if (!runStep || !responseMessageId) { console.warn('No run step or runId found for completed tool call event'); return; } - const response = messageMap.current.get(runStep.runId); + const response = messageMap.current.get(responseMessageId); if (response) { let updatedResponse = { ...response }; @@ -200,7 +221,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan updatedResponse = updateContent(updatedResponse, runStep.index, contentPart, true); - messageMap.current.set(runStep.runId, updatedResponse); + messageMap.current.set(responseMessageId, updatedResponse); const updatedMessages = messages.map((msg) => msg.messageId === runStep.runId ? updatedResponse : msg, ); diff --git a/packages/data-provider/src/types/agents.ts b/packages/data-provider/src/types/agents.ts index 04f668be63..6c4d627063 100644 --- a/packages/data-provider/src/types/agents.ts +++ b/packages/data-provider/src/types/agents.ts @@ -13,7 +13,7 @@ export namespace Agents { }; export type MessageContentImageUrl = { - type: 'image_url'; + type: ContentTypes.IMAGE_URL; image_url: string | { url: string; detail?: ImageDetail }; }; @@ -21,7 +21,7 @@ export namespace Agents { | MessageContentText | MessageContentImageUrl // eslint-disable-next-line @typescript-eslint/no-explicit-any - | (Record & { type?: ContentTypes | 'image_url' | 'text_delta' | string }) + | (Record & { type?: ContentTypes | string }) // eslint-disable-next-line @typescript-eslint/no-explicit-any | (Record & { type?: never }); @@ -38,7 +38,7 @@ export namespace Agents { /** The arguments to the tool call */ // eslint-disable-next-line @typescript-eslint/no-explicit-any - args: string | Record; + args?: string | Record; /** If provided, an identifier associated with the tool call */ id?: string; @@ -50,14 +50,14 @@ export namespace Agents { /** The Step Id of the Tool Call */ id: string; /** The Completed Tool Call */ - tool_call: ToolCall; + tool_call?: ToolCall; /** The content index of the tool call */ index: number; }; export type ToolCallContent = { type: ContentTypes.TOOL_CALL; - tool_call: ToolCall; + tool_call?: ToolCall; }; /** @@ -215,5 +215,5 @@ export namespace Agents { */ content?: MessageContentComplex[]; } - export type ContentType = ContentTypes.TEXT | 'image_url' | string; + export type ContentType = ContentTypes.TEXT | ContentTypes.IMAGE_URL | string; } diff --git a/packages/data-provider/src/types/runs.ts b/packages/data-provider/src/types/runs.ts index 4a28a32032..2d66d58695 100644 --- a/packages/data-provider/src/types/runs.ts +++ b/packages/data-provider/src/types/runs.ts @@ -1,5 +1,6 @@ export enum ContentTypes { TEXT = 'text', + TEXT_DELTA = 'text_delta', TOOL_CALL = 'tool_call', IMAGE_FILE = 'image_file', IMAGE_URL = 'image_url',