refactor: Integrate streamId handling for improved resumable functionality for attachments

- Added streamId parameter to various functions to support resumable mode in tool loading and memory processing.
- Updated related methods to ensure proper handling of attachments and responses based on the presence of streamId, enhancing the overall streaming experience.
- Improved logging and attachment management to accommodate both standard and resumable modes.
This commit is contained in:
Danny Avila 2025-12-12 20:06:55 -05:00
parent 3dee970d22
commit db9399aae6
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
6 changed files with 56 additions and 19 deletions

View file

@ -594,10 +594,12 @@ class AgentClient extends BaseClient {
const userId = this.options.req.user.id + ''; const userId = this.options.req.user.id + '';
const messageId = this.responseMessageId + ''; const messageId = this.responseMessageId + '';
const conversationId = this.conversationId + ''; const conversationId = this.conversationId + '';
const streamId = this.options.req?._resumableStreamId || null;
const [withoutKeys, processMemory] = await createMemoryProcessor({ const [withoutKeys, processMemory] = await createMemoryProcessor({
userId, userId,
config, config,
messageId, messageId,
streamId,
conversationId, conversationId,
memoryMethods: { memoryMethods: {
setMemory: db.setMemory, setMemory: db.setMemory,

View file

@ -25,9 +25,11 @@ const { logViolation } = require('~/cache');
const db = require('~/models'); const db = require('~/models');
/** /**
* @param {AbortSignal} signal * Creates a tool loader function for the agent.
* @param {AbortSignal} signal - The abort signal
* @param {string | null} [streamId] - The stream ID for resumable mode
*/ */
function createToolLoader(signal) { function createToolLoader(signal, streamId = null) {
/** /**
* @param {object} params * @param {object} params
* @param {ServerRequest} params.req * @param {ServerRequest} params.req
@ -52,6 +54,7 @@ function createToolLoader(signal) {
agent, agent,
signal, signal,
tool_resources, tool_resources,
streamId,
}); });
} catch (error) { } catch (error) {
logger.error('Error loading tools for agent ' + agentId, error); logger.error('Error loading tools for agent ' + agentId, error);
@ -108,7 +111,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
const agentConfigs = new Map(); const agentConfigs = new Map();
const allowedProviders = new Set(appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders); const allowedProviders = new Set(appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders);
const loadTools = createToolLoader(signal); const loadTools = createToolLoader(signal, streamId);
/** @type {Array<MongoFile>} */ /** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? []; const requestFiles = req.body.files ?? [];
/** @type {string} */ /** @type {string} */

View file

@ -369,7 +369,15 @@ async function processRequiredActions(client, requiredActions) {
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key. * @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools. * @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
*/ */
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) { async function loadAgentTools({
req,
res,
agent,
signal,
tool_resources,
openAIApiKey,
streamId = null,
}) {
if (!agent.tools || agent.tools.length === 0) { if (!agent.tools || agent.tools.length === 0) {
return {}; return {};
} else if ( } else if (
@ -422,7 +430,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
/** @type {ReturnType<typeof createOnSearchResults>} */ /** @type {ReturnType<typeof createOnSearchResults>} */
let webSearchCallbacks; let webSearchCallbacks;
if (includesWebSearch) { if (includesWebSearch) {
webSearchCallbacks = createOnSearchResults(res); webSearchCallbacks = createOnSearchResults(res, streamId);
} }
/** @type {Record<string, Record<string, string>>} */ /** @type {Record<string, Record<string, string>>} */

View file

@ -1,13 +1,29 @@
const { nanoid } = require('nanoid'); const { nanoid } = require('nanoid');
const { Tools } = require('librechat-data-provider'); const { Tools } = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { GenerationJobManager } = require('@librechat/api');
/**
* Helper to write attachment events either to res or to job emitter.
* @param {import('http').ServerResponse} res - The server response object
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
* @param {Object} attachment - The attachment data
*/
function writeAttachment(res, streamId, attachment) {
if (streamId) {
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
} else {
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
}
}
/** /**
* Creates a function to handle search results and stream them as attachments * Creates a function to handle search results and stream them as attachments
* @param {import('http').ServerResponse} res - The HTTP server response object * @param {import('http').ServerResponse} res - The HTTP server response object
* @param {string | null} [streamId] - The stream ID for resumable mode, or null for standard mode
* @returns {{ onSearchResults: function(SearchResult, GraphRunnableConfig): void; onGetHighlights: function(string): void}} - Function that takes search results and returns or streams an attachment * @returns {{ onSearchResults: function(SearchResult, GraphRunnableConfig): void; onGetHighlights: function(string): void}} - Function that takes search results and returns or streams an attachment
*/ */
function createOnSearchResults(res) { function createOnSearchResults(res, streamId = null) {
const context = { const context = {
sourceMap: new Map(), sourceMap: new Map(),
searchResultData: undefined, searchResultData: undefined,
@ -70,7 +86,7 @@ function createOnSearchResults(res) {
if (!res.headersSent) { if (!res.headersSent) {
return attachment; return attachment;
} }
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`); writeAttachment(res, streamId, attachment);
} }
/** /**
@ -92,7 +108,7 @@ function createOnSearchResults(res) {
} }
const attachment = buildAttachment(context); const attachment = buildAttachment(context);
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`); writeAttachment(res, streamId, attachment);
} }
return { return {

View file

@ -189,12 +189,8 @@ export default function useResumableSSE(
} }
if (data.sync != null) { if (data.sync != null) {
const textPart = data.resumeState?.aggregatedContent?.find(
(p: { type: string }) => p.type === 'text',
);
console.log('[ResumableSSE] SYNC received', { console.log('[ResumableSSE] SYNC received', {
runSteps: data.resumeState?.runSteps?.length ?? 0, runSteps: data.resumeState?.runSteps?.length ?? 0,
contentLength: textPart?.text?.length ?? 0,
}); });
const runId = v4(); const runId = v4();
@ -231,9 +227,6 @@ export default function useResumableSSE(
); );
} }
const textPart = data.resumeState.aggregatedContent?.find(
(p: { type: string }) => p.type === 'text',
);
console.log('[ResumableSSE] SYNC update', { console.log('[ResumableSSE] SYNC update', {
userMsgId, userMsgId,
serverResponseId, serverResponseId,
@ -241,7 +234,6 @@ export default function useResumableSSE(
foundMessageId: responseIdx >= 0 ? messages[responseIdx]?.messageId : null, foundMessageId: responseIdx >= 0 ? messages[responseIdx]?.messageId : null,
messagesCount: messages.length, messagesCount: messages.length,
aggregatedContentLength: data.resumeState.aggregatedContent?.length, aggregatedContentLength: data.resumeState.aggregatedContent?.length,
textContentLength: textPart?.text?.length ?? 0,
}); });
if (responseIdx >= 0) { if (responseIdx >= 0) {

View file

@ -17,6 +17,7 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
import type { ObjectId, MemoryMethods } from '@librechat/data-schemas'; import type { ObjectId, MemoryMethods } from '@librechat/data-schemas';
import type { BaseMessage, ToolMessage } from '@langchain/core/messages'; import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
import type { Response as ServerResponse } from 'express'; import type { Response as ServerResponse } from 'express';
import { GenerationJobManager } from '~/stream/GenerationJobManager';
import { Tokenizer } from '~/utils'; import { Tokenizer } from '~/utils';
type RequiredMemoryMethods = Pick< type RequiredMemoryMethods = Pick<
@ -250,6 +251,7 @@ export class BasicToolEndHandler implements EventHandler {
constructor(callback?: ToolEndCallback) { constructor(callback?: ToolEndCallback) {
this.callback = callback; this.callback = callback;
} }
handle( handle(
event: string, event: string,
data: StreamEventData | undefined, data: StreamEventData | undefined,
@ -282,6 +284,7 @@ export async function processMemory({
llmConfig, llmConfig,
tokenLimit, tokenLimit,
totalTokens = 0, totalTokens = 0,
streamId = null,
}: { }: {
res: ServerResponse; res: ServerResponse;
setMemory: MemoryMethods['setMemory']; setMemory: MemoryMethods['setMemory'];
@ -296,6 +299,7 @@ export async function processMemory({
tokenLimit?: number; tokenLimit?: number;
totalTokens?: number; totalTokens?: number;
llmConfig?: Partial<LLMConfig>; llmConfig?: Partial<LLMConfig>;
streamId?: string | null;
}): Promise<(TAttachment | null)[] | undefined> { }): Promise<(TAttachment | null)[] | undefined> {
try { try {
const memoryTool = createMemoryTool({ const memoryTool = createMemoryTool({
@ -363,7 +367,7 @@ ${memory ?? 'No existing memories'}`;
} }
const artifactPromises: Promise<TAttachment | null>[] = []; const artifactPromises: Promise<TAttachment | null>[] = [];
const memoryCallback = createMemoryCallback({ res, artifactPromises }); const memoryCallback = createMemoryCallback({ res, artifactPromises, streamId });
const customHandlers = { const customHandlers = {
[GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback), [GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback),
}; };
@ -416,6 +420,7 @@ export async function createMemoryProcessor({
memoryMethods, memoryMethods,
conversationId, conversationId,
config = {}, config = {},
streamId = null,
}: { }: {
res: ServerResponse; res: ServerResponse;
messageId: string; messageId: string;
@ -423,6 +428,7 @@ export async function createMemoryProcessor({
userId: string | ObjectId; userId: string | ObjectId;
memoryMethods: RequiredMemoryMethods; memoryMethods: RequiredMemoryMethods;
config?: MemoryConfig; config?: MemoryConfig;
streamId?: string | null;
}): Promise<[string, (messages: BaseMessage[]) => Promise<(TAttachment | null)[] | undefined>]> { }): Promise<[string, (messages: BaseMessage[]) => Promise<(TAttachment | null)[] | undefined>]> {
const { validKeys, instructions, llmConfig, tokenLimit } = config; const { validKeys, instructions, llmConfig, tokenLimit } = config;
const finalInstructions = instructions || getDefaultInstructions(validKeys, tokenLimit); const finalInstructions = instructions || getDefaultInstructions(validKeys, tokenLimit);
@ -443,6 +449,7 @@ export async function createMemoryProcessor({
llmConfig, llmConfig,
messageId, messageId,
tokenLimit, tokenLimit,
streamId,
conversationId, conversationId,
memory: withKeys, memory: withKeys,
totalTokens: totalTokens || 0, totalTokens: totalTokens || 0,
@ -461,10 +468,12 @@ async function handleMemoryArtifact({
res, res,
data, data,
metadata, metadata,
streamId = null,
}: { }: {
res: ServerResponse; res: ServerResponse;
data: ToolEndData; data: ToolEndData;
metadata?: ToolEndMetadata; metadata?: ToolEndMetadata;
streamId?: string | null;
}) { }) {
const output = data?.output as ToolMessage | undefined; const output = data?.output as ToolMessage | undefined;
if (!output) { if (!output) {
@ -490,7 +499,11 @@ async function handleMemoryArtifact({
if (!res.headersSent) { if (!res.headersSent) {
return attachment; return attachment;
} }
if (streamId) {
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
} else {
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`); res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
}
return attachment; return attachment;
} }
@ -499,14 +512,17 @@ async function handleMemoryArtifact({
* @param params - The parameters object * @param params - The parameters object
* @param params.res - The server response object * @param params.res - The server response object
* @param params.artifactPromises - Array to collect artifact promises * @param params.artifactPromises - Array to collect artifact promises
* @param params.streamId - The stream ID for resumable mode, or null for standard mode
* @returns The memory callback function * @returns The memory callback function
*/ */
export function createMemoryCallback({ export function createMemoryCallback({
res, res,
artifactPromises, artifactPromises,
streamId = null,
}: { }: {
res: ServerResponse; res: ServerResponse;
artifactPromises: Promise<Partial<TAttachment> | null>[]; artifactPromises: Promise<Partial<TAttachment> | null>[];
streamId?: string | null;
}): ToolEndCallback { }): ToolEndCallback {
return async (data: ToolEndData, metadata?: Record<string, unknown>) => { return async (data: ToolEndData, metadata?: Record<string, unknown>) => {
const output = data?.output as ToolMessage | undefined; const output = data?.output as ToolMessage | undefined;
@ -515,7 +531,7 @@ export function createMemoryCallback({
return; return;
} }
artifactPromises.push( artifactPromises.push(
handleMemoryArtifact({ res, data, metadata }).catch((error) => { handleMemoryArtifact({ res, data, metadata, streamId }).catch((error) => {
logger.error('Error processing memory artifact content:', error); logger.error('Error processing memory artifact content:', error);
return null; return null;
}), }),