fix: types for step handler

This commit is contained in:
Danny Avila 2024-09-03 07:00:18 -04:00
parent fbead7de40
commit e478e10c96
No known key found for this signature in database
GPG key ID: 2DD9CC89B9B50364
3 changed files with 60 additions and 38 deletions

View file

@ -1,6 +1,6 @@
import { useCallback, useRef } from 'react'; import { useCallback, useRef } from 'react';
import { StepTypes, ContentTypes, ToolCallTypes } from 'librechat-data-provider'; import { StepTypes, ContentTypes, ToolCallTypes } from 'librechat-data-provider';
import type { Agents, PartMetadata, TMessage } from 'librechat-data-provider'; import type { Agents, PartMetadata, TMessage, TMessageContentParts } from 'librechat-data-provider';
import { getNonEmptyValue } from 'librechat-data-provider'; import { getNonEmptyValue } from 'librechat-data-provider';
type TUseStepHandler = { type TUseStepHandler = {
@ -13,8 +13,15 @@ type TStepEvent = {
data: Agents.MessageDeltaEvent | Agents.RunStep | Agents.ToolEndEvent; data: Agents.MessageDeltaEvent | Agents.RunStep | Agents.ToolEndEvent;
}; };
type AllContentTypes =
| ContentTypes.TEXT
| ContentTypes.TOOL_CALL
| ContentTypes.IMAGE_FILE
| ContentTypes.IMAGE_URL
| ContentTypes.ERROR;
export default function useStepHandler({ setMessages, getMessages }: TUseStepHandler) { export default function useStepHandler({ setMessages, getMessages }: TUseStepHandler) {
const toolCallIdMap = useRef(new Map<string, string>()); const toolCallIdMap = useRef(new Map<string, string | undefined>());
const messageMap = useRef(new Map<string, TMessage>()); const messageMap = useRef(new Map<string, TMessage>());
const stepMap = useRef(new Map<string, Agents.RunStep>()); const stepMap = useRef(new Map<string, Agents.RunStep>());
@ -24,18 +31,21 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
contentPart: Agents.MessageContentComplex, contentPart: Agents.MessageContentComplex,
finalUpdate = false, finalUpdate = false,
) => { ) => {
if (!contentPart.type) { const contentType = contentPart.type ?? '';
if (!contentType) {
console.warn('No content type found in content part'); console.warn('No content type found in content part');
return message; return message;
} }
const updatedContent = [...(message.content || [])]; const updatedContent = [...(message.content || [])] as Array<
Partial<TMessageContentParts> | undefined
>;
if (!updatedContent[index]) { if (!updatedContent[index]) {
updatedContent[index] = { type: contentPart.type }; updatedContent[index] = { type: contentPart.type as AllContentTypes };
} }
if ( if (
contentPart.type.startsWith(ContentTypes.TEXT) && contentType.startsWith(ContentTypes.TEXT) &&
ContentTypes.TEXT in contentPart && ContentTypes.TEXT in contentPart &&
typeof contentPart.text === 'string' typeof contentPart.text === 'string'
) { ) {
@ -44,21 +54,25 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
type: ContentTypes.TEXT, type: ContentTypes.TEXT,
text: (currentContent.text || '') + contentPart.text, text: (currentContent.text || '') + contentPart.text,
}; };
} else if (contentPart.type === 'image_url' && 'image_url' in contentPart) { } else if (contentType === ContentTypes.IMAGE_URL && 'image_url' in contentPart) {
const currentContent = updatedContent[index] as { type: 'image_url'; image_url: string }; const currentContent = updatedContent[index] as {
type: ContentTypes.IMAGE_URL;
image_url: string;
};
updatedContent[index] = { updatedContent[index] = {
...currentContent, ...currentContent,
}; };
} else if (contentPart.type === ContentTypes.TOOL_CALL && 'tool_call' in contentPart) { } else if (contentType === ContentTypes.TOOL_CALL && 'tool_call' in contentPart) {
const existingContent = updatedContent[index] as Agents.ToolCallContent; const existingContent = updatedContent[index] as Agents.ToolCallContent | undefined;
const existingToolCall = existingContent?.tool_call;
const toolCallArgs = (contentPart.tool_call.args as unknown as string | undefined) ?? '';
const args = finalUpdate const args = finalUpdate
? contentPart.tool_call.args ? contentPart.tool_call.args
: (existingContent?.tool_call?.args || '') + (contentPart.tool_call.args || ''); : (existingToolCall?.args ?? '') + toolCallArgs;
const id = getNonEmptyValue([contentPart.tool_call.id, existingContent?.tool_call?.id]) ?? ''; const id = getNonEmptyValue([contentPart.tool_call.id, existingToolCall?.id]) ?? '';
const name = const name = getNonEmptyValue([contentPart.tool_call.name, existingToolCall?.name]) ?? '';
getNonEmptyValue([contentPart.tool_call.name, existingContent?.tool_call?.name]) ?? '';
const newToolCall: Agents.ToolCall & PartMetadata = { const newToolCall: Agents.ToolCall & PartMetadata = {
id, id,
@ -78,7 +92,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
}; };
} }
return { ...message, content: updatedContent }; return { ...message, content: updatedContent as TMessageContentParts[] };
}; };
return useCallback( return useCallback(
@ -87,7 +101,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
if (event === 'on_run_step') { if (event === 'on_run_step') {
const runStep = data as Agents.RunStep; const runStep = data as Agents.RunStep;
const responseMessageId = runStep.runId; const responseMessageId = runStep.runId ?? '';
if (!responseMessageId) { if (!responseMessageId) {
console.warn('No message id found in run step event'); console.warn('No message id found in run step event');
return; return;
@ -98,12 +112,12 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
if (!response) { if (!response) {
const responseMessage = messages[messages.length - 1] as TMessage; const responseMessage = messages[messages.length - 1] as TMessage;
const userMessage = messages[messages.length - 2]; const userMessage = messages[messages.length - 2] as TMessage | null;
response = { response = {
...responseMessage, ...responseMessage,
parentMessageId: userMessage?.messageId, parentMessageId: userMessage?.messageId ?? '',
conversationId: userMessage?.conversationId, conversationId: userMessage?.conversationId ?? '',
messageId: responseMessageId, messageId: responseMessageId,
content: [], content: [],
}; };
@ -115,20 +129,23 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
// Store tool call IDs if present // Store tool call IDs if present
if (runStep.stepDetails.type === StepTypes.TOOL_CALLS) { if (runStep.stepDetails.type === StepTypes.TOOL_CALLS) {
runStep.stepDetails.tool_calls.forEach((toolCall) => { runStep.stepDetails.tool_calls.forEach((toolCall) => {
if ('id' in toolCall && toolCall.id) { const toolCallId = toolCall.id ?? '';
toolCallIdMap.current.set(runStep.id, toolCall.id); if ('id' in toolCall && toolCallId) {
toolCallIdMap.current.set(runStep.id, toolCallId);
} }
}); });
} }
} else if (event === 'on_message_delta') { } else if (event === 'on_message_delta') {
const messageDelta = data as Agents.MessageDeltaEvent; const messageDelta = data as Agents.MessageDeltaEvent;
const runStep = stepMap.current.get(messageDelta.id); const runStep = stepMap.current.get(messageDelta.id);
if (!runStep || !runStep.runId) { const responseMessageId = runStep?.runId ?? '';
if (!runStep || !responseMessageId) {
console.warn('No run step or runId found for message delta event'); console.warn('No run step or runId found for message delta event');
return; return;
} }
const response = messageMap.current.get(runStep.runId); const response = messageMap.current.get(responseMessageId);
if (response && messageDelta.delta.content) { if (response && messageDelta.delta.content) {
const contentPart = Array.isArray(messageDelta.delta.content) const contentPart = Array.isArray(messageDelta.delta.content)
? messageDelta.delta.content[0] ? messageDelta.delta.content[0]
@ -136,19 +153,21 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
const updatedResponse = updateContent(response, runStep.index, contentPart); const updatedResponse = updateContent(response, runStep.index, contentPart);
messageMap.current.set(runStep.runId, updatedResponse); messageMap.current.set(responseMessageId, updatedResponse);
const currentMessages = getMessages() || []; const currentMessages = getMessages() || [];
setMessages([...currentMessages.slice(0, -1), updatedResponse]); setMessages([...currentMessages.slice(0, -1), updatedResponse]);
} }
} else if (event === 'on_run_step_delta') { } else if (event === 'on_run_step_delta') {
const runStepDelta = data as Agents.RunStepDeltaEvent; const runStepDelta = data as Agents.RunStepDeltaEvent;
const runStep = stepMap.current.get(runStepDelta.id); const runStep = stepMap.current.get(runStepDelta.id);
if (!runStep || !runStep.runId) { const responseMessageId = runStep?.runId ?? '';
if (!runStep || !responseMessageId) {
console.warn('No run step or runId found for run step delta event'); console.warn('No run step or runId found for run step delta event');
return; return;
} }
const response = messageMap.current.get(runStep.runId); const response = messageMap.current.get(responseMessageId);
if ( if (
response && response &&
runStepDelta.delta.type === StepTypes.TOOL_CALLS && runStepDelta.delta.type === StepTypes.TOOL_CALLS &&
@ -157,13 +176,13 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
let updatedResponse = { ...response }; let updatedResponse = { ...response };
runStepDelta.delta.tool_calls.forEach((toolCallDelta) => { runStepDelta.delta.tool_calls.forEach((toolCallDelta) => {
const toolCallId = toolCallIdMap.current.get(runStepDelta.id) || ''; const toolCallId = toolCallIdMap.current.get(runStepDelta.id) ?? '';
const contentPart: Agents.MessageContentComplex = { const contentPart: Agents.MessageContentComplex = {
type: ContentTypes.TOOL_CALL, type: ContentTypes.TOOL_CALL,
tool_call: { tool_call: {
name: toolCallDelta.name ?? '', name: toolCallDelta.name ?? '',
args: toolCallDelta.args || '', args: toolCallDelta.args ?? '',
id: toolCallId, id: toolCallId,
}, },
}; };
@ -171,7 +190,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
updatedResponse = updateContent(updatedResponse, runStep.index, contentPart); updatedResponse = updateContent(updatedResponse, runStep.index, contentPart);
}); });
messageMap.current.set(runStep.runId, updatedResponse); messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) => const updatedMessages = messages.map((msg) =>
msg.messageId === runStep.runId ? updatedResponse : msg, msg.messageId === runStep.runId ? updatedResponse : msg,
); );
@ -184,12 +203,14 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
const { id: stepId } = result; const { id: stepId } = result;
const runStep = stepMap.current.get(stepId); const runStep = stepMap.current.get(stepId);
if (!runStep || !runStep.runId) { const responseMessageId = runStep?.runId ?? '';
if (!runStep || !responseMessageId) {
console.warn('No run step or runId found for completed tool call event'); console.warn('No run step or runId found for completed tool call event');
return; return;
} }
const response = messageMap.current.get(runStep.runId); const response = messageMap.current.get(responseMessageId);
if (response) { if (response) {
let updatedResponse = { ...response }; let updatedResponse = { ...response };
@ -200,7 +221,7 @@ export default function useStepHandler({ setMessages, getMessages }: TUseStepHan
updatedResponse = updateContent(updatedResponse, runStep.index, contentPart, true); updatedResponse = updateContent(updatedResponse, runStep.index, contentPart, true);
messageMap.current.set(runStep.runId, updatedResponse); messageMap.current.set(responseMessageId, updatedResponse);
const updatedMessages = messages.map((msg) => const updatedMessages = messages.map((msg) =>
msg.messageId === runStep.runId ? updatedResponse : msg, msg.messageId === runStep.runId ? updatedResponse : msg,
); );

View file

@ -13,7 +13,7 @@ export namespace Agents {
}; };
export type MessageContentImageUrl = { export type MessageContentImageUrl = {
type: 'image_url'; type: ContentTypes.IMAGE_URL;
image_url: string | { url: string; detail?: ImageDetail }; image_url: string | { url: string; detail?: ImageDetail };
}; };
@ -21,7 +21,7 @@ export namespace Agents {
| MessageContentText | MessageContentText
| MessageContentImageUrl | MessageContentImageUrl
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
| (Record<string, any> & { type?: ContentTypes | 'image_url' | 'text_delta' | string }) | (Record<string, any> & { type?: ContentTypes | string })
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
| (Record<string, any> & { type?: never }); | (Record<string, any> & { type?: never });
@ -38,7 +38,7 @@ export namespace Agents {
/** The arguments to the tool call */ /** The arguments to the tool call */
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
args: string | Record<string, any>; args?: string | Record<string, any>;
/** If provided, an identifier associated with the tool call */ /** If provided, an identifier associated with the tool call */
id?: string; id?: string;
@ -50,14 +50,14 @@ export namespace Agents {
/** The Step Id of the Tool Call */ /** The Step Id of the Tool Call */
id: string; id: string;
/** The Completed Tool Call */ /** The Completed Tool Call */
tool_call: ToolCall; tool_call?: ToolCall;
/** The content index of the tool call */ /** The content index of the tool call */
index: number; index: number;
}; };
export type ToolCallContent = { export type ToolCallContent = {
type: ContentTypes.TOOL_CALL; type: ContentTypes.TOOL_CALL;
tool_call: ToolCall; tool_call?: ToolCall;
}; };
/** /**
@ -215,5 +215,5 @@ export namespace Agents {
*/ */
content?: MessageContentComplex[]; content?: MessageContentComplex[];
} }
export type ContentType = ContentTypes.TEXT | 'image_url' | string; export type ContentType = ContentTypes.TEXT | ContentTypes.IMAGE_URL | string;
} }

View file

@ -1,5 +1,6 @@
export enum ContentTypes { export enum ContentTypes {
TEXT = 'text', TEXT = 'text',
TEXT_DELTA = 'text_delta',
TOOL_CALL = 'tool_call', TOOL_CALL = 'tool_call',
IMAGE_FILE = 'image_file', IMAGE_FILE = 'image_file',
IMAGE_URL = 'image_url', IMAGE_URL = 'image_url',