🪨 feat: AWS Bedrock support (#3935)

* feat: Add BedrockIcon component to SVG library

* feat: EModelEndpoint.bedrock

* feat: first pass, bedrock chat. note: AgentClient is returning `agents` as conversation.endpoint

* fix: declare endpoint in initialization step

* chore: Update @librechat/agents dependency to version 1.4.5

* feat: backend content aggregation for agents/bedrock

* feat: abort agent requests

* feat: AWS Bedrock icons

* WIP: agent provider schema parsing

* chore: Update EditIcon props type

* refactor(useGenerationsByLatest): make agents and bedrock editable

* refactor: non-assistant message content, parts

* fix: Bedrock response `sender`

* fix: use endpointOption.model_parameters not endpointOption.modelOptions

* fix: types for step handler

* refactor: Update Agents.ToolCallDelta type

* refactor: Remove unnecessary assignment of parentMessageId in AskController

* refactor: remove unnecessary assignment of parentMessageId (agent request handler)

* fix(bedrock/agents): message regeneration

* refactor: dynamic form elements using react-hook-form Controllers

* fix: agent icons/labels for messages

* fix: agent actions

* fix: use of new dynamic tags causing application crash

* refactor: dynamic settings touch-ups

* refactor: update Slider component to allow custom track class name

* refactor: update DynamicSlider component styles

* refactor: use Constants value for GLOBAL_PROJECT_NAME (enum)

* feat: agent share global methods/controllers

* fix: agents query

* fix: `getResponseModel`

* fix: share prompt a11y issue

* refactor: update SharePrompt dialog theme styles

* refactor: explicit typing for SharePrompt

* feat: add agent roles/permissions

* chore: update @librechat/agents dependency to version 1.4.7 for tool_call_ids edge case

* fix(Anthropic): messages.X.content.Y.tool_use.input: Input should be a valid dictionary

* fix: handle text parts with tool_call_ids and empty text

* fix: role initialization

* refactor: don't make instructions required

* refactor: improve typing of Text part

* fix: setShowStopButton for agents route

* chore: remove params for now

* fix: add streamBuffer and streamRate to help prevent 'Overloaded' errors from Anthropic API

* refactor: remove console.log statement in ContentRender component

* chore: typing, rename Context to Delete Button

* chore(DeleteButton): logging

* refactor(Action): make accessible

* style(Action): improve a11y again

* refactor: remove use/mention of mongoose sessions

* feat: first pass, sharing agents

* feat: visual indicator for global agent, remove author when serving to non-author

* wip: params

* chore: fix typing issues

* fix(schemas): typing

* refactor: improve accessibility of ListCard component and fix console React warning

* wip: reset templates for non-legacy new convos

* Revert "wip: params"

This reverts commit f8067e91d4.

* Revert "refactor: dynamic form elements using react-hook-form Controllers"

This reverts commit 2150c4815d.

* fix(Parameters): types and parameter effect update to only update local state to parameters

* refactor: optimize useDebouncedInput hook for better performance

* feat: first pass, anthropic bedrock params

* chore: paramEndpoints check for endpointType too

* fix: maxTokens to use coerceNumber.optional(),

* feat: extra chat model params

* chore: reduce code repetition

* refactor: improve preset title handling in SaveAsPresetDialog component

* refactor: improve preset handling in HeaderOptions component

* chore: improve typing, replace legacy dialog for SaveAsPresetDialog

* feat: save as preset from parameters panel

* fix: multi-search in select dropdown when using Option type

* refactor: update default showDefault value to false in Dynamic components

* feat: Bedrock presets settings

* chore: config, fix agents schema, update config version

* refactor: update AWS region variable name in bedrock options endpoint to BEDROCK_AWS_DEFAULT_REGION

* refactor: update baseEndpointSchema in config.ts to include baseURL property

* refactor: update createRun function to include req parameter and set streamRate based on provider

* feat: availableRegions via config

* refactor: remove unused demo agent controller file

* WIP: title

* Update @librechat/agents to version 1.5.0

* chore: addTitle.js to handle empty responseText

* feat: support images and titles

* feat: context token updates

* Refactor BaseClient test to use expect.objectContaining

* refactor: add model select, remove header options params, move side panel params below prompts

* chore: update models list, catch title error

* feat: model service for bedrock models (env)

* chore: Remove verbose debug log in AgentClient class following stream

* feat(bedrock): track token spend; fix: token rates, value key mapping for AWS models

* refactor: handle streamRate in `handleLLMNewToken` callback

* chore: AWS Bedrock example config in `.env.example`

* refactor: Rename bedrockMeta to bedrockGeneral in settings.ts and use for AI21 and Amazon Bedrock providers

* refactor: Update `.env.example` with AWS Bedrock model IDs URL and additional notes

* feat: titleModel support for bedrock

* refactor: Update `.env.example` with additional notes for AWS Bedrock model IDs
This commit is contained in:
Danny Avila 2024-09-09 12:06:59 -04:00 committed by GitHub
parent 8c14360263
commit d59b62174f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
134 changed files with 3684 additions and 1213 deletions

View file

@ -123,11 +123,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
};
let response = await client.sendMessage(text, messageOptions);
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
response.endpoint = endpointOption.endpoint;
const { conversation = {} } = await client.responsePromise;

View file

@ -44,6 +44,14 @@ async function endpointController(req, res) {
};
}
if (mergedConfig[EModelEndpoint.bedrock] && req.app.locals?.[EModelEndpoint.bedrock]) {
const { availableRegions } = req.app.locals[EModelEndpoint.bedrock];
mergedConfig[EModelEndpoint.bedrock] = {
...mergedConfig[EModelEndpoint.bedrock],
availableRegions,
};
}
const endpointsConfig = orderEndpointsConfig(mergedConfig);
await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);

View file

@ -1,7 +1,10 @@
const { GraphEvents, ToolEndHandler, ChatModelStreamHandler } = require('@librechat/agents');
/** @typedef {import('@librechat/agents').Graph} Graph */
/** @typedef {import('@librechat/agents').EventHandler} EventHandler */
/** @typedef {import('@librechat/agents').ModelEndData} ModelEndData */
/** @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler */
/** @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator */
/** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */
/**
@ -18,18 +21,55 @@ const sendEvent = (res, event) => {
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
};
class ModelEndHandler {
/**
* @param {Array<UsageMetadata>} collectedUsage
*/
constructor(collectedUsage) {
if (!Array.isArray(collectedUsage)) {
throw new Error('collectedUsage must be an array');
}
this.collectedUsage = collectedUsage;
}
/**
* @param {string} event
* @param {ModelEndData | undefined} data
* @param {Record<string, unknown> | undefined} metadata
* @param {Graph} graph
* @returns
*/
handle(event, data, metadata, graph) {
if (!graph || !metadata) {
console.warn(`Graph or metadata not found in ${event} event`);
return;
}
const usage = data?.output?.usage_metadata;
if (usage) {
this.collectedUsage.push(usage);
}
}
}
/**
* Get default handlers for stream events.
* @param {{ res?: ServerResponse }} options - The options object.
* @param {Object} options - The options object.
* @param {ServerResponse} options.res - The options object.
* @param {ContentAggregator} options.aggregateContent - The options object.
* @param {Array<UsageMetadata>} options.collectedUsage - The list of collected usage metadata.
* @returns {Record<string, t.EventHandler>} The default handlers.
* @throws {Error} If the request is not found.
*/
function getDefaultHandlers({ res }) {
if (!res) {
throw new Error('Request not found');
function getDefaultHandlers({ res, aggregateContent, collectedUsage }) {
if (!res || !aggregateContent) {
throw new Error(
`[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`,
);
}
const handlers = {
// [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(),
[GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(collectedUsage),
[GraphEvents.TOOL_END]: new ToolEndHandler(),
[GraphEvents.CHAT_MODEL_STREAM]: new ChatModelStreamHandler(),
[GraphEvents.ON_RUN_STEP]: {
@ -40,6 +80,7 @@ function getDefaultHandlers({ res }) {
*/
handle: (event, data) => {
sendEvent(res, { event, data });
aggregateContent({ event, data });
},
},
[GraphEvents.ON_RUN_STEP_DELTA]: {
@ -50,6 +91,7 @@ function getDefaultHandlers({ res }) {
*/
handle: (event, data) => {
sendEvent(res, { event, data });
aggregateContent({ event, data });
},
},
[GraphEvents.ON_RUN_STEP_COMPLETED]: {
@ -60,6 +102,7 @@ function getDefaultHandlers({ res }) {
*/
handle: (event, data) => {
sendEvent(res, { event, data });
aggregateContent({ event, data });
},
},
[GraphEvents.ON_MESSAGE_DELTA]: {
@ -70,6 +113,7 @@ function getDefaultHandlers({ res }) {
*/
handle: (event, data) => {
sendEvent(res, { event, data });
aggregateContent({ event, data });
},
},
};

View file

@ -7,9 +7,11 @@
// validateVisionModel,
// mapModelToAzureConfig,
// } = require('librechat-data-provider');
const { Callback } = require('@librechat/agents');
const { Callback, createMetadataAggregator } = require('@librechat/agents');
const {
Constants,
EModelEndpoint,
bedrockOutputParser,
providerEndpointMap,
removeNullishValues,
} = require('librechat-data-provider');
@ -23,15 +25,27 @@ const {
formatAgentMessages,
createContextHandlers,
} = require('~/app/clients/prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
const BaseClient = require('~/app/clients/BaseClient');
// const { sleep } = require('~/server/utils');
const { createRun } = require('./run');
const { logger } = require('~/config');
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
// const providerSchemas = {
// [EModelEndpoint.bedrock]: true,
// };
const providerParsers = {
[EModelEndpoint.bedrock]: bedrockOutputParser,
};
class AgentClient extends BaseClient {
constructor(options = {}) {
super(options);
super(null, options);
/** @type {'discard' | 'summarize'} */
this.contextStrategy = 'discard';
@ -39,11 +53,31 @@ class AgentClient extends BaseClient {
/** @deprecated @type {true} - Is a Chat Completion Request */
this.isChatCompletion = true;
const { maxContextTokens, modelOptions = {}, ...clientOptions } = options;
/** @type {AgentRun} */
this.run;
const {
maxContextTokens,
modelOptions = {},
contentParts,
collectedUsage,
...clientOptions
} = options;
this.modelOptions = modelOptions;
this.maxContextTokens = maxContextTokens;
this.options = Object.assign({ endpoint: EModelEndpoint.agents }, clientOptions);
/** @type {MessageContentComplex[]} */
this.contentParts = contentParts;
/** @type {Array<UsageMetadata>} */
this.collectedUsage = collectedUsage;
this.options = Object.assign({ endpoint: options.endpoint }, clientOptions);
}
/**
* Returns the aggregated content parts for the current run.
* @returns {MessageContentComplex[]} */
getContentParts() {
return this.contentParts;
}
setOptions(options) {
@ -112,9 +146,27 @@ class AgentClient extends BaseClient {
}
getSaveOptions() {
const parseOptions = providerParsers[this.options.endpoint];
let runOptions =
this.options.endpoint === EModelEndpoint.agents
? {
model: undefined,
// TODO:
// would need to be override settings; otherwise, model needs to be undefined
// model: this.override.model,
// instructions: this.override.instructions,
// additional_instructions: this.override.additional_instructions,
}
: {};
if (parseOptions) {
runOptions = parseOptions(this.modelOptions);
}
return removeNullishValues(
Object.assign(
{
endpoint: this.options.endpoint,
agent_id: this.options.agent.id,
modelLabel: this.options.modelLabel,
maxContextTokens: this.options.maxContextTokens,
@ -122,15 +174,8 @@ class AgentClient extends BaseClient {
imageDetail: this.options.imageDetail,
spec: this.options.spec,
},
this.modelOptions,
{
model: undefined,
// TODO:
// would need to be override settings; otherwise, model needs to be undefined
// model: this.override.model,
// instructions: this.override.instructions,
// additional_instructions: this.override.additional_instructions,
},
// TODO: PARSE OPTIONS BY PROVIDER, MAY CONTAIN SENSITIVE DATA
runOptions,
),
);
}
@ -142,6 +187,16 @@ class AgentClient extends BaseClient {
};
}
async addImageURLs(message, attachments) {
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments,
this.options.agent.provider,
);
message.image_urls = image_urls.length ? image_urls : undefined;
return files;
}
async buildMessages(
messages,
parentMessageId,
@ -270,25 +325,34 @@ class AgentClient extends BaseClient {
/** @type {sendCompletion} */
async sendCompletion(payload, opts = {}) {
this.modelOptions.user = this.user;
return await this.chatCompletion({
await this.chatCompletion({
payload,
onProgress: opts.onProgress,
abortController: opts.abortController,
});
return this.contentParts;
}
// async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) {
// await spendTokens(
// {
// context,
// model: this.modelOptions.model,
// conversationId: this.conversationId,
// user: this.user ?? this.options.req.user?.id,
// endpointTokenConfig: this.options.endpointTokenConfig,
// },
// { promptTokens, completionTokens },
// );
// }
/**
* @param {Object} params
* @param {string} [params.model]
* @param {string} [params.context='message']
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
*/
async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) {
for (const usage of collectedUsage) {
await spendTokens(
{
context,
model: model ?? this.modelOptions.model,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ promptTokens: usage.input_tokens, completionTokens: usage.output_tokens },
);
}
}
async chatCompletion({ payload, abortController = null }) {
try {
@ -398,9 +462,8 @@ class AgentClient extends BaseClient {
// });
// }
// const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
const run = await createRun({
req: this.options.req,
agent: this.options.agent,
tools: this.options.tools,
toolMap: this.options.toolMap,
@ -415,6 +478,7 @@ class AgentClient extends BaseClient {
thread_id: this.conversationId,
},
run_id: this.responseMessageId,
signal: abortController.signal,
streamMode: 'values',
version: 'v2',
};
@ -423,8 +487,10 @@ class AgentClient extends BaseClient {
throw new Error('Failed to create run');
}
this.run = run;
const messages = formatAgentMessages(payload);
const runMessages = await run.processStream({ messages }, config, {
await run.processStream({ messages }, config, {
[Callback.TOOL_ERROR]: (graph, error, toolId) => {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
@ -433,14 +499,94 @@ class AgentClient extends BaseClient {
);
},
});
// console.dir(runMessages, { depth: null });
return runMessages;
this.recordCollectedUsage({ context: 'message' }).catch((err) => {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
err,
);
});
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Unhandled error type',
if (!abortController.signal.aborted) {
logger.error(
'[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type',
err,
);
throw err;
}
logger.warn(
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
err,
);
throw err;
}
}
/**
*
* @param {Object} params
* @param {string} params.text
* @param {string} params.conversationId
*/
async titleConvo({ text }) {
if (!this.run) {
throw new Error('Run not initialized');
}
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const clientOptions = {};
const providerConfig = this.options.req.app.locals[this.options.agent.provider];
if (
providerConfig &&
providerConfig.titleModel &&
providerConfig.titleModel !== Constants.CURRENT_MODEL
) {
clientOptions.model = providerConfig.titleModel;
}
try {
const titleResult = await this.run.generateTitle({
inputText: text,
contentParts: this.contentParts,
clientOptions,
chainOptions: {
callbacks: [
{
handleLLMEnd,
},
],
},
});
const collectedUsage = collectedMetadata.map((item) => {
let input_tokens, output_tokens;
if (item.usage) {
input_tokens = item.usage.input_tokens || item.usage.inputTokens;
output_tokens = item.usage.output_tokens || item.usage.outputTokens;
} else if (item.tokenUsage) {
input_tokens = item.tokenUsage.promptTokens;
output_tokens = item.tokenUsage.completionTokens;
}
return {
input_tokens: input_tokens,
output_tokens: output_tokens,
};
});
this.recordCollectedUsage({
model: clientOptions.model,
context: 'title',
collectedUsage,
}).catch((err) => {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
err,
);
});
return titleResult.title;
} catch (err) {
logger.error('[api/server/controllers/agents/client.js #titleConvo] Error', err);
return;
}
}

View file

@ -1,44 +0,0 @@
// Import the necessary modules
const path = require('path');
const base = path.resolve(__dirname, '..', '..', '..', '..', 'api');
console.log(base);
//api/server/controllers/agents/demo.js
require('module-alias')({ base });
const connectDb = require('~/lib/db/connectDb');
const AgentClient = require('./client');
// Define the user and message options
const user = 'user123';
const parentMessageId = 'pmid123';
const conversationId = 'cid456';
const maxContextTokens = 200000;
const req = {
user: { id: user },
};
const progressOptions = {
res: {},
};
// Define the message options
const messageOptions = {
user,
parentMessageId,
conversationId,
progressOptions,
};
async function main() {
await connectDb();
const client = new AgentClient({ req, maxContextTokens });
const text = 'Hello, this is a test message.';
try {
let response = await client.sendMessage(text, messageOptions);
console.log('Response:', response);
} catch (error) {
console.error('Error sending message:', error);
}
}
main();

View file

@ -1,4 +1,4 @@
const { Constants, getResponseSender } = require('librechat-data-provider');
const { Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models');
@ -9,22 +9,17 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
text,
endpointOption,
conversationId,
modelDisplayLabel,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
let sender;
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
let userMessagePromise;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
modelDisplayLabel,
});
const newConvo = !conversationId;
const user = req.user.id;
@ -39,6 +34,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (key === 'sender') {
sender = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
@ -46,6 +43,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
};
try {
/** @type {{ client: TAgentClient }} */
const { client } = await initializeClient({ req, res, endpointOption });
const getAbortData = () => ({
@ -54,8 +52,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
promptTokens,
conversationId,
userMessagePromise,
// text: getPartialText(),
messageId: responseMessageId,
content: client.getContentParts(),
parentMessageId: overrideParentMessageId ?? userMessageId,
});
@ -90,11 +88,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
};
let response = await client.sendMessage(text, messageOptions);
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
response.endpoint = endpointOption.endpoint;
const { conversation = {} } = await client.responsePromise;
@ -103,7 +96,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
if (client.options.attachments) {
userMessage.files = client.options.attachments;
conversation.model = endpointOption.modelOptions.model;
delete userMessage.image_urls;
}

View file

@ -1,4 +1,4 @@
const { Run } = require('@librechat/agents');
const { Run, Providers } = require('@librechat/agents');
const { providerEndpointMap } = require('librechat-data-provider');
/**
@ -14,11 +14,12 @@ const { providerEndpointMap } = require('librechat-data-provider');
* Creates a new Run instance with custom handlers and configuration.
*
* @param {Object} options - The options for creating the Run instance.
* @param {ServerRequest} [options.req] - The server request.
* @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated.
* @param {Agent} options.agent - The agent for this run.
* @param {StructuredTool[] | undefined} [options.tools] - The tools to use in the run.
* @param {Record<string, StructuredTool[]> | undefined} [options.toolMap] - The tool map for the run.
* @param {Record<GraphEvents, EventHandler> | undefined} [options.customHandlers] - Custom event handlers.
* @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated.
* @param {ClientOptions} [options.modelOptions] - Optional model to use; if not provided, it will use the default from modelMap.
* @param {boolean} [options.streaming=true] - Whether to use streaming.
* @param {boolean} [options.streamUsage=true] - Whether to stream usage information.
@ -43,15 +44,22 @@ async function createRun({
modelOptions,
);
const graphConfig = {
runId,
llmConfig,
tools,
toolMap,
instructions: agent.instructions,
additional_instructions: agent.additional_instructions,
};
// TEMPORARY FOR TESTING
if (agent.provider === Providers.ANTHROPIC) {
graphConfig.streamBuffer = 2000;
}
return Run.create({
graphConfig: {
runId,
llmConfig,
tools,
toolMap,
instructions: agent.instructions,
additional_instructions: agent.additional_instructions,
},
graphConfig,
customHandlers,
});
}

View file

@ -1,5 +1,5 @@
const { nanoid } = require('nanoid');
const { FileContext } = require('librechat-data-provider');
const { FileContext, Constants } = require('librechat-data-provider');
const {
getAgent,
createAgent,
@ -9,6 +9,8 @@ const {
} = require('~/models/Agent');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { uploadImageBuffer } = require('~/server/services/Files/process');
const { getProjectByName } = require('~/models/Project');
const { updateAgentProjects } = require('~/models/Agent');
const { deleteFileByFilter } = require('~/models/File');
const { logger } = require('~/config');
@ -53,16 +55,35 @@ const createAgentHandler = async (req, res) => {
* @param {object} req - Express Request
* @param {object} req.params - Request params
* @param {string} req.params.id - Agent identifier.
* @returns {Agent} 200 - success response - application/json
* @param {object} req.user - Authenticated user information
* @param {string} req.user.id - User ID
* @returns {Promise<Agent>} 200 - success response - application/json
* @returns {Error} 404 - Agent not found
*/
const getAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const agent = await getAgent({ id });
const author = req.user.id;
let query = { id, author };
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, ['agentIds']);
if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) {
query = {
$or: [{ id, $in: globalProject.agentIds }, query],
};
}
const agent = await getAgent(query);
if (!agent) {
return res.status(404).json({ error: 'Agent not found' });
}
if (agent.author !== author) {
delete agent.author;
}
return res.status(200).json(agent);
} catch (error) {
logger.error('[/Agents/:id] Error retrieving agent', error);
@ -82,7 +103,17 @@ const getAgentHandler = async (req, res) => {
const updateAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const updatedAgent = await updateAgent({ id, author: req.user.id }, req.body);
const { projectIds, removeProjectIds, ...updateData } = req.body;
let updatedAgent;
if (Object.keys(updateData).length > 0) {
updatedAgent = await updateAgent({ id, author: req.user.id }, updateData);
}
if (projectIds || removeProjectIds) {
updatedAgent = await updateAgentProjects(id, projectIds, removeProjectIds);
}
return res.json(updatedAgent);
} catch (error) {
logger.error('[/Agents/:id] Error updating Agent', error);
@ -119,13 +150,13 @@ const deleteAgentHandler = async (req, res) => {
* @param {object} req - Express Request
* @param {object} req.query - Request query
* @param {string} [req.query.user] - The user ID of the agent's author.
* @returns {AgentListResponse} 200 - success response - application/json
* @returns {Promise<AgentListResponse>} 200 - success response - application/json
*/
const getListAgentsHandler = async (req, res) => {
try {
const { user } = req.query;
const filter = user ? { author: user } : {};
const data = await getListAgents(filter);
const data = await getListAgents({
author: req.user.id,
});
return res.json(data);
} catch (error) {
logger.error('[/Agents] Error listing Agents', error);

View file

@ -106,6 +106,7 @@ const startServer = async () => {
app.use('/api/share', routes.share);
app.use('/api/roles', routes.roles);
app.use('/api/agents', routes.agents);
app.use('/api/bedrock', routes.bedrock);
app.use('/api/tags', routes.tags);

View file

@ -107,7 +107,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
finish_reason: 'incomplete',
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions.model,
model: endpointOption.modelOptions?.model ?? endpointOption.model_parameters?.model,
unfinished: false,
error: false,
isCreatedByUser: false,

View file

@ -5,6 +5,7 @@ const assistants = require('~/server/services/Endpoints/assistants');
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
const { processFiles } = require('~/server/services/Files/process');
const anthropic = require('~/server/services/Endpoints/anthropic');
const bedrock = require('~/server/services/Endpoints/bedrock');
const openAI = require('~/server/services/Endpoints/openAI');
const agents = require('~/server/services/Endpoints/agents');
const custom = require('~/server/services/Endpoints/custom');
@ -17,6 +18,7 @@ const buildFunction = {
[EModelEndpoint.google]: google.buildOptions,
[EModelEndpoint.custom]: custom.buildOptions,
[EModelEndpoint.agents]: agents.buildOptions,
[EModelEndpoint.bedrock]: bedrock.buildOptions,
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
[EModelEndpoint.anthropic]: anthropic.buildOptions,
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,

View file

@ -41,7 +41,7 @@ router.post('/:agent_id', async (req, res) => {
return res.status(400).json({ message: 'No functions provided' });
}
let metadata = encryptMetadata(_metadata);
let metadata = await encryptMetadata(_metadata);
let { domain } = metadata;
domain = await domainParser(req, domain, true);

View file

@ -1,11 +1,30 @@
const multer = require('multer');
const express = require('express');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const v1 = require('~/server/controllers/agents/v1');
const actions = require('./actions');
const upload = multer();
const router = express.Router();
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [
Permissions.USE,
Permissions.CREATE,
]);
const checkGlobalAgentShare = generateCheckAccess(
PermissionTypes.AGENTS,
[Permissions.USE, Permissions.CREATE],
{
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
},
);
router.use(requireJwtAuth);
router.use(checkAgentAccess);
/**
* Agent actions route.
* @route GET|POST /agents/actions
@ -27,7 +46,7 @@ router.use('/tools', (req, res) => {
* @param {AgentCreateParams} req.body - The agent creation parameters.
* @returns {Agent} 201 - Success response - application/json
*/
router.post('/', v1.createAgent);
router.post('/', checkAgentCreate, v1.createAgent);
/**
* Retrieves an agent.
@ -35,7 +54,7 @@ router.post('/', v1.createAgent);
* @param {string} req.params.id - Agent identifier.
* @returns {Agent} 200 - Success response - application/json
*/
router.get('/:id', v1.getAgent);
router.get('/:id', checkAgentAccess, v1.getAgent);
/**
* Updates an agent.
@ -44,7 +63,7 @@ router.get('/:id', v1.getAgent);
* @param {AgentUpdateParams} req.body - The agent update parameters.
* @returns {Agent} 200 - Success response - application/json
*/
router.patch('/:id', v1.updateAgent);
router.patch('/:id', checkGlobalAgentShare, v1.updateAgent);
/**
* Deletes an agent.
@ -52,7 +71,7 @@ router.patch('/:id', v1.updateAgent);
* @param {string} req.params.id - Agent identifier.
* @returns {Agent} 200 - success response - application/json
*/
router.delete('/:id', v1.deleteAgent);
router.delete('/:id', checkAgentCreate, v1.deleteAgent);
/**
* Returns a list of agents.
@ -60,9 +79,7 @@ router.delete('/:id', v1.deleteAgent);
* @param {AgentListParams} req.query - The agent list parameters for pagination and sorting.
* @returns {AgentListResponse} 200 - success response - application/json
*/
router.get('/', v1.getListAgents);
// TODO: handle private agents
router.get('/', checkAgentAccess, v1.getListAgents);
/**
* Uploads and updates an avatar for a specific agent.
@ -72,6 +89,6 @@ router.get('/', v1.getListAgents);
* @param {string} [req.body.metadata] - Optional metadata for the agent's avatar.
* @returns {Object} 200 - success response - application/json
*/
router.post('/avatar/:agent_id', upload.single('file'), v1.uploadAgentAvatar);
router.post('/avatar/:agent_id', checkAgentAccess, upload.single('file'), v1.uploadAgentAvatar);
module.exports = router;

View file

@ -0,0 +1,36 @@
const express = require('express');
const router = express.Router();
const {
setHeaders,
handleAbort,
// validateModel,
// validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const { initializeClient } = require('~/server/services/Endpoints/bedrock');
const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/bedrock/title');
router.post('/abort', handleAbort());
/**
* @route POST /
* @desc Chat with an assistant
* @access Public
* @param {express.Request} req - The request object, containing the request data.
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post(
'/',
// validateModel,
// validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AgentController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View file

@ -0,0 +1,19 @@
const express = require('express');
const router = express.Router();
const {
uaParser,
checkBan,
requireJwtAuth,
// concurrentLimiter,
// messageIpLimiter,
// messageUserLimiter,
} = require('~/server/middleware');
const chat = require('./chat');
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
router.use('/chat', chat);
module.exports = router;

View file

@ -1,5 +1,5 @@
const express = require('express');
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider');
const { getLdapConfig } = require('~/server/services/Config/ldap');
const { getProjectByName } = require('~/models/Project');
const { isEnabled } = require('~/server/utils');
@ -32,7 +32,7 @@ router.get('/', async function (req, res) {
return today.getMonth() === 1 && today.getDate() === 11;
};
const instanceProject = await getProjectByName('instance', '_id');
const instanceProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
const ldap = getLdapConfig();

View file

@ -8,6 +8,7 @@ const presets = require('./presets');
const prompts = require('./prompts');
const balance = require('./balance');
const plugins = require('./plugins');
const bedrock = require('./bedrock');
const search = require('./search');
const models = require('./models');
const convos = require('./convos');
@ -36,6 +37,7 @@ module.exports = {
files,
share,
agents,
bedrock,
convos,
search,
prompts,

View file

@ -24,6 +24,7 @@ const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
Permissions.USE,
Permissions.CREATE,
]);
const checkGlobalPromptShare = generateCheckAccess(
PermissionTypes.PROMPTS,
[Permissions.USE, Permissions.CREATE],

View file

@ -165,7 +165,7 @@ async function createActionTool({ action, requestBuilder, zodSchema, name, descr
* Encrypts sensitive metadata values for an action.
*
* @param {ActionMetadata} metadata - The action metadata to encrypt.
* @returns {ActionMetadata} The updated action metadata with encrypted values.
* @returns {Promise<ActionMetadata>} The updated action metadata with encrypted values.
*/
async function encryptMetadata(metadata) {
const encryptedMetadata = { ...metadata };

View file

@ -94,18 +94,19 @@ const AppService = async (app) => {
);
}
if (endpoints?.[EModelEndpoint.openAI]) {
endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI];
}
if (endpoints?.[EModelEndpoint.google]) {
endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google];
}
if (endpoints?.[EModelEndpoint.anthropic]) {
endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic];
}
if (endpoints?.[EModelEndpoint.gptPlugins]) {
endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins];
}
const endpointKeys = [
EModelEndpoint.openAI,
EModelEndpoint.google,
EModelEndpoint.bedrock,
EModelEndpoint.anthropic,
EModelEndpoint.gptPlugins,
];
endpointKeys.forEach((key) => {
if (endpoints?.[key]) {
endpointLocals[key] = endpoints[key];
}
});
app.locals = {
...defaultLocals,

View file

@ -45,6 +45,7 @@ module.exports = {
AZURE_ASSISTANTS_BASE_URL,
EModelEndpoint.azureAssistants,
),
[EModelEndpoint.bedrock]: generateConfig(process.env.BEDROCK_AWS_SECRET_ACCESS_KEY),
/* key will be part of separate config */
[EModelEndpoint.agents]: generateConfig(process.env.I_AM_A_TEAPOT),
},

View file

@ -9,22 +9,13 @@ const { config } = require('./EndpointService');
*/
async function loadDefaultEndpointsConfig(req) {
const { google, gptPlugins } = await loadAsyncEndpoints(req);
const {
openAI,
agents,
assistants,
azureAssistants,
bingAI,
anthropic,
azureOpenAI,
chatGPTBrowser,
} = config;
const { assistants, azureAssistants, bingAI, azureOpenAI, chatGPTBrowser } = config;
const enabledEndpoints = getEnabledEndpoints();
const endpointConfig = {
[EModelEndpoint.openAI]: openAI,
[EModelEndpoint.agents]: agents,
[EModelEndpoint.openAI]: config[EModelEndpoint.openAI],
[EModelEndpoint.agents]: config[EModelEndpoint.agents],
[EModelEndpoint.assistants]: assistants,
[EModelEndpoint.azureAssistants]: azureAssistants,
[EModelEndpoint.azureOpenAI]: azureOpenAI,
@ -32,7 +23,8 @@ async function loadDefaultEndpointsConfig(req) {
[EModelEndpoint.bingAI]: bingAI,
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
[EModelEndpoint.gptPlugins]: gptPlugins,
[EModelEndpoint.anthropic]: anthropic,
[EModelEndpoint.anthropic]: config[EModelEndpoint.anthropic],
[EModelEndpoint.bedrock]: config[EModelEndpoint.bedrock],
};
const orderedAndFilteredEndpoints = enabledEndpoints.reduce((config, key, index) => {

View file

@ -3,6 +3,7 @@ const { useAzurePlugins } = require('~/server/services/Config/EndpointService').
const {
getOpenAIModels,
getGoogleModels,
getBedrockModels,
getAnthropicModels,
getChatGPTBrowserModels,
} = require('~/server/services/ModelService');
@ -38,6 +39,7 @@ async function loadDefaultModels(req) {
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
[EModelEndpoint.assistants]: assistants,
[EModelEndpoint.azureAssistants]: azureAssistants,
[EModelEndpoint.bedrock]: getBedrockModels(),
};
}

View file

@ -2,7 +2,7 @@ const { getAgent } = require('~/models/Agent');
const { logger } = require('~/config');
const buildOptions = (req, endpoint, parsedBody) => {
const { agent_id, instructions, spec, ...rest } = parsedBody;
const { agent_id, instructions, spec, ...model_parameters } = parsedBody;
const agentPromise = getAgent({
id: agent_id,
@ -19,9 +19,7 @@ const buildOptions = (req, endpoint, parsedBody) => {
agent_id,
instructions,
spec,
modelOptions: {
...rest,
},
model_parameters,
};
return endpointOption;

View file

@ -11,7 +11,12 @@
const { z } = require('zod');
const { tool } = require('@langchain/core/tools');
const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider');
const { createContentAggregator } = require('@librechat/agents');
const {
EModelEndpoint,
providerEndpointMap,
getResponseSender,
} = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// for testing purposes
// const createTavilySearchTool = require('~/app/clients/tools/structured/TavilySearch');
@ -53,7 +58,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
}
// TODO: use endpointOption to determine options/modelOptions
const eventHandlers = getDefaultHandlers({ res });
const { contentParts, aggregateContent } = createContentAggregator();
const eventHandlers = getDefaultHandlers({ res, aggregateContent });
// const tools = [createTavilySearchTool()];
// const tools = [_getWeather];
@ -90,7 +96,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
}
// TODO: pass-in override settings that are specific to current run
endpointOption.modelOptions.model = agent.model;
endpointOption.model_parameters.model = agent.model;
const options = await getOptions({
req,
res,
@ -101,13 +107,21 @@ const initializeClient = async ({ req, res, endpointOption }) => {
});
modelOptions = Object.assign(modelOptions, options.llmConfig);
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
});
const client = new AgentClient({
req,
agent,
tools,
sender,
toolMap,
contentParts,
modelOptions,
eventHandlers,
endpoint: EModelEndpoint.agents,
configOptions: options.configOptions,
maxContextTokens:
agent.max_context_tokens ??

View file

@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => {
const title = await client.titleConvo({
text,
responseText: response?.text,
responseText: response?.text ?? '',
conversationId: response.conversationId,
});
await titleCache.set(key, title, 120000);

View file

@ -0,0 +1,44 @@
const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { logger } = require('~/config');
const buildOptions = (endpoint, parsedBody) => {
const {
modelLabel: name,
promptPrefix,
maxContextTokens,
resendFiles = true,
imageDetail,
iconURL,
greeting,
spec,
artifacts,
...model_parameters
} = parsedBody;
let parsedParams = model_parameters;
try {
parsedParams = bedrockInputParser.parse(model_parameters);
} catch (error) {
logger.warn('Failed to parse bedrock input', error);
}
const endpointOption = removeNullishValues({
endpoint,
name,
resendFiles,
imageDetail,
iconURL,
greeting,
spec,
promptPrefix,
maxContextTokens,
model_parameters: parsedParams,
});
if (typeof artifacts === 'string') {
endpointOption.artifactsPrompt = generateArtifactsPrompt({ endpoint, artifacts });
}
return endpointOption;
};
module.exports = { buildOptions };

View file

@ -0,0 +1,7 @@
const build = require('./build');
const initialize = require('./initialize');
module.exports = {
...build,
...initialize,
};

View file

@ -0,0 +1,72 @@
const { createContentAggregator } = require('@librechat/agents');
const {
EModelEndpoint,
providerEndpointMap,
getResponseSender,
} = require('librechat-data-provider');
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
// const { loadAgentTools } = require('~/server/services/ToolService');
const getOptions = require('~/server/services/Endpoints/bedrock/options');
const AgentClient = require('~/server/controllers/agents/client');
const { getModelMaxTokens } = require('~/utils');
const initializeClient = async ({ req, res, endpointOption }) => {
if (!endpointOption) {
throw new Error('Endpoint option not provided');
}
/** @type {Array<UsageMetadata>} */
const collectedUsage = [];
const { contentParts, aggregateContent } = createContentAggregator();
const eventHandlers = getDefaultHandlers({ res, aggregateContent, collectedUsage });
// const tools = [createTavilySearchTool()];
/** @type {Agent} */
const agent = {
id: EModelEndpoint.bedrock,
name: endpointOption.name,
instructions: endpointOption.promptPrefix,
provider: EModelEndpoint.bedrock,
model: endpointOption.model_parameters.model,
model_parameters: endpointOption.model_parameters,
};
let modelOptions = { model: agent.model };
// TODO: pass-in override settings that are specific to current run
const options = await getOptions({
req,
res,
endpointOption,
});
modelOptions = Object.assign(modelOptions, options.llmConfig);
const maxContextTokens =
agent.max_context_tokens ??
getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]);
const sender = getResponseSender({
...endpointOption,
model: endpointOption.model_parameters.model,
});
const client = new AgentClient({
req,
agent,
sender,
// tools,
// toolMap,
modelOptions,
contentParts,
eventHandlers,
collectedUsage,
maxContextTokens,
endpoint: EModelEndpoint.bedrock,
configOptions: options.configOptions,
attachments: endpointOption.attachments,
});
return { client };
};
module.exports = { initializeClient };

View file

@ -0,0 +1,90 @@
const { HttpsProxyAgent } = require('https-proxy-agent');
const {
EModelEndpoint,
Constants,
AuthType,
removeNullishValues,
} = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { sleep } = require('~/server/utils');
const getOptions = async ({ req, endpointOption }) => {
const {
BEDROCK_AWS_SECRET_ACCESS_KEY,
BEDROCK_AWS_ACCESS_KEY_ID,
BEDROCK_REVERSE_PROXY,
BEDROCK_AWS_DEFAULT_REGION,
PROXY,
} = process.env;
const expiresAt = req.body.key;
const isUserProvided = BEDROCK_AWS_SECRET_ACCESS_KEY === AuthType.USER_PROVIDED;
const credentials = isUserProvided
? await getUserKey({ userId: req.user.id, name: EModelEndpoint.bedrock })
: {
accessKeyId: BEDROCK_AWS_ACCESS_KEY_ID,
secretAccessKey: BEDROCK_AWS_SECRET_ACCESS_KEY,
};
if (!credentials) {
throw new Error('Bedrock credentials not provided. Please provide them again.');
}
if (expiresAt && isUserProvided) {
checkUserKeyExpiry(expiresAt, EModelEndpoint.bedrock);
}
/** @type {number} */
let streamRate = Constants.DEFAULT_STREAM_RATE;
/** @type {undefined | TBaseEndpoint} */
const bedrockConfig = req.app.locals[EModelEndpoint.bedrock];
if (bedrockConfig && bedrockConfig.streamRate) {
streamRate = bedrockConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig && allConfig.streamRate) {
streamRate = allConfig.streamRate;
}
/** @type {import('@librechat/agents').BedrockConverseClientOptions} */
const requestOptions = Object.assign(
{
credentials,
model: endpointOption.model,
region: BEDROCK_AWS_DEFAULT_REGION,
streaming: true,
streamUsage: true,
callbacks: [
{
handleLLMNewToken: async () => {
if (!streamRate) {
return;
}
await sleep(streamRate);
},
},
],
},
endpointOption.model_parameters,
);
const configOptions = {};
if (PROXY) {
configOptions.httpAgent = new HttpsProxyAgent(PROXY);
}
if (BEDROCK_REVERSE_PROXY) {
configOptions.endpointHost = BEDROCK_REVERSE_PROXY;
}
return {
llmConfig: removeNullishValues(requestOptions),
configOptions,
};
};
module.exports = getOptions;

View file

@ -0,0 +1,40 @@
const { CacheKeys } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils');
const { saveConvo } = require('~/models');
const addTitle = async (req, { text, response, client }) => {
const { TITLE_CONVO = true } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) {
return;
}
if (client.options.titleConvo === false) {
return;
}
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
return;
}
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;
const title = await client.titleConvo({
text,
responseText: response?.text ?? '',
conversationId: response.conversationId,
});
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/bedrock/title.js' },
);
};
module.exports = addTitle;

View file

@ -49,7 +49,7 @@ const addTitle = async (req, { text, response, client }) => {
const title = await titleClient.titleConvo({
text,
responseText: response?.text,
responseText: response?.text ?? '',
conversationId: response.conversationId,
});
await titleCache.set(key, title, 120000);

View file

@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => {
const title = await client.titleConvo({
text,
responseText: response?.text,
responseText: response?.text ?? '',
conversationId: response.conversationId,
});
await titleCache.set(key, title, 120000);

View file

@ -23,7 +23,13 @@ async function fetchImageToBase64(url) {
}
}
const base64Only = new Set([EModelEndpoint.google, EModelEndpoint.anthropic, 'Ollama', 'ollama']);
const base64Only = new Set([
EModelEndpoint.google,
EModelEndpoint.anthropic,
'Ollama',
'ollama',
EModelEndpoint.bedrock,
]);
/**
* Encodes and formats the given files.

View file

@ -5,6 +5,21 @@ const { extractBaseURL, inputSchema, processModelData, logAxiosError } = require
const { OllamaClient } = require('~/app/clients/OllamaClient');
const getLogStores = require('~/cache/getLogStores');
/**
* Splits a string by commas and trims each resulting value.
* @param {string} input - The input string to split.
* @returns {string[]} An array of trimmed values.
*/
const splitAndTrim = (input) => {
if (!input || typeof input !== 'string') {
return [];
}
return input
.split(',')
.map((item) => item.trim())
.filter(Boolean);
};
const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService').config;
/**
@ -194,7 +209,7 @@ const getOpenAIModels = async (opts) => {
}
if (process.env[key]) {
models = String(process.env[key]).split(',');
models = splitAndTrim(process.env[key]);
return models;
}
@ -208,7 +223,7 @@ const getOpenAIModels = async (opts) => {
const getChatGPTBrowserModels = () => {
let models = ['text-davinci-002-render-sha', 'gpt-4'];
if (process.env.CHATGPT_MODELS) {
models = String(process.env.CHATGPT_MODELS).split(',');
models = splitAndTrim(process.env.CHATGPT_MODELS);
}
return models;
@ -217,7 +232,7 @@ const getChatGPTBrowserModels = () => {
const getAnthropicModels = () => {
let models = defaultModels[EModelEndpoint.anthropic];
if (process.env.ANTHROPIC_MODELS) {
models = String(process.env.ANTHROPIC_MODELS).split(',');
models = splitAndTrim(process.env.ANTHROPIC_MODELS);
}
return models;
@ -226,7 +241,16 @@ const getAnthropicModels = () => {
const getGoogleModels = () => {
let models = defaultModels[EModelEndpoint.google];
if (process.env.GOOGLE_MODELS) {
models = String(process.env.GOOGLE_MODELS).split(',');
models = splitAndTrim(process.env.GOOGLE_MODELS);
}
return models;
};
const getBedrockModels = () => {
let models = defaultModels[EModelEndpoint.bedrock];
if (process.env.BEDROCK_AWS_MODELS) {
models = splitAndTrim(process.env.BEDROCK_AWS_MODELS);
}
return models;
@ -234,7 +258,9 @@ const getGoogleModels = () => {
module.exports = {
fetchModels,
splitAndTrim,
getOpenAIModels,
getBedrockModels,
getChatGPTBrowserModels,
getAnthropicModels,
getGoogleModels,

View file

@ -1,7 +1,16 @@
const axios = require('axios');
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
const { logger } = require('~/config');
const { fetchModels, getOpenAIModels } = require('./ModelService');
const {
fetchModels,
splitAndTrim,
getOpenAIModels,
getGoogleModels,
getBedrockModels,
getAnthropicModels,
} = require('./ModelService');
jest.mock('~/utils', () => {
const originalUtils = jest.requireActual('~/utils');
return {
@ -329,3 +338,71 @@ describe('fetchModels with Ollama specific logic', () => {
);
});
});
describe('splitAndTrim', () => {
it('should split a string by commas and trim each value', () => {
const input = ' model1, model2 , model3,model4 ';
const expected = ['model1', 'model2', 'model3', 'model4'];
expect(splitAndTrim(input)).toEqual(expected);
});
it('should return an empty array for empty input', () => {
expect(splitAndTrim('')).toEqual([]);
});
it('should return an empty array for null input', () => {
expect(splitAndTrim(null)).toEqual([]);
});
it('should return an empty array for undefined input', () => {
expect(splitAndTrim(undefined)).toEqual([]);
});
it('should filter out empty values after trimming', () => {
const input = 'model1,, ,model2,';
const expected = ['model1', 'model2'];
expect(splitAndTrim(input)).toEqual(expected);
});
});
describe('getAnthropicModels', () => {
it('returns default models when ANTHROPIC_MODELS is not set', () => {
delete process.env.ANTHROPIC_MODELS;
const models = getAnthropicModels();
expect(models).toEqual(defaultModels[EModelEndpoint.anthropic]);
});
it('returns models from ANTHROPIC_MODELS when set', () => {
process.env.ANTHROPIC_MODELS = 'claude-1, claude-2 ';
const models = getAnthropicModels();
expect(models).toEqual(['claude-1', 'claude-2']);
});
});
describe('getGoogleModels', () => {
it('returns default models when GOOGLE_MODELS is not set', () => {
delete process.env.GOOGLE_MODELS;
const models = getGoogleModels();
expect(models).toEqual(defaultModels[EModelEndpoint.google]);
});
it('returns models from GOOGLE_MODELS when set', () => {
process.env.GOOGLE_MODELS = 'gemini-pro, bard ';
const models = getGoogleModels();
expect(models).toEqual(['gemini-pro', 'bard']);
});
});
describe('getBedrockModels', () => {
it('returns default models when BEDROCK_AWS_MODELS is not set', () => {
delete process.env.BEDROCK_AWS_MODELS;
const models = getBedrockModels();
expect(models).toEqual(defaultModels[EModelEndpoint.bedrock]);
});
it('returns models from BEDROCK_AWS_MODELS when set', () => {
process.env.BEDROCK_AWS_MODELS = 'anthropic.claude-v2, ai21.j2-ultra ';
const models = getBedrockModels();
expect(models).toEqual(['anthropic.claude-v2', 'ai21.j2-ultra']);
});
});