fix: Bedrock response sender

This commit is contained in:
Danny Avila 2024-09-02 21:40:38 -04:00
parent d66a35887d
commit edcc66685a
No known key found for this signature in database
GPG key ID: 2DD9CC89B9B50364
6 changed files with 31 additions and 12 deletions

View file

@ -217,6 +217,7 @@ class BaseClient {
userMessage,
conversationId,
responseMessageId,
sender: this.sender,
});
}

View file

@ -38,7 +38,7 @@ const providerSchemas = {
class AgentClient extends BaseClient {
constructor(options = {}) {
super(options);
super(null, options);
/** @type {'discard' | 'summarize'} */
this.contextStrategy = 'discard';

View file

@ -1,4 +1,4 @@
const { Constants, getResponseSender } = require('librechat-data-provider');
const { Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models');
@ -9,22 +9,17 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
text,
endpointOption,
conversationId,
modelDisplayLabel,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
let sender;
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
let userMessagePromise;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
modelDisplayLabel,
});
const newConvo = !conversationId;
const user = req.user.id;
@ -39,6 +34,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (key === 'sender') {
sender = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}

View file

@ -12,7 +12,11 @@
const { z } = require('zod');
const { tool } = require('@langchain/core/tools');
const { createContentAggregator } = require('@librechat/agents');
const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider');
const {
EModelEndpoint,
providerEndpointMap,
getResponseSender,
} = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// for testing purposes
// const createTavilySearchTool = require('~/app/clients/tools/structured/TavilySearch');
@ -103,10 +107,16 @@ const initializeClient = async ({ req, res, endpointOption }) => {
});
modelOptions = Object.assign(modelOptions, options.llmConfig);
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
});
const client = new AgentClient({
req,
agent,
tools,
sender,
toolMap,
contentParts,
modelOptions,

View file

@ -1,5 +1,9 @@
const { createContentAggregator } = require('@librechat/agents');
const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider');
const {
EModelEndpoint,
providerEndpointMap,
getResponseSender,
} = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// const { loadAgentTools } = require('~/server/services/ToolService');
const getOptions = require('~/server/services/Endpoints/bedrock/options');
@ -40,9 +44,15 @@ const initializeClient = async ({ req, res, endpointOption }) => {
agent.max_context_tokens ??
getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]);
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
});
const client = new AgentClient({
req,
agent,
sender,
// tools,
// toolMap,
modelOptions,

View file

@ -232,8 +232,9 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string =>
if (
[
EModelEndpoint.openAI,
EModelEndpoint.azureOpenAI,
EModelEndpoint.bedrock,
EModelEndpoint.gptPlugins,
EModelEndpoint.azureOpenAI,
EModelEndpoint.chatGPTBrowser,
].includes(endpoint)
) {