mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-24 19:26:14 +01:00
Merge branch 'main' into feature/entra-id-azure-integration
This commit is contained in:
commit
af661b1df2
293 changed files with 20207 additions and 13884 deletions
10
api/server/services/Config/__tests__/getCachedTools.spec.js
Normal file
10
api/server/services/Config/__tests__/getCachedTools.spec.js
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
const { ToolCacheKeys } = require('../getCachedTools');
|
||||
|
||||
describe('getCachedTools - Cache Isolation Security', () => {
|
||||
describe('ToolCacheKeys.MCP_SERVER', () => {
|
||||
it('should generate cache keys that include userId', () => {
|
||||
const key = ToolCacheKeys.MCP_SERVER('user123', 'github');
|
||||
expect(key).toBe('tools:mcp:user123:github');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -7,24 +7,25 @@ const getLogStores = require('~/cache/getLogStores');
|
|||
const ToolCacheKeys = {
|
||||
/** Global tools available to all users */
|
||||
GLOBAL: 'tools:global',
|
||||
/** MCP tools cached by server name */
|
||||
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`,
|
||||
/** MCP tools cached by user ID and server name */
|
||||
MCP_SERVER: (userId, serverName) => `tools:mcp:${userId}:${serverName}`,
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves available tools from cache
|
||||
* @function getCachedTools
|
||||
* @param {Object} options - Options for retrieving tools
|
||||
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||
* @param {string} [options.serverName] - MCP server name to get cached tools for
|
||||
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
|
||||
*/
|
||||
async function getCachedTools(options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { serverName } = options;
|
||||
const { userId, serverName } = options;
|
||||
|
||||
// Return MCP server-specific tools if requested
|
||||
if (serverName) {
|
||||
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
|
||||
if (serverName && userId) {
|
||||
return await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
|
||||
}
|
||||
|
||||
// Default to global tools
|
||||
|
|
@ -36,17 +37,18 @@ async function getCachedTools(options = {}) {
|
|||
* @function setCachedTools
|
||||
* @param {Object} tools - The tools object to cache
|
||||
* @param {Object} options - Options for caching tools
|
||||
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||
* @param {string} [options.serverName] - MCP server name for server-specific tools
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds
|
||||
* @returns {Promise<boolean>} Whether the operation was successful
|
||||
*/
|
||||
async function setCachedTools(tools, options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { serverName, ttl } = options;
|
||||
const { userId, serverName, ttl } = options;
|
||||
|
||||
// Cache by MCP server if specified
|
||||
if (serverName) {
|
||||
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl);
|
||||
// Cache by MCP server if specified (requires userId)
|
||||
if (serverName && userId) {
|
||||
return await cache.set(ToolCacheKeys.MCP_SERVER(userId, serverName), tools, ttl);
|
||||
}
|
||||
|
||||
// Default to global cache
|
||||
|
|
@ -57,13 +59,14 @@ async function setCachedTools(tools, options = {}) {
|
|||
* Invalidates cached tools
|
||||
* @function invalidateCachedTools
|
||||
* @param {Object} options - Options for invalidating tools
|
||||
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||
* @param {string} [options.serverName] - MCP server name to invalidate
|
||||
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function invalidateCachedTools(options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { serverName, invalidateGlobal = false } = options;
|
||||
const { userId, serverName, invalidateGlobal = false } = options;
|
||||
|
||||
const keysToDelete = [];
|
||||
|
||||
|
|
@ -71,22 +74,23 @@ async function invalidateCachedTools(options = {}) {
|
|||
keysToDelete.push(ToolCacheKeys.GLOBAL);
|
||||
}
|
||||
|
||||
if (serverName) {
|
||||
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName));
|
||||
if (serverName && userId) {
|
||||
keysToDelete.push(ToolCacheKeys.MCP_SERVER(userId, serverName));
|
||||
}
|
||||
|
||||
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets MCP tools for a specific server from cache or merges with global tools
|
||||
* Gets MCP tools for a specific server from cache
|
||||
* @function getMCPServerTools
|
||||
* @param {string} userId - The user ID
|
||||
* @param {string} serverName - The MCP server name
|
||||
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
|
||||
*/
|
||||
async function getMCPServerTools(serverName) {
|
||||
async function getMCPServerTools(userId, serverName) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
|
||||
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
|
||||
|
||||
if (serverTools) {
|
||||
return serverTools;
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ async function getEndpointsConfig(req) {
|
|||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const isAgents = isAgentsEndpoint(req.body?.original_endpoint || req.body?.endpoint);
|
||||
const isAgents = isAgentsEndpoint(req.body?.endpointType || req.body?.endpoint);
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities =
|
||||
isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
const { isUserProvided, normalizeEndpointName } = require('@librechat/api');
|
||||
const { EModelEndpoint, extractEnvVariable } = require('librechat-data-provider');
|
||||
const { isUserProvided } = require('@librechat/api');
|
||||
const {
|
||||
EModelEndpoint,
|
||||
extractEnvVariable,
|
||||
normalizeEndpointName,
|
||||
} = require('librechat-data-provider');
|
||||
const { fetchModels } = require('~/server/services/ModelService');
|
||||
const { getAppConfig } = require('./app');
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ const { getLogStores } = require('~/cache');
|
|||
/**
|
||||
* Updates MCP tools in the cache for a specific server
|
||||
* @param {Object} params - Parameters for updating MCP tools
|
||||
* @param {string} params.userId - User ID for user-specific caching
|
||||
* @param {string} params.serverName - MCP server name
|
||||
* @param {Array} params.tools - Array of tool objects from MCP server
|
||||
* @returns {Promise<LCAvailableTools>}
|
||||
*/
|
||||
async function updateMCPServerTools({ serverName, tools }) {
|
||||
async function updateMCPServerTools({ userId, serverName, tools }) {
|
||||
try {
|
||||
const serverTools = {};
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
|
|
@ -27,14 +28,16 @@ async function updateMCPServerTools({ serverName, tools }) {
|
|||
};
|
||||
}
|
||||
|
||||
await setCachedTools(serverTools, { serverName });
|
||||
await setCachedTools(serverTools, { userId, serverName });
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`);
|
||||
logger.debug(
|
||||
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
|
||||
);
|
||||
return serverTools;
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
|
||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
|
@ -65,21 +68,22 @@ async function mergeAppTools(appTools) {
|
|||
/**
|
||||
* Caches MCP server tools (no longer merges with global)
|
||||
* @param {object} params
|
||||
* @param {string} params.userId - User ID for user-specific caching
|
||||
* @param {string} params.serverName
|
||||
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function cacheMCPServerTools({ serverName, serverTools }) {
|
||||
async function cacheMCPServerTools({ userId, serverName, serverTools }) {
|
||||
try {
|
||||
const count = Object.keys(serverTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
// Only cache server-specific tools, no merging with global
|
||||
await setCachedTools(serverTools, { serverName });
|
||||
logger.debug(`Cached ${count} MCP server tools for ${serverName}`);
|
||||
await setCachedTools(serverTools, { userId, serverName });
|
||||
logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error);
|
||||
logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@ const {
|
|||
primeResources,
|
||||
getModelMaxTokens,
|
||||
extractLibreChatParams,
|
||||
filterFilesByEndpointConfig,
|
||||
optionalChainWithEmptyCheck,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
paramEndpoints,
|
||||
isAgentsEndpoint,
|
||||
replaceSpecialVars,
|
||||
providerEndpointMap,
|
||||
|
|
@ -71,6 +73,9 @@ const initializeAgent = async ({
|
|||
|
||||
const { resendFiles, maxContextTokens, modelOptions } = extractLibreChatParams(_modelOptions);
|
||||
|
||||
const provider = agent.provider;
|
||||
agent.endpoint = provider;
|
||||
|
||||
if (isInitialAgent && conversationId != null && resendFiles) {
|
||||
const fileIds = (await getConvoFiles(conversationId)) ?? [];
|
||||
/** @type {Set<EToolResources>} */
|
||||
|
|
@ -88,6 +93,19 @@ const initializeAgent = async ({
|
|||
currentFiles = await processFiles(requestFiles);
|
||||
}
|
||||
|
||||
if (currentFiles && currentFiles.length) {
|
||||
let endpointType;
|
||||
if (!paramEndpoints.has(agent.endpoint)) {
|
||||
endpointType = EModelEndpoint.custom;
|
||||
}
|
||||
|
||||
currentFiles = filterFilesByEndpointConfig(req, {
|
||||
files: currentFiles,
|
||||
endpoint: agent.endpoint,
|
||||
endpointType,
|
||||
});
|
||||
}
|
||||
|
||||
const { attachments, tool_resources } = await primeResources({
|
||||
req,
|
||||
getFiles,
|
||||
|
|
@ -98,7 +116,6 @@ const initializeAgent = async ({
|
|||
requestFileSet: new Set(requestFiles?.map((file) => file.file_id)),
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const {
|
||||
tools: structuredTools,
|
||||
toolContextMap,
|
||||
|
|
@ -113,7 +130,6 @@ const initializeAgent = async ({
|
|||
tool_resources,
|
||||
})) ?? {};
|
||||
|
||||
agent.endpoint = provider;
|
||||
const { getOptions, overrideProvider } = getProviderConfig({ provider, appConfig });
|
||||
if (overrideProvider !== agent.provider) {
|
||||
agent.provider = overrideProvider;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createContentAggregator } = require('@librechat/agents');
|
||||
const { validateAgentModel, getCustomEndpointConfig } = require('@librechat/api');
|
||||
const {
|
||||
validateAgentModel,
|
||||
getCustomEndpointConfig,
|
||||
createSequentialChainEdges,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
|
|
@ -119,44 +123,90 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
|
||||
const agent_ids = primaryConfig.agent_ids;
|
||||
let userMCPAuthMap = primaryConfig.userMCPAuthMap;
|
||||
if (agent_ids?.length) {
|
||||
for (const agentId of agent_ids) {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
if (!agent) {
|
||||
throw new Error(`Agent ${agentId} not found`);
|
||||
|
||||
async function processAgent(agentId) {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
if (!agent) {
|
||||
throw new Error(`Agent ${agentId} not found`);
|
||||
}
|
||||
|
||||
const validationResult = await validateAgentModel({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
});
|
||||
|
||||
if (!validationResult.isValid) {
|
||||
throw new Error(validationResult.error?.message);
|
||||
}
|
||||
|
||||
const config = await initializeAgent({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
});
|
||||
if (userMCPAuthMap != null) {
|
||||
Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {});
|
||||
} else {
|
||||
userMCPAuthMap = config.userMCPAuthMap;
|
||||
}
|
||||
agentConfigs.set(agentId, config);
|
||||
}
|
||||
|
||||
let edges = primaryConfig.edges;
|
||||
const checkAgentInit = (agentId) => agentId === primaryConfig.id || agentConfigs.has(agentId);
|
||||
if ((edges?.length ?? 0) > 0) {
|
||||
for (const edge of edges) {
|
||||
if (Array.isArray(edge.to)) {
|
||||
for (const to of edge.to) {
|
||||
if (checkAgentInit(to)) {
|
||||
continue;
|
||||
}
|
||||
await processAgent(to);
|
||||
}
|
||||
} else if (typeof edge.to === 'string' && checkAgentInit(edge.to)) {
|
||||
continue;
|
||||
} else if (typeof edge.to === 'string') {
|
||||
await processAgent(edge.to);
|
||||
}
|
||||
|
||||
const validationResult = await validateAgentModel({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
});
|
||||
|
||||
if (!validationResult.isValid) {
|
||||
throw new Error(validationResult.error?.message);
|
||||
if (Array.isArray(edge.from)) {
|
||||
for (const from of edge.from) {
|
||||
if (checkAgentInit(from)) {
|
||||
continue;
|
||||
}
|
||||
await processAgent(from);
|
||||
}
|
||||
} else if (typeof edge.from === 'string' && checkAgentInit(edge.from)) {
|
||||
continue;
|
||||
} else if (typeof edge.from === 'string') {
|
||||
await processAgent(edge.from);
|
||||
}
|
||||
|
||||
const config = await initializeAgent({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
});
|
||||
if (userMCPAuthMap != null) {
|
||||
Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {});
|
||||
} else {
|
||||
userMCPAuthMap = config.userMCPAuthMap;
|
||||
}
|
||||
agentConfigs.set(agentId, config);
|
||||
}
|
||||
}
|
||||
|
||||
/** @deprecated Agent Chain */
|
||||
if (agent_ids?.length) {
|
||||
for (const agentId of agent_ids) {
|
||||
if (checkAgentInit(agentId)) {
|
||||
continue;
|
||||
}
|
||||
await processAgent(agentId);
|
||||
}
|
||||
|
||||
const chain = await createSequentialChainEdges([primaryConfig.id].concat(agent_ids), '{convo}');
|
||||
edges = edges ? edges.concat(chain) : chain;
|
||||
}
|
||||
|
||||
primaryConfig.edges = edges;
|
||||
|
||||
let endpointConfig = appConfig.endpoints?.[primaryConfig.endpoint];
|
||||
if (!isAgentsEndpoint(primaryConfig.endpoint) && !endpointConfig) {
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -27,13 +27,13 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
|||
const anthropicConfig = appConfig.endpoints?.[EModelEndpoint.anthropic];
|
||||
|
||||
if (anthropicConfig) {
|
||||
clientOptions.streamRate = anthropicConfig.streamRate;
|
||||
clientOptions._lc_stream_delay = anthropicConfig.streamRate;
|
||||
clientOptions.titleModel = anthropicConfig.titleModel;
|
||||
}
|
||||
|
||||
const allConfig = appConfig.endpoints?.all;
|
||||
if (allConfig) {
|
||||
clientOptions.streamRate = allConfig.streamRate;
|
||||
clientOptions._lc_stream_delay = allConfig.streamRate;
|
||||
}
|
||||
|
||||
if (optionsOnly) {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { createHandleLLMNewToken } = require('@librechat/api');
|
||||
const {
|
||||
AuthType,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
bedrockInputParser,
|
||||
bedrockOutputParser,
|
||||
|
|
@ -11,7 +9,6 @@ const {
|
|||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
|
||||
const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
||||
const appConfig = req.config;
|
||||
const {
|
||||
BEDROCK_AWS_SECRET_ACCESS_KEY,
|
||||
BEDROCK_AWS_ACCESS_KEY_ID,
|
||||
|
|
@ -47,10 +44,12 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
|||
checkUserKeyExpiry(expiresAt, EModelEndpoint.bedrock);
|
||||
}
|
||||
|
||||
/** @type {number} */
|
||||
/*
|
||||
Callback for stream rate no longer awaits and may end the stream prematurely
|
||||
/** @type {number}
|
||||
let streamRate = Constants.DEFAULT_STREAM_RATE;
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
/** @type {undefined | TBaseEndpoint}
|
||||
const bedrockConfig = appConfig.endpoints?.[EModelEndpoint.bedrock];
|
||||
|
||||
if (bedrockConfig && bedrockConfig.streamRate) {
|
||||
|
|
@ -61,6 +60,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
|||
if (allConfig && allConfig.streamRate) {
|
||||
streamRate = allConfig.streamRate;
|
||||
}
|
||||
*/
|
||||
|
||||
/** @type {BedrockClientOptions} */
|
||||
const requestOptions = {
|
||||
|
|
@ -88,12 +88,6 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
|
|||
llmConfig.endpointHost = BEDROCK_REVERSE_PROXY;
|
||||
}
|
||||
|
||||
llmConfig.callbacks = [
|
||||
{
|
||||
handleLLMNewToken: createHandleLLMNewToken(streamRate),
|
||||
},
|
||||
];
|
||||
|
||||
return {
|
||||
/** @type {BedrockClientOptions} */
|
||||
llmConfig,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ const {
|
|||
isUserProvided,
|
||||
getOpenAIConfig,
|
||||
getCustomEndpointConfig,
|
||||
createHandleLLMNewToken,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
CacheKeys,
|
||||
|
|
@ -157,11 +156,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
|||
if (!clientOptions.streamRate) {
|
||||
return options;
|
||||
}
|
||||
options.llmConfig.callbacks = [
|
||||
{
|
||||
handleLLMNewToken: createHandleLLMNewToken(clientOptions.streamRate),
|
||||
},
|
||||
];
|
||||
options.llmConfig._lc_stream_delay = clientOptions.streamRate;
|
||||
return options;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ jest.mock('@librechat/api', () => ({
|
|||
...jest.requireActual('@librechat/api'),
|
||||
resolveHeaders: jest.fn(),
|
||||
getOpenAIConfig: jest.fn(),
|
||||
createHandleLLMNewToken: jest.fn(),
|
||||
getCustomEndpointConfig: jest.fn().mockReturnValue({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'https://test.com',
|
||||
|
|
|
|||
|
|
@ -5,9 +5,7 @@ const {
|
|||
isUserProvided,
|
||||
getOpenAIConfig,
|
||||
getAzureCredentials,
|
||||
createHandleLLMNewToken,
|
||||
shouldUseEntraId,
|
||||
getEntraIdAccessToken,
|
||||
} = require('@librechat/api');
|
||||
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
|
|
@ -167,11 +165,7 @@ const initializeClient = async ({
|
|||
if (!streamRate) {
|
||||
return options;
|
||||
}
|
||||
options.llmConfig.callbacks = [
|
||||
{
|
||||
handleLLMNewToken: createHandleLLMNewToken(streamRate),
|
||||
},
|
||||
];
|
||||
options.llmConfig._lc_stream_delay = streamRate;
|
||||
return options;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -227,7 +227,6 @@ class STTService {
|
|||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
...(apiKey && { 'api-key': apiKey }),
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
const axios = require('axios');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logAxiosError, validateImage } = require('@librechat/api');
|
||||
const {
|
||||
FileSources,
|
||||
VisionModes,
|
||||
ImageDetail,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
mergeFileConfig,
|
||||
getEndpointFileConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
|
||||
|
|
@ -84,11 +86,15 @@ const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3]);
|
|||
* Encodes and formats the given files.
|
||||
* @param {ServerRequest} req - The request object.
|
||||
* @param {Array<MongoFile>} files - The array of files to encode and format.
|
||||
* @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image.
|
||||
* @param {object} params - Object containing provider/endpoint information
|
||||
* @param {Providers | EModelEndpoint | string} [params.provider] - The provider for the image
|
||||
* @param {string} [params.endpoint] - Optional: The endpoint for the image
|
||||
* @param {string} [mode] - Optional: The endpoint mode for the image.
|
||||
* @returns {Promise<{ files: MongoFile[]; image_urls: MessageContentImageUrl[] }>} - A promise that resolves to the result object containing the encoded images and file details.
|
||||
*/
|
||||
async function encodeAndFormat(req, files, endpoint, mode) {
|
||||
async function encodeAndFormat(req, files, params, mode) {
|
||||
const { provider, endpoint } = params;
|
||||
const effectiveEndpoint = endpoint ?? provider;
|
||||
const promises = [];
|
||||
/** @type {Record<FileSources, Pick<ReturnType<typeof getStrategyFunctions>, 'prepareImagePayload' | 'getDownloadStream'>>} */
|
||||
const encodingMethods = {};
|
||||
|
|
@ -134,7 +140,7 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
} catch (error) {
|
||||
logger.error('Error processing image from blob storage:', error);
|
||||
}
|
||||
} else if (source !== FileSources.local && base64Only.has(endpoint)) {
|
||||
} else if (source !== FileSources.local && base64Only.has(effectiveEndpoint)) {
|
||||
const [_file, imageURL] = await preparePayload(req, file);
|
||||
promises.push([_file, await fetchImageToBase64(imageURL)]);
|
||||
continue;
|
||||
|
|
@ -148,6 +154,17 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
const formattedImages = await Promise.all(promises);
|
||||
promises.length = 0;
|
||||
|
||||
/** Extract configured file size limit from fileConfig for this endpoint */
|
||||
let configuredFileSizeLimit;
|
||||
if (req.config?.fileConfig) {
|
||||
const fileConfig = mergeFileConfig(req.config.fileConfig);
|
||||
const endpointConfig = getEndpointFileConfig({
|
||||
fileConfig,
|
||||
endpoint: effectiveEndpoint,
|
||||
});
|
||||
configuredFileSizeLimit = endpointConfig?.fileSizeLimit;
|
||||
}
|
||||
|
||||
for (const [file, imageContent] of formattedImages) {
|
||||
const fileMetadata = {
|
||||
type: file.type,
|
||||
|
|
@ -168,6 +185,26 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
continue;
|
||||
}
|
||||
|
||||
/** Validate image buffer against size limits */
|
||||
if (file.height && file.width) {
|
||||
const imageBuffer = imageContent.startsWith('http')
|
||||
? null
|
||||
: Buffer.from(imageContent, 'base64');
|
||||
|
||||
if (imageBuffer) {
|
||||
const validation = await validateImage(
|
||||
imageBuffer,
|
||||
imageBuffer.length,
|
||||
effectiveEndpoint,
|
||||
configuredFileSizeLimit,
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Image validation failed for ${file.filename}: ${validation.error}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const imagePart = {
|
||||
type: ContentTypes.IMAGE_URL,
|
||||
image_url: {
|
||||
|
|
@ -184,15 +221,19 @@ async function encodeAndFormat(req, files, endpoint, mode) {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (endpoint && endpoint === EModelEndpoint.google && mode === VisionModes.generative) {
|
||||
if (
|
||||
effectiveEndpoint &&
|
||||
effectiveEndpoint === EModelEndpoint.google &&
|
||||
mode === VisionModes.generative
|
||||
) {
|
||||
delete imagePart.image_url;
|
||||
imagePart.inlineData = {
|
||||
mimeType: file.type,
|
||||
data: imageContent,
|
||||
};
|
||||
} else if (endpoint && endpoint === EModelEndpoint.google) {
|
||||
} else if (effectiveEndpoint && effectiveEndpoint === EModelEndpoint.google) {
|
||||
imagePart.image_url = imagePart.image_url.url;
|
||||
} else if (endpoint && endpoint === EModelEndpoint.anthropic) {
|
||||
} else if (effectiveEndpoint && effectiveEndpoint === EModelEndpoint.anthropic) {
|
||||
imagePart.type = 'image';
|
||||
imagePart.source = {
|
||||
type: 'base64',
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ const {
|
|||
checkOpenAIStorage,
|
||||
removeNullishValues,
|
||||
isAssistantsEndpoint,
|
||||
getEndpointFileConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { EnvVar } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
|
|
@ -994,7 +995,7 @@ async function saveBase64Image(
|
|||
*/
|
||||
function filterFile({ req, image, isAvatar }) {
|
||||
const { file } = req;
|
||||
const { endpoint, file_id, width, height } = req.body;
|
||||
const { endpoint, endpointType, file_id, width, height } = req.body;
|
||||
|
||||
if (!file_id && !isAvatar) {
|
||||
throw new Error('No file_id provided');
|
||||
|
|
@ -1016,9 +1017,13 @@ function filterFile({ req, image, isAvatar }) {
|
|||
const appConfig = req.config;
|
||||
const fileConfig = mergeFileConfig(appConfig.fileConfig);
|
||||
|
||||
const { fileSizeLimit: sizeLimit, supportedMimeTypes } =
|
||||
fileConfig.endpoints[endpoint] ?? fileConfig.endpoints.default;
|
||||
const fileSizeLimit = isAvatar === true ? fileConfig.avatarSizeLimit : sizeLimit;
|
||||
const endpointFileConfig = getEndpointFileConfig({
|
||||
endpoint,
|
||||
fileConfig,
|
||||
endpointType,
|
||||
});
|
||||
const fileSizeLimit =
|
||||
isAvatar === true ? fileConfig.avatarSizeLimit : endpointFileConfig.fileSizeLimit;
|
||||
|
||||
if (file.size > fileSizeLimit) {
|
||||
throw new Error(
|
||||
|
|
@ -1028,7 +1033,10 @@ function filterFile({ req, image, isAvatar }) {
|
|||
);
|
||||
}
|
||||
|
||||
const isSupportedMimeType = fileConfig.checkType(file.mimetype, supportedMimeTypes);
|
||||
const isSupportedMimeType = fileConfig.checkType(
|
||||
file.mimetype,
|
||||
endpointFileConfig.supportedMimeTypes,
|
||||
);
|
||||
|
||||
if (!isSupportedMimeType) {
|
||||
throw new Error('Unsupported file type');
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ const { findToken, createToken, updateToken } = require('~/models');
|
|||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
|
|
@ -450,7 +451,7 @@ async function getMCPSetupData(userId) {
|
|||
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
||||
}
|
||||
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
const oauthServers = await mcpServersRegistry.getOAuthServers();
|
||||
|
||||
return {
|
||||
mcpConfig,
|
||||
|
|
|
|||
|
|
@ -50,6 +50,9 @@ jest.mock('@librechat/api', () => ({
|
|||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
mcpServersRegistry: {
|
||||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
|
|
@ -100,6 +103,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
let mockGetFlowStateManager;
|
||||
let mockGetLogStores;
|
||||
let mockGetOAuthReconnectionManager;
|
||||
let mockMcpServersRegistry;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
|
@ -108,6 +112,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||
mockGetLogStores = require('~/cache').getLogStores;
|
||||
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
||||
mockMcpServersRegistry = require('@librechat/api').mcpServersRegistry;
|
||||
});
|
||||
|
||||
describe('getMCPSetupData', () => {
|
||||
|
|
@ -125,8 +130,8 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
mockGetMCPManager.mockReturnValue({
|
||||
appConnections: { getAll: jest.fn(() => new Map()) },
|
||||
getUserConnections: jest.fn(() => new Map()),
|
||||
getOAuthServers: jest.fn(() => new Set()),
|
||||
});
|
||||
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set());
|
||||
});
|
||||
|
||||
it('should successfully return MCP setup data', async () => {
|
||||
|
|
@ -139,9 +144,9 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => mockAppConnections) },
|
||||
getUserConnections: jest.fn(() => mockUserConnections),
|
||||
getOAuthServers: jest.fn(() => mockOAuthServers),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(mockOAuthServers);
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
|
|
@ -149,7 +154,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.getOAuthServers).toHaveBeenCalled();
|
||||
expect(mockMcpServersRegistry.getOAuthServers).toHaveBeenCalled();
|
||||
|
||||
expect(result).toEqual({
|
||||
mcpConfig: mockConfig.mcpServers,
|
||||
|
|
@ -170,9 +175,9 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => null) },
|
||||
getUserConnections: jest.fn(() => null),
|
||||
getOAuthServers: jest.fn(() => new Set()),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set());
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ const {
|
|||
ImageVisionTool,
|
||||
openapiToFunction,
|
||||
AgentCapabilities,
|
||||
validateActionDomain,
|
||||
defaultAgentCapabilities,
|
||||
validateAndParseOpenAPISpec,
|
||||
} = require('librechat-data-provider');
|
||||
|
|
@ -236,12 +237,26 @@ async function processRequiredActions(client, requiredActions) {
|
|||
|
||||
// Validate and parse OpenAPI spec
|
||||
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
if (!validationResult.spec || !validationResult.serverUrl) {
|
||||
throw new Error(
|
||||
`Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
|
||||
);
|
||||
}
|
||||
|
||||
// SECURITY: Validate the domain from the spec matches the stored domain
|
||||
// This is defense-in-depth to prevent any stored malicious actions
|
||||
const domainValidation = validateActionDomain(
|
||||
action.metadata.domain,
|
||||
validationResult.serverUrl,
|
||||
);
|
||||
if (!domainValidation.isValid) {
|
||||
logger.error(`Domain mismatch in stored action: ${domainValidation.message}`, {
|
||||
userId: client.req.user.id,
|
||||
action_id: action.action_id,
|
||||
});
|
||||
continue; // Skip this action rather than failing the entire request
|
||||
}
|
||||
|
||||
// Process the OpenAPI spec
|
||||
const { requestBuilders } = openapiToFunction(validationResult.spec);
|
||||
|
||||
|
|
@ -525,10 +540,25 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
|
||||
// Validate and parse OpenAPI spec once per action set
|
||||
const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec);
|
||||
if (!validationResult.spec) {
|
||||
if (!validationResult.spec || !validationResult.serverUrl) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// SECURITY: Validate the domain from the spec matches the stored domain
|
||||
// This is defense-in-depth to prevent any stored malicious actions
|
||||
const domainValidation = validateActionDomain(
|
||||
action.metadata.domain,
|
||||
validationResult.serverUrl,
|
||||
);
|
||||
if (!domainValidation.isValid) {
|
||||
logger.error(`Domain mismatch in stored action: ${domainValidation.message}`, {
|
||||
userId: req.user.id,
|
||||
agent_id: agent.id,
|
||||
action_id: action.action_id,
|
||||
});
|
||||
continue; // Skip this action rather than failing the entire request
|
||||
}
|
||||
|
||||
const encrypted = {
|
||||
oauth_client_id: action.metadata.oauth_client_id,
|
||||
oauth_client_secret: action.metadata.oauth_client_secret,
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ async function reinitMCPServer({
|
|||
if (connection && !oauthRequired) {
|
||||
tools = await connection.fetchTools();
|
||||
availableTools = await updateMCPServerTools({
|
||||
userId: user.id,
|
||||
serverName,
|
||||
tools,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ async function initializeMCPs() {
|
|||
const mcpManager = await createMCPManager(mcpServers);
|
||||
|
||||
try {
|
||||
const mcpTools = mcpManager.getAppToolFunctions() || {};
|
||||
const mcpTools = (await mcpManager.getAppToolFunctions()) || {};
|
||||
await mergeAppTools(mcpTools);
|
||||
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { Tool } = require('@langchain/core/tools');
|
||||
const { Calculator } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { zodToJsonSchema } = require('zod-to-json-schema');
|
||||
const { Tools, ImageVisionTool } = require('librechat-data-provider');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { getToolkitKey, oaiToolkit, ytToolkit } = require('@librechat/api');
|
||||
const { toolkits } = require('~/app/clients/tools/manifest');
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue