mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
Merge branch 'dev' into feat/prompt-enhancement
This commit is contained in:
commit
e1af9d21f0
309 changed files with 12487 additions and 6311 deletions
15
.env.example
15
.env.example
|
|
@ -58,7 +58,7 @@ DEBUG_CONSOLE=false
|
||||||
# Endpoints #
|
# Endpoints #
|
||||||
#===================================================#
|
#===================================================#
|
||||||
|
|
||||||
# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic
|
# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic
|
||||||
|
|
||||||
PROXY=
|
PROXY=
|
||||||
|
|
||||||
|
|
@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided
|
||||||
# GOOGLE_AUTH_HEADER=true
|
# GOOGLE_AUTH_HEADER=true
|
||||||
|
|
||||||
# Gemini API (AI Studio)
|
# Gemini API (AI Studio)
|
||||||
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
|
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
|
||||||
|
|
||||||
# Vertex AI
|
# Vertex AI
|
||||||
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
|
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
|
||||||
|
|
||||||
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
|
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
|
||||||
|
|
||||||
|
|
@ -349,6 +349,11 @@ REGISTRATION_VIOLATION_SCORE=1
|
||||||
CONCURRENT_VIOLATION_SCORE=1
|
CONCURRENT_VIOLATION_SCORE=1
|
||||||
MESSAGE_VIOLATION_SCORE=1
|
MESSAGE_VIOLATION_SCORE=1
|
||||||
NON_BROWSER_VIOLATION_SCORE=20
|
NON_BROWSER_VIOLATION_SCORE=20
|
||||||
|
TTS_VIOLATION_SCORE=0
|
||||||
|
STT_VIOLATION_SCORE=0
|
||||||
|
FORK_VIOLATION_SCORE=0
|
||||||
|
IMPORT_VIOLATION_SCORE=0
|
||||||
|
FILE_UPLOAD_VIOLATION_SCORE=0
|
||||||
|
|
||||||
LOGIN_MAX=7
|
LOGIN_MAX=7
|
||||||
LOGIN_WINDOW=5
|
LOGIN_WINDOW=5
|
||||||
|
|
@ -453,8 +458,8 @@ OPENID_REUSE_TOKENS=
|
||||||
OPENID_JWKS_URL_CACHE_ENABLED=
|
OPENID_JWKS_URL_CACHE_ENABLED=
|
||||||
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
|
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
|
||||||
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
|
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
|
||||||
OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
|
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
|
||||||
OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
|
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
|
||||||
# Set to true to use the OpenID Connect end session endpoint for logout
|
# Set to true to use the OpenID Connect end session endpoint for logout
|
||||||
OPENID_USE_END_SESSION_ENDPOINT=
|
OPENID_USE_END_SESSION_ENDPOINT=
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# v0.7.8
|
# v0.7.9-rc1
|
||||||
|
|
||||||
# Base node image
|
# Base node image
|
||||||
FROM node:20-alpine AS node
|
FROM node:20-alpine AS node
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# Dockerfile.multi
|
# Dockerfile.multi
|
||||||
# v0.7.8
|
# v0.7.9-rc1
|
||||||
|
|
||||||
# Base for all builds
|
# Base for all builds
|
||||||
FROM node:20-alpine AS base-min
|
FROM node:20-alpine AS base-min
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@
|
||||||
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
|
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
|
||||||
|
|
||||||
- 🤖 **AI Model Selection**:
|
- 🤖 **AI Model Selection**:
|
||||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
|
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
|
||||||
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
|
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
|
||||||
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
|
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
|
||||||
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
|
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
|
||||||
|
|
@ -66,10 +66,9 @@
|
||||||
- 🔦 **Agents & Tools Integration**:
|
- 🔦 **Agents & Tools Integration**:
|
||||||
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
|
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
|
||||||
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
|
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
|
||||||
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
|
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
|
||||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
|
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
|
||||||
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
|
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
|
||||||
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
|
|
||||||
|
|
||||||
- 🔍 **Web Search**:
|
- 🔍 **Web Search**:
|
||||||
- Search the internet and retrieve relevant information to enhance your AI context
|
- Search the internet and retrieve relevant information to enhance your AI context
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ const {
|
||||||
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
|
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
|
||||||
const { checkBalance } = require('~/models/balanceMethods');
|
const { checkBalance } = require('~/models/balanceMethods');
|
||||||
const { truncateToolCallOutputs } = require('./prompts');
|
const { truncateToolCallOutputs } = require('./prompts');
|
||||||
const { addSpaceIfNeeded } = require('~/server/utils');
|
|
||||||
const { getFiles } = require('~/models/File');
|
const { getFiles } = require('~/models/File');
|
||||||
const TextStream = require('./TextStream');
|
const TextStream = require('./TextStream');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
@ -572,7 +571,7 @@ class BaseClient {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const { generation = '' } = opts;
|
const { editedContent } = opts;
|
||||||
|
|
||||||
// It's not necessary to push to currentMessages
|
// It's not necessary to push to currentMessages
|
||||||
// depending on subclass implementation of handling messages
|
// depending on subclass implementation of handling messages
|
||||||
|
|
@ -587,11 +586,21 @@ class BaseClient {
|
||||||
isCreatedByUser: false,
|
isCreatedByUser: false,
|
||||||
model: this.modelOptions?.model ?? this.model,
|
model: this.modelOptions?.model ?? this.model,
|
||||||
sender: this.sender,
|
sender: this.sender,
|
||||||
text: generation,
|
|
||||||
};
|
};
|
||||||
this.currentMessages.push(userMessage, latestMessage);
|
this.currentMessages.push(userMessage, latestMessage);
|
||||||
} else {
|
} else if (editedContent != null) {
|
||||||
latestMessage.text = generation;
|
// Handle editedContent for content parts
|
||||||
|
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
|
||||||
|
const { index, text, type } = editedContent;
|
||||||
|
if (index >= 0 && index < latestMessage.content.length) {
|
||||||
|
const contentPart = latestMessage.content[index];
|
||||||
|
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
|
||||||
|
contentPart[ContentTypes.THINK] = text;
|
||||||
|
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
|
||||||
|
contentPart[ContentTypes.TEXT] = text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
this.continued = true;
|
this.continued = true;
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -672,16 +681,32 @@ class BaseClient {
|
||||||
};
|
};
|
||||||
|
|
||||||
if (typeof completion === 'string') {
|
if (typeof completion === 'string') {
|
||||||
responseMessage.text = addSpaceIfNeeded(generation) + completion;
|
responseMessage.text = completion;
|
||||||
} else if (
|
} else if (
|
||||||
Array.isArray(completion) &&
|
Array.isArray(completion) &&
|
||||||
(this.clientName === EModelEndpoint.agents ||
|
(this.clientName === EModelEndpoint.agents ||
|
||||||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
|
isParamEndpoint(this.options.endpoint, this.options.endpointType))
|
||||||
) {
|
) {
|
||||||
responseMessage.text = '';
|
responseMessage.text = '';
|
||||||
responseMessage.content = completion;
|
|
||||||
|
if (!opts.editedContent || this.currentMessages.length === 0) {
|
||||||
|
responseMessage.content = completion;
|
||||||
|
} else {
|
||||||
|
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
|
||||||
|
if (!latestMessage?.content) {
|
||||||
|
responseMessage.content = completion;
|
||||||
|
} else {
|
||||||
|
const existingContent = [...latestMessage.content];
|
||||||
|
const { type: editedType } = opts.editedContent;
|
||||||
|
responseMessage.content = this.mergeEditedContent(
|
||||||
|
existingContent,
|
||||||
|
completion,
|
||||||
|
editedType,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (Array.isArray(completion)) {
|
} else if (Array.isArray(completion)) {
|
||||||
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
|
responseMessage.text = completion.join('');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
@ -792,7 +817,8 @@ class BaseClient {
|
||||||
|
|
||||||
userMessage.tokenCount = userMessageTokenCount;
|
userMessage.tokenCount = userMessageTokenCount;
|
||||||
/*
|
/*
|
||||||
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
|
Note: `AgentController` saves the user message if not saved here
|
||||||
|
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
|
||||||
*/
|
*/
|
||||||
if (typeof opts?.getReqData === 'function') {
|
if (typeof opts?.getReqData === 'function') {
|
||||||
opts.getReqData({
|
opts.getReqData({
|
||||||
|
|
@ -801,7 +827,8 @@ class BaseClient {
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
Note: we update the user message to be sure it gets the calculated token count;
|
Note: we update the user message to be sure it gets the calculated token count;
|
||||||
though `AskController` saves the user message, EditController does not
|
though `AgentController` saves the user message if not saved here
|
||||||
|
(noted by `savedMessageIds`), EditController does not
|
||||||
*/
|
*/
|
||||||
await userMessagePromise;
|
await userMessagePromise;
|
||||||
await this.updateMessageInDatabase({
|
await this.updateMessageInDatabase({
|
||||||
|
|
@ -1093,6 +1120,50 @@ class BaseClient {
|
||||||
return numTokens;
|
return numTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Merges completion content with existing content when editing TEXT or THINK types
|
||||||
|
* @param {Array} existingContent - The existing content array
|
||||||
|
* @param {Array} newCompletion - The new completion content
|
||||||
|
* @param {string} editedType - The type of content being edited
|
||||||
|
* @returns {Array} The merged content array
|
||||||
|
*/
|
||||||
|
mergeEditedContent(existingContent, newCompletion, editedType) {
|
||||||
|
if (!newCompletion.length) {
|
||||||
|
return existingContent.concat(newCompletion);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
|
||||||
|
return existingContent.concat(newCompletion);
|
||||||
|
}
|
||||||
|
|
||||||
|
const lastIndex = existingContent.length - 1;
|
||||||
|
const lastExisting = existingContent[lastIndex];
|
||||||
|
const firstNew = newCompletion[0];
|
||||||
|
|
||||||
|
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
|
||||||
|
return existingContent.concat(newCompletion);
|
||||||
|
}
|
||||||
|
|
||||||
|
const mergedContent = [...existingContent];
|
||||||
|
if (editedType === ContentTypes.TEXT) {
|
||||||
|
mergedContent[lastIndex] = {
|
||||||
|
...mergedContent[lastIndex],
|
||||||
|
[ContentTypes.TEXT]:
|
||||||
|
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
mergedContent[lastIndex] = {
|
||||||
|
...mergedContent[lastIndex],
|
||||||
|
[ContentTypes.THINK]:
|
||||||
|
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
|
||||||
|
(firstNew[ContentTypes.THINK] || ''),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add remaining completion items
|
||||||
|
return mergedContent.concat(newCompletion.slice(1));
|
||||||
|
}
|
||||||
|
|
||||||
async sendPayload(payload, opts = {}) {
|
async sendPayload(payload, opts = {}) {
|
||||||
if (opts && typeof opts === 'object') {
|
if (opts && typeof opts === 'object') {
|
||||||
this.setOptions(opts);
|
this.setOptions(opts);
|
||||||
|
|
|
||||||
|
|
@ -1,804 +0,0 @@
|
||||||
const { Keyv } = require('keyv');
|
|
||||||
const crypto = require('crypto');
|
|
||||||
const { CohereClient } = require('cohere-ai');
|
|
||||||
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
|
||||||
const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api');
|
|
||||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
|
||||||
const {
|
|
||||||
ImageDetail,
|
|
||||||
EModelEndpoint,
|
|
||||||
resolveHeaders,
|
|
||||||
CohereConstants,
|
|
||||||
mapModelToAzureConfig,
|
|
||||||
} = require('librechat-data-provider');
|
|
||||||
const { createContextHandlers } = require('./prompts');
|
|
||||||
const { createCoherePayload } = require('./llm');
|
|
||||||
const { extractBaseURL } = require('~/utils');
|
|
||||||
const BaseClient = require('./BaseClient');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const CHATGPT_MODEL = 'gpt-3.5-turbo';
|
|
||||||
const tokenizersCache = {};
|
|
||||||
|
|
||||||
class ChatGPTClient extends BaseClient {
|
|
||||||
constructor(apiKey, options = {}, cacheOptions = {}) {
|
|
||||||
super(apiKey, options, cacheOptions);
|
|
||||||
|
|
||||||
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
|
|
||||||
this.conversationsCache = new Keyv(cacheOptions);
|
|
||||||
this.setOptions(options);
|
|
||||||
}
|
|
||||||
|
|
||||||
setOptions(options) {
|
|
||||||
if (this.options && !this.options.replaceOptions) {
|
|
||||||
// nested options aren't spread properly, so we need to do this manually
|
|
||||||
this.options.modelOptions = {
|
|
||||||
...this.options.modelOptions,
|
|
||||||
...options.modelOptions,
|
|
||||||
};
|
|
||||||
delete options.modelOptions;
|
|
||||||
// now we can merge options
|
|
||||||
this.options = {
|
|
||||||
...this.options,
|
|
||||||
...options,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
this.options = options;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.options.openaiApiKey) {
|
|
||||||
this.apiKey = this.options.openaiApiKey;
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelOptions = this.options.modelOptions || {};
|
|
||||||
this.modelOptions = {
|
|
||||||
...modelOptions,
|
|
||||||
// set some good defaults (check for undefined in some cases because they may be 0)
|
|
||||||
model: modelOptions.model || CHATGPT_MODEL,
|
|
||||||
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
|
|
||||||
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
|
|
||||||
presence_penalty:
|
|
||||||
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
|
|
||||||
stop: modelOptions.stop,
|
|
||||||
};
|
|
||||||
|
|
||||||
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
|
|
||||||
const { isChatGptModel } = this;
|
|
||||||
this.isUnofficialChatGptModel =
|
|
||||||
this.modelOptions.model.startsWith('text-chat') ||
|
|
||||||
this.modelOptions.model.startsWith('text-davinci-002-render');
|
|
||||||
const { isUnofficialChatGptModel } = this;
|
|
||||||
|
|
||||||
// Davinci models have a max context length of 4097 tokens.
|
|
||||||
this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097);
|
|
||||||
// I decided to reserve 1024 tokens for the response.
|
|
||||||
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
|
|
||||||
// Earlier messages will be dropped until the prompt is within the limit.
|
|
||||||
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
|
|
||||||
this.maxPromptTokens =
|
|
||||||
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
|
||||||
|
|
||||||
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
|
|
||||||
throw new Error(
|
|
||||||
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
|
|
||||||
this.maxPromptTokens + this.maxResponseTokens
|
|
||||||
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.userLabel = this.options.userLabel || 'User';
|
|
||||||
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
|
|
||||||
|
|
||||||
if (isChatGptModel) {
|
|
||||||
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
|
|
||||||
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
|
|
||||||
// without tripping the stop sequences, so I'm using "||>" instead.
|
|
||||||
this.startToken = '||>';
|
|
||||||
this.endToken = '';
|
|
||||||
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
|
|
||||||
} else if (isUnofficialChatGptModel) {
|
|
||||||
this.startToken = '<|im_start|>';
|
|
||||||
this.endToken = '<|im_end|>';
|
|
||||||
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
|
|
||||||
'<|im_start|>': 100264,
|
|
||||||
'<|im_end|>': 100265,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
|
|
||||||
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
|
|
||||||
// as a single token. So we're using this instead.
|
|
||||||
this.startToken = '||>';
|
|
||||||
this.endToken = '';
|
|
||||||
try {
|
|
||||||
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
|
|
||||||
} catch {
|
|
||||||
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!this.modelOptions.stop) {
|
|
||||||
const stopTokens = [this.startToken];
|
|
||||||
if (this.endToken && this.endToken !== this.startToken) {
|
|
||||||
stopTokens.push(this.endToken);
|
|
||||||
}
|
|
||||||
stopTokens.push(`\n${this.userLabel}:`);
|
|
||||||
stopTokens.push('<|diff_marker|>');
|
|
||||||
// I chose not to do one for `chatGptLabel` because I've never seen it happen
|
|
||||||
this.modelOptions.stop = stopTokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.options.reverseProxyUrl) {
|
|
||||||
this.completionsUrl = this.options.reverseProxyUrl;
|
|
||||||
} else if (isChatGptModel) {
|
|
||||||
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
|
|
||||||
} else {
|
|
||||||
this.completionsUrl = 'https://api.openai.com/v1/completions';
|
|
||||||
}
|
|
||||||
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
|
||||||
if (tokenizersCache[encoding]) {
|
|
||||||
return tokenizersCache[encoding];
|
|
||||||
}
|
|
||||||
let tokenizer;
|
|
||||||
if (isModelName) {
|
|
||||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
|
||||||
} else {
|
|
||||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
|
||||||
}
|
|
||||||
tokenizersCache[encoding] = tokenizer;
|
|
||||||
return tokenizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @type {getCompletion} */
|
|
||||||
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
|
|
||||||
if (!abortController) {
|
|
||||||
abortController = new AbortController();
|
|
||||||
}
|
|
||||||
|
|
||||||
let modelOptions = { ...this.modelOptions };
|
|
||||||
if (typeof onProgress === 'function') {
|
|
||||||
modelOptions.stream = true;
|
|
||||||
}
|
|
||||||
if (this.isChatGptModel) {
|
|
||||||
modelOptions.messages = input;
|
|
||||||
} else {
|
|
||||||
modelOptions.prompt = input;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.useOpenRouter && modelOptions.prompt) {
|
|
||||||
delete modelOptions.stop;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { debug } = this.options;
|
|
||||||
let baseURL = this.completionsUrl;
|
|
||||||
if (debug) {
|
|
||||||
console.debug();
|
|
||||||
console.debug(baseURL);
|
|
||||||
console.debug(modelOptions);
|
|
||||||
console.debug();
|
|
||||||
}
|
|
||||||
|
|
||||||
const opts = {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
if (this.isVisionModel) {
|
|
||||||
modelOptions.max_tokens = 4000;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @type {TAzureConfig | undefined} */
|
|
||||||
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
|
|
||||||
|
|
||||||
const isAzure = this.azure || this.options.azure;
|
|
||||||
if (
|
|
||||||
(isAzure && this.isVisionModel && azureConfig) ||
|
|
||||||
(azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
|
|
||||||
) {
|
|
||||||
const { modelGroupMap, groupMap } = azureConfig;
|
|
||||||
const {
|
|
||||||
azureOptions,
|
|
||||||
baseURL,
|
|
||||||
headers = {},
|
|
||||||
serverless,
|
|
||||||
} = mapModelToAzureConfig({
|
|
||||||
modelName: modelOptions.model,
|
|
||||||
modelGroupMap,
|
|
||||||
groupMap,
|
|
||||||
});
|
|
||||||
opts.headers = resolveHeaders(headers);
|
|
||||||
this.langchainProxy = extractBaseURL(baseURL);
|
|
||||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
|
||||||
|
|
||||||
const groupName = modelGroupMap[modelOptions.model].group;
|
|
||||||
this.options.addParams = azureConfig.groupMap[groupName].addParams;
|
|
||||||
this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
|
|
||||||
// Note: `forcePrompt` not re-assigned as only chat models are vision models
|
|
||||||
|
|
||||||
this.azure = !serverless && azureOptions;
|
|
||||||
this.azureEndpoint =
|
|
||||||
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
|
|
||||||
if (serverless === true) {
|
|
||||||
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
|
|
||||||
? { 'api-version': azureOptions.azureOpenAIApiVersion }
|
|
||||||
: undefined;
|
|
||||||
this.options.headers['api-key'] = this.apiKey;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.options.defaultQuery) {
|
|
||||||
opts.defaultQuery = this.options.defaultQuery;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.options.headers) {
|
|
||||||
opts.headers = { ...opts.headers, ...this.options.headers };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isAzure) {
|
|
||||||
// Azure does not accept `model` in the body, so we need to remove it.
|
|
||||||
delete modelOptions.model;
|
|
||||||
|
|
||||||
baseURL = this.langchainProxy
|
|
||||||
? constructAzureURL({
|
|
||||||
baseURL: this.langchainProxy,
|
|
||||||
azureOptions: this.azure,
|
|
||||||
})
|
|
||||||
: this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
|
|
||||||
|
|
||||||
if (this.options.forcePrompt) {
|
|
||||||
baseURL += '/completions';
|
|
||||||
} else {
|
|
||||||
baseURL += '/chat/completions';
|
|
||||||
}
|
|
||||||
|
|
||||||
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
|
|
||||||
opts.headers = { ...opts.headers, 'api-key': this.apiKey };
|
|
||||||
} else if (this.apiKey) {
|
|
||||||
opts.headers.Authorization = `Bearer ${this.apiKey}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (process.env.OPENAI_ORGANIZATION) {
|
|
||||||
opts.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.useOpenRouter) {
|
|
||||||
opts.headers['HTTP-Referer'] = 'https://librechat.ai';
|
|
||||||
opts.headers['X-Title'] = 'LibreChat';
|
|
||||||
}
|
|
||||||
|
|
||||||
/* hacky fixes for Mistral AI API:
|
|
||||||
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
|
|
||||||
- If there is only one message and it's a system message, change the role to user
|
|
||||||
*/
|
|
||||||
if (baseURL.includes('https://api.mistral.ai/v1') && modelOptions.messages) {
|
|
||||||
const { messages } = modelOptions;
|
|
||||||
|
|
||||||
const systemMessageIndex = messages.findIndex((msg) => msg.role === 'system');
|
|
||||||
|
|
||||||
if (systemMessageIndex > 0) {
|
|
||||||
const [systemMessage] = messages.splice(systemMessageIndex, 1);
|
|
||||||
messages.unshift(systemMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
modelOptions.messages = messages;
|
|
||||||
|
|
||||||
if (messages.length === 1 && messages[0].role === 'system') {
|
|
||||||
modelOptions.messages[0].role = 'user';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.options.addParams && typeof this.options.addParams === 'object') {
|
|
||||||
modelOptions = {
|
|
||||||
...modelOptions,
|
|
||||||
...this.options.addParams,
|
|
||||||
};
|
|
||||||
logger.debug('[ChatGPTClient] chatCompletion: added params', {
|
|
||||||
addParams: this.options.addParams,
|
|
||||||
modelOptions,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
|
|
||||||
this.options.dropParams.forEach((param) => {
|
|
||||||
delete modelOptions[param];
|
|
||||||
});
|
|
||||||
logger.debug('[ChatGPTClient] chatCompletion: dropped params', {
|
|
||||||
dropParams: this.options.dropParams,
|
|
||||||
modelOptions,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (baseURL.startsWith(CohereConstants.API_URL)) {
|
|
||||||
const payload = createCoherePayload({ modelOptions });
|
|
||||||
return await this.cohereChatCompletion({ payload, onTokenProgress });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
|
|
||||||
baseURL = baseURL.split('v1')[0] + 'v1/completions';
|
|
||||||
} else if (
|
|
||||||
baseURL.includes('v1') &&
|
|
||||||
!baseURL.includes('/chat/completions') &&
|
|
||||||
this.isChatCompletion
|
|
||||||
) {
|
|
||||||
baseURL = baseURL.split('v1')[0] + 'v1/chat/completions';
|
|
||||||
}
|
|
||||||
|
|
||||||
const BASE_URL = new URL(baseURL);
|
|
||||||
if (opts.defaultQuery) {
|
|
||||||
Object.entries(opts.defaultQuery).forEach(([key, value]) => {
|
|
||||||
BASE_URL.searchParams.append(key, value);
|
|
||||||
});
|
|
||||||
delete opts.defaultQuery;
|
|
||||||
}
|
|
||||||
|
|
||||||
const completionsURL = BASE_URL.toString();
|
|
||||||
opts.body = JSON.stringify(modelOptions);
|
|
||||||
|
|
||||||
if (modelOptions.stream) {
|
|
||||||
return new Promise(async (resolve, reject) => {
|
|
||||||
try {
|
|
||||||
let done = false;
|
|
||||||
await fetchEventSource(completionsURL, {
|
|
||||||
...opts,
|
|
||||||
signal: abortController.signal,
|
|
||||||
async onopen(response) {
|
|
||||||
if (response.status === 200) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (debug) {
|
|
||||||
console.debug(response);
|
|
||||||
}
|
|
||||||
let error;
|
|
||||||
try {
|
|
||||||
const body = await response.text();
|
|
||||||
error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
|
|
||||||
error.status = response.status;
|
|
||||||
error.json = JSON.parse(body);
|
|
||||||
} catch {
|
|
||||||
error = error || new Error(`Failed to send message. HTTP ${response.status}`);
|
|
||||||
}
|
|
||||||
throw error;
|
|
||||||
},
|
|
||||||
onclose() {
|
|
||||||
if (debug) {
|
|
||||||
console.debug('Server closed the connection unexpectedly, returning...');
|
|
||||||
}
|
|
||||||
// workaround for private API not sending [DONE] event
|
|
||||||
if (!done) {
|
|
||||||
onProgress('[DONE]');
|
|
||||||
resolve();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onerror(err) {
|
|
||||||
if (debug) {
|
|
||||||
console.debug(err);
|
|
||||||
}
|
|
||||||
// rethrow to stop the operation
|
|
||||||
throw err;
|
|
||||||
},
|
|
||||||
onmessage(message) {
|
|
||||||
if (debug) {
|
|
||||||
console.debug(message);
|
|
||||||
}
|
|
||||||
if (!message.data || message.event === 'ping') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (message.data === '[DONE]') {
|
|
||||||
onProgress('[DONE]');
|
|
||||||
resolve();
|
|
||||||
done = true;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
onProgress(JSON.parse(message.data));
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} catch (err) {
|
|
||||||
reject(err);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
const response = await fetch(completionsURL, {
|
|
||||||
...opts,
|
|
||||||
signal: abortController.signal,
|
|
||||||
});
|
|
||||||
if (response.status !== 200) {
|
|
||||||
const body = await response.text();
|
|
||||||
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
|
|
||||||
error.status = response.status;
|
|
||||||
try {
|
|
||||||
error.json = JSON.parse(body);
|
|
||||||
} catch {
|
|
||||||
error.body = body;
|
|
||||||
}
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
return response.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @type {cohereChatCompletion} */
|
|
||||||
async cohereChatCompletion({ payload, onTokenProgress }) {
|
|
||||||
const cohere = new CohereClient({
|
|
||||||
token: this.apiKey,
|
|
||||||
environment: this.completionsUrl,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!payload.stream) {
|
|
||||||
const chatResponse = await cohere.chat(payload);
|
|
||||||
return chatResponse.text;
|
|
||||||
}
|
|
||||||
|
|
||||||
const chatStream = await cohere.chatStream(payload);
|
|
||||||
let reply = '';
|
|
||||||
for await (const message of chatStream) {
|
|
||||||
if (!message) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (message.eventType === 'text-generation' && message.text) {
|
|
||||||
onTokenProgress(message.text);
|
|
||||||
reply += message.text;
|
|
||||||
}
|
|
||||||
/*
|
|
||||||
Cohere API Chinese Unicode character replacement hotfix.
|
|
||||||
Should be un-commented when the following issue is resolved:
|
|
||||||
https://github.com/cohere-ai/cohere-typescript/issues/151
|
|
||||||
|
|
||||||
else if (message.eventType === 'stream-end' && message.response) {
|
|
||||||
reply = message.response.text;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
return reply;
|
|
||||||
}
|
|
||||||
|
|
||||||
async generateTitle(userMessage, botMessage) {
|
|
||||||
const instructionsPayload = {
|
|
||||||
role: 'system',
|
|
||||||
content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation.
|
|
||||||
|
|
||||||
||>Message:
|
|
||||||
${userMessage.message}
|
|
||||||
||>Response:
|
|
||||||
${botMessage.message}
|
|
||||||
|
|
||||||
||>Title:`,
|
|
||||||
};
|
|
||||||
|
|
||||||
const titleGenClientOptions = JSON.parse(JSON.stringify(this.options));
|
|
||||||
titleGenClientOptions.modelOptions = {
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
temperature: 0,
|
|
||||||
presence_penalty: 0,
|
|
||||||
frequency_penalty: 0,
|
|
||||||
};
|
|
||||||
const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions);
|
|
||||||
const result = await titleGenClient.getCompletion([instructionsPayload], null);
|
|
||||||
// remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim
|
|
||||||
return result.choices[0].message.content
|
|
||||||
.replace(/[^a-zA-Z0-9' ]/g, '')
|
|
||||||
.replace(/\s+/g, ' ')
|
|
||||||
.trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
async sendMessage(message, opts = {}) {
|
|
||||||
if (opts.clientOptions && typeof opts.clientOptions === 'object') {
|
|
||||||
this.setOptions(opts.clientOptions);
|
|
||||||
}
|
|
||||||
|
|
||||||
const conversationId = opts.conversationId || crypto.randomUUID();
|
|
||||||
const parentMessageId = opts.parentMessageId || crypto.randomUUID();
|
|
||||||
|
|
||||||
let conversation =
|
|
||||||
typeof opts.conversation === 'object'
|
|
||||||
? opts.conversation
|
|
||||||
: await this.conversationsCache.get(conversationId);
|
|
||||||
|
|
||||||
let isNewConversation = false;
|
|
||||||
if (!conversation) {
|
|
||||||
conversation = {
|
|
||||||
messages: [],
|
|
||||||
createdAt: Date.now(),
|
|
||||||
};
|
|
||||||
isNewConversation = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
|
|
||||||
|
|
||||||
const userMessage = {
|
|
||||||
id: crypto.randomUUID(),
|
|
||||||
parentMessageId,
|
|
||||||
role: 'User',
|
|
||||||
message,
|
|
||||||
};
|
|
||||||
conversation.messages.push(userMessage);
|
|
||||||
|
|
||||||
// Doing it this way instead of having each message be a separate element in the array seems to be more reliable,
|
|
||||||
// especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention.
|
|
||||||
const { prompt: payload, context } = await this.buildPrompt(
|
|
||||||
conversation.messages,
|
|
||||||
userMessage.id,
|
|
||||||
{
|
|
||||||
isChatGptModel: this.isChatGptModel,
|
|
||||||
promptPrefix: opts.promptPrefix,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
if (this.options.keepNecessaryMessagesOnly) {
|
|
||||||
conversation.messages = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
let reply = '';
|
|
||||||
let result = null;
|
|
||||||
if (typeof opts.onProgress === 'function') {
|
|
||||||
await this.getCompletion(
|
|
||||||
payload,
|
|
||||||
(progressMessage) => {
|
|
||||||
if (progressMessage === '[DONE]') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const token = this.isChatGptModel
|
|
||||||
? progressMessage.choices[0].delta.content
|
|
||||||
: progressMessage.choices[0].text;
|
|
||||||
// first event's delta content is always undefined
|
|
||||||
if (!token) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (this.options.debug) {
|
|
||||||
console.debug(token);
|
|
||||||
}
|
|
||||||
if (token === this.endToken) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
opts.onProgress(token);
|
|
||||||
reply += token;
|
|
||||||
},
|
|
||||||
opts.abortController || new AbortController(),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
result = await this.getCompletion(
|
|
||||||
payload,
|
|
||||||
null,
|
|
||||||
opts.abortController || new AbortController(),
|
|
||||||
);
|
|
||||||
if (this.options.debug) {
|
|
||||||
console.debug(JSON.stringify(result));
|
|
||||||
}
|
|
||||||
if (this.isChatGptModel) {
|
|
||||||
reply = result.choices[0].message.content;
|
|
||||||
} else {
|
|
||||||
reply = result.choices[0].text.replace(this.endToken, '');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// avoids some rendering issues when using the CLI app
|
|
||||||
if (this.options.debug) {
|
|
||||||
console.debug();
|
|
||||||
}
|
|
||||||
|
|
||||||
reply = reply.trim();
|
|
||||||
|
|
||||||
const replyMessage = {
|
|
||||||
id: crypto.randomUUID(),
|
|
||||||
parentMessageId: userMessage.id,
|
|
||||||
role: 'ChatGPT',
|
|
||||||
message: reply,
|
|
||||||
};
|
|
||||||
conversation.messages.push(replyMessage);
|
|
||||||
|
|
||||||
const returnData = {
|
|
||||||
response: replyMessage.message,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId: replyMessage.parentMessageId,
|
|
||||||
messageId: replyMessage.id,
|
|
||||||
details: result || {},
|
|
||||||
};
|
|
||||||
|
|
||||||
if (shouldGenerateTitle) {
|
|
||||||
conversation.title = await this.generateTitle(userMessage, replyMessage);
|
|
||||||
returnData.title = conversation.title;
|
|
||||||
}
|
|
||||||
|
|
||||||
await this.conversationsCache.set(conversationId, conversation);
|
|
||||||
|
|
||||||
if (this.options.returnConversation) {
|
|
||||||
returnData.conversation = conversation;
|
|
||||||
}
|
|
||||||
|
|
||||||
return returnData;
|
|
||||||
}
|
|
||||||
|
|
||||||
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
|
|
||||||
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
|
|
||||||
|
|
||||||
// Handle attachments and create augmentedPrompt
|
|
||||||
if (this.options.attachments) {
|
|
||||||
const attachments = await this.options.attachments;
|
|
||||||
const lastMessage = messages[messages.length - 1];
|
|
||||||
|
|
||||||
if (this.message_file_map) {
|
|
||||||
this.message_file_map[lastMessage.messageId] = attachments;
|
|
||||||
} else {
|
|
||||||
this.message_file_map = {
|
|
||||||
[lastMessage.messageId]: attachments,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const files = await this.addImageURLs(lastMessage, attachments);
|
|
||||||
this.options.attachments = files;
|
|
||||||
|
|
||||||
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.message_file_map) {
|
|
||||||
this.contextHandlers = createContextHandlers(
|
|
||||||
this.options.req,
|
|
||||||
messages[messages.length - 1].text,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate image token cost and process embedded files
|
|
||||||
messages.forEach((message, i) => {
|
|
||||||
if (this.message_file_map && this.message_file_map[message.messageId]) {
|
|
||||||
const attachments = this.message_file_map[message.messageId];
|
|
||||||
for (const file of attachments) {
|
|
||||||
if (file.embedded) {
|
|
||||||
this.contextHandlers?.processFile(file);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
messages[i].tokenCount =
|
|
||||||
(messages[i].tokenCount || 0) +
|
|
||||||
this.calculateImageTokenCost({
|
|
||||||
width: file.width,
|
|
||||||
height: file.height,
|
|
||||||
detail: this.options.imageDetail ?? ImageDetail.auto,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (this.contextHandlers) {
|
|
||||||
this.augmentedPrompt = await this.contextHandlers.createContext();
|
|
||||||
promptPrefix = this.augmentedPrompt + promptPrefix;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (promptPrefix) {
|
|
||||||
// If the prompt prefix doesn't end with the end token, add it.
|
|
||||||
if (!promptPrefix.endsWith(`${this.endToken}`)) {
|
|
||||||
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
|
|
||||||
}
|
|
||||||
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
|
|
||||||
}
|
|
||||||
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
|
|
||||||
|
|
||||||
const instructionsPayload = {
|
|
||||||
role: 'system',
|
|
||||||
content: promptPrefix,
|
|
||||||
};
|
|
||||||
|
|
||||||
const messagePayload = {
|
|
||||||
role: 'system',
|
|
||||||
content: promptSuffix,
|
|
||||||
};
|
|
||||||
|
|
||||||
let currentTokenCount;
|
|
||||||
if (isChatGptModel) {
|
|
||||||
currentTokenCount =
|
|
||||||
this.getTokenCountForMessage(instructionsPayload) +
|
|
||||||
this.getTokenCountForMessage(messagePayload);
|
|
||||||
} else {
|
|
||||||
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
|
|
||||||
}
|
|
||||||
let promptBody = '';
|
|
||||||
const maxTokenCount = this.maxPromptTokens;
|
|
||||||
|
|
||||||
const context = [];
|
|
||||||
|
|
||||||
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
|
|
||||||
// Do this within a recursive async function so that it doesn't block the event loop for too long.
|
|
||||||
const buildPromptBody = async () => {
|
|
||||||
if (currentTokenCount < maxTokenCount && messages.length > 0) {
|
|
||||||
const message = messages.pop();
|
|
||||||
const roleLabel =
|
|
||||||
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
|
|
||||||
? this.userLabel
|
|
||||||
: this.chatGptLabel;
|
|
||||||
const messageString = `${this.startToken}${roleLabel}:\n${
|
|
||||||
message?.text ?? message?.message
|
|
||||||
}${this.endToken}\n`;
|
|
||||||
let newPromptBody;
|
|
||||||
if (promptBody || isChatGptModel) {
|
|
||||||
newPromptBody = `${messageString}${promptBody}`;
|
|
||||||
} else {
|
|
||||||
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
|
|
||||||
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
|
|
||||||
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
|
|
||||||
// like "what's the last thing I wrote?".
|
|
||||||
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
context.unshift(message);
|
|
||||||
|
|
||||||
const tokenCountForMessage = this.getTokenCount(messageString);
|
|
||||||
const newTokenCount = currentTokenCount + tokenCountForMessage;
|
|
||||||
if (newTokenCount > maxTokenCount) {
|
|
||||||
if (promptBody) {
|
|
||||||
// This message would put us over the token limit, so don't add it.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// This is the first message, so we can't add it. Just throw an error.
|
|
||||||
throw new Error(
|
|
||||||
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
promptBody = newPromptBody;
|
|
||||||
currentTokenCount = newTokenCount;
|
|
||||||
// wait for next tick to avoid blocking the event loop
|
|
||||||
await new Promise((resolve) => setImmediate(resolve));
|
|
||||||
return buildPromptBody();
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
await buildPromptBody();
|
|
||||||
|
|
||||||
const prompt = `${promptBody}${promptSuffix}`;
|
|
||||||
if (isChatGptModel) {
|
|
||||||
messagePayload.content = prompt;
|
|
||||||
// Add 3 tokens for Assistant Label priming after all messages have been counted.
|
|
||||||
currentTokenCount += 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
|
||||||
this.modelOptions.max_tokens = Math.min(
|
|
||||||
this.maxContextTokens - currentTokenCount,
|
|
||||||
this.maxResponseTokens,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isChatGptModel) {
|
|
||||||
return { prompt: [instructionsPayload, messagePayload], context };
|
|
||||||
}
|
|
||||||
return { prompt, context, promptTokens: currentTokenCount };
|
|
||||||
}
|
|
||||||
|
|
||||||
getTokenCount(text) {
|
|
||||||
return this.gptEncoder.encode(text, 'all').length;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Algorithm adapted from "6. Counting tokens for chat API calls" of
|
|
||||||
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
*
|
|
||||||
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
|
|
||||||
*
|
|
||||||
* @param {Object} message
|
|
||||||
*/
|
|
||||||
getTokenCountForMessage(message) {
|
|
||||||
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
|
|
||||||
let tokensPerMessage = 3;
|
|
||||||
let tokensPerName = 1;
|
|
||||||
|
|
||||||
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
|
|
||||||
tokensPerMessage = 4;
|
|
||||||
tokensPerName = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
let numTokens = tokensPerMessage;
|
|
||||||
for (let [key, value] of Object.entries(message)) {
|
|
||||||
numTokens += this.getTokenCount(value);
|
|
||||||
if (key === 'name') {
|
|
||||||
numTokens += tokensPerName;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return numTokens;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports = ChatGPTClient;
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
const { google } = require('googleapis');
|
const { google } = require('googleapis');
|
||||||
const { Tokenizer } = require('@librechat/api');
|
|
||||||
const { concat } = require('@langchain/core/utils/stream');
|
const { concat } = require('@langchain/core/utils/stream');
|
||||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||||
|
const { Tokenizer, getSafetySettings } = require('@librechat/api');
|
||||||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||||
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
|
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
|
||||||
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
||||||
|
|
@ -12,13 +12,13 @@ const {
|
||||||
endpointSettings,
|
endpointSettings,
|
||||||
parseTextParts,
|
parseTextParts,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
|
googleSettings,
|
||||||
ContentTypes,
|
ContentTypes,
|
||||||
VisionModes,
|
VisionModes,
|
||||||
ErrorTypes,
|
ErrorTypes,
|
||||||
Constants,
|
Constants,
|
||||||
AuthKeys,
|
AuthKeys,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
|
|
||||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||||
const { spendTokens } = require('~/models/spendTokens');
|
const { spendTokens } = require('~/models/spendTokens');
|
||||||
const { getModelMaxTokens } = require('~/utils');
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
|
|
@ -166,6 +166,16 @@ class GoogleClient extends BaseClient {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add thinking configuration
|
||||||
|
this.modelOptions.thinkingConfig = {
|
||||||
|
thinkingBudget:
|
||||||
|
(this.modelOptions.thinking ?? googleSettings.thinking.default)
|
||||||
|
? this.modelOptions.thinkingBudget
|
||||||
|
: 0,
|
||||||
|
};
|
||||||
|
delete this.modelOptions.thinking;
|
||||||
|
delete this.modelOptions.thinkingBudget;
|
||||||
|
|
||||||
this.sender =
|
this.sender =
|
||||||
this.options.sender ??
|
this.options.sender ??
|
||||||
getResponseSender({
|
getResponseSender({
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ const {
|
||||||
isEnabled,
|
isEnabled,
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
createFetch,
|
createFetch,
|
||||||
|
resolveHeaders,
|
||||||
constructAzureURL,
|
constructAzureURL,
|
||||||
genAzureChatCompletion,
|
genAzureChatCompletion,
|
||||||
createStreamEventHandlers,
|
createStreamEventHandlers,
|
||||||
|
|
@ -15,7 +16,6 @@ const {
|
||||||
ContentTypes,
|
ContentTypes,
|
||||||
parseTextParts,
|
parseTextParts,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
resolveHeaders,
|
|
||||||
KnownEndpoints,
|
KnownEndpoints,
|
||||||
openAISettings,
|
openAISettings,
|
||||||
ImageDetailCost,
|
ImageDetailCost,
|
||||||
|
|
@ -37,7 +37,6 @@ const { addSpaceIfNeeded, sleep } = require('~/server/utils');
|
||||||
const { spendTokens } = require('~/models/spendTokens');
|
const { spendTokens } = require('~/models/spendTokens');
|
||||||
const { handleOpenAIErrors } = require('./tools/util');
|
const { handleOpenAIErrors } = require('./tools/util');
|
||||||
const { createLLM, RunManager } = require('./llm');
|
const { createLLM, RunManager } = require('./llm');
|
||||||
const ChatGPTClient = require('./ChatGPTClient');
|
|
||||||
const { summaryBuffer } = require('./memory');
|
const { summaryBuffer } = require('./memory');
|
||||||
const { runTitleChain } = require('./chains');
|
const { runTitleChain } = require('./chains');
|
||||||
const { tokenSplit } = require('./document');
|
const { tokenSplit } = require('./document');
|
||||||
|
|
@ -47,12 +46,6 @@ const { logger } = require('~/config');
|
||||||
class OpenAIClient extends BaseClient {
|
class OpenAIClient extends BaseClient {
|
||||||
constructor(apiKey, options = {}) {
|
constructor(apiKey, options = {}) {
|
||||||
super(apiKey, options);
|
super(apiKey, options);
|
||||||
this.ChatGPTClient = new ChatGPTClient();
|
|
||||||
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
|
|
||||||
/** @type {getCompletion} */
|
|
||||||
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
|
|
||||||
/** @type {cohereChatCompletion} */
|
|
||||||
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
|
|
||||||
this.contextStrategy = options.contextStrategy
|
this.contextStrategy = options.contextStrategy
|
||||||
? options.contextStrategy.toLowerCase()
|
? options.contextStrategy.toLowerCase()
|
||||||
: 'discard';
|
: 'discard';
|
||||||
|
|
@ -379,23 +372,12 @@ class OpenAIClient extends BaseClient {
|
||||||
return files;
|
return files;
|
||||||
}
|
}
|
||||||
|
|
||||||
async buildMessages(
|
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) {
|
||||||
messages,
|
|
||||||
parentMessageId,
|
|
||||||
{ isChatCompletion = false, promptPrefix = null },
|
|
||||||
opts,
|
|
||||||
) {
|
|
||||||
let orderedMessages = this.constructor.getMessagesForConversation({
|
let orderedMessages = this.constructor.getMessagesForConversation({
|
||||||
messages,
|
messages,
|
||||||
parentMessageId,
|
parentMessageId,
|
||||||
summary: this.shouldSummarize,
|
summary: this.shouldSummarize,
|
||||||
});
|
});
|
||||||
if (!isChatCompletion) {
|
|
||||||
return await this.buildPrompt(orderedMessages, {
|
|
||||||
isChatGptModel: isChatCompletion,
|
|
||||||
promptPrefix,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let payload;
|
let payload;
|
||||||
let instructions;
|
let instructions;
|
||||||
|
|
|
||||||
|
|
@ -1,542 +0,0 @@
|
||||||
const OpenAIClient = require('./OpenAIClient');
|
|
||||||
const { CallbackManager } = require('@langchain/core/callbacks/manager');
|
|
||||||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
|
||||||
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
|
||||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
|
||||||
const { processFileURL } = require('~/server/services/Files/process');
|
|
||||||
const { EModelEndpoint } = require('librechat-data-provider');
|
|
||||||
const { checkBalance } = require('~/models/balanceMethods');
|
|
||||||
const { formatLangChainMessages } = require('./prompts');
|
|
||||||
const { extractBaseURL } = require('~/utils');
|
|
||||||
const { loadTools } = require('./tools/util');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
class PluginsClient extends OpenAIClient {
|
|
||||||
constructor(apiKey, options = {}) {
|
|
||||||
super(apiKey, options);
|
|
||||||
this.sender = options.sender ?? 'Assistant';
|
|
||||||
this.tools = [];
|
|
||||||
this.actions = [];
|
|
||||||
this.setOptions(options);
|
|
||||||
this.openAIApiKey = this.apiKey;
|
|
||||||
this.executor = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
setOptions(options) {
|
|
||||||
this.agentOptions = { ...options.agentOptions };
|
|
||||||
this.functionsAgent = this.agentOptions?.agent === 'functions';
|
|
||||||
this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3');
|
|
||||||
|
|
||||||
super.setOptions(options);
|
|
||||||
|
|
||||||
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
|
|
||||||
|
|
||||||
if (this.options.reverseProxyUrl) {
|
|
||||||
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
getSaveOptions() {
|
|
||||||
return {
|
|
||||||
artifacts: this.options.artifacts,
|
|
||||||
chatGptLabel: this.options.chatGptLabel,
|
|
||||||
modelLabel: this.options.modelLabel,
|
|
||||||
promptPrefix: this.options.promptPrefix,
|
|
||||||
tools: this.options.tools,
|
|
||||||
...this.modelOptions,
|
|
||||||
agentOptions: this.agentOptions,
|
|
||||||
iconURL: this.options.iconURL,
|
|
||||||
greeting: this.options.greeting,
|
|
||||||
spec: this.options.spec,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
saveLatestAction(action) {
|
|
||||||
this.actions.push(action);
|
|
||||||
}
|
|
||||||
|
|
||||||
getFunctionModelName(input) {
|
|
||||||
if (/-(?!0314)\d{4}/.test(input)) {
|
|
||||||
return input;
|
|
||||||
} else if (input.includes('gpt-3.5-turbo')) {
|
|
||||||
return 'gpt-3.5-turbo';
|
|
||||||
} else if (input.includes('gpt-4')) {
|
|
||||||
return 'gpt-4';
|
|
||||||
} else {
|
|
||||||
return 'gpt-3.5-turbo';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
getBuildMessagesOptions(opts) {
|
|
||||||
return {
|
|
||||||
isChatCompletion: true,
|
|
||||||
promptPrefix: opts.promptPrefix,
|
|
||||||
abortController: opts.abortController,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
|
|
||||||
const modelOptions = {
|
|
||||||
modelName: this.agentOptions.model,
|
|
||||||
temperature: this.agentOptions.temperature,
|
|
||||||
};
|
|
||||||
|
|
||||||
const model = this.initializeLLM({
|
|
||||||
...modelOptions,
|
|
||||||
context: 'plugins',
|
|
||||||
initialMessageCount: this.currentMessages.length + 1,
|
|
||||||
});
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
`[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Map Messages to Langchain format
|
|
||||||
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
|
|
||||||
userName: this.options?.name,
|
|
||||||
});
|
|
||||||
logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length);
|
|
||||||
|
|
||||||
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
|
|
||||||
const memory = new BufferMemory({
|
|
||||||
llm: model,
|
|
||||||
chatHistory: new ChatMessageHistory(pastMessages),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { loadedTools } = await loadTools({
|
|
||||||
user,
|
|
||||||
model,
|
|
||||||
tools: this.options.tools,
|
|
||||||
functions: this.functionsAgent,
|
|
||||||
options: {
|
|
||||||
memory,
|
|
||||||
signal: this.abortController.signal,
|
|
||||||
openAIApiKey: this.openAIApiKey,
|
|
||||||
conversationId: this.conversationId,
|
|
||||||
fileStrategy: this.options.req.app.locals.fileStrategy,
|
|
||||||
processFileURL,
|
|
||||||
message,
|
|
||||||
},
|
|
||||||
useSpecs: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (loadedTools.length === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.tools = loadedTools;
|
|
||||||
|
|
||||||
logger.debug('[PluginsClient] Requested Tools', this.options.tools);
|
|
||||||
logger.debug(
|
|
||||||
'[PluginsClient] Loaded Tools',
|
|
||||||
this.tools.map((tool) => tool.name),
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleAction = (action, runId, callback = null) => {
|
|
||||||
this.saveLatestAction(action);
|
|
||||||
|
|
||||||
logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]);
|
|
||||||
|
|
||||||
if (typeof callback === 'function') {
|
|
||||||
callback(action, runId);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// initialize agent
|
|
||||||
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
|
|
||||||
|
|
||||||
let customInstructions = (this.options.promptPrefix ?? '').trim();
|
|
||||||
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
|
||||||
customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
this.executor = await initializer({
|
|
||||||
model,
|
|
||||||
signal,
|
|
||||||
pastMessages,
|
|
||||||
tools: this.tools,
|
|
||||||
customInstructions,
|
|
||||||
verbose: this.options.debug,
|
|
||||||
returnIntermediateSteps: true,
|
|
||||||
customName: this.options.chatGptLabel,
|
|
||||||
currentDateString: this.currentDateString,
|
|
||||||
callbackManager: CallbackManager.fromHandlers({
|
|
||||||
async handleAgentAction(action, runId) {
|
|
||||||
handleAction(action, runId, onAgentAction);
|
|
||||||
},
|
|
||||||
async handleChainEnd(action) {
|
|
||||||
if (typeof onChainEnd === 'function') {
|
|
||||||
onChainEnd(action);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
logger.debug('[PluginsClient] Loaded agent.');
|
|
||||||
}
|
|
||||||
|
|
||||||
async executorCall(message, { signal, stream, onToolStart, onToolEnd }) {
|
|
||||||
let errorMessage = '';
|
|
||||||
const maxAttempts = 1;
|
|
||||||
|
|
||||||
for (let attempts = 1; attempts <= maxAttempts; attempts++) {
|
|
||||||
const errorInput = buildErrorInput({
|
|
||||||
message,
|
|
||||||
errorMessage,
|
|
||||||
actions: this.actions,
|
|
||||||
functionsAgent: this.functionsAgent,
|
|
||||||
});
|
|
||||||
const input = attempts > 1 ? errorInput : message;
|
|
||||||
|
|
||||||
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
|
|
||||||
|
|
||||||
if (errorMessage.length > 0) {
|
|
||||||
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
this.result = await this.executor.call({ input, signal }, [
|
|
||||||
{
|
|
||||||
async handleToolStart(...args) {
|
|
||||||
await onToolStart(...args);
|
|
||||||
},
|
|
||||||
async handleToolEnd(...args) {
|
|
||||||
await onToolEnd(...args);
|
|
||||||
},
|
|
||||||
async handleLLMEnd(output) {
|
|
||||||
const { generations } = output;
|
|
||||||
const { text } = generations[0][0];
|
|
||||||
if (text && typeof stream === 'function') {
|
|
||||||
await stream(text);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
break; // Exit the loop if the function call is successful
|
|
||||||
} catch (err) {
|
|
||||||
logger.error('[PluginsClient] executorCall error:', err);
|
|
||||||
if (attempts === maxAttempts) {
|
|
||||||
const { run } = this.runManager.getRunByConversationId(this.conversationId);
|
|
||||||
const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`;
|
|
||||||
this.result.output = run && run.error ? run.error : defaultOutput;
|
|
||||||
this.result.errorMessage = run && run.error ? run.error : err.message;
|
|
||||||
this.result.intermediateSteps = this.actions;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param {TMessage} responseMessage
|
|
||||||
* @param {Partial<TMessage>} saveOptions
|
|
||||||
* @param {string} user
|
|
||||||
* @returns
|
|
||||||
*/
|
|
||||||
async handleResponseMessage(responseMessage, saveOptions, user) {
|
|
||||||
const { output, errorMessage, ...result } = this.result;
|
|
||||||
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
|
|
||||||
output,
|
|
||||||
errorMessage,
|
|
||||||
...result,
|
|
||||||
});
|
|
||||||
const { error } = responseMessage;
|
|
||||||
if (!error) {
|
|
||||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
|
||||||
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record usage only when completion is skipped as it is already recorded in the agent phase.
|
|
||||||
if (!this.agentOptions.skipCompletion && !error) {
|
|
||||||
await this.recordTokenUsage(responseMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
|
||||||
delete responseMessage.tokenCount;
|
|
||||||
return { ...responseMessage, ...result, databasePromise };
|
|
||||||
}
|
|
||||||
|
|
||||||
async sendMessage(message, opts = {}) {
|
|
||||||
/** @type {Promise<TMessage>} */
|
|
||||||
let userMessagePromise;
|
|
||||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
|
||||||
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
|
|
||||||
|
|
||||||
if (includedTools.length > 0) {
|
|
||||||
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
|
|
||||||
this.options.tools = tools;
|
|
||||||
} else {
|
|
||||||
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
|
|
||||||
this.options.tools = tools;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a message is edited, no tools can be used.
|
|
||||||
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
|
||||||
if (completionMode) {
|
|
||||||
this.setOptions(opts);
|
|
||||||
return super.sendMessage(message, opts);
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
|
|
||||||
const {
|
|
||||||
user,
|
|
||||||
conversationId,
|
|
||||||
responseMessageId,
|
|
||||||
saveOptions,
|
|
||||||
userMessage,
|
|
||||||
onAgentAction,
|
|
||||||
onChainEnd,
|
|
||||||
onToolStart,
|
|
||||||
onToolEnd,
|
|
||||||
} = await this.handleStartMethods(message, opts);
|
|
||||||
|
|
||||||
if (opts.progressCallback) {
|
|
||||||
opts.onProgress = opts.progressCallback.call(null, {
|
|
||||||
...(opts.progressOptions ?? {}),
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
this.currentMessages.push(userMessage);
|
|
||||||
|
|
||||||
let {
|
|
||||||
prompt: payload,
|
|
||||||
tokenCountMap,
|
|
||||||
promptTokens,
|
|
||||||
} = await this.buildMessages(
|
|
||||||
this.currentMessages,
|
|
||||||
userMessage.messageId,
|
|
||||||
this.getBuildMessagesOptions({
|
|
||||||
promptPrefix: null,
|
|
||||||
abortController: this.abortController,
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
if (tokenCountMap) {
|
|
||||||
logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap });
|
|
||||||
if (tokenCountMap[userMessage.messageId]) {
|
|
||||||
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
|
||||||
logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount);
|
|
||||||
}
|
|
||||||
this.handleTokenCountMap(tokenCountMap);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.result = {};
|
|
||||||
if (payload) {
|
|
||||||
this.currentMessages = payload;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!this.skipSaveUserMessage) {
|
|
||||||
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
|
||||||
if (typeof opts?.getReqData === 'function') {
|
|
||||||
opts.getReqData({
|
|
||||||
userMessagePromise,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const balance = this.options.req?.app?.locals?.balance;
|
|
||||||
if (balance?.enabled) {
|
|
||||||
await checkBalance({
|
|
||||||
req: this.options.req,
|
|
||||||
res: this.options.res,
|
|
||||||
txData: {
|
|
||||||
user: this.user,
|
|
||||||
tokenType: 'prompt',
|
|
||||||
amount: promptTokens,
|
|
||||||
debug: this.options.debug,
|
|
||||||
model: this.modelOptions.model,
|
|
||||||
endpoint: EModelEndpoint.openAI,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const responseMessage = {
|
|
||||||
endpoint: EModelEndpoint.gptPlugins,
|
|
||||||
iconURL: this.options.iconURL,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
isCreatedByUser: false,
|
|
||||||
model: this.modelOptions.model,
|
|
||||||
sender: this.sender,
|
|
||||||
promptTokens,
|
|
||||||
};
|
|
||||||
|
|
||||||
await this.initialize({
|
|
||||||
user,
|
|
||||||
message,
|
|
||||||
onAgentAction,
|
|
||||||
onChainEnd,
|
|
||||||
signal: this.abortController.signal,
|
|
||||||
onProgress: opts.onProgress,
|
|
||||||
});
|
|
||||||
|
|
||||||
// const stream = async (text) => {
|
|
||||||
// await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 });
|
|
||||||
// };
|
|
||||||
await this.executorCall(message, {
|
|
||||||
signal: this.abortController.signal,
|
|
||||||
// stream,
|
|
||||||
onToolStart,
|
|
||||||
onToolEnd,
|
|
||||||
});
|
|
||||||
|
|
||||||
// If message was aborted mid-generation
|
|
||||||
if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) {
|
|
||||||
responseMessage.text = 'Cancelled.';
|
|
||||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If error occurred during generation (likely token_balance)
|
|
||||||
if (this.result?.errorMessage?.length > 0) {
|
|
||||||
responseMessage.error = true;
|
|
||||||
responseMessage.text = this.result.output;
|
|
||||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
|
|
||||||
const partialText = opts.getPartialText();
|
|
||||||
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
|
|
||||||
responseMessage.text =
|
|
||||||
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
|
|
||||||
addImages(this.result.intermediateSteps, responseMessage);
|
|
||||||
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
|
|
||||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.agentOptions.skipCompletion && this.result.output) {
|
|
||||||
responseMessage.text = this.result.output;
|
|
||||||
addImages(this.result.intermediateSteps, responseMessage);
|
|
||||||
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
|
|
||||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('[PluginsClient] Completion phase: this.result', this.result);
|
|
||||||
|
|
||||||
const promptPrefix = buildPromptPrefix({
|
|
||||||
result: this.result,
|
|
||||||
message,
|
|
||||||
functionsAgent: this.functionsAgent,
|
|
||||||
});
|
|
||||||
|
|
||||||
logger.debug('[PluginsClient]', { promptPrefix });
|
|
||||||
|
|
||||||
payload = await this.buildCompletionPrompt({
|
|
||||||
messages: this.currentMessages,
|
|
||||||
promptPrefix,
|
|
||||||
});
|
|
||||||
|
|
||||||
logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload);
|
|
||||||
responseMessage.text = await this.sendCompletion(payload, opts);
|
|
||||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
|
||||||
}
|
|
||||||
|
|
||||||
async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) {
|
|
||||||
logger.debug('[PluginsClient] buildCompletionPrompt messages', messages);
|
|
||||||
|
|
||||||
const orderedMessages = messages;
|
|
||||||
let promptPrefix = _promptPrefix.trim();
|
|
||||||
// If the prompt prefix doesn't end with the end token, add it.
|
|
||||||
if (!promptPrefix.endsWith(`${this.endToken}`)) {
|
|
||||||
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
|
|
||||||
}
|
|
||||||
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
|
|
||||||
const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`;
|
|
||||||
|
|
||||||
const instructionsPayload = {
|
|
||||||
role: 'system',
|
|
||||||
content: promptPrefix,
|
|
||||||
};
|
|
||||||
|
|
||||||
const messagePayload = {
|
|
||||||
role: 'system',
|
|
||||||
content: promptSuffix,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (this.isGpt3) {
|
|
||||||
instructionsPayload.role = 'user';
|
|
||||||
messagePayload.role = 'user';
|
|
||||||
instructionsPayload.content += `\n${promptSuffix}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
// testing if this works with browser endpoint
|
|
||||||
if (!this.isGpt3 && this.options.reverseProxyUrl) {
|
|
||||||
instructionsPayload.role = 'user';
|
|
||||||
}
|
|
||||||
|
|
||||||
let currentTokenCount =
|
|
||||||
this.getTokenCountForMessage(instructionsPayload) +
|
|
||||||
this.getTokenCountForMessage(messagePayload);
|
|
||||||
|
|
||||||
let promptBody = '';
|
|
||||||
const maxTokenCount = this.maxPromptTokens;
|
|
||||||
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
|
|
||||||
// Do this within a recursive async function so that it doesn't block the event loop for too long.
|
|
||||||
const buildPromptBody = async () => {
|
|
||||||
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
|
|
||||||
const message = orderedMessages.pop();
|
|
||||||
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
|
|
||||||
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
|
|
||||||
let messageString = `${this.startToken}${roleLabel}:\n${
|
|
||||||
message.text ?? message.content ?? ''
|
|
||||||
}${this.endToken}\n`;
|
|
||||||
let newPromptBody = `${messageString}${promptBody}`;
|
|
||||||
|
|
||||||
const tokenCountForMessage = this.getTokenCount(messageString);
|
|
||||||
const newTokenCount = currentTokenCount + tokenCountForMessage;
|
|
||||||
if (newTokenCount > maxTokenCount) {
|
|
||||||
if (promptBody) {
|
|
||||||
// This message would put us over the token limit, so don't add it.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// This is the first message, so we can't add it. Just throw an error.
|
|
||||||
throw new Error(
|
|
||||||
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
promptBody = newPromptBody;
|
|
||||||
currentTokenCount = newTokenCount;
|
|
||||||
// wait for next tick to avoid blocking the event loop
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
|
||||||
return buildPromptBody();
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
await buildPromptBody();
|
|
||||||
const prompt = promptBody;
|
|
||||||
messagePayload.content = prompt;
|
|
||||||
// Add 2 tokens for metadata after all messages have been counted.
|
|
||||||
currentTokenCount += 2;
|
|
||||||
|
|
||||||
if (this.isGpt3 && messagePayload.content.length > 0) {
|
|
||||||
const context = 'Chat History:\n';
|
|
||||||
messagePayload.content = `${context}${prompt}`;
|
|
||||||
currentTokenCount += this.getTokenCount(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
|
||||||
this.modelOptions.max_tokens = Math.min(
|
|
||||||
this.maxContextTokens - currentTokenCount,
|
|
||||||
this.maxResponseTokens,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (this.isGpt3) {
|
|
||||||
messagePayload.content += promptSuffix;
|
|
||||||
return [instructionsPayload, messagePayload];
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = [messagePayload, instructionsPayload];
|
|
||||||
|
|
||||||
if (this.functionsAgent && !this.isGpt3) {
|
|
||||||
result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.filter((message) => message.content.length > 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports = PluginsClient;
|
|
||||||
|
|
@ -1,15 +1,11 @@
|
||||||
const ChatGPTClient = require('./ChatGPTClient');
|
|
||||||
const OpenAIClient = require('./OpenAIClient');
|
const OpenAIClient = require('./OpenAIClient');
|
||||||
const PluginsClient = require('./PluginsClient');
|
|
||||||
const GoogleClient = require('./GoogleClient');
|
const GoogleClient = require('./GoogleClient');
|
||||||
const TextStream = require('./TextStream');
|
const TextStream = require('./TextStream');
|
||||||
const AnthropicClient = require('./AnthropicClient');
|
const AnthropicClient = require('./AnthropicClient');
|
||||||
const toolUtils = require('./tools/util');
|
const toolUtils = require('./tools/util');
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
ChatGPTClient,
|
|
||||||
OpenAIClient,
|
OpenAIClient,
|
||||||
PluginsClient,
|
|
||||||
GoogleClient,
|
GoogleClient,
|
||||||
TextStream,
|
TextStream,
|
||||||
AnthropicClient,
|
AnthropicClient,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
const { isEnabled } = require('~/server/utils');
|
const { isEnabled } = require('@librechat/api');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||||
|
|
||||||
const footer = `Use the context as your learned knowledge to better answer the user.
|
const footer = `Use the context as your learned knowledge to better answer the user.
|
||||||
|
|
||||||
|
|
@ -18,7 +19,7 @@ function createContextHandlers(req, userMessageContent) {
|
||||||
const queryPromises = [];
|
const queryPromises = [];
|
||||||
const processedFiles = [];
|
const processedFiles = [];
|
||||||
const processedIds = new Set();
|
const processedIds = new Set();
|
||||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
const jwtToken = generateShortLivedToken(req.user.id);
|
||||||
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
|
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
|
||||||
|
|
||||||
const query = async (file) => {
|
const query = async (file) => {
|
||||||
|
|
@ -96,35 +97,35 @@ function createContextHandlers(req, userMessageContent) {
|
||||||
resolvedQueries.length === 0
|
resolvedQueries.length === 0
|
||||||
? '\n\tThe semantic search did not return any results.'
|
? '\n\tThe semantic search did not return any results.'
|
||||||
: resolvedQueries
|
: resolvedQueries
|
||||||
.map((queryResult, index) => {
|
.map((queryResult, index) => {
|
||||||
const file = processedFiles[index];
|
const file = processedFiles[index];
|
||||||
let contextItems = queryResult.data;
|
let contextItems = queryResult.data;
|
||||||
|
|
||||||
const generateContext = (currentContext) =>
|
const generateContext = (currentContext) =>
|
||||||
`
|
`
|
||||||
<file>
|
<file>
|
||||||
<filename>${file.filename}</filename>
|
<filename>${file.filename}</filename>
|
||||||
<context>${currentContext}
|
<context>${currentContext}
|
||||||
</context>
|
</context>
|
||||||
</file>`;
|
</file>`;
|
||||||
|
|
||||||
if (useFullContext) {
|
if (useFullContext) {
|
||||||
return generateContext(`\n${contextItems}`);
|
return generateContext(`\n${contextItems}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
contextItems = queryResult.data
|
contextItems = queryResult.data
|
||||||
.map((item) => {
|
.map((item) => {
|
||||||
const pageContent = item[0].page_content;
|
const pageContent = item[0].page_content;
|
||||||
return `
|
return `
|
||||||
<contextItem>
|
<contextItem>
|
||||||
<![CDATA[${pageContent?.trim()}]]>
|
<![CDATA[${pageContent?.trim()}]]>
|
||||||
</contextItem>`;
|
</contextItem>`;
|
||||||
})
|
})
|
||||||
.join('');
|
.join('');
|
||||||
|
|
||||||
return generateContext(contextItems);
|
return generateContext(contextItems);
|
||||||
})
|
})
|
||||||
.join('');
|
.join('');
|
||||||
|
|
||||||
if (useFullContext) {
|
if (useFullContext) {
|
||||||
const prompt = `${header}
|
const prompt = `${header}
|
||||||
|
|
|
||||||
|
|
@ -531,44 +531,6 @@ describe('OpenAIClient', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('sendMessage/getCompletion/chatCompletion', () => {
|
|
||||||
afterEach(() => {
|
|
||||||
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
|
||||||
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
|
|
||||||
const model = 'text-davinci-003';
|
|
||||||
const onProgress = jest.fn().mockImplementation(() => ({}));
|
|
||||||
|
|
||||||
const testClient = new OpenAIClient('test-api-key', {
|
|
||||||
...defaultOptions,
|
|
||||||
modelOptions: { model },
|
|
||||||
});
|
|
||||||
|
|
||||||
const getCompletion = jest.spyOn(testClient, 'getCompletion');
|
|
||||||
await testClient.sendMessage('Hi mom!', { onProgress });
|
|
||||||
|
|
||||||
expect(getCompletion).toHaveBeenCalled();
|
|
||||||
expect(getCompletion.mock.calls.length).toBe(1);
|
|
||||||
|
|
||||||
expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
|
|
||||||
|
|
||||||
expect(fetchEventSource).toHaveBeenCalled();
|
|
||||||
expect(fetchEventSource.mock.calls.length).toBe(1);
|
|
||||||
|
|
||||||
// Check if the first argument (url) is correct
|
|
||||||
const firstCallArgs = fetchEventSource.mock.calls[0];
|
|
||||||
|
|
||||||
const expectedURL = 'https://api.openai.com/v1/completions';
|
|
||||||
expect(firstCallArgs[0]).toBe(expectedURL);
|
|
||||||
|
|
||||||
const requestBody = JSON.parse(firstCallArgs[1].body);
|
|
||||||
expect(requestBody).toHaveProperty('model');
|
|
||||||
expect(requestBody.model).toBe(model);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('checkVisionRequest functionality', () => {
|
describe('checkVisionRequest functionality', () => {
|
||||||
let client;
|
let client;
|
||||||
const attachments = [{ type: 'image/png' }];
|
const attachments = [{ type: 'image/png' }];
|
||||||
|
|
|
||||||
|
|
@ -1,314 +0,0 @@
|
||||||
const crypto = require('crypto');
|
|
||||||
const { Constants } = require('librechat-data-provider');
|
|
||||||
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
|
|
||||||
const PluginsClient = require('../PluginsClient');
|
|
||||||
|
|
||||||
jest.mock('~/db/connect');
|
|
||||||
jest.mock('~/models/Conversation', () => {
|
|
||||||
return function () {
|
|
||||||
return {
|
|
||||||
save: jest.fn(),
|
|
||||||
deleteConvos: jest.fn(),
|
|
||||||
};
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const defaultAzureOptions = {
|
|
||||||
azureOpenAIApiInstanceName: 'your-instance-name',
|
|
||||||
azureOpenAIApiDeploymentName: 'your-deployment-name',
|
|
||||||
azureOpenAIApiVersion: '2020-07-01-preview',
|
|
||||||
};
|
|
||||||
|
|
||||||
describe('PluginsClient', () => {
|
|
||||||
let TestAgent;
|
|
||||||
let options = {
|
|
||||||
tools: [],
|
|
||||||
modelOptions: {
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
temperature: 0,
|
|
||||||
max_tokens: 2,
|
|
||||||
},
|
|
||||||
agentOptions: {
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let parentMessageId;
|
|
||||||
let conversationId;
|
|
||||||
const fakeMessages = [];
|
|
||||||
const userMessage = 'Hello, ChatGPT!';
|
|
||||||
const apiKey = 'fake-api-key';
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
TestAgent = new PluginsClient(apiKey, options);
|
|
||||||
TestAgent.loadHistory = jest
|
|
||||||
.fn()
|
|
||||||
.mockImplementation((conversationId, parentMessageId = null) => {
|
|
||||||
if (!conversationId) {
|
|
||||||
TestAgent.currentMessages = [];
|
|
||||||
return Promise.resolve([]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
|
|
||||||
messages: fakeMessages,
|
|
||||||
parentMessageId,
|
|
||||||
});
|
|
||||||
|
|
||||||
const chatMessages = orderedMessages.map((msg) =>
|
|
||||||
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
|
|
||||||
? new HumanMessage(msg.text)
|
|
||||||
: new AIMessage(msg.text),
|
|
||||||
);
|
|
||||||
|
|
||||||
TestAgent.currentMessages = orderedMessages;
|
|
||||||
return Promise.resolve(chatMessages);
|
|
||||||
});
|
|
||||||
TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
|
|
||||||
if (opts && typeof opts === 'object') {
|
|
||||||
TestAgent.setOptions(opts);
|
|
||||||
}
|
|
||||||
const conversationId = opts.conversationId || crypto.randomUUID();
|
|
||||||
const parentMessageId = opts.parentMessageId || Constants.NO_PARENT;
|
|
||||||
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
|
|
||||||
this.pastMessages = await TestAgent.loadHistory(
|
|
||||||
conversationId,
|
|
||||||
TestAgent.options?.parentMessageId,
|
|
||||||
);
|
|
||||||
|
|
||||||
const userMessage = {
|
|
||||||
text: message,
|
|
||||||
sender: 'ChatGPT',
|
|
||||||
isCreatedByUser: true,
|
|
||||||
messageId: userMessageId,
|
|
||||||
parentMessageId,
|
|
||||||
conversationId,
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = {
|
|
||||||
sender: 'ChatGPT',
|
|
||||||
text: 'Hello, User!',
|
|
||||||
isCreatedByUser: false,
|
|
||||||
messageId: crypto.randomUUID(),
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
conversationId,
|
|
||||||
};
|
|
||||||
|
|
||||||
fakeMessages.push(userMessage);
|
|
||||||
fakeMessages.push(response);
|
|
||||||
return response;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
test('initializes PluginsClient without crashing', () => {
|
|
||||||
expect(TestAgent).toBeInstanceOf(PluginsClient);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('check setOptions function', () => {
|
|
||||||
expect(TestAgent.agentIsGpt3).toBe(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('sendMessage', () => {
|
|
||||||
test('sendMessage should return a response message', async () => {
|
|
||||||
const expectedResult = expect.objectContaining({
|
|
||||||
sender: 'ChatGPT',
|
|
||||||
text: expect.any(String),
|
|
||||||
isCreatedByUser: false,
|
|
||||||
messageId: expect.any(String),
|
|
||||||
parentMessageId: expect.any(String),
|
|
||||||
conversationId: expect.any(String),
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await TestAgent.sendMessage(userMessage);
|
|
||||||
parentMessageId = response.messageId;
|
|
||||||
conversationId = response.conversationId;
|
|
||||||
expect(response).toEqual(expectedResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
|
|
||||||
const userMessage = 'Second message in the conversation';
|
|
||||||
const opts = {
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
};
|
|
||||||
|
|
||||||
const expectedResult = expect.objectContaining({
|
|
||||||
sender: 'ChatGPT',
|
|
||||||
text: expect.any(String),
|
|
||||||
isCreatedByUser: false,
|
|
||||||
messageId: expect.any(String),
|
|
||||||
parentMessageId: expect.any(String),
|
|
||||||
conversationId: opts.conversationId,
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await TestAgent.sendMessage(userMessage, opts);
|
|
||||||
parentMessageId = response.messageId;
|
|
||||||
expect(response.conversationId).toEqual(conversationId);
|
|
||||||
expect(response).toEqual(expectedResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should return chat history', async () => {
|
|
||||||
const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId);
|
|
||||||
expect(TestAgent.currentMessages).toHaveLength(4);
|
|
||||||
expect(chatMessages[0].text).toEqual(userMessage);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getFunctionModelName', () => {
|
|
||||||
let client;
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
client = new PluginsClient('dummy_api_key');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should return the input when it includes a dash followed by four digits', () => {
|
|
||||||
expect(client.getFunctionModelName('-1234')).toBe('-1234');
|
|
||||||
expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should return the input for all function-capable models (`0613` models and above)', () => {
|
|
||||||
expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613');
|
|
||||||
expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
|
|
||||||
expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613');
|
|
||||||
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613');
|
|
||||||
expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
|
|
||||||
expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview');
|
|
||||||
expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should return the corresponding model if input is non-function capable (`0314` models)', () => {
|
|
||||||
expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4');
|
|
||||||
expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4');
|
|
||||||
expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo');
|
|
||||||
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => {
|
|
||||||
expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should return "gpt-4" when the input includes "gpt-4"', () => {
|
|
||||||
expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => {
|
|
||||||
expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo');
|
|
||||||
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('Azure OpenAI tests specific to Plugins', () => {
|
|
||||||
// TODO: add more tests for Azure OpenAI integration with Plugins
|
|
||||||
// let client;
|
|
||||||
// beforeEach(() => {
|
|
||||||
// client = new PluginsClient('dummy_api_key');
|
|
||||||
// });
|
|
||||||
|
|
||||||
test('should not call getFunctionModelName when azure options are set', () => {
|
|
||||||
const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName');
|
|
||||||
const model = 'gpt-4-turbo';
|
|
||||||
|
|
||||||
// note, without the azure change in PR #1766, `getFunctionModelName` is called twice
|
|
||||||
const testClient = new PluginsClient('dummy_api_key', {
|
|
||||||
agentOptions: {
|
|
||||||
model,
|
|
||||||
agent: 'functions',
|
|
||||||
},
|
|
||||||
azure: defaultAzureOptions,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(spy).not.toHaveBeenCalled();
|
|
||||||
expect(testClient.agentOptions.model).toBe(model);
|
|
||||||
|
|
||||||
spy.mockRestore();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('sendMessage with filtered tools', () => {
|
|
||||||
let TestAgent;
|
|
||||||
const apiKey = 'fake-api-key';
|
|
||||||
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
TestAgent = new PluginsClient(apiKey, {
|
|
||||||
tools: mockTools,
|
|
||||||
modelOptions: {
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
temperature: 0,
|
|
||||||
max_tokens: 2,
|
|
||||||
},
|
|
||||||
agentOptions: {
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
TestAgent.options.req = {
|
|
||||||
app: {
|
|
||||||
locals: {},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
|
|
||||||
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
|
|
||||||
|
|
||||||
if (includedTools.length > 0) {
|
|
||||||
const tools = TestAgent.options.tools.filter((plugin) =>
|
|
||||||
includedTools.includes(plugin.name),
|
|
||||||
);
|
|
||||||
TestAgent.options.tools = tools;
|
|
||||||
} else {
|
|
||||||
const tools = TestAgent.options.tools.filter(
|
|
||||||
(plugin) => !filteredTools.includes(plugin.name),
|
|
||||||
);
|
|
||||||
TestAgent.options.tools = tools;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
text: 'Mocked response',
|
|
||||||
tools: TestAgent.options.tools,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should filter out tools when filteredTools is provided', async () => {
|
|
||||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
|
||||||
const response = await TestAgent.sendMessage('Test message');
|
|
||||||
expect(response.tools).toHaveLength(2);
|
|
||||||
expect(response.tools).toEqual(
|
|
||||||
expect.arrayContaining([
|
|
||||||
expect.objectContaining({ name: 'tool2' }),
|
|
||||||
expect.objectContaining({ name: 'tool4' }),
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should only include specified tools when includedTools is provided', async () => {
|
|
||||||
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
|
|
||||||
const response = await TestAgent.sendMessage('Test message');
|
|
||||||
expect(response.tools).toHaveLength(2);
|
|
||||||
expect(response.tools).toEqual(
|
|
||||||
expect.arrayContaining([
|
|
||||||
expect.objectContaining({ name: 'tool2' }),
|
|
||||||
expect.objectContaining({ name: 'tool4' }),
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should prioritize includedTools over filteredTools', async () => {
|
|
||||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
|
||||||
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
|
|
||||||
const response = await TestAgent.sendMessage('Test message');
|
|
||||||
expect(response.tools).toHaveLength(2);
|
|
||||||
expect(response.tools).toEqual(
|
|
||||||
expect.arrayContaining([
|
|
||||||
expect.objectContaining({ name: 'tool1' }),
|
|
||||||
expect.objectContaining({ name: 'tool2' }),
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('should not modify tools when no filters are provided', async () => {
|
|
||||||
const response = await TestAgent.sendMessage('Test message');
|
|
||||||
expect(response.tools).toHaveLength(4);
|
|
||||||
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
@ -107,6 +107,12 @@ const getImageEditPromptDescription = () => {
|
||||||
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
|
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function createAbortHandler() {
|
||||||
|
return function () {
|
||||||
|
logger.debug('[ImageGenOAI] Image generation aborted');
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates OpenAI Image tools (generation and editing)
|
* Creates OpenAI Image tools (generation and editing)
|
||||||
* @param {Object} fields - Configuration fields
|
* @param {Object} fields - Configuration fields
|
||||||
|
|
@ -201,10 +207,18 @@ function createOpenAIImageTools(fields = {}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp;
|
let resp;
|
||||||
|
/** @type {AbortSignal} */
|
||||||
|
let derivedSignal = null;
|
||||||
|
/** @type {() => void} */
|
||||||
|
let abortHandler = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const derivedSignal = runnableConfig?.signal
|
if (runnableConfig?.signal) {
|
||||||
? AbortSignal.any([runnableConfig.signal])
|
derivedSignal = AbortSignal.any([runnableConfig.signal]);
|
||||||
: undefined;
|
abortHandler = createAbortHandler();
|
||||||
|
derivedSignal.addEventListener('abort', abortHandler, { once: true });
|
||||||
|
}
|
||||||
|
|
||||||
resp = await openai.images.generate(
|
resp = await openai.images.generate(
|
||||||
{
|
{
|
||||||
model: 'gpt-image-1',
|
model: 'gpt-image-1',
|
||||||
|
|
@ -228,6 +242,10 @@ function createOpenAIImageTools(fields = {}) {
|
||||||
logAxiosError({ error, message });
|
logAxiosError({ error, message });
|
||||||
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
|
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
|
||||||
Error Message: ${error.message}`);
|
Error Message: ${error.message}`);
|
||||||
|
} finally {
|
||||||
|
if (abortHandler && derivedSignal) {
|
||||||
|
derivedSignal.removeEventListener('abort', abortHandler);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!resp) {
|
if (!resp) {
|
||||||
|
|
@ -409,10 +427,17 @@ Error Message: ${error.message}`);
|
||||||
headers['Authorization'] = `Bearer ${apiKey}`;
|
headers['Authorization'] = `Bearer ${apiKey}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** @type {AbortSignal} */
|
||||||
|
let derivedSignal = null;
|
||||||
|
/** @type {() => void} */
|
||||||
|
let abortHandler = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const derivedSignal = runnableConfig?.signal
|
if (runnableConfig?.signal) {
|
||||||
? AbortSignal.any([runnableConfig.signal])
|
derivedSignal = AbortSignal.any([runnableConfig.signal]);
|
||||||
: undefined;
|
abortHandler = createAbortHandler();
|
||||||
|
derivedSignal.addEventListener('abort', abortHandler, { once: true });
|
||||||
|
}
|
||||||
|
|
||||||
/** @type {import('axios').AxiosRequestConfig} */
|
/** @type {import('axios').AxiosRequestConfig} */
|
||||||
const axiosConfig = {
|
const axiosConfig = {
|
||||||
|
|
@ -467,6 +492,10 @@ Error Message: ${error.message}`);
|
||||||
logAxiosError({ error, message });
|
logAxiosError({ error, message });
|
||||||
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
|
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
|
||||||
Error Message: ${error.message || 'Unknown error'}`);
|
Error Message: ${error.message || 'Unknown error'}`);
|
||||||
|
} finally {
|
||||||
|
if (abortHandler && derivedSignal) {
|
||||||
|
derivedSignal.removeEventListener('abort', abortHandler);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
const { z } = require('zod');
|
const { z } = require('zod');
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
const { tool } = require('@langchain/core/tools');
|
const { tool } = require('@langchain/core/tools');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { Tools, EToolResources } = require('librechat-data-provider');
|
const { Tools, EToolResources } = require('librechat-data-provider');
|
||||||
|
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||||
const { getFiles } = require('~/models/File');
|
const { getFiles } = require('~/models/File');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
@ -59,7 +60,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||||
if (files.length === 0) {
|
if (files.length === 0) {
|
||||||
return 'No files to search. Instruct the user to add files for the search.';
|
return 'No files to search. Instruct the user to add files for the search.';
|
||||||
}
|
}
|
||||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
const jwtToken = generateShortLivedToken(req.user.id);
|
||||||
if (!jwtToken) {
|
if (!jwtToken) {
|
||||||
return 'There was an error authenticating the file search request.';
|
return 'There was an error authenticating the file search request.';
|
||||||
}
|
}
|
||||||
|
|
|
||||||
3
api/cache/banViolation.js
vendored
3
api/cache/banViolation.js
vendored
|
|
@ -1,7 +1,8 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { isEnabled, math } = require('@librechat/api');
|
||||||
const { ViolationTypes } = require('librechat-data-provider');
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
const { isEnabled, math, removePorts } = require('~/server/utils');
|
|
||||||
const { deleteAllUserSessions } = require('~/models');
|
const { deleteAllUserSessions } = require('~/models');
|
||||||
|
const { removePorts } = require('~/server/utils');
|
||||||
const getLogStores = require('./getLogStores');
|
const getLogStores = require('./getLogStores');
|
||||||
|
|
||||||
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
|
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
|
||||||
|
|
|
||||||
2
api/cache/getLogStores.js
vendored
2
api/cache/getLogStores.js
vendored
|
|
@ -1,7 +1,7 @@
|
||||||
const { Keyv } = require('keyv');
|
const { Keyv } = require('keyv');
|
||||||
|
const { isEnabled, math } = require('@librechat/api');
|
||||||
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
||||||
const { logFile, violationFile } = require('./keyvFiles');
|
const { logFile, violationFile } = require('./keyvFiles');
|
||||||
const { isEnabled, math } = require('~/server/utils');
|
|
||||||
const keyvRedis = require('./keyvRedis');
|
const keyvRedis = require('./keyvRedis');
|
||||||
const keyvMongo = require('./keyvMongo');
|
const keyvMongo = require('./keyvMongo');
|
||||||
|
|
||||||
|
|
|
||||||
2
api/cache/logViolation.js
vendored
2
api/cache/logViolation.js
vendored
|
|
@ -9,7 +9,7 @@ const banViolation = require('./banViolation');
|
||||||
* @param {Object} res - Express response object.
|
* @param {Object} res - Express response object.
|
||||||
* @param {string} type - The type of violation.
|
* @param {string} type - The type of violation.
|
||||||
* @param {Object} errorMessage - The error message to log.
|
* @param {Object} errorMessage - The error message to log.
|
||||||
* @param {number} [score=1] - The severity of the violation. Defaults to 1
|
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
|
||||||
*/
|
*/
|
||||||
const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
||||||
const userId = req.user?.id ?? req.user?._id;
|
const userId = req.user?.id ?? req.user?._id;
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,9 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||||
if (ephemeralAgent?.execute_code === true) {
|
if (ephemeralAgent?.execute_code === true) {
|
||||||
tools.push(Tools.execute_code);
|
tools.push(Tools.execute_code);
|
||||||
}
|
}
|
||||||
|
if (ephemeralAgent?.file_search === true) {
|
||||||
|
tools.push(Tools.file_search);
|
||||||
|
}
|
||||||
if (ephemeralAgent?.web_search === true) {
|
if (ephemeralAgent?.web_search === true) {
|
||||||
tools.push(Tools.web_search);
|
tools.push(Tools.web_search);
|
||||||
}
|
}
|
||||||
|
|
@ -87,7 +90,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||||
}
|
}
|
||||||
|
|
||||||
const instructions = req.body.promptPrefix;
|
const instructions = req.body.promptPrefix;
|
||||||
return {
|
const result = {
|
||||||
id: agent_id,
|
id: agent_id,
|
||||||
instructions,
|
instructions,
|
||||||
provider: endpoint,
|
provider: endpoint,
|
||||||
|
|
@ -95,6 +98,11 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||||
model,
|
model,
|
||||||
tools,
|
tools,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||||
|
result.artifacts = ephemeralAgent.artifacts;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ describe('models/Agent', () => {
|
||||||
const mongoUri = mongoServer.getUri();
|
const mongoUri = mongoServer.getUri();
|
||||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
await mongoose.connect(mongoUri);
|
await mongoose.connect(mongoUri);
|
||||||
});
|
}, 20000);
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
await mongoose.disconnect();
|
await mongoose.disconnect();
|
||||||
|
|
@ -413,7 +413,7 @@ describe('models/Agent', () => {
|
||||||
const mongoUri = mongoServer.getUri();
|
const mongoUri = mongoServer.getUri();
|
||||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
await mongoose.connect(mongoUri);
|
await mongoose.connect(mongoUri);
|
||||||
});
|
}, 20000);
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
await mongoose.disconnect();
|
await mongoose.disconnect();
|
||||||
|
|
@ -670,7 +670,7 @@ describe('models/Agent', () => {
|
||||||
const mongoUri = mongoServer.getUri();
|
const mongoUri = mongoServer.getUri();
|
||||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
await mongoose.connect(mongoUri);
|
await mongoose.connect(mongoUri);
|
||||||
});
|
}, 20000);
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
await mongoose.disconnect();
|
await mongoose.disconnect();
|
||||||
|
|
@ -1332,7 +1332,7 @@ describe('models/Agent', () => {
|
||||||
const mongoUri = mongoServer.getUri();
|
const mongoUri = mongoServer.getUri();
|
||||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
await mongoose.connect(mongoUri);
|
await mongoose.connect(mongoUri);
|
||||||
});
|
}, 20000);
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
await mongoose.disconnect();
|
await mongoose.disconnect();
|
||||||
|
|
@ -1514,7 +1514,7 @@ describe('models/Agent', () => {
|
||||||
const mongoUri = mongoServer.getUri();
|
const mongoUri = mongoServer.getUri();
|
||||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
await mongoose.connect(mongoUri);
|
await mongoose.connect(mongoUri);
|
||||||
});
|
}, 20000);
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
await mongoose.disconnect();
|
await mongoose.disconnect();
|
||||||
|
|
@ -1798,7 +1798,7 @@ describe('models/Agent', () => {
|
||||||
const mongoUri = mongoServer.getUri();
|
const mongoUri = mongoServer.getUri();
|
||||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
await mongoose.connect(mongoUri);
|
await mongoose.connect(mongoUri);
|
||||||
});
|
}, 20000);
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
await mongoose.disconnect();
|
await mongoose.disconnect();
|
||||||
|
|
@ -2350,7 +2350,7 @@ describe('models/Agent', () => {
|
||||||
const mongoUri = mongoServer.getUri();
|
const mongoUri = mongoServer.getUri();
|
||||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
await mongoose.connect(mongoUri);
|
await mongoose.connect(mongoUri);
|
||||||
});
|
}, 20000);
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
await mongoose.disconnect();
|
await mongoose.disconnect();
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||||
|
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||||
const { getMessages, deleteMessages } = require('./Message');
|
const { getMessages, deleteMessages } = require('./Message');
|
||||||
const { Conversation } = require('~/db/models');
|
const { Conversation } = require('~/db/models');
|
||||||
|
|
||||||
|
|
@ -98,10 +100,15 @@ module.exports = {
|
||||||
update.conversationId = newConversationId;
|
update.conversationId = newConversationId;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (req.body.isTemporary) {
|
if (req?.body?.isTemporary) {
|
||||||
const expiredAt = new Date();
|
try {
|
||||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
const customConfig = await getCustomConfig();
|
||||||
update.expiredAt = expiredAt;
|
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Error creating temporary chat expiration date:', err);
|
||||||
|
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
|
||||||
|
update.expiredAt = null;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
update.expiredAt = null;
|
update.expiredAt = null;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { EToolResources } = require('librechat-data-provider');
|
const { EToolResources, FileContext } = require('librechat-data-provider');
|
||||||
const { File } = require('~/db/models');
|
const { File } = require('~/db/models');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -32,19 +32,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||||
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
|
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
|
||||||
*/
|
*/
|
||||||
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
|
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
|
||||||
if (!fileIds || !fileIds.length) {
|
if (!fileIds || !fileIds.length || !toolResourceSet?.size) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const filter = {
|
const filter = {
|
||||||
file_id: { $in: fileIds },
|
file_id: { $in: fileIds },
|
||||||
|
$or: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
if (toolResourceSet.size) {
|
if (toolResourceSet.has(EToolResources.ocr)) {
|
||||||
filter.$or = [];
|
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (toolResourceSet.has(EToolResources.file_search)) {
|
if (toolResourceSet.has(EToolResources.file_search)) {
|
||||||
filter.$or.push({ embedded: true });
|
filter.$or.push({ embedded: true });
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
const { z } = require('zod');
|
const { z } = require('zod');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||||
|
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||||
const { Message } = require('~/db/models');
|
const { Message } = require('~/db/models');
|
||||||
|
|
||||||
const idSchema = z.string().uuid();
|
const idSchema = z.string().uuid();
|
||||||
|
|
@ -54,9 +56,14 @@ async function saveMessage(req, params, metadata) {
|
||||||
};
|
};
|
||||||
|
|
||||||
if (req?.body?.isTemporary) {
|
if (req?.body?.isTemporary) {
|
||||||
const expiredAt = new Date();
|
try {
|
||||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
const customConfig = await getCustomConfig();
|
||||||
update.expiredAt = expiredAt;
|
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Error creating temporary chat expiration date:', err);
|
||||||
|
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||||
|
update.expiredAt = null;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
update.expiredAt = null;
|
update.expiredAt = null;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "@librechat/backend",
|
"name": "@librechat/backend",
|
||||||
"version": "v0.7.8",
|
"version": "v0.7.9-rc1",
|
||||||
"description": "",
|
"description": "",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "echo 'please run this from the root directory'",
|
"start": "echo 'please run this from the root directory'",
|
||||||
|
|
@ -48,14 +48,13 @@
|
||||||
"@langchain/google-genai": "^0.2.13",
|
"@langchain/google-genai": "^0.2.13",
|
||||||
"@langchain/google-vertexai": "^0.2.13",
|
"@langchain/google-vertexai": "^0.2.13",
|
||||||
"@langchain/textsplitters": "^0.1.0",
|
"@langchain/textsplitters": "^0.1.0",
|
||||||
"@librechat/agents": "^2.4.41",
|
"@librechat/agents": "^2.4.56",
|
||||||
"@librechat/api": "*",
|
"@librechat/api": "*",
|
||||||
"@librechat/data-schemas": "*",
|
"@librechat/data-schemas": "*",
|
||||||
"@node-saml/passport-saml": "^5.0.0",
|
"@node-saml/passport-saml": "^5.0.0",
|
||||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||||
"axios": "^1.8.2",
|
"axios": "^1.8.2",
|
||||||
"bcryptjs": "^2.4.3",
|
"bcryptjs": "^2.4.3",
|
||||||
"cohere-ai": "^7.9.1",
|
|
||||||
"compression": "^1.7.4",
|
"compression": "^1.7.4",
|
||||||
"connect-redis": "^7.1.0",
|
"connect-redis": "^7.1.0",
|
||||||
"cookie": "^0.7.2",
|
"cookie": "^0.7.2",
|
||||||
|
|
|
||||||
|
|
@ -169,9 +169,6 @@ function disposeClient(client) {
|
||||||
client.isGenerativeModel = null;
|
client.isGenerativeModel = null;
|
||||||
}
|
}
|
||||||
// Properties specific to OpenAIClient
|
// Properties specific to OpenAIClient
|
||||||
if (client.ChatGPTClient) {
|
|
||||||
client.ChatGPTClient = null;
|
|
||||||
}
|
|
||||||
if (client.completionsUrl) {
|
if (client.completionsUrl) {
|
||||||
client.completionsUrl = null;
|
client.completionsUrl = null;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,282 +0,0 @@
|
||||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
|
||||||
const {
|
|
||||||
handleAbortError,
|
|
||||||
createAbortController,
|
|
||||||
cleanupAbortController,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
const {
|
|
||||||
disposeClient,
|
|
||||||
processReqData,
|
|
||||||
clientRegistry,
|
|
||||||
requestDataMap,
|
|
||||||
} = require('~/server/cleanup');
|
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
|
||||||
const { saveMessage } = require('~/models');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|
||||||
let {
|
|
||||||
text,
|
|
||||||
endpointOption,
|
|
||||||
conversationId,
|
|
||||||
modelDisplayLabel,
|
|
||||||
parentMessageId = null,
|
|
||||||
overrideParentMessageId = null,
|
|
||||||
} = req.body;
|
|
||||||
|
|
||||||
let client = null;
|
|
||||||
let abortKey = null;
|
|
||||||
let cleanupHandlers = [];
|
|
||||||
let clientRef = null;
|
|
||||||
|
|
||||||
logger.debug('[AskController]', {
|
|
||||||
text,
|
|
||||||
conversationId,
|
|
||||||
...endpointOption,
|
|
||||||
modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
|
|
||||||
});
|
|
||||||
|
|
||||||
let userMessage = null;
|
|
||||||
let userMessagePromise = null;
|
|
||||||
let promptTokens = null;
|
|
||||||
let userMessageId = null;
|
|
||||||
let responseMessageId = null;
|
|
||||||
let getAbortData = null;
|
|
||||||
|
|
||||||
const sender = getResponseSender({
|
|
||||||
...endpointOption,
|
|
||||||
model: endpointOption.modelOptions.model,
|
|
||||||
modelDisplayLabel,
|
|
||||||
});
|
|
||||||
const initialConversationId = conversationId;
|
|
||||||
const newConvo = !initialConversationId;
|
|
||||||
const userId = req.user.id;
|
|
||||||
|
|
||||||
let reqDataContext = {
|
|
||||||
userMessage,
|
|
||||||
userMessagePromise,
|
|
||||||
responseMessageId,
|
|
||||||
promptTokens,
|
|
||||||
conversationId,
|
|
||||||
userMessageId,
|
|
||||||
};
|
|
||||||
|
|
||||||
const updateReqData = (data = {}) => {
|
|
||||||
reqDataContext = processReqData(data, reqDataContext);
|
|
||||||
abortKey = reqDataContext.abortKey;
|
|
||||||
userMessage = reqDataContext.userMessage;
|
|
||||||
userMessagePromise = reqDataContext.userMessagePromise;
|
|
||||||
responseMessageId = reqDataContext.responseMessageId;
|
|
||||||
promptTokens = reqDataContext.promptTokens;
|
|
||||||
conversationId = reqDataContext.conversationId;
|
|
||||||
userMessageId = reqDataContext.userMessageId;
|
|
||||||
};
|
|
||||||
|
|
||||||
let { onProgress: progressCallback, getPartialText } = createOnProgress();
|
|
||||||
|
|
||||||
const performCleanup = () => {
|
|
||||||
logger.debug('[AskController] Performing cleanup');
|
|
||||||
if (Array.isArray(cleanupHandlers)) {
|
|
||||||
for (const handler of cleanupHandlers) {
|
|
||||||
try {
|
|
||||||
if (typeof handler === 'function') {
|
|
||||||
handler();
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
// Ignore
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (abortKey) {
|
|
||||||
logger.debug('[AskController] Cleaning up abort controller');
|
|
||||||
cleanupAbortController(abortKey);
|
|
||||||
abortKey = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (client) {
|
|
||||||
disposeClient(client);
|
|
||||||
client = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
reqDataContext = null;
|
|
||||||
userMessage = null;
|
|
||||||
userMessagePromise = null;
|
|
||||||
promptTokens = null;
|
|
||||||
getAbortData = null;
|
|
||||||
progressCallback = null;
|
|
||||||
endpointOption = null;
|
|
||||||
cleanupHandlers = null;
|
|
||||||
addTitle = null;
|
|
||||||
|
|
||||||
if (requestDataMap.has(req)) {
|
|
||||||
requestDataMap.delete(req);
|
|
||||||
}
|
|
||||||
logger.debug('[AskController] Cleanup completed');
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
({ client } = await initializeClient({ req, res, endpointOption }));
|
|
||||||
if (clientRegistry && client) {
|
|
||||||
clientRegistry.register(client, { userId }, client);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (client) {
|
|
||||||
requestDataMap.set(req, { client });
|
|
||||||
}
|
|
||||||
|
|
||||||
clientRef = new WeakRef(client);
|
|
||||||
|
|
||||||
getAbortData = () => {
|
|
||||||
const currentClient = clientRef?.deref();
|
|
||||||
const currentText =
|
|
||||||
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
|
|
||||||
|
|
||||||
return {
|
|
||||||
sender,
|
|
||||||
conversationId,
|
|
||||||
messageId: reqDataContext.responseMessageId,
|
|
||||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
|
||||||
text: currentText,
|
|
||||||
userMessage: userMessage,
|
|
||||||
userMessagePromise: userMessagePromise,
|
|
||||||
promptTokens: reqDataContext.promptTokens,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const { onStart, abortController } = createAbortController(
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
getAbortData,
|
|
||||||
updateReqData,
|
|
||||||
);
|
|
||||||
|
|
||||||
const closeHandler = () => {
|
|
||||||
logger.debug('[AskController] Request closed');
|
|
||||||
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
abortController.abort();
|
|
||||||
logger.debug('[AskController] Request aborted on close');
|
|
||||||
};
|
|
||||||
|
|
||||||
res.on('close', closeHandler);
|
|
||||||
cleanupHandlers.push(() => {
|
|
||||||
try {
|
|
||||||
res.removeListener('close', closeHandler);
|
|
||||||
} catch (e) {
|
|
||||||
// Ignore
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
const messageOptions = {
|
|
||||||
user: userId,
|
|
||||||
parentMessageId,
|
|
||||||
conversationId: reqDataContext.conversationId,
|
|
||||||
overrideParentMessageId,
|
|
||||||
getReqData: updateReqData,
|
|
||||||
onStart,
|
|
||||||
abortController,
|
|
||||||
progressCallback,
|
|
||||||
progressOptions: {
|
|
||||||
res,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
/** @type {TMessage} */
|
|
||||||
let response = await client.sendMessage(text, messageOptions);
|
|
||||||
response.endpoint = endpointOption.endpoint;
|
|
||||||
|
|
||||||
const databasePromise = response.databasePromise;
|
|
||||||
delete response.databasePromise;
|
|
||||||
|
|
||||||
const { conversation: convoData = {} } = await databasePromise;
|
|
||||||
const conversation = { ...convoData };
|
|
||||||
conversation.title =
|
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
|
||||||
|
|
||||||
const latestUserMessage = reqDataContext.userMessage;
|
|
||||||
|
|
||||||
if (client?.options?.attachments && latestUserMessage) {
|
|
||||||
latestUserMessage.files = client.options.attachments;
|
|
||||||
if (endpointOption?.modelOptions?.model) {
|
|
||||||
conversation.model = endpointOption.modelOptions.model;
|
|
||||||
}
|
|
||||||
delete latestUserMessage.image_urls;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!abortController.signal.aborted) {
|
|
||||||
const finalResponseMessage = { ...response };
|
|
||||||
|
|
||||||
sendMessage(res, {
|
|
||||||
final: true,
|
|
||||||
conversation,
|
|
||||||
title: conversation.title,
|
|
||||||
requestMessage: latestUserMessage,
|
|
||||||
responseMessage: finalResponseMessage,
|
|
||||||
});
|
|
||||||
res.end();
|
|
||||||
|
|
||||||
if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
|
|
||||||
await saveMessage(
|
|
||||||
req,
|
|
||||||
{ ...finalResponseMessage, user: userId },
|
|
||||||
{ context: 'api/server/controllers/AskController.js - response end' },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!client?.skipSaveUserMessage && latestUserMessage) {
|
|
||||||
await saveMessage(req, latestUserMessage, {
|
|
||||||
context: "api/server/controllers/AskController.js - don't skip saving user message",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
|
|
||||||
addTitle(req, {
|
|
||||||
text,
|
|
||||||
response: { ...response },
|
|
||||||
client,
|
|
||||||
})
|
|
||||||
.then(() => {
|
|
||||||
logger.debug('[AskController] Title generation started');
|
|
||||||
})
|
|
||||||
.catch((err) => {
|
|
||||||
logger.error('[AskController] Error in title generation', err);
|
|
||||||
})
|
|
||||||
.finally(() => {
|
|
||||||
logger.debug('[AskController] Title generation completed');
|
|
||||||
performCleanup();
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
performCleanup();
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[AskController] Error handling request', error);
|
|
||||||
let partialText = '';
|
|
||||||
try {
|
|
||||||
const currentClient = clientRef?.deref();
|
|
||||||
partialText =
|
|
||||||
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
|
|
||||||
} catch (getTextError) {
|
|
||||||
logger.error('[AskController] Error calling getText() during error handling', getTextError);
|
|
||||||
}
|
|
||||||
|
|
||||||
handleAbortError(res, req, error, {
|
|
||||||
sender,
|
|
||||||
partialText,
|
|
||||||
conversationId: reqDataContext.conversationId,
|
|
||||||
messageId: reqDataContext.responseMessageId,
|
|
||||||
parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
|
|
||||||
userMessageId: reqDataContext.userMessageId,
|
|
||||||
})
|
|
||||||
.catch((err) => {
|
|
||||||
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
|
|
||||||
})
|
|
||||||
.finally(() => {
|
|
||||||
performCleanup();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = AskController;
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
const cookies = require('cookie');
|
const cookies = require('cookie');
|
||||||
const jwt = require('jsonwebtoken');
|
const jwt = require('jsonwebtoken');
|
||||||
const openIdClient = require('openid-client');
|
const openIdClient = require('openid-client');
|
||||||
|
const { isEnabled } = require('@librechat/api');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
registerUser,
|
|
||||||
resetPassword,
|
|
||||||
setAuthTokens,
|
|
||||||
requestPasswordReset,
|
requestPasswordReset,
|
||||||
setOpenIDAuthTokens,
|
setOpenIDAuthTokens,
|
||||||
|
resetPassword,
|
||||||
|
setAuthTokens,
|
||||||
|
registerUser,
|
||||||
} = require('~/server/services/AuthService');
|
} = require('~/server/services/AuthService');
|
||||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||||
const { getOpenIdConfig } = require('~/strategies');
|
const { getOpenIdConfig } = require('~/strategies');
|
||||||
const { isEnabled } = require('~/server/utils');
|
|
||||||
|
|
||||||
const registrationController = async (req, res) => {
|
const registrationController = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
const { sendEvent } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { getResponseSender } = require('librechat-data-provider');
|
const { getResponseSender } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
handleAbortError,
|
handleAbortError,
|
||||||
|
|
@ -10,9 +12,8 @@ const {
|
||||||
clientRegistry,
|
clientRegistry,
|
||||||
requestDataMap,
|
requestDataMap,
|
||||||
} = require('~/server/cleanup');
|
} = require('~/server/cleanup');
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
const { createOnProgress } = require('~/server/utils');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const EditController = async (req, res, next, initializeClient) => {
|
const EditController = async (req, res, next, initializeClient) => {
|
||||||
let {
|
let {
|
||||||
|
|
@ -84,7 +85,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (abortKey) {
|
if (abortKey) {
|
||||||
logger.debug('[AskController] Cleaning up abort controller');
|
logger.debug('[EditController] Cleaning up abort controller');
|
||||||
cleanupAbortController(abortKey);
|
cleanupAbortController(abortKey);
|
||||||
abortKey = null;
|
abortKey = null;
|
||||||
}
|
}
|
||||||
|
|
@ -198,7 +199,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||||
const finalUserMessage = reqDataContext.userMessage;
|
const finalUserMessage = reqDataContext.userMessage;
|
||||||
const finalResponseMessage = { ...response };
|
const finalResponseMessage = { ...response };
|
||||||
|
|
||||||
sendMessage(res, {
|
sendEvent(res, {
|
||||||
final: true,
|
final: true,
|
||||||
conversation,
|
conversation,
|
||||||
title: conversation.title,
|
title: conversation.title,
|
||||||
|
|
|
||||||
|
|
@ -24,17 +24,23 @@ const handleValidationError = (err, res) => {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// eslint-disable-next-line no-unused-vars
|
module.exports = (err, _req, res, _next) => {
|
||||||
module.exports = (err, req, res, next) => {
|
|
||||||
try {
|
try {
|
||||||
if (err.name === 'ValidationError') {
|
if (err.name === 'ValidationError') {
|
||||||
return (err = handleValidationError(err, res));
|
return handleValidationError(err, res);
|
||||||
}
|
}
|
||||||
if (err.code && err.code == 11000) {
|
if (err.code && err.code == 11000) {
|
||||||
return (err = handleDuplicateKeyError(err, res));
|
return handleDuplicateKeyError(err, res);
|
||||||
}
|
}
|
||||||
} catch (err) {
|
// Special handling for errors like SyntaxError
|
||||||
|
if (err.statusCode && err.body) {
|
||||||
|
return res.status(err.statusCode).send(err.body);
|
||||||
|
}
|
||||||
|
|
||||||
logger.error('ErrorController => error', err);
|
logger.error('ErrorController => error', err);
|
||||||
res.status(500).send('An unknown error occurred.');
|
return res.status(500).send('An unknown error occurred.');
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('ErrorController => processing error', err);
|
||||||
|
return res.status(500).send('Processing error in ErrorController.');
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
241
api/server/controllers/ErrorController.spec.js
Normal file
241
api/server/controllers/ErrorController.spec.js
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
const errorController = require('./ErrorController');
|
||||||
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
// Mock the logger
|
||||||
|
jest.mock('~/config', () => ({
|
||||||
|
logger: {
|
||||||
|
error: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe('ErrorController', () => {
|
||||||
|
let mockReq, mockRes, mockNext;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockReq = {};
|
||||||
|
mockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
send: jest.fn(),
|
||||||
|
};
|
||||||
|
mockNext = jest.fn();
|
||||||
|
logger.error.mockClear();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('ValidationError handling', () => {
|
||||||
|
it('should handle ValidationError with single error', () => {
|
||||||
|
const validationError = {
|
||||||
|
name: 'ValidationError',
|
||||||
|
errors: {
|
||||||
|
email: { message: 'Email is required', path: 'email' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(validationError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
messages: '["Email is required"]',
|
||||||
|
fields: '["email"]',
|
||||||
|
});
|
||||||
|
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle ValidationError with multiple errors', () => {
|
||||||
|
const validationError = {
|
||||||
|
name: 'ValidationError',
|
||||||
|
errors: {
|
||||||
|
email: { message: 'Email is required', path: 'email' },
|
||||||
|
password: { message: 'Password is required', path: 'password' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(validationError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
messages: '"Email is required Password is required"',
|
||||||
|
fields: '["email","password"]',
|
||||||
|
});
|
||||||
|
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle ValidationError with empty errors object', () => {
|
||||||
|
const validationError = {
|
||||||
|
name: 'ValidationError',
|
||||||
|
errors: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(validationError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
messages: '[]',
|
||||||
|
fields: '[]',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Duplicate key error handling', () => {
|
||||||
|
it('should handle duplicate key error (code 11000)', () => {
|
||||||
|
const duplicateKeyError = {
|
||||||
|
code: 11000,
|
||||||
|
keyValue: { email: 'test@example.com' },
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
messages: 'An document with that ["email"] already exists.',
|
||||||
|
fields: '["email"]',
|
||||||
|
});
|
||||||
|
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle duplicate key error with multiple fields', () => {
|
||||||
|
const duplicateKeyError = {
|
||||||
|
code: 11000,
|
||||||
|
keyValue: { email: 'test@example.com', username: 'testuser' },
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
messages: 'An document with that ["email","username"] already exists.',
|
||||||
|
fields: '["email","username"]',
|
||||||
|
});
|
||||||
|
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle error with code 11000 as string', () => {
|
||||||
|
const duplicateKeyError = {
|
||||||
|
code: '11000',
|
||||||
|
keyValue: { email: 'test@example.com' },
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith({
|
||||||
|
messages: 'An document with that ["email"] already exists.',
|
||||||
|
fields: '["email"]',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('SyntaxError handling', () => {
|
||||||
|
it('should handle errors with statusCode and body', () => {
|
||||||
|
const syntaxError = {
|
||||||
|
statusCode: 400,
|
||||||
|
body: 'Invalid JSON syntax',
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(syntaxError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle errors with different statusCode and body', () => {
|
||||||
|
const customError = {
|
||||||
|
statusCode: 422,
|
||||||
|
body: { error: 'Unprocessable entity' },
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(customError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(422);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle error with statusCode but no body', () => {
|
||||||
|
const partialError = {
|
||||||
|
statusCode: 400,
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(partialError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle error with body but no statusCode', () => {
|
||||||
|
const partialError = {
|
||||||
|
body: 'Some error message',
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(partialError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Unknown error handling', () => {
|
||||||
|
it('should handle unknown errors', () => {
|
||||||
|
const unknownError = new Error('Some unknown error');
|
||||||
|
|
||||||
|
errorController(unknownError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle errors with code other than 11000', () => {
|
||||||
|
const mongoError = {
|
||||||
|
code: 11100,
|
||||||
|
message: 'Some MongoDB error',
|
||||||
|
};
|
||||||
|
|
||||||
|
errorController(mongoError, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||||
|
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle null/undefined errors', () => {
|
||||||
|
errorController(null, mockReq, mockRes, mockNext);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||||
|
expect(logger.error).toHaveBeenCalledWith(
|
||||||
|
'ErrorController => processing error',
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Catch block handling', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
// Restore logger mock to normal behavior for these tests
|
||||||
|
logger.error.mockRestore();
|
||||||
|
logger.error = jest.fn();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle errors when logger.error throws', () => {
|
||||||
|
// Create fresh mocks for this test
|
||||||
|
const freshMockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
send: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Mock logger to throw on the first call, succeed on the second
|
||||||
|
logger.error
|
||||||
|
.mockImplementationOnce(() => {
|
||||||
|
throw new Error('Logger error');
|
||||||
|
})
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
const testError = new Error('Test error');
|
||||||
|
|
||||||
|
errorController(testError, mockReq, freshMockRes, mockNext);
|
||||||
|
|
||||||
|
expect(freshMockRes.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||||
|
expect(logger.error).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
195
api/server/controllers/agents/__tests__/v1.spec.js
Normal file
195
api/server/controllers/agents/__tests__/v1.spec.js
Normal file
|
|
@ -0,0 +1,195 @@
|
||||||
|
const { duplicateAgent } = require('../v1');
|
||||||
|
const { getAgent, createAgent } = require('~/models/Agent');
|
||||||
|
const { getActions } = require('~/models/Action');
|
||||||
|
const { nanoid } = require('nanoid');
|
||||||
|
|
||||||
|
jest.mock('~/models/Agent');
|
||||||
|
jest.mock('~/models/Action');
|
||||||
|
jest.mock('nanoid');
|
||||||
|
|
||||||
|
describe('duplicateAgent', () => {
|
||||||
|
let req, res;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
req = {
|
||||||
|
params: { id: 'agent_123' },
|
||||||
|
user: { id: 'user_456' },
|
||||||
|
};
|
||||||
|
res = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn(),
|
||||||
|
};
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should duplicate an agent successfully', async () => {
|
||||||
|
const mockAgent = {
|
||||||
|
id: 'agent_123',
|
||||||
|
name: 'Test Agent',
|
||||||
|
description: 'Test Description',
|
||||||
|
instructions: 'Test Instructions',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['file_search'],
|
||||||
|
actions: [],
|
||||||
|
author: 'user_789',
|
||||||
|
versions: [{ name: 'Test Agent', version: 1 }],
|
||||||
|
__v: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockNewAgent = {
|
||||||
|
id: 'agent_new_123',
|
||||||
|
name: 'Test Agent (1/2/23, 12:34)',
|
||||||
|
description: 'Test Description',
|
||||||
|
instructions: 'Test Instructions',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['file_search'],
|
||||||
|
actions: [],
|
||||||
|
author: 'user_456',
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Test Agent (1/2/23, 12:34)',
|
||||||
|
description: 'Test Description',
|
||||||
|
instructions: 'Test Instructions',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['file_search'],
|
||||||
|
actions: [],
|
||||||
|
createdAt: new Date(),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
getAgent.mockResolvedValue(mockAgent);
|
||||||
|
getActions.mockResolvedValue([]);
|
||||||
|
nanoid.mockReturnValue('new_123');
|
||||||
|
createAgent.mockResolvedValue(mockNewAgent);
|
||||||
|
|
||||||
|
await duplicateAgent(req, res);
|
||||||
|
|
||||||
|
expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' });
|
||||||
|
expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true);
|
||||||
|
expect(createAgent).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'agent_new_123',
|
||||||
|
author: 'user_456',
|
||||||
|
name: expect.stringContaining('Test Agent ('),
|
||||||
|
description: 'Test Description',
|
||||||
|
instructions: 'Test Instructions',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['file_search'],
|
||||||
|
actions: [],
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(createAgent).toHaveBeenCalledWith(
|
||||||
|
expect.not.objectContaining({
|
||||||
|
versions: expect.anything(),
|
||||||
|
__v: expect.anything(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(201);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({
|
||||||
|
agent: mockNewAgent,
|
||||||
|
actions: [],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should ensure duplicated agent has clean versions array without nested fields', async () => {
|
||||||
|
const mockAgent = {
|
||||||
|
id: 'agent_123',
|
||||||
|
name: 'Test Agent',
|
||||||
|
description: 'Test Description',
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Test Agent',
|
||||||
|
versions: [{ name: 'Nested' }],
|
||||||
|
__v: 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
__v: 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockNewAgent = {
|
||||||
|
id: 'agent_new_123',
|
||||||
|
name: 'Test Agent (1/2/23, 12:34)',
|
||||||
|
description: 'Test Description',
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Test Agent (1/2/23, 12:34)',
|
||||||
|
description: 'Test Description',
|
||||||
|
createdAt: new Date(),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
getAgent.mockResolvedValue(mockAgent);
|
||||||
|
getActions.mockResolvedValue([]);
|
||||||
|
nanoid.mockReturnValue('new_123');
|
||||||
|
createAgent.mockResolvedValue(mockNewAgent);
|
||||||
|
|
||||||
|
await duplicateAgent(req, res);
|
||||||
|
|
||||||
|
expect(mockNewAgent.versions).toHaveLength(1);
|
||||||
|
|
||||||
|
const firstVersion = mockNewAgent.versions[0];
|
||||||
|
expect(firstVersion).not.toHaveProperty('versions');
|
||||||
|
expect(firstVersion).not.toHaveProperty('__v');
|
||||||
|
|
||||||
|
expect(mockNewAgent).not.toHaveProperty('__v');
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(201);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 404 if agent not found', async () => {
|
||||||
|
getAgent.mockResolvedValue(null);
|
||||||
|
|
||||||
|
await duplicateAgent(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(404);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({
|
||||||
|
error: 'Agent not found',
|
||||||
|
status: 'error',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle tool_resources.ocr correctly', async () => {
|
||||||
|
const mockAgent = {
|
||||||
|
id: 'agent_123',
|
||||||
|
name: 'Test Agent',
|
||||||
|
tool_resources: {
|
||||||
|
ocr: { enabled: true, config: 'test' },
|
||||||
|
other: { should: 'not be copied' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
getAgent.mockResolvedValue(mockAgent);
|
||||||
|
getActions.mockResolvedValue([]);
|
||||||
|
nanoid.mockReturnValue('new_123');
|
||||||
|
createAgent.mockResolvedValue({ id: 'agent_new_123' });
|
||||||
|
|
||||||
|
await duplicateAgent(req, res);
|
||||||
|
|
||||||
|
expect(createAgent).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
tool_resources: {
|
||||||
|
ocr: { enabled: true, config: 'test' },
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle errors gracefully', async () => {
|
||||||
|
getAgent.mockRejectedValue(new Error('Database error'));
|
||||||
|
|
||||||
|
await duplicateAgent(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({ error: 'Database error' });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -4,11 +4,13 @@ const {
|
||||||
sendEvent,
|
sendEvent,
|
||||||
createRun,
|
createRun,
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
|
checkAccess,
|
||||||
memoryInstructions,
|
memoryInstructions,
|
||||||
createMemoryProcessor,
|
createMemoryProcessor,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
Callback,
|
Callback,
|
||||||
|
Providers,
|
||||||
GraphEvents,
|
GraphEvents,
|
||||||
formatMessage,
|
formatMessage,
|
||||||
formatAgentMessages,
|
formatAgentMessages,
|
||||||
|
|
@ -31,22 +33,29 @@ const {
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||||
const {
|
const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config');
|
||||||
getCustomEndpointConfig,
|
|
||||||
createGetMCPAuthMap,
|
|
||||||
checkCapability,
|
|
||||||
} = require('~/server/services/Config');
|
|
||||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||||
const { checkAccess } = require('~/server/middleware/roles/access');
|
|
||||||
const BaseClient = require('~/app/clients/BaseClient');
|
const BaseClient = require('~/app/clients/BaseClient');
|
||||||
|
const { getRoleByName } = require('~/models/Role');
|
||||||
const { loadAgent } = require('~/models/Agent');
|
const { loadAgent } = require('~/models/Agent');
|
||||||
const { getMCPManager } = require('~/config');
|
const { getMCPManager } = require('~/config');
|
||||||
|
|
||||||
|
const omitTitleOptions = new Set([
|
||||||
|
'stream',
|
||||||
|
'thinking',
|
||||||
|
'streaming',
|
||||||
|
'clientOptions',
|
||||||
|
'thinkingConfig',
|
||||||
|
'thinkingBudget',
|
||||||
|
'includeThoughts',
|
||||||
|
'maxOutputTokens',
|
||||||
|
]);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {ServerRequest} req
|
* @param {ServerRequest} req
|
||||||
* @param {Agent} agent
|
* @param {Agent} agent
|
||||||
|
|
@ -393,7 +402,12 @@ class AgentClient extends BaseClient {
|
||||||
if (user.personalization?.memories === false) {
|
if (user.personalization?.memories === false) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
|
const hasAccess = await checkAccess({
|
||||||
|
user,
|
||||||
|
permissionType: PermissionTypes.MEMORIES,
|
||||||
|
permissions: [Permissions.USE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
if (!hasAccess) {
|
if (!hasAccess) {
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -511,7 +525,10 @@ class AgentClient extends BaseClient {
|
||||||
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return await this.processMemory(messagesToProcess);
|
|
||||||
|
const bufferString = getBufferString(messagesToProcess);
|
||||||
|
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
|
||||||
|
return await this.processMemory([bufferMessage]);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Memory Agent failed to process memory', error);
|
logger.error('Memory Agent failed to process memory', error);
|
||||||
}
|
}
|
||||||
|
|
@ -677,7 +694,7 @@ class AgentClient extends BaseClient {
|
||||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||||
user: this.options.req.user,
|
user: this.options.req.user,
|
||||||
},
|
},
|
||||||
recursionLimit: agentsEConfig?.recursionLimit,
|
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
streamMode: 'values',
|
streamMode: 'values',
|
||||||
version: 'v2',
|
version: 'v2',
|
||||||
|
|
@ -983,23 +1000,26 @@ class AgentClient extends BaseClient {
|
||||||
throw new Error('Run not initialized');
|
throw new Error('Run not initialized');
|
||||||
}
|
}
|
||||||
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
|
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
|
||||||
const endpoint = this.options.agent.endpoint;
|
const { req, res, agent } = this.options;
|
||||||
const { req, res } = this.options;
|
const endpoint = agent.endpoint;
|
||||||
|
|
||||||
/** @type {import('@librechat/agents').ClientOptions} */
|
/** @type {import('@librechat/agents').ClientOptions} */
|
||||||
let clientOptions = {
|
let clientOptions = {
|
||||||
maxTokens: 75,
|
maxTokens: 75,
|
||||||
|
model: agent.model_parameters.model,
|
||||||
};
|
};
|
||||||
let endpointConfig = req.app.locals[endpoint];
|
|
||||||
|
const { getOptions, overrideProvider, customEndpointConfig } =
|
||||||
|
await getProviderConfig(endpoint);
|
||||||
|
|
||||||
|
/** @type {TEndpoint | undefined} */
|
||||||
|
const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig;
|
||||||
if (!endpointConfig) {
|
if (!endpointConfig) {
|
||||||
try {
|
logger.warn(
|
||||||
endpointConfig = await getCustomEndpointConfig(endpoint);
|
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
|
||||||
} catch (err) {
|
);
|
||||||
logger.error(
|
|
||||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
|
|
||||||
err,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
endpointConfig &&
|
endpointConfig &&
|
||||||
endpointConfig.titleModel &&
|
endpointConfig.titleModel &&
|
||||||
|
|
@ -1007,30 +1027,50 @@ class AgentClient extends BaseClient {
|
||||||
) {
|
) {
|
||||||
clientOptions.model = endpointConfig.titleModel;
|
clientOptions.model = endpointConfig.titleModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const options = await getOptions({
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
optionsOnly: true,
|
||||||
|
overrideEndpoint: endpoint,
|
||||||
|
overrideModel: clientOptions.model,
|
||||||
|
endpointOption: { model_parameters: clientOptions },
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider = options.provider ?? overrideProvider ?? agent.provider;
|
||||||
if (
|
if (
|
||||||
endpoint === EModelEndpoint.azureOpenAI &&
|
endpoint === EModelEndpoint.azureOpenAI &&
|
||||||
clientOptions.model &&
|
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||||
this.options.agent.model_parameters.model !== clientOptions.model
|
|
||||||
) {
|
) {
|
||||||
clientOptions =
|
provider = Providers.OPENAI;
|
||||||
(
|
|
||||||
await initOpenAI({
|
|
||||||
req,
|
|
||||||
res,
|
|
||||||
optionsOnly: true,
|
|
||||||
overrideModel: clientOptions.model,
|
|
||||||
overrideEndpoint: endpoint,
|
|
||||||
endpointOption: {
|
|
||||||
model_parameters: clientOptions,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
)?.llmConfig ?? clientOptions;
|
|
||||||
}
|
}
|
||||||
if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
|
||||||
|
/** @type {import('@librechat/agents').ClientOptions} */
|
||||||
|
clientOptions = { ...options.llmConfig };
|
||||||
|
if (options.configOptions) {
|
||||||
|
clientOptions.configuration = options.configOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure maxTokens is set for non-o1 models
|
||||||
|
if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) {
|
||||||
|
clientOptions.maxTokens = 75;
|
||||||
|
} else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||||
delete clientOptions.maxTokens;
|
delete clientOptions.maxTokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientOptions = Object.assign(
|
||||||
|
Object.fromEntries(
|
||||||
|
Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (provider === Providers.GOOGLE) {
|
||||||
|
clientOptions.json = true;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const titleResult = await this.run.generateTitle({
|
const titleResult = await this.run.generateTitle({
|
||||||
|
provider,
|
||||||
inputText: text,
|
inputText: text,
|
||||||
contentParts: this.contentParts,
|
contentParts: this.contentParts,
|
||||||
clientOptions,
|
clientOptions,
|
||||||
|
|
@ -1048,8 +1088,10 @@ class AgentClient extends BaseClient {
|
||||||
let input_tokens, output_tokens;
|
let input_tokens, output_tokens;
|
||||||
|
|
||||||
if (item.usage) {
|
if (item.usage) {
|
||||||
input_tokens = item.usage.input_tokens || item.usage.inputTokens;
|
input_tokens =
|
||||||
output_tokens = item.usage.output_tokens || item.usage.outputTokens;
|
item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens;
|
||||||
|
output_tokens =
|
||||||
|
item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens;
|
||||||
} else if (item.tokenUsage) {
|
} else if (item.tokenUsage) {
|
||||||
input_tokens = item.tokenUsage.promptTokens;
|
input_tokens = item.tokenUsage.promptTokens;
|
||||||
output_tokens = item.tokenUsage.completionTokens;
|
output_tokens = item.tokenUsage.completionTokens;
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
// errorHandler.js
|
// errorHandler.js
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
|
||||||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||||
|
const { sendResponse } = require('~/server/middleware/error');
|
||||||
const { recordUsage } = require('~/server/services/Threads');
|
const { recordUsage } = require('~/server/services/Threads');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
const { getConvo } = require('~/models/Conversation');
|
||||||
const { sendResponse } = require('~/server/utils');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef {Object} ErrorHandlerContext
|
* @typedef {Object} ErrorHandlerContext
|
||||||
|
|
@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||||
} else if (/Files.*are invalid/.test(error.message)) {
|
} else if (/Files.*are invalid/.test(error.message)) {
|
||||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||||
endpoint === 'azureAssistants'
|
endpoint === 'azureAssistants'
|
||||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
|
||||||
: ''
|
: ''
|
||||||
}`;
|
}`;
|
||||||
return sendResponse(req, res, messageData, errorMessage);
|
return sendResponse(req, res, messageData, errorMessage);
|
||||||
|
|
|
||||||
|
|
@ -1,106 +0,0 @@
|
||||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
|
||||||
const { resolveHeaders } = require('librechat-data-provider');
|
|
||||||
const { createLLM } = require('~/app/clients/llm');
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes and returns a Language Learning Model (LLM) instance.
|
|
||||||
*
|
|
||||||
* @param {Object} options - Configuration options for the LLM.
|
|
||||||
* @param {string} options.model - The model identifier.
|
|
||||||
* @param {string} options.modelName - The specific name of the model.
|
|
||||||
* @param {number} options.temperature - The temperature setting for the model.
|
|
||||||
* @param {number} options.presence_penalty - The presence penalty for the model.
|
|
||||||
* @param {number} options.frequency_penalty - The frequency penalty for the model.
|
|
||||||
* @param {number} options.max_tokens - The maximum number of tokens for the model output.
|
|
||||||
* @param {boolean} options.streaming - Whether to use streaming for the model output.
|
|
||||||
* @param {Object} options.context - The context for the conversation.
|
|
||||||
* @param {number} options.tokenBuffer - The token buffer size.
|
|
||||||
* @param {number} options.initialMessageCount - The initial message count.
|
|
||||||
* @param {string} options.conversationId - The ID of the conversation.
|
|
||||||
* @param {string} options.user - The user identifier.
|
|
||||||
* @param {string} options.langchainProxy - The langchain proxy URL.
|
|
||||||
* @param {boolean} options.useOpenRouter - Whether to use OpenRouter.
|
|
||||||
* @param {Object} options.options - Additional options.
|
|
||||||
* @param {Object} options.options.headers - Custom headers for the request.
|
|
||||||
* @param {string} options.options.proxy - Proxy URL.
|
|
||||||
* @param {Object} options.options.req - The request object.
|
|
||||||
* @param {Object} options.options.res - The response object.
|
|
||||||
* @param {boolean} options.options.debug - Whether to enable debug mode.
|
|
||||||
* @param {string} options.apiKey - The API key for authentication.
|
|
||||||
* @param {Object} options.azure - Azure-specific configuration.
|
|
||||||
* @param {Object} options.abortController - The AbortController instance.
|
|
||||||
* @returns {Object} The initialized LLM instance.
|
|
||||||
*/
|
|
||||||
function initializeLLM(options) {
|
|
||||||
const {
|
|
||||||
model,
|
|
||||||
modelName,
|
|
||||||
temperature,
|
|
||||||
presence_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
max_tokens,
|
|
||||||
streaming,
|
|
||||||
user,
|
|
||||||
langchainProxy,
|
|
||||||
useOpenRouter,
|
|
||||||
options: { headers, proxy },
|
|
||||||
apiKey,
|
|
||||||
azure,
|
|
||||||
} = options;
|
|
||||||
|
|
||||||
const modelOptions = {
|
|
||||||
modelName: modelName || model,
|
|
||||||
temperature,
|
|
||||||
presence_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
user,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (max_tokens) {
|
|
||||||
modelOptions.max_tokens = max_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
const configOptions = {};
|
|
||||||
|
|
||||||
if (langchainProxy) {
|
|
||||||
configOptions.basePath = langchainProxy;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (useOpenRouter) {
|
|
||||||
configOptions.basePath = 'https://openrouter.ai/api/v1';
|
|
||||||
configOptions.baseOptions = {
|
|
||||||
headers: {
|
|
||||||
'HTTP-Referer': 'https://librechat.ai',
|
|
||||||
'X-Title': 'LibreChat',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
|
|
||||||
configOptions.baseOptions = {
|
|
||||||
headers: resolveHeaders({
|
|
||||||
...headers,
|
|
||||||
...configOptions?.baseOptions?.headers,
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (proxy) {
|
|
||||||
configOptions.httpAgent = new HttpsProxyAgent(proxy);
|
|
||||||
configOptions.httpsAgent = new HttpsProxyAgent(proxy);
|
|
||||||
}
|
|
||||||
|
|
||||||
const llm = createLLM({
|
|
||||||
modelOptions,
|
|
||||||
configOptions,
|
|
||||||
openAIApiKey: apiKey,
|
|
||||||
azure,
|
|
||||||
streaming,
|
|
||||||
});
|
|
||||||
|
|
||||||
return llm;
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports = {
|
|
||||||
initializeLLM,
|
|
||||||
};
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
const { sendEvent } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { Constants } = require('librechat-data-provider');
|
const { Constants } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
handleAbortError,
|
handleAbortError,
|
||||||
|
|
@ -5,17 +7,18 @@ const {
|
||||||
cleanupAbortController,
|
cleanupAbortController,
|
||||||
} = require('~/server/middleware');
|
} = require('~/server/middleware');
|
||||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||||
const { sendMessage } = require('~/server/utils');
|
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
let {
|
let {
|
||||||
text,
|
text,
|
||||||
endpointOption,
|
endpointOption,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
isContinued = false,
|
||||||
|
editedContent = null,
|
||||||
parentMessageId = null,
|
parentMessageId = null,
|
||||||
overrideParentMessageId = null,
|
overrideParentMessageId = null,
|
||||||
|
responseMessageId: editedResponseMessageId = null,
|
||||||
} = req.body;
|
} = req.body;
|
||||||
|
|
||||||
let sender;
|
let sender;
|
||||||
|
|
@ -67,7 +70,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
handler();
|
handler();
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
// Ignore cleanup errors
|
logger.error('[AgentController] Error in cleanup handler', e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -155,7 +158,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
try {
|
try {
|
||||||
res.removeListener('close', closeHandler);
|
res.removeListener('close', closeHandler);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
// Ignore
|
logger.error('[AgentController] Error removing close listener', e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -163,10 +166,14 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
user: userId,
|
user: userId,
|
||||||
onStart,
|
onStart,
|
||||||
getReqData,
|
getReqData,
|
||||||
|
isContinued,
|
||||||
|
editedContent,
|
||||||
conversationId,
|
conversationId,
|
||||||
parentMessageId,
|
parentMessageId,
|
||||||
abortController,
|
abortController,
|
||||||
overrideParentMessageId,
|
overrideParentMessageId,
|
||||||
|
isEdited: !!editedContent,
|
||||||
|
responseMessageId: editedResponseMessageId,
|
||||||
progressOptions: {
|
progressOptions: {
|
||||||
res,
|
res,
|
||||||
},
|
},
|
||||||
|
|
@ -206,7 +213,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
// Create a new response object with minimal copies
|
// Create a new response object with minimal copies
|
||||||
const finalResponse = { ...response };
|
const finalResponse = { ...response };
|
||||||
|
|
||||||
sendMessage(res, {
|
sendEvent(res, {
|
||||||
final: true,
|
final: true,
|
||||||
conversation,
|
conversation,
|
||||||
title: conversation.title,
|
title: conversation.title,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
const { z } = require('zod');
|
||||||
const fs = require('fs').promises;
|
const fs = require('fs').promises;
|
||||||
const { nanoid } = require('nanoid');
|
const { nanoid } = require('nanoid');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
Tools,
|
Tools,
|
||||||
Constants,
|
Constants,
|
||||||
|
|
@ -8,6 +10,7 @@ const {
|
||||||
SystemRoles,
|
SystemRoles,
|
||||||
EToolResources,
|
EToolResources,
|
||||||
actionDelimiter,
|
actionDelimiter,
|
||||||
|
removeNullishValues,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
getAgent,
|
getAgent,
|
||||||
|
|
@ -30,6 +33,7 @@ const { deleteFileByFilter } = require('~/models/File');
|
||||||
const systemTools = {
|
const systemTools = {
|
||||||
[Tools.execute_code]: true,
|
[Tools.execute_code]: true,
|
||||||
[Tools.file_search]: true,
|
[Tools.file_search]: true,
|
||||||
|
[Tools.web_search]: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -42,9 +46,13 @@ const systemTools = {
|
||||||
*/
|
*/
|
||||||
const createAgentHandler = async (req, res) => {
|
const createAgentHandler = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
|
const validatedData = agentCreateSchema.parse(req.body);
|
||||||
|
const { tools = [], ...agentData } = removeNullishValues(validatedData);
|
||||||
|
|
||||||
const { id: userId } = req.user;
|
const { id: userId } = req.user;
|
||||||
|
|
||||||
|
agentData.id = `agent_${nanoid()}`;
|
||||||
|
agentData.author = userId;
|
||||||
agentData.tools = [];
|
agentData.tools = [];
|
||||||
|
|
||||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||||
|
|
@ -58,19 +66,13 @@ const createAgentHandler = async (req, res) => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Object.assign(agentData, {
|
|
||||||
author: userId,
|
|
||||||
name,
|
|
||||||
description,
|
|
||||||
instructions,
|
|
||||||
provider,
|
|
||||||
model,
|
|
||||||
});
|
|
||||||
|
|
||||||
agentData.id = `agent_${nanoid()}`;
|
|
||||||
const agent = await createAgent(agentData);
|
const agent = await createAgent(agentData);
|
||||||
res.status(201).json(agent);
|
res.status(201).json(agent);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error instanceof z.ZodError) {
|
||||||
|
logger.error('[/Agents] Validation error', error.errors);
|
||||||
|
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||||
|
}
|
||||||
logger.error('[/Agents] Error creating agent', error);
|
logger.error('[/Agents] Error creating agent', error);
|
||||||
res.status(500).json({ error: error.message });
|
res.status(500).json({ error: error.message });
|
||||||
}
|
}
|
||||||
|
|
@ -154,14 +156,16 @@ const getAgentHandler = async (req, res) => {
|
||||||
const updateAgentHandler = async (req, res) => {
|
const updateAgentHandler = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const id = req.params.id;
|
const id = req.params.id;
|
||||||
const { projectIds, removeProjectIds, ...updateData } = req.body;
|
const validatedData = agentUpdateSchema.parse(req.body);
|
||||||
|
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
|
||||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||||
const existingAgent = await getAgent({ id });
|
const existingAgent = await getAgent({ id });
|
||||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
|
||||||
|
|
||||||
if (!existingAgent) {
|
if (!existingAgent) {
|
||||||
return res.status(404).json({ error: 'Agent not found' });
|
return res.status(404).json({ error: 'Agent not found' });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||||
|
|
||||||
if (!hasEditPermission) {
|
if (!hasEditPermission) {
|
||||||
|
|
@ -200,6 +204,11 @@ const updateAgentHandler = async (req, res) => {
|
||||||
|
|
||||||
return res.json(updatedAgent);
|
return res.json(updatedAgent);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error instanceof z.ZodError) {
|
||||||
|
logger.error('[/Agents/:id] Validation error', error.errors);
|
||||||
|
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||||
|
}
|
||||||
|
|
||||||
logger.error('[/Agents/:id] Error updating Agent', error);
|
logger.error('[/Agents/:id] Error updating Agent', error);
|
||||||
|
|
||||||
if (error.statusCode === 409) {
|
if (error.statusCode === 409) {
|
||||||
|
|
@ -242,6 +251,8 @@ const duplicateAgentHandler = async (req, res) => {
|
||||||
createdAt: _createdAt,
|
createdAt: _createdAt,
|
||||||
updatedAt: _updatedAt,
|
updatedAt: _updatedAt,
|
||||||
tool_resources: _tool_resources = {},
|
tool_resources: _tool_resources = {},
|
||||||
|
versions: _versions,
|
||||||
|
__v: _v,
|
||||||
...cloneData
|
...cloneData
|
||||||
} = agent;
|
} = agent;
|
||||||
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
|
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
|
||||||
|
|
|
||||||
659
api/server/controllers/agents/v1.spec.js
Normal file
659
api/server/controllers/agents/v1.spec.js
Normal file
|
|
@ -0,0 +1,659 @@
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const { v4: uuidv4 } = require('uuid');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
const { agentSchema } = require('@librechat/data-schemas');
|
||||||
|
|
||||||
|
// Only mock the dependencies that are not database-related
|
||||||
|
jest.mock('~/server/services/Config', () => ({
|
||||||
|
getCachedTools: jest.fn().mockResolvedValue({
|
||||||
|
web_search: true,
|
||||||
|
execute_code: true,
|
||||||
|
file_search: true,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Project', () => ({
|
||||||
|
getProjectByName: jest.fn().mockResolvedValue(null),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/strategies', () => ({
|
||||||
|
getStrategyFunctions: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/images/avatar', () => ({
|
||||||
|
resizeAvatar: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||||
|
refreshS3Url: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/process', () => ({
|
||||||
|
filterFile: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Action', () => ({
|
||||||
|
updateAction: jest.fn(),
|
||||||
|
getActions: jest.fn().mockResolvedValue([]),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/File', () => ({
|
||||||
|
deleteFileByFilter: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||||
|
*/
|
||||||
|
let Agent;
|
||||||
|
|
||||||
|
describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||||
|
let mongoServer;
|
||||||
|
let mockReq;
|
||||||
|
let mockRes;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
const mongoUri = mongoServer.getUri();
|
||||||
|
await mongoose.connect(mongoUri);
|
||||||
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
|
}, 20000);
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await Agent.deleteMany({});
|
||||||
|
|
||||||
|
// Reset all mocks
|
||||||
|
jest.clearAllMocks();
|
||||||
|
|
||||||
|
// Setup mock request and response objects
|
||||||
|
mockReq = {
|
||||||
|
user: {
|
||||||
|
id: new mongoose.Types.ObjectId().toString(),
|
||||||
|
role: 'USER',
|
||||||
|
},
|
||||||
|
body: {},
|
||||||
|
params: {},
|
||||||
|
app: {
|
||||||
|
locals: {
|
||||||
|
fileStrategy: 'local',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn().mockReturnThis(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('createAgentHandler', () => {
|
||||||
|
test('should create agent with allowed fields only', async () => {
|
||||||
|
const validData = {
|
||||||
|
name: 'Test Agent',
|
||||||
|
description: 'A test agent',
|
||||||
|
instructions: 'Be helpful',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['web_search'],
|
||||||
|
model_parameters: { temperature: 0.7 },
|
||||||
|
tool_resources: {
|
||||||
|
file_search: { file_ids: ['file1', 'file2'] },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockReq.body = validData;
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(createdAgent.name).toBe('Test Agent');
|
||||||
|
expect(createdAgent.description).toBe('A test agent');
|
||||||
|
expect(createdAgent.provider).toBe('openai');
|
||||||
|
expect(createdAgent.model).toBe('gpt-4');
|
||||||
|
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
|
||||||
|
expect(createdAgent.tools).toContain('web_search');
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||||
|
expect(agentInDb).toBeDefined();
|
||||||
|
expect(agentInDb.name).toBe('Test Agent');
|
||||||
|
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
|
||||||
|
const maliciousData = {
|
||||||
|
// Required fields
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'Malicious Agent',
|
||||||
|
|
||||||
|
// Unauthorized fields that should be stripped
|
||||||
|
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
|
||||||
|
authorName: 'Hacker', // Should be stripped
|
||||||
|
isCollaborative: true, // Should be stripped on creation
|
||||||
|
versions: [], // Should be stripped
|
||||||
|
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||||
|
id: 'custom_agent_id', // Should be overridden
|
||||||
|
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||||
|
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||||
|
};
|
||||||
|
|
||||||
|
mockReq.body = maliciousData;
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
|
||||||
|
// Verify unauthorized fields were not set
|
||||||
|
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
|
||||||
|
expect(createdAgent.authorName).toBeUndefined();
|
||||||
|
expect(createdAgent.isCollaborative).toBeFalsy();
|
||||||
|
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
|
||||||
|
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
|
||||||
|
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
|
||||||
|
|
||||||
|
// Verify timestamps are recent (not the malicious dates)
|
||||||
|
const createdTime = new Date(createdAgent.createdAt).getTime();
|
||||||
|
const now = Date.now();
|
||||||
|
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||||
|
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||||
|
expect(agentInDb.authorName).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should validate required fields', async () => {
|
||||||
|
const invalidData = {
|
||||||
|
name: 'Missing Required Fields',
|
||||||
|
// Missing provider and model
|
||||||
|
};
|
||||||
|
|
||||||
|
mockReq.body = invalidData;
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.json).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
error: 'Invalid request data',
|
||||||
|
details: expect.any(Array),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify nothing was created in database
|
||||||
|
const count = await Agent.countDocuments();
|
||||||
|
expect(count).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle tool_resources validation', async () => {
|
||||||
|
const dataWithInvalidToolResources = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'Agent with Tool Resources',
|
||||||
|
tool_resources: {
|
||||||
|
// Valid resources
|
||||||
|
file_search: {
|
||||||
|
file_ids: ['file1', 'file2'],
|
||||||
|
vector_store_ids: ['vs1'],
|
||||||
|
},
|
||||||
|
execute_code: {
|
||||||
|
file_ids: ['file3'],
|
||||||
|
},
|
||||||
|
// Invalid resource (should be stripped by schema)
|
||||||
|
invalid_resource: {
|
||||||
|
file_ids: ['file4'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockReq.body = dataWithInvalidToolResources;
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(createdAgent.tool_resources).toBeDefined();
|
||||||
|
expect(createdAgent.tool_resources.file_search).toBeDefined();
|
||||||
|
expect(createdAgent.tool_resources.execute_code).toBeDefined();
|
||||||
|
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||||
|
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle avatar validation', async () => {
|
||||||
|
const dataWithAvatar = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'Agent with Avatar',
|
||||||
|
avatar: {
|
||||||
|
filepath: 'https://example.com/avatar.png',
|
||||||
|
source: 's3',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockReq.body = dataWithAvatar;
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(createdAgent.avatar).toEqual({
|
||||||
|
filepath: 'https://example.com/avatar.png',
|
||||||
|
source: 's3',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle invalid avatar format', async () => {
|
||||||
|
const dataWithInvalidAvatar = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'Agent with Invalid Avatar',
|
||||||
|
avatar: 'just-a-string', // Invalid format
|
||||||
|
};
|
||||||
|
|
||||||
|
mockReq.body = dataWithInvalidAvatar;
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.json).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
error: 'Invalid request data',
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('updateAgentHandler', () => {
|
||||||
|
let existingAgentId;
|
||||||
|
let existingAgentAuthorId;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
// Create an existing agent for update tests
|
||||||
|
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||||
|
const agent = await Agent.create({
|
||||||
|
id: `agent_${uuidv4()}`,
|
||||||
|
name: 'Original Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
author: existingAgentAuthorId,
|
||||||
|
description: 'Original description',
|
||||||
|
isCollaborative: false,
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Original Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
description: 'Original description',
|
||||||
|
createdAt: new Date(),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
existingAgentId = agent.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should update agent with allowed fields only', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Updated Agent',
|
||||||
|
description: 'Updated description',
|
||||||
|
model: 'gpt-4',
|
||||||
|
isCollaborative: true, // This IS allowed in updates
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).not.toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.name).toBe('Updated Agent');
|
||||||
|
expect(updatedAgent.description).toBe('Updated description');
|
||||||
|
expect(updatedAgent.model).toBe('gpt-4');
|
||||||
|
expect(updatedAgent.isCollaborative).toBe(true);
|
||||||
|
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||||
|
expect(agentInDb.name).toBe('Updated Agent');
|
||||||
|
expect(agentInDb.isCollaborative).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Updated Name',
|
||||||
|
|
||||||
|
// Unauthorized fields that should be stripped
|
||||||
|
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
|
||||||
|
authorName: 'Hacker', // Should be stripped
|
||||||
|
id: 'different_agent_id', // Should be stripped
|
||||||
|
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||||
|
versions: [], // Should be stripped
|
||||||
|
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||||
|
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
|
||||||
|
// Verify unauthorized fields were not changed
|
||||||
|
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
|
||||||
|
expect(updatedAgent.authorName).toBeUndefined();
|
||||||
|
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
|
||||||
|
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||||
|
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
|
||||||
|
expect(agentInDb.id).toBe(existingAgentId);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should reject update from non-author when not collaborative', async () => {
|
||||||
|
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||||
|
mockReq.user.id = differentUserId; // Different user
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Unauthorized Update',
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||||
|
expect(mockRes.json).toHaveBeenCalledWith({
|
||||||
|
error: 'You do not have permission to modify this non-collaborative agent',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify agent was not modified in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||||
|
expect(agentInDb.name).toBe('Original Agent');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should allow update from non-author when collaborative', async () => {
|
||||||
|
// First make the agent collaborative
|
||||||
|
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
|
||||||
|
|
||||||
|
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||||
|
mockReq.user.id = differentUserId; // Different user
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Collaborative Update',
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.name).toBe('Collaborative Update');
|
||||||
|
// Author field should be removed for non-author
|
||||||
|
expect(updatedAgent.author).toBeUndefined();
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||||
|
expect(agentInDb.name).toBe('Collaborative Update');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should allow admin to update any agent', async () => {
|
||||||
|
const adminUserId = new mongoose.Types.ObjectId().toString();
|
||||||
|
mockReq.user.id = adminUserId;
|
||||||
|
mockReq.user.role = 'ADMIN'; // Set as admin
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Admin Update',
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.name).toBe('Admin Update');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle projectIds updates', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
|
||||||
|
const projectId1 = new mongoose.Types.ObjectId().toString();
|
||||||
|
const projectId2 = new mongoose.Types.ObjectId().toString();
|
||||||
|
|
||||||
|
mockReq.body = {
|
||||||
|
projectIds: [projectId1, projectId2],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent).toBeDefined();
|
||||||
|
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should validate tool_resources in updates', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
tool_resources: {
|
||||||
|
ocr: {
|
||||||
|
file_ids: ['ocr1', 'ocr2'],
|
||||||
|
},
|
||||||
|
execute_code: {
|
||||||
|
file_ids: ['img1'],
|
||||||
|
},
|
||||||
|
// Invalid tool resource
|
||||||
|
invalid_tool: {
|
||||||
|
file_ids: ['invalid'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.tool_resources).toBeDefined();
|
||||||
|
expect(updatedAgent.tool_resources.ocr).toBeDefined();
|
||||||
|
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
|
||||||
|
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should return 404 for non-existent agent', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Update Non-existent',
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(404);
|
||||||
|
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle validation errors properly', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
model_parameters: 'invalid-not-an-object', // Should be an object
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockRes.json).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
error: 'Invalid request data',
|
||||||
|
details: expect.any(Array),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Mass Assignment Attack Scenarios', () => {
|
||||||
|
test('should prevent setting system fields during creation', async () => {
|
||||||
|
const systemFields = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'System Fields Test',
|
||||||
|
|
||||||
|
// System fields that should never be settable by users
|
||||||
|
__v: 99,
|
||||||
|
_id: new mongoose.Types.ObjectId(),
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Fake Version',
|
||||||
|
provider: 'fake',
|
||||||
|
model: 'fake-model',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
mockReq.body = systemFields;
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
|
||||||
|
// Verify system fields were not affected
|
||||||
|
expect(createdAgent.__v).not.toBe(99);
|
||||||
|
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
|
||||||
|
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
|
||||||
|
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||||
|
expect(agentInDb.__v).not.toBe(99);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should prevent privilege escalation through isCollaborative', async () => {
|
||||||
|
// Create a non-collaborative agent
|
||||||
|
const authorId = new mongoose.Types.ObjectId();
|
||||||
|
const agent = await Agent.create({
|
||||||
|
id: `agent_${uuidv4()}`,
|
||||||
|
name: 'Private Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
isCollaborative: false,
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Private Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
createdAt: new Date(),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
// Try to make it collaborative as a different user
|
||||||
|
const attackerId = new mongoose.Types.ObjectId().toString();
|
||||||
|
mockReq.user.id = attackerId;
|
||||||
|
mockReq.params.id = agent.id;
|
||||||
|
mockReq.body = {
|
||||||
|
isCollaborative: true, // Trying to escalate privileges
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
// Should be rejected
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||||
|
|
||||||
|
// Verify in database that it's still not collaborative
|
||||||
|
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||||
|
expect(agentInDb.isCollaborative).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should prevent author hijacking', async () => {
|
||||||
|
const originalAuthorId = new mongoose.Types.ObjectId();
|
||||||
|
const attackerId = new mongoose.Types.ObjectId();
|
||||||
|
|
||||||
|
// Admin creates an agent
|
||||||
|
mockReq.user.id = originalAuthorId.toString();
|
||||||
|
mockReq.user.role = 'ADMIN';
|
||||||
|
mockReq.body = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'Admin Agent',
|
||||||
|
author: attackerId.toString(), // Trying to set different author
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
|
||||||
|
// Author should be the actual user, not the attempted value
|
||||||
|
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
|
||||||
|
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
|
||||||
|
|
||||||
|
// Verify in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||||
|
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should strip unknown fields to prevent future vulnerabilities', async () => {
|
||||||
|
mockReq.body = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'Future Proof Test',
|
||||||
|
|
||||||
|
// Unknown fields that might be added in future
|
||||||
|
superAdminAccess: true,
|
||||||
|
bypassAllChecks: true,
|
||||||
|
internalFlag: 'secret',
|
||||||
|
futureFeature: 'exploit',
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
|
||||||
|
// Verify unknown fields were stripped
|
||||||
|
expect(createdAgent.superAdminAccess).toBeUndefined();
|
||||||
|
expect(createdAgent.bypassAllChecks).toBeUndefined();
|
||||||
|
expect(createdAgent.internalFlag).toBeUndefined();
|
||||||
|
expect(createdAgent.futureFeature).toBeUndefined();
|
||||||
|
|
||||||
|
// Also check in database
|
||||||
|
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
|
||||||
|
expect(agentInDb.superAdminAccess).toBeUndefined();
|
||||||
|
expect(agentInDb.bypassAllChecks).toBeUndefined();
|
||||||
|
expect(agentInDb.internalFlag).toBeUndefined();
|
||||||
|
expect(agentInDb.futureFeature).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
const { v4 } = require('uuid');
|
const { v4 } = require('uuid');
|
||||||
|
const { sleep } = require('@librechat/agents');
|
||||||
|
const { sendEvent } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
Time,
|
Time,
|
||||||
Constants,
|
Constants,
|
||||||
|
|
@ -19,20 +22,20 @@ const {
|
||||||
addThreadMetadata,
|
addThreadMetadata,
|
||||||
saveAssistantMessage,
|
saveAssistantMessage,
|
||||||
} = require('~/server/services/Threads');
|
} = require('~/server/services/Threads');
|
||||||
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
|
|
||||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||||
const { createRunBody } = require('~/server/services/createRunBody');
|
const { createRunBody } = require('~/server/services/createRunBody');
|
||||||
|
const { sendResponse } = require('~/server/middleware/error');
|
||||||
const { getTransactions } = require('~/models/Transaction');
|
const { getTransactions } = require('~/models/Transaction');
|
||||||
const { checkBalance } = require('~/models/balanceMethods');
|
const { checkBalance } = require('~/models/balanceMethods');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
const { getConvo } = require('~/models/Conversation');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
const { countTokens } = require('~/server/utils');
|
||||||
const { getModelMaxTokens } = require('~/utils');
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
const { getOpenAIClient } = require('./helpers');
|
const { getOpenAIClient } = require('./helpers');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @route POST /
|
* @route POST /
|
||||||
|
|
@ -471,7 +474,7 @@ const chatV1 = async (req, res) => {
|
||||||
await Promise.all(promises);
|
await Promise.all(promises);
|
||||||
|
|
||||||
const sendInitialResponse = () => {
|
const sendInitialResponse = () => {
|
||||||
sendMessage(res, {
|
sendEvent(res, {
|
||||||
sync: true,
|
sync: true,
|
||||||
conversationId,
|
conversationId,
|
||||||
// messages: previousMessages,
|
// messages: previousMessages,
|
||||||
|
|
@ -587,7 +590,7 @@ const chatV1 = async (req, res) => {
|
||||||
iconURL: endpointOption.iconURL,
|
iconURL: endpointOption.iconURL,
|
||||||
};
|
};
|
||||||
|
|
||||||
sendMessage(res, {
|
sendEvent(res, {
|
||||||
final: true,
|
final: true,
|
||||||
conversation,
|
conversation,
|
||||||
requestMessage: {
|
requestMessage: {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
const { v4 } = require('uuid');
|
const { v4 } = require('uuid');
|
||||||
|
const { sleep } = require('@librechat/agents');
|
||||||
|
const { sendEvent } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
Time,
|
Time,
|
||||||
Constants,
|
Constants,
|
||||||
|
|
@ -22,15 +25,14 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors')
|
||||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||||
const { sendMessage, sleep, countTokens } = require('~/server/utils');
|
|
||||||
const { createRunBody } = require('~/server/services/createRunBody');
|
const { createRunBody } = require('~/server/services/createRunBody');
|
||||||
const { getTransactions } = require('~/models/Transaction');
|
const { getTransactions } = require('~/models/Transaction');
|
||||||
const { checkBalance } = require('~/models/balanceMethods');
|
const { checkBalance } = require('~/models/balanceMethods');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
const { getConvo } = require('~/models/Conversation');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
const { countTokens } = require('~/server/utils');
|
||||||
const { getModelMaxTokens } = require('~/utils');
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
const { getOpenAIClient } = require('./helpers');
|
const { getOpenAIClient } = require('./helpers');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @route POST /
|
* @route POST /
|
||||||
|
|
@ -309,7 +311,7 @@ const chatV2 = async (req, res) => {
|
||||||
await Promise.all(promises);
|
await Promise.all(promises);
|
||||||
|
|
||||||
const sendInitialResponse = () => {
|
const sendInitialResponse = () => {
|
||||||
sendMessage(res, {
|
sendEvent(res, {
|
||||||
sync: true,
|
sync: true,
|
||||||
conversationId,
|
conversationId,
|
||||||
// messages: previousMessages,
|
// messages: previousMessages,
|
||||||
|
|
@ -432,7 +434,7 @@ const chatV2 = async (req, res) => {
|
||||||
iconURL: endpointOption.iconURL,
|
iconURL: endpointOption.iconURL,
|
||||||
};
|
};
|
||||||
|
|
||||||
sendMessage(res, {
|
sendEvent(res, {
|
||||||
final: true,
|
final: true,
|
||||||
conversation,
|
conversation,
|
||||||
requestMessage: {
|
requestMessage: {
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
// errorHandler.js
|
// errorHandler.js
|
||||||
const { sendResponse } = require('~/server/utils');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { logger } = require('~/config');
|
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
|
||||||
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
|
||||||
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||||
|
const { sendResponse } = require('~/server/middleware/error');
|
||||||
|
const { getConvo } = require('~/models/Conversation');
|
||||||
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef {Object} ErrorHandlerContext
|
* @typedef {Object} ErrorHandlerContext
|
||||||
|
|
@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||||
} else if (/Files.*are invalid/.test(error.message)) {
|
} else if (/Files.*are invalid/.test(error.message)) {
|
||||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||||
endpoint === 'azureAssistants'
|
endpoint === 'azureAssistants'
|
||||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
|
||||||
: ''
|
: ''
|
||||||
}`;
|
}`;
|
||||||
return sendResponse(req, res, messageData, errorMessage);
|
return sendResponse(req, res, messageData, errorMessage);
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
const { nanoid } = require('nanoid');
|
const { nanoid } = require('nanoid');
|
||||||
const { EnvVar } = require('@librechat/agents');
|
const { EnvVar } = require('@librechat/agents');
|
||||||
|
const { checkAccess } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
Tools,
|
Tools,
|
||||||
AuthType,
|
AuthType,
|
||||||
|
|
@ -13,9 +15,8 @@ const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||||
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
|
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
|
||||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||||
const { loadTools } = require('~/app/clients/tools/util');
|
const { loadTools } = require('~/app/clients/tools/util');
|
||||||
const { checkAccess } = require('~/server/middleware');
|
const { getRoleByName } = require('~/models/Role');
|
||||||
const { getMessage } = require('~/models/Message');
|
const { getMessage } = require('~/models/Message');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const fieldsMap = {
|
const fieldsMap = {
|
||||||
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
|
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
|
||||||
|
|
@ -79,6 +80,7 @@ const verifyToolAuth = async (req, res) => {
|
||||||
throwError: false,
|
throwError: false,
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
logger.error('Error loading auth values', error);
|
||||||
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
|
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +134,12 @@ const callTool = async (req, res) => {
|
||||||
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
|
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
|
||||||
let hasAccess = true;
|
let hasAccess = true;
|
||||||
if (toolAccessPermType[toolId]) {
|
if (toolAccessPermType[toolId]) {
|
||||||
hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
|
hasAccess = await checkAccess({
|
||||||
|
user: req.user,
|
||||||
|
permissionType: toolAccessPermType[toolId],
|
||||||
|
permissions: [Permissions.USE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
if (!hasAccess) {
|
if (!hasAccess) {
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,6 @@ const startServer = async () => {
|
||||||
|
|
||||||
/* Middleware */
|
/* Middleware */
|
||||||
app.use(noIndex);
|
app.use(noIndex);
|
||||||
app.use(errorController);
|
|
||||||
app.use(express.json({ limit: '3mb' }));
|
app.use(express.json({ limit: '3mb' }));
|
||||||
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
||||||
app.use(mongoSanitize());
|
app.use(mongoSanitize());
|
||||||
|
|
@ -97,7 +96,6 @@ const startServer = async () => {
|
||||||
app.use('/api/actions', routes.actions);
|
app.use('/api/actions', routes.actions);
|
||||||
app.use('/api/keys', routes.keys);
|
app.use('/api/keys', routes.keys);
|
||||||
app.use('/api/user', routes.user);
|
app.use('/api/user', routes.user);
|
||||||
app.use('/api/ask', routes.ask);
|
|
||||||
app.use('/api/search', routes.search);
|
app.use('/api/search', routes.search);
|
||||||
app.use('/api/edit', routes.edit);
|
app.use('/api/edit', routes.edit);
|
||||||
app.use('/api/messages', routes.messages);
|
app.use('/api/messages', routes.messages);
|
||||||
|
|
@ -118,11 +116,13 @@ const startServer = async () => {
|
||||||
app.use('/api/roles', routes.roles);
|
app.use('/api/roles', routes.roles);
|
||||||
app.use('/api/agents', routes.agents);
|
app.use('/api/agents', routes.agents);
|
||||||
app.use('/api/banner', routes.banner);
|
app.use('/api/banner', routes.banner);
|
||||||
app.use('/api/bedrock', routes.bedrock);
|
|
||||||
app.use('/api/memories', routes.memories);
|
app.use('/api/memories', routes.memories);
|
||||||
app.use('/api/tags', routes.tags);
|
app.use('/api/tags', routes.tags);
|
||||||
app.use('/api/mcp', routes.mcp);
|
app.use('/api/mcp', routes.mcp);
|
||||||
|
|
||||||
|
// Add the error controller one more time after all routes
|
||||||
|
app.use(errorController);
|
||||||
|
|
||||||
app.use((req, res) => {
|
app.use((req, res) => {
|
||||||
res.set({
|
res.set({
|
||||||
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
const fs = require('fs');
|
const fs = require('fs');
|
||||||
const path = require('path');
|
|
||||||
const request = require('supertest');
|
const request = require('supertest');
|
||||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
const mongoose = require('mongoose');
|
const mongoose = require('mongoose');
|
||||||
|
|
@ -59,6 +58,30 @@ describe('Server Configuration', () => {
|
||||||
expect(response.headers['pragma']).toBe('no-cache');
|
expect(response.headers['pragma']).toBe('no-cache');
|
||||||
expect(response.headers['expires']).toBe('0');
|
expect(response.headers['expires']).toBe('0');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return 500 for unknown errors via ErrorController', async () => {
|
||||||
|
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
|
||||||
|
|
||||||
|
// Mock MongoDB operations to fail
|
||||||
|
const originalFindOne = mongoose.models.User.findOne;
|
||||||
|
const mockError = new Error('MongoDB operation failed');
|
||||||
|
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
|
||||||
|
throw mockError;
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await request(app).post('/api/auth/login').send({
|
||||||
|
email: 'test@example.com',
|
||||||
|
password: 'password123',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(500);
|
||||||
|
expect(response.text).toBe('An unknown error occurred.');
|
||||||
|
} finally {
|
||||||
|
// Restore original function
|
||||||
|
mongoose.models.User.findOne = originalFindOne;
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely
|
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
// abortMiddleware.js
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
|
||||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
|
||||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||||
|
const { sendError } = require('~/server/middleware/error');
|
||||||
const { spendTokens } = require('~/models/spendTokens');
|
const { spendTokens } = require('~/models/spendTokens');
|
||||||
const abortControllers = require('./abortControllers');
|
const abortControllers = require('./abortControllers');
|
||||||
const { saveMessage, getConvo } = require('~/models');
|
const { saveMessage, getConvo } = require('~/models');
|
||||||
const { abortRun } = require('./abortRun');
|
const { abortRun } = require('./abortRun');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const abortDataMap = new WeakMap();
|
const abortDataMap = new WeakMap();
|
||||||
|
|
||||||
|
|
@ -101,7 +101,7 @@ async function abortMessage(req, res) {
|
||||||
cleanupAbortController(abortKey);
|
cleanupAbortController(abortKey);
|
||||||
|
|
||||||
if (res.headersSent && finalEvent) {
|
if (res.headersSent && finalEvent) {
|
||||||
return sendMessage(res, finalEvent);
|
return sendEvent(res, finalEvent);
|
||||||
}
|
}
|
||||||
|
|
||||||
res.setHeader('Content-Type', 'application/json');
|
res.setHeader('Content-Type', 'application/json');
|
||||||
|
|
@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||||
* @param {string} responseMessageId
|
* @param {string} responseMessageId
|
||||||
*/
|
*/
|
||||||
const onStart = (userMessage, responseMessageId) => {
|
const onStart = (userMessage, responseMessageId) => {
|
||||||
sendMessage(res, { message: userMessage, created: true });
|
sendEvent(res, { message: userMessage, created: true });
|
||||||
|
|
||||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||||
getReqData({ abortKey });
|
getReqData({ abortKey });
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
|
const { sendEvent } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||||
const { deleteMessages } = require('~/models/Message');
|
const { deleteMessages } = require('~/models/Message');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
const { getConvo } = require('~/models/Conversation');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
const { sendMessage } = require('~/server/utils');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const three_minutes = 1000 * 60 * 3;
|
const three_minutes = 1000 * 60 * 3;
|
||||||
|
|
||||||
|
|
@ -34,7 +34,7 @@ async function abortRun(req, res) {
|
||||||
const [thread_id, run_id] = runValues.split(':');
|
const [thread_id, run_id] = runValues.split(':');
|
||||||
|
|
||||||
if (!run_id) {
|
if (!run_id) {
|
||||||
logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id });
|
logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id });
|
||||||
return res.status(204).send({ message: 'Run not found' });
|
return res.status(204).send({ message: 'Run not found' });
|
||||||
} else if (run_id === 'cancelled') {
|
} else if (run_id === 'cancelled') {
|
||||||
logger.warn('[abortRun] Run already cancelled', { thread_id });
|
logger.warn('[abortRun] Run already cancelled', { thread_id });
|
||||||
|
|
@ -93,7 +93,7 @@ async function abortRun(req, res) {
|
||||||
};
|
};
|
||||||
|
|
||||||
if (res.headersSent && finalEvent) {
|
if (res.headersSent && finalEvent) {
|
||||||
return sendMessage(res, finalEvent);
|
return sendEvent(res, finalEvent);
|
||||||
}
|
}
|
||||||
|
|
||||||
res.json(finalEvent);
|
res.json(finalEvent);
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
parseCompactConvo,
|
EndpointURLs,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
isAgentsEndpoint,
|
isAgentsEndpoint,
|
||||||
EndpointURLs,
|
parseCompactConvo,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
|
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
|
||||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
|
||||||
const assistants = require('~/server/services/Endpoints/assistants');
|
const assistants = require('~/server/services/Endpoints/assistants');
|
||||||
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
|
|
||||||
const { processFiles } = require('~/server/services/Files/process');
|
const { processFiles } = require('~/server/services/Files/process');
|
||||||
const anthropic = require('~/server/services/Endpoints/anthropic');
|
const anthropic = require('~/server/services/Endpoints/anthropic');
|
||||||
const bedrock = require('~/server/services/Endpoints/bedrock');
|
const bedrock = require('~/server/services/Endpoints/bedrock');
|
||||||
|
|
@ -25,7 +24,6 @@ const buildFunction = {
|
||||||
[EModelEndpoint.bedrock]: bedrock.buildOptions,
|
[EModelEndpoint.bedrock]: bedrock.buildOptions,
|
||||||
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
|
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
|
||||||
[EModelEndpoint.anthropic]: anthropic.buildOptions,
|
[EModelEndpoint.anthropic]: anthropic.buildOptions,
|
||||||
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
|
|
||||||
[EModelEndpoint.assistants]: assistants.buildOptions,
|
[EModelEndpoint.assistants]: assistants.buildOptions,
|
||||||
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
|
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
|
||||||
};
|
};
|
||||||
|
|
@ -36,6 +34,9 @@ async function buildEndpointOption(req, res, next) {
|
||||||
try {
|
try {
|
||||||
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
|
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
logger.warn(
|
||||||
|
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
|
||||||
|
);
|
||||||
return handleError(res, { text: 'Error parsing conversation' });
|
return handleError(res, { text: 'Error parsing conversation' });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -57,15 +58,6 @@ async function buildEndpointOption(req, res, next) {
|
||||||
return handleError(res, { text: 'Model spec mismatch' });
|
return handleError(res, { text: 'Model spec mismatch' });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins &&
|
|
||||||
currentModelSpec.preset.tools
|
|
||||||
) {
|
|
||||||
return handleError(res, {
|
|
||||||
text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
currentModelSpec.preset.spec = spec;
|
currentModelSpec.preset.spec = spec;
|
||||||
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
|
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
|
||||||
|
|
@ -77,6 +69,7 @@ async function buildEndpointOption(req, res, next) {
|
||||||
conversation: currentModelSpec.preset,
|
conversation: currentModelSpec.preset,
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
logger.error(`Error parsing model spec for endpoint ${endpoint}`, error);
|
||||||
return handleError(res, { text: 'Error parsing model spec' });
|
return handleError(res, { text: 'Error parsing model spec' });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -84,20 +77,23 @@ async function buildEndpointOption(req, res, next) {
|
||||||
try {
|
try {
|
||||||
const isAgents =
|
const isAgents =
|
||||||
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
|
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
|
||||||
const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)];
|
const builder = isAgents
|
||||||
const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn;
|
? (...args) => buildFunction[EModelEndpoint.agents](req, ...args)
|
||||||
|
: buildFunction[endpointType ?? endpoint];
|
||||||
|
|
||||||
// TODO: use object params
|
// TODO: use object params
|
||||||
req.body.endpointOption = await builder(endpoint, parsedBody, endpointType);
|
req.body.endpointOption = await builder(endpoint, parsedBody, endpointType);
|
||||||
|
|
||||||
// TODO: use `getModelsConfig` only when necessary
|
|
||||||
const modelsConfig = await getModelsConfig(req);
|
|
||||||
req.body.endpointOption.modelsConfig = modelsConfig;
|
|
||||||
if (req.body.files && !isAgents) {
|
if (req.body.files && !isAgents) {
|
||||||
req.body.endpointOption.attachments = processFiles(req.body.files);
|
req.body.endpointOption.attachments = processFiles(req.body.files);
|
||||||
}
|
}
|
||||||
|
|
||||||
next();
|
next();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
`Error building endpoint option for endpoint ${endpoint} with type ${endpointType}`,
|
||||||
|
error,
|
||||||
|
);
|
||||||
return handleError(res, { text: 'Error building endpoint option' });
|
return handleError(res, { text: 'Error building endpoint option' });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ const message = 'Your account has been temporarily banned due to violations of o
|
||||||
* @function
|
* @function
|
||||||
* @param {Object} req - Express Request object.
|
* @param {Object} req - Express Request object.
|
||||||
* @param {Object} res - Express Response object.
|
* @param {Object} res - Express Response object.
|
||||||
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
|
|
||||||
*
|
*
|
||||||
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
|
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
|
||||||
*/
|
*/
|
||||||
|
|
@ -135,6 +134,7 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||||
return await banResponse(req, res);
|
return await banResponse(req, res);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error in checkBan middleware:', error);
|
logger.error('Error in checkBan middleware:', error);
|
||||||
|
return next(error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
|
const { sendEvent } = require('@librechat/api');
|
||||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||||
const { sendMessage, sendError } = require('~/server/utils');
|
const { sendError } = require('~/server/middleware/error');
|
||||||
const { saveMessage } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -36,7 +37,7 @@ const denyRequest = async (req, res, errorMessage) => {
|
||||||
isCreatedByUser: true,
|
isCreatedByUser: true,
|
||||||
text,
|
text,
|
||||||
};
|
};
|
||||||
sendMessage(res, { message: userMessage, created: true });
|
sendEvent(res, { message: userMessage, created: true });
|
||||||
|
|
||||||
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
|
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,9 @@
|
||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { parseConvo } = require('librechat-data-provider');
|
const { parseConvo } = require('librechat-data-provider');
|
||||||
|
const { sendEvent, handleError } = require('@librechat/api');
|
||||||
const { saveMessage, getMessages } = require('~/models/Message');
|
const { saveMessage, getMessages } = require('~/models/Message');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
const { getConvo } = require('~/models/Conversation');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sends error data in Server Sent Events format and ends the response.
|
|
||||||
* @param {object} res - The server response.
|
|
||||||
* @param {string} message - The error message.
|
|
||||||
*/
|
|
||||||
const handleError = (res, message) => {
|
|
||||||
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
|
|
||||||
res.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sends message data in Server Sent Events format.
|
|
||||||
* @param {Express.Response} res - - The server response.
|
|
||||||
* @param {string | Object} message - The message to be sent.
|
|
||||||
* @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'.
|
|
||||||
*/
|
|
||||||
const sendMessage = (res, message, event = 'message') => {
|
|
||||||
if (typeof message === 'string' && message.length === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Processes an error with provided options, saves the error message and sends a corresponding SSE response
|
* Processes an error with provided options, saves the error message and sends a corresponding SSE response
|
||||||
|
|
@ -91,7 +69,7 @@ const sendError = async (req, res, options, callback) => {
|
||||||
convo = parseConvo(errorMessage);
|
convo = parseConvo(errorMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sendMessage(res, {
|
return sendEvent(res, {
|
||||||
final: true,
|
final: true,
|
||||||
requestMessage: query?.[0] ? query[0] : requestMessage,
|
requestMessage: query?.[0] ? query[0] : requestMessage,
|
||||||
responseMessage: errorMessage,
|
responseMessage: errorMessage,
|
||||||
|
|
@ -120,12 +98,10 @@ const sendResponse = (req, res, data, errorMessage) => {
|
||||||
if (errorMessage) {
|
if (errorMessage) {
|
||||||
return sendError(req, res, { ...data, text: errorMessage });
|
return sendError(req, res, { ...data, text: errorMessage });
|
||||||
}
|
}
|
||||||
return sendMessage(res, data);
|
return sendEvent(res, data);
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
sendResponse,
|
|
||||||
handleError,
|
|
||||||
sendMessage,
|
|
||||||
sendError,
|
sendError,
|
||||||
|
sendResponse,
|
||||||
};
|
};
|
||||||
95
api/server/middleware/limiters/forkLimiters.js
Normal file
95
api/server/middleware/limiters/forkLimiters.js
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
const rateLimit = require('express-rate-limit');
|
||||||
|
const { isEnabled } = require('@librechat/api');
|
||||||
|
const { RedisStore } = require('rate-limit-redis');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
|
const ioredisClient = require('~/cache/ioredisClient');
|
||||||
|
const logViolation = require('~/cache/logViolation');
|
||||||
|
|
||||||
|
const getEnvironmentVariables = () => {
|
||||||
|
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
|
||||||
|
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
|
||||||
|
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
|
||||||
|
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
|
||||||
|
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
|
||||||
|
|
||||||
|
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
|
||||||
|
const forkIpMax = FORK_IP_MAX;
|
||||||
|
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
|
||||||
|
|
||||||
|
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
|
||||||
|
const forkUserMax = FORK_USER_MAX;
|
||||||
|
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
|
||||||
|
|
||||||
|
return {
|
||||||
|
forkIpWindowMs,
|
||||||
|
forkIpMax,
|
||||||
|
forkIpWindowInMinutes,
|
||||||
|
forkUserWindowMs,
|
||||||
|
forkUserMax,
|
||||||
|
forkUserWindowInMinutes,
|
||||||
|
forkViolationScore: FORK_VIOLATION_SCORE,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createForkHandler = (ip = true) => {
|
||||||
|
const {
|
||||||
|
forkIpMax,
|
||||||
|
forkUserMax,
|
||||||
|
forkViolationScore,
|
||||||
|
forkIpWindowInMinutes,
|
||||||
|
forkUserWindowInMinutes,
|
||||||
|
} = getEnvironmentVariables();
|
||||||
|
|
||||||
|
return async (req, res) => {
|
||||||
|
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||||
|
const errorMessage = {
|
||||||
|
type,
|
||||||
|
max: ip ? forkIpMax : forkUserMax,
|
||||||
|
limiter: ip ? 'ip' : 'user',
|
||||||
|
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
|
||||||
|
};
|
||||||
|
|
||||||
|
await logViolation(req, res, type, errorMessage, forkViolationScore);
|
||||||
|
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createForkLimiters = () => {
|
||||||
|
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
|
||||||
|
|
||||||
|
const ipLimiterOptions = {
|
||||||
|
windowMs: forkIpWindowMs,
|
||||||
|
max: forkIpMax,
|
||||||
|
handler: createForkHandler(),
|
||||||
|
};
|
||||||
|
const userLimiterOptions = {
|
||||||
|
windowMs: forkUserWindowMs,
|
||||||
|
max: forkUserMax,
|
||||||
|
handler: createForkHandler(false),
|
||||||
|
keyGenerator: function (req) {
|
||||||
|
return req.user?.id;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||||
|
logger.debug('Using Redis for fork rate limiters.');
|
||||||
|
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||||
|
const ipStore = new RedisStore({
|
||||||
|
sendCommand,
|
||||||
|
prefix: 'fork_ip_limiter:',
|
||||||
|
});
|
||||||
|
const userStore = new RedisStore({
|
||||||
|
sendCommand,
|
||||||
|
prefix: 'fork_user_limiter:',
|
||||||
|
});
|
||||||
|
ipLimiterOptions.store = ipStore;
|
||||||
|
userLimiterOptions.store = userStore;
|
||||||
|
}
|
||||||
|
|
||||||
|
const forkIpLimiter = rateLimit(ipLimiterOptions);
|
||||||
|
const forkUserLimiter = rateLimit(userLimiterOptions);
|
||||||
|
return { forkIpLimiter, forkUserLimiter };
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = { createForkLimiters };
|
||||||
|
|
@ -1,16 +1,17 @@
|
||||||
const rateLimit = require('express-rate-limit');
|
const rateLimit = require('express-rate-limit');
|
||||||
|
const { isEnabled } = require('@librechat/api');
|
||||||
const { RedisStore } = require('rate-limit-redis');
|
const { RedisStore } = require('rate-limit-redis');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { ViolationTypes } = require('librechat-data-provider');
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
const ioredisClient = require('~/cache/ioredisClient');
|
const ioredisClient = require('~/cache/ioredisClient');
|
||||||
const logViolation = require('~/cache/logViolation');
|
const logViolation = require('~/cache/logViolation');
|
||||||
const { isEnabled } = require('~/server/utils');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const getEnvironmentVariables = () => {
|
const getEnvironmentVariables = () => {
|
||||||
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
|
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
|
||||||
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
|
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
|
||||||
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
|
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
|
||||||
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
|
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
|
||||||
|
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
|
||||||
|
|
||||||
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
|
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
|
||||||
const importIpMax = IMPORT_IP_MAX;
|
const importIpMax = IMPORT_IP_MAX;
|
||||||
|
|
@ -27,12 +28,18 @@ const getEnvironmentVariables = () => {
|
||||||
importUserWindowMs,
|
importUserWindowMs,
|
||||||
importUserMax,
|
importUserMax,
|
||||||
importUserWindowInMinutes,
|
importUserWindowInMinutes,
|
||||||
|
importViolationScore: IMPORT_VIOLATION_SCORE,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
const createImportHandler = (ip = true) => {
|
const createImportHandler = (ip = true) => {
|
||||||
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
|
const {
|
||||||
getEnvironmentVariables();
|
importIpMax,
|
||||||
|
importUserMax,
|
||||||
|
importViolationScore,
|
||||||
|
importIpWindowInMinutes,
|
||||||
|
importUserWindowInMinutes,
|
||||||
|
} = getEnvironmentVariables();
|
||||||
|
|
||||||
return async (req, res) => {
|
return async (req, res) => {
|
||||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||||
|
|
@ -43,7 +50,7 @@ const createImportHandler = (ip = true) => {
|
||||||
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
|
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
|
||||||
};
|
};
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage);
|
await logViolation(req, res, type, errorMessage, importViolationScore);
|
||||||
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
|
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ const createSTTLimiters = require('./sttLimiters');
|
||||||
const loginLimiter = require('./loginLimiter');
|
const loginLimiter = require('./loginLimiter');
|
||||||
const importLimiters = require('./importLimiters');
|
const importLimiters = require('./importLimiters');
|
||||||
const uploadLimiters = require('./uploadLimiters');
|
const uploadLimiters = require('./uploadLimiters');
|
||||||
|
const forkLimiters = require('./forkLimiters');
|
||||||
const registerLimiter = require('./registerLimiter');
|
const registerLimiter = require('./registerLimiter');
|
||||||
const toolCallLimiter = require('./toolCallLimiter');
|
const toolCallLimiter = require('./toolCallLimiter');
|
||||||
const messageLimiters = require('./messageLimiters');
|
const messageLimiters = require('./messageLimiters');
|
||||||
|
|
@ -14,6 +15,7 @@ module.exports = {
|
||||||
...uploadLimiters,
|
...uploadLimiters,
|
||||||
...importLimiters,
|
...importLimiters,
|
||||||
...messageLimiters,
|
...messageLimiters,
|
||||||
|
...forkLimiters,
|
||||||
loginLimiter,
|
loginLimiter,
|
||||||
registerLimiter,
|
registerLimiter,
|
||||||
toolCallLimiter,
|
toolCallLimiter,
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ const {
|
||||||
MESSAGE_IP_WINDOW = 1,
|
MESSAGE_IP_WINDOW = 1,
|
||||||
MESSAGE_USER_MAX = 40,
|
MESSAGE_USER_MAX = 40,
|
||||||
MESSAGE_USER_WINDOW = 1,
|
MESSAGE_USER_WINDOW = 1,
|
||||||
|
MESSAGE_VIOLATION_SCORE: score,
|
||||||
} = process.env;
|
} = process.env;
|
||||||
|
|
||||||
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
|
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
|
||||||
|
|
@ -39,7 +40,7 @@ const createHandler = (ip = true) => {
|
||||||
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
|
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
|
||||||
};
|
};
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage);
|
await logViolation(req, res, type, errorMessage, score);
|
||||||
return await denyRequest(req, res, errorMessage);
|
return await denyRequest(req, res, errorMessage);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||||
|
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
|
||||||
|
|
||||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||||
const sttIpMax = STT_IP_MAX;
|
const sttIpMax = STT_IP_MAX;
|
||||||
|
|
@ -27,11 +28,12 @@ const getEnvironmentVariables = () => {
|
||||||
sttUserWindowMs,
|
sttUserWindowMs,
|
||||||
sttUserMax,
|
sttUserMax,
|
||||||
sttUserWindowInMinutes,
|
sttUserWindowInMinutes,
|
||||||
|
sttViolationScore: STT_VIOLATION_SCORE,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
const createSTTHandler = (ip = true) => {
|
const createSTTHandler = (ip = true) => {
|
||||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
|
||||||
getEnvironmentVariables();
|
getEnvironmentVariables();
|
||||||
|
|
||||||
return async (req, res) => {
|
return async (req, res) => {
|
||||||
|
|
@ -43,7 +45,7 @@ const createSTTHandler = (ip = true) => {
|
||||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||||
};
|
};
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage);
|
await logViolation(req, res, type, errorMessage, sttViolationScore);
|
||||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ const logViolation = require('~/cache/logViolation');
|
||||||
const { isEnabled } = require('~/server/utils');
|
const { isEnabled } = require('~/server/utils');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
|
||||||
|
|
||||||
const handler = async (req, res) => {
|
const handler = async (req, res) => {
|
||||||
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
||||||
const errorMessage = {
|
const errorMessage = {
|
||||||
|
|
@ -15,7 +17,7 @@ const handler = async (req, res) => {
|
||||||
windowInMinutes: 1,
|
windowInMinutes: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage, 0);
|
await logViolation(req, res, type, errorMessage, score);
|
||||||
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||||
|
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
|
||||||
|
|
||||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||||
const ttsIpMax = TTS_IP_MAX;
|
const ttsIpMax = TTS_IP_MAX;
|
||||||
|
|
@ -27,11 +28,12 @@ const getEnvironmentVariables = () => {
|
||||||
ttsUserWindowMs,
|
ttsUserWindowMs,
|
||||||
ttsUserMax,
|
ttsUserMax,
|
||||||
ttsUserWindowInMinutes,
|
ttsUserWindowInMinutes,
|
||||||
|
ttsViolationScore: TTS_VIOLATION_SCORE,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
const createTTSHandler = (ip = true) => {
|
const createTTSHandler = (ip = true) => {
|
||||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
|
||||||
getEnvironmentVariables();
|
getEnvironmentVariables();
|
||||||
|
|
||||||
return async (req, res) => {
|
return async (req, res) => {
|
||||||
|
|
@ -43,7 +45,7 @@ const createTTSHandler = (ip = true) => {
|
||||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||||
};
|
};
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage);
|
await logViolation(req, res, type, errorMessage, ttsViolationScore);
|
||||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||||
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
|
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
|
||||||
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
|
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
|
||||||
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
|
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
|
||||||
|
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
|
||||||
|
|
||||||
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
|
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
|
||||||
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
|
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
|
||||||
|
|
@ -27,6 +28,7 @@ const getEnvironmentVariables = () => {
|
||||||
fileUploadUserWindowMs,
|
fileUploadUserWindowMs,
|
||||||
fileUploadUserMax,
|
fileUploadUserMax,
|
||||||
fileUploadUserWindowInMinutes,
|
fileUploadUserWindowInMinutes,
|
||||||
|
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -36,6 +38,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||||
fileUploadIpWindowInMinutes,
|
fileUploadIpWindowInMinutes,
|
||||||
fileUploadUserMax,
|
fileUploadUserMax,
|
||||||
fileUploadUserWindowInMinutes,
|
fileUploadUserWindowInMinutes,
|
||||||
|
fileUploadViolationScore,
|
||||||
} = getEnvironmentVariables();
|
} = getEnvironmentVariables();
|
||||||
|
|
||||||
return async (req, res) => {
|
return async (req, res) => {
|
||||||
|
|
@ -47,7 +50,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||||
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
|
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
|
||||||
};
|
};
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage);
|
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
|
||||||
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
|
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,78 +0,0 @@
|
||||||
const { getRoleByName } = require('~/models/Role');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Core function to check if a user has one or more required permissions
|
|
||||||
*
|
|
||||||
* @param {object} user - The user object
|
|
||||||
* @param {PermissionTypes} permissionType - The type of permission to check
|
|
||||||
* @param {Permissions[]} permissions - The list of specific permissions to check
|
|
||||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
|
|
||||||
* @param {object} [checkObject] - The object to check properties against
|
|
||||||
* @returns {Promise<boolean>} Whether the user has the required permissions
|
|
||||||
*/
|
|
||||||
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
|
|
||||||
if (!user) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const role = await getRoleByName(user.role);
|
|
||||||
if (role && role.permissions && role.permissions[permissionType]) {
|
|
||||||
const hasAnyPermission = permissions.some((permission) => {
|
|
||||||
if (role.permissions[permissionType][permission]) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bodyProps[permission] && checkObject) {
|
|
||||||
return bodyProps[permission].some((prop) =>
|
|
||||||
Object.prototype.hasOwnProperty.call(checkObject, prop),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
|
|
||||||
return hasAnyPermission;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
|
||||||
*
|
|
||||||
* @param {PermissionTypes} permissionType - The type of permission to check.
|
|
||||||
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
|
||||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
|
||||||
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
|
|
||||||
*/
|
|
||||||
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
|
||||||
return async (req, res, next) => {
|
|
||||||
try {
|
|
||||||
const hasAccess = await checkAccess(
|
|
||||||
req.user,
|
|
||||||
permissionType,
|
|
||||||
permissions,
|
|
||||||
bodyProps,
|
|
||||||
req.body,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (hasAccess) {
|
|
||||||
return next();
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.warn(
|
|
||||||
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
|
|
||||||
);
|
|
||||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(error);
|
|
||||||
return res.status(500).json({ message: `Server error: ${error.message}` });
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = {
|
|
||||||
checkAccess,
|
|
||||||
generateCheckAccess,
|
|
||||||
};
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
const checkAdmin = require('./admin');
|
const checkAdmin = require('./admin');
|
||||||
const { checkAccess, generateCheckAccess } = require('./access');
|
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
checkAdmin,
|
checkAdmin,
|
||||||
checkAccess,
|
|
||||||
generateCheckAccess,
|
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,28 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { nanoid } = require('nanoid');
|
const { nanoid } = require('nanoid');
|
||||||
const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { generateCheckAccess } = require('@librechat/api');
|
||||||
|
const {
|
||||||
|
SystemRoles,
|
||||||
|
Permissions,
|
||||||
|
PermissionTypes,
|
||||||
|
actionDelimiter,
|
||||||
|
removeNullishValues,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
|
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
|
||||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||||
const { getAgent, updateAgent } = require('~/models/Agent');
|
const { getAgent, updateAgent } = require('~/models/Agent');
|
||||||
const { logger } = require('~/config');
|
const { getRoleByName } = require('~/models/Role');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
|
const checkAgentCreate = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.AGENTS,
|
||||||
|
permissions: [Permissions.USE, Permissions.CREATE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
// If the user has ADMIN role
|
// If the user has ADMIN role
|
||||||
// then action edition is possible even if not owner of the assistant
|
// then action edition is possible even if not owner of the assistant
|
||||||
const isAdmin = (req) => {
|
const isAdmin = (req) => {
|
||||||
|
|
@ -41,7 +55,7 @@ router.get('/', async (req, res) => {
|
||||||
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
|
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
|
||||||
* @returns {Object} 200 - success response - application/json
|
* @returns {Object} 200 - success response - application/json
|
||||||
*/
|
*/
|
||||||
router.post('/:agent_id', async (req, res) => {
|
router.post('/:agent_id', checkAgentCreate, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { agent_id } = req.params;
|
const { agent_id } = req.params;
|
||||||
|
|
||||||
|
|
@ -149,7 +163,7 @@ router.post('/:agent_id', async (req, res) => {
|
||||||
* @param {string} req.params.action_id - The ID of the action to delete.
|
* @param {string} req.params.action_id - The ID of the action to delete.
|
||||||
* @returns {Object} 200 - success response - application/json
|
* @returns {Object} 200 - success response - application/json
|
||||||
*/
|
*/
|
||||||
router.delete('/:agent_id/:action_id', async (req, res) => {
|
router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { agent_id, action_id } = req.params;
|
const { agent_id, action_id } = req.params;
|
||||||
const admin = isAdmin(req);
|
const admin = isAdmin(req);
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,28 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
|
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
|
||||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
setHeaders,
|
setHeaders,
|
||||||
moderateText,
|
moderateText,
|
||||||
// validateModel,
|
// validateModel,
|
||||||
generateCheckAccess,
|
|
||||||
validateConvoAccess,
|
validateConvoAccess,
|
||||||
buildEndpointOption,
|
buildEndpointOption,
|
||||||
} = require('~/server/middleware');
|
} = require('~/server/middleware');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/agents');
|
const { initializeClient } = require('~/server/services/Endpoints/agents');
|
||||||
const AgentController = require('~/server/controllers/agents/request');
|
const AgentController = require('~/server/controllers/agents/request');
|
||||||
const addTitle = require('~/server/services/Endpoints/agents/title');
|
const addTitle = require('~/server/services/Endpoints/agents/title');
|
||||||
|
const { getRoleByName } = require('~/models/Role');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
router.use(moderateText);
|
router.use(moderateText);
|
||||||
|
|
||||||
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
const checkAgentAccess = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.AGENTS,
|
||||||
|
permissions: [Permissions.USE],
|
||||||
|
skipCheck: skipAgentCheck,
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
router.use(checkAgentAccess);
|
router.use(checkAgentAccess);
|
||||||
router.use(validateConvoAccess);
|
router.use(validateConvoAccess);
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,36 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
|
const { generateCheckAccess } = require('@librechat/api');
|
||||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
const { requireJwtAuth } = require('~/server/middleware');
|
||||||
const v1 = require('~/server/controllers/agents/v1');
|
const v1 = require('~/server/controllers/agents/v1');
|
||||||
|
const { getRoleByName } = require('~/models/Role');
|
||||||
const actions = require('./actions');
|
const actions = require('./actions');
|
||||||
const tools = require('./tools');
|
const tools = require('./tools');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const avatar = express.Router();
|
const avatar = express.Router();
|
||||||
|
|
||||||
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
const checkAgentAccess = generateCheckAccess({
|
||||||
const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [
|
permissionType: PermissionTypes.AGENTS,
|
||||||
Permissions.USE,
|
permissions: [Permissions.USE],
|
||||||
Permissions.CREATE,
|
getRoleByName,
|
||||||
]);
|
});
|
||||||
|
const checkAgentCreate = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.AGENTS,
|
||||||
|
permissions: [Permissions.USE, Permissions.CREATE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
const checkGlobalAgentShare = generateCheckAccess(
|
const checkGlobalAgentShare = generateCheckAccess({
|
||||||
PermissionTypes.AGENTS,
|
permissionType: PermissionTypes.AGENTS,
|
||||||
[Permissions.USE, Permissions.CREATE],
|
permissions: [Permissions.USE, Permissions.CREATE],
|
||||||
{
|
bodyProps: {
|
||||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||||
},
|
},
|
||||||
);
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
router.use(requireJwtAuth);
|
router.use(requireJwtAuth);
|
||||||
router.use(checkAgentAccess);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Agent actions route.
|
* Agent actions route.
|
||||||
|
|
|
||||||
|
|
@ -1,63 +0,0 @@
|
||||||
const { Keyv } = require('keyv');
|
|
||||||
const { KeyvFile } = require('keyv-file');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => {
|
|
||||||
try {
|
|
||||||
const conversationsCache = new Keyv({
|
|
||||||
store: new KeyvFile({ filename: './data/cache.json' }),
|
|
||||||
namespace: 'chatgpt', // should be 'bing' for bing/sydney
|
|
||||||
});
|
|
||||||
|
|
||||||
const {
|
|
||||||
conversationId,
|
|
||||||
messageId: userMessageId,
|
|
||||||
parentMessageId: userParentMessageId,
|
|
||||||
text: userText,
|
|
||||||
} = userMessage;
|
|
||||||
const {
|
|
||||||
messageId: responseMessageId,
|
|
||||||
parentMessageId: responseParentMessageId,
|
|
||||||
text: responseText,
|
|
||||||
} = responseMessage;
|
|
||||||
|
|
||||||
let conversation = await conversationsCache.get(conversationId);
|
|
||||||
// used to generate a title for the conversation if none exists
|
|
||||||
// let isNewConversation = false;
|
|
||||||
if (!conversation) {
|
|
||||||
conversation = {
|
|
||||||
messages: [],
|
|
||||||
createdAt: Date.now(),
|
|
||||||
};
|
|
||||||
// isNewConversation = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const roles = (options) => {
|
|
||||||
if (endpoint === 'openAI') {
|
|
||||||
return options?.chatGptLabel || 'ChatGPT';
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let _userMessage = {
|
|
||||||
id: userMessageId,
|
|
||||||
parentMessageId: userParentMessageId,
|
|
||||||
role: 'User',
|
|
||||||
message: userText,
|
|
||||||
};
|
|
||||||
|
|
||||||
let _responseMessage = {
|
|
||||||
id: responseMessageId,
|
|
||||||
parentMessageId: responseParentMessageId,
|
|
||||||
role: roles(endpointOption),
|
|
||||||
message: responseText,
|
|
||||||
};
|
|
||||||
|
|
||||||
conversation.messages.push(_userMessage, _responseMessage);
|
|
||||||
|
|
||||||
await conversationsCache.set(conversationId, conversation);
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[addToCache] Error adding conversation to cache', error);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = addToCache;
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const AskController = require('~/server/controllers/AskController');
|
|
||||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/anthropic');
|
|
||||||
const {
|
|
||||||
setHeaders,
|
|
||||||
handleAbort,
|
|
||||||
validateModel,
|
|
||||||
validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
|
|
||||||
router.post(
|
|
||||||
'/',
|
|
||||||
validateEndpoint,
|
|
||||||
validateModel,
|
|
||||||
buildEndpointOption,
|
|
||||||
setHeaders,
|
|
||||||
async (req, res, next) => {
|
|
||||||
await AskController(req, res, next, initializeClient, addTitle);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const AskController = require('~/server/controllers/AskController');
|
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/custom');
|
|
||||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
|
||||||
const {
|
|
||||||
setHeaders,
|
|
||||||
validateModel,
|
|
||||||
validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
|
|
||||||
router.post(
|
|
||||||
'/',
|
|
||||||
validateEndpoint,
|
|
||||||
validateModel,
|
|
||||||
buildEndpointOption,
|
|
||||||
setHeaders,
|
|
||||||
async (req, res, next) => {
|
|
||||||
await AskController(req, res, next, initializeClient, addTitle);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const AskController = require('~/server/controllers/AskController');
|
|
||||||
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
|
|
||||||
const {
|
|
||||||
setHeaders,
|
|
||||||
validateModel,
|
|
||||||
validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
|
|
||||||
router.post(
|
|
||||||
'/',
|
|
||||||
validateEndpoint,
|
|
||||||
validateModel,
|
|
||||||
buildEndpointOption,
|
|
||||||
setHeaders,
|
|
||||||
async (req, res, next) => {
|
|
||||||
await AskController(req, res, next, initializeClient, addTitle);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,241 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
|
||||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
|
||||||
const { saveMessage, updateMessage } = require('~/models');
|
|
||||||
const {
|
|
||||||
handleAbort,
|
|
||||||
createAbortController,
|
|
||||||
handleAbortError,
|
|
||||||
setHeaders,
|
|
||||||
validateModel,
|
|
||||||
validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
moderateText,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
const { validateTools } = require('~/app');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
|
|
||||||
router.use(moderateText);
|
|
||||||
|
|
||||||
router.post(
|
|
||||||
'/',
|
|
||||||
validateEndpoint,
|
|
||||||
validateModel,
|
|
||||||
buildEndpointOption,
|
|
||||||
setHeaders,
|
|
||||||
async (req, res) => {
|
|
||||||
let {
|
|
||||||
text,
|
|
||||||
endpointOption,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId = null,
|
|
||||||
overrideParentMessageId = null,
|
|
||||||
} = req.body;
|
|
||||||
|
|
||||||
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
|
|
||||||
|
|
||||||
let userMessage;
|
|
||||||
let userMessagePromise;
|
|
||||||
let promptTokens;
|
|
||||||
let userMessageId;
|
|
||||||
let responseMessageId;
|
|
||||||
const sender = getResponseSender({
|
|
||||||
...endpointOption,
|
|
||||||
model: endpointOption.modelOptions.model,
|
|
||||||
});
|
|
||||||
const newConvo = !conversationId;
|
|
||||||
const user = req.user.id;
|
|
||||||
|
|
||||||
const plugins = [];
|
|
||||||
|
|
||||||
const getReqData = (data = {}) => {
|
|
||||||
for (let key in data) {
|
|
||||||
if (key === 'userMessage') {
|
|
||||||
userMessage = data[key];
|
|
||||||
userMessageId = data[key].messageId;
|
|
||||||
} else if (key === 'userMessagePromise') {
|
|
||||||
userMessagePromise = data[key];
|
|
||||||
} else if (key === 'responseMessageId') {
|
|
||||||
responseMessageId = data[key];
|
|
||||||
} else if (key === 'promptTokens') {
|
|
||||||
promptTokens = data[key];
|
|
||||||
} else if (!conversationId && key === 'conversationId') {
|
|
||||||
conversationId = data[key];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let streaming = null;
|
|
||||||
let timer = null;
|
|
||||||
|
|
||||||
const {
|
|
||||||
onProgress: progressCallback,
|
|
||||||
sendIntermediateMessage,
|
|
||||||
getPartialText,
|
|
||||||
} = createOnProgress({
|
|
||||||
onProgress: () => {
|
|
||||||
if (timer) {
|
|
||||||
clearTimeout(timer);
|
|
||||||
}
|
|
||||||
|
|
||||||
streaming = new Promise((resolve) => {
|
|
||||||
timer = setTimeout(() => {
|
|
||||||
resolve();
|
|
||||||
}, 250);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const pluginMap = new Map();
|
|
||||||
const onAgentAction = async (action, runId) => {
|
|
||||||
pluginMap.set(runId, action.tool);
|
|
||||||
sendIntermediateMessage(res, {
|
|
||||||
plugins,
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
const onToolStart = async (tool, input, runId, parentRunId) => {
|
|
||||||
const pluginName = pluginMap.get(parentRunId);
|
|
||||||
const latestPlugin = {
|
|
||||||
runId,
|
|
||||||
loading: true,
|
|
||||||
inputs: [input],
|
|
||||||
latest: pluginName,
|
|
||||||
outputs: null,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (streaming) {
|
|
||||||
await streaming;
|
|
||||||
}
|
|
||||||
const extraTokens = ':::plugin:::\n';
|
|
||||||
plugins.push(latestPlugin);
|
|
||||||
sendIntermediateMessage(
|
|
||||||
res,
|
|
||||||
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
|
|
||||||
extraTokens,
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const onToolEnd = async (output, runId) => {
|
|
||||||
if (streaming) {
|
|
||||||
await streaming;
|
|
||||||
}
|
|
||||||
|
|
||||||
const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
|
|
||||||
|
|
||||||
if (pluginIndex !== -1) {
|
|
||||||
plugins[pluginIndex].loading = false;
|
|
||||||
plugins[pluginIndex].outputs = output;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const getAbortData = () => ({
|
|
||||||
sender,
|
|
||||||
conversationId,
|
|
||||||
userMessagePromise,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
|
||||||
text: getPartialText(),
|
|
||||||
plugins: plugins.map((p) => ({ ...p, loading: false })),
|
|
||||||
userMessage,
|
|
||||||
promptTokens,
|
|
||||||
});
|
|
||||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
|
||||||
|
|
||||||
try {
|
|
||||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
|
||||||
const { client } = await initializeClient({ req, res, endpointOption });
|
|
||||||
|
|
||||||
const onChainEnd = () => {
|
|
||||||
if (!client.skipSaveUserMessage) {
|
|
||||||
saveMessage(
|
|
||||||
req,
|
|
||||||
{ ...userMessage, user },
|
|
||||||
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
sendIntermediateMessage(res, {
|
|
||||||
plugins,
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = await client.sendMessage(text, {
|
|
||||||
user,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
overrideParentMessageId,
|
|
||||||
getReqData,
|
|
||||||
onAgentAction,
|
|
||||||
onChainEnd,
|
|
||||||
onToolStart,
|
|
||||||
onToolEnd,
|
|
||||||
onStart,
|
|
||||||
getPartialText,
|
|
||||||
...endpointOption,
|
|
||||||
progressCallback,
|
|
||||||
progressOptions: {
|
|
||||||
res,
|
|
||||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
plugins,
|
|
||||||
},
|
|
||||||
abortController,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (overrideParentMessageId) {
|
|
||||||
response.parentMessageId = overrideParentMessageId;
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('[/ask/gptPlugins]', response);
|
|
||||||
|
|
||||||
const { conversation = {} } = await response.databasePromise;
|
|
||||||
delete response.databasePromise;
|
|
||||||
conversation.title =
|
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
|
||||||
|
|
||||||
sendMessage(res, {
|
|
||||||
title: conversation.title,
|
|
||||||
final: true,
|
|
||||||
conversation,
|
|
||||||
requestMessage: userMessage,
|
|
||||||
responseMessage: response,
|
|
||||||
});
|
|
||||||
res.end();
|
|
||||||
|
|
||||||
if (parentMessageId === Constants.NO_PARENT && newConvo) {
|
|
||||||
addTitle(req, {
|
|
||||||
text,
|
|
||||||
response,
|
|
||||||
client,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
|
||||||
if (response.plugins?.length > 0) {
|
|
||||||
await updateMessage(
|
|
||||||
req,
|
|
||||||
{ ...response, user },
|
|
||||||
{ context: 'api/server/routes/ask/gptPlugins.js - save plugins used' },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
const partialText = getPartialText();
|
|
||||||
handleAbortError(res, req, error, {
|
|
||||||
partialText,
|
|
||||||
conversationId,
|
|
||||||
sender,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
parentMessageId: userMessageId ?? parentMessageId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const { EModelEndpoint } = require('librechat-data-provider');
|
|
||||||
const {
|
|
||||||
uaParser,
|
|
||||||
checkBan,
|
|
||||||
requireJwtAuth,
|
|
||||||
messageIpLimiter,
|
|
||||||
concurrentLimiter,
|
|
||||||
messageUserLimiter,
|
|
||||||
validateConvoAccess,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
const { isEnabled } = require('~/server/utils');
|
|
||||||
const gptPlugins = require('./gptPlugins');
|
|
||||||
const anthropic = require('./anthropic');
|
|
||||||
const custom = require('./custom');
|
|
||||||
const google = require('./google');
|
|
||||||
const openAI = require('./openAI');
|
|
||||||
|
|
||||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
|
|
||||||
router.use(requireJwtAuth);
|
|
||||||
router.use(checkBan);
|
|
||||||
router.use(uaParser);
|
|
||||||
|
|
||||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
|
||||||
router.use(concurrentLimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
|
||||||
router.use(messageIpLimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
|
||||||
router.use(messageUserLimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
router.use(validateConvoAccess);
|
|
||||||
|
|
||||||
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
|
|
||||||
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
|
|
||||||
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
|
|
||||||
router.use(`/${EModelEndpoint.google}`, google);
|
|
||||||
router.use(`/${EModelEndpoint.custom}`, custom);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const AskController = require('~/server/controllers/AskController');
|
|
||||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI');
|
|
||||||
const {
|
|
||||||
handleAbort,
|
|
||||||
setHeaders,
|
|
||||||
validateModel,
|
|
||||||
validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
moderateText,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
router.use(moderateText);
|
|
||||||
|
|
||||||
router.post(
|
|
||||||
'/',
|
|
||||||
validateEndpoint,
|
|
||||||
validateModel,
|
|
||||||
buildEndpointOption,
|
|
||||||
setHeaders,
|
|
||||||
async (req, res, next) => {
|
|
||||||
await AskController(req, res, next, initializeClient, addTitle);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
const {
|
|
||||||
setHeaders,
|
|
||||||
handleAbort,
|
|
||||||
moderateText,
|
|
||||||
// validateModel,
|
|
||||||
// validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/bedrock');
|
|
||||||
const AgentController = require('~/server/controllers/agents/request');
|
|
||||||
const addTitle = require('~/server/services/Endpoints/agents/title');
|
|
||||||
|
|
||||||
router.use(moderateText);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @route POST /
|
|
||||||
* @desc Chat with an assistant
|
|
||||||
* @access Public
|
|
||||||
* @param {express.Request} req - The request object, containing the request data.
|
|
||||||
* @param {express.Response} res - The response object, used to send back a response.
|
|
||||||
* @returns {void}
|
|
||||||
*/
|
|
||||||
router.post(
|
|
||||||
'/',
|
|
||||||
// validateModel,
|
|
||||||
// validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
setHeaders,
|
|
||||||
async (req, res, next) => {
|
|
||||||
await AgentController(req, res, next, initializeClient, addTitle);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const {
|
|
||||||
uaParser,
|
|
||||||
checkBan,
|
|
||||||
requireJwtAuth,
|
|
||||||
messageIpLimiter,
|
|
||||||
concurrentLimiter,
|
|
||||||
messageUserLimiter,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
const { isEnabled } = require('~/server/utils');
|
|
||||||
const chat = require('./chat');
|
|
||||||
|
|
||||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
|
|
||||||
router.use(requireJwtAuth);
|
|
||||||
router.use(checkBan);
|
|
||||||
router.use(uaParser);
|
|
||||||
|
|
||||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
|
||||||
router.use(concurrentLimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
|
||||||
router.use(messageIpLimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
|
||||||
router.use(messageUserLimiter);
|
|
||||||
}
|
|
||||||
|
|
||||||
router.use('/chat', chat);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -1,16 +1,17 @@
|
||||||
const multer = require('multer');
|
const multer = require('multer');
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
|
const { sleep } = require('@librechat/agents');
|
||||||
|
const { isEnabled } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||||
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
||||||
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
|
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
|
||||||
|
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
|
||||||
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
||||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||||
const { importConversations } = require('~/server/utils/import');
|
const { importConversations } = require('~/server/utils/import');
|
||||||
const { createImportLimiters } = require('~/server/middleware');
|
|
||||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||||
const { isEnabled, sleep } = require('~/server/utils');
|
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const assistantClients = {
|
const assistantClients = {
|
||||||
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
|
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
|
||||||
|
|
@ -43,6 +44,7 @@ router.get('/', async (req, res) => {
|
||||||
});
|
});
|
||||||
res.status(200).json(result);
|
res.status(200).json(result);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
logger.error('Error fetching conversations', error);
|
||||||
res.status(500).json({ error: 'Error fetching conversations' });
|
res.status(500).json({ error: 'Error fetching conversations' });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
@ -156,6 +158,7 @@ router.post('/update', async (req, res) => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const { importIpLimiter, importUserLimiter } = createImportLimiters();
|
const { importIpLimiter, importUserLimiter } = createImportLimiters();
|
||||||
|
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
|
||||||
const upload = multer({ storage: storage, fileFilter: importFileFilter });
|
const upload = multer({ storage: storage, fileFilter: importFileFilter });
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -189,7 +192,7 @@ router.post(
|
||||||
* @param {express.Response<TForkConvoResponse>} res - Express response object.
|
* @param {express.Response<TForkConvoResponse>} res - Express response object.
|
||||||
* @returns {Promise<void>} - The response after forking the conversation.
|
* @returns {Promise<void>} - The response after forking the conversation.
|
||||||
*/
|
*/
|
||||||
router.post('/fork', async (req, res) => {
|
router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
/** @type {TForkConvoRequest} */
|
/** @type {TForkConvoRequest} */
|
||||||
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;
|
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;
|
||||||
|
|
|
||||||
|
|
@ -1,207 +0,0 @@
|
||||||
const express = require('express');
|
|
||||||
const { getResponseSender } = require('librechat-data-provider');
|
|
||||||
const {
|
|
||||||
setHeaders,
|
|
||||||
moderateText,
|
|
||||||
validateModel,
|
|
||||||
handleAbortError,
|
|
||||||
validateEndpoint,
|
|
||||||
buildEndpointOption,
|
|
||||||
createAbortController,
|
|
||||||
} = require('~/server/middleware');
|
|
||||||
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
|
||||||
const { saveMessage, updateMessage } = require('~/models');
|
|
||||||
const { validateTools } = require('~/app');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const router = express.Router();
|
|
||||||
|
|
||||||
router.use(moderateText);
|
|
||||||
|
|
||||||
router.post(
|
|
||||||
'/',
|
|
||||||
validateEndpoint,
|
|
||||||
validateModel,
|
|
||||||
buildEndpointOption,
|
|
||||||
setHeaders,
|
|
||||||
async (req, res) => {
|
|
||||||
let {
|
|
||||||
text,
|
|
||||||
generation,
|
|
||||||
endpointOption,
|
|
||||||
conversationId,
|
|
||||||
responseMessageId,
|
|
||||||
isContinued = false,
|
|
||||||
parentMessageId = null,
|
|
||||||
overrideParentMessageId = null,
|
|
||||||
} = req.body;
|
|
||||||
|
|
||||||
logger.debug('[/edit/gptPlugins]', {
|
|
||||||
text,
|
|
||||||
generation,
|
|
||||||
isContinued,
|
|
||||||
conversationId,
|
|
||||||
...endpointOption,
|
|
||||||
});
|
|
||||||
|
|
||||||
let userMessage;
|
|
||||||
let userMessagePromise;
|
|
||||||
let promptTokens;
|
|
||||||
const sender = getResponseSender({
|
|
||||||
...endpointOption,
|
|
||||||
model: endpointOption.modelOptions.model,
|
|
||||||
});
|
|
||||||
const userMessageId = parentMessageId;
|
|
||||||
const user = req.user.id;
|
|
||||||
|
|
||||||
const plugin = {
|
|
||||||
loading: true,
|
|
||||||
inputs: [],
|
|
||||||
latest: null,
|
|
||||||
outputs: null,
|
|
||||||
};
|
|
||||||
|
|
||||||
const getReqData = (data = {}) => {
|
|
||||||
for (let key in data) {
|
|
||||||
if (key === 'userMessage') {
|
|
||||||
userMessage = data[key];
|
|
||||||
} else if (key === 'userMessagePromise') {
|
|
||||||
userMessagePromise = data[key];
|
|
||||||
} else if (key === 'responseMessageId') {
|
|
||||||
responseMessageId = data[key];
|
|
||||||
} else if (key === 'promptTokens') {
|
|
||||||
promptTokens = data[key];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const {
|
|
||||||
onProgress: progressCallback,
|
|
||||||
sendIntermediateMessage,
|
|
||||||
getPartialText,
|
|
||||||
} = createOnProgress({
|
|
||||||
generation,
|
|
||||||
onProgress: () => {
|
|
||||||
if (plugin.loading === true) {
|
|
||||||
plugin.loading = false;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const onChainEnd = (data) => {
|
|
||||||
let { intermediateSteps: steps } = data;
|
|
||||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
|
||||||
plugin.loading = false;
|
|
||||||
saveMessage(
|
|
||||||
req,
|
|
||||||
{ ...userMessage, user },
|
|
||||||
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
|
|
||||||
);
|
|
||||||
sendIntermediateMessage(res, {
|
|
||||||
plugin,
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
});
|
|
||||||
// logger.debug('CHAIN END', plugin.outputs);
|
|
||||||
};
|
|
||||||
|
|
||||||
const getAbortData = () => ({
|
|
||||||
sender,
|
|
||||||
conversationId,
|
|
||||||
userMessagePromise,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
|
||||||
text: getPartialText(),
|
|
||||||
plugin: { ...plugin, loading: false },
|
|
||||||
userMessage,
|
|
||||||
promptTokens,
|
|
||||||
});
|
|
||||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
|
||||||
|
|
||||||
try {
|
|
||||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
|
||||||
const { client } = await initializeClient({ req, res, endpointOption });
|
|
||||||
|
|
||||||
const onAgentAction = (action, start = false) => {
|
|
||||||
const formattedAction = formatAction(action);
|
|
||||||
plugin.inputs.push(formattedAction);
|
|
||||||
plugin.latest = formattedAction.plugin;
|
|
||||||
if (!start && !client.skipSaveUserMessage) {
|
|
||||||
saveMessage(
|
|
||||||
req,
|
|
||||||
{ ...userMessage, user },
|
|
||||||
{ context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
sendIntermediateMessage(res, {
|
|
||||||
plugin,
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
});
|
|
||||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = await client.sendMessage(text, {
|
|
||||||
user,
|
|
||||||
generation,
|
|
||||||
isContinued,
|
|
||||||
isEdited: true,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId,
|
|
||||||
responseMessageId,
|
|
||||||
overrideParentMessageId,
|
|
||||||
getReqData,
|
|
||||||
onAgentAction,
|
|
||||||
onChainEnd,
|
|
||||||
onStart,
|
|
||||||
...endpointOption,
|
|
||||||
progressCallback,
|
|
||||||
progressOptions: {
|
|
||||||
res,
|
|
||||||
plugin,
|
|
||||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
},
|
|
||||||
abortController,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (overrideParentMessageId) {
|
|
||||||
response.parentMessageId = overrideParentMessageId;
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
|
|
||||||
|
|
||||||
const { conversation = {} } = await response.databasePromise;
|
|
||||||
delete response.databasePromise;
|
|
||||||
conversation.title =
|
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
|
||||||
|
|
||||||
sendMessage(res, {
|
|
||||||
title: conversation.title,
|
|
||||||
final: true,
|
|
||||||
conversation,
|
|
||||||
requestMessage: userMessage,
|
|
||||||
responseMessage: response,
|
|
||||||
});
|
|
||||||
res.end();
|
|
||||||
|
|
||||||
response.plugin = { ...plugin, loading: false };
|
|
||||||
await updateMessage(
|
|
||||||
req,
|
|
||||||
{ ...response, user },
|
|
||||||
{ context: 'api/server/routes/edit/gptPlugins.js' },
|
|
||||||
);
|
|
||||||
} catch (error) {
|
|
||||||
const partialText = getPartialText();
|
|
||||||
handleAbortError(res, req, error, {
|
|
||||||
partialText,
|
|
||||||
conversationId,
|
|
||||||
sender,
|
|
||||||
messageId: responseMessageId,
|
|
||||||
parentMessageId: userMessageId ?? parentMessageId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
module.exports = router;
|
|
||||||
|
|
@ -3,7 +3,6 @@ const openAI = require('./openAI');
|
||||||
const custom = require('./custom');
|
const custom = require('./custom');
|
||||||
const google = require('./google');
|
const google = require('./google');
|
||||||
const anthropic = require('./anthropic');
|
const anthropic = require('./anthropic');
|
||||||
const gptPlugins = require('./gptPlugins');
|
|
||||||
const { isEnabled } = require('~/server/utils');
|
const { isEnabled } = require('~/server/utils');
|
||||||
const { EModelEndpoint } = require('librechat-data-provider');
|
const { EModelEndpoint } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
|
|
@ -39,7 +38,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||||
router.use(validateConvoAccess);
|
router.use(validateConvoAccess);
|
||||||
|
|
||||||
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
|
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
|
||||||
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
|
|
||||||
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
|
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
|
||||||
router.use(`/${EModelEndpoint.google}`, google);
|
router.use(`/${EModelEndpoint.google}`, google);
|
||||||
router.use(`/${EModelEndpoint.custom}`, custom);
|
router.use(`/${EModelEndpoint.custom}`, custom);
|
||||||
|
|
|
||||||
|
|
@ -283,7 +283,10 @@ router.post('/', async (req, res) => {
|
||||||
message += ': ' + error.message;
|
message += ': ' + error.message;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error.message?.includes('Invalid file format')) {
|
if (
|
||||||
|
error.message?.includes('Invalid file format') ||
|
||||||
|
error.message?.includes('No OCR result')
|
||||||
|
) {
|
||||||
message = error.message;
|
message = error.message;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -477,7 +477,9 @@ describe('Multer Configuration', () => {
|
||||||
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
|
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
|
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
|
||||||
expect(error.code).toBe('EACCES');
|
// On Linux, this typically returns EACCES (permission denied)
|
||||||
|
// On macOS/Darwin, this returns ENOENT (no such file or directory)
|
||||||
|
expect(['EACCES', 'ENOENT']).toContain(error.code);
|
||||||
done();
|
done();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ const presets = require('./presets');
|
||||||
const prompts = require('./prompts');
|
const prompts = require('./prompts');
|
||||||
const balance = require('./balance');
|
const balance = require('./balance');
|
||||||
const plugins = require('./plugins');
|
const plugins = require('./plugins');
|
||||||
const bedrock = require('./bedrock');
|
|
||||||
const actions = require('./actions');
|
const actions = require('./actions');
|
||||||
const banner = require('./banner');
|
const banner = require('./banner');
|
||||||
const search = require('./search');
|
const search = require('./search');
|
||||||
|
|
@ -26,11 +25,9 @@ const auth = require('./auth');
|
||||||
const edit = require('./edit');
|
const edit = require('./edit');
|
||||||
const keys = require('./keys');
|
const keys = require('./keys');
|
||||||
const user = require('./user');
|
const user = require('./user');
|
||||||
const ask = require('./ask');
|
|
||||||
const mcp = require('./mcp');
|
const mcp = require('./mcp');
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
ask,
|
|
||||||
edit,
|
edit,
|
||||||
auth,
|
auth,
|
||||||
keys,
|
keys,
|
||||||
|
|
@ -46,7 +43,6 @@ module.exports = {
|
||||||
search,
|
search,
|
||||||
config,
|
config,
|
||||||
models,
|
models,
|
||||||
bedrock,
|
|
||||||
prompts,
|
prompts,
|
||||||
plugins,
|
plugins,
|
||||||
actions,
|
actions,
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,43 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { Tokenizer } = require('@librechat/api');
|
const { Tokenizer, generateCheckAccess } = require('@librechat/api');
|
||||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
getAllUserMemories,
|
getAllUserMemories,
|
||||||
toggleUserMemories,
|
toggleUserMemories,
|
||||||
createMemory,
|
createMemory,
|
||||||
setMemory,
|
|
||||||
deleteMemory,
|
deleteMemory,
|
||||||
|
setMemory,
|
||||||
} = require('~/models');
|
} = require('~/models');
|
||||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
const { requireJwtAuth } = require('~/server/middleware');
|
||||||
|
const { getRoleByName } = require('~/models/Role');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [
|
const checkMemoryRead = generateCheckAccess({
|
||||||
Permissions.USE,
|
permissionType: PermissionTypes.MEMORIES,
|
||||||
Permissions.READ,
|
permissions: [Permissions.USE, Permissions.READ],
|
||||||
]);
|
getRoleByName,
|
||||||
const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
});
|
||||||
Permissions.USE,
|
const checkMemoryCreate = generateCheckAccess({
|
||||||
Permissions.CREATE,
|
permissionType: PermissionTypes.MEMORIES,
|
||||||
]);
|
permissions: [Permissions.USE, Permissions.CREATE],
|
||||||
const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
getRoleByName,
|
||||||
Permissions.USE,
|
});
|
||||||
Permissions.UPDATE,
|
const checkMemoryUpdate = generateCheckAccess({
|
||||||
]);
|
permissionType: PermissionTypes.MEMORIES,
|
||||||
const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [
|
permissions: [Permissions.USE, Permissions.UPDATE],
|
||||||
Permissions.USE,
|
getRoleByName,
|
||||||
Permissions.UPDATE,
|
});
|
||||||
]);
|
const checkMemoryDelete = generateCheckAccess({
|
||||||
const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [
|
permissionType: PermissionTypes.MEMORIES,
|
||||||
Permissions.USE,
|
permissions: [Permissions.USE, Permissions.UPDATE],
|
||||||
Permissions.OPT_OUT,
|
getRoleByName,
|
||||||
]);
|
});
|
||||||
|
const checkMemoryOptOut = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.MEMORIES,
|
||||||
|
permissions: [Permissions.USE, Permissions.OPT_OUT],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
router.use(requireJwtAuth);
|
router.use(requireJwtAuth);
|
||||||
|
|
||||||
|
|
@ -166,40 +172,68 @@ router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
|
||||||
/**
|
/**
|
||||||
* PATCH /memories/:key
|
* PATCH /memories/:key
|
||||||
* Updates the value of an existing memory entry for the authenticated user.
|
* Updates the value of an existing memory entry for the authenticated user.
|
||||||
* Body: { value: string }
|
* Body: { key?: string, value: string }
|
||||||
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
|
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
|
||||||
*/
|
*/
|
||||||
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
|
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
|
||||||
const { key } = req.params;
|
const { key: urlKey } = req.params;
|
||||||
const { value } = req.body || {};
|
const { key: bodyKey, value } = req.body || {};
|
||||||
|
|
||||||
if (typeof value !== 'string' || value.trim() === '') {
|
if (typeof value !== 'string' || value.trim() === '') {
|
||||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use the key from the body if provided, otherwise use the key from the URL
|
||||||
|
const newKey = bodyKey || urlKey;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||||
|
|
||||||
const memories = await getAllUserMemories(req.user.id);
|
const memories = await getAllUserMemories(req.user.id);
|
||||||
const existingMemory = memories.find((m) => m.key === key);
|
const existingMemory = memories.find((m) => m.key === urlKey);
|
||||||
|
|
||||||
if (!existingMemory) {
|
if (!existingMemory) {
|
||||||
return res.status(404).json({ error: 'Memory not found.' });
|
return res.status(404).json({ error: 'Memory not found.' });
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await setMemory({
|
// If the key is changing, we need to handle it specially
|
||||||
userId: req.user.id,
|
if (newKey !== urlKey) {
|
||||||
key,
|
const keyExists = memories.find((m) => m.key === newKey);
|
||||||
value,
|
if (keyExists) {
|
||||||
tokenCount,
|
return res.status(409).json({ error: 'Memory with this key already exists.' });
|
||||||
});
|
}
|
||||||
|
|
||||||
if (!result.ok) {
|
const createResult = await createMemory({
|
||||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
userId: req.user.id,
|
||||||
|
key: newKey,
|
||||||
|
value,
|
||||||
|
tokenCount,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!createResult.ok) {
|
||||||
|
return res.status(500).json({ error: 'Failed to create new memory.' });
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteResult = await deleteMemory({ userId: req.user.id, key: urlKey });
|
||||||
|
if (!deleteResult.ok) {
|
||||||
|
return res.status(500).json({ error: 'Failed to delete old memory.' });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Key is not changing, just update the value
|
||||||
|
const result = await setMemory({
|
||||||
|
userId: req.user.id,
|
||||||
|
key: newKey,
|
||||||
|
value,
|
||||||
|
tokenCount,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.ok) {
|
||||||
|
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||||
const updatedMemory = updatedMemories.find((m) => m.key === key);
|
const updatedMemory = updatedMemories.find((m) => m.key === newKey);
|
||||||
|
|
||||||
res.json({ updated: true, memory: updatedMemory });
|
res.json({ updated: true, memory: updatedMemory });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
|
||||||
|
|
@ -235,12 +235,13 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) =
|
||||||
return res.status(400).json({ error: 'Content part not found' });
|
return res.status(400).json({ error: 'Content part not found' });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (updatedContent[index].type !== ContentTypes.TEXT) {
|
const currentPartType = updatedContent[index].type;
|
||||||
|
if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) {
|
||||||
return res.status(400).json({ error: 'Cannot update non-text content' });
|
return res.status(400).json({ error: 'Cannot update non-text content' });
|
||||||
}
|
}
|
||||||
|
|
||||||
const oldText = updatedContent[index].text;
|
const oldText = updatedContent[index][currentPartType];
|
||||||
updatedContent[index] = { type: ContentTypes.TEXT, text };
|
updatedContent[index] = { type: currentPartType, [currentPartType]: text };
|
||||||
|
|
||||||
let tokenCount = message.tokenCount;
|
let tokenCount = message.tokenCount;
|
||||||
if (tokenCount !== undefined) {
|
if (tokenCount !== undefined) {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { generateCheckAccess } = require('@librechat/api');
|
||||||
|
const { Permissions, SystemRoles, PermissionTypes } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
getPrompt,
|
getPrompt,
|
||||||
getPrompts,
|
getPrompts,
|
||||||
|
|
@ -16,23 +18,30 @@ const {
|
||||||
} = require('~/models/Prompt');
|
} = require('~/models/Prompt');
|
||||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||||
const { getUserById, updateUser } = require('~/models');
|
const { getUserById, updateUser } = require('~/models');
|
||||||
|
const { getRoleByName } = require('~/models/Role');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]);
|
const checkPromptAccess = generateCheckAccess({
|
||||||
const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
|
permissionType: PermissionTypes.PROMPTS,
|
||||||
Permissions.USE,
|
permissions: [Permissions.USE],
|
||||||
Permissions.CREATE,
|
getRoleByName,
|
||||||
]);
|
});
|
||||||
|
const checkPromptCreate = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.PROMPTS,
|
||||||
|
permissions: [Permissions.USE, Permissions.CREATE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
const checkGlobalPromptShare = generateCheckAccess(
|
const checkGlobalPromptShare = generateCheckAccess({
|
||||||
PermissionTypes.PROMPTS,
|
permissionType: PermissionTypes.PROMPTS,
|
||||||
[Permissions.USE, Permissions.CREATE],
|
permissions: [Permissions.USE, Permissions.CREATE],
|
||||||
{
|
bodyProps: {
|
||||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||||
},
|
},
|
||||||
);
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
router.use(requireJwtAuth);
|
router.use(requireJwtAuth);
|
||||||
router.use(checkPromptAccess);
|
router.use(checkPromptAccess);
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,24 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { generateCheckAccess } = require('@librechat/api');
|
||||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
getConversationTags,
|
updateTagsForConversation,
|
||||||
updateConversationTag,
|
updateConversationTag,
|
||||||
createConversationTag,
|
createConversationTag,
|
||||||
deleteConversationTag,
|
deleteConversationTag,
|
||||||
updateTagsForConversation,
|
getConversationTags,
|
||||||
} = require('~/models/ConversationTag');
|
} = require('~/models/ConversationTag');
|
||||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
const { requireJwtAuth } = require('~/server/middleware');
|
||||||
const { logger } = require('~/config');
|
const { getRoleByName } = require('~/models/Role');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
const checkBookmarkAccess = generateCheckAccess(PermissionTypes.BOOKMARKS, [Permissions.USE]);
|
const checkBookmarkAccess = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.BOOKMARKS,
|
||||||
|
permissions: [Permissions.USE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
router.use(requireJwtAuth);
|
router.use(requireJwtAuth);
|
||||||
router.use(checkBookmarkAccess);
|
router.use(checkBookmarkAccess);
|
||||||
|
|
|
||||||
|
|
@ -152,12 +152,14 @@ describe('AppService', () => {
|
||||||
filteredTools: undefined,
|
filteredTools: undefined,
|
||||||
includedTools: undefined,
|
includedTools: undefined,
|
||||||
webSearch: {
|
webSearch: {
|
||||||
|
safeSearch: 1,
|
||||||
|
jinaApiKey: '${JINA_API_KEY}',
|
||||||
cohereApiKey: '${COHERE_API_KEY}',
|
cohereApiKey: '${COHERE_API_KEY}',
|
||||||
|
serperApiKey: '${SERPER_API_KEY}',
|
||||||
|
searxngApiKey: '${SEARXNG_API_KEY}',
|
||||||
firecrawlApiKey: '${FIRECRAWL_API_KEY}',
|
firecrawlApiKey: '${FIRECRAWL_API_KEY}',
|
||||||
firecrawlApiUrl: '${FIRECRAWL_API_URL}',
|
firecrawlApiUrl: '${FIRECRAWL_API_URL}',
|
||||||
jinaApiKey: '${JINA_API_KEY}',
|
searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}',
|
||||||
safeSearch: 1,
|
|
||||||
serperApiKey: '${SERPER_API_KEY}',
|
|
||||||
},
|
},
|
||||||
memory: undefined,
|
memory: undefined,
|
||||||
agents: {
|
agents: {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
const { klona } = require('klona');
|
const { klona } = require('klona');
|
||||||
|
const { sleep } = require('@librechat/agents');
|
||||||
|
const { sendEvent } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
StepTypes,
|
StepTypes,
|
||||||
RunStatus,
|
RunStatus,
|
||||||
|
|
@ -11,11 +14,10 @@ const {
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
||||||
const { processRequiredActions } = require('~/server/services/ToolService');
|
const { processRequiredActions } = require('~/server/services/ToolService');
|
||||||
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
|
|
||||||
const { RunManager, waitForRun } = require('~/server/services/Runs');
|
const { RunManager, waitForRun } = require('~/server/services/Runs');
|
||||||
const { processMessages } = require('~/server/services/Threads');
|
const { processMessages } = require('~/server/services/Threads');
|
||||||
|
const { createOnProgress } = require('~/server/utils');
|
||||||
const { TextStream } = require('~/app/clients');
|
const { TextStream } = require('~/app/clients');
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sorts, processes, and flattens messages to a single string.
|
* Sorts, processes, and flattens messages to a single string.
|
||||||
|
|
@ -64,7 +66,7 @@ async function createOnTextProgress({
|
||||||
};
|
};
|
||||||
|
|
||||||
logger.debug('Content data:', contentData);
|
logger.debug('Content data:', contentData);
|
||||||
sendMessage(openai.res, contentData);
|
sendEvent(openai.res, contentData);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
const bcrypt = require('bcryptjs');
|
const bcrypt = require('bcryptjs');
|
||||||
|
const jwt = require('jsonwebtoken');
|
||||||
const { webcrypto } = require('node:crypto');
|
const { webcrypto } = require('node:crypto');
|
||||||
const { isEnabled } = require('@librechat/api');
|
const { isEnabled } = require('@librechat/api');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
|
@ -499,6 +500,18 @@ const resendVerificationEmail = async (req) => {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* Generate a short-lived JWT token
|
||||||
|
* @param {String} userId - The ID of the user
|
||||||
|
* @param {String} [expireIn='5m'] - The expiration time for the token (default is 5 minutes)
|
||||||
|
* @returns {String} - The generated JWT token
|
||||||
|
*/
|
||||||
|
const generateShortLivedToken = (userId, expireIn = '5m') => {
|
||||||
|
return jwt.sign({ id: userId }, process.env.JWT_SECRET, {
|
||||||
|
expiresIn: expireIn,
|
||||||
|
algorithm: 'HS256',
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
logoutUser,
|
logoutUser,
|
||||||
|
|
@ -506,7 +519,8 @@ module.exports = {
|
||||||
registerUser,
|
registerUser,
|
||||||
setAuthTokens,
|
setAuthTokens,
|
||||||
resetPassword,
|
resetPassword,
|
||||||
|
setOpenIDAuthTokens,
|
||||||
requestPasswordReset,
|
requestPasswordReset,
|
||||||
resendVerificationEmail,
|
resendVerificationEmail,
|
||||||
setOpenIDAuthTokens,
|
generateShortLivedToken,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
const { isUserProvided } = require('@librechat/api');
|
||||||
const { EModelEndpoint } = require('librechat-data-provider');
|
const { EModelEndpoint } = require('librechat-data-provider');
|
||||||
const { isUserProvided, generateConfig } = require('~/server/utils');
|
const { generateConfig } = require('~/server/utils/handleText');
|
||||||
|
|
||||||
const {
|
const {
|
||||||
OPENAI_API_KEY: openAIApiKey,
|
OPENAI_API_KEY: openAIApiKey,
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ async function getBalanceConfig() {
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param {string | EModelEndpoint} endpoint
|
* @param {string | EModelEndpoint} endpoint
|
||||||
|
* @returns {Promise<TEndpoint | undefined>}
|
||||||
*/
|
*/
|
||||||
const getCustomEndpointConfig = async (endpoint) => {
|
const getCustomEndpointConfig = async (endpoint) => {
|
||||||
const customConfig = await getCustomConfig();
|
const customConfig = await getCustomConfig();
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,10 @@
|
||||||
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
|
const {
|
||||||
|
CacheKeys,
|
||||||
|
EModelEndpoint,
|
||||||
|
isAgentsEndpoint,
|
||||||
|
orderEndpointsConfig,
|
||||||
|
defaultAgentCapabilities,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
|
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
|
||||||
const loadConfigEndpoints = require('./loadConfigEndpoints');
|
const loadConfigEndpoints = require('./loadConfigEndpoints');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
|
@ -80,8 +86,12 @@ async function getEndpointsConfig(req) {
|
||||||
* @returns {Promise<boolean>}
|
* @returns {Promise<boolean>}
|
||||||
*/
|
*/
|
||||||
const checkCapability = async (req, capability) => {
|
const checkCapability = async (req, capability) => {
|
||||||
|
const isAgents = isAgentsEndpoint(req.body?.original_endpoint || req.body?.endpoint);
|
||||||
const endpointsConfig = await getEndpointsConfig(req);
|
const endpointsConfig = await getEndpointsConfig(req);
|
||||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
const capabilities =
|
||||||
|
isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null
|
||||||
|
? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [])
|
||||||
|
: defaultAgentCapabilities;
|
||||||
return capabilities.includes(capability);
|
return capabilities.includes(capability);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
const path = require('path');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { loadServiceKey, isUserProvided } = require('@librechat/api');
|
||||||
const { EModelEndpoint } = require('librechat-data-provider');
|
const { EModelEndpoint } = require('librechat-data-provider');
|
||||||
const { isUserProvided } = require('~/server/utils');
|
|
||||||
const { config } = require('./EndpointService');
|
const { config } = require('./EndpointService');
|
||||||
|
|
||||||
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config;
|
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config;
|
||||||
|
|
@ -9,37 +11,41 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
|
||||||
* @param {Express.Request} req - The request object
|
* @param {Express.Request} req - The request object
|
||||||
*/
|
*/
|
||||||
async function loadAsyncEndpoints(req) {
|
async function loadAsyncEndpoints(req) {
|
||||||
let i = 0;
|
|
||||||
let serviceKey, googleUserProvides;
|
let serviceKey, googleUserProvides;
|
||||||
try {
|
|
||||||
serviceKey = require('~/data/auth.json');
|
/** Check if GOOGLE_KEY is provided at all(including 'user_provided') */
|
||||||
} catch (e) {
|
const isGoogleKeyProvided = googleKey && googleKey.trim() !== '';
|
||||||
if (i === 0) {
|
|
||||||
i++;
|
if (isGoogleKeyProvided) {
|
||||||
|
/** If GOOGLE_KEY is provided, check if it's user_provided */
|
||||||
|
googleUserProvides = isUserProvided(googleKey);
|
||||||
|
} else {
|
||||||
|
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||||
|
const serviceKeyPath =
|
||||||
|
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(__dirname, '../../..', 'data', 'auth.json');
|
||||||
|
|
||||||
|
try {
|
||||||
|
serviceKey = await loadServiceKey(serviceKeyPath);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error loading service key', error);
|
||||||
|
serviceKey = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isUserProvided(googleKey)) {
|
const google = serviceKey || isGoogleKeyProvided ? { userProvide: googleUserProvides } : false;
|
||||||
googleUserProvides = true;
|
|
||||||
if (i <= 1) {
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
|
|
||||||
|
|
||||||
const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins;
|
const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins;
|
||||||
const gptPlugins =
|
const gptPlugins =
|
||||||
useAzure || openAIApiKey || azureOpenAIApiKey
|
useAzure || openAIApiKey || azureOpenAIApiKey
|
||||||
? {
|
? {
|
||||||
availableAgents: ['classic', 'functions'],
|
availableAgents: ['classic', 'functions'],
|
||||||
userProvide: useAzure ? false : userProvidedOpenAI,
|
userProvide: useAzure ? false : userProvidedOpenAI,
|
||||||
userProvideURL: useAzure
|
userProvideURL: useAzure
|
||||||
? false
|
? false
|
||||||
: config[EModelEndpoint.openAI]?.userProvideURL ||
|
: config[EModelEndpoint.openAI]?.userProvideURL ||
|
||||||
config[EModelEndpoint.azureOpenAI]?.userProvideURL,
|
config[EModelEndpoint.azureOpenAI]?.userProvideURL,
|
||||||
azure: useAzurePlugins || useAzure,
|
azure: useAzurePlugins || useAzure,
|
||||||
}
|
}
|
||||||
: false;
|
: false;
|
||||||
|
|
||||||
return { google, gptPlugins };
|
return { google, gptPlugins };
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const {
|
|
||||||
CacheKeys,
|
|
||||||
configSchema,
|
|
||||||
EImageOutputType,
|
|
||||||
validateSettingDefinitions,
|
|
||||||
agentParamSettings,
|
|
||||||
paramSettings,
|
|
||||||
} = require('librechat-data-provider');
|
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
|
||||||
const loadYaml = require('~/utils/loadYaml');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
const yaml = require('js-yaml');
|
const yaml = require('js-yaml');
|
||||||
const keyBy = require('lodash/keyBy');
|
const keyBy = require('lodash/keyBy');
|
||||||
|
const { loadYaml } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const {
|
||||||
|
CacheKeys,
|
||||||
|
configSchema,
|
||||||
|
paramSettings,
|
||||||
|
EImageOutputType,
|
||||||
|
agentParamSettings,
|
||||||
|
validateSettingDefinitions,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
|
||||||
const projectRoot = path.resolve(__dirname, '..', '..', '..', '..');
|
const projectRoot = path.resolve(__dirname, '..', '..', '..', '..');
|
||||||
const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml');
|
const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml');
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
jest.mock('axios');
|
jest.mock('axios');
|
||||||
jest.mock('~/cache/getLogStores');
|
jest.mock('~/cache/getLogStores');
|
||||||
jest.mock('~/utils/loadYaml');
|
jest.mock('@librechat/api', () => ({
|
||||||
|
...jest.requireActual('@librechat/api'),
|
||||||
|
loadYaml: jest.fn(),
|
||||||
|
}));
|
||||||
jest.mock('librechat-data-provider', () => {
|
jest.mock('librechat-data-provider', () => {
|
||||||
const actual = jest.requireActual('librechat-data-provider');
|
const actual = jest.requireActual('librechat-data-provider');
|
||||||
return {
|
return {
|
||||||
|
|
@ -30,11 +33,22 @@ jest.mock('librechat-data-provider', () => {
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => {
|
||||||
|
return {
|
||||||
|
logger: {
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
debug: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
|
const { loadYaml } = require('@librechat/api');
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const loadCustomConfig = require('./loadCustomConfig');
|
const loadCustomConfig = require('./loadCustomConfig');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
const loadYaml = require('~/utils/loadYaml');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
describe('loadCustomConfig', () => {
|
describe('loadCustomConfig', () => {
|
||||||
const mockSet = jest.fn();
|
const mockSet = jest.fn();
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
const { Providers } = require('@librechat/agents');
|
const { Providers } = require('@librechat/agents');
|
||||||
const { primeResources, optionalChainWithEmptyCheck } = require('@librechat/api');
|
const {
|
||||||
|
primeResources,
|
||||||
|
extractLibreChatParams,
|
||||||
|
optionalChainWithEmptyCheck,
|
||||||
|
} = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
ErrorTypes,
|
ErrorTypes,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
|
|
@ -7,30 +11,12 @@ const {
|
||||||
replaceSpecialVars,
|
replaceSpecialVars,
|
||||||
providerEndpointMap,
|
providerEndpointMap,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize');
|
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||||
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
|
|
||||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
|
||||||
const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
|
||||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
|
||||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
|
||||||
const { processFiles } = require('~/server/services/Files/process');
|
const { processFiles } = require('~/server/services/Files/process');
|
||||||
|
const { getFiles, getToolFilesByIds } = require('~/models/File');
|
||||||
const { getConvoFiles } = require('~/models/Conversation');
|
const { getConvoFiles } = require('~/models/Conversation');
|
||||||
const { getToolFilesByIds } = require('~/models/File');
|
|
||||||
const { getModelMaxTokens } = require('~/utils');
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
const { getFiles } = require('~/models/File');
|
|
||||||
|
|
||||||
const providerConfigMap = {
|
|
||||||
[Providers.XAI]: initCustom,
|
|
||||||
[Providers.OLLAMA]: initCustom,
|
|
||||||
[Providers.DEEPSEEK]: initCustom,
|
|
||||||
[Providers.OPENROUTER]: initCustom,
|
|
||||||
[EModelEndpoint.openAI]: initOpenAI,
|
|
||||||
[EModelEndpoint.google]: initGoogle,
|
|
||||||
[EModelEndpoint.azureOpenAI]: initOpenAI,
|
|
||||||
[EModelEndpoint.anthropic]: initAnthropic,
|
|
||||||
[EModelEndpoint.bedrock]: getBedrockOptions,
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {object} params
|
* @param {object} params
|
||||||
|
|
@ -71,7 +57,7 @@ const initializeAgent = async ({
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
const { resendFiles = true, ...modelOptions } = _modelOptions;
|
const { resendFiles, maxContextTokens, modelOptions } = extractLibreChatParams(_modelOptions);
|
||||||
|
|
||||||
if (isInitialAgent && conversationId != null && resendFiles) {
|
if (isInitialAgent && conversationId != null && resendFiles) {
|
||||||
const fileIds = (await getConvoFiles(conversationId)) ?? [];
|
const fileIds = (await getConvoFiles(conversationId)) ?? [];
|
||||||
|
|
@ -99,7 +85,7 @@ const initializeAgent = async ({
|
||||||
});
|
});
|
||||||
|
|
||||||
const provider = agent.provider;
|
const provider = agent.provider;
|
||||||
const { tools, toolContextMap } =
|
const { tools: structuredTools, toolContextMap } =
|
||||||
(await loadTools?.({
|
(await loadTools?.({
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
|
|
@ -111,17 +97,9 @@ const initializeAgent = async ({
|
||||||
})) ?? {};
|
})) ?? {};
|
||||||
|
|
||||||
agent.endpoint = provider;
|
agent.endpoint = provider;
|
||||||
let getOptions = providerConfigMap[provider];
|
const { getOptions, overrideProvider } = await getProviderConfig(provider);
|
||||||
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
|
if (overrideProvider) {
|
||||||
agent.provider = provider.toLowerCase();
|
agent.provider = overrideProvider;
|
||||||
getOptions = providerConfigMap[agent.provider];
|
|
||||||
} else if (!getOptions) {
|
|
||||||
const customEndpointConfig = await getCustomEndpointConfig(provider);
|
|
||||||
if (!customEndpointConfig) {
|
|
||||||
throw new Error(`Provider ${provider} not supported`);
|
|
||||||
}
|
|
||||||
getOptions = initCustom;
|
|
||||||
agent.provider = Providers.OPENAI;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const _endpointOption =
|
const _endpointOption =
|
||||||
|
|
@ -145,9 +123,8 @@ const initializeAgent = async ({
|
||||||
modelOptions.maxTokens,
|
modelOptions.maxTokens,
|
||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
const maxContextTokens = optionalChainWithEmptyCheck(
|
const agentMaxContextTokens = optionalChainWithEmptyCheck(
|
||||||
modelOptions.maxContextTokens,
|
maxContextTokens,
|
||||||
modelOptions.max_context_tokens,
|
|
||||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
||||||
4096,
|
4096,
|
||||||
);
|
);
|
||||||
|
|
@ -163,6 +140,24 @@ const initializeAgent = async ({
|
||||||
agent.provider = options.provider;
|
agent.provider = options.provider;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** @type {import('@librechat/agents').GenericTool[]} */
|
||||||
|
let tools = options.tools?.length ? options.tools : structuredTools;
|
||||||
|
if (
|
||||||
|
(agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) &&
|
||||||
|
options.tools?.length &&
|
||||||
|
structuredTools?.length
|
||||||
|
) {
|
||||||
|
throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`);
|
||||||
|
} else if (
|
||||||
|
(agent.provider === Providers.OPENAI ||
|
||||||
|
agent.provider === Providers.AZURE ||
|
||||||
|
agent.provider === Providers.ANTHROPIC) &&
|
||||||
|
options.tools?.length &&
|
||||||
|
structuredTools?.length
|
||||||
|
) {
|
||||||
|
tools = structuredTools.concat(options.tools);
|
||||||
|
}
|
||||||
|
|
||||||
/** @type {import('@librechat/agents').ClientOptions} */
|
/** @type {import('@librechat/agents').ClientOptions} */
|
||||||
agent.model_parameters = { ...options.llmConfig };
|
agent.model_parameters = { ...options.llmConfig };
|
||||||
if (options.configOptions) {
|
if (options.configOptions) {
|
||||||
|
|
@ -185,11 +180,11 @@ const initializeAgent = async ({
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...agent,
|
...agent,
|
||||||
tools,
|
|
||||||
attachments,
|
attachments,
|
||||||
resendFiles,
|
resendFiles,
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
|
tools,
|
||||||
|
maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
const { isAgentsEndpoint, Constants } = require('librechat-data-provider');
|
const { isAgentsEndpoint, removeNullishValues, Constants } = require('librechat-data-provider');
|
||||||
const { loadAgent } = require('~/models/Agent');
|
const { loadAgent } = require('~/models/Agent');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
||||||
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
|
const { spec, iconURL, agent_id, instructions, ...model_parameters } = parsedBody;
|
||||||
parsedBody;
|
|
||||||
const agentPromise = loadAgent({
|
const agentPromise = loadAgent({
|
||||||
req,
|
req,
|
||||||
agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID,
|
agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID,
|
||||||
|
|
@ -15,19 +14,16 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
||||||
return undefined;
|
return undefined;
|
||||||
});
|
});
|
||||||
|
|
||||||
const endpointOption = {
|
return removeNullishValues({
|
||||||
spec,
|
spec,
|
||||||
iconURL,
|
iconURL,
|
||||||
endpoint,
|
endpoint,
|
||||||
agent_id,
|
agent_id,
|
||||||
endpointType,
|
endpointType,
|
||||||
instructions,
|
instructions,
|
||||||
maxContextTokens,
|
|
||||||
model_parameters,
|
model_parameters,
|
||||||
agent: agentPromise,
|
agent: agentPromise,
|
||||||
};
|
});
|
||||||
|
|
||||||
return endpointOption;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = { buildOptions };
|
module.exports = { buildOptions };
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,17 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { createContentAggregator } = require('@librechat/agents');
|
const { createContentAggregator } = require('@librechat/agents');
|
||||||
const { Constants, EModelEndpoint, getResponseSender } = require('librechat-data-provider');
|
|
||||||
const {
|
const {
|
||||||
getDefaultHandlers,
|
Constants,
|
||||||
|
EModelEndpoint,
|
||||||
|
isAgentsEndpoint,
|
||||||
|
getResponseSender,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
|
const {
|
||||||
createToolEndCallback,
|
createToolEndCallback,
|
||||||
|
getDefaultHandlers,
|
||||||
} = require('~/server/controllers/agents/callbacks');
|
} = require('~/server/controllers/agents/callbacks');
|
||||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||||
|
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||||
const AgentClient = require('~/server/controllers/agents/client');
|
const AgentClient = require('~/server/controllers/agents/client');
|
||||||
const { getAgent } = require('~/models/Agent');
|
const { getAgent } = require('~/models/Agent');
|
||||||
|
|
@ -61,6 +67,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
const primaryAgent = await endpointOption.agent;
|
const primaryAgent = await endpointOption.agent;
|
||||||
|
delete endpointOption.agent;
|
||||||
if (!primaryAgent) {
|
if (!primaryAgent) {
|
||||||
throw new Error('Agent not found');
|
throw new Error('Agent not found');
|
||||||
}
|
}
|
||||||
|
|
@ -108,11 +115,25 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let endpointConfig = req.app.locals[primaryConfig.endpoint];
|
||||||
|
if (!isAgentsEndpoint(primaryConfig.endpoint) && !endpointConfig) {
|
||||||
|
try {
|
||||||
|
endpointConfig = await getCustomEndpointConfig(primaryConfig.endpoint);
|
||||||
|
} catch (err) {
|
||||||
|
logger.error(
|
||||||
|
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
|
||||||
|
err,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const sender =
|
const sender =
|
||||||
primaryAgent.name ??
|
primaryAgent.name ??
|
||||||
getResponseSender({
|
getResponseSender({
|
||||||
...endpointOption,
|
...endpointOption,
|
||||||
model: endpointOption.model_parameters.model,
|
model: endpointOption.model_parameters.model,
|
||||||
|
modelDisplayLabel: endpointConfig?.modelDisplayLabel,
|
||||||
|
modelLabel: endpointOption.model_parameters.modelLabel,
|
||||||
});
|
});
|
||||||
|
|
||||||
const client = new AgentClient({
|
const client = new AgentClient({
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => {
|
||||||
let timeoutId;
|
let timeoutId;
|
||||||
try {
|
try {
|
||||||
const timeoutPromise = new Promise((_, reject) => {
|
const timeoutPromise = new Promise((_, reject) => {
|
||||||
timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000);
|
timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 45000);
|
||||||
}).catch((error) => {
|
}).catch((error) => {
|
||||||
logger.error('Title error:', error);
|
logger.error('Title error:', error);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||||
{
|
{
|
||||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||||
proxy: PROXY ?? null,
|
proxy: PROXY ?? null,
|
||||||
modelOptions: endpointOption.model_parameters,
|
modelOptions: endpointOption?.model_parameters ?? {},
|
||||||
},
|
},
|
||||||
clientOptions,
|
clientOptions,
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -75,9 +75,20 @@ function getLLMConfig(apiKey, options = {}) {
|
||||||
|
|
||||||
if (options.reverseProxyUrl) {
|
if (options.reverseProxyUrl) {
|
||||||
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
|
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
|
||||||
|
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tools = [];
|
||||||
|
|
||||||
|
if (mergedOptions.web_search) {
|
||||||
|
tools.push({
|
||||||
|
type: 'web_search_20250305',
|
||||||
|
name: 'web_search',
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
tools,
|
||||||
/** @type {AnthropicClientOptions} */
|
/** @type {AnthropicClientOptions} */
|
||||||
llmConfig: removeNullishValues(requestOptions),
|
llmConfig: removeNullishValues(requestOptions),
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,45 @@
|
||||||
const { anthropicSettings } = require('librechat-data-provider');
|
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
|
||||||
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
||||||
|
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
|
||||||
|
|
||||||
jest.mock('https-proxy-agent', () => ({
|
jest.mock('https-proxy-agent', () => ({
|
||||||
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
|
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
jest.mock('./helpers', () => ({
|
||||||
|
checkPromptCacheSupport: jest.fn(),
|
||||||
|
getClaudeHeaders: jest.fn(),
|
||||||
|
configureReasoning: jest.fn((requestOptions) => requestOptions),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('librechat-data-provider', () => ({
|
||||||
|
anthropicSettings: {
|
||||||
|
model: { default: 'claude-3-opus-20240229' },
|
||||||
|
maxOutputTokens: { default: 4096, reset: jest.fn(() => 4096) },
|
||||||
|
thinking: { default: false },
|
||||||
|
promptCache: { default: false },
|
||||||
|
thinkingBudget: { default: null },
|
||||||
|
},
|
||||||
|
removeNullishValues: jest.fn((obj) => {
|
||||||
|
const result = {};
|
||||||
|
for (const key in obj) {
|
||||||
|
if (obj[key] !== null && obj[key] !== undefined) {
|
||||||
|
result[key] = obj[key];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
describe('getLLMConfig', () => {
|
describe('getLLMConfig', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
checkPromptCacheSupport.mockReturnValue(false);
|
||||||
|
getClaudeHeaders.mockReturnValue(undefined);
|
||||||
|
configureReasoning.mockImplementation((requestOptions) => requestOptions);
|
||||||
|
anthropicSettings.maxOutputTokens.reset.mockReturnValue(4096);
|
||||||
|
});
|
||||||
|
|
||||||
it('should create a basic configuration with default values', () => {
|
it('should create a basic configuration with default values', () => {
|
||||||
const result = getLLMConfig('test-api-key', { modelOptions: {} });
|
const result = getLLMConfig('test-api-key', { modelOptions: {} });
|
||||||
|
|
||||||
|
|
@ -36,6 +70,7 @@ describe('getLLMConfig', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy');
|
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy');
|
||||||
|
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'http://reverse-proxy');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should include topK and topP for non-Claude-3.7 models', () => {
|
it('should include topK and topP for non-Claude-3.7 models', () => {
|
||||||
|
|
@ -65,6 +100,11 @@ describe('getLLMConfig', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => {
|
it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => {
|
||||||
|
configureReasoning.mockImplementation((requestOptions) => {
|
||||||
|
requestOptions.thinking = { type: 'enabled' };
|
||||||
|
return requestOptions;
|
||||||
|
});
|
||||||
|
|
||||||
const result = getLLMConfig('test-api-key', {
|
const result = getLLMConfig('test-api-key', {
|
||||||
modelOptions: {
|
modelOptions: {
|
||||||
model: 'claude-3-7-sonnet',
|
model: 'claude-3-7-sonnet',
|
||||||
|
|
@ -78,6 +118,11 @@ describe('getLLMConfig', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => {
|
it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => {
|
||||||
|
configureReasoning.mockImplementation((requestOptions) => {
|
||||||
|
requestOptions.thinking = { type: 'enabled' };
|
||||||
|
return requestOptions;
|
||||||
|
});
|
||||||
|
|
||||||
const result = getLLMConfig('test-api-key', {
|
const result = getLLMConfig('test-api-key', {
|
||||||
modelOptions: {
|
modelOptions: {
|
||||||
model: 'claude-3.7-sonnet',
|
model: 'claude-3.7-sonnet',
|
||||||
|
|
@ -154,4 +199,160 @@ describe('getLLMConfig', () => {
|
||||||
expect(result3.llmConfig).toHaveProperty('topK', 10);
|
expect(result3.llmConfig).toHaveProperty('topK', 10);
|
||||||
expect(result3.llmConfig).toHaveProperty('topP', 0.9);
|
expect(result3.llmConfig).toHaveProperty('topP', 0.9);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Edge cases', () => {
|
||||||
|
it('should handle missing apiKey', () => {
|
||||||
|
const result = getLLMConfig(undefined, { modelOptions: {} });
|
||||||
|
expect(result.llmConfig).not.toHaveProperty('apiKey');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty modelOptions', () => {
|
||||||
|
expect(() => {
|
||||||
|
getLLMConfig('test-api-key', {});
|
||||||
|
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle no options parameter', () => {
|
||||||
|
expect(() => {
|
||||||
|
getLLMConfig('test-api-key');
|
||||||
|
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle temperature, stop sequences, and stream settings', () => {
|
||||||
|
const result = getLLMConfig('test-api-key', {
|
||||||
|
modelOptions: {
|
||||||
|
temperature: 0.7,
|
||||||
|
stop: ['\n\n', 'END'],
|
||||||
|
stream: false,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
|
||||||
|
expect(result.llmConfig).toHaveProperty('stopSequences', ['\n\n', 'END']);
|
||||||
|
expect(result.llmConfig).toHaveProperty('stream', false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle maxOutputTokens when explicitly set to falsy value', () => {
|
||||||
|
anthropicSettings.maxOutputTokens.reset.mockReturnValue(8192);
|
||||||
|
const result = getLLMConfig('test-api-key', {
|
||||||
|
modelOptions: {
|
||||||
|
model: 'claude-3-opus',
|
||||||
|
maxOutputTokens: null,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(anthropicSettings.maxOutputTokens.reset).toHaveBeenCalledWith('claude-3-opus');
|
||||||
|
expect(result.llmConfig).toHaveProperty('maxTokens', 8192);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle both proxy and reverseProxyUrl', () => {
|
||||||
|
const result = getLLMConfig('test-api-key', {
|
||||||
|
modelOptions: {},
|
||||||
|
proxy: 'http://proxy:8080',
|
||||||
|
reverseProxyUrl: 'https://reverse-proxy.com',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
|
||||||
|
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
|
||||||
|
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
|
||||||
|
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
|
||||||
|
'ProxyAgent',
|
||||||
|
);
|
||||||
|
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com');
|
||||||
|
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'https://reverse-proxy.com');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle prompt cache with supported model', () => {
|
||||||
|
checkPromptCacheSupport.mockReturnValue(true);
|
||||||
|
getClaudeHeaders.mockReturnValue({ 'anthropic-beta': 'prompt-caching-2024-07-31' });
|
||||||
|
|
||||||
|
const result = getLLMConfig('test-api-key', {
|
||||||
|
modelOptions: {
|
||||||
|
model: 'claude-3-5-sonnet',
|
||||||
|
promptCache: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(checkPromptCacheSupport).toHaveBeenCalledWith('claude-3-5-sonnet');
|
||||||
|
expect(getClaudeHeaders).toHaveBeenCalledWith('claude-3-5-sonnet', true);
|
||||||
|
expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({
|
||||||
|
'anthropic-beta': 'prompt-caching-2024-07-31',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle thinking and thinkingBudget options', () => {
|
||||||
|
configureReasoning.mockImplementation((requestOptions, systemOptions) => {
|
||||||
|
if (systemOptions.thinking) {
|
||||||
|
requestOptions.thinking = { type: 'enabled' };
|
||||||
|
}
|
||||||
|
if (systemOptions.thinkingBudget) {
|
||||||
|
requestOptions.thinking = {
|
||||||
|
...requestOptions.thinking,
|
||||||
|
budget_tokens: systemOptions.thinkingBudget,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return requestOptions;
|
||||||
|
});
|
||||||
|
|
||||||
|
getLLMConfig('test-api-key', {
|
||||||
|
modelOptions: {
|
||||||
|
model: 'claude-3-7-sonnet',
|
||||||
|
thinking: true,
|
||||||
|
thinkingBudget: 5000,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(configureReasoning).toHaveBeenCalledWith(
|
||||||
|
expect.any(Object),
|
||||||
|
expect.objectContaining({
|
||||||
|
thinking: true,
|
||||||
|
promptCache: false,
|
||||||
|
thinkingBudget: 5000,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should remove system options from modelOptions', () => {
|
||||||
|
const modelOptions = {
|
||||||
|
model: 'claude-3-opus',
|
||||||
|
thinking: true,
|
||||||
|
promptCache: true,
|
||||||
|
thinkingBudget: 1000,
|
||||||
|
temperature: 0.5,
|
||||||
|
};
|
||||||
|
|
||||||
|
getLLMConfig('test-api-key', { modelOptions });
|
||||||
|
|
||||||
|
expect(modelOptions).not.toHaveProperty('thinking');
|
||||||
|
expect(modelOptions).not.toHaveProperty('promptCache');
|
||||||
|
expect(modelOptions).not.toHaveProperty('thinkingBudget');
|
||||||
|
expect(modelOptions).toHaveProperty('temperature', 0.5);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle all nullish values removal', () => {
|
||||||
|
removeNullishValues.mockImplementation((obj) => {
|
||||||
|
const cleaned = {};
|
||||||
|
Object.entries(obj).forEach(([key, value]) => {
|
||||||
|
if (value !== null && value !== undefined) {
|
||||||
|
cleaned[key] = value;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return cleaned;
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = getLLMConfig('test-api-key', {
|
||||||
|
modelOptions: {
|
||||||
|
temperature: null,
|
||||||
|
topP: undefined,
|
||||||
|
topK: 0,
|
||||||
|
stop: [],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||||
|
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||||
|
expect(result.llmConfig).toHaveProperty('topK', 0);
|
||||||
|
expect(result.llmConfig).toHaveProperty('stopSequences', []);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,7 @@
|
||||||
const OpenAI = require('openai');
|
const OpenAI = require('openai');
|
||||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||||
const { constructAzureURL, isUserProvided } = require('@librechat/api');
|
const { constructAzureURL, isUserProvided, resolveHeaders } = require('@librechat/api');
|
||||||
const {
|
const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider');
|
||||||
ErrorTypes,
|
|
||||||
EModelEndpoint,
|
|
||||||
resolveHeaders,
|
|
||||||
mapModelToAzureConfig,
|
|
||||||
} = require('librechat-data-provider');
|
|
||||||
const {
|
const {
|
||||||
getUserKeyValues,
|
getUserKeyValues,
|
||||||
getUserKeyExpiry,
|
getUserKeyExpiry,
|
||||||
|
|
@ -114,11 +109,14 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
|
||||||
|
|
||||||
apiKey = azureOptions.azureOpenAIApiKey;
|
apiKey = azureOptions.azureOpenAIApiKey;
|
||||||
opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion };
|
opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion };
|
||||||
opts.defaultHeaders = resolveHeaders({
|
opts.defaultHeaders = resolveHeaders(
|
||||||
...headers,
|
{
|
||||||
'api-key': apiKey,
|
...headers,
|
||||||
'OpenAI-Beta': `assistants=${version}`,
|
'api-key': apiKey,
|
||||||
});
|
'OpenAI-Beta': `assistants=${version}`,
|
||||||
|
},
|
||||||
|
req.user,
|
||||||
|
);
|
||||||
opts.model = azureOptions.azureOpenAIApiDeploymentName;
|
opts.model = azureOptions.azureOpenAIApiDeploymentName;
|
||||||
|
|
||||||
if (initAppClient) {
|
if (initAppClient) {
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
||||||
|
|
||||||
/** @type {BedrockClientOptions} */
|
/** @type {BedrockClientOptions} */
|
||||||
const requestOptions = {
|
const requestOptions = {
|
||||||
model: overrideModel ?? endpointOption.model,
|
model: overrideModel ?? endpointOption?.model,
|
||||||
region: BEDROCK_AWS_DEFAULT_REGION,
|
region: BEDROCK_AWS_DEFAULT_REGION,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -76,7 +76,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
||||||
|
|
||||||
const llmConfig = bedrockOutputParser(
|
const llmConfig = bedrockOutputParser(
|
||||||
bedrockInputParser.parse(
|
bedrockInputParser.parse(
|
||||||
removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
|
removeNullishValues(Object.assign(requestOptions, endpointOption?.model_parameters ?? {})),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ const {
|
||||||
extractEnvVariable,
|
extractEnvVariable,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { Providers } = require('@librechat/agents');
|
const { Providers } = require('@librechat/agents');
|
||||||
const { getOpenAIConfig, createHandleLLMNewToken } = require('@librechat/api');
|
const { getOpenAIConfig, createHandleLLMNewToken, resolveHeaders } = require('@librechat/api');
|
||||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||||
const { fetchModels } = require('~/server/services/ModelService');
|
const { fetchModels } = require('~/server/services/ModelService');
|
||||||
|
|
@ -28,12 +28,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||||
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
|
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
|
||||||
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
|
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
|
||||||
|
|
||||||
let resolvedHeaders = {};
|
let resolvedHeaders = resolveHeaders(endpointConfig.headers, req.user);
|
||||||
if (endpointConfig.headers && typeof endpointConfig.headers === 'object') {
|
|
||||||
Object.keys(endpointConfig.headers).forEach((key) => {
|
|
||||||
resolvedHeaders[key] = extractEnvVariable(endpointConfig.headers[key]);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (CUSTOM_API_KEY.match(envVarRegex)) {
|
if (CUSTOM_API_KEY.match(envVarRegex)) {
|
||||||
throw new Error(`Missing API Key for ${endpoint}.`);
|
throw new Error(`Missing API Key for ${endpoint}.`);
|
||||||
|
|
@ -134,7 +129,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||||
};
|
};
|
||||||
|
|
||||||
if (optionsOnly) {
|
if (optionsOnly) {
|
||||||
const modelOptions = endpointOption.model_parameters;
|
const modelOptions = endpointOption?.model_parameters ?? {};
|
||||||
if (endpoint !== Providers.OLLAMA) {
|
if (endpoint !== Providers.OLLAMA) {
|
||||||
clientOptions = Object.assign(
|
clientOptions = Object.assign(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue