diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 7640388060..05ea5c0149 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -217,6 +217,7 @@ class BaseClient { userMessage, conversationId, responseMessageId, + sender: this.sender, }); } diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 8932ba8fd9..fd091ead20 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -38,7 +38,7 @@ const providerSchemas = { class AgentClient extends BaseClient { constructor(options = {}) { - super(options); + super(null, options); /** @type {'discard' | 'summarize'} */ this.contextStrategy = 'discard'; diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index c7927e6a60..bddb2befc8 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -1,4 +1,4 @@ -const { Constants, getResponseSender } = require('librechat-data-provider'); +const { Constants } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage } = require('~/server/utils'); const { saveMessage } = require('~/models'); @@ -9,22 +9,17 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { text, endpointOption, conversationId, - modelDisplayLabel, parentMessageId = null, overrideParentMessageId = null, } = req.body; + let sender; let userMessage; - let userMessagePromise; let promptTokens; let userMessageId; let responseMessageId; + let userMessagePromise; - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.model_parameters.model, - modelDisplayLabel, - }); const newConvo = !conversationId; const user = req.user.id; @@ -39,6 +34,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { responseMessageId = data[key]; } else if (key === 'promptTokens') { promptTokens = data[key]; + } else if (key === 'sender') { + sender = data[key]; } else if (!conversationId && key === 'conversationId') { conversationId = data[key]; } diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index b5b006848f..790be90674 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -12,7 +12,11 @@ const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); const { createContentAggregator } = require('@librechat/agents'); -const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); +const { + EModelEndpoint, + providerEndpointMap, + getResponseSender, +} = require('librechat-data-provider'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); // for testing purposes // const createTavilySearchTool = require('~/app/clients/tools/structured/TavilySearch'); @@ -103,10 +107,16 @@ const initializeClient = async ({ req, res, endpointOption }) => { }); modelOptions = Object.assign(modelOptions, options.llmConfig); + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.model_parameters.model, + }); + const client = new AgentClient({ req, agent, tools, + sender, toolMap, contentParts, modelOptions, diff --git a/api/server/services/Endpoints/bedrock/initialize.js b/api/server/services/Endpoints/bedrock/initialize.js index b7fcf40863..be392e8130 100644 --- a/api/server/services/Endpoints/bedrock/initialize.js +++ b/api/server/services/Endpoints/bedrock/initialize.js @@ -1,5 +1,9 @@ const { createContentAggregator } = require('@librechat/agents'); -const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); +const { + EModelEndpoint, + providerEndpointMap, + getResponseSender, +} = require('librechat-data-provider'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); // const { loadAgentTools } = require('~/server/services/ToolService'); const getOptions = require('~/server/services/Endpoints/bedrock/options'); @@ -40,9 +44,15 @@ const initializeClient = async ({ req, res, endpointOption }) => { agent.max_context_tokens ?? getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]); + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.model_parameters.model, + }); + const client = new AgentClient({ req, agent, + sender, // tools, // toolMap, modelOptions, diff --git a/packages/data-provider/src/parsers.ts b/packages/data-provider/src/parsers.ts index 253835ec3d..3d21b7217e 100644 --- a/packages/data-provider/src/parsers.ts +++ b/packages/data-provider/src/parsers.ts @@ -232,8 +232,9 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string => if ( [ EModelEndpoint.openAI, - EModelEndpoint.azureOpenAI, + EModelEndpoint.bedrock, EModelEndpoint.gptPlugins, + EModelEndpoint.azureOpenAI, EModelEndpoint.chatGPTBrowser, ].includes(endpoint) ) {