feat: backend content aggregation for agents/bedrock

This commit is contained in:
Danny Avila 2024-09-02 13:02:47 -04:00
parent 16ba3ed243
commit 757f544a9b
No known key found for this signature in database
GPG key ID: 2DD9CC89B9B50364
4 changed files with 32 additions and 17 deletions

View file

@ -2,6 +2,7 @@ const { GraphEvents, ToolEndHandler, ChatModelStreamHandler } = require('@librec
/** @typedef {import('@librechat/agents').EventHandler} EventHandler */ /** @typedef {import('@librechat/agents').EventHandler} EventHandler */
/** @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler */ /** @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler */
/** @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator */
/** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */ /** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */
/** /**
@ -20,13 +21,17 @@ const sendEvent = (res, event) => {
/** /**
* Get default handlers for stream events. * Get default handlers for stream events.
* @param {{ res?: ServerResponse }} options - The options object. * @param {Object} options - The options object.
* @param {ServerResponse} options.res - The options object.
* @param {ContentAggregator} options.aggregateContent - The options object.
* @returns {Record<string, t.EventHandler>} The default handlers. * @returns {Record<string, t.EventHandler>} The default handlers.
* @throws {Error} If the request is not found. * @throws {Error} If the request is not found.
*/ */
function getDefaultHandlers({ res }) { function getDefaultHandlers({ res, aggregateContent }) {
if (!res) { if (!res || !aggregateContent) {
throw new Error('Request not found'); throw new Error(
`[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`,
);
} }
const handlers = { const handlers = {
// [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(), // [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(),
@ -40,6 +45,7 @@ function getDefaultHandlers({ res }) {
*/ */
handle: (event, data) => { handle: (event, data) => {
sendEvent(res, { event, data }); sendEvent(res, { event, data });
aggregateContent({ event, data });
}, },
}, },
[GraphEvents.ON_RUN_STEP_DELTA]: { [GraphEvents.ON_RUN_STEP_DELTA]: {
@ -50,6 +56,7 @@ function getDefaultHandlers({ res }) {
*/ */
handle: (event, data) => { handle: (event, data) => {
sendEvent(res, { event, data }); sendEvent(res, { event, data });
aggregateContent({ event, data });
}, },
}, },
[GraphEvents.ON_RUN_STEP_COMPLETED]: { [GraphEvents.ON_RUN_STEP_COMPLETED]: {
@ -60,6 +67,7 @@ function getDefaultHandlers({ res }) {
*/ */
handle: (event, data) => { handle: (event, data) => {
sendEvent(res, { event, data }); sendEvent(res, { event, data });
aggregateContent({ event, data });
}, },
}, },
[GraphEvents.ON_MESSAGE_DELTA]: { [GraphEvents.ON_MESSAGE_DELTA]: {
@ -70,6 +78,7 @@ function getDefaultHandlers({ res }) {
*/ */
handle: (event, data) => { handle: (event, data) => {
sendEvent(res, { event, data }); sendEvent(res, { event, data });
aggregateContent({ event, data });
}, },
}, },
}; };

View file

@ -8,11 +8,7 @@
// mapModelToAzureConfig, // mapModelToAzureConfig,
// } = require('librechat-data-provider'); // } = require('librechat-data-provider');
const { Callback } = require('@librechat/agents'); const { Callback } = require('@librechat/agents');
const { const { providerEndpointMap, removeNullishValues } = require('librechat-data-provider');
EModelEndpoint,
providerEndpointMap,
removeNullishValues,
} = require('librechat-data-provider');
const { const {
extractBaseURL, extractBaseURL,
// constructAzureURL, // constructAzureURL,
@ -29,6 +25,8 @@ const BaseClient = require('~/app/clients/BaseClient');
const { createRun } = require('./run'); const { createRun } = require('./run');
const { logger } = require('~/config'); const { logger } = require('~/config');
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
class AgentClient extends BaseClient { class AgentClient extends BaseClient {
constructor(options = {}) { constructor(options = {}) {
super(options); super(options);
@ -43,7 +41,9 @@ class AgentClient extends BaseClient {
this.modelOptions = modelOptions; this.modelOptions = modelOptions;
this.maxContextTokens = maxContextTokens; this.maxContextTokens = maxContextTokens;
this.options = Object.assign({ endpoint: EModelEndpoint.agents }, clientOptions); /** @type {MessageContentComplex[]} */
this.contentParts = options.contentParts;
this.options = Object.assign({ endpoint: options.endpoint }, clientOptions);
} }
setOptions(options) { setOptions(options) {
@ -270,11 +270,12 @@ class AgentClient extends BaseClient {
/** @type {sendCompletion} */ /** @type {sendCompletion} */
async sendCompletion(payload, opts = {}) { async sendCompletion(payload, opts = {}) {
this.modelOptions.user = this.user; this.modelOptions.user = this.user;
return await this.chatCompletion({ await this.chatCompletion({
payload, payload,
onProgress: opts.onProgress, onProgress: opts.onProgress,
abortController: opts.abortController, abortController: opts.abortController,
}); });
return this.contentParts;
} }
// async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) { // async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) {
@ -415,6 +416,7 @@ class AgentClient extends BaseClient {
thread_id: this.conversationId, thread_id: this.conversationId,
}, },
run_id: this.responseMessageId, run_id: this.responseMessageId,
signal: abortController.signal,
streamMode: 'values', streamMode: 'values',
version: 'v2', version: 'v2',
}; };
@ -424,7 +426,7 @@ class AgentClient extends BaseClient {
} }
const messages = formatAgentMessages(payload); const messages = formatAgentMessages(payload);
const runMessages = await run.processStream({ messages }, config, { await run.processStream({ messages }, config, {
[Callback.TOOL_ERROR]: (graph, error, toolId) => { [Callback.TOOL_ERROR]: (graph, error, toolId) => {
logger.error( logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error', '[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
@ -433,8 +435,7 @@ class AgentClient extends BaseClient {
); );
}, },
}); });
// console.dir(runMessages, { depth: null }); logger.info(this.contentParts, { depth: null });
return runMessages;
} catch (err) { } catch (err) {
logger.error( logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Unhandled error type', '[api/server/controllers/agents/client.js #chatCompletion] Unhandled error type',

View file

@ -11,6 +11,7 @@
const { z } = require('zod'); const { z } = require('zod');
const { tool } = require('@langchain/core/tools'); const { tool } = require('@langchain/core/tools');
const { createContentAggregator } = require('@librechat/agents');
const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// for testing purposes // for testing purposes
@ -53,7 +54,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
} }
// TODO: use endpointOption to determine options/modelOptions // TODO: use endpointOption to determine options/modelOptions
const eventHandlers = getDefaultHandlers({ res }); const { contentParts, aggregateContent } = createContentAggregator();
const eventHandlers = getDefaultHandlers({ res, aggregateContent });
// const tools = [createTavilySearchTool()]; // const tools = [createTavilySearchTool()];
// const tools = [_getWeather]; // const tools = [_getWeather];
@ -106,6 +108,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
agent, agent,
tools, tools,
toolMap, toolMap,
contentParts,
modelOptions, modelOptions,
eventHandlers, eventHandlers,
endpoint: EModelEndpoint.agents, endpoint: EModelEndpoint.agents,

View file

@ -1,3 +1,4 @@
const { createContentAggregator } = require('@librechat/agents');
const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// const { loadAgentTools } = require('~/server/services/ToolService'); // const { loadAgentTools } = require('~/server/services/ToolService');
@ -10,8 +11,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
throw new Error('Endpoint option not provided'); throw new Error('Endpoint option not provided');
} }
// TODO: use endpointOption to determine options/modelOptions const { contentParts, aggregateContent } = createContentAggregator();
const eventHandlers = getDefaultHandlers({ res }); const eventHandlers = getDefaultHandlers({ res, aggregateContent });
// const tools = [createTavilySearchTool()]; // const tools = [createTavilySearchTool()];
@ -45,6 +46,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
// tools, // tools,
// toolMap, // toolMap,
modelOptions, modelOptions,
contentParts,
eventHandlers, eventHandlers,
maxContextTokens, maxContextTokens,
endpoint: EModelEndpoint.bedrock, endpoint: EModelEndpoint.bedrock,