fix: Bedrock response sender

This commit is contained in:
Danny Avila 2024-09-02 21:40:38 -04:00
parent d66a35887d
commit edcc66685a
No known key found for this signature in database
GPG key ID: 2DD9CC89B9B50364
6 changed files with 31 additions and 12 deletions

View file

@ -217,6 +217,7 @@ class BaseClient {
userMessage, userMessage,
conversationId, conversationId,
responseMessageId, responseMessageId,
sender: this.sender,
}); });
} }

View file

@ -38,7 +38,7 @@ const providerSchemas = {
class AgentClient extends BaseClient { class AgentClient extends BaseClient {
constructor(options = {}) { constructor(options = {}) {
super(options); super(null, options);
/** @type {'discard' | 'summarize'} */ /** @type {'discard' | 'summarize'} */
this.contextStrategy = 'discard'; this.contextStrategy = 'discard';

View file

@ -1,4 +1,4 @@
const { Constants, getResponseSender } = require('librechat-data-provider'); const { Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage } = require('~/server/utils'); const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models'); const { saveMessage } = require('~/models');
@ -9,22 +9,17 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
text, text,
endpointOption, endpointOption,
conversationId, conversationId,
modelDisplayLabel,
parentMessageId = null, parentMessageId = null,
overrideParentMessageId = null, overrideParentMessageId = null,
} = req.body; } = req.body;
let sender;
let userMessage; let userMessage;
let userMessagePromise;
let promptTokens; let promptTokens;
let userMessageId; let userMessageId;
let responseMessageId; let responseMessageId;
let userMessagePromise;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
modelDisplayLabel,
});
const newConvo = !conversationId; const newConvo = !conversationId;
const user = req.user.id; const user = req.user.id;
@ -39,6 +34,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
responseMessageId = data[key]; responseMessageId = data[key];
} else if (key === 'promptTokens') { } else if (key === 'promptTokens') {
promptTokens = data[key]; promptTokens = data[key];
} else if (key === 'sender') {
sender = data[key];
} else if (!conversationId && key === 'conversationId') { } else if (!conversationId && key === 'conversationId') {
conversationId = data[key]; conversationId = data[key];
} }

View file

@ -12,7 +12,11 @@
const { z } = require('zod'); const { z } = require('zod');
const { tool } = require('@langchain/core/tools'); const { tool } = require('@langchain/core/tools');
const { createContentAggregator } = require('@librechat/agents'); const { createContentAggregator } = require('@librechat/agents');
const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); const {
EModelEndpoint,
providerEndpointMap,
getResponseSender,
} = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// for testing purposes // for testing purposes
// const createTavilySearchTool = require('~/app/clients/tools/structured/TavilySearch'); // const createTavilySearchTool = require('~/app/clients/tools/structured/TavilySearch');
@ -103,10 +107,16 @@ const initializeClient = async ({ req, res, endpointOption }) => {
}); });
modelOptions = Object.assign(modelOptions, options.llmConfig); modelOptions = Object.assign(modelOptions, options.llmConfig);
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
});
const client = new AgentClient({ const client = new AgentClient({
req, req,
agent, agent,
tools, tools,
sender,
toolMap, toolMap,
contentParts, contentParts,
modelOptions, modelOptions,

View file

@ -1,5 +1,9 @@
const { createContentAggregator } = require('@librechat/agents'); const { createContentAggregator } = require('@librechat/agents');
const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); const {
EModelEndpoint,
providerEndpointMap,
getResponseSender,
} = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// const { loadAgentTools } = require('~/server/services/ToolService'); // const { loadAgentTools } = require('~/server/services/ToolService');
const getOptions = require('~/server/services/Endpoints/bedrock/options'); const getOptions = require('~/server/services/Endpoints/bedrock/options');
@ -40,9 +44,15 @@ const initializeClient = async ({ req, res, endpointOption }) => {
agent.max_context_tokens ?? agent.max_context_tokens ??
getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]); getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]);
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
});
const client = new AgentClient({ const client = new AgentClient({
req, req,
agent, agent,
sender,
// tools, // tools,
// toolMap, // toolMap,
modelOptions, modelOptions,

View file

@ -232,8 +232,9 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string =>
if ( if (
[ [
EModelEndpoint.openAI, EModelEndpoint.openAI,
EModelEndpoint.azureOpenAI, EModelEndpoint.bedrock,
EModelEndpoint.gptPlugins, EModelEndpoint.gptPlugins,
EModelEndpoint.azureOpenAI,
EModelEndpoint.chatGPTBrowser, EModelEndpoint.chatGPTBrowser,
].includes(endpoint) ].includes(endpoint)
) { ) {