mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 08:25:23 +02:00
Merge remote-tracking branch 'origin/main' into dt-conf-logo
This commit is contained in:
commit
063d6fb721
473 changed files with 32771 additions and 15449 deletions
51
.env.example
51
.env.example
|
|
@ -68,6 +68,18 @@ DEBUG_CONSOLE=false
|
|||
# UID=1000
|
||||
# GID=1000
|
||||
|
||||
#==============#
|
||||
# Node Options #
|
||||
#==============#
|
||||
|
||||
# NOTE: NODE_MAX_OLD_SPACE_SIZE is NOT recognized by Node.js directly.
|
||||
# This variable is used as a build argument for Docker or CI/CD workflows,
|
||||
# and is NOT used by Node.js to set the heap size at runtime.
|
||||
# To configure Node.js memory, use NODE_OPTIONS, e.g.:
|
||||
# NODE_OPTIONS="--max-old-space-size=6144"
|
||||
# See: https://nodejs.org/api/cli.html#--max-old-space-sizesize-in-mib
|
||||
NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
||||
#===============#
|
||||
# Configuration #
|
||||
#===============#
|
||||
|
|
@ -112,6 +124,10 @@ ANTHROPIC_API_KEY=user_provided
|
|||
# ANTHROPIC_MODELS=claude-opus-4-20250514,claude-sonnet-4-20250514,claude-3-7-sonnet-20250219,claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307
|
||||
# ANTHROPIC_REVERSE_PROXY=
|
||||
|
||||
# Set to true to use Anthropic models through Google Vertex AI instead of direct API
|
||||
# ANTHROPIC_USE_VERTEX=
|
||||
# ANTHROPIC_VERTEX_REGION=us-east5
|
||||
|
||||
#============#
|
||||
# Azure #
|
||||
#============#
|
||||
|
|
@ -169,8 +185,16 @@ GOOGLE_KEY=user_provided
|
|||
|
||||
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
|
||||
|
||||
# Google Cloud region for Vertex AI (used by both chat and image generation)
|
||||
# GOOGLE_LOC=us-central1
|
||||
|
||||
# Alternative region env var for Gemini Image Generation
|
||||
# GOOGLE_CLOUD_LOCATION=global
|
||||
|
||||
# Vertex AI Service Account Configuration
|
||||
# Path to your Google Cloud service account JSON file
|
||||
# GOOGLE_SERVICE_KEY_FILE=/path/to/service-account.json
|
||||
|
||||
# Google Safety Settings
|
||||
# NOTE: These settings apply to both Vertex AI and Gemini API (AI Studio)
|
||||
#
|
||||
|
|
@ -190,6 +214,27 @@ GOOGLE_KEY=user_provided
|
|||
# GOOGLE_SAFETY_DANGEROUS_CONTENT=BLOCK_ONLY_HIGH
|
||||
# GOOGLE_SAFETY_CIVIC_INTEGRITY=BLOCK_ONLY_HIGH
|
||||
|
||||
#========================#
|
||||
# Gemini Image Generation #
|
||||
#========================#
|
||||
|
||||
# Gemini Image Generation Tool (for Agents)
|
||||
# Supports multiple authentication methods in priority order:
|
||||
# 1. User-provided API key (via GUI)
|
||||
# 2. GEMINI_API_KEY env var (admin-configured)
|
||||
# 3. GOOGLE_KEY env var (shared with Google chat endpoint)
|
||||
# 4. Vertex AI service account (via GOOGLE_SERVICE_KEY_FILE)
|
||||
|
||||
# Option A: Use dedicated Gemini API key for image generation
|
||||
# GEMINI_API_KEY=your-gemini-api-key
|
||||
|
||||
# Option B: Use Vertex AI (no API key needed, uses service account)
|
||||
# Set this to enable Vertex AI and allow tool without requiring API keys
|
||||
# GEMINI_VERTEX_ENABLED=true
|
||||
|
||||
# Vertex AI model for image generation (defaults to gemini-2.5-flash-image)
|
||||
# GEMINI_IMAGE_MODEL=gemini-2.5-flash-image
|
||||
|
||||
#============#
|
||||
# OpenAI #
|
||||
#============#
|
||||
|
|
@ -248,6 +293,7 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=
|
|||
# IMAGE_GEN_OAI_API_KEY= # Create or reuse OpenAI API key for image generation tool
|
||||
# IMAGE_GEN_OAI_BASEURL= # Custom OpenAI base URL for image generation tool
|
||||
# IMAGE_GEN_OAI_AZURE_API_VERSION= # Custom Azure OpenAI deployments
|
||||
# IMAGE_GEN_OAI_MODEL=gpt-image-1 # OpenAI image model (e.g., gpt-image-1, gpt-image-1.5)
|
||||
# IMAGE_GEN_OAI_DESCRIPTION=
|
||||
# IMAGE_GEN_OAI_DESCRIPTION_WITH_FILES=Custom description for image generation tool when files are present
|
||||
# IMAGE_GEN_OAI_DESCRIPTION_NO_FILES=Custom description for image generation tool when no files are present
|
||||
|
|
@ -479,6 +525,8 @@ OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
|
|||
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
|
||||
# Set to true to use the OpenID Connect end session endpoint for logout
|
||||
OPENID_USE_END_SESSION_ENDPOINT=
|
||||
# URL to redirect to after OpenID logout (defaults to ${DOMAIN_CLIENT}/login)
|
||||
OPENID_POST_LOGOUT_REDIRECT_URI=
|
||||
|
||||
#========================#
|
||||
# SharePoint Integration #
|
||||
|
|
@ -656,6 +704,9 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
|||
|
||||
# Enable Redis for caching and session storage
|
||||
# USE_REDIS=true
|
||||
# Enable Redis for resumable LLM streams (defaults to USE_REDIS value if not set)
|
||||
# Set to false to use in-memory storage for streams while keeping Redis for other caches
|
||||
# USE_REDIS_STREAMS=true
|
||||
|
||||
# Single Redis instance
|
||||
# REDIS_URI=redis://127.0.0.1:6379
|
||||
|
|
|
|||
1
.github/workflows/backend-review.yml
vendored
1
.github/workflows/backend-review.yml
vendored
|
|
@ -24,6 +24,7 @@ jobs:
|
|||
BAN_DURATION: ${{ secrets.BAN_DURATION }}
|
||||
BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }}
|
||||
NODE_ENV: CI
|
||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Use Node.js 20.x
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ on:
|
|||
- 'packages/api/src/cache/**'
|
||||
- 'packages/api/src/cluster/**'
|
||||
- 'packages/api/src/mcp/**'
|
||||
- 'packages/api/src/stream/**'
|
||||
- 'redis-config/**'
|
||||
- '.github/workflows/cache-integration-tests.yml'
|
||||
|
||||
|
|
|
|||
4
.github/workflows/frontend-review.yml
vendored
4
.github/workflows/frontend-review.yml
vendored
|
|
@ -16,6 +16,8 @@ jobs:
|
|||
name: Run frontend unit tests on Ubuntu
|
||||
timeout-minutes: 60
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Use Node.js 20.x
|
||||
|
|
@ -38,6 +40,8 @@ jobs:
|
|||
name: Run frontend unit tests on Windows
|
||||
timeout-minutes: 60
|
||||
runs-on: windows-latest
|
||||
env:
|
||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Use Node.js 20.x
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# v0.8.2-rc1
|
||||
# v0.8.2-rc2
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
|
@ -14,6 +14,9 @@ ENV LD_PRELOAD=/usr/lib/libjemalloc.so.2
|
|||
COPY --from=ghcr.io/astral-sh/uv:0.9.5-python3.12-alpine /usr/local/bin/uv /usr/local/bin/uvx /bin/
|
||||
RUN uv --version
|
||||
|
||||
# Set configurable max-old-space-size with default
|
||||
ARG NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
||||
RUN mkdir -p /app && chown node:node /app
|
||||
WORKDIR /app
|
||||
|
||||
|
|
@ -39,8 +42,8 @@ RUN \
|
|||
COPY --chown=node:node . .
|
||||
|
||||
RUN \
|
||||
# React client build
|
||||
NODE_OPTIONS="--max-old-space-size=2048" npm run frontend; \
|
||||
# React client build with configurable memory
|
||||
NODE_OPTIONS="--max-old-space-size=${NODE_MAX_OLD_SPACE_SIZE}" npm run frontend; \
|
||||
npm prune --production; \
|
||||
npm cache clean --force
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
# Dockerfile.multi
|
||||
# v0.8.2-rc1
|
||||
# v0.8.2-rc2
|
||||
|
||||
# Set configurable max-old-space-size with default
|
||||
ARG NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
|
|
@ -7,6 +10,7 @@ FROM node:20-alpine AS base-min
|
|||
RUN apk add --no-cache jemalloc
|
||||
# Set environment variable to use jemalloc
|
||||
ENV LD_PRELOAD=/usr/lib/libjemalloc.so.2
|
||||
|
||||
WORKDIR /app
|
||||
RUN apk --no-cache add curl
|
||||
RUN npm config set fetch-retry-maxtimeout 600000 && \
|
||||
|
|
@ -59,7 +63,8 @@ COPY client ./
|
|||
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
|
||||
COPY --from=client-package-build /app/packages/client/dist /app/packages/client/dist
|
||||
COPY --from=client-package-build /app/packages/client/src /app/packages/client/src
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
ARG NODE_MAX_OLD_SPACE_SIZE
|
||||
ENV NODE_OPTIONS="--max-old-space-size=${NODE_MAX_OLD_SPACE_SIZE}"
|
||||
RUN npm run build
|
||||
|
||||
# API setup (including client dist)
|
||||
|
|
@ -79,4 +84,4 @@ COPY --from=client-build /app/client/dist ./client/dist
|
|||
WORKDIR /app/api
|
||||
EXPOSE 3080
|
||||
ENV HOST=0.0.0.0
|
||||
CMD ["node", "server/index.js"]
|
||||
CMD ["node", "server/index.js"]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ const {
|
|||
EModelEndpoint,
|
||||
isParamEndpoint,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
supportsBalanceCheck,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -714,7 +715,7 @@ class BaseClient {
|
|||
iconURL: this.options.iconURL,
|
||||
endpoint: this.options.endpoint,
|
||||
...(this.metadata ?? {}),
|
||||
metadata,
|
||||
metadata: Object.keys(metadata ?? {}).length > 0 ? metadata : undefined,
|
||||
};
|
||||
|
||||
if (typeof completion === 'string') {
|
||||
|
|
@ -937,6 +938,7 @@ class BaseClient {
|
|||
throw new Error('User mismatch.');
|
||||
}
|
||||
|
||||
const hasAddedConvo = this.options?.req?.body?.addedConvo != null;
|
||||
const savedMessage = await saveMessage(
|
||||
this.options?.req,
|
||||
{
|
||||
|
|
@ -944,6 +946,7 @@ class BaseClient {
|
|||
endpoint: this.options.endpoint,
|
||||
unfinished: false,
|
||||
user,
|
||||
...(hasAddedConvo && { addedConvo: true }),
|
||||
},
|
||||
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage' },
|
||||
);
|
||||
|
|
@ -969,7 +972,7 @@ class BaseClient {
|
|||
const hasNonEphemeralAgent =
|
||||
isAgentsEndpoint(this.options.endpoint) &&
|
||||
endpointOptions?.agent_id &&
|
||||
endpointOptions.agent_id !== Constants.EPHEMERAL_AGENT_ID;
|
||||
!isEphemeralAgentId(endpointOptions.agent_id);
|
||||
if (hasNonEphemeralAgent) {
|
||||
exceptions.add('model');
|
||||
}
|
||||
|
|
@ -1024,7 +1027,8 @@ class BaseClient {
|
|||
* @param {Object} options - The options for the function.
|
||||
* @param {TMessage[]} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property.
|
||||
* @param {string} options.parentMessageId - The ID of the parent message to start the traversal from.
|
||||
* @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. If provided, it will be applied to each message in the resulting array.
|
||||
* @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. Applied conditionally based on mapCondition.
|
||||
* @param {(message: TMessage) => boolean} [options.mapCondition] - An optional function to determine whether mapMethod should be applied to a given message. If not provided and mapMethod is set, mapMethod applies to all messages.
|
||||
* @param {boolean} [options.summary=false] - If set to true, the traversal modifies messages with 'summary' and 'summaryTokenCount' properties and stops at the message with a 'summary' property.
|
||||
* @returns {TMessage[]} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'.
|
||||
*/
|
||||
|
|
@ -1032,6 +1036,7 @@ class BaseClient {
|
|||
messages,
|
||||
parentMessageId,
|
||||
mapMethod = null,
|
||||
mapCondition = null,
|
||||
summary = false,
|
||||
}) {
|
||||
if (!messages || messages.length === 0) {
|
||||
|
|
@ -1066,7 +1071,9 @@ class BaseClient {
|
|||
message.tokenCount = message.summaryTokenCount;
|
||||
}
|
||||
|
||||
orderedMessages.push(message);
|
||||
const shouldMap = mapMethod != null && (mapCondition != null ? mapCondition(message) : true);
|
||||
const processedMessage = shouldMap ? mapMethod(message) : message;
|
||||
orderedMessages.push(processedMessage);
|
||||
|
||||
if (summary && message.summary) {
|
||||
break;
|
||||
|
|
@ -1077,11 +1084,6 @@ class BaseClient {
|
|||
}
|
||||
|
||||
orderedMessages.reverse();
|
||||
|
||||
if (mapMethod) {
|
||||
return orderedMessages.map(mapMethod);
|
||||
}
|
||||
|
||||
return orderedMessages;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ const GoogleSearchAPI = require('./structured/GoogleSearch');
|
|||
const TraversaalSearch = require('./structured/TraversaalSearch');
|
||||
const createOpenAIImageTools = require('./structured/OpenAIImageTools');
|
||||
const TavilySearchResults = require('./structured/TavilySearchResults');
|
||||
const createGeminiImageTool = require('./structured/GeminiImageGen');
|
||||
|
||||
module.exports = {
|
||||
...manifest,
|
||||
|
|
@ -27,4 +28,5 @@ module.exports = {
|
|||
createYouTubeTools,
|
||||
TavilySearchResults,
|
||||
createOpenAIImageTools,
|
||||
createGeminiImageTool,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -179,5 +179,19 @@
|
|||
"description": "Provide your Flux API key from your user profile."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Gemini Image Tools",
|
||||
"pluginKey": "gemini_image_gen",
|
||||
"toolkit": true,
|
||||
"description": "Generate high-quality images using Google's Gemini Image Models. Supports Gemini API or Vertex AI.",
|
||||
"icon": "assets/gemini_image_gen.svg",
|
||||
"authConfig": [
|
||||
{
|
||||
"authField": "GEMINI_API_KEY||GOOGLE_KEY||GEMINI_VERTEX_ENABLED",
|
||||
"label": "Gemini API Key (Optional if Vertex AI is configured)",
|
||||
"description": "Your Google Gemini API Key from <a href='https://aistudio.google.com/app/apikey' target='_blank'>Google AI Studio</a>. Leave blank if using Vertex AI with service account."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
576
api/app/clients/tools/structured/GeminiImageGen.js
Normal file
576
api/app/clients/tools/structured/GeminiImageGen.js
Normal file
|
|
@ -0,0 +1,576 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { v4 } = require('uuid');
|
||||
const { GoogleGenAI } = require('@google/genai');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
FileContext,
|
||||
ContentTypes,
|
||||
FileSources,
|
||||
EImageOutputType,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
geminiToolkit,
|
||||
loadServiceKey,
|
||||
getBalanceConfig,
|
||||
getTransactionsConfig,
|
||||
} = require('@librechat/api');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
/**
|
||||
* Get the default service key file path (consistent with main Google endpoint)
|
||||
* @returns {string} - The default path to the service key file
|
||||
*/
|
||||
function getDefaultServiceKeyPath() {
|
||||
return (
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(process.cwd(), 'api', 'data', 'auth.json')
|
||||
);
|
||||
}
|
||||
|
||||
const displayMessage =
|
||||
"Gemini displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
|
||||
|
||||
/**
|
||||
* Replaces unwanted characters from the input string
|
||||
* @param {string} inputString - The input string to process
|
||||
* @returns {string} - The processed string
|
||||
*/
|
||||
function replaceUnwantedChars(inputString) {
|
||||
return inputString?.replace(/[^\w\s\-_.,!?()]/g, '') || '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and sanitize image format
|
||||
* @param {string} format - The format to validate
|
||||
* @returns {string} - Safe format
|
||||
*/
|
||||
function getSafeFormat(format) {
|
||||
const allowedFormats = ['png', 'jpg', 'jpeg', 'webp', 'gif'];
|
||||
return allowedFormats.includes(format?.toLowerCase()) ? format.toLowerCase() : 'png';
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert image buffer to target format if needed
|
||||
* @param {Buffer} inputBuffer - The input image buffer
|
||||
* @param {string} targetFormat - The target format (png, jpeg, webp)
|
||||
* @returns {Promise<{buffer: Buffer, format: string}>} - Converted buffer and format
|
||||
*/
|
||||
async function convertImageFormat(inputBuffer, targetFormat) {
|
||||
const metadata = await sharp(inputBuffer).metadata();
|
||||
const currentFormat = metadata.format;
|
||||
|
||||
// Normalize format names (jpg -> jpeg)
|
||||
const normalizedTarget = targetFormat === 'jpg' ? 'jpeg' : targetFormat.toLowerCase();
|
||||
const normalizedCurrent = currentFormat === 'jpg' ? 'jpeg' : currentFormat;
|
||||
|
||||
// If already in target format, return as-is
|
||||
if (normalizedCurrent === normalizedTarget) {
|
||||
return { buffer: inputBuffer, format: normalizedTarget };
|
||||
}
|
||||
|
||||
// Convert to target format
|
||||
const convertedBuffer = await sharp(inputBuffer).toFormat(normalizedTarget).toBuffer();
|
||||
return { buffer: convertedBuffer, format: normalizedTarget };
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize Gemini client (supports both Gemini API and Vertex AI)
|
||||
* Priority: API key (from options, resolved by loadAuthValues) > Vertex AI service account
|
||||
* @param {Object} options - Initialization options
|
||||
* @param {string} [options.GEMINI_API_KEY] - Gemini API key (resolved by loadAuthValues)
|
||||
* @param {string} [options.GOOGLE_KEY] - Google API key (resolved by loadAuthValues)
|
||||
* @returns {Promise<GoogleGenAI>} - The initialized client
|
||||
*/
|
||||
async function initializeGeminiClient(options = {}) {
|
||||
const geminiKey = options.GEMINI_API_KEY;
|
||||
if (geminiKey) {
|
||||
logger.debug('[GeminiImageGen] Using Gemini API with GEMINI_API_KEY');
|
||||
return new GoogleGenAI({ apiKey: geminiKey });
|
||||
}
|
||||
|
||||
const googleKey = options.GOOGLE_KEY;
|
||||
if (googleKey) {
|
||||
logger.debug('[GeminiImageGen] Using Gemini API with GOOGLE_KEY');
|
||||
return new GoogleGenAI({ apiKey: googleKey });
|
||||
}
|
||||
|
||||
// Fall back to Vertex AI with service account
|
||||
logger.debug('[GeminiImageGen] Using Vertex AI with service account');
|
||||
const credentialsPath = getDefaultServiceKeyPath();
|
||||
|
||||
// Use loadServiceKey for consistent loading (supports file paths, JSON strings, base64)
|
||||
const serviceKey = await loadServiceKey(credentialsPath);
|
||||
|
||||
if (!serviceKey || !serviceKey.project_id) {
|
||||
throw new Error(
|
||||
'Gemini Image Generation requires one of: user-provided API key, GEMINI_API_KEY or GOOGLE_KEY env var, or a valid Google service account. ' +
|
||||
`Service account file not found or invalid at: ${credentialsPath}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Set GOOGLE_APPLICATION_CREDENTIALS for any Google Cloud SDK dependencies
|
||||
try {
|
||||
await fs.promises.access(credentialsPath);
|
||||
process.env.GOOGLE_APPLICATION_CREDENTIALS = credentialsPath;
|
||||
} catch {
|
||||
// File doesn't exist, skip setting env var
|
||||
}
|
||||
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: serviceKey.project_id,
|
||||
location: process.env.GOOGLE_LOC || process.env.GOOGLE_CLOUD_LOCATION || 'global',
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Save image to local filesystem
|
||||
* @param {string} base64Data - Base64 encoded image data
|
||||
* @param {string} format - Image format
|
||||
* @param {string} userId - User ID
|
||||
* @returns {Promise<string>} - The relative URL
|
||||
*/
|
||||
async function saveImageLocally(base64Data, format, userId) {
|
||||
const safeFormat = getSafeFormat(format);
|
||||
const safeUserId = userId ? path.basename(userId) : 'default';
|
||||
const imageName = `gemini-img-${v4()}.${safeFormat}`;
|
||||
const userDir = path.join(process.cwd(), 'client/public/images', safeUserId);
|
||||
|
||||
await fs.promises.mkdir(userDir, { recursive: true });
|
||||
|
||||
const filePath = path.join(userDir, imageName);
|
||||
await fs.promises.writeFile(filePath, Buffer.from(base64Data, 'base64'));
|
||||
|
||||
logger.debug('[GeminiImageGen] Image saved locally to:', filePath);
|
||||
return `/images/${safeUserId}/${imageName}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save image to cloud storage
|
||||
* @param {Object} params - Parameters
|
||||
* @returns {Promise<string|null>} - The storage URL or null
|
||||
*/
|
||||
async function saveToCloudStorage({ base64Data, format, processFileURL, fileStrategy, userId }) {
|
||||
if (!processFileURL || !fileStrategy || !userId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const safeFormat = getSafeFormat(format);
|
||||
const safeUserId = path.basename(userId);
|
||||
const dataURL = `data:image/${safeFormat};base64,${base64Data}`;
|
||||
const imageName = `gemini-img-${v4()}.${safeFormat}`;
|
||||
|
||||
const result = await processFileURL({
|
||||
URL: dataURL,
|
||||
basePath: 'images',
|
||||
userId: safeUserId,
|
||||
fileName: imageName,
|
||||
fileStrategy,
|
||||
context: FileContext.image_generation,
|
||||
});
|
||||
|
||||
return result.filepath;
|
||||
} catch (error) {
|
||||
logger.error('[GeminiImageGen] Error saving to cloud storage:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert image files to Gemini inline data format
|
||||
* @param {Object} params - Parameters
|
||||
* @returns {Promise<Array>} - Array of inline data objects
|
||||
*/
|
||||
async function convertImagesToInlineData({ imageFiles, image_ids, req, fileStrategy }) {
|
||||
if (!image_ids || image_ids.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const streamMethods = {};
|
||||
const requestFilesMap = Object.fromEntries(imageFiles.map((f) => [f.file_id, { ...f }]));
|
||||
const orderedFiles = new Array(image_ids.length);
|
||||
const idsToFetch = [];
|
||||
const indexOfMissing = Object.create(null);
|
||||
|
||||
for (let i = 0; i < image_ids.length; i++) {
|
||||
const id = image_ids[i];
|
||||
const file = requestFilesMap[id];
|
||||
if (file) {
|
||||
orderedFiles[i] = file;
|
||||
} else {
|
||||
idsToFetch.push(id);
|
||||
indexOfMissing[id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (idsToFetch.length && req?.user?.id) {
|
||||
const fetchedFiles = await getFiles(
|
||||
{
|
||||
user: req.user.id,
|
||||
file_id: { $in: idsToFetch },
|
||||
height: { $exists: true },
|
||||
width: { $exists: true },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
|
||||
for (const file of fetchedFiles) {
|
||||
requestFilesMap[file.file_id] = file;
|
||||
orderedFiles[indexOfMissing[file.file_id]] = file;
|
||||
}
|
||||
}
|
||||
|
||||
const inlineDataArray = [];
|
||||
for (const imageFile of orderedFiles) {
|
||||
if (!imageFile) continue;
|
||||
|
||||
try {
|
||||
const source = imageFile.source || fileStrategy;
|
||||
if (!source) continue;
|
||||
|
||||
let getDownloadStream = streamMethods[source];
|
||||
if (!getDownloadStream) {
|
||||
({ getDownloadStream } = getStrategyFunctions(source));
|
||||
streamMethods[source] = getDownloadStream;
|
||||
}
|
||||
if (!getDownloadStream) continue;
|
||||
|
||||
const stream = await getDownloadStream(req, imageFile.filepath);
|
||||
if (!stream) continue;
|
||||
|
||||
const chunks = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
const buffer = Buffer.concat(chunks);
|
||||
const base64Data = buffer.toString('base64');
|
||||
const mimeType = imageFile.type || 'image/png';
|
||||
|
||||
inlineDataArray.push({
|
||||
inlineData: { mimeType, data: base64Data },
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[GeminiImageGen] Error processing image:', imageFile.file_id, error);
|
||||
}
|
||||
}
|
||||
|
||||
return inlineDataArray;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check for safety blocks in API response
|
||||
* @param {Object} response - The API response
|
||||
* @returns {Object|null} - Safety block info or null
|
||||
*/
|
||||
function checkForSafetyBlock(response) {
|
||||
if (!response?.candidates?.length) {
|
||||
return { reason: 'NO_CANDIDATES', message: 'No candidates returned' };
|
||||
}
|
||||
|
||||
const candidate = response.candidates[0];
|
||||
const finishReason = candidate.finishReason;
|
||||
|
||||
if (finishReason === 'SAFETY' || finishReason === 'PROHIBITED_CONTENT') {
|
||||
return { reason: finishReason, message: 'Content blocked by safety filters' };
|
||||
}
|
||||
|
||||
if (finishReason === 'RECITATION') {
|
||||
return { reason: finishReason, message: 'Content blocked due to recitation concerns' };
|
||||
}
|
||||
|
||||
if (candidate.safetyRatings) {
|
||||
for (const rating of candidate.safetyRatings) {
|
||||
if (rating.probability === 'HIGH' || rating.blocked === true) {
|
||||
return {
|
||||
reason: 'SAFETY_RATING',
|
||||
message: `Blocked due to ${rating.category}`,
|
||||
category: rating.category,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Record token usage for balance tracking
|
||||
* @param {Object} params - Parameters
|
||||
* @param {Object} params.usageMetadata - The usage metadata from API response
|
||||
* @param {Object} params.req - The request object
|
||||
* @param {string} params.userId - The user ID
|
||||
* @param {string} params.conversationId - The conversation ID
|
||||
* @param {string} params.model - The model name
|
||||
*/
|
||||
async function recordTokenUsage({ usageMetadata, req, userId, conversationId, model }) {
|
||||
if (!usageMetadata) {
|
||||
logger.debug('[GeminiImageGen] No usage metadata available for balance tracking');
|
||||
return;
|
||||
}
|
||||
|
||||
const appConfig = req?.config;
|
||||
const balance = getBalanceConfig(appConfig);
|
||||
const transactions = getTransactionsConfig(appConfig);
|
||||
|
||||
// Skip if neither balance nor transactions are enabled
|
||||
if (!balance?.enabled && transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const promptTokens = usageMetadata.prompt_token_count || usageMetadata.promptTokenCount || 0;
|
||||
const completionTokens =
|
||||
usageMetadata.candidates_token_count || usageMetadata.candidatesTokenCount || 0;
|
||||
|
||||
if (promptTokens === 0 && completionTokens === 0) {
|
||||
logger.debug('[GeminiImageGen] No tokens to record');
|
||||
return;
|
||||
}
|
||||
|
||||
logger.debug('[GeminiImageGen] Recording token usage:', {
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
model,
|
||||
conversationId,
|
||||
});
|
||||
|
||||
try {
|
||||
await spendTokens(
|
||||
{
|
||||
user: userId,
|
||||
model,
|
||||
conversationId,
|
||||
context: 'image_generation',
|
||||
balance,
|
||||
transactions,
|
||||
},
|
||||
{
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[GeminiImageGen] Error recording token usage:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates Gemini Image Generation tool
|
||||
* @param {Object} fields - Configuration fields
|
||||
* @returns {ReturnType<tool>} - The image generation tool
|
||||
*/
|
||||
function createGeminiImageTool(fields = {}) {
|
||||
const override = fields.override ?? false;
|
||||
|
||||
if (!override && !fields.isAgent) {
|
||||
throw new Error('This tool is only available for agents.');
|
||||
}
|
||||
|
||||
// Skip validation during tool creation - validation happens at runtime in initializeGeminiClient
|
||||
// This allows the tool to be added to agents when using Vertex AI without requiring API keys
|
||||
// The actual credentials check happens when the tool is invoked
|
||||
|
||||
const {
|
||||
req,
|
||||
imageFiles = [],
|
||||
processFileURL,
|
||||
userId,
|
||||
fileStrategy,
|
||||
GEMINI_API_KEY,
|
||||
GOOGLE_KEY,
|
||||
// GEMINI_VERTEX_ENABLED is used for auth validation only (not used in code)
|
||||
// When set as env var, it signals Vertex AI is configured and bypasses API key requirement
|
||||
} = fields;
|
||||
|
||||
const imageOutputType = fields.imageOutputType || EImageOutputType.PNG;
|
||||
|
||||
const geminiImageGenTool = tool(
|
||||
async ({ prompt, image_ids, aspectRatio, imageSize }, _runnableConfig) => {
|
||||
if (!prompt) {
|
||||
throw new Error('Missing required field: prompt');
|
||||
}
|
||||
|
||||
logger.debug('[GeminiImageGen] Generating image with prompt:', prompt?.substring(0, 100));
|
||||
logger.debug('[GeminiImageGen] Options:', { aspectRatio, imageSize });
|
||||
|
||||
// Initialize Gemini client with user-provided credentials
|
||||
let ai;
|
||||
try {
|
||||
ai = await initializeGeminiClient({
|
||||
GEMINI_API_KEY,
|
||||
GOOGLE_KEY,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[GeminiImageGen] Failed to initialize client:', error);
|
||||
return [
|
||||
[{ type: ContentTypes.TEXT, text: `Failed to initialize Gemini: ${error.message}` }],
|
||||
{ content: [], file_ids: [] },
|
||||
];
|
||||
}
|
||||
|
||||
// Build request contents
|
||||
const contents = [{ text: replaceUnwantedChars(prompt) }];
|
||||
|
||||
// Add context images if provided
|
||||
if (image_ids?.length > 0) {
|
||||
const contextImages = await convertImagesToInlineData({
|
||||
imageFiles,
|
||||
image_ids,
|
||||
req,
|
||||
fileStrategy,
|
||||
});
|
||||
contents.push(...contextImages);
|
||||
logger.debug('[GeminiImageGen] Added', contextImages.length, 'context images');
|
||||
}
|
||||
|
||||
// Generate image
|
||||
let apiResponse;
|
||||
const geminiModel = process.env.GEMINI_IMAGE_MODEL || 'gemini-2.5-flash-image';
|
||||
try {
|
||||
// Build config with optional imageConfig
|
||||
const config = {
|
||||
responseModalities: ['TEXT', 'IMAGE'],
|
||||
};
|
||||
|
||||
// Add imageConfig if aspectRatio or imageSize is specified
|
||||
// Note: gemini-2.5-flash-image doesn't support imageSize
|
||||
const supportsImageSize = !geminiModel.includes('gemini-2.5-flash-image');
|
||||
if (aspectRatio || (imageSize && supportsImageSize)) {
|
||||
config.imageConfig = {};
|
||||
if (aspectRatio) {
|
||||
config.imageConfig.aspectRatio = aspectRatio;
|
||||
}
|
||||
if (imageSize && supportsImageSize) {
|
||||
config.imageConfig.imageSize = imageSize;
|
||||
}
|
||||
}
|
||||
|
||||
apiResponse = await ai.models.generateContent({
|
||||
model: geminiModel,
|
||||
contents,
|
||||
config,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[GeminiImageGen] API error:', error);
|
||||
return [
|
||||
[{ type: ContentTypes.TEXT, text: `Image generation failed: ${error.message}` }],
|
||||
{ content: [], file_ids: [] },
|
||||
];
|
||||
}
|
||||
|
||||
// Check for safety blocks
|
||||
const safetyBlock = checkForSafetyBlock(apiResponse);
|
||||
if (safetyBlock) {
|
||||
logger.warn('[GeminiImageGen] Safety block:', safetyBlock);
|
||||
const errorMsg = 'Image blocked by content safety filters. Please try different content.';
|
||||
return [[{ type: ContentTypes.TEXT, text: errorMsg }], { content: [], file_ids: [] }];
|
||||
}
|
||||
|
||||
const rawImageData = apiResponse.candidates?.[0]?.content?.parts?.find((p) => p.inlineData)
|
||||
?.inlineData?.data;
|
||||
|
||||
if (!rawImageData) {
|
||||
logger.warn('[GeminiImageGen] No image data in response');
|
||||
return [
|
||||
[{ type: ContentTypes.TEXT, text: 'No image was generated. Please try again.' }],
|
||||
{ content: [], file_ids: [] },
|
||||
];
|
||||
}
|
||||
|
||||
const rawBuffer = Buffer.from(rawImageData, 'base64');
|
||||
const { buffer: convertedBuffer, format: outputFormat } = await convertImageFormat(
|
||||
rawBuffer,
|
||||
imageOutputType,
|
||||
);
|
||||
const imageData = convertedBuffer.toString('base64');
|
||||
const mimeType = outputFormat === 'jpeg' ? 'image/jpeg' : `image/${outputFormat}`;
|
||||
|
||||
logger.debug('[GeminiImageGen] Image format:', { outputFormat, mimeType });
|
||||
|
||||
let imageUrl;
|
||||
const useLocalStorage = !fileStrategy || fileStrategy === FileSources.local;
|
||||
|
||||
if (useLocalStorage) {
|
||||
try {
|
||||
imageUrl = await saveImageLocally(imageData, outputFormat, userId);
|
||||
} catch (error) {
|
||||
logger.error('[GeminiImageGen] Local save failed:', error);
|
||||
imageUrl = `data:${mimeType};base64,${imageData}`;
|
||||
}
|
||||
} else {
|
||||
const cloudUrl = await saveToCloudStorage({
|
||||
base64Data: imageData,
|
||||
format: outputFormat,
|
||||
processFileURL,
|
||||
fileStrategy,
|
||||
userId,
|
||||
});
|
||||
|
||||
if (cloudUrl) {
|
||||
imageUrl = cloudUrl;
|
||||
} else {
|
||||
// Fallback to local
|
||||
try {
|
||||
imageUrl = await saveImageLocally(imageData, outputFormat, userId);
|
||||
} catch (_error) {
|
||||
imageUrl = `data:${mimeType};base64,${imageData}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('[GeminiImageGen] Image URL:', imageUrl);
|
||||
|
||||
// For the artifact, we need a data URL (same as OpenAI)
|
||||
// The local file save is for persistence, but the response needs a data URL
|
||||
const dataUrl = `data:${mimeType};base64,${imageData}`;
|
||||
|
||||
// Return in content_and_artifact format (same as OpenAI)
|
||||
const file_ids = [v4()];
|
||||
const content = [
|
||||
{
|
||||
type: ContentTypes.IMAGE_URL,
|
||||
image_url: { url: dataUrl },
|
||||
},
|
||||
];
|
||||
|
||||
const textResponse = [
|
||||
{
|
||||
type: ContentTypes.TEXT,
|
||||
text:
|
||||
displayMessage +
|
||||
`\n\ngenerated_image_id: "${file_ids[0]}"` +
|
||||
(image_ids?.length > 0 ? `\nreferenced_image_ids: ["${image_ids.join('", "')}"]` : ''),
|
||||
},
|
||||
];
|
||||
|
||||
// Record token usage for balance tracking (don't await to avoid blocking response)
|
||||
const conversationId = _runnableConfig?.configurable?.thread_id;
|
||||
recordTokenUsage({
|
||||
usageMetadata: apiResponse.usageMetadata,
|
||||
req,
|
||||
userId,
|
||||
conversationId,
|
||||
model: geminiModel,
|
||||
}).catch((error) => {
|
||||
logger.error('[GeminiImageGen] Failed to record token usage:', error);
|
||||
});
|
||||
|
||||
return [textResponse, { content, file_ids }];
|
||||
},
|
||||
{
|
||||
...geminiToolkit.gemini_image_gen,
|
||||
responseFormat: 'content_and_artifact',
|
||||
},
|
||||
);
|
||||
|
||||
return geminiImageGenTool;
|
||||
}
|
||||
|
||||
// Export both for compatibility
|
||||
module.exports = createGeminiImageTool;
|
||||
module.exports.createGeminiImageTool = createGeminiImageTool;
|
||||
|
|
@ -78,6 +78,8 @@ function createOpenAIImageTools(fields = {}) {
|
|||
let apiKey = fields.IMAGE_GEN_OAI_API_KEY ?? getApiKey();
|
||||
const closureConfig = { apiKey };
|
||||
|
||||
const imageModel = process.env.IMAGE_GEN_OAI_MODEL || 'gpt-image-1';
|
||||
|
||||
let baseURL = 'https://api.openai.com/v1/';
|
||||
if (!override && process.env.IMAGE_GEN_OAI_BASEURL) {
|
||||
baseURL = extractBaseURL(process.env.IMAGE_GEN_OAI_BASEURL);
|
||||
|
|
@ -157,7 +159,7 @@ function createOpenAIImageTools(fields = {}) {
|
|||
|
||||
resp = await openai.images.generate(
|
||||
{
|
||||
model: 'gpt-image-1',
|
||||
model: imageModel,
|
||||
prompt: replaceUnwantedChars(prompt),
|
||||
n: Math.min(Math.max(1, n), 10),
|
||||
background,
|
||||
|
|
@ -239,7 +241,7 @@ Error Message: ${error.message}`);
|
|||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('model', 'gpt-image-1');
|
||||
formData.append('model', imageModel);
|
||||
formData.append('prompt', replaceUnwantedChars(prompt));
|
||||
// TODO: `mask` support
|
||||
// TODO: more than 1 image support
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const {
|
|||
createSafeUser,
|
||||
mcpToolPattern,
|
||||
loadWebSearchAuth,
|
||||
buildImageToolContext,
|
||||
} = require('@librechat/api');
|
||||
const { getMCPServersRegistry } = require('~/config');
|
||||
const {
|
||||
|
|
@ -35,6 +36,7 @@ const {
|
|||
StructuredWolfram,
|
||||
createYouTubeTools,
|
||||
TavilySearchResults,
|
||||
createGeminiImageTool,
|
||||
createOpenAIImageTools,
|
||||
} = require('../');
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
|
|
@ -192,21 +194,11 @@ const loadTools = async ({
|
|||
const authFields = getAuthFields('image_gen_oai');
|
||||
const authValues = await loadAuthValues({ userId: user, authFields });
|
||||
const imageFiles = options.tool_resources?.[EToolResources.image_edit]?.files ?? [];
|
||||
let toolContext = '';
|
||||
for (let i = 0; i < imageFiles.length; i++) {
|
||||
const file = imageFiles[i];
|
||||
if (!file) {
|
||||
continue;
|
||||
}
|
||||
if (i === 0) {
|
||||
toolContext =
|
||||
'Image files provided in this request (their image IDs listed in order of appearance) available for image editing:';
|
||||
}
|
||||
toolContext += `\n\t- ${file.file_id}`;
|
||||
if (i === imageFiles.length - 1) {
|
||||
toolContext += `\n\nInclude any you need in the \`image_ids\` array when calling \`${EToolResources.image_edit}_oai\`. You may also include previously referenced or generated image IDs.`;
|
||||
}
|
||||
}
|
||||
const toolContext = buildImageToolContext({
|
||||
imageFiles,
|
||||
toolName: `${EToolResources.image_edit}_oai`,
|
||||
contextDescription: 'image editing',
|
||||
});
|
||||
if (toolContext) {
|
||||
toolContextMap.image_edit_oai = toolContext;
|
||||
}
|
||||
|
|
@ -219,6 +211,28 @@ const loadTools = async ({
|
|||
imageFiles,
|
||||
});
|
||||
},
|
||||
gemini_image_gen: async (toolContextMap) => {
|
||||
const authFields = getAuthFields('gemini_image_gen');
|
||||
const authValues = await loadAuthValues({ userId: user, authFields });
|
||||
const imageFiles = options.tool_resources?.[EToolResources.image_edit]?.files ?? [];
|
||||
const toolContext = buildImageToolContext({
|
||||
imageFiles,
|
||||
toolName: 'gemini_image_gen',
|
||||
contextDescription: 'image context',
|
||||
});
|
||||
if (toolContext) {
|
||||
toolContextMap.gemini_image_gen = toolContext;
|
||||
}
|
||||
return createGeminiImageTool({
|
||||
...authValues,
|
||||
isAgent: !!agent,
|
||||
req: options.req,
|
||||
imageFiles,
|
||||
processFileURL: options.processFileURL,
|
||||
userId: user,
|
||||
fileStrategy,
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
const requestedTools = {};
|
||||
|
|
@ -241,6 +255,7 @@ const loadTools = async ({
|
|||
flux: imageGenOptions,
|
||||
dalle: imageGenOptions,
|
||||
'stable-diffusion': imageGenOptions,
|
||||
gemini_image_gen: imageGenOptions,
|
||||
};
|
||||
|
||||
/** @type {Record<string, string>} */
|
||||
|
|
@ -348,10 +363,10 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
/** Placeholder used for UI purposes */
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
serverName &&
|
||||
(await getMCPServersRegistry().getServerConfig(serverName, user)) == undefined
|
||||
) {
|
||||
const serverConfig = serverName
|
||||
? await getMCPServersRegistry().getServerConfig(serverName, user)
|
||||
: null;
|
||||
if (!serverConfig) {
|
||||
logger.warn(
|
||||
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
|
||||
);
|
||||
|
|
@ -362,6 +377,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
{
|
||||
type: 'all',
|
||||
serverName,
|
||||
config: serverConfig,
|
||||
},
|
||||
];
|
||||
continue;
|
||||
|
|
@ -372,6 +388,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
type: 'single',
|
||||
toolKey: tool,
|
||||
serverName,
|
||||
config: serverConfig,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
|
@ -432,9 +449,11 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
user: safeUser,
|
||||
userMCPAuthMap,
|
||||
res: options.res,
|
||||
streamId: options.req?._resumableStreamId || null,
|
||||
model: agent?.model ?? model,
|
||||
serverName: config.serverName,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
config: config.config,
|
||||
};
|
||||
|
||||
if (config.type === 'all' && toolConfigs.length === 1) {
|
||||
|
|
|
|||
9
api/cache/banViolation.js
vendored
9
api/cache/banViolation.js
vendored
|
|
@ -47,7 +47,16 @@ const banViolation = async (req, res, errorMessage) => {
|
|||
}
|
||||
|
||||
await deleteAllUserSessions({ userId: user_id });
|
||||
|
||||
/** Clear OpenID session tokens if present */
|
||||
if (req.session?.openidTokens) {
|
||||
delete req.session.openidTokens;
|
||||
}
|
||||
|
||||
res.clearCookie('refreshToken');
|
||||
res.clearCookie('openid_access_token');
|
||||
res.clearCookie('openid_user_id');
|
||||
res.clearCookie('token_provider');
|
||||
|
||||
const banLogs = getLogStores(ViolationTypes.BAN);
|
||||
const duration = errorMessage.duration || banLogs.opts.ttl;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { isEnabled, FlowStateManager } = require('@librechat/api');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { batchResetMeiliFlags } = require('./utils');
|
||||
|
||||
const Conversation = mongoose.models.Conversation;
|
||||
const Message = mongoose.models.Message;
|
||||
|
|
@ -189,6 +190,11 @@ async function ensureFilterableAttributes(client) {
|
|||
*/
|
||||
async function performSync(flowManager, flowId, flowType) {
|
||||
try {
|
||||
if (indexingDisabled === true) {
|
||||
logger.info('[indexSync] Indexing is disabled, skipping...');
|
||||
return { messagesSync: false, convosSync: false };
|
||||
}
|
||||
|
||||
const client = MeiliSearchClient.getInstance();
|
||||
|
||||
const { status } = await client.health();
|
||||
|
|
@ -196,11 +202,6 @@ async function performSync(flowManager, flowId, flowType) {
|
|||
throw new Error('Meilisearch not available');
|
||||
}
|
||||
|
||||
if (indexingDisabled === true) {
|
||||
logger.info('[indexSync] Indexing is disabled, skipping...');
|
||||
return { messagesSync: false, convosSync: false };
|
||||
}
|
||||
|
||||
/** Ensures indexes have proper filterable attributes configured */
|
||||
const { settingsUpdated, orphanedDocsFound: _orphanedDocsFound } =
|
||||
await ensureFilterableAttributes(client);
|
||||
|
|
@ -215,11 +216,8 @@ async function performSync(flowManager, flowId, flowType) {
|
|||
);
|
||||
|
||||
// Reset sync flags to force full re-sync
|
||||
await Message.collection.updateMany({ _meiliIndex: true }, { $set: { _meiliIndex: false } });
|
||||
await Conversation.collection.updateMany(
|
||||
{ _meiliIndex: true },
|
||||
{ $set: { _meiliIndex: false } },
|
||||
);
|
||||
await batchResetMeiliFlags(Message.collection);
|
||||
await batchResetMeiliFlags(Conversation.collection);
|
||||
}
|
||||
|
||||
// Check if we need to sync messages
|
||||
|
|
|
|||
90
api/db/utils.js
Normal file
90
api/db/utils.js
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
|
||||
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
|
||||
/**
|
||||
* Batch update documents in chunks to avoid timeouts on weak instances
|
||||
* @param {mongoose.Collection} collection - MongoDB collection
|
||||
* @returns {Promise<number>} - Total modified count
|
||||
* @throws {Error} - Throws if database operations fail (e.g., network issues, connection loss, permission problems)
|
||||
*/
|
||||
async function batchResetMeiliFlags(collection) {
|
||||
const DEFAULT_BATCH_SIZE = 1000;
|
||||
|
||||
let BATCH_SIZE = parseEnvInt('MEILI_SYNC_BATCH_SIZE', DEFAULT_BATCH_SIZE);
|
||||
if (BATCH_SIZE === 0) {
|
||||
logger.warn(
|
||||
`[batchResetMeiliFlags] MEILI_SYNC_BATCH_SIZE cannot be 0. Using default: ${DEFAULT_BATCH_SIZE}`,
|
||||
);
|
||||
BATCH_SIZE = DEFAULT_BATCH_SIZE;
|
||||
}
|
||||
|
||||
const BATCH_DELAY_MS = parseEnvInt('MEILI_SYNC_DELAY_MS', 100);
|
||||
let totalModified = 0;
|
||||
let hasMore = true;
|
||||
|
||||
try {
|
||||
while (hasMore) {
|
||||
const docs = await collection
|
||||
.find({ expiredAt: null, _meiliIndex: true }, { projection: { _id: 1 } })
|
||||
.limit(BATCH_SIZE)
|
||||
.toArray();
|
||||
|
||||
if (docs.length === 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
const ids = docs.map((doc) => doc._id);
|
||||
const result = await collection.updateMany(
|
||||
{ _id: { $in: ids } },
|
||||
{ $set: { _meiliIndex: false } },
|
||||
);
|
||||
|
||||
totalModified += result.modifiedCount;
|
||||
process.stdout.write(
|
||||
`\r Updating ${collection.collectionName}: ${totalModified} documents...`,
|
||||
);
|
||||
|
||||
if (docs.length < BATCH_SIZE) {
|
||||
hasMore = false;
|
||||
}
|
||||
|
||||
if (hasMore && BATCH_DELAY_MS > 0) {
|
||||
await sleep(BATCH_DELAY_MS);
|
||||
}
|
||||
}
|
||||
|
||||
return totalModified;
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to batch reset Meili flags for collection '${collection.collectionName}' after processing ${totalModified} documents: ${error.message}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse and validate an environment variable as a positive integer
|
||||
* @param {string} varName - Environment variable name
|
||||
* @param {number} defaultValue - Default value to use if invalid or missing
|
||||
* @returns {number} - Parsed value or default
|
||||
*/
|
||||
function parseEnvInt(varName, defaultValue) {
|
||||
const value = process.env[varName];
|
||||
if (!value) {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
const parsed = parseInt(value, 10);
|
||||
if (isNaN(parsed) || parsed < 0) {
|
||||
logger.warn(
|
||||
`[batchResetMeiliFlags] Invalid value for ${varName}="${value}". Expected a positive integer. Using default: ${defaultValue}`,
|
||||
);
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
return parsed;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
batchResetMeiliFlags,
|
||||
};
|
||||
521
api/db/utils.spec.js
Normal file
521
api/db/utils.spec.js
Normal file
|
|
@ -0,0 +1,521 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { batchResetMeiliFlags } = require('./utils');
|
||||
|
||||
describe('batchResetMeiliFlags', () => {
|
||||
let mongoServer;
|
||||
let testCollection;
|
||||
const ORIGINAL_BATCH_SIZE = process.env.MEILI_SYNC_BATCH_SIZE;
|
||||
const ORIGINAL_BATCH_DELAY = process.env.MEILI_SYNC_DELAY_MS;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
|
||||
// Restore original env variables
|
||||
if (ORIGINAL_BATCH_SIZE !== undefined) {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = ORIGINAL_BATCH_SIZE;
|
||||
} else {
|
||||
delete process.env.MEILI_SYNC_BATCH_SIZE;
|
||||
}
|
||||
|
||||
if (ORIGINAL_BATCH_DELAY !== undefined) {
|
||||
process.env.MEILI_SYNC_DELAY_MS = ORIGINAL_BATCH_DELAY;
|
||||
} else {
|
||||
delete process.env.MEILI_SYNC_DELAY_MS;
|
||||
}
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a fresh collection for each test
|
||||
testCollection = mongoose.connection.db.collection('test_meili_batch');
|
||||
await testCollection.deleteMany({});
|
||||
|
||||
// Reset env variables to defaults
|
||||
delete process.env.MEILI_SYNC_BATCH_SIZE;
|
||||
delete process.env.MEILI_SYNC_DELAY_MS;
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
if (testCollection) {
|
||||
await testCollection.deleteMany({});
|
||||
}
|
||||
});
|
||||
|
||||
describe('basic functionality', () => {
|
||||
it('should reset _meiliIndex flag for documents with expiredAt: null and _meiliIndex: true', async () => {
|
||||
// Insert test documents
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true, name: 'doc1' },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true, name: 'doc2' },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true, name: 'doc3' },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(3);
|
||||
|
||||
const updatedDocs = await testCollection.find({ _meiliIndex: false }).toArray();
|
||||
expect(updatedDocs).toHaveLength(3);
|
||||
|
||||
const notUpdatedDocs = await testCollection.find({ _meiliIndex: true }).toArray();
|
||||
expect(notUpdatedDocs).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not modify documents with expiredAt set', async () => {
|
||||
const expiredDate = new Date();
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: expiredDate, _meiliIndex: true },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
|
||||
const expiredDoc = await testCollection.findOne({ expiredAt: expiredDate });
|
||||
expect(expiredDoc._meiliIndex).toBe(true);
|
||||
});
|
||||
|
||||
it('should not modify documents with _meiliIndex: false', async () => {
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: false },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
});
|
||||
|
||||
it('should return 0 when no documents match the criteria', async () => {
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: new Date(), _meiliIndex: true },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: false },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(0);
|
||||
});
|
||||
|
||||
it('should return 0 when collection is empty', async () => {
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('batch processing', () => {
|
||||
it('should process documents in batches according to MEILI_SYNC_BATCH_SIZE', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '2';
|
||||
|
||||
const docs = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
docs.push({
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
name: `doc${i}`,
|
||||
});
|
||||
}
|
||||
await testCollection.insertMany(docs);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(5);
|
||||
|
||||
const updatedDocs = await testCollection.find({ _meiliIndex: false }).toArray();
|
||||
expect(updatedDocs).toHaveLength(5);
|
||||
});
|
||||
|
||||
it('should handle large datasets with small batch sizes', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '10';
|
||||
|
||||
const docs = [];
|
||||
for (let i = 0; i < 25; i++) {
|
||||
docs.push({
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
});
|
||||
}
|
||||
await testCollection.insertMany(docs);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(25);
|
||||
});
|
||||
|
||||
it('should use default batch size of 1000 when env variable is not set', async () => {
|
||||
// Create exactly 1000 documents to verify default batch behavior
|
||||
const docs = [];
|
||||
for (let i = 0; i < 1000; i++) {
|
||||
docs.push({
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
});
|
||||
}
|
||||
await testCollection.insertMany(docs);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('return value', () => {
|
||||
it('should return correct modified count', async () => {
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
await expect(batchResetMeiliFlags(testCollection)).resolves.toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('batch delay', () => {
|
||||
it('should respect MEILI_SYNC_DELAY_MS between batches', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '2';
|
||||
process.env.MEILI_SYNC_DELAY_MS = '50';
|
||||
|
||||
const docs = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
docs.push({
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
});
|
||||
}
|
||||
await testCollection.insertMany(docs);
|
||||
|
||||
const startTime = Date.now();
|
||||
await batchResetMeiliFlags(testCollection);
|
||||
const endTime = Date.now();
|
||||
|
||||
// With 5 documents and batch size 2, we need 3 batches
|
||||
// That means 2 delays between batches (not after the last one)
|
||||
// So minimum time should be around 100ms (2 * 50ms)
|
||||
// Using a slightly lower threshold to account for timing variations
|
||||
const elapsed = endTime - startTime;
|
||||
expect(elapsed).toBeGreaterThanOrEqual(80);
|
||||
});
|
||||
|
||||
it('should not delay when MEILI_SYNC_DELAY_MS is 0', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '2';
|
||||
process.env.MEILI_SYNC_DELAY_MS = '0';
|
||||
|
||||
const docs = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
docs.push({
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
});
|
||||
}
|
||||
await testCollection.insertMany(docs);
|
||||
|
||||
const startTime = Date.now();
|
||||
await batchResetMeiliFlags(testCollection);
|
||||
const endTime = Date.now();
|
||||
|
||||
const elapsed = endTime - startTime;
|
||||
// Should complete without intentional delays, but database operations still take time
|
||||
// Just verify it completes and returns the correct count
|
||||
expect(elapsed).toBeLessThan(1000); // More reasonable upper bound
|
||||
|
||||
const result = await testCollection.countDocuments({ _meiliIndex: false });
|
||||
expect(result).toBe(5);
|
||||
});
|
||||
|
||||
it('should not delay after the last batch', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '3';
|
||||
process.env.MEILI_SYNC_DELAY_MS = '100';
|
||||
|
||||
// Exactly 3 documents - should fit in one batch, no delay
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
// Verify all 3 documents were processed in a single batch
|
||||
expect(result).toBe(3);
|
||||
|
||||
const updatedDocs = await testCollection.countDocuments({ _meiliIndex: false });
|
||||
expect(updatedDocs).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle documents without _meiliIndex field', async () => {
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
// Only one document has _meiliIndex: true
|
||||
expect(result).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle mixed document states correctly', async () => {
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: false },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: new Date(), _meiliIndex: true },
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(2);
|
||||
|
||||
const flaggedDocs = await testCollection
|
||||
.find({ expiredAt: null, _meiliIndex: false })
|
||||
.toArray();
|
||||
expect(flaggedDocs).toHaveLength(3); // 2 were updated, 1 was already false
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should throw error with context when find operation fails', async () => {
|
||||
const mockCollection = {
|
||||
collectionName: 'test_meili_batch',
|
||||
find: jest.fn().mockReturnValue({
|
||||
limit: jest.fn().mockReturnValue({
|
||||
toArray: jest.fn().mockRejectedValue(new Error('Network error')),
|
||||
}),
|
||||
}),
|
||||
};
|
||||
|
||||
await expect(batchResetMeiliFlags(mockCollection)).rejects.toThrow(
|
||||
"Failed to batch reset Meili flags for collection 'test_meili_batch' after processing 0 documents: Network error",
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error with context when updateMany operation fails', async () => {
|
||||
const mockCollection = {
|
||||
collectionName: 'test_meili_batch',
|
||||
find: jest.fn().mockReturnValue({
|
||||
limit: jest.fn().mockReturnValue({
|
||||
toArray: jest
|
||||
.fn()
|
||||
.mockResolvedValue([
|
||||
{ _id: new mongoose.Types.ObjectId() },
|
||||
{ _id: new mongoose.Types.ObjectId() },
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
updateMany: jest.fn().mockRejectedValue(new Error('Connection lost')),
|
||||
};
|
||||
|
||||
await expect(batchResetMeiliFlags(mockCollection)).rejects.toThrow(
|
||||
"Failed to batch reset Meili flags for collection 'test_meili_batch' after processing 0 documents: Connection lost",
|
||||
);
|
||||
});
|
||||
|
||||
it('should include documents processed count in error when failure occurs mid-batch', async () => {
|
||||
// Set batch size to 2 to force multiple batches
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '2';
|
||||
process.env.MEILI_SYNC_DELAY_MS = '0'; // No delay for faster test
|
||||
|
||||
let findCallCount = 0;
|
||||
let updateCallCount = 0;
|
||||
|
||||
const mockCollection = {
|
||||
collectionName: 'test_meili_batch',
|
||||
find: jest.fn().mockReturnValue({
|
||||
limit: jest.fn().mockReturnValue({
|
||||
toArray: jest.fn().mockImplementation(() => {
|
||||
findCallCount++;
|
||||
// Return 2 documents for first two calls (to keep loop going)
|
||||
// Return 2 documents for third call (to trigger third update which will fail)
|
||||
if (findCallCount <= 3) {
|
||||
return Promise.resolve([
|
||||
{ _id: new mongoose.Types.ObjectId() },
|
||||
{ _id: new mongoose.Types.ObjectId() },
|
||||
]);
|
||||
}
|
||||
// Should not reach here due to error
|
||||
return Promise.resolve([]);
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
updateMany: jest.fn().mockImplementation(() => {
|
||||
updateCallCount++;
|
||||
if (updateCallCount === 1) {
|
||||
return Promise.resolve({ modifiedCount: 2 });
|
||||
} else if (updateCallCount === 2) {
|
||||
return Promise.resolve({ modifiedCount: 2 });
|
||||
} else {
|
||||
return Promise.reject(new Error('Database timeout'));
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
await expect(batchResetMeiliFlags(mockCollection)).rejects.toThrow(
|
||||
"Failed to batch reset Meili flags for collection 'test_meili_batch' after processing 4 documents: Database timeout",
|
||||
);
|
||||
});
|
||||
|
||||
it('should use collection.collectionName in error messages', async () => {
|
||||
const mockCollection = {
|
||||
collectionName: 'messages',
|
||||
find: jest.fn().mockReturnValue({
|
||||
limit: jest.fn().mockReturnValue({
|
||||
toArray: jest.fn().mockRejectedValue(new Error('Permission denied')),
|
||||
}),
|
||||
}),
|
||||
};
|
||||
|
||||
await expect(batchResetMeiliFlags(mockCollection)).rejects.toThrow(
|
||||
"Failed to batch reset Meili flags for collection 'messages' after processing 0 documents: Permission denied",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('environment variable validation', () => {
|
||||
let warnSpy;
|
||||
|
||||
beforeEach(() => {
|
||||
// Mock logger.warn to track warning calls
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
warnSpy = jest.spyOn(logger, 'warn').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (warnSpy) {
|
||||
warnSpy.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it('should log warning and use default when MEILI_SYNC_BATCH_SIZE is not a number', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = 'abc';
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid value for MEILI_SYNC_BATCH_SIZE="abc"'),
|
||||
);
|
||||
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Using default: 1000'));
|
||||
});
|
||||
|
||||
it('should log warning and use default when MEILI_SYNC_DELAY_MS is not a number', async () => {
|
||||
process.env.MEILI_SYNC_DELAY_MS = 'xyz';
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid value for MEILI_SYNC_DELAY_MS="xyz"'),
|
||||
);
|
||||
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Using default: 100'));
|
||||
});
|
||||
|
||||
it('should log warning and use default when MEILI_SYNC_BATCH_SIZE is negative', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '-50';
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid value for MEILI_SYNC_BATCH_SIZE="-50"'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should log warning and use default when MEILI_SYNC_DELAY_MS is negative', async () => {
|
||||
process.env.MEILI_SYNC_DELAY_MS = '-100';
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid value for MEILI_SYNC_DELAY_MS="-100"'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept valid positive integer values without warnings', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '500';
|
||||
process.env.MEILI_SYNC_DELAY_MS = '50';
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should log warning and use default when MEILI_SYNC_BATCH_SIZE is zero', async () => {
|
||||
process.env.MEILI_SYNC_BATCH_SIZE = '0';
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('MEILI_SYNC_BATCH_SIZE cannot be 0. Using default: 1000'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept zero as a valid value for MEILI_SYNC_DELAY_MS without warnings', async () => {
|
||||
process.env.MEILI_SYNC_DELAY_MS = '0';
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not log warnings when environment variables are not set', async () => {
|
||||
delete process.env.MEILI_SYNC_BATCH_SIZE;
|
||||
delete process.env.MEILI_SYNC_DELAY_MS;
|
||||
|
||||
await testCollection.insertMany([
|
||||
{ _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true },
|
||||
]);
|
||||
|
||||
const result = await batchResetMeiliFlags(testCollection);
|
||||
|
||||
expect(result).toBe(1);
|
||||
expect(warnSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,8 +1,17 @@
|
|||
const mongoose = require('mongoose');
|
||||
const crypto = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_all, mcp_delimiter } =
|
||||
const { getCustomEndpointConfig } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
actionDelimiter,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
encodeEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, mcp_all, mcp_delimiter } =
|
||||
require('librechat-data-provider').Constants;
|
||||
const {
|
||||
removeAgentFromAllProjects,
|
||||
|
|
@ -92,7 +101,7 @@ const getAgents = async (searchParameter) => await Agent.find(searchParameter).l
|
|||
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_parameters: _m }) => {
|
||||
const loadEphemeralAgent = async ({ req, spec, endpoint, model_parameters: _m }) => {
|
||||
const { model, ...model_parameters } = _m;
|
||||
const modelSpecs = req.config?.modelSpecs?.list;
|
||||
/** @type {TModelSpec | null} */
|
||||
|
|
@ -139,8 +148,28 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
|
|||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
|
||||
// Get endpoint config for modelDisplayLabel fallback
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadEphemeralAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender =
|
||||
model_parameters?.modelLabel ?? modelSpec?.label ?? endpointConfig?.modelDisplayLabel ?? '';
|
||||
|
||||
// Encode ephemeral agent ID with endpoint, model, and computed sender for display
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender });
|
||||
|
||||
const result = {
|
||||
id: agent_id,
|
||||
id: ephemeralId,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
|
|
@ -169,8 +198,8 @@ const loadAgent = async ({ req, spec, agent_id, endpoint, model_parameters }) =>
|
|||
if (!agent_id) {
|
||||
return null;
|
||||
}
|
||||
if (agent_id === EPHEMERAL_AGENT_ID) {
|
||||
return await loadEphemeralAgent({ req, spec, agent_id, endpoint, model_parameters });
|
||||
if (isEphemeralAgentId(agent_id)) {
|
||||
return await loadEphemeralAgent({ req, spec, endpoint, model_parameters });
|
||||
}
|
||||
const agent = await getAgent({
|
||||
id: agent_id,
|
||||
|
|
@ -566,6 +595,11 @@ const deleteAgent = async (searchParameter) => {
|
|||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
});
|
||||
try {
|
||||
await Agent.updateMany({ 'edges.to': agent.id }, { $pull: { edges: { to: agent.id } } });
|
||||
} catch (error) {
|
||||
logger.error('[deleteAgent] Error removing agent from handoff edges', error);
|
||||
}
|
||||
}
|
||||
return agent;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -532,6 +532,49 @@ describe('models/Agent', () => {
|
|||
expect(aclEntriesAfter).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should remove handoff edges referencing deleted agent from other agents', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const targetAgentId = `agent_${uuidv4()}`;
|
||||
const sourceAgentId = `agent_${uuidv4()}`;
|
||||
|
||||
// Create target agent (handoff destination)
|
||||
await createAgent({
|
||||
id: targetAgentId,
|
||||
name: 'Target Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Create source agent with handoff edge to target
|
||||
await createAgent({
|
||||
id: sourceAgentId,
|
||||
name: 'Source Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
edges: [
|
||||
{
|
||||
from: sourceAgentId,
|
||||
to: targetAgentId,
|
||||
edgeType: 'handoff',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Verify edge exists before deletion
|
||||
const sourceAgentBefore = await getAgent({ id: sourceAgentId });
|
||||
expect(sourceAgentBefore.edges).toHaveLength(1);
|
||||
expect(sourceAgentBefore.edges[0].to).toBe(targetAgentId);
|
||||
|
||||
// Delete the target agent
|
||||
await deleteAgent({ id: targetAgentId });
|
||||
|
||||
// Verify the edge is removed from source agent
|
||||
const sourceAgentAfter = await getAgent({ id: sourceAgentId });
|
||||
expect(sourceAgentAfter.edges).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should list agents by author', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const otherAuthorId = new mongoose.Types.ObjectId();
|
||||
|
|
@ -1960,7 +2003,8 @@ describe('models/Agent', () => {
|
|||
});
|
||||
|
||||
if (result) {
|
||||
expect(result.id).toBe(EPHEMERAL_AGENT_ID);
|
||||
// Ephemeral agent ID is encoded with endpoint and model
|
||||
expect(result.id).toBe('openai__gpt-4');
|
||||
expect(result.instructions).toBe('Test instructions');
|
||||
expect(result.provider).toBe('openai');
|
||||
expect(result.model).toBe('gpt-4');
|
||||
|
|
@ -1978,7 +2022,7 @@ describe('models/Agent', () => {
|
|||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent({
|
||||
req: mockReq,
|
||||
agent_id: 'non_existent_agent',
|
||||
agent_id: 'agent_non_existent',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
|
@ -2105,7 +2149,7 @@ describe('models/Agent', () => {
|
|||
test('should handle loadAgent with malformed req object', async () => {
|
||||
const result = await loadAgent({
|
||||
req: null,
|
||||
agent_id: 'test',
|
||||
agent_id: 'agent_test',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ module.exports = {
|
|||
isArchived = false,
|
||||
tags,
|
||||
search,
|
||||
sortBy = 'createdAt',
|
||||
sortBy = 'updatedAt',
|
||||
sortDirection = 'desc',
|
||||
} = {},
|
||||
) => {
|
||||
|
|
@ -251,10 +251,12 @@ module.exports = {
|
|||
|
||||
let nextCursor = null;
|
||||
if (convos.length > limit) {
|
||||
const lastConvo = convos.pop();
|
||||
const primaryValue = lastConvo[finalSortBy];
|
||||
convos.pop(); // Remove extra item used to detect next page
|
||||
// Create cursor from the last RETURNED item (not the popped one)
|
||||
const lastReturned = convos[convos.length - 1];
|
||||
const primaryValue = lastReturned[finalSortBy];
|
||||
const primaryStr = finalSortBy === 'title' ? primaryValue : primaryValue.toISOString();
|
||||
const secondaryStr = lastConvo.updatedAt.toISOString();
|
||||
const secondaryStr = lastReturned.updatedAt.toISOString();
|
||||
const composite = { primary: primaryStr, secondary: secondaryStr };
|
||||
nextCursor = Buffer.from(JSON.stringify(composite)).toString('base64');
|
||||
}
|
||||
|
|
@ -290,8 +292,9 @@ module.exports = {
|
|||
const limited = filtered.slice(0, limit + 1);
|
||||
let nextCursor = null;
|
||||
if (limited.length > limit) {
|
||||
const lastConvo = limited.pop();
|
||||
nextCursor = lastConvo.updatedAt.toISOString();
|
||||
limited.pop(); // Remove extra item used to detect next page
|
||||
// Create cursor from the last RETURNED item (not the popped one)
|
||||
nextCursor = limited[limited.length - 1].updatedAt.toISOString();
|
||||
}
|
||||
|
||||
const convoMap = {};
|
||||
|
|
|
|||
|
|
@ -567,4 +567,267 @@ describe('Conversation Operations', () => {
|
|||
await mongoose.connect(mongoServer.getUri());
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvosByCursor pagination', () => {
|
||||
/**
|
||||
* Helper to create conversations with specific timestamps
|
||||
* Uses collection.insertOne to bypass Mongoose timestamps entirely
|
||||
*/
|
||||
const createConvoWithTimestamps = async (index, createdAt, updatedAt) => {
|
||||
const conversationId = uuidv4();
|
||||
// Use collection-level insert to bypass Mongoose timestamps
|
||||
await Conversation.collection.insertOne({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
title: `Conversation ${index}`,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
isArchived: false,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
});
|
||||
return Conversation.findOne({ conversationId }).lean();
|
||||
};
|
||||
|
||||
it('should not skip conversations at page boundaries', async () => {
|
||||
// Create 30 conversations to ensure pagination (limit is 25)
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
const convos = [];
|
||||
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000); // Each 1 minute apart
|
||||
const convo = await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
convos.push(convo);
|
||||
}
|
||||
|
||||
// Fetch first page
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
expect(page1.conversations).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Fetch second page using cursor
|
||||
const page2 = await getConvosByCursor('user123', {
|
||||
limit: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Should get remaining 5 conversations
|
||||
expect(page2.conversations).toHaveLength(5);
|
||||
expect(page2.nextCursor).toBeNull();
|
||||
|
||||
// Verify no duplicates and no gaps
|
||||
const allIds = [
|
||||
...page1.conversations.map((c) => c.conversationId),
|
||||
...page2.conversations.map((c) => c.conversationId),
|
||||
];
|
||||
const uniqueIds = new Set(allIds);
|
||||
|
||||
expect(uniqueIds.size).toBe(30); // All 30 conversations accounted for
|
||||
expect(allIds.length).toBe(30); // No duplicates
|
||||
});
|
||||
|
||||
it('should include conversation at exact page boundary (item 26 bug fix)', async () => {
|
||||
// This test specifically verifies the fix for the bug where item 26
|
||||
// (the first item that should appear on page 2) was being skipped
|
||||
|
||||
const baseTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create exactly 26 conversations
|
||||
const convos = [];
|
||||
for (let i = 0; i < 26; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
const convo = await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
convos.push(convo);
|
||||
}
|
||||
|
||||
// The 26th conversation (index 25) should be on page 2
|
||||
const item26 = convos[25];
|
||||
|
||||
// Fetch first page with limit 25
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
expect(page1.conversations).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Item 26 should NOT be in page 1
|
||||
const page1Ids = page1.conversations.map((c) => c.conversationId);
|
||||
expect(page1Ids).not.toContain(item26.conversationId);
|
||||
|
||||
// Fetch second page
|
||||
const page2 = await getConvosByCursor('user123', {
|
||||
limit: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Item 26 MUST be in page 2 (this was the bug - it was being skipped)
|
||||
expect(page2.conversations).toHaveLength(1);
|
||||
expect(page2.conversations[0].conversationId).toBe(item26.conversationId);
|
||||
});
|
||||
|
||||
it('should sort by updatedAt DESC by default', async () => {
|
||||
// Create conversations with different updatedAt times
|
||||
// Note: createdAt is older but updatedAt varies
|
||||
const convo1 = await createConvoWithTimestamps(
|
||||
1,
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest created
|
||||
new Date('2026-01-03T00:00:00.000Z'), // most recently updated
|
||||
);
|
||||
|
||||
const convo2 = await createConvoWithTimestamps(
|
||||
2,
|
||||
new Date('2026-01-02T00:00:00.000Z'), // middle created
|
||||
new Date('2026-01-02T00:00:00.000Z'), // middle updated
|
||||
);
|
||||
|
||||
const convo3 = await createConvoWithTimestamps(
|
||||
3,
|
||||
new Date('2026-01-03T00:00:00.000Z'), // newest created
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest updated
|
||||
);
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// Should be sorted by updatedAt DESC (most recent first)
|
||||
expect(result.conversations).toHaveLength(3);
|
||||
expect(result.conversations[0].conversationId).toBe(convo1.conversationId); // Jan 3 updatedAt
|
||||
expect(result.conversations[1].conversationId).toBe(convo2.conversationId); // Jan 2 updatedAt
|
||||
expect(result.conversations[2].conversationId).toBe(convo3.conversationId); // Jan 1 updatedAt
|
||||
});
|
||||
|
||||
it('should handle conversations with same updatedAt (tie-breaker)', async () => {
|
||||
const sameTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create 3 conversations with exact same updatedAt
|
||||
const convo1 = await createConvoWithTimestamps(1, sameTime, sameTime);
|
||||
const convo2 = await createConvoWithTimestamps(2, sameTime, sameTime);
|
||||
const convo3 = await createConvoWithTimestamps(3, sameTime, sameTime);
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// All 3 should be returned (no skipping due to same timestamps)
|
||||
expect(result.conversations).toHaveLength(3);
|
||||
|
||||
const returnedIds = result.conversations.map((c) => c.conversationId);
|
||||
expect(returnedIds).toContain(convo1.conversationId);
|
||||
expect(returnedIds).toContain(convo2.conversationId);
|
||||
expect(returnedIds).toContain(convo3.conversationId);
|
||||
});
|
||||
|
||||
it('should handle cursor pagination with conversations updated during pagination', async () => {
|
||||
// Simulate the scenario where a conversation is updated between page fetches
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 30 conversations
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
}
|
||||
|
||||
// Fetch first page
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
expect(page1.conversations).toHaveLength(25);
|
||||
|
||||
// Now update one of the conversations that should be on page 2
|
||||
// to have a newer updatedAt (simulating user activity during pagination)
|
||||
const convosOnPage2 = await Conversation.find({ user: 'user123' })
|
||||
.sort({ updatedAt: -1 })
|
||||
.skip(25)
|
||||
.limit(5);
|
||||
|
||||
if (convosOnPage2.length > 0) {
|
||||
const updatedConvo = convosOnPage2[0];
|
||||
await Conversation.updateOne(
|
||||
{ _id: updatedConvo._id },
|
||||
{ updatedAt: new Date('2026-01-02T00:00:00.000Z') }, // Much newer
|
||||
);
|
||||
}
|
||||
|
||||
// Fetch second page with original cursor
|
||||
const page2 = await getConvosByCursor('user123', {
|
||||
limit: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// The updated conversation might not be in page 2 anymore
|
||||
// (it moved to the front), but we should still get remaining items
|
||||
// without errors and without infinite loops
|
||||
expect(page2.conversations.length).toBeGreaterThanOrEqual(0);
|
||||
});
|
||||
|
||||
it('should correctly decode and use cursor for pagination', async () => {
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 30 conversations
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
}
|
||||
|
||||
// Fetch first page
|
||||
const page1 = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
// Decode the cursor to verify it's based on the last RETURNED item
|
||||
const decodedCursor = JSON.parse(Buffer.from(page1.nextCursor, 'base64').toString());
|
||||
|
||||
// The cursor should match the last item in page1 (item at index 24)
|
||||
const lastReturnedItem = page1.conversations[24];
|
||||
|
||||
expect(new Date(decodedCursor.primary).getTime()).toBe(
|
||||
new Date(lastReturnedItem.updatedAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should support sortBy createdAt when explicitly requested', async () => {
|
||||
// Create conversations with different timestamps
|
||||
const convo1 = await createConvoWithTimestamps(
|
||||
1,
|
||||
new Date('2026-01-03T00:00:00.000Z'), // newest created
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest updated
|
||||
);
|
||||
|
||||
const convo2 = await createConvoWithTimestamps(
|
||||
2,
|
||||
new Date('2026-01-01T00:00:00.000Z'), // oldest created
|
||||
new Date('2026-01-03T00:00:00.000Z'), // newest updated
|
||||
);
|
||||
|
||||
// Verify timestamps were set correctly
|
||||
expect(new Date(convo1.createdAt).getTime()).toBe(
|
||||
new Date('2026-01-03T00:00:00.000Z').getTime(),
|
||||
);
|
||||
expect(new Date(convo2.createdAt).getTime()).toBe(
|
||||
new Date('2026-01-01T00:00:00.000Z').getTime(),
|
||||
);
|
||||
|
||||
const result = await getConvosByCursor('user123', { sortBy: 'createdAt' });
|
||||
|
||||
// Should be sorted by createdAt DESC
|
||||
expect(result.conversations).toHaveLength(2);
|
||||
expect(result.conversations[0].conversationId).toBe(convo1.conversationId); // Jan 3 createdAt
|
||||
expect(result.conversations[1].conversationId).toBe(convo2.conversationId); // Jan 1 createdAt
|
||||
});
|
||||
|
||||
it('should handle empty result set gracefully', async () => {
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
expect(result.conversations).toHaveLength(0);
|
||||
expect(result.nextCursor).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle exactly limit number of conversations (no next page)', async () => {
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create exactly 25 conversations (equal to default limit)
|
||||
for (let i = 0; i < 25; i++) {
|
||||
const updatedAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createConvoWithTimestamps(i, updatedAt, updatedAt);
|
||||
}
|
||||
|
||||
const result = await getConvosByCursor('user123', { limit: 25 });
|
||||
|
||||
expect(result.conversations).toHaveLength(25);
|
||||
expect(result.nextCursor).toBeNull(); // No next page
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -573,4 +573,326 @@ describe('Message Operations', () => {
|
|||
expect(bulk2.expiredAt).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Message cursor pagination', () => {
|
||||
/**
|
||||
* Helper to create messages with specific timestamps
|
||||
* Uses collection.insertOne to bypass Mongoose timestamps
|
||||
*/
|
||||
const createMessageWithTimestamp = async (index, conversationId, createdAt) => {
|
||||
const messageId = uuidv4();
|
||||
await Message.collection.insertOne({
|
||||
messageId,
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
text: `Message ${index}`,
|
||||
isCreatedByUser: index % 2 === 0,
|
||||
createdAt,
|
||||
updatedAt: createdAt,
|
||||
});
|
||||
return Message.findOne({ messageId }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Simulates the pagination logic from api/server/routes/messages.js
|
||||
* This tests the exact query pattern used in the route
|
||||
*/
|
||||
const getMessagesByCursor = async ({
|
||||
conversationId,
|
||||
user,
|
||||
pageSize = 25,
|
||||
cursor = null,
|
||||
sortBy = 'createdAt',
|
||||
sortDirection = 'desc',
|
||||
}) => {
|
||||
const sortOrder = sortDirection === 'asc' ? 1 : -1;
|
||||
const sortField = ['createdAt', 'updatedAt'].includes(sortBy) ? sortBy : 'createdAt';
|
||||
const cursorOperator = sortDirection === 'asc' ? '$gt' : '$lt';
|
||||
|
||||
const filter = { conversationId, user };
|
||||
if (cursor) {
|
||||
filter[sortField] = { [cursorOperator]: new Date(cursor) };
|
||||
}
|
||||
|
||||
const messages = await Message.find(filter)
|
||||
.sort({ [sortField]: sortOrder })
|
||||
.limit(pageSize + 1)
|
||||
.lean();
|
||||
|
||||
let nextCursor = null;
|
||||
if (messages.length > pageSize) {
|
||||
messages.pop(); // Remove extra item used to detect next page
|
||||
// Create cursor from the last RETURNED item (not the popped one)
|
||||
nextCursor = messages[messages.length - 1][sortField];
|
||||
}
|
||||
|
||||
return { messages, nextCursor };
|
||||
};
|
||||
|
||||
it('should return messages for a conversation with pagination', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 30 messages to test pagination
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000); // Each 1 minute apart
|
||||
await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
}
|
||||
|
||||
// Fetch first page (pageSize 25)
|
||||
const page1 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
});
|
||||
|
||||
expect(page1.messages).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Fetch second page using cursor
|
||||
const page2 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Should get remaining 5 messages
|
||||
expect(page2.messages).toHaveLength(5);
|
||||
expect(page2.nextCursor).toBeNull();
|
||||
|
||||
// Verify no duplicates and no gaps
|
||||
const allMessageIds = [
|
||||
...page1.messages.map((m) => m.messageId),
|
||||
...page2.messages.map((m) => m.messageId),
|
||||
];
|
||||
const uniqueIds = new Set(allMessageIds);
|
||||
|
||||
expect(uniqueIds.size).toBe(30); // All 30 messages accounted for
|
||||
expect(allMessageIds.length).toBe(30); // No duplicates
|
||||
});
|
||||
|
||||
it('should not skip message at page boundary (item 26 bug fix)', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create exactly 26 messages
|
||||
const messages = [];
|
||||
for (let i = 0; i < 26; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000);
|
||||
const msg = await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
messages.push(msg);
|
||||
}
|
||||
|
||||
// The 26th message (index 25) should be on page 2
|
||||
const item26 = messages[25];
|
||||
|
||||
// Fetch first page with pageSize 25
|
||||
const page1 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
});
|
||||
|
||||
expect(page1.messages).toHaveLength(25);
|
||||
expect(page1.nextCursor).toBeTruthy();
|
||||
|
||||
// Item 26 should NOT be in page 1
|
||||
const page1Ids = page1.messages.map((m) => m.messageId);
|
||||
expect(page1Ids).not.toContain(item26.messageId);
|
||||
|
||||
// Fetch second page
|
||||
const page2 = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
cursor: page1.nextCursor,
|
||||
});
|
||||
|
||||
// Item 26 MUST be in page 2 (this was the bug - it was being skipped)
|
||||
expect(page2.messages).toHaveLength(1);
|
||||
expect(page2.messages[0].messageId).toBe(item26.messageId);
|
||||
});
|
||||
|
||||
it('should sort by createdAt DESC by default', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
// Create messages with specific timestamps
|
||||
const msg1 = await createMessageWithTimestamp(
|
||||
1,
|
||||
conversationId,
|
||||
new Date('2026-01-01T00:00:00.000Z'),
|
||||
);
|
||||
const msg2 = await createMessageWithTimestamp(
|
||||
2,
|
||||
conversationId,
|
||||
new Date('2026-01-02T00:00:00.000Z'),
|
||||
);
|
||||
const msg3 = await createMessageWithTimestamp(
|
||||
3,
|
||||
conversationId,
|
||||
new Date('2026-01-03T00:00:00.000Z'),
|
||||
);
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
// Should be sorted by createdAt DESC (newest first) by default
|
||||
expect(result.messages).toHaveLength(3);
|
||||
expect(result.messages[0].messageId).toBe(msg3.messageId);
|
||||
expect(result.messages[1].messageId).toBe(msg2.messageId);
|
||||
expect(result.messages[2].messageId).toBe(msg1.messageId);
|
||||
});
|
||||
|
||||
it('should support ascending sort direction', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
const msg1 = await createMessageWithTimestamp(
|
||||
1,
|
||||
conversationId,
|
||||
new Date('2026-01-01T00:00:00.000Z'),
|
||||
);
|
||||
const msg2 = await createMessageWithTimestamp(
|
||||
2,
|
||||
conversationId,
|
||||
new Date('2026-01-02T00:00:00.000Z'),
|
||||
);
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
sortDirection: 'asc',
|
||||
});
|
||||
|
||||
// Should be sorted by createdAt ASC (oldest first)
|
||||
expect(result.messages).toHaveLength(2);
|
||||
expect(result.messages[0].messageId).toBe(msg1.messageId);
|
||||
expect(result.messages[1].messageId).toBe(msg2.messageId);
|
||||
});
|
||||
|
||||
it('should handle empty conversation', async () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
expect(result.messages).toHaveLength(0);
|
||||
expect(result.nextCursor).toBeNull();
|
||||
});
|
||||
|
||||
it('should only return messages for the specified user', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const createdAt = new Date();
|
||||
|
||||
// Create a message for user123
|
||||
await Message.collection.insertOne({
|
||||
messageId: uuidv4(),
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
text: 'User message',
|
||||
createdAt,
|
||||
updatedAt: createdAt,
|
||||
});
|
||||
|
||||
// Create a message for a different user
|
||||
await Message.collection.insertOne({
|
||||
messageId: uuidv4(),
|
||||
conversationId,
|
||||
user: 'otherUser',
|
||||
text: 'Other user message',
|
||||
createdAt,
|
||||
updatedAt: createdAt,
|
||||
});
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
// Should only return user123's message
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].user).toBe('user123');
|
||||
});
|
||||
|
||||
it('should handle exactly pageSize number of messages (no next page)', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create exactly 25 messages (equal to default pageSize)
|
||||
for (let i = 0; i < 25; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
}
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 25,
|
||||
});
|
||||
|
||||
expect(result.messages).toHaveLength(25);
|
||||
expect(result.nextCursor).toBeNull(); // No next page
|
||||
});
|
||||
|
||||
it('should handle pageSize of 1', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const baseTime = new Date('2026-01-01T00:00:00.000Z');
|
||||
|
||||
// Create 3 messages
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const createdAt = new Date(baseTime.getTime() - i * 60000);
|
||||
await createMessageWithTimestamp(i, conversationId, createdAt);
|
||||
}
|
||||
|
||||
// Fetch with pageSize 1
|
||||
let cursor = null;
|
||||
const allMessages = [];
|
||||
|
||||
for (let page = 0; page < 5; page++) {
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 1,
|
||||
cursor,
|
||||
});
|
||||
|
||||
allMessages.push(...result.messages);
|
||||
cursor = result.nextCursor;
|
||||
|
||||
if (!cursor) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Should get all 3 messages without duplicates
|
||||
expect(allMessages).toHaveLength(3);
|
||||
const uniqueIds = new Set(allMessages.map((m) => m.messageId));
|
||||
expect(uniqueIds.size).toBe(3);
|
||||
});
|
||||
|
||||
it('should handle messages with same createdAt timestamp', async () => {
|
||||
const conversationId = uuidv4();
|
||||
const sameTime = new Date('2026-01-01T12:00:00.000Z');
|
||||
|
||||
// Create multiple messages with the exact same timestamp
|
||||
const messages = [];
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const msg = await createMessageWithTimestamp(i, conversationId, sameTime);
|
||||
messages.push(msg);
|
||||
}
|
||||
|
||||
const result = await getMessagesByCursor({
|
||||
conversationId,
|
||||
user: 'user123',
|
||||
pageSize: 10,
|
||||
});
|
||||
|
||||
// All messages should be returned
|
||||
expect(result.messages).toHaveLength(5);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
|
@ -55,7 +55,7 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
SHARE: true,
|
||||
},
|
||||
});
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ describe('updateAccessPermissions', () => {
|
|||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
SHARE: true,
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
|
@ -83,7 +83,7 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: false,
|
||||
},
|
||||
});
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ describe('updateAccessPermissions', () => {
|
|||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: false,
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -110,20 +110,20 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { SHARED_GLOBAL: true },
|
||||
[PermissionTypes.PROMPTS]: { SHARE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
SHARE: true,
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ describe('updateAccessPermissions', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: false,
|
||||
},
|
||||
},
|
||||
}).save();
|
||||
|
|
@ -147,7 +147,7 @@ describe('updateAccessPermissions', () => {
|
|||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: false,
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -155,13 +155,13 @@ describe('updateAccessPermissions', () => {
|
|||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARE: false },
|
||||
[PermissionTypes.BOOKMARKS]: { USE: true },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARED_GLOBAL: true },
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARE: true },
|
||||
[PermissionTypes.BOOKMARKS]: { USE: false },
|
||||
});
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ describe('updateAccessPermissions', () => {
|
|||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
SHARE: true,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.BOOKMARKS]).toEqual({ USE: false });
|
||||
});
|
||||
|
|
@ -178,19 +178,19 @@ describe('updateAccessPermissions', () => {
|
|||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARED_GLOBAL: true },
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await getRoleByName(SystemRoles.USER);
|
||||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
SHARE: true,
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -214,13 +214,13 @@ describe('updateAccessPermissions', () => {
|
|||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false },
|
||||
[PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARE: false },
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: false },
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { SHARED_GLOBAL: true },
|
||||
[PermissionTypes.PROMPTS]: { SHARE: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
|
|
@ -228,7 +228,7 @@ describe('updateAccessPermissions', () => {
|
|||
expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
SHARE: true,
|
||||
});
|
||||
expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true });
|
||||
});
|
||||
|
|
@ -271,7 +271,7 @@ describe('initializeRoles', () => {
|
|||
});
|
||||
|
||||
// Example: Check default values for ADMIN role
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.PROMPTS].SHARE).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.BOOKMARKS].USE).toBe(true);
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBe(true);
|
||||
});
|
||||
|
|
@ -283,7 +283,7 @@ describe('initializeRoles', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
[Permissions.SHARE]: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false },
|
||||
},
|
||||
|
|
@ -320,7 +320,7 @@ describe('initializeRoles', () => {
|
|||
expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
expect(userRole.permissions[PermissionTypes.AGENTS].SHARE).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle multiple runs without duplicating or modifying data', async () => {
|
||||
|
|
@ -348,7 +348,7 @@ describe('initializeRoles', () => {
|
|||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
[Permissions.SHARE]: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]:
|
||||
roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.BOOKMARKS],
|
||||
|
|
@ -365,7 +365,7 @@ describe('initializeRoles', () => {
|
|||
expect(adminRole.permissions[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
expect(adminRole.permissions[PermissionTypes.AGENTS].SHARE).toBeDefined();
|
||||
});
|
||||
|
||||
it('should include MULTI_CONVO permissions when creating default roles', async () => {
|
||||
|
|
|
|||
218
api/models/loadAddedAgent.js
Normal file
218
api/models/loadAddedAgent.js
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getCustomEndpointConfig } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
appendAgentIdSuffix,
|
||||
encodeEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
|
||||
const { mcp_all, mcp_delimiter } = Constants;
|
||||
|
||||
/**
|
||||
* Constant for added conversation agent ID
|
||||
*/
|
||||
const ADDED_AGENT_ID = 'added_agent';
|
||||
|
||||
/**
|
||||
* Get an agent document based on the provided ID.
|
||||
* @param {Object} searchParameter - The search parameters to find the agent.
|
||||
* @param {string} searchParameter.id - The ID of the agent.
|
||||
* @returns {Promise<import('librechat-data-provider').Agent|null>}
|
||||
*/
|
||||
let getAgent;
|
||||
|
||||
/**
|
||||
* Set the getAgent function (dependency injection to avoid circular imports)
|
||||
* @param {Function} fn
|
||||
*/
|
||||
const setGetAgent = (fn) => {
|
||||
getAgent = fn;
|
||||
};
|
||||
|
||||
/**
|
||||
* Load an agent from an added conversation (TConversation).
|
||||
* Used for multi-convo parallel agent execution.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req
|
||||
* @param {import('librechat-data-provider').TConversation} params.conversation - The added conversation
|
||||
* @param {import('librechat-data-provider').Agent} [params.primaryAgent] - The primary agent (used to duplicate tools when both are ephemeral)
|
||||
* @returns {Promise<import('librechat-data-provider').Agent|null>} The agent config as a plain object, or null if invalid.
|
||||
*/
|
||||
const loadAddedAgent = async ({ req, conversation, primaryAgent }) => {
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// If there's an agent_id, load the existing agent
|
||||
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
|
||||
if (!getAgent) {
|
||||
throw new Error('getAgent not initialized - call setGetAgent first');
|
||||
}
|
||||
const agent = await getAgent({
|
||||
id: conversation.agent_id,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);
|
||||
return null;
|
||||
}
|
||||
|
||||
agent.version = agent.versions ? agent.versions.length : 0;
|
||||
// Append suffix to distinguish from primary agent (matches ephemeral format)
|
||||
// This is needed when both agents have the same ID or for consistent parallel content attribution
|
||||
agent.id = appendAgentIdSuffix(agent.id, 1);
|
||||
return agent;
|
||||
}
|
||||
|
||||
// Otherwise, create an ephemeral agent config from the conversation
|
||||
const { model, endpoint, promptPrefix, spec, ...rest } = conversation;
|
||||
|
||||
if (!endpoint || !model) {
|
||||
logger.warn('[loadAddedAgent] Missing required endpoint or model for ephemeral agent');
|
||||
return null;
|
||||
}
|
||||
|
||||
// If both primary and added agents are ephemeral, duplicate tools from primary agent
|
||||
const primaryIsEphemeral = primaryAgent && isEphemeralAgentId(primaryAgent.id);
|
||||
if (primaryIsEphemeral && Array.isArray(primaryAgent.tools)) {
|
||||
// Get endpoint config and model spec for display name fallbacks
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// Look up model spec for label fallback
|
||||
const modelSpecs = appConfig?.modelSpecs?.list;
|
||||
const modelSpec = spec != null && spec !== '' ? modelSpecs?.find((s) => s.name === spec) : null;
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender = rest.modelLabel ?? modelSpec?.label ?? endpointConfig?.modelDisplayLabel ?? '';
|
||||
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
return {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters: {},
|
||||
model,
|
||||
tools: [...primaryAgent.tools],
|
||||
};
|
||||
}
|
||||
|
||||
// Extract ephemeral agent options from conversation if present
|
||||
const ephemeralAgent = rest.ephemeralAgent;
|
||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id;
|
||||
|
||||
// Check model spec for MCP servers
|
||||
const modelSpecs = req.config?.modelSpecs?.list;
|
||||
let modelSpec = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.find((s) => s.name === spec) || null;
|
||||
}
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {string[]} */
|
||||
const tools = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
// Build model_parameters from conversation fields
|
||||
const model_parameters = {};
|
||||
const paramKeys = [
|
||||
'temperature',
|
||||
'top_p',
|
||||
'topP',
|
||||
'topK',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'maxOutputTokens',
|
||||
'maxTokens',
|
||||
'max_tokens',
|
||||
];
|
||||
|
||||
for (const key of paramKeys) {
|
||||
if (rest[key] != null) {
|
||||
model_parameters[key] = rest[key];
|
||||
}
|
||||
}
|
||||
|
||||
// Get endpoint config for modelDisplayLabel fallback
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// For ephemeral agents, use modelLabel if provided, then model spec's label,
|
||||
// then modelDisplayLabel from endpoint config, otherwise empty string to show model name
|
||||
const sender = rest.modelLabel ?? modelSpec?.label ?? endpointConfig?.modelDisplayLabel ?? '';
|
||||
|
||||
/** Encoded ephemeral agent ID with endpoint, model, sender, and index=1 to distinguish from primary */
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
const result = {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
ADDED_AGENT_ID,
|
||||
loadAddedAgent,
|
||||
setGetAgent,
|
||||
};
|
||||
|
|
@ -113,6 +113,8 @@ const tokenValues = Object.assign(
|
|||
'gpt-4o-2024-05-13': { prompt: 5, completion: 15 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-5': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5.1': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5.2': { prompt: 1.75, completion: 14 },
|
||||
'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
|
||||
'gpt-5-mini': { prompt: 0.25, completion: 2 },
|
||||
'gpt-5-pro': { prompt: 15, completion: 120 },
|
||||
|
|
@ -158,7 +160,9 @@ const tokenValues = Object.assign(
|
|||
'gemini-2.5-flash': { prompt: 0.3, completion: 2.5 },
|
||||
'gemini-2.5-flash-lite': { prompt: 0.1, completion: 0.4 },
|
||||
'gemini-2.5-pro': { prompt: 1.25, completion: 10 },
|
||||
'gemini-2.5-flash-image': { prompt: 0.15, completion: 30 },
|
||||
'gemini-3': { prompt: 2, completion: 12 },
|
||||
'gemini-3-pro-image': { prompt: 2, completion: 120 },
|
||||
'gemini-pro-vision': { prompt: 0.5, completion: 1.5 },
|
||||
grok: { prompt: 2.0, completion: 10.0 }, // Base pattern defaults to grok-2
|
||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||
|
|
|
|||
|
|
@ -36,6 +36,19 @@ describe('getValueKey', () => {
|
|||
expect(getValueKey('gpt-5-0130')).toBe('gpt-5');
|
||||
});
|
||||
|
||||
it('should return "gpt-5.1" for model name containing "gpt-5.1"', () => {
|
||||
expect(getValueKey('gpt-5.1')).toBe('gpt-5.1');
|
||||
expect(getValueKey('gpt-5.1-chat')).toBe('gpt-5.1');
|
||||
expect(getValueKey('gpt-5.1-codex')).toBe('gpt-5.1');
|
||||
expect(getValueKey('openai/gpt-5.1')).toBe('gpt-5.1');
|
||||
});
|
||||
|
||||
it('should return "gpt-5.2" for model name containing "gpt-5.2"', () => {
|
||||
expect(getValueKey('gpt-5.2')).toBe('gpt-5.2');
|
||||
expect(getValueKey('gpt-5.2-chat')).toBe('gpt-5.2');
|
||||
expect(getValueKey('openai/gpt-5.2')).toBe('gpt-5.2');
|
||||
});
|
||||
|
||||
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
|
||||
expect(getValueKey('gpt-3.5-turbo-1106-some-other-info')).toBe('gpt-3.5-turbo-1106');
|
||||
expect(getValueKey('openai/gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
|
||||
|
|
@ -311,6 +324,34 @@ describe('getMultiplier', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5.1', () => {
|
||||
expect(getMultiplier({ model: 'gpt-5.1', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.1'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5.1', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5.1'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5.1', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.1'].prompt,
|
||||
);
|
||||
expect(tokenValues['gpt-5.1'].prompt).toBe(1.25);
|
||||
expect(tokenValues['gpt-5.1'].completion).toBe(10);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5.2', () => {
|
||||
expect(getMultiplier({ model: 'gpt-5.2', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.2'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5.2', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5.2'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5.2', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.2'].prompt,
|
||||
);
|
||||
expect(tokenValues['gpt-5.2'].prompt).toBe(1.75);
|
||||
expect(tokenValues['gpt-5.2'].completion).toBe(14);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o', () => {
|
||||
const valueKey = getValueKey('gpt-4o-2024-08-06');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.8.2-rc1",
|
||||
"version": "v0.8.2-rc2",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
"server-dev": "echo 'please run this from the root directory'",
|
||||
"test": "cross-env NODE_ENV=test jest",
|
||||
"b:test": "NODE_ENV=test bun jest",
|
||||
"test:ci": "jest --ci",
|
||||
"test:ci": "jest --ci --logHeapUsage",
|
||||
"add-balance": "node ./add-balance.js",
|
||||
"list-balances": "node ./list-balances.js",
|
||||
"user-stats": "node ./user-stats.js",
|
||||
|
|
@ -34,20 +34,23 @@
|
|||
},
|
||||
"homepage": "https://librechat.ai",
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.71.0",
|
||||
"@anthropic-ai/vertex-sdk": "^0.14.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.941.0",
|
||||
"@aws-sdk/client-s3": "^3.758.0",
|
||||
"@aws-sdk/s3-request-presigner": "^3.758.0",
|
||||
"@azure/identity": "^4.7.0",
|
||||
"@azure/search-documents": "^12.0.0",
|
||||
"@azure/storage-blob": "^12.27.0",
|
||||
"@google/genai": "^1.19.0",
|
||||
"@googleapis/youtube": "^20.0.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/core": "^0.3.79",
|
||||
"@librechat/agents": "^3.0.52",
|
||||
"@langchain/core": "^0.3.80",
|
||||
"@librechat/agents": "^3.0.66",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@modelcontextprotocol/sdk": "^1.24.3",
|
||||
"@modelcontextprotocol/sdk": "^1.25.2",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"axios": "^1.12.1",
|
||||
|
|
@ -60,7 +63,7 @@
|
|||
"dedent": "^1.5.3",
|
||||
"dotenv": "^16.0.3",
|
||||
"eventsource": "^3.0.2",
|
||||
"express": "^5.1.0",
|
||||
"express": "^5.2.1",
|
||||
"express-mongo-sanitize": "^2.2.0",
|
||||
"express-rate-limit": "^8.2.1",
|
||||
"express-session": "^1.18.2",
|
||||
|
|
@ -79,6 +82,7 @@
|
|||
"klona": "^2.0.6",
|
||||
"librechat-data-provider": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"mathjs": "^15.1.0",
|
||||
"meilisearch": "^0.38.0",
|
||||
"memorystore": "^1.6.7",
|
||||
"mime": "^3.0.0",
|
||||
|
|
|
|||
|
|
@ -350,9 +350,6 @@ function disposeClient(client) {
|
|||
if (client.agentConfigs) {
|
||||
client.agentConfigs = null;
|
||||
}
|
||||
if (client.agentIdMap) {
|
||||
client.agentIdMap = null;
|
||||
}
|
||||
if (client.artifactPromises) {
|
||||
client.artifactPromises = null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,14 +66,17 @@ const resetPasswordController = async (req, res) => {
|
|||
};
|
||||
|
||||
const refreshController = async (req, res) => {
|
||||
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
|
||||
const token_provider = req.headers.cookie
|
||||
? cookies.parse(req.headers.cookie).token_provider
|
||||
: null;
|
||||
if (!refreshToken) {
|
||||
return res.status(200).send('Refresh token not provided');
|
||||
}
|
||||
if (token_provider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS) === true) {
|
||||
const parsedCookies = req.headers.cookie ? cookies.parse(req.headers.cookie) : {};
|
||||
const token_provider = parsedCookies.token_provider;
|
||||
|
||||
if (token_provider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
/** For OpenID users, read refresh token from session to avoid large cookie issues */
|
||||
const refreshToken = req.session?.openidTokens?.refreshToken || parsedCookies.refreshToken;
|
||||
|
||||
if (!refreshToken) {
|
||||
return res.status(200).send('Refresh token not provided');
|
||||
}
|
||||
|
||||
try {
|
||||
const openIdConfig = getOpenIdConfig();
|
||||
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
|
||||
|
|
@ -110,7 +113,7 @@ const refreshController = async (req, res) => {
|
|||
);
|
||||
}
|
||||
|
||||
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString(), refreshToken);
|
||||
const token = setOpenIDAuthTokens(tokenset, req, res, user._id.toString(), refreshToken);
|
||||
|
||||
user.federatedTokens = {
|
||||
access_token: tokenset.access_token,
|
||||
|
|
@ -125,6 +128,13 @@ const refreshController = async (req, res) => {
|
|||
return res.status(403).send('Invalid OpenID refresh token');
|
||||
}
|
||||
}
|
||||
|
||||
/** For non-OpenID users, read refresh token from cookies */
|
||||
const refreshToken = parsedCookies.refreshToken;
|
||||
if (!refreshToken) {
|
||||
return res.status(200).send('Refresh token not provided');
|
||||
}
|
||||
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const user = await getUserById(payload.id, '-password -__v -totpSecret -backupCodes');
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { sendEvent, GenerationJobManager } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -144,17 +144,38 @@ function checkIfLastAgent(last_agent_id, langgraph_node) {
|
|||
return langgraph_node?.endsWith(last_agent_id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to emit events either to res (standard mode) or to job emitter (resumable mode).
|
||||
* @param {ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} eventData - The event data to send
|
||||
*/
|
||||
function emitEvent(res, streamId, eventData) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default handlers for stream events.
|
||||
* @param {Object} options - The options object.
|
||||
* @param {ServerResponse} options.res - The options object.
|
||||
* @param {ContentAggregator} options.aggregateContent - The options object.
|
||||
* @param {ServerResponse} options.res - The server response object.
|
||||
* @param {ContentAggregator} options.aggregateContent - Content aggregator function.
|
||||
* @param {ToolEndCallback} options.toolEndCallback - Callback to use when tool ends.
|
||||
* @param {Array<UsageMetadata>} options.collectedUsage - The list of collected usage metadata.
|
||||
* @param {string | null} [options.streamId] - The stream ID for resumable mode, or null for standard mode.
|
||||
* @returns {Record<string, t.EventHandler>} The default handlers.
|
||||
* @throws {Error} If the request is not found.
|
||||
*/
|
||||
function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedUsage }) {
|
||||
function getDefaultHandlers({
|
||||
res,
|
||||
aggregateContent,
|
||||
toolEndCallback,
|
||||
collectedUsage,
|
||||
streamId = null,
|
||||
}) {
|
||||
if (!res || !aggregateContent) {
|
||||
throw new Error(
|
||||
`[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`,
|
||||
|
|
@ -173,16 +194,16 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.stepDetails.type === StepTypes.TOOL_CALLS) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else {
|
||||
const agentName = metadata?.name ?? 'Agent';
|
||||
const isToolCall = data?.stepDetails.type === StepTypes.TOOL_CALLS;
|
||||
const action = isToolCall ? 'performing a task...' : 'thinking...';
|
||||
sendEvent(res, {
|
||||
emitEvent(res, streamId, {
|
||||
event: 'on_agent_update',
|
||||
data: {
|
||||
runId: metadata?.run_id,
|
||||
|
|
@ -202,11 +223,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.delta.type === StepTypes.TOOL_CALLS) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -220,11 +241,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.result != null) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -238,9 +259,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -254,9 +275,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -266,15 +287,30 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
return handlers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to write attachment events either to res or to job emitter.
|
||||
* @param {ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} attachment - The attachment data
|
||||
*/
|
||||
function writeAttachment(res, streamId, attachment) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
|
||||
} else {
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Promise<MongoFile | { filename: string; filepath: string; expires: number;} | null>[]} params.artifactPromises
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode, or null for standard mode.
|
||||
* @returns {ToolEndCallback} The tool end callback.
|
||||
*/
|
||||
function createToolEndCallback({ req, res, artifactPromises }) {
|
||||
function createToolEndCallback({ req, res, artifactPromises, streamId = null }) {
|
||||
/**
|
||||
* @type {ToolEndCallback}
|
||||
*/
|
||||
|
|
@ -302,10 +338,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
if (!attachment) {
|
||||
return null;
|
||||
}
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing file citations:', error);
|
||||
|
|
@ -314,8 +350,6 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
);
|
||||
}
|
||||
|
||||
// TODO: a lot of duplicated code in createToolEndCallback
|
||||
// we should refactor this to use a helper function in a follow-up PR
|
||||
if (output.artifact[Tools.ui_resources]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
|
|
@ -326,10 +360,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -348,10 +382,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
[Tools.web_search]: { ...output.artifact[Tools.web_search] },
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -388,7 +422,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
toolCallId: output.tool_call_id,
|
||||
conversationId: metadata.thread_id,
|
||||
});
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return fileMetadata;
|
||||
}
|
||||
|
||||
|
|
@ -396,7 +430,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
return null;
|
||||
}
|
||||
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
|
||||
writeAttachment(res, streamId, fileMetadata);
|
||||
return fileMetadata;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -435,7 +469,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
session_id: output.artifact.session_id,
|
||||
});
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return fileMetadata;
|
||||
}
|
||||
|
||||
|
|
@ -443,7 +477,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
return null;
|
||||
}
|
||||
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
|
||||
writeAttachment(res, streamId, fileMetadata);
|
||||
return fileMetadata;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing code output:', error);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const {
|
|||
getBalanceConfig,
|
||||
getProviderConfig,
|
||||
memoryInstructions,
|
||||
GenerationJobManager,
|
||||
getTransactionsConfig,
|
||||
createMemoryProcessor,
|
||||
filterMalformedContentParts,
|
||||
|
|
@ -36,14 +37,13 @@ const {
|
|||
EModelEndpoint,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
isEphemeralAgentId,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { checkCapability } = require('~/server/services/Config');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
|
@ -95,59 +95,137 @@ function logToolError(graph, error, toolId) {
|
|||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies agent labeling to conversation history when multi-agent patterns are detected.
|
||||
* Labels content parts by their originating agent to prevent identity confusion.
|
||||
*
|
||||
* @param {TMessage[]} orderedMessages - The ordered conversation messages
|
||||
* @param {Agent} primaryAgent - The primary agent configuration
|
||||
* @param {Map<string, Agent>} agentConfigs - Map of additional agent configurations
|
||||
* @returns {TMessage[]} Messages with agent labels applied where appropriate
|
||||
*/
|
||||
function applyAgentLabelsToHistory(orderedMessages, primaryAgent, agentConfigs) {
|
||||
const shouldLabelByAgent = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
|
||||
/** Regex pattern to match agent ID suffix (____N) */
|
||||
const AGENT_SUFFIX_PATTERN = /____(\d+)$/;
|
||||
|
||||
if (!shouldLabelByAgent) {
|
||||
return orderedMessages;
|
||||
/**
|
||||
* Finds the primary agent ID within a set of agent IDs.
|
||||
* Primary = no suffix (____N) or lowest suffix number.
|
||||
* @param {Set<string>} agentIds
|
||||
* @returns {string | null}
|
||||
*/
|
||||
function findPrimaryAgentId(agentIds) {
|
||||
let primaryAgentId = null;
|
||||
let lowestSuffixIndex = Infinity;
|
||||
|
||||
for (const agentId of agentIds) {
|
||||
const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN);
|
||||
if (!suffixMatch) {
|
||||
return agentId;
|
||||
}
|
||||
const suffixIndex = parseInt(suffixMatch[1], 10);
|
||||
if (suffixIndex < lowestSuffixIndex) {
|
||||
lowestSuffixIndex = suffixIndex;
|
||||
primaryAgentId = agentId;
|
||||
}
|
||||
}
|
||||
|
||||
const processedMessages = [];
|
||||
return primaryAgentId;
|
||||
}
|
||||
|
||||
for (let i = 0; i < orderedMessages.length; i++) {
|
||||
const message = orderedMessages[i];
|
||||
|
||||
/** @type {Record<string, string>} */
|
||||
const agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
|
||||
/**
|
||||
* Creates a mapMethod for getMessagesForConversation that processes agent content.
|
||||
* - Strips agentId/groupId metadata from all content
|
||||
* - For parallel agents (addedConvo with groupId): filters each group to its primary agent
|
||||
* - For handoffs (agentId without groupId): keeps all content from all agents
|
||||
* - For multi-agent: applies agent labels to content
|
||||
*
|
||||
* The key distinction:
|
||||
* - Parallel execution (addedConvo): Parts have both agentId AND groupId
|
||||
* - Handoffs: Parts only have agentId, no groupId
|
||||
*
|
||||
* @param {Agent} primaryAgent - Primary agent configuration
|
||||
* @param {Map<string, Agent>} [agentConfigs] - Additional agent configurations
|
||||
* @returns {(message: TMessage) => TMessage} Map method for processing messages
|
||||
*/
|
||||
function createMultiAgentMapper(primaryAgent, agentConfigs) {
|
||||
const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
|
||||
|
||||
/** @type {Record<string, string> | null} */
|
||||
let agentNames = null;
|
||||
if (hasMultipleAgents) {
|
||||
agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
|
||||
if (agentConfigs) {
|
||||
for (const [agentId, agentConfig] of agentConfigs.entries()) {
|
||||
agentNames[agentId] = agentConfig.name || agentConfig.id;
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
!message.isCreatedByUser &&
|
||||
message.metadata?.agentIdMap &&
|
||||
Array.isArray(message.content)
|
||||
) {
|
||||
try {
|
||||
const labeledContent = labelContentByAgent(
|
||||
message.content,
|
||||
message.metadata.agentIdMap,
|
||||
agentNames,
|
||||
);
|
||||
|
||||
processedMessages.push({ ...message, content: labeledContent });
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error applying agent labels to message:', error);
|
||||
processedMessages.push(message);
|
||||
}
|
||||
} else {
|
||||
processedMessages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
return processedMessages;
|
||||
return (message) => {
|
||||
if (message.isCreatedByUser || !Array.isArray(message.content)) {
|
||||
return message;
|
||||
}
|
||||
|
||||
// Check for metadata
|
||||
const hasAgentMetadata = message.content.some((part) => part?.agentId || part?.groupId != null);
|
||||
if (!hasAgentMetadata) {
|
||||
return message;
|
||||
}
|
||||
|
||||
try {
|
||||
// Build a map of groupId -> Set of agentIds, to find primary per group
|
||||
/** @type {Map<number, Set<string>>} */
|
||||
const groupAgentMap = new Map();
|
||||
|
||||
for (const part of message.content) {
|
||||
const groupId = part?.groupId;
|
||||
const agentId = part?.agentId;
|
||||
if (groupId != null && agentId) {
|
||||
if (!groupAgentMap.has(groupId)) {
|
||||
groupAgentMap.set(groupId, new Set());
|
||||
}
|
||||
groupAgentMap.get(groupId).add(agentId);
|
||||
}
|
||||
}
|
||||
|
||||
// For each group, find the primary agent
|
||||
/** @type {Map<number, string>} */
|
||||
const groupPrimaryMap = new Map();
|
||||
for (const [groupId, agentIds] of groupAgentMap) {
|
||||
const primary = findPrimaryAgentId(agentIds);
|
||||
if (primary) {
|
||||
groupPrimaryMap.set(groupId, primary);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {Array<TMessageContentParts>} */
|
||||
const filteredContent = [];
|
||||
/** @type {Record<number, string>} */
|
||||
const agentIdMap = {};
|
||||
|
||||
for (const part of message.content) {
|
||||
const agentId = part?.agentId;
|
||||
const groupId = part?.groupId;
|
||||
|
||||
// Filtering logic:
|
||||
// - No groupId (handoffs): always include
|
||||
// - Has groupId (parallel): only include if it's the primary for that group
|
||||
const isParallelPart = groupId != null;
|
||||
const groupPrimary = isParallelPart ? groupPrimaryMap.get(groupId) : null;
|
||||
const shouldInclude = !isParallelPart || !agentId || agentId === groupPrimary;
|
||||
|
||||
if (shouldInclude) {
|
||||
const newIndex = filteredContent.length;
|
||||
const { agentId: _a, groupId: _g, ...cleanPart } = part;
|
||||
filteredContent.push(cleanPart);
|
||||
if (agentId && hasMultipleAgents) {
|
||||
agentIdMap[newIndex] = agentId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const finalContent =
|
||||
Object.keys(agentIdMap).length > 0 && agentNames
|
||||
? labelContentByAgent(filteredContent, agentIdMap, agentNames)
|
||||
: filteredContent;
|
||||
|
||||
return { ...message, content: finalContent };
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error processing multi-agent message:', error);
|
||||
return message;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
class AgentClient extends BaseClient {
|
||||
|
|
@ -199,8 +277,6 @@ class AgentClient extends BaseClient {
|
|||
this.indexTokenCountMap = {};
|
||||
/** @type {(messages: BaseMessage[]) => Promise<void>} */
|
||||
this.processMemory;
|
||||
/** @type {Record<number, string> | null} */
|
||||
this.agentIdMap = null;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -210,9 +286,7 @@ class AgentClient extends BaseClient {
|
|||
return this.contentParts;
|
||||
}
|
||||
|
||||
setOptions(options) {
|
||||
logger.info('[api/server/controllers/agents/client.js] setOptions', options);
|
||||
}
|
||||
setOptions(_options) {}
|
||||
|
||||
/**
|
||||
* `AgentClient` is not opinionated about vision requests, so we don't do anything here
|
||||
|
|
@ -287,18 +361,15 @@ class AgentClient extends BaseClient {
|
|||
{ instructions = null, additional_instructions = null },
|
||||
opts,
|
||||
) {
|
||||
let orderedMessages = this.constructor.getMessagesForConversation({
|
||||
/** Always pass mapMethod; getMessagesForConversation applies it only to messages with addedConvo flag */
|
||||
const orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
summary: this.shouldSummarize,
|
||||
mapMethod: createMultiAgentMapper(this.options.agent, this.agentConfigs),
|
||||
mapCondition: (message) => message.addedConvo === true,
|
||||
});
|
||||
|
||||
orderedMessages = applyAgentLabelsToHistory(
|
||||
orderedMessages,
|
||||
this.options.agent,
|
||||
this.agentConfigs,
|
||||
);
|
||||
|
||||
let payload;
|
||||
/** @type {number | undefined} */
|
||||
let promptTokens;
|
||||
|
|
@ -550,10 +621,9 @@ class AgentClient extends BaseClient {
|
|||
agent: prelimAgent,
|
||||
allowedProviders,
|
||||
endpointOption: {
|
||||
endpoint:
|
||||
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
endpoint: !isEphemeralAgentId(prelimAgent.id)
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -593,10 +663,12 @@ class AgentClient extends BaseClient {
|
|||
const userId = this.options.req.user.id + '';
|
||||
const messageId = this.responseMessageId + '';
|
||||
const conversationId = this.conversationId + '';
|
||||
const streamId = this.options.req?._resumableStreamId || null;
|
||||
const [withoutKeys, processMemory] = await createMemoryProcessor({
|
||||
userId,
|
||||
config,
|
||||
messageId,
|
||||
streamId,
|
||||
conversationId,
|
||||
memoryMethods: {
|
||||
setMemory: db.setMemory,
|
||||
|
|
@ -604,6 +676,7 @@ class AgentClient extends BaseClient {
|
|||
getFormattedMemories: db.getFormattedMemories,
|
||||
},
|
||||
res: this.options.res,
|
||||
user: createSafeUser(this.options.req.user),
|
||||
});
|
||||
|
||||
this.processMemory = processMemory;
|
||||
|
|
@ -690,9 +763,7 @@ class AgentClient extends BaseClient {
|
|||
});
|
||||
|
||||
const completion = filterMalformedContentParts(this.contentParts);
|
||||
const metadata = this.agentIdMap ? { agentIdMap: this.agentIdMap } : undefined;
|
||||
|
||||
return { completion, metadata };
|
||||
return { completion };
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -713,10 +784,16 @@ class AgentClient extends BaseClient {
|
|||
if (!collectedUsage || !collectedUsage.length) {
|
||||
return;
|
||||
}
|
||||
// Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens)
|
||||
const firstUsage = collectedUsage[0];
|
||||
const input_tokens =
|
||||
(collectedUsage[0]?.input_tokens || 0) +
|
||||
(Number(collectedUsage[0]?.input_token_details?.cache_creation) || 0) +
|
||||
(Number(collectedUsage[0]?.input_token_details?.cache_read) || 0);
|
||||
(firstUsage?.input_tokens || 0) +
|
||||
(Number(firstUsage?.input_token_details?.cache_creation) ||
|
||||
Number(firstUsage?.cache_creation_input_tokens) ||
|
||||
0) +
|
||||
(Number(firstUsage?.input_token_details?.cache_read) ||
|
||||
Number(firstUsage?.cache_read_input_tokens) ||
|
||||
0);
|
||||
|
||||
let output_tokens = 0;
|
||||
let previousTokens = input_tokens; // Start with original input
|
||||
|
|
@ -726,8 +803,13 @@ class AgentClient extends BaseClient {
|
|||
continue;
|
||||
}
|
||||
|
||||
const cache_creation = Number(usage.input_token_details?.cache_creation) || 0;
|
||||
const cache_read = Number(usage.input_token_details?.cache_read) || 0;
|
||||
// Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens)
|
||||
const cache_creation =
|
||||
Number(usage.input_token_details?.cache_creation) ||
|
||||
Number(usage.cache_creation_input_tokens) ||
|
||||
0;
|
||||
const cache_read =
|
||||
Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0;
|
||||
|
||||
const txMetadata = {
|
||||
context,
|
||||
|
|
@ -888,12 +970,10 @@ class AgentClient extends BaseClient {
|
|||
*/
|
||||
const runAgents = async (messages) => {
|
||||
const agents = [this.options.agent];
|
||||
if (
|
||||
this.agentConfigs &&
|
||||
this.agentConfigs.size > 0 &&
|
||||
((this.options.agent.edges?.length ?? 0) > 0 ||
|
||||
(await checkCapability(this.options.req, AgentCapabilities.chain)))
|
||||
) {
|
||||
// Include additional agents when:
|
||||
// - agentConfigs has agents (from addedConvo parallel execution or agent handoffs)
|
||||
// - Agents without incoming edges become start nodes and run in parallel automatically
|
||||
if (this.agentConfigs && this.agentConfigs.size > 0) {
|
||||
agents.push(...this.agentConfigs.values());
|
||||
}
|
||||
|
||||
|
|
@ -953,6 +1033,12 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
this.run = run;
|
||||
|
||||
const streamId = this.options.req?._resumableStreamId;
|
||||
if (streamId && run.Graph) {
|
||||
GenerationJobManager.setGraph(streamId, run.Graph);
|
||||
}
|
||||
|
||||
if (userMCPAuthMap != null) {
|
||||
config.configurable.userMCPAuthMap = userMCPAuthMap;
|
||||
}
|
||||
|
|
@ -983,24 +1069,6 @@ class AgentClient extends BaseClient {
|
|||
);
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
/** Capture agent ID map if we have edges or multiple agents */
|
||||
const shouldStoreAgentMap =
|
||||
(this.options.agent.edges?.length ?? 0) > 0 || (this.agentConfigs?.size ?? 0) > 0;
|
||||
if (shouldStoreAgentMap && run?.Graph) {
|
||||
const contentPartAgentMap = run.Graph.getContentPartAgentMap();
|
||||
if (contentPartAgentMap && contentPartAgentMap.size > 0) {
|
||||
this.agentIdMap = Object.fromEntries(contentPartAgentMap);
|
||||
logger.debug('[AgentClient] Captured agent ID map:', {
|
||||
totalParts: this.contentParts.length,
|
||||
mappedParts: Object.keys(this.agentIdMap).length,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error capturing agent ID map:', error);
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
|
|
@ -1052,6 +1120,14 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
|
||||
const { req, agent } = this.options;
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
logger.debug(
|
||||
`[api/server/controllers/agents/client.js #titleConvo] Skipping title generation for temporary conversation`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const appConfig = req.config;
|
||||
let endpoint = agent.endpoint;
|
||||
|
||||
|
|
|
|||
|
|
@ -336,6 +336,25 @@ describe('AgentClient - titleConvo', () => {
|
|||
expect(client.recordCollectedUsage).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should skip title generation for temporary chats', async () => {
|
||||
// Set isTemporary to true
|
||||
mockReq.body.isTemporary = true;
|
||||
|
||||
const text = 'Test temporary chat';
|
||||
const abortController = new AbortController();
|
||||
|
||||
const result = await client.titleConvo({ text, abortController });
|
||||
|
||||
// Should return undefined without generating title
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// generateTitle should NOT have been called
|
||||
expect(mockRun.generateTitle).not.toHaveBeenCalled();
|
||||
|
||||
// recordCollectedUsage should NOT have been called
|
||||
expect(client.recordCollectedUsage).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should skip title generation when titleConvo is false in all config', async () => {
|
||||
// Set titleConvo to false in "all" config
|
||||
mockReq.config = {
|
||||
|
|
@ -1611,4 +1630,223 @@ describe('AgentClient - titleConvo', () => {
|
|||
expect(mockProcessMemory).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMessagesForConversation - mapMethod and mapCondition', () => {
|
||||
const createMessage = (id, parentId, text, extras = {}) => ({
|
||||
messageId: id,
|
||||
parentMessageId: parentId,
|
||||
text,
|
||||
isCreatedByUser: false,
|
||||
...extras,
|
||||
});
|
||||
|
||||
it('should apply mapMethod to all messages when mapCondition is not provided', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'First message'),
|
||||
createMessage('msg-2', 'msg-1', 'Second message'),
|
||||
createMessage('msg-3', 'msg-2', 'Third message'),
|
||||
];
|
||||
|
||||
const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true }));
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-3',
|
||||
mapMethod,
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(3);
|
||||
expect(mapMethod).toHaveBeenCalledTimes(3);
|
||||
result.forEach((msg) => {
|
||||
expect(msg.mapped).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should apply mapMethod only to messages where mapCondition returns true', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'First message', { addedConvo: false }),
|
||||
createMessage('msg-2', 'msg-1', 'Second message', { addedConvo: true }),
|
||||
createMessage('msg-3', 'msg-2', 'Third message', { addedConvo: true }),
|
||||
createMessage('msg-4', 'msg-3', 'Fourth message', { addedConvo: false }),
|
||||
];
|
||||
|
||||
const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true }));
|
||||
const mapCondition = (msg) => msg.addedConvo === true;
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-4',
|
||||
mapMethod,
|
||||
mapCondition,
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(4);
|
||||
expect(mapMethod).toHaveBeenCalledTimes(2);
|
||||
|
||||
expect(result[0].mapped).toBeUndefined();
|
||||
expect(result[1].mapped).toBe(true);
|
||||
expect(result[2].mapped).toBe(true);
|
||||
expect(result[3].mapped).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should not apply mapMethod when mapCondition returns false for all messages', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'First message', { addedConvo: false }),
|
||||
createMessage('msg-2', 'msg-1', 'Second message', { addedConvo: false }),
|
||||
];
|
||||
|
||||
const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true }));
|
||||
const mapCondition = (msg) => msg.addedConvo === true;
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-2',
|
||||
mapMethod,
|
||||
mapCondition,
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(mapMethod).not.toHaveBeenCalled();
|
||||
result.forEach((msg) => {
|
||||
expect(msg.mapped).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
it('should not call mapMethod when mapMethod is null', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'First message'),
|
||||
createMessage('msg-2', 'msg-1', 'Second message'),
|
||||
];
|
||||
|
||||
const mapCondition = jest.fn(() => true);
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-2',
|
||||
mapMethod: null,
|
||||
mapCondition,
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(mapCondition).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle mapCondition with complex logic', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'User message', { isCreatedByUser: true, addedConvo: true }),
|
||||
createMessage('msg-2', 'msg-1', 'Assistant response', { addedConvo: true }),
|
||||
createMessage('msg-3', 'msg-2', 'Another user message', { isCreatedByUser: true }),
|
||||
createMessage('msg-4', 'msg-3', 'Another response', { addedConvo: true }),
|
||||
];
|
||||
|
||||
const mapMethod = jest.fn((msg) => ({ ...msg, processed: true }));
|
||||
const mapCondition = (msg) => msg.addedConvo === true && !msg.isCreatedByUser;
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-4',
|
||||
mapMethod,
|
||||
mapCondition,
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(4);
|
||||
expect(mapMethod).toHaveBeenCalledTimes(2);
|
||||
|
||||
expect(result[0].processed).toBeUndefined();
|
||||
expect(result[1].processed).toBe(true);
|
||||
expect(result[2].processed).toBeUndefined();
|
||||
expect(result[3].processed).toBe(true);
|
||||
});
|
||||
|
||||
it('should preserve message order after applying mapMethod with mapCondition', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'First', { addedConvo: true }),
|
||||
createMessage('msg-2', 'msg-1', 'Second', { addedConvo: false }),
|
||||
createMessage('msg-3', 'msg-2', 'Third', { addedConvo: true }),
|
||||
];
|
||||
|
||||
const mapMethod = (msg) => ({ ...msg, text: `[MAPPED] ${msg.text}` });
|
||||
const mapCondition = (msg) => msg.addedConvo === true;
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-3',
|
||||
mapMethod,
|
||||
mapCondition,
|
||||
});
|
||||
|
||||
expect(result[0].text).toBe('[MAPPED] First');
|
||||
expect(result[1].text).toBe('Second');
|
||||
expect(result[2].text).toBe('[MAPPED] Third');
|
||||
});
|
||||
|
||||
it('should work with summary option alongside mapMethod and mapCondition', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'First', { addedConvo: false }),
|
||||
createMessage('msg-2', 'msg-1', 'Second', {
|
||||
summary: 'Summary of conversation',
|
||||
addedConvo: true,
|
||||
}),
|
||||
createMessage('msg-3', 'msg-2', 'Third', { addedConvo: true }),
|
||||
createMessage('msg-4', 'msg-3', 'Fourth', { addedConvo: false }),
|
||||
];
|
||||
|
||||
const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true }));
|
||||
const mapCondition = (msg) => msg.addedConvo === true;
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-4',
|
||||
mapMethod,
|
||||
mapCondition,
|
||||
summary: true,
|
||||
});
|
||||
|
||||
/** Traversal stops at msg-2 (has summary), so we get msg-4 -> msg-3 -> msg-2 */
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0].text).toBe('Summary of conversation');
|
||||
expect(result[0].role).toBe('system');
|
||||
expect(result[0].mapped).toBe(true);
|
||||
expect(result[1].mapped).toBe(true);
|
||||
expect(result[2].mapped).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle empty messages array', () => {
|
||||
const mapMethod = jest.fn();
|
||||
const mapCondition = jest.fn();
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages: [],
|
||||
parentMessageId: 'msg-1',
|
||||
mapMethod,
|
||||
mapCondition,
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(0);
|
||||
expect(mapMethod).not.toHaveBeenCalled();
|
||||
expect(mapCondition).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle undefined mapCondition explicitly', () => {
|
||||
const messages = [
|
||||
createMessage('msg-1', null, 'First'),
|
||||
createMessage('msg-2', 'msg-1', 'Second'),
|
||||
];
|
||||
|
||||
const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true }));
|
||||
|
||||
const result = AgentClient.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId: 'msg-2',
|
||||
mapMethod,
|
||||
mapCondition: undefined,
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(mapMethod).toHaveBeenCalledTimes(2);
|
||||
result.forEach((msg) => {
|
||||
expect(msg.mapped).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { Constants, ViolationTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
sendEvent,
|
||||
getViolationInfo,
|
||||
GenerationJobManager,
|
||||
decrementPendingRequest,
|
||||
sanitizeFileForTransmit,
|
||||
sanitizeMessageForTransmit,
|
||||
checkAndIncrementPendingRequest,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||
const { handleAbortError } = require('~/server/middleware');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { saveMessage } = require('~/models');
|
||||
|
||||
function createCloseHandler(abortController) {
|
||||
|
|
@ -31,12 +32,16 @@ function createCloseHandler(abortController) {
|
|||
};
|
||||
}
|
||||
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
/**
|
||||
* Resumable Agent Controller - Generation runs independently of HTTP connection.
|
||||
* Returns streamId immediately, client subscribes separately via SSE.
|
||||
*/
|
||||
const ResumableAgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
conversationId: reqConversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
|
|
@ -44,18 +49,370 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let sender;
|
||||
let abortKey;
|
||||
const userId = req.user.id;
|
||||
|
||||
const { allowed, pendingRequests, limit } = await checkAndIncrementPendingRequest(userId);
|
||||
if (!allowed) {
|
||||
const violationInfo = getViolationInfo(pendingRequests, limit);
|
||||
await logViolation(req, res, ViolationTypes.CONCURRENT, violationInfo, violationInfo.score);
|
||||
return res.status(429).json(violationInfo);
|
||||
}
|
||||
|
||||
// Generate conversationId upfront if not provided - streamId === conversationId always
|
||||
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
||||
const conversationId =
|
||||
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
|
||||
const streamId = conversationId;
|
||||
|
||||
let client = null;
|
||||
|
||||
try {
|
||||
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
|
||||
req._resumableStreamId = streamId;
|
||||
|
||||
// Send JSON response IMMEDIATELY so client can connect to SSE stream
|
||||
// This is critical: tool loading (MCP OAuth) may emit events that the client needs to receive
|
||||
res.json({ streamId, conversationId, status: 'started' });
|
||||
|
||||
// Note: We no longer use res.on('close') to abort since we send JSON immediately.
|
||||
// The response closes normally after res.json(), which is not an abort condition.
|
||||
// Abort handling is done through GenerationJobManager via the SSE stream connection.
|
||||
|
||||
// Track if partial response was already saved to avoid duplicates
|
||||
let partialResponseSaved = false;
|
||||
|
||||
/**
|
||||
* Listen for all subscribers leaving to save partial response.
|
||||
* This ensures the response is saved to DB even if all clients disconnect
|
||||
* while generation continues.
|
||||
*
|
||||
* Note: The messageId used here falls back to `${userMessage.messageId}_` if the
|
||||
* actual response messageId isn't available yet. The final response save will
|
||||
* overwrite this with the complete response using the same messageId pattern.
|
||||
*/
|
||||
job.emitter.on('allSubscribersLeft', async (aggregatedContent) => {
|
||||
if (partialResponseSaved || !aggregatedContent || aggregatedContent.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resumeState = await GenerationJobManager.getResumeState(streamId);
|
||||
if (!resumeState?.userMessage) {
|
||||
logger.debug('[ResumableAgentController] No user message to save partial response for');
|
||||
return;
|
||||
}
|
||||
|
||||
partialResponseSaved = true;
|
||||
const responseConversationId = resumeState.conversationId || conversationId;
|
||||
|
||||
try {
|
||||
const partialMessage = {
|
||||
messageId: resumeState.responseMessageId || `${resumeState.userMessage.messageId}_`,
|
||||
conversationId: responseConversationId,
|
||||
parentMessageId: resumeState.userMessage.messageId,
|
||||
sender: client?.sender ?? 'AI',
|
||||
content: aggregatedContent,
|
||||
unfinished: true,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
user: userId,
|
||||
endpoint: endpointOption.endpoint,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
if (req.body?.agent_id) {
|
||||
partialMessage.agent_id = req.body.agent_id;
|
||||
}
|
||||
|
||||
await saveMessage(req, partialMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - partial response on disconnect',
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`[ResumableAgentController] Saved partial response for ${streamId}, content parts: ${aggregatedContent.length}`,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[ResumableAgentController] Error saving partial response:', error);
|
||||
// Reset flag so we can try again if subscribers reconnect and leave again
|
||||
partialResponseSaved = false;
|
||||
}
|
||||
});
|
||||
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
// Use the job's abort controller signal - allows abort via GenerationJobManager.abortJob()
|
||||
signal: job.abortController.signal,
|
||||
});
|
||||
|
||||
if (job.abortController.signal.aborted) {
|
||||
GenerationJobManager.completeJob(streamId, 'Request aborted during initialization');
|
||||
await decrementPendingRequest(userId);
|
||||
return;
|
||||
}
|
||||
|
||||
client = result.client;
|
||||
|
||||
if (client?.sender) {
|
||||
GenerationJobManager.updateMetadata(streamId, { sender: client.sender });
|
||||
}
|
||||
|
||||
// Store reference to client's contentParts - graph will be set when run is created
|
||||
if (client?.contentParts) {
|
||||
GenerationJobManager.setContentParts(streamId, client.contentParts);
|
||||
}
|
||||
|
||||
let userMessage;
|
||||
|
||||
const getReqData = (data = {}) => {
|
||||
if (data.userMessage) {
|
||||
userMessage = data.userMessage;
|
||||
}
|
||||
// conversationId is pre-generated, no need to update from callback
|
||||
};
|
||||
|
||||
// Start background generation - readyPromise resolves immediately now
|
||||
// (sync mechanism handles late subscribers)
|
||||
const startGeneration = async () => {
|
||||
try {
|
||||
// Short timeout as safety net - promise should already be resolved
|
||||
await Promise.race([job.readyPromise, new Promise((resolve) => setTimeout(resolve, 100))]);
|
||||
} catch (waitError) {
|
||||
logger.warn(
|
||||
`[ResumableAgentController] Error waiting for subscriber: ${waitError.message}`,
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const onStart = (userMsg, respMsgId, _isNewConvo) => {
|
||||
userMessage = userMsg;
|
||||
|
||||
// Store userMessage and responseMessageId upfront for resume capability
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
responseMessageId: respMsgId,
|
||||
userMessage: {
|
||||
messageId: userMsg.messageId,
|
||||
parentMessageId: userMsg.parentMessageId,
|
||||
conversationId: userMsg.conversationId,
|
||||
text: userMsg.text,
|
||||
},
|
||||
});
|
||||
|
||||
GenerationJobManager.emitChunk(streamId, {
|
||||
created: true,
|
||||
message: userMessage,
|
||||
streamId,
|
||||
});
|
||||
};
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
onStart,
|
||||
getReqData,
|
||||
isContinued,
|
||||
isRegenerate,
|
||||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController: job.abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res: {
|
||||
write: () => true,
|
||||
end: () => {},
|
||||
headersSent: false,
|
||||
writableEnded: false,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const response = await client.sendMessage(text, messageOptions);
|
||||
|
||||
const messageId = response.messageId;
|
||||
const endpoint = endpointOption.endpoint;
|
||||
response.endpoint = endpoint;
|
||||
|
||||
const databasePromise = response.databasePromise;
|
||||
delete response.databasePromise;
|
||||
|
||||
const { conversation: convoData = {} } = await databasePromise;
|
||||
const conversation = { ...convoData };
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
if (req.body.files && client.options?.attachments) {
|
||||
userMessage.files = [];
|
||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
||||
for (const attachment of client.options.attachments) {
|
||||
if (messageFiles.has(attachment.file_id)) {
|
||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
||||
}
|
||||
}
|
||||
delete userMessage.image_urls;
|
||||
}
|
||||
|
||||
// Check abort state BEFORE calling completeJob (which triggers abort signal for cleanup)
|
||||
const wasAbortedBeforeComplete = job.abortController.signal.aborted;
|
||||
const isNewConvo = !reqConversationId || reqConversationId === 'new';
|
||||
const shouldGenerateTitle =
|
||||
addTitle &&
|
||||
parentMessageId === Constants.NO_PARENT &&
|
||||
isNewConvo &&
|
||||
!wasAbortedBeforeComplete;
|
||||
|
||||
// Save user message BEFORE sending final event to avoid race condition
|
||||
// where client refetch happens before database is updated
|
||||
if (!client.skipSaveUserMessage && userMessage) {
|
||||
await saveMessage(req, userMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - resumable user message',
|
||||
});
|
||||
}
|
||||
|
||||
if (!wasAbortedBeforeComplete) {
|
||||
const finalEvent = {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: { ...response },
|
||||
};
|
||||
|
||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||
GenerationJobManager.completeJob(streamId);
|
||||
await decrementPendingRequest(userId);
|
||||
|
||||
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...response, user: userId },
|
||||
{ context: 'api/server/controllers/agents/request.js - resumable response end' },
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const finalEvent = {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: { ...response, error: true },
|
||||
error: { message: 'Request was aborted' },
|
||||
};
|
||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||
GenerationJobManager.completeJob(streamId, 'Request aborted');
|
||||
await decrementPendingRequest(userId);
|
||||
}
|
||||
|
||||
if (shouldGenerateTitle) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
client,
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error('[ResumableAgentController] Error in title generation', err);
|
||||
})
|
||||
.finally(() => {
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Check if this was an abort (not a real error)
|
||||
const wasAborted = job.abortController.signal.aborted || error.message?.includes('abort');
|
||||
|
||||
if (wasAborted) {
|
||||
logger.debug(`[ResumableAgentController] Generation aborted for ${streamId}`);
|
||||
// abortJob already handled emitDone and completeJob
|
||||
} else {
|
||||
logger.error(`[ResumableAgentController] Generation error for ${streamId}:`, error);
|
||||
GenerationJobManager.emitError(streamId, error.message || 'Generation failed');
|
||||
GenerationJobManager.completeJob(streamId, error.message);
|
||||
}
|
||||
|
||||
await decrementPendingRequest(userId);
|
||||
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
|
||||
// Don't continue to title generation after error/abort
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Start generation and handle any unhandled errors
|
||||
startGeneration().catch(async (err) => {
|
||||
logger.error(
|
||||
`[ResumableAgentController] Unhandled error in background generation: ${err.message}`,
|
||||
);
|
||||
GenerationJobManager.completeJob(streamId, err.message);
|
||||
await decrementPendingRequest(userId);
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[ResumableAgentController] Initialization error:', error);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({ error: error.message || 'Failed to start generation' });
|
||||
} else {
|
||||
// JSON already sent, emit error to stream so client can receive it
|
||||
GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation');
|
||||
}
|
||||
GenerationJobManager.completeJob(streamId, error.message);
|
||||
await decrementPendingRequest(userId);
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Agent Controller - Routes to ResumableAgentController for all requests.
|
||||
* The legacy non-resumable path is kept below but no longer used by default.
|
||||
*/
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
return ResumableAgentController(req, res, next, initializeClient, addTitle);
|
||||
};
|
||||
|
||||
/**
|
||||
* Legacy Non-resumable Agent Controller - Uses GenerationJobManager for abort handling.
|
||||
* Response is streamed directly to client via res, but abort state is managed centrally.
|
||||
* @deprecated Use ResumableAgentController instead
|
||||
*/
|
||||
const _LegacyAgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId: reqConversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
// Generate conversationId upfront if not provided - streamId === conversationId always
|
||||
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
||||
const conversationId =
|
||||
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
|
||||
const streamId = conversationId;
|
||||
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let userMessagePromise;
|
||||
let getAbortData;
|
||||
let client = null;
|
||||
let cleanupHandlers = [];
|
||||
|
||||
const newConvo = !conversationId;
|
||||
// Match the same logic used for conversationId generation above
|
||||
const isNewConvo = !reqConversationId || reqConversationId === 'new';
|
||||
const userId = req.user.id;
|
||||
|
||||
// Create handler to avoid capturing the entire parent scope
|
||||
|
|
@ -64,24 +421,20 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
// Update job metadata with prompt tokens for abort handling
|
||||
GenerationJobManager.updateMetadata(streamId, { promptTokens: data[key] });
|
||||
} else if (key === 'sender') {
|
||||
sender = data[key];
|
||||
} else if (key === 'abortKey') {
|
||||
abortKey = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
GenerationJobManager.updateMetadata(streamId, { sender: data[key] });
|
||||
}
|
||||
// conversationId is pre-generated, no need to update from callback
|
||||
}
|
||||
};
|
||||
|
||||
// Create a function to handle final cleanup
|
||||
const performCleanup = () => {
|
||||
const performCleanup = async () => {
|
||||
logger.debug('[AgentController] Performing cleanup');
|
||||
if (Array.isArray(cleanupHandlers)) {
|
||||
for (const handler of cleanupHandlers) {
|
||||
|
|
@ -95,10 +448,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
}
|
||||
|
||||
// Clean up abort controller
|
||||
if (abortKey) {
|
||||
logger.debug('[AgentController] Cleaning up abort controller');
|
||||
cleanupAbortController(abortKey);
|
||||
// Complete the job in GenerationJobManager
|
||||
if (streamId) {
|
||||
logger.debug('[AgentController] Completing job in GenerationJobManager');
|
||||
await GenerationJobManager.completeJob(streamId);
|
||||
}
|
||||
|
||||
// Dispose client properly
|
||||
|
|
@ -110,11 +463,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
client = null;
|
||||
getReqData = null;
|
||||
userMessage = null;
|
||||
getAbortData = null;
|
||||
endpointOption.agent = null;
|
||||
endpointOption = null;
|
||||
cleanupHandlers = null;
|
||||
userMessagePromise = null;
|
||||
|
||||
// Clear request data map
|
||||
if (requestDataMap.has(req)) {
|
||||
|
|
@ -136,6 +485,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
};
|
||||
cleanupHandlers.push(removePrelimHandler);
|
||||
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
|
|
@ -143,6 +493,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
endpointOption,
|
||||
signal: prelimAbortController.signal,
|
||||
});
|
||||
|
||||
if (prelimAbortController.signal?.aborted) {
|
||||
prelimAbortController = null;
|
||||
throw new Error('Request was aborted before initialization could complete');
|
||||
|
|
@ -161,28 +512,24 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
// Store request data in WeakMap keyed by req object
|
||||
requestDataMap.set(req, { client });
|
||||
|
||||
// Use WeakRef to allow GC but still access content if it exists
|
||||
const contentRef = new WeakRef(client.contentParts || []);
|
||||
// Create job in GenerationJobManager for abort handling
|
||||
// streamId === conversationId (pre-generated above)
|
||||
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
|
||||
|
||||
// Minimize closure scope - only capture small primitives and WeakRef
|
||||
getAbortData = () => {
|
||||
// Dereference WeakRef each time
|
||||
const content = contentRef.deref();
|
||||
// Store endpoint metadata for abort handling
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
sender: client?.sender,
|
||||
});
|
||||
|
||||
return {
|
||||
sender,
|
||||
content: content || [],
|
||||
userMessage,
|
||||
promptTokens,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
};
|
||||
};
|
||||
// Store content parts reference for abort
|
||||
if (client?.contentParts) {
|
||||
GenerationJobManager.setContentParts(streamId, client.contentParts);
|
||||
}
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
const closeHandler = createCloseHandler(abortController);
|
||||
const closeHandler = createCloseHandler(job.abortController);
|
||||
res.on('close', closeHandler);
|
||||
cleanupHandlers.push(() => {
|
||||
try {
|
||||
|
|
@ -192,6 +539,27 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* onStart callback - stores user message and response ID for abort handling
|
||||
*/
|
||||
const onStart = (userMsg, respMsgId, _isNewConvo) => {
|
||||
sendEvent(res, { message: userMsg, created: true });
|
||||
userMessage = userMsg;
|
||||
userMessageId = userMsg.messageId;
|
||||
responseMessageId = respMsgId;
|
||||
|
||||
// Store metadata for abort handling (conversationId is pre-generated)
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
responseMessageId: respMsgId,
|
||||
userMessage: {
|
||||
messageId: userMsg.messageId,
|
||||
parentMessageId: userMsg.parentMessageId,
|
||||
conversationId,
|
||||
text: userMsg.text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
onStart,
|
||||
|
|
@ -201,7 +569,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController,
|
||||
abortController: job.abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
|
|
@ -241,7 +609,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
|
||||
// Only send if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
if (!job.abortController.signal.aborted) {
|
||||
// Create a new response object with minimal copies
|
||||
const finalResponse = { ...response };
|
||||
|
||||
|
|
@ -292,7 +660,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
|
||||
// Add title if needed - extract minimal data
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && isNewConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
|
|
@ -315,7 +683,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
// Handle error without capturing much scope
|
||||
handleAbortError(res, req, error, {
|
||||
conversationId,
|
||||
sender,
|
||||
sender: client?.sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
|
||||
userMessageId,
|
||||
|
|
|
|||
|
|
@ -109,6 +109,10 @@ const createAgentHandler = async (req, res) => {
|
|||
const validatedData = agentCreateSchema.parse(req.body);
|
||||
const { tools = [], ...agentData } = removeNullishValues(validatedData);
|
||||
|
||||
if (agentData.model_parameters && typeof agentData.model_parameters === 'object') {
|
||||
agentData.model_parameters = removeNullishValues(agentData.model_parameters, true);
|
||||
}
|
||||
|
||||
const { id: userId } = req.user;
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
|
|
@ -259,6 +263,11 @@ const updateAgentHandler = async (req, res) => {
|
|||
// Preserve explicit null for avatar to allow resetting the avatar
|
||||
const { avatar: avatarField, _id, ...rest } = validatedData;
|
||||
const updateData = removeNullishValues(rest);
|
||||
|
||||
if (updateData.model_parameters && typeof updateData.model_parameters === 'object') {
|
||||
updateData.model_parameters = removeNullishValues(updateData.model_parameters, true);
|
||||
}
|
||||
|
||||
if (avatarField === null) {
|
||||
updateData.avatar = avatarField;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -357,6 +357,46 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
});
|
||||
});
|
||||
|
||||
test('should remove empty strings from model_parameters (Issue Fix)', async () => {
|
||||
// This tests the fix for empty strings being sent to API instead of being omitted
|
||||
// When a user clears a numeric field (like max_tokens), it should be removed, not sent as ""
|
||||
const dataWithEmptyModelParams = {
|
||||
provider: 'azureOpenAI',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Empty Model Params',
|
||||
model_parameters: {
|
||||
temperature: 0.7, // Valid number - should be preserved
|
||||
max_tokens: '', // Empty string - should be removed
|
||||
maxContextTokens: '', // Empty string - should be removed
|
||||
topP: 0, // Zero value - should be preserved (not treated as empty)
|
||||
frequency_penalty: '', // Empty string - should be removed
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithEmptyModelParams;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.model_parameters).toBeDefined();
|
||||
// Valid numbers should be preserved
|
||||
expect(createdAgent.model_parameters.temperature).toBe(0.7);
|
||||
expect(createdAgent.model_parameters.topP).toBe(0);
|
||||
// Empty strings should be removed
|
||||
expect(createdAgent.model_parameters.max_tokens).toBeUndefined();
|
||||
expect(createdAgent.model_parameters.maxContextTokens).toBeUndefined();
|
||||
expect(createdAgent.model_parameters.frequency_penalty).toBeUndefined();
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.model_parameters.temperature).toBe(0.7);
|
||||
expect(agentInDb.model_parameters.topP).toBe(0);
|
||||
expect(agentInDb.model_parameters.max_tokens).toBeUndefined();
|
||||
expect(agentInDb.model_parameters.maxContextTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should handle invalid avatar format', async () => {
|
||||
const dataWithInvalidAvatar = {
|
||||
provider: 'openai',
|
||||
|
|
@ -539,6 +579,49 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
|||
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should remove empty strings from model_parameters during update (Issue Fix)', async () => {
|
||||
// First create an agent with valid model_parameters
|
||||
await Agent.updateOne(
|
||||
{ id: existingAgentId },
|
||||
{
|
||||
model_parameters: {
|
||||
temperature: 0.5,
|
||||
max_tokens: 1000,
|
||||
maxContextTokens: 2000,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
// Simulate user clearing the fields (sends empty strings)
|
||||
mockReq.body = {
|
||||
model_parameters: {
|
||||
temperature: 0.7, // Change to new value
|
||||
max_tokens: '', // Clear this field (should be removed, not sent as "")
|
||||
maxContextTokens: '', // Clear this field (should be removed, not sent as "")
|
||||
},
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.model_parameters).toBeDefined();
|
||||
// Valid number should be updated
|
||||
expect(updatedAgent.model_parameters.temperature).toBe(0.7);
|
||||
// Empty strings should be removed, not sent as ""
|
||||
expect(updatedAgent.model_parameters.max_tokens).toBeUndefined();
|
||||
expect(updatedAgent.model_parameters.maxContextTokens).toBeUndefined();
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.model_parameters.temperature).toBe(0.7);
|
||||
expect(agentInDb.model_parameters.max_tokens).toBeUndefined();
|
||||
expect(agentInDb.model_parameters.maxContextTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return 404 for non-existent agent', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
|
||||
|
|
|
|||
|
|
@ -5,15 +5,28 @@ const { logoutUser } = require('~/server/services/AuthService');
|
|||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
||||
const logoutController = async (req, res) => {
|
||||
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
|
||||
const parsedCookies = req.headers.cookie ? cookies.parse(req.headers.cookie) : {};
|
||||
const isOpenIdUser = req.user?.openidId != null && req.user?.provider === 'openid';
|
||||
|
||||
/** For OpenID users, read refresh token from session; for others, use cookie */
|
||||
let refreshToken;
|
||||
if (isOpenIdUser && req.session?.openidTokens) {
|
||||
refreshToken = req.session.openidTokens.refreshToken;
|
||||
delete req.session.openidTokens;
|
||||
}
|
||||
refreshToken = refreshToken || parsedCookies.refreshToken;
|
||||
|
||||
try {
|
||||
const logout = await logoutUser(req, refreshToken);
|
||||
const { status, message } = logout;
|
||||
|
||||
res.clearCookie('refreshToken');
|
||||
res.clearCookie('openid_access_token');
|
||||
res.clearCookie('openid_user_id');
|
||||
res.clearCookie('token_provider');
|
||||
const response = { message };
|
||||
if (
|
||||
req.user.openidId != null &&
|
||||
isOpenIdUser &&
|
||||
isEnabled(process.env.OPENID_USE_END_SESSION_ENDPOINT) &&
|
||||
process.env.OPENID_ISSUER
|
||||
) {
|
||||
|
|
@ -27,7 +40,12 @@ const logoutController = async (req, res) => {
|
|||
? openIdConfig.serverMetadata().end_session_endpoint
|
||||
: null;
|
||||
if (endSessionEndpoint) {
|
||||
response.redirect = endSessionEndpoint;
|
||||
const endSessionUrl = new URL(endSessionEndpoint);
|
||||
/** Redirect back to app's login page after IdP logout */
|
||||
const postLogoutRedirectUri =
|
||||
process.env.OPENID_POST_LOGOUT_REDIRECT_URI || `${process.env.DOMAIN_CLIENT}/login`;
|
||||
endSessionUrl.searchParams.set('post_logout_redirect_uri', postLogoutRedirectUri);
|
||||
response.redirect = endSessionUrl.toString();
|
||||
} else {
|
||||
logger.warn(
|
||||
'[logoutController] end_session_endpoint not found in OpenID issuer metadata. Please verify that the issuer is correct.',
|
||||
|
|
|
|||
|
|
@ -6,10 +6,54 @@
|
|||
* @import { MCPServerDocument } from 'librechat-data-provider'
|
||||
*/
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
isMCPDomainNotAllowedError,
|
||||
isMCPInspectionFailedError,
|
||||
MCPErrorCodes,
|
||||
} = require('@librechat/api');
|
||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
||||
|
||||
/**
|
||||
* Handles MCP-specific errors and sends appropriate HTTP responses.
|
||||
* @param {Error} error - The error to handle
|
||||
* @param {import('express').Response} res - Express response object
|
||||
* @returns {import('express').Response | null} Response if handled, null if not an MCP error
|
||||
*/
|
||||
function handleMCPError(error, res) {
|
||||
if (isMCPDomainNotAllowedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
if (isMCPInspectionFailedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
// Fallback for legacy string-based error handling (backwards compatibility)
|
||||
if (error.message?.startsWith(MCPErrorCodes.DOMAIN_NOT_ALLOWED)) {
|
||||
return res.status(403).json({
|
||||
error: MCPErrorCodes.DOMAIN_NOT_ALLOWED,
|
||||
message: error.message.replace(/^MCP_DOMAIN_NOT_ALLOWED\s*:\s*/i, ''),
|
||||
});
|
||||
}
|
||||
|
||||
if (error.message?.startsWith(MCPErrorCodes.INSPECTION_FAILED)) {
|
||||
return res.status(400).json({
|
||||
error: MCPErrorCodes.INSPECTION_FAILED,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
*/
|
||||
|
|
@ -175,11 +219,9 @@ const createMCPServerController = async (req, res) => {
|
|||
});
|
||||
} catch (error) {
|
||||
logger.error('[createMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
@ -235,11 +277,9 @@ const updateMCPServerController = async (req, res) => {
|
|||
res.status(200).json(parsedConfig);
|
||||
} catch (error) {
|
||||
logger.error('[updateMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED:')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ const {
|
|||
performStartupChecks,
|
||||
handleJsonParseError,
|
||||
initializeFileStorage,
|
||||
GenerationJobManager,
|
||||
createStreamServices,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
|
|
@ -192,6 +194,11 @@ const startServer = async () => {
|
|||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await checkMigrations();
|
||||
|
||||
// Configure stream services (auto-detects Redis from USE_REDIS env var)
|
||||
const streamServices = createStreamServices();
|
||||
GenerationJobManager.configure(streamServices);
|
||||
GenerationJobManager.initialize();
|
||||
});
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,13 @@ jest.mock('~/app/clients/tools', () => ({
|
|||
toolkits: [],
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
createMCPServersRegistry: jest.fn(),
|
||||
createMCPManager: jest.fn().mockResolvedValue({
|
||||
getAppToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('Server Configuration', () => {
|
||||
// Increase the default timeout to allow for Mongo cleanup
|
||||
jest.setTimeout(30_000);
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
// abortControllers.js
|
||||
module.exports = new Map();
|
||||
|
|
@ -1,124 +1,102 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent, sanitizeMessageForTransmit } = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
countTokens,
|
||||
isEnabled,
|
||||
sendEvent,
|
||||
GenerationJobManager,
|
||||
sanitizeMessageForTransmit,
|
||||
} = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { abortRun } = require('./abortRun');
|
||||
|
||||
const abortDataMap = new WeakMap();
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {boolean}
|
||||
* Abort an active message generation.
|
||||
* Uses GenerationJobManager for all agent requests.
|
||||
* Since streamId === conversationId, we can directly abort by conversationId.
|
||||
*/
|
||||
function cleanupAbortController(abortKey) {
|
||||
if (!abortControllers.has(abortKey)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey);
|
||||
|
||||
if (!abortController) {
|
||||
abortControllers.delete(abortKey);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 1. Check if this controller has any composed signals and clean them up
|
||||
try {
|
||||
// This creates a temporary composed signal to use for cleanup
|
||||
const composedSignal = AbortSignal.any([abortController.signal]);
|
||||
|
||||
// Get all event types - in practice, AbortSignal typically only uses 'abort'
|
||||
const eventTypes = ['abort'];
|
||||
|
||||
// First, execute a dummy listener removal to handle potential composed signals
|
||||
for (const eventType of eventTypes) {
|
||||
const dummyHandler = () => {};
|
||||
composedSignal.addEventListener(eventType, dummyHandler);
|
||||
composedSignal.removeEventListener(eventType, dummyHandler);
|
||||
|
||||
const listeners = composedSignal.listeners?.(eventType) || [];
|
||||
for (const listener of listeners) {
|
||||
composedSignal.removeEventListener(eventType, listener);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug(`Error cleaning up composed signals: ${e}`);
|
||||
}
|
||||
|
||||
// 2. Abort the controller if not already aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
|
||||
// 3. Remove from registry
|
||||
abortControllers.delete(abortKey);
|
||||
|
||||
// 4. Clean up any data stored in the WeakMap
|
||||
if (abortDataMap.has(abortController)) {
|
||||
abortDataMap.delete(abortController);
|
||||
}
|
||||
|
||||
// 5. Clean up function references on the controller
|
||||
if (abortController.getAbortData) {
|
||||
abortController.getAbortData = null;
|
||||
}
|
||||
|
||||
if (abortController.abortCompletion) {
|
||||
abortController.abortCompletion = null;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {function(): void}
|
||||
*/
|
||||
function createCleanUpHandler(abortKey) {
|
||||
return function () {
|
||||
try {
|
||||
cleanupAbortController(abortKey);
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, endpoint } = req.body;
|
||||
const { abortKey, endpoint } = req.body;
|
||||
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||
abortKey = conversationId;
|
||||
// Use GenerationJobManager to abort the job (streamId === conversationId)
|
||||
const abortResult = await GenerationJobManager.abortJob(conversationId);
|
||||
|
||||
if (!abortResult.success) {
|
||||
if (!res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
const { jobData, content, text } = abortResult;
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||
if (!abortController) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
// Count tokens and spend them
|
||||
const completionTokens = await countTokens(text);
|
||||
const promptTokens = jobData?.promptTokens ?? 0;
|
||||
|
||||
const finalEvent = await abortController.abortCompletion?.();
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
|
||||
JSON.stringify({ abortKey }),
|
||||
const responseMessage = {
|
||||
messageId: jobData?.responseMessageId,
|
||||
parentMessageId: jobData?.userMessage?.messageId,
|
||||
conversationId: jobData?.conversationId,
|
||||
content,
|
||||
text,
|
||||
sender: jobData?.sender ?? 'AI',
|
||||
finish_reason: 'incomplete',
|
||||
endpoint: jobData?.endpoint,
|
||||
iconURL: jobData?.iconURL,
|
||||
model: jobData?.model,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: completionTokens,
|
||||
};
|
||||
|
||||
await spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user: userId },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
cleanupAbortController(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user: userId },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
// Get conversation for title
|
||||
const conversation = await getConvo(userId, conversationId);
|
||||
|
||||
const finalEvent = {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: jobData?.userMessage
|
||||
? sanitizeMessageForTransmit({
|
||||
messageId: jobData.userMessage.messageId,
|
||||
parentMessageId: jobData.userMessage.parentMessageId,
|
||||
conversationId: jobData.userMessage.conversationId,
|
||||
text: jobData.userMessage.text,
|
||||
isCreatedByUser: true,
|
||||
})
|
||||
: null,
|
||||
responseMessage,
|
||||
};
|
||||
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${userId} | ${req.user.email} | Aborted request: ${conversationId}`,
|
||||
);
|
||||
|
||||
if (res.headersSent) {
|
||||
return sendEvent(res, finalEvent);
|
||||
}
|
||||
|
||||
|
|
@ -139,171 +117,13 @@ const handleAbort = function () {
|
|||
};
|
||||
};
|
||||
|
||||
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
const abortController = new AbortController();
|
||||
const { endpointOption } = req.body;
|
||||
|
||||
// Store minimal data in WeakMap to avoid circular references
|
||||
abortDataMap.set(abortController, {
|
||||
getAbortDataFn: getAbortData,
|
||||
userId: req.user.id,
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
});
|
||||
|
||||
// Replace the direct function reference with a wrapper that uses WeakMap
|
||||
abortController.getAbortData = function () {
|
||||
const data = abortDataMap.get(this);
|
||||
if (!data || typeof data.getAbortDataFn !== 'function') {
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
const result = data.getAbortDataFn();
|
||||
|
||||
// Create a copy without circular references
|
||||
const cleanResult = { ...result };
|
||||
|
||||
// If userMessagePromise exists, break its reference to client
|
||||
if (
|
||||
cleanResult.userMessagePromise &&
|
||||
typeof cleanResult.userMessagePromise.then === 'function'
|
||||
) {
|
||||
// Create a new promise that fulfills with the same result but doesn't reference the original
|
||||
const originalPromise = cleanResult.userMessagePromise;
|
||||
cleanResult.userMessagePromise = new Promise((resolve, reject) => {
|
||||
originalPromise.then(
|
||||
(result) => resolve({ ...result }),
|
||||
(error) => reject(error),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return cleanResult;
|
||||
} catch (err) {
|
||||
logger.error('[abortController.getAbortData] Error:', err);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
* @param {boolean} [isNewConvo]
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId, isNewConvo) => {
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
|
||||
const prelimAbortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const abortKey = isNewConvo
|
||||
? `${prelimAbortKey}${Constants.COMMON_DIVIDER}${Constants.NEW_CONVO}`
|
||||
: prelimAbortKey;
|
||||
getReqData({ abortKey });
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
const { overrideUserMessageId } = req?.body ?? {};
|
||||
|
||||
if (overrideUserMessageId != null && prevRequest && prevRequest?.abortController) {
|
||||
const data = prevRequest.abortController.getAbortData();
|
||||
getReqData({ userMessage: data?.userMessage });
|
||||
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||
|
||||
// Store minimal options
|
||||
const minimalOptions = {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
|
||||
const cleanupHandler = createCleanUpHandler(addedAbortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store minimal options
|
||||
const minimalOptions = {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...minimalOptions });
|
||||
const cleanupHandler = createCleanUpHandler(abortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
};
|
||||
|
||||
// Define abortCompletion without capturing the entire parent scope
|
||||
abortController.abortCompletion = async function () {
|
||||
this.abort();
|
||||
|
||||
// Get data from WeakMap
|
||||
const ctrlData = abortDataMap.get(this);
|
||||
if (!ctrlData || !ctrlData.getAbortDataFn) {
|
||||
return { final: true, conversation: {}, title: 'New Chat' };
|
||||
}
|
||||
|
||||
// Get abort data using stored function
|
||||
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
|
||||
ctrlData.getAbortDataFn();
|
||||
|
||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||
const user = ctrlData.userId;
|
||||
|
||||
const responseMessage = {
|
||||
...responseData,
|
||||
conversationId,
|
||||
finish_reason: 'incomplete',
|
||||
endpoint: ctrlData.endpoint,
|
||||
iconURL: ctrlData.iconURL,
|
||||
model: ctrlData.modelOptions?.model ?? ctrlData.model_parameters?.model,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: completionTokens,
|
||||
};
|
||||
|
||||
await spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
let conversation;
|
||||
if (userMessagePromise) {
|
||||
const resolved = await userMessagePromise;
|
||||
conversation = resolved?.conversation;
|
||||
// Break reference to promise
|
||||
resolved.conversation = null;
|
||||
}
|
||||
|
||||
if (!conversation) {
|
||||
conversation = await getConvo(user, conversationId);
|
||||
}
|
||||
|
||||
return {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
};
|
||||
|
||||
return { abortController, onStart };
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle abort errors during generation.
|
||||
* @param {ServerResponse} res
|
||||
* @param {ServerRequest} req
|
||||
* @param {Error | unknown} error
|
||||
* @param {Partial<TMessage> & { partialText?: string }} data
|
||||
* @returns { Promise<void> }
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const handleAbortError = async (res, req, error, data) => {
|
||||
if (error?.message?.includes('base64')) {
|
||||
|
|
@ -368,8 +188,7 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
};
|
||||
}
|
||||
|
||||
const callback = createCleanUpHandler(conversationId);
|
||||
await sendError(req, res, options, callback);
|
||||
await sendError(req, res, options);
|
||||
};
|
||||
|
||||
if (partialText && partialText.length > 5) {
|
||||
|
|
@ -387,6 +206,4 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
module.exports = {
|
||||
handleAbort,
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants, isAgentsEndpoint, ResourceType } = require('librechat-data-provider');
|
||||
const {
|
||||
Constants,
|
||||
ResourceType,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
|
||||
|
|
@ -13,7 +18,8 @@ const { getAgent } = require('~/models/Agent');
|
|||
*/
|
||||
const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||
// Handle ephemeral agents - they don't need permission checks
|
||||
if (agentCustomId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
if (isEphemeralAgentId(agentCustomId)) {
|
||||
return null; // No permission check needed for ephemeral agents
|
||||
}
|
||||
|
||||
|
|
@ -62,7 +68,8 @@ const canAccessAgentFromBody = (options) => {
|
|||
}
|
||||
|
||||
// Skip permission checks for ephemeral agents
|
||||
if (agentId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
if (isEphemeralAgentId(agentId)) {
|
||||
return next();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ describe('canAccessAgentResource middleware', () => {
|
|||
AGENTS: {
|
||||
USE: true,
|
||||
CREATE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -26,10 +26,10 @@ describe('canAccessMCPServerResource middleware', () => {
|
|||
await Role.create({
|
||||
name: 'test-role',
|
||||
permissions: {
|
||||
MCPSERVERS: {
|
||||
MCP_SERVERS: {
|
||||
USE: true,
|
||||
CREATE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ describe('fileAccess middleware', () => {
|
|||
AGENTS: {
|
||||
USE: true,
|
||||
CREATE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
SHARE: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -23,9 +23,10 @@ async function buildEndpointOption(req, res, next) {
|
|||
try {
|
||||
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
|
||||
);
|
||||
logger.error(`Error parsing compact conversation for endpoint ${endpoint}`, error);
|
||||
logger.debug({
|
||||
'Error parsing compact conversation': { endpoint, endpointType, conversation: req.body },
|
||||
});
|
||||
return handleError(res, { text: 'Error parsing conversation' });
|
||||
}
|
||||
|
||||
|
|
|
|||
84
api/server/middleware/checkSharePublicAccess.js
Normal file
84
api/server/middleware/checkSharePublicAccess.js
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ResourceType, PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
* Maps resource types to their corresponding permission types
|
||||
*/
|
||||
const resourceToPermissionType = {
|
||||
[ResourceType.AGENT]: PermissionTypes.AGENTS,
|
||||
[ResourceType.PROMPTGROUP]: PermissionTypes.PROMPTS,
|
||||
[ResourceType.MCPSERVER]: PermissionTypes.MCP_SERVERS,
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to check if user has SHARE_PUBLIC permission for a resource type
|
||||
* Only enforced when request body contains `public: true`
|
||||
* @param {import('express').Request} req - Express request
|
||||
* @param {import('express').Response} res - Express response
|
||||
* @param {import('express').NextFunction} next - Express next function
|
||||
*/
|
||||
const checkSharePublicAccess = async (req, res, next) => {
|
||||
try {
|
||||
const { public: isPublic } = req.body;
|
||||
|
||||
// Only check if trying to enable public sharing
|
||||
if (!isPublic) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const user = req.user;
|
||||
if (!user || !user.role) {
|
||||
return res.status(401).json({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
}
|
||||
|
||||
const { resourceType } = req.params;
|
||||
const permissionType = resourceToPermissionType[resourceType];
|
||||
|
||||
if (!permissionType) {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: `Unsupported resource type for public sharing: ${resourceType}`,
|
||||
});
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (!role || !role.permissions) {
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: 'No permissions configured for user role',
|
||||
});
|
||||
}
|
||||
|
||||
const resourcePerms = role.permissions[permissionType] || {};
|
||||
const canSharePublic = resourcePerms[Permissions.SHARE_PUBLIC] === true;
|
||||
|
||||
if (!canSharePublic) {
|
||||
logger.warn(
|
||||
`[checkSharePublicAccess][${user.id}] User denied SHARE_PUBLIC for ${resourceType}`,
|
||||
);
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: `You do not have permission to share ${resourceType} resources publicly`,
|
||||
});
|
||||
}
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[checkSharePublicAccess][${req.user?.id}] Error checking SHARE_PUBLIC permission`,
|
||||
error,
|
||||
);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to check public sharing permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
checkSharePublicAccess,
|
||||
};
|
||||
164
api/server/middleware/checkSharePublicAccess.spec.js
Normal file
164
api/server/middleware/checkSharePublicAccess.spec.js
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
const { ResourceType, PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { checkSharePublicAccess } = require('./checkSharePublicAccess');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
jest.mock('~/models/Role');
|
||||
|
||||
describe('checkSharePublicAccess middleware', () => {
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
let mockNext;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockReq = {
|
||||
user: { id: 'user123', role: 'USER' },
|
||||
params: { resourceType: ResourceType.AGENT },
|
||||
body: {},
|
||||
};
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
mockNext = jest.fn();
|
||||
});
|
||||
|
||||
it('should call next() when public is not true', async () => {
|
||||
mockReq.body = { public: false };
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockNext).toHaveBeenCalled();
|
||||
expect(mockRes.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should call next() when public is undefined', async () => {
|
||||
mockReq.body = { updated: [] };
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockNext).toHaveBeenCalled();
|
||||
expect(mockRes.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
mockReq.body = { public: true };
|
||||
mockReq.user = null;
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(401);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
expect(mockNext).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 403 when user role has no SHARE_PUBLIC permission for agents', async () => {
|
||||
mockReq.body = { public: true };
|
||||
mockReq.params = { resourceType: ResourceType.AGENT };
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.SHARE]: true,
|
||||
[Permissions.SHARE_PUBLIC]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: `You do not have permission to share ${ResourceType.AGENT} resources publicly`,
|
||||
});
|
||||
expect(mockNext).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should call next() when user has SHARE_PUBLIC permission for agents', async () => {
|
||||
mockReq.body = { public: true };
|
||||
mockReq.params = { resourceType: ResourceType.AGENT };
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.SHARE]: true,
|
||||
[Permissions.SHARE_PUBLIC]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockNext).toHaveBeenCalled();
|
||||
expect(mockRes.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should check prompts permission for promptgroup resource type', async () => {
|
||||
mockReq.body = { public: true };
|
||||
mockReq.params = { resourceType: ResourceType.PROMPTGROUP };
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARE_PUBLIC]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockNext).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should check mcp_servers permission for mcpserver resource type', async () => {
|
||||
mockReq.body = { public: true };
|
||||
mockReq.params = { resourceType: ResourceType.MCPSERVER };
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.MCP_SERVERS]: {
|
||||
[Permissions.SHARE_PUBLIC]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockNext).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 400 for unsupported resource type', async () => {
|
||||
mockReq.body = { public: true };
|
||||
mockReq.params = { resourceType: 'unsupported' };
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'Bad Request',
|
||||
message: 'Unsupported resource type for public sharing: unsupported',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return 403 when role has no permissions object', async () => {
|
||||
mockReq.body = { public: true };
|
||||
getRoleByName.mockResolvedValue({ permissions: null });
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
|
||||
it('should return 500 on error', async () => {
|
||||
mockReq.body = { public: true };
|
||||
getRoleByName.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
await checkSharePublicAccess(mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to check public sharing permissions',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
const { isEnabled } = require('@librechat/api');
|
||||
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
const denyRequest = require('./denyRequest');
|
||||
|
||||
const {
|
||||
USE_REDIS,
|
||||
CONCURRENT_MESSAGE_MAX = 1,
|
||||
CONCURRENT_VIOLATION_SCORE: score,
|
||||
} = process.env ?? {};
|
||||
|
||||
/**
|
||||
* Middleware to limit concurrent requests for a user.
|
||||
*
|
||||
* This middleware checks if a user has exceeded a specified concurrent request limit.
|
||||
* If the user exceeds the limit, an error is returned. If the user is within the limit,
|
||||
* their request count is incremented. After the request is processed, the count is decremented.
|
||||
* If the `cache` store is not available, the middleware will skip its logic.
|
||||
*
|
||||
* @function
|
||||
* @param {Object} req - Express request object containing user information.
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
|
||||
*/
|
||||
const concurrentLimiter = async (req, res, next) => {
|
||||
const namespace = CacheKeys.PENDING_REQ;
|
||||
const cache = getLogStores(namespace);
|
||||
if (!cache) {
|
||||
return next();
|
||||
}
|
||||
|
||||
if (Object.keys(req?.body ?? {}).length === 1 && req?.body?.abortKey) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const userId = req.user?.id ?? req.user?._id ?? '';
|
||||
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
|
||||
const type = ViolationTypes.CONCURRENT;
|
||||
|
||||
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
|
||||
const pendingRequests = +((await cache.get(key)) ?? 0);
|
||||
|
||||
if (pendingRequests >= limit) {
|
||||
const errorMessage = {
|
||||
type,
|
||||
limit,
|
||||
pendingRequests,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
} else {
|
||||
await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE);
|
||||
}
|
||||
|
||||
// Ensure the requests are removed from the store once the request is done
|
||||
let cleared = false;
|
||||
const cleanUp = async () => {
|
||||
if (cleared) {
|
||||
return;
|
||||
}
|
||||
cleared = true;
|
||||
await clearPendingReq({ userId, cache });
|
||||
};
|
||||
|
||||
if (pendingRequests < limit) {
|
||||
res.on('finish', cleanUp);
|
||||
res.on('close', cleanUp);
|
||||
}
|
||||
|
||||
next();
|
||||
};
|
||||
|
||||
module.exports = concurrentLimiter;
|
||||
|
|
@ -3,7 +3,6 @@ const validateRegistration = require('./validateRegistration');
|
|||
const buildEndpointOption = require('./buildEndpointOption');
|
||||
const validateMessageReq = require('./validateMessageReq');
|
||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||
const concurrentLimiter = require('./concurrentLimiter');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const canDeleteAccount = require('./canDeleteAccount');
|
||||
const accessResources = require('./accessResources');
|
||||
|
|
@ -42,7 +41,6 @@ module.exports = {
|
|||
requireLocalAuth,
|
||||
canDeleteAccount,
|
||||
configMiddleware,
|
||||
concurrentLimiter,
|
||||
checkDomainAllowed,
|
||||
validateMessageReq,
|
||||
buildEndpointOption,
|
||||
|
|
|
|||
|
|
@ -51,9 +51,9 @@ describe('Access Middleware', () => {
|
|||
permissions: {
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARE]: true,
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: true,
|
||||
|
|
@ -65,7 +65,7 @@ describe('Access Middleware', () => {
|
|||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
[Permissions.SHARE]: false,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -79,9 +79,9 @@ describe('Access Middleware', () => {
|
|||
permissions: {
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARE]: true,
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: true,
|
||||
|
|
@ -93,7 +93,7 @@ describe('Access Middleware', () => {
|
|||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
[Permissions.SHARE]: true,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
|
|
@ -110,7 +110,7 @@ describe('Access Middleware', () => {
|
|||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
[Permissions.SHARE]: false,
|
||||
},
|
||||
// Has permissions for other types
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
|
|
@ -241,7 +241,7 @@ describe('Access Middleware', () => {
|
|||
req: {},
|
||||
user: { id: 'admin123', role: 'admin' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.SHARED_GLOBAL],
|
||||
permissions: [Permissions.SHARE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(shareResult).toBe(true);
|
||||
|
|
@ -318,7 +318,7 @@ describe('Access Middleware', () => {
|
|||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE, Permissions.SHARED_GLOBAL],
|
||||
permissions: [Permissions.USE, Permissions.CREATE, Permissions.SHARE],
|
||||
getRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
|
@ -349,7 +349,7 @@ describe('Access Middleware', () => {
|
|||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
[Permissions.SHARE]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -6,6 +6,15 @@ const { logViolation, getLogStores } = require('~/cache');
|
|||
|
||||
const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {};
|
||||
|
||||
/**
|
||||
* Helper function to get conversationId from different request body structures.
|
||||
* @param {Object} body - The request body.
|
||||
* @returns {string|undefined} The conversationId.
|
||||
*/
|
||||
const getConversationId = (body) => {
|
||||
return body.conversationId ?? body.arg?.conversationId;
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to validate user's authorization for a conversation.
|
||||
*
|
||||
|
|
@ -24,7 +33,7 @@ const validateConvoAccess = async (req, res, next) => {
|
|||
const namespace = ViolationTypes.CONVO_ACCESS;
|
||||
const cache = getLogStores(namespace);
|
||||
|
||||
const conversationId = req.body.conversationId;
|
||||
const conversationId = getConversationId(req.body);
|
||||
|
||||
if (!conversationId || conversationId === Constants.NEW_CONVO) {
|
||||
return next();
|
||||
|
|
|
|||
|
|
@ -68,17 +68,11 @@ function createValidateImageRequest(secureImageLinks) {
|
|||
}
|
||||
|
||||
const parsedCookies = cookies.parse(cookieHeader);
|
||||
const refreshToken = parsedCookies.refreshToken;
|
||||
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const tokenProvider = parsedCookies.token_provider;
|
||||
let userIdForPath;
|
||||
|
||||
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
/** For OpenID users with OPENID_REUSE_TOKENS, use openid_user_id cookie */
|
||||
const openidUserId = parsedCookies.openid_user_id;
|
||||
if (!openidUserId) {
|
||||
logger.warn('[validateImageRequest] No OpenID user ID cookie found');
|
||||
|
|
@ -92,6 +86,17 @@ function createValidateImageRequest(secureImageLinks) {
|
|||
}
|
||||
userIdForPath = validationResult.userId;
|
||||
} else {
|
||||
/**
|
||||
* For non-OpenID users (or OpenID without REUSE_TOKENS), use refreshToken from cookies.
|
||||
* These users authenticate via setAuthTokens() which stores refreshToken in cookies.
|
||||
*/
|
||||
const refreshToken = parsedCookies.refreshToken;
|
||||
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const validationResult = validateToken(refreshToken);
|
||||
if (!validationResult.valid) {
|
||||
logger.warn(`[validateImageRequest] ${validationResult.error}`);
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ jest.mock('~/server/middleware', () => ({
|
|||
forkUserLimiter: (req, res, next) => next(),
|
||||
})),
|
||||
configMiddleware: (req, res, next) => next(),
|
||||
validateConvoAccess: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/utils/import/fork', () => ({
|
||||
|
|
@ -108,7 +109,7 @@ describe('Convos Routes', () => {
|
|||
let app;
|
||||
let convosRouter;
|
||||
const { deleteAllSharedLinks, deleteConvoSharedLink } = require('~/models');
|
||||
const { deleteConvos } = require('~/models/Conversation');
|
||||
const { deleteConvos, saveConvo } = require('~/models/Conversation');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
|
||||
beforeAll(() => {
|
||||
|
|
@ -460,6 +461,138 @@ describe('Convos Routes', () => {
|
|||
expect(deleteConvoSharedLink).toHaveBeenCalledAfter(deleteConvos);
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /archive', () => {
|
||||
it('should archive a conversation successfully', async () => {
|
||||
const mockConversationId = 'conv-123';
|
||||
const mockArchivedConvo = {
|
||||
conversationId: mockConversationId,
|
||||
title: 'Test Conversation',
|
||||
isArchived: true,
|
||||
user: 'test-user-123',
|
||||
};
|
||||
|
||||
saveConvo.mockResolvedValue(mockArchivedConvo);
|
||||
|
||||
const response = await request(app)
|
||||
.post('/api/convos/archive')
|
||||
.send({
|
||||
arg: {
|
||||
conversationId: mockConversationId,
|
||||
isArchived: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual(mockArchivedConvo);
|
||||
expect(saveConvo).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ user: { id: 'test-user-123' } }),
|
||||
{ conversationId: mockConversationId, isArchived: true },
|
||||
{ context: `POST /api/convos/archive ${mockConversationId}` },
|
||||
);
|
||||
});
|
||||
|
||||
it('should unarchive a conversation successfully', async () => {
|
||||
const mockConversationId = 'conv-456';
|
||||
const mockUnarchivedConvo = {
|
||||
conversationId: mockConversationId,
|
||||
title: 'Unarchived Conversation',
|
||||
isArchived: false,
|
||||
user: 'test-user-123',
|
||||
};
|
||||
|
||||
saveConvo.mockResolvedValue(mockUnarchivedConvo);
|
||||
|
||||
const response = await request(app)
|
||||
.post('/api/convos/archive')
|
||||
.send({
|
||||
arg: {
|
||||
conversationId: mockConversationId,
|
||||
isArchived: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual(mockUnarchivedConvo);
|
||||
expect(saveConvo).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ user: { id: 'test-user-123' } }),
|
||||
{ conversationId: mockConversationId, isArchived: false },
|
||||
{ context: `POST /api/convos/archive ${mockConversationId}` },
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 400 when conversationId is missing', async () => {
|
||||
const response = await request(app)
|
||||
.post('/api/convos/archive')
|
||||
.send({
|
||||
arg: {
|
||||
isArchived: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
expect(response.body).toEqual({ error: 'conversationId is required' });
|
||||
expect(saveConvo).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 400 when isArchived is not a boolean', async () => {
|
||||
const response = await request(app)
|
||||
.post('/api/convos/archive')
|
||||
.send({
|
||||
arg: {
|
||||
conversationId: 'conv-123',
|
||||
isArchived: 'true',
|
||||
},
|
||||
});
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
expect(response.body).toEqual({ error: 'isArchived must be a boolean' });
|
||||
expect(saveConvo).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 400 when isArchived is undefined', async () => {
|
||||
const response = await request(app)
|
||||
.post('/api/convos/archive')
|
||||
.send({
|
||||
arg: {
|
||||
conversationId: 'conv-123',
|
||||
},
|
||||
});
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
expect(response.body).toEqual({ error: 'isArchived must be a boolean' });
|
||||
expect(saveConvo).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 500 when saveConvo fails', async () => {
|
||||
const mockConversationId = 'conv-error';
|
||||
saveConvo.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
const response = await request(app)
|
||||
.post('/api/convos/archive')
|
||||
.send({
|
||||
arg: {
|
||||
conversationId: mockConversationId,
|
||||
isArchived: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.text).toBe('Error archiving conversation');
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error archiving conversation', expect.any(Error));
|
||||
});
|
||||
|
||||
it('should handle empty arg object', async () => {
|
||||
const response = await request(app).post('/api/convos/archive').send({
|
||||
arg: {},
|
||||
});
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
expect(response.body).toEqual({ error: 'conversationId is required' });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ const express = require('express');
|
|||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { getBasePath } = require('@librechat/api');
|
||||
|
||||
const mockRegistryInstance = {
|
||||
getServerConfig: jest.fn(),
|
||||
|
|
@ -12,26 +13,36 @@ const mockRegistryInstance = {
|
|||
removeServer: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
jest.mock('@librechat/api', () => {
|
||||
const actual = jest.requireActual('@librechat/api');
|
||||
return {
|
||||
...actual,
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
// Error handling utilities (from @librechat/api mcp/errors)
|
||||
isMCPDomainNotAllowedError: (error) => error?.code === 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
isMCPInspectionFailedError: (error) => error?.code === 'MCP_INSPECTION_FAILED',
|
||||
MCPErrorCodes: {
|
||||
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -271,27 +282,30 @@ describe('MCP Routes', () => {
|
|||
error: 'access_denied',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=access_denied');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when code is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=missing_code');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_code`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when state is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=missing_state');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when flow state is not found', async () => {
|
||||
|
|
@ -301,9 +315,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'invalid-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=invalid_state');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
||||
});
|
||||
|
||||
it('should handle OAuth callback successfully', async () => {
|
||||
|
|
@ -358,9 +373,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
'test-flow-id',
|
||||
'test-auth-code',
|
||||
|
|
@ -394,9 +410,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
|
||||
});
|
||||
|
||||
it('should handle system-level OAuth completion', async () => {
|
||||
|
|
@ -429,9 +446,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
|
|
@ -474,9 +492,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
|
@ -515,9 +534,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
|
||||
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
|
|
@ -573,9 +593,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify storeTokens was called with ORIGINAL flow state credentials
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
||||
|
|
@ -614,9 +635,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify completeOAuthFlow was NOT called (prevented duplicate)
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled();
|
||||
|
|
@ -1385,8 +1407,10 @@ describe('MCP Routes', () => {
|
|||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.expect(302);
|
||||
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(mockFlowManager.completeFlow).not.toHaveBeenCalled();
|
||||
expect(response.headers.location).toContain('/oauth/success');
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
|
||||
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
|
||||
|
|
@ -1433,7 +1457,9 @@ describe('MCP Routes', () => {
|
|||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.expect(302);
|
||||
|
||||
expect(response.headers.location).toContain('/oauth/success');
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -1527,23 +1553,19 @@ describe('MCP Routes', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should create MCP server with valid stdio config', async () => {
|
||||
const validConfig = {
|
||||
it('should reject stdio config for security reasons', async () => {
|
||||
const stdioConfig = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
title: 'Test Stdio Server',
|
||||
};
|
||||
|
||||
mockRegistryInstance.addServer.mockResolvedValue({
|
||||
serverName: 'test-stdio-server',
|
||||
config: validConfig,
|
||||
});
|
||||
const response = await request(app).post('/api/mcp/servers').send({ config: stdioConfig });
|
||||
|
||||
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
|
||||
|
||||
expect(response.status).toBe(201);
|
||||
expect(response.body.serverName).toBe('test-stdio-server');
|
||||
// Stdio transport is not allowed via API - only admins can configure it via YAML
|
||||
expect(response.status).toBe(400);
|
||||
expect(response.body.message).toBe('Invalid configuration');
|
||||
});
|
||||
|
||||
it('should return 400 for invalid configuration', async () => {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ const {
|
|||
} = require('~/server/controllers/PermissionsController');
|
||||
const { requireJwtAuth, checkBan, uaParser, canAccessResource } = require('~/server/middleware');
|
||||
const { checkPeoplePickerAccess } = require('~/server/middleware/checkPeoplePickerAccess');
|
||||
const { checkSharePublicAccess } = require('~/server/middleware/checkSharePublicAccess');
|
||||
const { findMCPServerById } = require('~/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
|
@ -36,52 +37,67 @@ router.get('/search-principals', checkPeoplePickerAccess, searchPrincipals);
|
|||
*/
|
||||
router.get('/:resourceType/roles', getResourceRoles);
|
||||
|
||||
/**
|
||||
* Middleware factory to check resource access for permission-related operations.
|
||||
* SECURITY: Users must have SHARE permission to view or modify resource permissions.
|
||||
* @param {string} requiredPermission - The permission bit required (e.g., SHARE)
|
||||
* @returns Express middleware function
|
||||
*/
|
||||
const checkResourcePermissionAccess = (requiredPermission) => (req, res, next) => {
|
||||
const { resourceType } = req.params;
|
||||
let middleware;
|
||||
|
||||
if (resourceType === ResourceType.AGENT) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else if (resourceType === ResourceType.PROMPTGROUP) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else if (resourceType === ResourceType.MCPSERVER) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'resourceId',
|
||||
idResolver: findMCPServerById,
|
||||
});
|
||||
} else {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: `Unsupported resource type: ${resourceType}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Execute the middleware
|
||||
middleware(req, res, next);
|
||||
};
|
||||
|
||||
/**
|
||||
* GET /api/permissions/{resourceType}/{resourceId}
|
||||
* Get all permissions for a specific resource
|
||||
* SECURITY: Requires SHARE permission to view resource permissions
|
||||
*/
|
||||
router.get('/:resourceType/:resourceId', getResourcePermissions);
|
||||
router.get(
|
||||
'/:resourceType/:resourceId',
|
||||
checkResourcePermissionAccess(PermissionBits.SHARE),
|
||||
getResourcePermissions,
|
||||
);
|
||||
|
||||
/**
|
||||
* PUT /api/permissions/{resourceType}/{resourceId}
|
||||
* Bulk update permissions for a specific resource
|
||||
* SECURITY: Requires SHARE permission to modify resource permissions
|
||||
* SECURITY: Requires SHARE_PUBLIC permission to enable public sharing
|
||||
*/
|
||||
router.put(
|
||||
'/:resourceType/:resourceId',
|
||||
// Use middleware that dynamically handles resource type and permissions
|
||||
(req, res, next) => {
|
||||
const { resourceType } = req.params;
|
||||
let middleware;
|
||||
|
||||
if (resourceType === ResourceType.AGENT) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission: PermissionBits.SHARE,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else if (resourceType === ResourceType.PROMPTGROUP) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermission: PermissionBits.SHARE,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else if (resourceType === ResourceType.MCPSERVER) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
requiredPermission: PermissionBits.SHARE,
|
||||
resourceIdParam: 'resourceId',
|
||||
idResolver: findMCPServerById,
|
||||
});
|
||||
} else {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: `Unsupported resource type: ${resourceType}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Execute the middleware
|
||||
middleware(req, res, next);
|
||||
},
|
||||
checkResourcePermissionAccess(PermissionBits.SHARE),
|
||||
checkSharePublicAccess,
|
||||
updateResourcePermissions,
|
||||
);
|
||||
|
||||
|
|
|
|||
228
api/server/routes/accessPermissions.test.js
Normal file
228
api/server/routes/accessPermissions.test.js
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { createMethods } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { ResourceType, PermissionBits } = require('librechat-data-provider');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
|
||||
/**
|
||||
* Mock the PermissionsController to isolate route testing
|
||||
*/
|
||||
jest.mock('~/server/controllers/PermissionsController', () => ({
|
||||
getUserEffectivePermissions: jest.fn((req, res) => res.json({ permissions: [] })),
|
||||
getAllEffectivePermissions: jest.fn((req, res) => res.json({ permissions: [] })),
|
||||
updateResourcePermissions: jest.fn((req, res) => res.json({ success: true })),
|
||||
getResourcePermissions: jest.fn((req, res) =>
|
||||
res.json({
|
||||
resourceType: req.params.resourceType,
|
||||
resourceId: req.params.resourceId,
|
||||
principals: [],
|
||||
public: false,
|
||||
}),
|
||||
),
|
||||
getResourceRoles: jest.fn((req, res) => res.json({ roles: [] })),
|
||||
searchPrincipals: jest.fn((req, res) => res.json({ principals: [] })),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/middleware/checkPeoplePickerAccess', () => ({
|
||||
checkPeoplePickerAccess: jest.fn((req, res, next) => next()),
|
||||
}));
|
||||
|
||||
// Import actual middleware to get canAccessResource
|
||||
const { canAccessResource } = require('~/server/middleware');
|
||||
const { findMCPServerById } = require('~/models');
|
||||
|
||||
/**
|
||||
* Security Tests for SBA-ADV-20251203-02
|
||||
*
|
||||
* These tests verify that users cannot query or modify agent permissions
|
||||
* without proper SHARE permission.
|
||||
*/
|
||||
describe('Access Permissions Routes - Security Tests (SBA-ADV-20251203-02)', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let authorId;
|
||||
let attackerId;
|
||||
let agentId;
|
||||
let methods;
|
||||
let User;
|
||||
let modelsToCleanup = [];
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
const { createModels } = require('@librechat/data-schemas');
|
||||
const models = createModels(mongoose);
|
||||
modelsToCleanup = Object.keys(models);
|
||||
Object.assign(mongoose.models, models);
|
||||
|
||||
methods = createMethods(mongoose);
|
||||
User = models.User;
|
||||
|
||||
await methods.seedDefaultRoles();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
for (const modelName of modelsToCleanup) {
|
||||
delete mongoose.models[modelName];
|
||||
}
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear all collections
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
await methods.seedDefaultRoles();
|
||||
|
||||
// Create author (owner of the agent)
|
||||
authorId = new mongoose.Types.ObjectId().toString();
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
name: 'Agent Owner',
|
||||
email: 'owner@example.com',
|
||||
username: 'owner@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create attacker (should not have access)
|
||||
attackerId = new mongoose.Types.ObjectId().toString();
|
||||
await User.create({
|
||||
_id: attackerId,
|
||||
name: 'Attacker',
|
||||
email: 'attacker@example.com',
|
||||
username: 'attacker@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create private agent owned by author
|
||||
const customAgentId = `agent_${uuidv4().replace(/-/g, '').substring(0, 20)}`;
|
||||
await createAgent({
|
||||
id: customAgentId,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
agentId = customAgentId;
|
||||
|
||||
// Create Express app with attacker as current user
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware - attacker is the current user
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: attackerId, role: 'USER' };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
|
||||
// Middleware factory for permission access check (mirrors actual implementation)
|
||||
const checkResourcePermissionAccess = (requiredPermission) => (req, res, next) => {
|
||||
const { resourceType } = req.params;
|
||||
let middleware;
|
||||
|
||||
if (resourceType === ResourceType.AGENT) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else if (resourceType === ResourceType.PROMPTGROUP) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else if (resourceType === ResourceType.MCPSERVER) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.MCPSERVER,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'resourceId',
|
||||
idResolver: findMCPServerById,
|
||||
});
|
||||
} else {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: `Unsupported resource type: ${resourceType}`,
|
||||
});
|
||||
}
|
||||
|
||||
middleware(req, res, next);
|
||||
};
|
||||
|
||||
// GET route with access control (THE FIX)
|
||||
app.get(
|
||||
'/permissions/:resourceType/:resourceId',
|
||||
checkResourcePermissionAccess(PermissionBits.SHARE),
|
||||
(req, res) =>
|
||||
res.json({
|
||||
resourceType: req.params.resourceType,
|
||||
resourceId: req.params.resourceId,
|
||||
principals: [],
|
||||
public: false,
|
||||
}),
|
||||
);
|
||||
|
||||
// PUT route with access control
|
||||
app.put(
|
||||
'/permissions/:resourceType/:resourceId',
|
||||
checkResourcePermissionAccess(PermissionBits.SHARE),
|
||||
(req, res) => res.json({ success: true }),
|
||||
);
|
||||
});
|
||||
|
||||
describe('GET /permissions/:resourceType/:resourceId', () => {
|
||||
it('should deny permission query for user without access (main vulnerability test)', async () => {
|
||||
/**
|
||||
* SECURITY TEST: This is the core test for SBA-ADV-20251203-02
|
||||
*
|
||||
* Before the fix, any authenticated user could query permissions for
|
||||
* any agent by just knowing the agent ID, exposing information about
|
||||
* who has access to private agents.
|
||||
*
|
||||
* After the fix, users must have SHARE permission to view permissions.
|
||||
*/
|
||||
const response = await request(app)
|
||||
.get(`/permissions/agent/${agentId}`)
|
||||
.set('Content-Type', 'application/json');
|
||||
|
||||
// Should be denied - attacker has no permission on the agent
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.error).toBe('Forbidden');
|
||||
});
|
||||
|
||||
it('should return 400 for unsupported resource type', async () => {
|
||||
const response = await request(app)
|
||||
.get(`/permissions/unsupported/${agentId}`)
|
||||
.set('Content-Type', 'application/json');
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
expect(response.body.message).toContain('Unsupported resource type');
|
||||
});
|
||||
});
|
||||
|
||||
describe('PUT /permissions/:resourceType/:resourceId', () => {
|
||||
it('should deny permission update for user without access', async () => {
|
||||
const response = await request(app)
|
||||
.put(`/permissions/agent/${agentId}`)
|
||||
.set('Content-Type', 'application/json')
|
||||
.send({ principals: [] });
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.error).toBe('Forbidden');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
const express = require('express');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { getAccessToken } = require('@librechat/api');
|
||||
const { getAccessToken, getBasePath } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { findToken, updateToken, createToken } = require('~/models');
|
||||
|
|
@ -24,6 +24,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
const { code, state } = req.query;
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const basePath = getBasePath();
|
||||
let identifier = action_id;
|
||||
try {
|
||||
let decodedState;
|
||||
|
|
@ -32,17 +33,17 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
} catch (err) {
|
||||
logger.error('Error verifying state parameter:', err);
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
if (decodedState.action_id !== action_id) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
if (!decodedState.user) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
identifier = `${decodedState.user}:${action_id}`;
|
||||
const flowState = await flowManager.getFlowState(identifier, 'oauth');
|
||||
|
|
@ -72,12 +73,12 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
|
||||
/** Redirect to React success page */
|
||||
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('Error in OAuth callback:', error);
|
||||
await flowManager.failFlow(identifier, 'oauth', error);
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ const express = require('express');
|
|||
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions, PermissionBits } = require('librechat-data-provider');
|
||||
const {
|
||||
setHeaders,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
validateConvoAccess,
|
||||
|
|
@ -16,8 +15,6 @@ const { getRoleByName } = require('~/models/Role');
|
|||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(moderateText);
|
||||
|
||||
const checkAgentAccess = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
|
|
@ -28,11 +25,11 @@ const checkAgentResourceAccess = canAccessAgentFromBody({
|
|||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
router.use(moderateText);
|
||||
router.use(checkAgentAccess);
|
||||
router.use(checkAgentResourceAccess);
|
||||
router.use(validateConvoAccess);
|
||||
router.use(buildEndpointOption);
|
||||
router.use(setHeaders);
|
||||
|
||||
const controller = async (req, res, next) => {
|
||||
await AgentController(req, res, next, initializeClient, addTitle);
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
const express = require('express');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { isEnabled, GenerationJobManager } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
messageIpLimiter,
|
||||
configMiddleware,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('~/server/middleware');
|
||||
const { v1 } = require('./v1');
|
||||
const chat = require('./chat');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
|
|
@ -22,13 +22,191 @@ router.use(uaParser);
|
|||
|
||||
router.use('/', v1);
|
||||
|
||||
/**
|
||||
* Stream endpoints - mounted before chatRouter to bypass rate limiters
|
||||
* These are GET requests and don't need message body validation or rate limiting
|
||||
*/
|
||||
|
||||
/**
|
||||
* @route GET /chat/stream/:streamId
|
||||
* @desc Subscribe to an ongoing generation job's SSE stream with replay support
|
||||
* @access Private
|
||||
* @description Sends sync event with resume state, replays missed chunks, then streams live
|
||||
* @query resume=true - Indicates this is a reconnection (sends sync event)
|
||||
*/
|
||||
router.get('/chat/stream/:streamId', async (req, res) => {
|
||||
const { streamId } = req.params;
|
||||
const isResume = req.query.resume === 'true';
|
||||
|
||||
const job = await GenerationJobManager.getJob(streamId);
|
||||
if (!job) {
|
||||
return res.status(404).json({
|
||||
error: 'Stream not found',
|
||||
message: 'The generation job does not exist or has expired.',
|
||||
});
|
||||
}
|
||||
|
||||
res.setHeader('Content-Encoding', 'identity');
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.flushHeaders();
|
||||
|
||||
logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`);
|
||||
|
||||
// Send sync event with resume state for ALL reconnecting clients
|
||||
// This supports multi-tab scenarios where each tab needs run step data
|
||||
if (isResume) {
|
||||
const resumeState = await GenerationJobManager.getResumeState(streamId);
|
||||
if (resumeState && !res.writableEnded) {
|
||||
// Send sync event with run steps AND aggregatedContent
|
||||
// Client will use aggregatedContent to initialize message state
|
||||
res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
logger.debug(
|
||||
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const result = await GenerationJobManager.subscribe(
|
||||
streamId,
|
||||
(event) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
}
|
||||
},
|
||||
(event) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
res.end();
|
||||
}
|
||||
},
|
||||
(error) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
res.end();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
if (!result) {
|
||||
return res.status(404).json({ error: 'Failed to subscribe to stream' });
|
||||
}
|
||||
|
||||
req.on('close', () => {
|
||||
logger.debug(`[AgentStream] Client disconnected from ${streamId}`);
|
||||
result.unsubscribe();
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @route GET /chat/active
|
||||
* @desc Get all active generation job IDs for the current user
|
||||
* @access Private
|
||||
* @returns { activeJobIds: string[] }
|
||||
*/
|
||||
router.get('/chat/active', async (req, res) => {
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id);
|
||||
res.json({ activeJobIds });
|
||||
});
|
||||
|
||||
/**
|
||||
* @route GET /chat/status/:conversationId
|
||||
* @desc Check if there's an active generation job for a conversation
|
||||
* @access Private
|
||||
* @returns { active, streamId, status, aggregatedContent, createdAt, resumeState }
|
||||
*/
|
||||
router.get('/chat/status/:conversationId', async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
|
||||
// streamId === conversationId, so we can use getJob directly
|
||||
const job = await GenerationJobManager.getJob(conversationId);
|
||||
|
||||
if (!job) {
|
||||
return res.json({ active: false });
|
||||
}
|
||||
|
||||
if (job.metadata.userId !== req.user.id) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
// Get resume state which contains aggregatedContent
|
||||
// Avoid calling both getStreamInfo and getResumeState (both fetch content)
|
||||
const resumeState = await GenerationJobManager.getResumeState(conversationId);
|
||||
const isActive = job.status === 'running';
|
||||
|
||||
res.json({
|
||||
active: isActive,
|
||||
streamId: conversationId,
|
||||
status: job.status,
|
||||
aggregatedContent: resumeState?.aggregatedContent ?? [],
|
||||
createdAt: job.createdAt,
|
||||
resumeState,
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @route POST /chat/abort
|
||||
* @desc Abort an ongoing generation job
|
||||
* @access Private
|
||||
* @description Mounted before chatRouter to bypass buildEndpointOption middleware
|
||||
*/
|
||||
router.post('/chat/abort', async (req, res) => {
|
||||
logger.debug(`[AgentStream] ========== ABORT ENDPOINT HIT ==========`);
|
||||
logger.debug(`[AgentStream] Method: ${req.method}, Path: ${req.path}`);
|
||||
logger.debug(`[AgentStream] Body:`, req.body);
|
||||
|
||||
const { streamId, conversationId, abortKey } = req.body;
|
||||
const userId = req.user?.id;
|
||||
|
||||
// streamId === conversationId, so try any of the provided IDs
|
||||
// Skip "new" as it's a placeholder for new conversations, not an actual ID
|
||||
let jobStreamId =
|
||||
streamId || (conversationId !== 'new' ? conversationId : null) || abortKey?.split(':')[0];
|
||||
let job = jobStreamId ? await GenerationJobManager.getJob(jobStreamId) : null;
|
||||
|
||||
// Fallback: if job not found and we have a userId, look up active jobs for user
|
||||
// This handles the case where frontend sends "new" but job was created with a UUID
|
||||
if (!job && userId) {
|
||||
logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId);
|
||||
if (activeJobIds.length > 0) {
|
||||
// Abort the most recent active job for this user
|
||||
jobStreamId = activeJobIds[0];
|
||||
job = await GenerationJobManager.getJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Found active job for user: ${jobStreamId}`);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(`[AgentStream] Computed jobStreamId: ${jobStreamId}`);
|
||||
|
||||
if (job && jobStreamId) {
|
||||
logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`);
|
||||
await GenerationJobManager.abortJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`);
|
||||
return res.json({ success: true, aborted: jobStreamId });
|
||||
}
|
||||
|
||||
logger.warn(`[AgentStream] Job not found for streamId: ${jobStreamId}`);
|
||||
return res.status(404).json({ error: 'Job not found', streamId: jobStreamId });
|
||||
});
|
||||
|
||||
const chatRouter = express.Router();
|
||||
chatRouter.use(configMiddleware);
|
||||
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
chatRouter.use(concurrentLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||
chatRouter.use(messageIpLimiter);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ const checkGlobalAgentShare = generateCheckAccess({
|
|||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
bodyProps: {
|
||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||
[Permissions.SHARE]: ['projectIds', 'removeProjectIds'],
|
||||
},
|
||||
getRoleByName,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
createImportLimiters,
|
||||
validateConvoAccess,
|
||||
createForkLimiters,
|
||||
configMiddleware,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -31,7 +32,7 @@ router.get('/', async (req, res) => {
|
|||
const cursor = req.query.cursor;
|
||||
const isArchived = isEnabled(req.query.isArchived);
|
||||
const search = req.query.search ? decodeURIComponent(req.query.search) : undefined;
|
||||
const sortBy = req.query.sortBy || 'createdAt';
|
||||
const sortBy = req.query.sortBy || 'updatedAt';
|
||||
const sortDirection = req.query.sortDirection || 'desc';
|
||||
|
||||
let tags;
|
||||
|
|
@ -67,16 +68,17 @@ router.get('/:conversationId', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
router.post('/gen_title', async (req, res) => {
|
||||
const { conversationId } = req.body;
|
||||
router.get('/gen_title/:conversationId', async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${conversationId}`;
|
||||
let title = await titleCache.get(key);
|
||||
|
||||
if (!title) {
|
||||
// Retry every 1s for up to 20s
|
||||
for (let i = 0; i < 20; i++) {
|
||||
await sleep(1000);
|
||||
// Exponential backoff: 500ms, 1s, 2s, 4s, 8s (total ~15.5s max wait)
|
||||
const delays = [500, 1000, 2000, 4000, 8000];
|
||||
for (const delay of delays) {
|
||||
await sleep(delay);
|
||||
title = await titleCache.get(key);
|
||||
if (title) {
|
||||
break;
|
||||
|
|
@ -150,17 +152,70 @@ router.delete('/all', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
router.post('/update', async (req, res) => {
|
||||
const update = req.body.arg;
|
||||
/**
|
||||
* Archives or unarchives a conversation.
|
||||
* @route POST /archive
|
||||
* @param {string} req.body.arg.conversationId - The conversation ID to archive/unarchive.
|
||||
* @param {boolean} req.body.arg.isArchived - Whether to archive (true) or unarchive (false).
|
||||
* @returns {object} 200 - The updated conversation object.
|
||||
*/
|
||||
router.post('/archive', validateConvoAccess, async (req, res) => {
|
||||
const { conversationId, isArchived } = req.body.arg ?? {};
|
||||
|
||||
if (!update.conversationId) {
|
||||
if (!conversationId) {
|
||||
return res.status(400).json({ error: 'conversationId is required' });
|
||||
}
|
||||
|
||||
if (typeof isArchived !== 'boolean') {
|
||||
return res.status(400).json({ error: 'isArchived must be a boolean' });
|
||||
}
|
||||
|
||||
try {
|
||||
const dbResponse = await saveConvo(req, update, {
|
||||
context: `POST /api/convos/update ${update.conversationId}`,
|
||||
});
|
||||
const dbResponse = await saveConvo(
|
||||
req,
|
||||
{ conversationId, isArchived },
|
||||
{ context: `POST /api/convos/archive ${conversationId}` },
|
||||
);
|
||||
res.status(200).json(dbResponse);
|
||||
} catch (error) {
|
||||
logger.error('Error archiving conversation', error);
|
||||
res.status(500).send('Error archiving conversation');
|
||||
}
|
||||
});
|
||||
|
||||
/** Maximum allowed length for conversation titles */
|
||||
const MAX_CONVO_TITLE_LENGTH = 1024;
|
||||
|
||||
/**
|
||||
* Updates a conversation's title.
|
||||
* @route POST /update
|
||||
* @param {string} req.body.arg.conversationId - The conversation ID to update.
|
||||
* @param {string} req.body.arg.title - The new title for the conversation.
|
||||
* @returns {object} 201 - The updated conversation object.
|
||||
*/
|
||||
router.post('/update', validateConvoAccess, async (req, res) => {
|
||||
const { conversationId, title } = req.body.arg ?? {};
|
||||
|
||||
if (!conversationId) {
|
||||
return res.status(400).json({ error: 'conversationId is required' });
|
||||
}
|
||||
|
||||
if (title === undefined) {
|
||||
return res.status(400).json({ error: 'title is required' });
|
||||
}
|
||||
|
||||
if (typeof title !== 'string') {
|
||||
return res.status(400).json({ error: 'title must be a string' });
|
||||
}
|
||||
|
||||
const sanitizedTitle = title.trim().slice(0, MAX_CONVO_TITLE_LENGTH);
|
||||
|
||||
try {
|
||||
const dbResponse = await saveConvo(
|
||||
req,
|
||||
{ conversationId, title: sanitizedTitle },
|
||||
{ context: `POST /api/convos/update ${conversationId}` },
|
||||
);
|
||||
res.status(201).json(dbResponse);
|
||||
} catch (error) {
|
||||
logger.error('Error updating conversation', error);
|
||||
|
|
|
|||
|
|
@ -4,7 +4,12 @@ const mongoose = require('mongoose');
|
|||
const { v4: uuidv4 } = require('uuid');
|
||||
const { createMethods } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { AccessRoleIds, ResourceType, PrincipalType } = require('librechat-data-provider');
|
||||
const {
|
||||
SystemRoles,
|
||||
AccessRoleIds,
|
||||
ResourceType,
|
||||
PrincipalType,
|
||||
} = require('librechat-data-provider');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { createFile } = require('~/models');
|
||||
|
||||
|
|
@ -13,7 +18,13 @@ jest.mock('~/server/services/Files/process', () => ({
|
|||
processDeleteRequest: jest.fn().mockResolvedValue({}),
|
||||
filterFile: jest.fn(),
|
||||
processFileUpload: jest.fn(),
|
||||
processAgentFileUpload: jest.fn(),
|
||||
processAgentFileUpload: jest.fn().mockImplementation(async ({ res }) => {
|
||||
// processAgentFileUpload sends response directly via res.json()
|
||||
return res.status(200).json({
|
||||
message: 'Agent file uploaded and processed successfully',
|
||||
file_id: 'test-file-id',
|
||||
});
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
|
|
@ -28,6 +39,31 @@ jest.mock('~/server/services/Tools/credentials', () => ({
|
|||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3FileUrls: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock fs.promises.unlink to prevent file cleanup errors in tests
|
||||
jest.mock('fs', () => {
|
||||
const actualFs = jest.requireActual('fs');
|
||||
return {
|
||||
...actualFs,
|
||||
promises: {
|
||||
...actualFs.promises,
|
||||
unlink: jest.fn().mockResolvedValue(undefined),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const { processAgentFileUpload } = require('~/server/services/Files/process');
|
||||
|
||||
// Import the router
|
||||
const router = require('~/server/routes/files/files');
|
||||
|
||||
|
|
@ -339,4 +375,347 @@ describe('File Routes - Agent Files Endpoint', () => {
|
|||
expect(response.body.map((f) => f.file_id)).toContain(otherUserFileId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /files - Agent File Upload Permission Check', () => {
|
||||
let agentCustomId;
|
||||
|
||||
beforeEach(async () => {
|
||||
agentCustomId = `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
/**
|
||||
* Helper to create an Express app with specific user context
|
||||
*/
|
||||
const createAppWithUser = (userId, userRole = SystemRoles.USER) => {
|
||||
const testApp = express();
|
||||
testApp.use(express.json());
|
||||
|
||||
// Mock multer - populate req.file
|
||||
testApp.use((req, res, next) => {
|
||||
if (req.method === 'POST') {
|
||||
req.file = {
|
||||
originalname: 'test.txt',
|
||||
mimetype: 'text/plain',
|
||||
size: 100,
|
||||
path: '/tmp/test.txt',
|
||||
};
|
||||
req.file_id = uuidv4();
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
testApp.use((req, res, next) => {
|
||||
req.user = { id: userId.toString(), role: userRole };
|
||||
req.app = { locals: {} };
|
||||
req.config = { fileStrategy: 'local' };
|
||||
next();
|
||||
});
|
||||
|
||||
testApp.use('/files', router);
|
||||
return testApp;
|
||||
};
|
||||
|
||||
it('should deny file upload to agent when user has no permission', async () => {
|
||||
// Create an agent owned by authorId
|
||||
await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'context',
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.error).toBe('Forbidden');
|
||||
expect(response.body.message).toBe('Insufficient permissions to upload files to this agent');
|
||||
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow file upload to agent for agent author', async () => {
|
||||
// Create an agent owned by authorId
|
||||
await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(authorId);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'context',
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow file upload to agent for user with EDIT permission', async () => {
|
||||
// Create an agent owned by authorId
|
||||
const agent = await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Grant EDIT permission to otherUserId
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'context',
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should deny file upload to agent for user with only VIEW permission', async () => {
|
||||
// Create an agent owned by authorId
|
||||
const agent = await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to otherUserId
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'file_search',
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.error).toBe('Forbidden');
|
||||
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow file upload for admin user regardless of agent ownership', async () => {
|
||||
// Create an agent owned by authorId
|
||||
await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Create app with admin user (otherUserId as admin)
|
||||
const testApp = createAppWithUser(otherUserId, SystemRoles.ADMIN);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'context',
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 404 when uploading to non-existent agent', async () => {
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: 'agent_nonexistent123456789',
|
||||
tool_resource: 'context',
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(response.body.error).toBe('Not Found');
|
||||
expect(response.body.message).toBe('Agent not found');
|
||||
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow file upload without agent_id (message attachment)', async () => {
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
file_id: uuidv4(),
|
||||
// No agent_id or tool_resource - this is a message attachment
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow file upload with agent_id but no tool_resource (message attachment)', async () => {
|
||||
// Create an agent owned by authorId
|
||||
await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
file_id: uuidv4(),
|
||||
// No tool_resource - permission check should not apply
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow message_file attachment to agent even without EDIT permission', async () => {
|
||||
// Create an agent owned by authorId
|
||||
const agent = await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to otherUserId
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
// message_file: true indicates this is a chat message attachment, not a permanent file upload
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'context',
|
||||
message_file: true,
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow message_file attachment (string "true") to agent even without EDIT permission', async () => {
|
||||
// Create an agent owned by authorId
|
||||
const agent = await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to otherUserId
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
// message_file as string "true" (from form data) should also be allowed
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'context',
|
||||
message_file: 'true',
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should deny file upload when message_file is false (not a message attachment)', async () => {
|
||||
// Create an agent owned by authorId
|
||||
const agent = await createAgent({
|
||||
id: agentCustomId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to otherUserId
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const testApp = createAppWithUser(otherUserId);
|
||||
|
||||
// message_file: false should NOT bypass permission check
|
||||
const response = await request(testApp).post('/files').send({
|
||||
endpoint: 'agents',
|
||||
agent_id: agentCustomId,
|
||||
tool_resource: 'context',
|
||||
message_file: false,
|
||||
file_id: uuidv4(),
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.error).toBe('Forbidden');
|
||||
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ const {
|
|||
isUUID,
|
||||
CacheKeys,
|
||||
FileSources,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
EModelEndpoint,
|
||||
PermissionBits,
|
||||
|
|
@ -380,6 +381,50 @@ router.post('/', async (req, res) => {
|
|||
return await processFileUpload({ req, res, metadata });
|
||||
}
|
||||
|
||||
/**
|
||||
* Check agent permissions for permanent agent file uploads (not message attachments).
|
||||
* Message attachments (message_file=true) are temporary files for a single conversation
|
||||
* and should be allowed for users who can chat with the agent.
|
||||
* Permanent file uploads to tool_resources require EDIT permission.
|
||||
*/
|
||||
const isMessageAttachment = metadata.message_file === true || metadata.message_file === 'true';
|
||||
if (metadata.agent_id && metadata.tool_resource && !isMessageAttachment) {
|
||||
const userId = req.user.id;
|
||||
|
||||
/** Admin users bypass permission checks */
|
||||
if (req.user.role !== SystemRoles.ADMIN) {
|
||||
const agent = await getAgent({ id: metadata.agent_id });
|
||||
|
||||
if (!agent) {
|
||||
return res.status(404).json({
|
||||
error: 'Not Found',
|
||||
message: 'Agent not found',
|
||||
});
|
||||
}
|
||||
|
||||
/** Check if user is the author or has edit permission */
|
||||
if (agent.author.toString() !== userId) {
|
||||
const hasEditPermission = await checkPermission({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
if (!hasEditPermission) {
|
||||
logger.warn(
|
||||
`[/files] User ${userId} denied upload to agent ${metadata.agent_id} (insufficient permissions)`,
|
||||
);
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to upload files to this agent',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return await processAgentFileUpload({ req, res, metadata });
|
||||
} catch (error) {
|
||||
let message = 'Error processing file';
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ const {
|
|||
createSafeUser,
|
||||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
getBasePath,
|
||||
getUserMCPAuthMap,
|
||||
generateCheckAccess,
|
||||
} = require('@librechat/api');
|
||||
|
|
@ -105,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
|||
* This handles the OAuth callback after the user has authorized the application
|
||||
*/
|
||||
router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
const basePath = getBasePath();
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const { code, state, error: oauthError } = req.query;
|
||||
|
|
@ -118,17 +120,19 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
|
||||
if (oauthError) {
|
||||
logger.error('[MCP OAuth] OAuth error received', { error: oauthError });
|
||||
return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`);
|
||||
return res.redirect(
|
||||
`${basePath}/oauth/error?error=${encodeURIComponent(String(oauthError))}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!code || typeof code !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid code');
|
||||
return res.redirect('/oauth/error?error=missing_code');
|
||||
return res.redirect(`${basePath}/oauth/error?error=missing_code`);
|
||||
}
|
||||
|
||||
if (!state || typeof state !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid state');
|
||||
return res.redirect('/oauth/error?error=missing_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
|
||||
}
|
||||
|
||||
const flowId = state;
|
||||
|
|
@ -142,7 +146,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
|
||||
if (!flowState) {
|
||||
logger.error('[MCP OAuth] Flow state not found for flowId:', flowId);
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Flow state details', {
|
||||
|
|
@ -160,7 +164,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
flowId,
|
||||
serverName,
|
||||
});
|
||||
return res.redirect(`/oauth/success?serverName=${encodeURIComponent(serverName)}`);
|
||||
return res.redirect(`${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
|
|
@ -254,11 +258,11 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
}
|
||||
|
||||
/** Redirect to success page with flowId and serverName */
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] OAuth callback error', error);
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -588,7 +592,7 @@ async function getOAuthHeaders(serverName, userId) {
|
|||
return serverConfig?.oauth_headers ?? {};
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
MCP Server CRUD Routes (User-Managed MCP Servers)
|
||||
*/
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const express = require('express');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ContentTypes } = require('librechat-data-provider');
|
||||
const { unescapeLaTeX, countTokens } = require('@librechat/api');
|
||||
|
|
@ -23,7 +24,7 @@ router.get('/', async (req, res) => {
|
|||
const user = req.user.id ?? '';
|
||||
const {
|
||||
cursor = null,
|
||||
sortBy = 'createdAt',
|
||||
sortBy = 'updatedAt',
|
||||
sortDirection = 'desc',
|
||||
pageSize: pageSizeRaw,
|
||||
conversationId,
|
||||
|
|
@ -54,7 +55,12 @@ router.get('/', async (req, res) => {
|
|||
.sort({ [sortField]: sortOrder })
|
||||
.limit(pageSize + 1)
|
||||
.lean();
|
||||
const nextCursor = messages.length > pageSize ? messages.pop()[sortField] : null;
|
||||
let nextCursor = null;
|
||||
if (messages.length > pageSize) {
|
||||
messages.pop(); // Remove extra item used to detect next page
|
||||
// Create cursor from the last RETURNED item (not the popped one)
|
||||
nextCursor = messages[messages.length - 1][sortField];
|
||||
}
|
||||
response = { messages, nextCursor };
|
||||
} else if (search) {
|
||||
const searchResults = await Message.meiliSearch(search, { filter: `user = "${user}"` }, true);
|
||||
|
|
@ -111,6 +117,91 @@ router.get('/', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Creates a new branch message from a specific agent's content within a parallel response message.
|
||||
* Filters the original message's content to only include parts attributed to the specified agentId.
|
||||
* Only available for non-user messages with content attributions.
|
||||
*
|
||||
* @route POST /branch
|
||||
* @param {string} req.body.messageId - The ID of the source message
|
||||
* @param {string} req.body.agentId - The agentId to filter content by
|
||||
* @returns {TMessage} The newly created branch message
|
||||
*/
|
||||
router.post('/branch', async (req, res) => {
|
||||
try {
|
||||
const { messageId, agentId } = req.body;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!messageId || !agentId) {
|
||||
return res.status(400).json({ error: 'messageId and agentId are required' });
|
||||
}
|
||||
|
||||
const sourceMessage = await getMessage({ user: userId, messageId });
|
||||
if (!sourceMessage) {
|
||||
return res.status(404).json({ error: 'Source message not found' });
|
||||
}
|
||||
|
||||
if (sourceMessage.isCreatedByUser) {
|
||||
return res.status(400).json({ error: 'Cannot branch from user messages' });
|
||||
}
|
||||
|
||||
if (!Array.isArray(sourceMessage.content)) {
|
||||
return res.status(400).json({ error: 'Message does not have content' });
|
||||
}
|
||||
|
||||
const hasAgentMetadata = sourceMessage.content.some((part) => part?.agentId);
|
||||
if (!hasAgentMetadata) {
|
||||
return res
|
||||
.status(400)
|
||||
.json({ error: 'Message does not have parallel content with attributions' });
|
||||
}
|
||||
|
||||
/** @type {Array<import('librechat-data-provider').TMessageContentParts>} */
|
||||
const filteredContent = [];
|
||||
for (const part of sourceMessage.content) {
|
||||
if (part?.agentId === agentId) {
|
||||
const { agentId: _a, groupId: _g, ...cleanPart } = part;
|
||||
filteredContent.push(cleanPart);
|
||||
}
|
||||
}
|
||||
|
||||
if (filteredContent.length === 0) {
|
||||
return res.status(400).json({ error: 'No content found for the specified agentId' });
|
||||
}
|
||||
|
||||
const newMessageId = uuidv4();
|
||||
/** @type {import('librechat-data-provider').TMessage} */
|
||||
const newMessage = {
|
||||
messageId: newMessageId,
|
||||
conversationId: sourceMessage.conversationId,
|
||||
parentMessageId: sourceMessage.parentMessageId,
|
||||
attachments: sourceMessage.attachments,
|
||||
isCreatedByUser: false,
|
||||
model: sourceMessage.model,
|
||||
endpoint: sourceMessage.endpoint,
|
||||
sender: sourceMessage.sender,
|
||||
iconURL: sourceMessage.iconURL,
|
||||
content: filteredContent,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
user: userId,
|
||||
};
|
||||
|
||||
const savedMessage = await saveMessage(req, newMessage, {
|
||||
context: 'POST /api/messages/branch',
|
||||
});
|
||||
|
||||
if (!savedMessage) {
|
||||
return res.status(500).json({ error: 'Failed to save branch message' });
|
||||
}
|
||||
|
||||
res.status(201).json(savedMessage);
|
||||
} catch (error) {
|
||||
logger.error('Error creating branch message:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/artifact/:messageId', async (req, res) => {
|
||||
try {
|
||||
const { messageId } = req.params;
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ const oauthHandler = async (req, res, next) => {
|
|||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
|
||||
setOpenIDAuthTokens(req.user.tokenset, req, res, req.user._id.toString());
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ const checkGlobalPromptShare = generateCheckAccess({
|
|||
permissionType: PermissionTypes.PROMPTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
bodyProps: {
|
||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||
[Permissions.SHARE]: ['projectIds', 'removeProjectIds'],
|
||||
},
|
||||
getRoleByName,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ async function setupTestData() {
|
|||
case SystemRoles.USER:
|
||||
return { permissions: { PROMPTS: { USE: true, CREATE: true } } };
|
||||
case SystemRoles.ADMIN:
|
||||
return { permissions: { PROMPTS: { USE: true, CREATE: true, SHARED_GLOBAL: true } } };
|
||||
return { permissions: { PROMPTS: { USE: true, CREATE: true, SHARE: true } } };
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@ const { nanoid } = require('nanoid');
|
|||
const { tool } = require('@langchain/core/tools');
|
||||
const { GraphEvents, sleep } = require('@librechat/agents');
|
||||
const { logger, encryptV2, decryptV2 } = require('@librechat/data-schemas');
|
||||
const { sendEvent, logAxiosError, refreshAccessToken } = require('@librechat/api');
|
||||
const {
|
||||
sendEvent,
|
||||
logAxiosError,
|
||||
refreshAccessToken,
|
||||
GenerationJobManager,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
|
|
@ -127,6 +132,7 @@ async function loadActionSets(searchParams) {
|
|||
* @param {string | undefined} [params.description] - The description for the tool.
|
||||
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
|
||||
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable streams.
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createActionTool({
|
||||
|
|
@ -138,6 +144,7 @@ async function createActionTool({
|
|||
name,
|
||||
description,
|
||||
encrypted,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolInput, config) => {
|
||||
|
|
@ -192,7 +199,12 @@ async function createActionTool({
|
|||
`${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`,
|
||||
'oauth_login',
|
||||
async () => {
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
logger.debug('Sent OAuth login request to client', { action_id, identifier });
|
||||
return true;
|
||||
},
|
||||
|
|
@ -217,7 +229,12 @@ async function createActionTool({
|
|||
logger.debug('Received OAuth Authorization response', { action_id, identifier });
|
||||
data.delta.auth = undefined;
|
||||
data.delta.expires_at = undefined;
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const successEventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, successEventData);
|
||||
} else {
|
||||
sendEvent(res, successEventData);
|
||||
}
|
||||
await sleep(3000);
|
||||
metadata.oauth_access_token = result.access_token;
|
||||
metadata.oauth_refresh_token = result.refresh_token;
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { webcrypto } = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, checkEmailConfig, isEmailDomainAllowed } = require('@librechat/api');
|
||||
const {
|
||||
logger,
|
||||
DEFAULT_SESSION_EXPIRY,
|
||||
DEFAULT_REFRESH_TOKEN_EXPIRY,
|
||||
} = require('@librechat/data-schemas');
|
||||
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const { isEnabled, checkEmailConfig, isEmailDomainAllowed, math } = require('@librechat/api');
|
||||
const {
|
||||
findUser,
|
||||
findToken,
|
||||
|
|
@ -369,19 +373,21 @@ const setAuthTokens = async (userId, res, _session = null) => {
|
|||
let session = _session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
const expiresIn = math(process.env.REFRESH_TOKEN_EXPIRY, DEFAULT_REFRESH_TOKEN_EXPIRY);
|
||||
|
||||
if (session && session._id && session.expiration != null) {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
const result = await createSession(userId);
|
||||
const result = await createSession(userId, { expiresIn });
|
||||
session = result.session;
|
||||
refreshToken = result.refreshToken;
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
const sessionExpiry = math(process.env.SESSION_EXPIRY, DEFAULT_SESSION_EXPIRY);
|
||||
const token = await generateToken(user, sessionExpiry);
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
|
|
@ -405,23 +411,26 @@ const setAuthTokens = async (userId, res, _session = null) => {
|
|||
/**
|
||||
* @function setOpenIDAuthTokens
|
||||
* Set OpenID Authentication Tokens
|
||||
* //type tokenset from openid-client
|
||||
* Stores tokens server-side in express-session to avoid large cookie sizes
|
||||
* that can exceed HTTP/2 header limits (especially for users with many group memberships).
|
||||
*
|
||||
* @param {import('openid-client').TokenEndpointResponse & import('openid-client').TokenEndpointResponseHelpers} tokenset
|
||||
* - The tokenset object containing access and refresh tokens
|
||||
* @param {Object} req - request object (for session access)
|
||||
* @param {Object} res - response object
|
||||
* @param {string} [userId] - Optional MongoDB user ID for image path validation
|
||||
* @returns {String} - access token
|
||||
*/
|
||||
const setOpenIDAuthTokens = (tokenset, res, userId, existingRefreshToken) => {
|
||||
const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) => {
|
||||
try {
|
||||
if (!tokenset) {
|
||||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
return;
|
||||
}
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY
|
||||
? eval(REFRESH_TOKEN_EXPIRY)
|
||||
: 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
const expiryInMilliseconds = math(
|
||||
process.env.REFRESH_TOKEN_EXPIRY,
|
||||
DEFAULT_REFRESH_TOKEN_EXPIRY,
|
||||
);
|
||||
const expirationDate = new Date(Date.now() + expiryInMilliseconds);
|
||||
if (tokenset == null) {
|
||||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
|
|
@ -439,18 +448,30 @@ const setOpenIDAuthTokens = (tokenset, res, userId, existingRefreshToken) => {
|
|||
return;
|
||||
}
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: expirationDate,
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
res.cookie('openid_access_token', tokenset.access_token, {
|
||||
expires: expirationDate,
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
/** Store tokens server-side in session to avoid large cookies */
|
||||
if (req.session) {
|
||||
req.session.openidTokens = {
|
||||
accessToken: tokenset.access_token,
|
||||
refreshToken: refreshToken,
|
||||
expiresAt: expirationDate.getTime(),
|
||||
};
|
||||
} else {
|
||||
logger.warn('[setOpenIDAuthTokens] No session available, falling back to cookies');
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: expirationDate,
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
res.cookie('openid_access_token', tokenset.access_token, {
|
||||
expires: expirationDate,
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
}
|
||||
|
||||
/** Small cookie to indicate token provider (required for auth middleware) */
|
||||
res.cookie('token_provider', 'openid', {
|
||||
expires: expirationDate,
|
||||
httpOnly: true,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { isUserProvided } = require('@librechat/api');
|
||||
const { isUserProvided, isEnabled } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { generateConfig } = require('~/server/utils/handleText');
|
||||
|
||||
|
|
@ -23,7 +23,9 @@ module.exports = {
|
|||
openAIApiKey,
|
||||
azureOpenAIApiKey,
|
||||
userProvidedOpenAI,
|
||||
[EModelEndpoint.anthropic]: generateConfig(anthropicApiKey),
|
||||
[EModelEndpoint.anthropic]: generateConfig(
|
||||
anthropicApiKey || isEnabled(process.env.ANTHROPIC_USE_VERTEX),
|
||||
),
|
||||
[EModelEndpoint.openAI]: generateConfig(openAIApiKey, OPENAI_REVERSE_PROXY),
|
||||
[EModelEndpoint.azureOpenAI]: generateConfig(azureOpenAIApiKey, AZURE_OPENAI_BASEURL),
|
||||
[EModelEndpoint.assistants]: generateConfig(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { CacheKeys, Time } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
|
|
@ -39,12 +39,12 @@ async function getCachedTools(options = {}) {
|
|||
* @param {Object} options - Options for caching tools
|
||||
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||
* @param {string} [options.serverName] - MCP server name for server-specific tools
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds (default: 12 hours)
|
||||
* @returns {Promise<boolean>} Whether the operation was successful
|
||||
*/
|
||||
async function setCachedTools(tools, options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { userId, serverName, ttl } = options;
|
||||
const { userId, serverName, ttl = Time.TWELVE_HOURS } = options;
|
||||
|
||||
// Cache by MCP server if specified (requires userId)
|
||||
if (serverName && userId) {
|
||||
|
|
|
|||
|
|
@ -43,6 +43,14 @@ async function getEndpointsConfig(req) {
|
|||
};
|
||||
}
|
||||
|
||||
// Enable Anthropic endpoint when Vertex AI is configured in YAML
|
||||
if (appConfig.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig?.enabled) {
|
||||
/** @type {Omit<TConfig, 'order'>} */
|
||||
mergedConfig[EModelEndpoint.anthropic] = {
|
||||
userProvide: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (appConfig.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
/** @type {Omit<TConfig, 'order'>} */
|
||||
mergedConfig[EModelEndpoint.azureAssistants] = {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const {
|
|||
getOpenAIModels,
|
||||
getGoogleModels,
|
||||
} = require('@librechat/api');
|
||||
const { getAppConfig } = require('./app');
|
||||
|
||||
/**
|
||||
* Loads the default models for the application.
|
||||
|
|
@ -15,16 +16,21 @@ const {
|
|||
*/
|
||||
async function loadDefaultModels(req) {
|
||||
try {
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const vertexConfig = appConfig?.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig;
|
||||
|
||||
const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] =
|
||||
await Promise.all([
|
||||
getOpenAIModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching OpenAI models:', error);
|
||||
return [];
|
||||
}),
|
||||
getAnthropicModels({ user: req.user.id }).catch((error) => {
|
||||
logger.error('Error fetching Anthropic models:', error);
|
||||
return [];
|
||||
}),
|
||||
getAnthropicModels({ user: req.user.id, vertexModels: vertexConfig?.modelNames }).catch(
|
||||
(error) => {
|
||||
logger.error('Error fetching Anthropic models:', error);
|
||||
return [];
|
||||
},
|
||||
),
|
||||
getOpenAIModels({ user: req.user.id, azure: true }).catch((error) => {
|
||||
logger.error('Error fetching Azure OpenAI models:', error);
|
||||
return [];
|
||||
|
|
|
|||
136
api/server/services/Endpoints/agents/addedConvo.js
Normal file
136
api/server/services/Endpoints/agents/addedConvo.js
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { initializeAgent, validateAgentModel } = require('@librechat/api');
|
||||
const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
||||
// Initialize the getAgent dependency
|
||||
setGetAgent(getAgent);
|
||||
|
||||
/**
|
||||
* Process addedConvo for parallel agent execution.
|
||||
* Creates a parallel agent config from an added conversation.
|
||||
*
|
||||
* When an added agent has no incoming edges, it becomes a start node
|
||||
* and runs in parallel with the primary agent automatically.
|
||||
*
|
||||
* Edge cases handled:
|
||||
* - Primary agent has edges (handoffs): Added agent runs in parallel with primary,
|
||||
* but doesn't participate in the primary's handoff graph
|
||||
* - Primary agent has agent_ids (legacy chain): Added agent runs in parallel with primary,
|
||||
* but doesn't participate in the chain
|
||||
* - Primary agent has both: Added agent is independent, runs parallel from start
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req
|
||||
* @param {import('express').Response} params.res
|
||||
* @param {Object} params.endpointOption - The endpoint option containing addedConvo
|
||||
* @param {Object} params.modelsConfig - The models configuration
|
||||
* @param {Function} params.logViolation - Function to log violations
|
||||
* @param {Function} params.loadTools - Function to load agent tools
|
||||
* @param {Array} params.requestFiles - Request files
|
||||
* @param {string} params.conversationId - The conversation ID
|
||||
* @param {Set} params.allowedProviders - Set of allowed providers
|
||||
* @param {Map} params.agentConfigs - Map of agent configs to add to
|
||||
* @param {string} params.primaryAgentId - The primary agent ID
|
||||
* @param {Object|undefined} params.userMCPAuthMap - User MCP auth map to merge into
|
||||
* @returns {Promise<{userMCPAuthMap: Object|undefined}>} The updated userMCPAuthMap
|
||||
*/
|
||||
const processAddedConvo = async ({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
allowedProviders,
|
||||
agentConfigs,
|
||||
primaryAgentId,
|
||||
primaryAgent,
|
||||
userMCPAuthMap,
|
||||
}) => {
|
||||
const addedConvo = endpointOption.addedConvo;
|
||||
logger.debug('[processAddedConvo] Called with addedConvo:', {
|
||||
hasAddedConvo: addedConvo != null,
|
||||
addedConvoEndpoint: addedConvo?.endpoint,
|
||||
addedConvoModel: addedConvo?.model,
|
||||
addedConvoAgentId: addedConvo?.agent_id,
|
||||
});
|
||||
if (addedConvo == null) {
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
try {
|
||||
const addedAgent = await loadAddedAgent({ req, conversation: addedConvo, primaryAgent });
|
||||
if (!addedAgent) {
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
const addedValidation = await validateAgentModel({
|
||||
req,
|
||||
res,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
agent: addedAgent,
|
||||
});
|
||||
|
||||
if (!addedValidation.isValid) {
|
||||
logger.warn(
|
||||
`[processAddedConvo] Added agent validation failed: ${addedValidation.error?.message}`,
|
||||
);
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
const addedConfig = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
agent: addedAgent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
},
|
||||
{
|
||||
getConvoFiles,
|
||||
getFiles: db.getFiles,
|
||||
getUserKey: db.getUserKey,
|
||||
updateFilesUsage: db.updateFilesUsage,
|
||||
getUserKeyValues: db.getUserKeyValues,
|
||||
getToolFilesByIds: db.getToolFilesByIds,
|
||||
},
|
||||
);
|
||||
|
||||
if (userMCPAuthMap != null) {
|
||||
Object.assign(userMCPAuthMap, addedConfig.userMCPAuthMap ?? {});
|
||||
} else {
|
||||
userMCPAuthMap = addedConfig.userMCPAuthMap;
|
||||
}
|
||||
|
||||
const addedAgentId = addedConfig.id || ADDED_AGENT_ID;
|
||||
agentConfigs.set(addedAgentId, addedConfig);
|
||||
|
||||
// No edges needed - agent without incoming edges becomes a start node
|
||||
// and runs in parallel with the primary agent automatically.
|
||||
// This is independent of any edges/agent_ids the primary agent has.
|
||||
|
||||
logger.debug(
|
||||
`[processAddedConvo] Added parallel agent: ${addedAgentId} (primary: ${primaryAgentId}, ` +
|
||||
`primary has edges: ${!!endpointOption.edges}, primary has agent_ids: ${!!endpointOption.agent_ids})`,
|
||||
);
|
||||
|
||||
return { userMCPAuthMap };
|
||||
} catch (err) {
|
||||
logger.error('[processAddedConvo] Error processing addedConvo for parallel agent', err);
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
processAddedConvo,
|
||||
ADDED_AGENT_ID,
|
||||
};
|
||||
|
|
@ -15,6 +15,9 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
|||
return undefined;
|
||||
});
|
||||
|
||||
/** @type {import('librechat-data-provider').TConversation | undefined} */
|
||||
const addedConvo = req.body?.addedConvo;
|
||||
|
||||
return removeNullishValues({
|
||||
spec,
|
||||
iconURL,
|
||||
|
|
@ -23,6 +26,7 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
|||
endpointType,
|
||||
model_parameters,
|
||||
agent: agentPromise,
|
||||
addedConvo,
|
||||
});
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -5,12 +5,14 @@ const {
|
|||
validateAgentModel,
|
||||
getCustomEndpointConfig,
|
||||
createSequentialChainEdges,
|
||||
createEdgeCollector,
|
||||
filterOrphanedEdges,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
getResponseSender,
|
||||
isEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
createToolEndCallback,
|
||||
|
|
@ -20,14 +22,17 @@ const { getModelsConfig } = require('~/server/controllers/ModelController');
|
|||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { processAddedConvo } = require('./addedConvo');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { logViolation } = require('~/cache');
|
||||
const db = require('~/models');
|
||||
|
||||
/**
|
||||
* @param {AbortSignal} signal
|
||||
* Creates a tool loader function for the agent.
|
||||
* @param {AbortSignal} signal - The abort signal
|
||||
* @param {string | null} [streamId] - The stream ID for resumable mode
|
||||
*/
|
||||
function createToolLoader(signal) {
|
||||
function createToolLoader(signal, streamId = null) {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
|
|
@ -52,6 +57,7 @@ function createToolLoader(signal) {
|
|||
agent,
|
||||
signal,
|
||||
tool_resources,
|
||||
streamId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error loading tools for agent ' + agentId, error);
|
||||
|
|
@ -65,18 +71,21 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
}
|
||||
const appConfig = req.config;
|
||||
|
||||
// TODO: use endpointOption to determine options/modelOptions
|
||||
/** @type {string | null} */
|
||||
const streamId = req._resumableStreamId || null;
|
||||
|
||||
/** @type {Array<UsageMetadata>} */
|
||||
const collectedUsage = [];
|
||||
/** @type {ArtifactPromises} */
|
||||
const artifactPromises = [];
|
||||
const { contentParts, aggregateContent } = createContentAggregator();
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId });
|
||||
const eventHandlers = getDefaultHandlers({
|
||||
res,
|
||||
aggregateContent,
|
||||
toolEndCallback,
|
||||
collectedUsage,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (!endpointOption.agent) {
|
||||
|
|
@ -105,7 +114,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
const agentConfigs = new Map();
|
||||
const allowedProviders = new Set(appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
const loadTools = createToolLoader(signal);
|
||||
const loadTools = createToolLoader(signal, streamId);
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
/** @type {string} */
|
||||
|
|
@ -136,10 +145,17 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
const agent_ids = primaryConfig.agent_ids;
|
||||
let userMCPAuthMap = primaryConfig.userMCPAuthMap;
|
||||
|
||||
/** @type {Set<string>} Track agents that failed to load (orphaned references) */
|
||||
const skippedAgentIds = new Set();
|
||||
|
||||
async function processAgent(agentId) {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
if (!agent) {
|
||||
throw new Error(`Agent ${agentId} not found`);
|
||||
logger.warn(
|
||||
`[processAgent] Handoff agent ${agentId} not found, skipping (orphaned reference)`,
|
||||
);
|
||||
skippedAgentIds.add(agentId);
|
||||
return null;
|
||||
}
|
||||
|
||||
const validationResult = await validateAgentModel({
|
||||
|
|
@ -180,37 +196,31 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
userMCPAuthMap = config.userMCPAuthMap;
|
||||
}
|
||||
agentConfigs.set(agentId, config);
|
||||
return agent;
|
||||
}
|
||||
|
||||
let edges = primaryConfig.edges;
|
||||
const checkAgentInit = (agentId) => agentId === primaryConfig.id || agentConfigs.has(agentId);
|
||||
if ((edges?.length ?? 0) > 0) {
|
||||
for (const edge of edges) {
|
||||
if (Array.isArray(edge.to)) {
|
||||
for (const to of edge.to) {
|
||||
if (checkAgentInit(to)) {
|
||||
continue;
|
||||
}
|
||||
await processAgent(to);
|
||||
}
|
||||
} else if (typeof edge.to === 'string' && checkAgentInit(edge.to)) {
|
||||
continue;
|
||||
} else if (typeof edge.to === 'string') {
|
||||
await processAgent(edge.to);
|
||||
}
|
||||
|
||||
if (Array.isArray(edge.from)) {
|
||||
for (const from of edge.from) {
|
||||
if (checkAgentInit(from)) {
|
||||
continue;
|
||||
}
|
||||
await processAgent(from);
|
||||
}
|
||||
} else if (typeof edge.from === 'string' && checkAgentInit(edge.from)) {
|
||||
continue;
|
||||
} else if (typeof edge.from === 'string') {
|
||||
await processAgent(edge.from);
|
||||
// Graph topology discovery for recursive agent handoffs (BFS)
|
||||
const { edgeMap, agentsToProcess, collectEdges } = createEdgeCollector(
|
||||
checkAgentInit,
|
||||
skippedAgentIds,
|
||||
);
|
||||
|
||||
// Seed with primary agent's edges
|
||||
collectEdges(primaryConfig.edges);
|
||||
|
||||
// BFS to load and merge all connected agents (enables transitive handoffs: A->B->C)
|
||||
while (agentsToProcess.size > 0) {
|
||||
const agentId = agentsToProcess.values().next().value;
|
||||
agentsToProcess.delete(agentId);
|
||||
try {
|
||||
const agent = await processAgent(agentId);
|
||||
if (agent?.edges?.length) {
|
||||
collectEdges(agent.edges);
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`[initializeClient] Error processing agent ${agentId}:`, err);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -222,11 +232,42 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
}
|
||||
await processAgent(agentId);
|
||||
}
|
||||
|
||||
const chain = await createSequentialChainEdges([primaryConfig.id].concat(agent_ids), '{convo}');
|
||||
edges = edges ? edges.concat(chain) : chain;
|
||||
collectEdges(chain);
|
||||
}
|
||||
|
||||
let edges = Array.from(edgeMap.values());
|
||||
|
||||
/** Multi-Convo: Process addedConvo for parallel agent execution */
|
||||
const { userMCPAuthMap: updatedMCPAuthMap } = await processAddedConvo({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
allowedProviders,
|
||||
agentConfigs,
|
||||
primaryAgentId: primaryConfig.id,
|
||||
primaryAgent,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
if (updatedMCPAuthMap) {
|
||||
userMCPAuthMap = updatedMCPAuthMap;
|
||||
}
|
||||
|
||||
// Ensure edges is an array when we have multiple agents (multi-agent mode)
|
||||
// MultiAgentGraph.categorizeEdges requires edges to be iterable
|
||||
if (agentConfigs.size > 0 && !edges) {
|
||||
edges = [];
|
||||
}
|
||||
|
||||
// Filter out edges referencing non-existent agents (orphaned references)
|
||||
edges = filterOrphanedEdges(edges, skippedAgentIds);
|
||||
|
||||
primaryConfig.edges = edges;
|
||||
|
||||
let endpointConfig = appConfig.endpoints?.[primaryConfig.endpoint];
|
||||
|
|
@ -270,10 +311,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
endpointType: endpointOption.endpointType,
|
||||
resendFiles: primaryConfig.resendFiles ?? true,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
endpoint:
|
||||
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
|
||||
? primaryConfig.endpoint
|
||||
: EModelEndpoint.agents,
|
||||
endpoint: isEphemeralAgentId(primaryConfig.id) ? primaryConfig.endpoint : EModelEndpoint.agents,
|
||||
});
|
||||
|
||||
return { client, userMCPAuthMap };
|
||||
|
|
|
|||
|
|
@ -17,6 +17,11 @@ const addTitle = async (req, { text, response, client }) => {
|
|||
return;
|
||||
}
|
||||
|
||||
// Skip title generation for temporary conversations
|
||||
if (req?.body?.isTemporary) {
|
||||
return;
|
||||
}
|
||||
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${response.conversationId}`;
|
||||
/** @type {NodeJS.Timeout} */
|
||||
|
|
|
|||
|
|
@ -50,6 +50,11 @@ const addTitle = async (req, { text, responseText, conversationId }) => {
|
|||
return;
|
||||
}
|
||||
|
||||
// Skip title generation for temporary conversations
|
||||
if (req?.body?.isTemporary) {
|
||||
return;
|
||||
}
|
||||
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${conversationId}`;
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ const base64Only = new Set([
|
|||
EModelEndpoint.bedrock,
|
||||
]);
|
||||
|
||||
const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3]);
|
||||
const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3, FileSources.firebase]);
|
||||
|
||||
/**
|
||||
* Encodes and formats the given files.
|
||||
|
|
@ -127,7 +127,7 @@ async function encodeAndFormat(req, files, params, mode) {
|
|||
}
|
||||
|
||||
const preparePayload = encodingMethods[source].prepareImagePayload;
|
||||
/* We need to fetch the image and convert it to base64 if we are using S3/Azure Blob storage. */
|
||||
/* We need to fetch the image and convert it to base64 if we are using S3/Azure Blob/Firebase storage. */
|
||||
if (blobStorageSources.has(source)) {
|
||||
try {
|
||||
const downloadStream = encodingMethods[source].getDownloadStream;
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ const {
|
|||
const {
|
||||
sendEvent,
|
||||
MCPOAuthHandler,
|
||||
isMCPDomainAllowed,
|
||||
normalizeServerName,
|
||||
convertWithResolvedRefs,
|
||||
GenerationJobManager,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
|
|
@ -21,13 +23,14 @@ const {
|
|||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getMCPManager,
|
||||
getFlowStateManager,
|
||||
getOAuthReconnectionManager,
|
||||
getMCPServersRegistry,
|
||||
getFlowStateManager,
|
||||
getMCPManager,
|
||||
} = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
|
|
@ -35,8 +38,9 @@ const { getLogStores } = require('~/cache');
|
|||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
||||
function createRunStepDeltaEmitter({ res, stepId, toolCall, streamId = null }) {
|
||||
/**
|
||||
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
|
||||
* @returns {void}
|
||||
|
|
@ -52,7 +56,12 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
|||
expires_at: Date.now() + Time.TWO_MINUTES,
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -63,8 +72,9 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
|||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {number} [params.index]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
|
||||
function createRunStepEmitter({ res, runId, stepId, toolCall, index, streamId = null }) {
|
||||
return function () {
|
||||
/** @type {import('@librechat/agents').RunStep} */
|
||||
const data = {
|
||||
|
|
@ -77,7 +87,12 @@ function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
|
|||
tool_calls: [toolCall],
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -108,10 +123,9 @@ function createOAuthStart({ flowId, flowManager, callback }) {
|
|||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string} params.loginFlowId - The ID of the login flow.
|
||||
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createOAuthEnd({ res, stepId, toolCall }) {
|
||||
function createOAuthEnd({ res, stepId, toolCall, streamId = null }) {
|
||||
return async function () {
|
||||
/** @type {{ id: string; delta: AgentToolCallDelta }} */
|
||||
const data = {
|
||||
|
|
@ -121,7 +135,12 @@ function createOAuthEnd({ res, stepId, toolCall }) {
|
|||
tool_calls: [{ ...toolCall }],
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
logger.debug('Sent OAuth login success to client');
|
||||
};
|
||||
}
|
||||
|
|
@ -137,7 +156,9 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
|
|||
return function () {
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
// Clean up both mcp_oauth and mcp_get_tokens flows
|
||||
flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted'));
|
||||
flowManager.failFlow(flowId, 'mcp_get_tokens', new Error('Tool call aborted'));
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -162,10 +183,19 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
|||
* @param {AbortSignal} params.signal
|
||||
* @param {string} params.model
|
||||
* @param {number} [params.index]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) {
|
||||
async function reconnectServer({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
|
|
@ -176,36 +206,60 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
|||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
const runStepEmitter = createRunStepEmitter({
|
||||
res,
|
||||
index,
|
||||
runId,
|
||||
stepId,
|
||||
toolCall,
|
||||
});
|
||||
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
});
|
||||
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
|
||||
const oauthStart = createOAuthStart({
|
||||
res,
|
||||
flowId,
|
||||
callback,
|
||||
flowManager,
|
||||
});
|
||||
return await reinitMCPServer({
|
||||
user,
|
||||
signal,
|
||||
serverName,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
// Set up abort handler to clean up OAuth flows if request is aborted
|
||||
const oauthFlowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
const abortHandler = () => {
|
||||
logger.info(
|
||||
`[MCP][User: ${user.id}][${serverName}] Tool loading aborted, cleaning up OAuth flows`,
|
||||
);
|
||||
// Clean up both mcp_oauth and mcp_get_tokens flows
|
||||
flowManager.failFlow(oauthFlowId, 'mcp_oauth', new Error('Tool loading aborted'));
|
||||
flowManager.failFlow(oauthFlowId, 'mcp_get_tokens', new Error('Tool loading aborted'));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener('abort', abortHandler, { once: true });
|
||||
}
|
||||
|
||||
try {
|
||||
const runStepEmitter = createRunStepEmitter({
|
||||
res,
|
||||
index,
|
||||
runId,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
|
||||
const oauthStart = createOAuthStart({
|
||||
res,
|
||||
flowId,
|
||||
callback,
|
||||
flowManager,
|
||||
});
|
||||
return await reinitMCPServer({
|
||||
user,
|
||||
signal,
|
||||
serverName,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
} finally {
|
||||
// Clean up abort handler to prevent memory leaks
|
||||
if (signal) {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -222,11 +276,45 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
|||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {number} [params.index]
|
||||
* @param {AbortSignal} [params.signal]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
|
||||
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
|
||||
async function createMCPTools({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
config,
|
||||
provider,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
const result = await reconnectServer({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
if (!result || !result.tools) {
|
||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||
return;
|
||||
|
|
@ -239,8 +327,10 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
user,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
availableTools: result.availableTools,
|
||||
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||
config: serverConfig,
|
||||
});
|
||||
if (toolInstance) {
|
||||
serverTools.push(toolInstance);
|
||||
|
|
@ -259,9 +349,11 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
* @param {string} params.model - The model for the tool.
|
||||
* @param {number} [params.index]
|
||||
* @param {AbortSignal} [params.signal]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {LCAvailableTools} [params.availableTools]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTool({
|
||||
|
|
@ -273,9 +365,25 @@ async function createMCPTool({
|
|||
provider,
|
||||
userMCPAuthMap,
|
||||
availableTools,
|
||||
config,
|
||||
streamId = null,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
// Runtime domain validation: check if the server's domain is still allowed
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {LCTool | undefined} */
|
||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
|
|
@ -289,6 +397,7 @@ async function createMCPTool({
|
|||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||
}
|
||||
|
|
@ -304,10 +413,18 @@ async function createMCPTool({
|
|||
toolName,
|
||||
serverName,
|
||||
toolDefinition,
|
||||
streamId,
|
||||
});
|
||||
}
|
||||
|
||||
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
|
||||
function createToolInstance({
|
||||
res,
|
||||
toolName,
|
||||
serverName,
|
||||
toolDefinition,
|
||||
provider: _provider,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {LCTool} */
|
||||
const { description, parameters } = toolDefinition;
|
||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||
|
|
@ -343,6 +460,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
|
|||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const oauthStart = createOAuthStart({
|
||||
flowId,
|
||||
|
|
@ -353,6 +471,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
|
|||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (derivedSignal) {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,4 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
// Mock all dependencies - define mocks before imports
|
||||
// Mock all dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({
|
|||
},
|
||||
}));
|
||||
|
||||
// Create mock registry instance
|
||||
const mockRegistryInstance = {
|
||||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
||||
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
// Create isMCPDomainAllowed mock that can be configured per-test
|
||||
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
|
||||
|
||||
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
|
||||
|
||||
jest.mock('@librechat/api', () => {
|
||||
// Access mock via getter to avoid hoisting issues
|
||||
return {
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
get isMCPDomainAllowed() {
|
||||
return mockIsMCPDomainAllowed;
|
||||
},
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
CacheKeys: {
|
||||
|
|
@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({
|
|||
|
||||
jest.mock('./Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
getAppConfig: jest.fn(),
|
||||
get getAppConfig() {
|
||||
return mockGetAppConfig;
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
|
|
@ -692,6 +708,18 @@ describe('User parameter passing tests', () => {
|
|||
createFlowWithHandler: jest.fn(),
|
||||
failFlow: jest.fn(),
|
||||
});
|
||||
|
||||
// Reset domain validation mock to default (allow all)
|
||||
mockIsMCPDomainAllowed.mockReset();
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
// Reset registry mocks
|
||||
mockRegistryInstance.getServerConfig.mockReset();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
||||
|
||||
// Reset getAppConfig mock to default (no restrictions)
|
||||
mockGetAppConfig.mockReset();
|
||||
mockGetAppConfig.mockResolvedValue({});
|
||||
});
|
||||
|
||||
describe('createMCPTools', () => {
|
||||
|
|
@ -887,6 +915,229 @@ describe('User parameter passing tests', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('Runtime domain validation', () => {
|
||||
it('should skip tool creation when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://disallowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Should return undefined for disallowed domain
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
|
||||
// Verify domain validation was called with correct parameters
|
||||
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
|
||||
{ url: 'https://disallowed-domain.com/sse' },
|
||||
['allowed-domain.com'],
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow tool creation when domain is allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'admin' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://allowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return true (domain allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
|
||||
});
|
||||
|
||||
it('should skip domain validation for stdio transports (no URL)', async () => {
|
||||
const mockUser = { id: 'stdio-test-user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config without URL (stdio transport)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
command: 'npx',
|
||||
args: ['@modelcontextprotocol/server'],
|
||||
});
|
||||
|
||||
// Mock getAppConfig (should not be called for stdio)
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
|
||||
});
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully without domain check
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
|
||||
expect(mockGetAppConfig).not.toHaveBeenCalled();
|
||||
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return empty array from createMCPTools when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTools({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
serverName: 'test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
config: serverConfig,
|
||||
});
|
||||
|
||||
// Should return empty array for disallowed domain
|
||||
expect(result).toEqual([]);
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed early
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
});
|
||||
|
||||
it('should use user role when fetching domain restrictions', async () => {
|
||||
const adminUser = { id: 'admin-user', role: 'admin' };
|
||||
const regularUser = { id: 'regular-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock different responses based on role
|
||||
mockGetAppConfig
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
|
||||
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Call with admin user
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: adminUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Reset and call with regular user
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: regularUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Verify getAppConfig was called with correct roles
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('User parameter integrity', () => {
|
||||
it('should preserve user object properties through the call chain', async () => {
|
||||
const complexUser = {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ const {
|
|||
} = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
ErrorTypes,
|
||||
ContentTypes,
|
||||
imageGenTools,
|
||||
|
|
@ -18,6 +17,7 @@ const {
|
|||
ImageVisionTool,
|
||||
openapiToFunction,
|
||||
AgentCapabilities,
|
||||
isEphemeralAgentId,
|
||||
validateActionDomain,
|
||||
defaultAgentCapabilities,
|
||||
validateAndParseOpenAPISpec,
|
||||
|
|
@ -369,7 +369,15 @@ async function processRequiredActions(client, requiredActions) {
|
|||
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
|
||||
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
|
||||
*/
|
||||
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
|
||||
async function loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
signal,
|
||||
tool_resources,
|
||||
openAIApiKey,
|
||||
streamId = null,
|
||||
}) {
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
return {};
|
||||
} else if (
|
||||
|
|
@ -385,7 +393,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
|
||||
/** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */
|
||||
if (enabledCapabilities.size === 0 && agent.id === Constants.EPHEMERAL_AGENT_ID) {
|
||||
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
|
||||
enabledCapabilities = new Set(
|
||||
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
|
||||
);
|
||||
|
|
@ -422,7 +430,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
/** @type {ReturnType<typeof createOnSearchResults>} */
|
||||
let webSearchCallbacks;
|
||||
if (includesWebSearch) {
|
||||
webSearchCallbacks = createOnSearchResults(res);
|
||||
webSearchCallbacks = createOnSearchResults(res, streamId);
|
||||
}
|
||||
|
||||
/** @type {Record<string, Record<string, string>>} */
|
||||
|
|
@ -622,6 +630,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
encrypted,
|
||||
name: toolName,
|
||||
description: functionSig.description,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (!tool) {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,29 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { Tools } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { GenerationJobManager } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Helper to write attachment events either to res or to job emitter.
|
||||
* @param {import('http').ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} attachment - The attachment data
|
||||
*/
|
||||
function writeAttachment(res, streamId, attachment) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
|
||||
} else {
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a function to handle search results and stream them as attachments
|
||||
* @param {import('http').ServerResponse} res - The HTTP server response object
|
||||
* @param {string | null} [streamId] - The stream ID for resumable mode, or null for standard mode
|
||||
* @returns {{ onSearchResults: function(SearchResult, GraphRunnableConfig): void; onGetHighlights: function(string): void}} - Function that takes search results and returns or streams an attachment
|
||||
*/
|
||||
function createOnSearchResults(res) {
|
||||
function createOnSearchResults(res, streamId = null) {
|
||||
const context = {
|
||||
sourceMap: new Map(),
|
||||
searchResultData: undefined,
|
||||
|
|
@ -70,7 +86,7 @@ function createOnSearchResults(res) {
|
|||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -92,7 +108,7 @@ function createOnSearchResults(res) {
|
|||
}
|
||||
|
||||
const attachment = buildAttachment(context);
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -9,29 +9,31 @@ const { createMCPServersRegistry, createMCPManager } = require('~/config');
|
|||
async function initializeMCPs() {
|
||||
const appConfig = await getAppConfig();
|
||||
const mcpServers = appConfig.mcpConfig;
|
||||
if (!mcpServers) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize MCPServersRegistry first (required for MCPManager)
|
||||
try {
|
||||
createMCPServersRegistry(mongoose);
|
||||
createMCPServersRegistry(mongoose, appConfig?.mcpSettings?.allowedDomains);
|
||||
} catch (error) {
|
||||
logger.error('[MCP] Failed to initialize MCPServersRegistry:', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
const mcpManager = await createMCPManager(mcpServers);
|
||||
|
||||
try {
|
||||
const mcpTools = (await mcpManager.getAppToolFunctions()) || {};
|
||||
await mergeAppTools(mcpTools);
|
||||
const mcpManager = await createMCPManager(mcpServers || {});
|
||||
|
||||
logger.info(
|
||||
`MCP servers initialized successfully. Added ${Object.keys(mcpTools).length} MCP tools.`,
|
||||
);
|
||||
if (mcpServers && Object.keys(mcpServers).length > 0) {
|
||||
const mcpTools = (await mcpManager.getAppToolFunctions()) || {};
|
||||
await mergeAppTools(mcpTools);
|
||||
const serverCount = Object.keys(mcpServers).length;
|
||||
const toolCount = Object.keys(mcpTools).length;
|
||||
logger.info(
|
||||
`[MCP] Initialized with ${serverCount} configured ${serverCount === 1 ? 'server' : 'servers'} and ${toolCount} ${toolCount === 1 ? 'tool' : 'tools'}.`,
|
||||
);
|
||||
} else {
|
||||
logger.debug('[MCP] No servers configured. MCPManager ready for UI-based servers.');
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize MCP servers:', error);
|
||||
logger.error('[MCP] Failed to initialize MCPManager:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
281
api/server/services/initializeMCPs.spec.js
Normal file
281
api/server/services/initializeMCPs.spec.js
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
/**
|
||||
* Tests for initializeMCPs.js
|
||||
*
|
||||
* These tests verify that MCPServersRegistry and MCPManager are ALWAYS initialized,
|
||||
* even when no explicitly configured MCP servers exist. This is critical for the
|
||||
* "Dynamic MCP Server Management" feature (v0.8.2-rc1) which allows users to
|
||||
* add MCP servers via the UI without requiring explicit configuration.
|
||||
*
|
||||
* Bug fixed: Previously, MCPManager was only initialized when mcpServers existed
|
||||
* in librechat.yaml, causing "MCPManager has not been initialized" errors when
|
||||
* users tried to create MCP servers via the UI.
|
||||
*/
|
||||
|
||||
// Mock dependencies before imports
|
||||
jest.mock('mongoose', () => ({
|
||||
connection: { readyState: 1 },
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock config functions
|
||||
const mockGetAppConfig = jest.fn();
|
||||
const mockMergeAppTools = jest.fn();
|
||||
|
||||
jest.mock('./Config', () => ({
|
||||
get getAppConfig() {
|
||||
return mockGetAppConfig;
|
||||
},
|
||||
get mergeAppTools() {
|
||||
return mockMergeAppTools;
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock MCP singletons
|
||||
const mockCreateMCPServersRegistry = jest.fn();
|
||||
const mockCreateMCPManager = jest.fn();
|
||||
const mockMCPManagerInstance = {
|
||||
getAppToolFunctions: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
get createMCPServersRegistry() {
|
||||
return mockCreateMCPServersRegistry;
|
||||
},
|
||||
get createMCPManager() {
|
||||
return mockCreateMCPManager;
|
||||
},
|
||||
}));
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const initializeMCPs = require('./initializeMCPs');
|
||||
|
||||
describe('initializeMCPs', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Default: successful initialization
|
||||
mockCreateMCPServersRegistry.mockReturnValue(undefined);
|
||||
mockCreateMCPManager.mockResolvedValue(mockMCPManagerInstance);
|
||||
mockMCPManagerInstance.getAppToolFunctions.mockResolvedValue({});
|
||||
mockMergeAppTools.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
describe('MCPServersRegistry initialization', () => {
|
||||
it('should ALWAYS initialize MCPServersRegistry even without configured servers', async () => {
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpConfig: null, // No configured servers
|
||||
mcpSettings: { allowedDomains: ['localhost'] },
|
||||
});
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(mockCreateMCPServersRegistry).toHaveBeenCalledTimes(1);
|
||||
expect(mockCreateMCPServersRegistry).toHaveBeenCalledWith(
|
||||
expect.anything(), // mongoose
|
||||
['localhost'],
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass allowedDomains from mcpSettings to registry', async () => {
|
||||
const allowedDomains = ['localhost', '*.example.com', 'trusted-mcp.com'];
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpConfig: null,
|
||||
mcpSettings: { allowedDomains },
|
||||
});
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(mockCreateMCPServersRegistry).toHaveBeenCalledWith(expect.anything(), allowedDomains);
|
||||
});
|
||||
|
||||
it('should handle undefined mcpSettings gracefully', async () => {
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpConfig: null,
|
||||
// mcpSettings is undefined
|
||||
});
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(mockCreateMCPServersRegistry).toHaveBeenCalledWith(expect.anything(), undefined);
|
||||
});
|
||||
|
||||
it('should throw and log error if MCPServersRegistry initialization fails', async () => {
|
||||
const registryError = new Error('Registry initialization failed');
|
||||
mockCreateMCPServersRegistry.mockImplementation(() => {
|
||||
throw registryError;
|
||||
});
|
||||
mockGetAppConfig.mockResolvedValue({ mcpConfig: null });
|
||||
|
||||
await expect(initializeMCPs()).rejects.toThrow('Registry initialization failed');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'[MCP] Failed to initialize MCPServersRegistry:',
|
||||
registryError,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('MCPManager initialization', () => {
|
||||
it('should ALWAYS initialize MCPManager even without configured servers', async () => {
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpConfig: null, // No configured servers
|
||||
});
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
// MCPManager should be created with empty object when no configured servers
|
||||
expect(mockCreateMCPManager).toHaveBeenCalledTimes(1);
|
||||
expect(mockCreateMCPManager).toHaveBeenCalledWith({});
|
||||
});
|
||||
|
||||
it('should initialize MCPManager with configured servers when provided', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': { type: 'sse', url: 'http://localhost:3001/sse' },
|
||||
'local-server': { type: 'stdio', command: 'node', args: ['server.js'] },
|
||||
};
|
||||
mockGetAppConfig.mockResolvedValue({ mcpConfig: mcpServers });
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(mockCreateMCPManager).toHaveBeenCalledWith(mcpServers);
|
||||
});
|
||||
|
||||
it('should throw and log error if MCPManager initialization fails', async () => {
|
||||
const managerError = new Error('Manager initialization failed');
|
||||
mockCreateMCPManager.mockRejectedValue(managerError);
|
||||
mockGetAppConfig.mockResolvedValue({ mcpConfig: null });
|
||||
|
||||
await expect(initializeMCPs()).rejects.toThrow('Manager initialization failed');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'[MCP] Failed to initialize MCPManager:',
|
||||
managerError,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool merging behavior', () => {
|
||||
it('should NOT merge tools when no configured servers exist', async () => {
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpConfig: null, // No configured servers
|
||||
});
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
|
||||
expect(mockMergeAppTools).not.toHaveBeenCalled();
|
||||
expect(logger.debug).toHaveBeenCalledWith(
|
||||
'[MCP] No servers configured. MCPManager ready for UI-based servers.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT merge tools when mcpConfig is empty object', async () => {
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpConfig: {}, // Empty object
|
||||
});
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(mockMCPManagerInstance.getAppToolFunctions).not.toHaveBeenCalled();
|
||||
expect(mockMergeAppTools).not.toHaveBeenCalled();
|
||||
expect(logger.debug).toHaveBeenCalledWith(
|
||||
'[MCP] No servers configured. MCPManager ready for UI-based servers.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should merge tools when configured servers exist', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': { type: 'sse', url: 'http://localhost:3001/sse' },
|
||||
};
|
||||
const mcpTools = {
|
||||
tool1: jest.fn(),
|
||||
tool2: jest.fn(),
|
||||
};
|
||||
mockGetAppConfig.mockResolvedValue({ mcpConfig: mcpServers });
|
||||
mockMCPManagerInstance.getAppToolFunctions.mockResolvedValue(mcpTools);
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(mockMCPManagerInstance.getAppToolFunctions).toHaveBeenCalledTimes(1);
|
||||
expect(mockMergeAppTools).toHaveBeenCalledWith(mcpTools);
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
'[MCP] Initialized with 1 configured server and 2 tools.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle null return from getAppToolFunctions', async () => {
|
||||
const mcpServers = { 'test-server': { type: 'sse', url: 'http://localhost:3001' } };
|
||||
mockGetAppConfig.mockResolvedValue({ mcpConfig: mcpServers });
|
||||
mockMCPManagerInstance.getAppToolFunctions.mockResolvedValue(null);
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
// Should use empty object fallback
|
||||
expect(mockMergeAppTools).toHaveBeenCalledWith({});
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
'[MCP] Initialized with 1 configured server and 0 tools.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Initialization order', () => {
|
||||
it('should initialize Registry before Manager', async () => {
|
||||
const callOrder = [];
|
||||
|
||||
mockCreateMCPServersRegistry.mockImplementation(() => {
|
||||
callOrder.push('registry');
|
||||
});
|
||||
mockCreateMCPManager.mockImplementation(async () => {
|
||||
callOrder.push('manager');
|
||||
return mockMCPManagerInstance;
|
||||
});
|
||||
mockGetAppConfig.mockResolvedValue({ mcpConfig: null });
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
expect(callOrder).toEqual(['registry', 'manager']);
|
||||
});
|
||||
|
||||
it('should not attempt MCPManager initialization if Registry fails', async () => {
|
||||
mockCreateMCPServersRegistry.mockImplementation(() => {
|
||||
throw new Error('Registry failed');
|
||||
});
|
||||
mockGetAppConfig.mockResolvedValue({ mcpConfig: null });
|
||||
|
||||
await expect(initializeMCPs()).rejects.toThrow('Registry failed');
|
||||
expect(mockCreateMCPManager).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('UI-based MCP server management support', () => {
|
||||
/**
|
||||
* This test documents the critical fix:
|
||||
* MCPManager must be initialized even without configured servers to support
|
||||
* the "Dynamic MCP Server Management" feature where users create
|
||||
* MCP servers via the UI.
|
||||
*/
|
||||
it('should support UI-based server creation without explicit configuration', async () => {
|
||||
// Scenario: User has no MCP servers in librechat.yaml but wants to
|
||||
// add servers via the UI
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpConfig: null,
|
||||
mcpSettings: undefined,
|
||||
});
|
||||
|
||||
await initializeMCPs();
|
||||
|
||||
// Both singletons must be initialized for UI-based management to work
|
||||
expect(mockCreateMCPServersRegistry).toHaveBeenCalledTimes(1);
|
||||
expect(mockCreateMCPManager).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Verify manager was created with empty config (not null/undefined)
|
||||
expect(mockCreateMCPManager).toHaveBeenCalledWith({});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -5,7 +5,7 @@ const { Calculator } = require('@librechat/agents');
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { zodToJsonSchema } = require('zod-to-json-schema');
|
||||
const { Tools, ImageVisionTool } = require('librechat-data-provider');
|
||||
const { getToolkitKey, oaiToolkit, ytToolkit } = require('@librechat/api');
|
||||
const { getToolkitKey, oaiToolkit, ytToolkit, geminiToolkit } = require('@librechat/api');
|
||||
const { toolkits } = require('~/app/clients/tools/manifest');
|
||||
|
||||
/**
|
||||
|
|
@ -84,6 +84,7 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
|
|||
new Calculator(),
|
||||
...Object.values(oaiToolkit),
|
||||
...Object.values(ytToolkit),
|
||||
...Object.values(geminiToolkit),
|
||||
];
|
||||
for (const toolInstance of basicToolInstances) {
|
||||
const formattedTool = formatToOpenAIAssistantTool(toolInstance);
|
||||
|
|
|
|||
|
|
@ -243,6 +243,133 @@ describe('Import Timestamp Ordering', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('ChatGPT Import - Timestamp Issues', () => {
|
||||
test('should correct timestamp inversions (child before parent)', async () => {
|
||||
// Simulate ChatGPT export with timestamp inversion (like tool call results)
|
||||
const jsonData = [
|
||||
{
|
||||
title: 'Timestamp Inversion Test',
|
||||
create_time: 1000,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['parent-msg'],
|
||||
},
|
||||
'parent-msg': {
|
||||
id: 'parent-msg',
|
||||
message: {
|
||||
id: 'parent-msg',
|
||||
author: { role: 'user' },
|
||||
create_time: 1000.1, // Parent: 1000.1
|
||||
content: { content_type: 'text', parts: ['Parent message'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['child-msg'],
|
||||
},
|
||||
'child-msg': {
|
||||
id: 'child-msg',
|
||||
message: {
|
||||
id: 'child-msg',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1000.095, // Child: 1000.095 (5ms BEFORE parent)
|
||||
content: { content_type: 'text', parts: ['Child message'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'parent-msg',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
const parent = savedMessages.find((msg) => msg.text === 'Parent message');
|
||||
const child = savedMessages.find((msg) => msg.text === 'Child message');
|
||||
|
||||
expect(parent).toBeDefined();
|
||||
expect(child).toBeDefined();
|
||||
|
||||
// Child timestamp should be adjusted to be after parent
|
||||
expect(new Date(child.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
test('should use conv.create_time for null message timestamps', async () => {
|
||||
const convCreateTime = 1500000000; // Conversation create time
|
||||
const jsonData = [
|
||||
{
|
||||
title: 'Null Timestamp Test',
|
||||
create_time: convCreateTime,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['msg-with-null-time'],
|
||||
},
|
||||
'msg-with-null-time': {
|
||||
id: 'msg-with-null-time',
|
||||
message: {
|
||||
id: 'msg-with-null-time',
|
||||
author: { role: 'user' },
|
||||
create_time: null, // Null timestamp
|
||||
content: { content_type: 'text', parts: ['Message with null time'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['msg-with-valid-time'],
|
||||
},
|
||||
'msg-with-valid-time': {
|
||||
id: 'msg-with-valid-time',
|
||||
message: {
|
||||
id: 'msg-with-valid-time',
|
||||
author: { role: 'assistant' },
|
||||
create_time: convCreateTime + 10, // Valid timestamp
|
||||
content: { content_type: 'text', parts: ['Message with valid time'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'msg-with-null-time',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
const nullTimeMsg = savedMessages.find((msg) => msg.text === 'Message with null time');
|
||||
const validTimeMsg = savedMessages.find((msg) => msg.text === 'Message with valid time');
|
||||
|
||||
expect(nullTimeMsg).toBeDefined();
|
||||
expect(validTimeMsg).toBeDefined();
|
||||
|
||||
// Null timestamp should fall back to conv.create_time
|
||||
expect(nullTimeMsg.createdAt).toEqual(new Date(convCreateTime * 1000));
|
||||
|
||||
// Child should still be after parent (timestamp adjustment)
|
||||
expect(new Date(validTimeMsg.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(nullTimeMsg.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Comparison with Fork Functionality', () => {
|
||||
test('fork functionality correctly handles timestamp issues (for comparison)', async () => {
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
|
|
|
|||
|
|
@ -13,8 +13,14 @@ const getLogStores = require('~/cache/getLogStores');
|
|||
* @throws {Error} - If the import type is not supported.
|
||||
*/
|
||||
function getImporter(jsonData) {
|
||||
// For ChatGPT
|
||||
// For array-based formats (ChatGPT or Claude)
|
||||
if (Array.isArray(jsonData)) {
|
||||
// Claude format has chat_messages array in each conversation
|
||||
if (jsonData.length > 0 && jsonData[0]?.chat_messages) {
|
||||
logger.info('Importing Claude conversation');
|
||||
return importClaudeConvo;
|
||||
}
|
||||
// ChatGPT format has mapping object in each conversation
|
||||
logger.info('Importing ChatGPT conversation');
|
||||
return importChatGptConvo;
|
||||
}
|
||||
|
|
@ -71,6 +77,111 @@ async function importChatBotUiConvo(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts text and thinking content from a Claude message.
|
||||
* @param {Object} msg - Claude message object with content array and optional text field.
|
||||
* @returns {{textContent: string, thinkingContent: string}} Extracted text and thinking content.
|
||||
*/
|
||||
function extractClaudeContent(msg) {
|
||||
let textContent = '';
|
||||
let thinkingContent = '';
|
||||
|
||||
for (const part of msg.content || []) {
|
||||
if (part.type === 'text' && part.text) {
|
||||
textContent += part.text;
|
||||
} else if (part.type === 'thinking' && part.thinking) {
|
||||
thinkingContent += part.thinking;
|
||||
}
|
||||
}
|
||||
|
||||
// Use the text field as fallback if content array is empty
|
||||
if (!textContent && msg.text) {
|
||||
textContent = msg.text;
|
||||
}
|
||||
|
||||
return { textContent, thinkingContent };
|
||||
}
|
||||
|
||||
/**
|
||||
* Imports Claude conversations from provided JSON data.
|
||||
* Claude export format: array of conversations with chat_messages array.
|
||||
*
|
||||
* @param {Array} jsonData - Array of Claude conversation objects to be imported.
|
||||
* @param {string} requestUserId - The ID of the user who initiated the import process.
|
||||
* @param {Function} builderFactory - Factory function to create a new import batch builder instance.
|
||||
* @returns {Promise<void>} Promise that resolves when all conversations have been imported.
|
||||
*/
|
||||
async function importClaudeConvo(
|
||||
jsonData,
|
||||
requestUserId,
|
||||
builderFactory = createImportBatchBuilder,
|
||||
) {
|
||||
try {
|
||||
const importBatchBuilder = builderFactory(requestUserId);
|
||||
|
||||
for (const conv of jsonData) {
|
||||
importBatchBuilder.startConversation(EModelEndpoint.anthropic);
|
||||
|
||||
let lastMessageId = Constants.NO_PARENT;
|
||||
let lastTimestamp = null;
|
||||
|
||||
for (const msg of conv.chat_messages || []) {
|
||||
const isCreatedByUser = msg.sender === 'human';
|
||||
const messageId = uuidv4();
|
||||
|
||||
const { textContent, thinkingContent } = extractClaudeContent(msg);
|
||||
|
||||
// Skip empty messages
|
||||
if (!textContent && !thinkingContent) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse timestamp, fallback to conversation create_time or current time
|
||||
const messageTime = msg.created_at || conv.created_at;
|
||||
let createdAt = messageTime ? new Date(messageTime) : new Date();
|
||||
|
||||
// Ensure timestamp is after the previous message.
|
||||
// Messages are sorted by createdAt and buildTree expects parents to appear before children.
|
||||
// This guards against any potential ordering issues in exports.
|
||||
if (lastTimestamp && createdAt <= lastTimestamp) {
|
||||
createdAt = new Date(lastTimestamp.getTime() + 1);
|
||||
}
|
||||
lastTimestamp = createdAt;
|
||||
|
||||
const message = {
|
||||
messageId,
|
||||
parentMessageId: lastMessageId,
|
||||
text: textContent,
|
||||
sender: isCreatedByUser ? 'user' : 'Claude',
|
||||
isCreatedByUser,
|
||||
user: requestUserId,
|
||||
endpoint: EModelEndpoint.anthropic,
|
||||
createdAt,
|
||||
};
|
||||
|
||||
// Add content array with thinking if present
|
||||
if (thinkingContent && !isCreatedByUser) {
|
||||
message.content = [
|
||||
{ type: 'think', think: thinkingContent },
|
||||
{ type: 'text', text: textContent },
|
||||
];
|
||||
}
|
||||
|
||||
importBatchBuilder.saveMessage(message);
|
||||
lastMessageId = messageId;
|
||||
}
|
||||
|
||||
const createdAt = conv.created_at ? new Date(conv.created_at) : new Date();
|
||||
importBatchBuilder.finishConversation(conv.name || 'Imported Claude Chat', createdAt);
|
||||
}
|
||||
|
||||
await importBatchBuilder.saveBatch();
|
||||
logger.info(`user: ${requestUserId} | Claude conversation imported`);
|
||||
} catch (error) {
|
||||
logger.error(`user: ${requestUserId} | Error creating conversation from Claude file`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Imports a LibreChat conversation from JSON.
|
||||
*
|
||||
|
|
@ -213,11 +324,11 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Helper function to find the nearest non-system parent
|
||||
* Helper function to find the nearest valid parent (skips system, reasoning_recap, and thoughts messages)
|
||||
* @param {string} parentId - The ID of the parent message.
|
||||
* @returns {string} The ID of the nearest non-system parent message.
|
||||
* @returns {string} The ID of the nearest valid parent message.
|
||||
*/
|
||||
const findNonSystemParent = (parentId) => {
|
||||
const findValidParent = (parentId) => {
|
||||
if (!parentId || !messageMap.has(parentId)) {
|
||||
return Constants.NO_PARENT;
|
||||
}
|
||||
|
|
@ -227,14 +338,62 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
|
|||
return Constants.NO_PARENT;
|
||||
}
|
||||
|
||||
/* If parent is a system message, traverse up to find the nearest non-system parent */
|
||||
if (parentMapping.message.author?.role === 'system') {
|
||||
return findNonSystemParent(parentMapping.parent);
|
||||
/* If parent is a system message, reasoning_recap, or thoughts, traverse up to find the nearest valid parent */
|
||||
const contentType = parentMapping.message.content?.content_type;
|
||||
const shouldSkip =
|
||||
parentMapping.message.author?.role === 'system' ||
|
||||
contentType === 'reasoning_recap' ||
|
||||
contentType === 'thoughts';
|
||||
|
||||
if (shouldSkip) {
|
||||
return findValidParent(parentMapping.parent);
|
||||
}
|
||||
|
||||
return messageMap.get(parentId);
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper function to find thinking content from parent chain (thoughts messages)
|
||||
* @param {string} parentId - The ID of the parent message.
|
||||
* @param {Set} visited - Set of already-visited IDs to prevent cycles.
|
||||
* @returns {Array} The thinking content array (empty if not found).
|
||||
*/
|
||||
const findThinkingContent = (parentId, visited = new Set()) => {
|
||||
// Guard against circular references in malformed imports
|
||||
if (!parentId || visited.has(parentId)) {
|
||||
return [];
|
||||
}
|
||||
visited.add(parentId);
|
||||
|
||||
const parentMapping = conv.mapping[parentId];
|
||||
if (!parentMapping?.message) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const contentType = parentMapping.message.content?.content_type;
|
||||
|
||||
// If this is a thoughts message, extract the thinking content
|
||||
if (contentType === 'thoughts') {
|
||||
const thoughts = parentMapping.message.content.thoughts || [];
|
||||
const thinkingText = thoughts
|
||||
.map((t) => t.content || t.summary || '')
|
||||
.filter(Boolean)
|
||||
.join('\n\n');
|
||||
|
||||
if (thinkingText) {
|
||||
return [{ type: 'think', think: thinkingText }];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
// If this is reasoning_recap, look at its parent for thoughts
|
||||
if (contentType === 'reasoning_recap') {
|
||||
return findThinkingContent(parentMapping.parent, visited);
|
||||
}
|
||||
|
||||
return [];
|
||||
};
|
||||
|
||||
// Create and save messages using the mapped IDs
|
||||
const messages = [];
|
||||
for (const [id, mapping] of Object.entries(conv.mapping)) {
|
||||
|
|
@ -247,8 +406,20 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
|
|||
continue;
|
||||
}
|
||||
|
||||
const contentType = mapping.message.content?.content_type;
|
||||
|
||||
// Skip thoughts messages - they will be merged into the response message
|
||||
if (contentType === 'thoughts') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip reasoning_recap messages (just summaries like "Thought for 44s")
|
||||
if (contentType === 'reasoning_recap') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const newMessageId = messageMap.get(id);
|
||||
const parentMessageId = findNonSystemParent(mapping.parent);
|
||||
const parentMessageId = findValidParent(mapping.parent);
|
||||
|
||||
const messageText = formatMessageText(mapping.message);
|
||||
|
||||
|
|
@ -266,7 +437,12 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
|
|||
}
|
||||
}
|
||||
|
||||
messages.push({
|
||||
// Use create_time from ChatGPT export to ensure proper message ordering
|
||||
// For null timestamps, use the conversation's create_time as fallback, or current time as last resort
|
||||
const messageTime = mapping.message.create_time || conv.create_time;
|
||||
const createdAt = messageTime ? new Date(messageTime * 1000) : new Date();
|
||||
|
||||
const message = {
|
||||
messageId: newMessageId,
|
||||
parentMessageId,
|
||||
text: messageText,
|
||||
|
|
@ -275,9 +451,23 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
|
|||
model,
|
||||
user: requestUserId,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
createdAt,
|
||||
};
|
||||
|
||||
// For assistant messages, check if there's thinking content in the parent chain
|
||||
if (!isCreatedByUser) {
|
||||
const thinkingContent = findThinkingContent(mapping.parent);
|
||||
if (thinkingContent.length > 0) {
|
||||
// Combine thinking content with the text response
|
||||
message.content = [...thinkingContent, { type: 'text', text: messageText }];
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
adjustTimestampsForOrdering(messages);
|
||||
|
||||
for (const message of messages) {
|
||||
importBatchBuilder.saveMessage(message);
|
||||
}
|
||||
|
|
@ -325,17 +515,18 @@ function processAssistantMessage(messageData, messageText) {
|
|||
/**
|
||||
* Formats the text content of a message based on its content type and author role.
|
||||
* @param {ChatGPTMessage} messageData - The message data.
|
||||
* @returns {string} - The updated message text after processing.
|
||||
* @returns {string} - The formatted message text.
|
||||
*/
|
||||
function formatMessageText(messageData) {
|
||||
const isText = messageData.content.content_type === 'text';
|
||||
const contentType = messageData.content.content_type;
|
||||
const isText = contentType === 'text';
|
||||
let messageText = '';
|
||||
|
||||
if (isText && messageData.content.parts) {
|
||||
messageText = messageData.content.parts.join(' ');
|
||||
} else if (messageData.content.content_type === 'code') {
|
||||
} else if (contentType === 'code') {
|
||||
messageText = `\`\`\`${messageData.content.language}\n${messageData.content.text}\n\`\`\``;
|
||||
} else if (messageData.content.content_type === 'execution_output') {
|
||||
} else if (contentType === 'execution_output') {
|
||||
messageText = `Execution Output:\n> ${messageData.content.text}`;
|
||||
} else if (messageData.content.parts) {
|
||||
for (const part of messageData.content.parts) {
|
||||
|
|
@ -357,4 +548,33 @@ function formatMessageText(messageData) {
|
|||
return messageText;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjusts message timestamps to ensure children always come after parents.
|
||||
* Messages are sorted by createdAt and buildTree expects parents to appear before children.
|
||||
* ChatGPT exports can have slight timestamp inversions (e.g., tool call results
|
||||
* arriving a few ms before their parent). Uses multiple passes to handle cascading adjustments.
|
||||
*
|
||||
* @param {Array} messages - Array of message objects with messageId, parentMessageId, and createdAt.
|
||||
*/
|
||||
function adjustTimestampsForOrdering(messages) {
|
||||
const timestampMap = new Map();
|
||||
messages.forEach((msg) => timestampMap.set(msg.messageId, msg.createdAt));
|
||||
|
||||
let hasChanges = true;
|
||||
while (hasChanges) {
|
||||
hasChanges = false;
|
||||
for (const message of messages) {
|
||||
if (message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT) {
|
||||
const parentTimestamp = timestampMap.get(message.parentMessageId);
|
||||
if (parentTimestamp && message.createdAt <= parentTimestamp) {
|
||||
// Bump child timestamp to 1ms after parent
|
||||
message.createdAt = new Date(parentTimestamp.getTime() + 1);
|
||||
timestampMap.set(message.messageId, message.createdAt);
|
||||
hasChanges = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { getImporter, processAssistantMessage };
|
||||
|
|
|
|||
|
|
@ -497,6 +497,262 @@ describe('importChatGptConvo', () => {
|
|||
expect(userMsg.sender).toBe('user');
|
||||
expect(userMsg.isCreatedByUser).toBe(true);
|
||||
});
|
||||
|
||||
it('should merge thinking content into assistant message', async () => {
|
||||
const testData = [
|
||||
{
|
||||
title: 'Thinking Content Test',
|
||||
create_time: 1000,
|
||||
update_time: 2000,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['user-msg-1'],
|
||||
},
|
||||
'user-msg-1': {
|
||||
id: 'user-msg-1',
|
||||
message: {
|
||||
id: 'user-msg-1',
|
||||
author: { role: 'user' },
|
||||
create_time: 1,
|
||||
content: { content_type: 'text', parts: ['What is 2+2?'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['thoughts-msg'],
|
||||
},
|
||||
'thoughts-msg': {
|
||||
id: 'thoughts-msg',
|
||||
message: {
|
||||
id: 'thoughts-msg',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 2,
|
||||
content: {
|
||||
content_type: 'thoughts',
|
||||
thoughts: [
|
||||
{ content: 'Let me think about this math problem.' },
|
||||
{ content: 'Adding 2 and 2 together gives 4.' },
|
||||
],
|
||||
},
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'user-msg-1',
|
||||
children: ['reasoning-recap-msg'],
|
||||
},
|
||||
'reasoning-recap-msg': {
|
||||
id: 'reasoning-recap-msg',
|
||||
message: {
|
||||
id: 'reasoning-recap-msg',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 3,
|
||||
content: {
|
||||
content_type: 'reasoning_recap',
|
||||
recap_text: 'Thought for 2 seconds',
|
||||
},
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'thoughts-msg',
|
||||
children: ['assistant-msg-1'],
|
||||
},
|
||||
'assistant-msg-1': {
|
||||
id: 'assistant-msg-1',
|
||||
message: {
|
||||
id: 'assistant-msg-1',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 4,
|
||||
content: { content_type: 'text', parts: ['The answer is 4.'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'reasoning-recap-msg',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(testData);
|
||||
await importer(testData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
|
||||
// Should only have 2 messages: user message and assistant response
|
||||
// (thoughts and reasoning_recap should be merged/skipped)
|
||||
expect(savedMessages).toHaveLength(2);
|
||||
|
||||
const userMsg = savedMessages.find((msg) => msg.text === 'What is 2+2?');
|
||||
const assistantMsg = savedMessages.find((msg) => msg.text === 'The answer is 4.');
|
||||
|
||||
expect(userMsg).toBeDefined();
|
||||
expect(assistantMsg).toBeDefined();
|
||||
|
||||
// Assistant message should have content array with thinking block
|
||||
expect(assistantMsg.content).toBeDefined();
|
||||
expect(assistantMsg.content).toHaveLength(2);
|
||||
expect(assistantMsg.content[0].type).toBe('think');
|
||||
expect(assistantMsg.content[0].think).toContain('Let me think about this math problem.');
|
||||
expect(assistantMsg.content[0].think).toContain('Adding 2 and 2 together gives 4.');
|
||||
expect(assistantMsg.content[1].type).toBe('text');
|
||||
expect(assistantMsg.content[1].text).toBe('The answer is 4.');
|
||||
|
||||
// Verify parent-child relationship is correct (skips thoughts and reasoning_recap)
|
||||
expect(assistantMsg.parentMessageId).toBe(userMsg.messageId);
|
||||
});
|
||||
|
||||
it('should skip reasoning_recap and thoughts messages as separate entries', async () => {
|
||||
const testData = [
|
||||
{
|
||||
title: 'Skip Thinking Messages Test',
|
||||
create_time: 1000,
|
||||
update_time: 2000,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['user-msg-1'],
|
||||
},
|
||||
'user-msg-1': {
|
||||
id: 'user-msg-1',
|
||||
message: {
|
||||
id: 'user-msg-1',
|
||||
author: { role: 'user' },
|
||||
create_time: 1,
|
||||
content: { content_type: 'text', parts: ['Hello'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['thoughts-msg'],
|
||||
},
|
||||
'thoughts-msg': {
|
||||
id: 'thoughts-msg',
|
||||
message: {
|
||||
id: 'thoughts-msg',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 2,
|
||||
content: {
|
||||
content_type: 'thoughts',
|
||||
thoughts: [{ content: 'Thinking...' }],
|
||||
},
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'user-msg-1',
|
||||
children: ['reasoning-recap-msg'],
|
||||
},
|
||||
'reasoning-recap-msg': {
|
||||
id: 'reasoning-recap-msg',
|
||||
message: {
|
||||
id: 'reasoning-recap-msg',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 3,
|
||||
content: {
|
||||
content_type: 'reasoning_recap',
|
||||
recap_text: 'Thought for 1 second',
|
||||
},
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'thoughts-msg',
|
||||
children: ['assistant-msg-1'],
|
||||
},
|
||||
'assistant-msg-1': {
|
||||
id: 'assistant-msg-1',
|
||||
message: {
|
||||
id: 'assistant-msg-1',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 4,
|
||||
content: { content_type: 'text', parts: ['Hi there!'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'reasoning-recap-msg',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(testData);
|
||||
await importer(testData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
|
||||
// Verify no messages have thoughts or reasoning_recap content types
|
||||
const thoughtsMessages = savedMessages.filter(
|
||||
(msg) =>
|
||||
msg.text === '' || msg.text?.includes('Thinking...') || msg.text?.includes('Thought for'),
|
||||
);
|
||||
expect(thoughtsMessages).toHaveLength(0);
|
||||
|
||||
// Only user and assistant text messages should be saved
|
||||
expect(savedMessages).toHaveLength(2);
|
||||
expect(savedMessages.map((m) => m.text).sort()).toEqual(['Hello', 'Hi there!'].sort());
|
||||
});
|
||||
|
||||
it('should set createdAt from ChatGPT create_time', async () => {
|
||||
const testData = [
|
||||
{
|
||||
title: 'Timestamp Test',
|
||||
create_time: 1000,
|
||||
update_time: 2000,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['user-msg-1'],
|
||||
},
|
||||
'user-msg-1': {
|
||||
id: 'user-msg-1',
|
||||
message: {
|
||||
id: 'user-msg-1',
|
||||
author: { role: 'user' },
|
||||
create_time: 1000,
|
||||
content: { content_type: 'text', parts: ['Test message'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['assistant-msg-1'],
|
||||
},
|
||||
'assistant-msg-1': {
|
||||
id: 'assistant-msg-1',
|
||||
message: {
|
||||
id: 'assistant-msg-1',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 2000,
|
||||
content: { content_type: 'text', parts: ['Response'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'user-msg-1',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(testData);
|
||||
await importer(testData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
|
||||
const userMsg = savedMessages.find((msg) => msg.text === 'Test message');
|
||||
const assistantMsg = savedMessages.find((msg) => msg.text === 'Response');
|
||||
|
||||
// Verify createdAt is set from create_time (converted from Unix timestamp)
|
||||
expect(userMsg.createdAt).toEqual(new Date(1000 * 1000));
|
||||
expect(assistantMsg.createdAt).toEqual(new Date(2000 * 1000));
|
||||
});
|
||||
});
|
||||
|
||||
describe('importLibreChatConvo', () => {
|
||||
|
|
@ -1057,3 +1313,301 @@ describe('processAssistantMessage', () => {
|
|||
expect(duration).toBeLessThan(100);
|
||||
});
|
||||
});
|
||||
|
||||
describe('importClaudeConvo', () => {
|
||||
it('should import basic Claude conversation correctly', async () => {
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: 'Test Conversation',
|
||||
created_at: '2025-01-15T10:00:00.000Z',
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:01.000Z',
|
||||
content: [{ type: 'text', text: 'Hello Claude' }],
|
||||
},
|
||||
{
|
||||
uuid: 'msg-2',
|
||||
sender: 'assistant',
|
||||
created_at: '2025-01-15T10:00:02.000Z',
|
||||
content: [{ type: 'text', text: 'Hello! How can I help you?' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.anthropic);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(2);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledWith(
|
||||
'Test Conversation',
|
||||
expect.any(Date),
|
||||
);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
|
||||
// Check user message
|
||||
const userMsg = savedMessages.find((msg) => msg.text === 'Hello Claude');
|
||||
expect(userMsg.isCreatedByUser).toBe(true);
|
||||
expect(userMsg.sender).toBe('user');
|
||||
expect(userMsg.endpoint).toBe(EModelEndpoint.anthropic);
|
||||
|
||||
// Check assistant message
|
||||
const assistantMsg = savedMessages.find((msg) => msg.text === 'Hello! How can I help you?');
|
||||
expect(assistantMsg.isCreatedByUser).toBe(false);
|
||||
expect(assistantMsg.sender).toBe('Claude');
|
||||
expect(assistantMsg.parentMessageId).toBe(userMsg.messageId);
|
||||
});
|
||||
|
||||
it('should merge thinking content into assistant message', async () => {
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: 'Thinking Test',
|
||||
created_at: '2025-01-15T10:00:00.000Z',
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:01.000Z',
|
||||
content: [{ type: 'text', text: 'What is 2+2?' }],
|
||||
},
|
||||
{
|
||||
uuid: 'msg-2',
|
||||
sender: 'assistant',
|
||||
created_at: '2025-01-15T10:00:02.000Z',
|
||||
content: [
|
||||
{ type: 'thinking', thinking: 'Let me calculate this simple math problem.' },
|
||||
{ type: 'text', text: 'The answer is 4.' },
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
const assistantMsg = savedMessages.find((msg) => msg.text === 'The answer is 4.');
|
||||
|
||||
expect(assistantMsg.content).toBeDefined();
|
||||
expect(assistantMsg.content).toHaveLength(2);
|
||||
expect(assistantMsg.content[0].type).toBe('think');
|
||||
expect(assistantMsg.content[0].think).toBe('Let me calculate this simple math problem.');
|
||||
expect(assistantMsg.content[1].type).toBe('text');
|
||||
expect(assistantMsg.content[1].text).toBe('The answer is 4.');
|
||||
});
|
||||
|
||||
it('should not include model field (Claude exports do not contain model info)', async () => {
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: 'No Model Test',
|
||||
created_at: '2025-01-15T10:00:00.000Z',
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:01.000Z',
|
||||
content: [{ type: 'text', text: 'Hello' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
// Model should not be explicitly set (will use ImportBatchBuilder default)
|
||||
expect(savedMessages[0]).not.toHaveProperty('model');
|
||||
});
|
||||
|
||||
it('should correct timestamp inversions (child before parent)', async () => {
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: 'Timestamp Inversion Test',
|
||||
created_at: '2025-01-15T10:00:00.000Z',
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:05.000Z', // Later timestamp
|
||||
content: [{ type: 'text', text: 'First message' }],
|
||||
},
|
||||
{
|
||||
uuid: 'msg-2',
|
||||
sender: 'assistant',
|
||||
created_at: '2025-01-15T10:00:02.000Z', // Earlier timestamp (inverted)
|
||||
content: [{ type: 'text', text: 'Second message' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
const firstMsg = savedMessages.find((msg) => msg.text === 'First message');
|
||||
const secondMsg = savedMessages.find((msg) => msg.text === 'Second message');
|
||||
|
||||
// Second message should have timestamp adjusted to be after first
|
||||
expect(new Date(secondMsg.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(firstMsg.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use conversation create_time for null message timestamps', async () => {
|
||||
const convCreateTime = '2025-01-15T10:00:00.000Z';
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: 'Null Timestamp Test',
|
||||
created_at: convCreateTime,
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: null, // Null timestamp
|
||||
content: [{ type: 'text', text: 'Message with null time' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
expect(savedMessages[0].createdAt).toEqual(new Date(convCreateTime));
|
||||
});
|
||||
|
||||
it('should use text field as fallback when content array is empty', async () => {
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: 'Text Fallback Test',
|
||||
created_at: '2025-01-15T10:00:00.000Z',
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:01.000Z',
|
||||
text: 'Fallback text content',
|
||||
content: [], // Empty content array
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
expect(savedMessages[0].text).toBe('Fallback text content');
|
||||
});
|
||||
|
||||
it('should skip empty messages', async () => {
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: 'Skip Empty Test',
|
||||
created_at: '2025-01-15T10:00:00.000Z',
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:01.000Z',
|
||||
content: [{ type: 'text', text: 'Valid message' }],
|
||||
},
|
||||
{
|
||||
uuid: 'msg-2',
|
||||
sender: 'assistant',
|
||||
created_at: '2025-01-15T10:00:02.000Z',
|
||||
content: [], // Empty content
|
||||
text: '', // Empty text
|
||||
},
|
||||
{
|
||||
uuid: 'msg-3',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:03.000Z',
|
||||
content: [{ type: 'text', text: 'Another valid message' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Should only save 2 messages (empty one skipped)
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should use default name for unnamed conversations', async () => {
|
||||
const jsonData = [
|
||||
{
|
||||
uuid: 'conv-123',
|
||||
name: '', // Empty name
|
||||
created_at: '2025-01-15T10:00:00.000Z',
|
||||
chat_messages: [
|
||||
{
|
||||
uuid: 'msg-1',
|
||||
sender: 'human',
|
||||
created_at: '2025-01-15T10:00:01.000Z',
|
||||
content: [{ type: 'text', text: 'Hello' }],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledWith(
|
||||
'Imported Claude Chat',
|
||||
expect.any(Date),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ const jwksRsa = require('jwks-rsa');
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { isEnabled, findOpenIDUser, math } = require('@librechat/api');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { isEnabled, findOpenIDUser } = require('@librechat/api');
|
||||
const { updateUser, findUser } = require('~/models');
|
||||
|
||||
/**
|
||||
|
|
@ -27,9 +27,7 @@ const { updateUser, findUser } = require('~/models');
|
|||
const openIdJwtLogin = (openIdConfig) => {
|
||||
let jwksRsaOptions = {
|
||||
cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true,
|
||||
cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME
|
||||
? eval(process.env.OPENID_JWKS_URL_CACHE_TIME)
|
||||
: 60000,
|
||||
cacheMaxAge: math(process.env.OPENID_JWKS_URL_CACHE_TIME, 60000),
|
||||
jwksUri: openIdConfig.serverMetadata().jwks_uri,
|
||||
};
|
||||
|
||||
|
|
@ -83,10 +81,18 @@ const openIdJwtLogin = (openIdConfig) => {
|
|||
await updateUser(user.id, updateData);
|
||||
}
|
||||
|
||||
const cookieHeader = req.headers.cookie;
|
||||
const parsedCookies = cookieHeader ? cookies.parse(cookieHeader) : {};
|
||||
const accessToken = parsedCookies.openid_access_token;
|
||||
const refreshToken = parsedCookies.refreshToken;
|
||||
/** Read tokens from session (server-side) to avoid large cookie issues */
|
||||
const sessionTokens = req.session?.openidTokens;
|
||||
let accessToken = sessionTokens?.accessToken;
|
||||
let refreshToken = sessionTokens?.refreshToken;
|
||||
|
||||
/** Fallback to cookies for backward compatibility */
|
||||
if (!accessToken || !refreshToken) {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
const parsedCookies = cookieHeader ? cookies.parse(cookieHeader) : {};
|
||||
accessToken = accessToken || parsedCookies.openid_access_token;
|
||||
refreshToken = refreshToken || parsedCookies.refreshToken;
|
||||
}
|
||||
|
||||
user.federatedTokens = {
|
||||
access_token: accessToken || rawToken,
|
||||
|
|
|
|||
|
|
@ -11,3 +11,7 @@ OPENAI_API_KEY=your-api-key
|
|||
BAN_VIOLATIONS=true
|
||||
BAN_DURATION=7200000
|
||||
BAN_INTERVAL=20
|
||||
|
||||
# NODE_MAX_OLD_SPACE_SIZE is only used as a Docker build argument.
|
||||
# Node.js does NOT recognize this environment variable for heap size.
|
||||
NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
|
|
|||
162
api/test/app/clients/tools/structured/OpenAIImageTools.test.js
Normal file
162
api/test/app/clients/tools/structured/OpenAIImageTools.test.js
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
const OpenAI = require('openai');
|
||||
const createOpenAIImageTools = require('~/app/clients/tools/structured/OpenAIImageTools');
|
||||
|
||||
jest.mock('openai');
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
logAxiosError: jest.fn(),
|
||||
oaiToolkit: {
|
||||
image_gen_oai: {
|
||||
name: 'image_gen_oai',
|
||||
description: 'Generate an image',
|
||||
schema: {},
|
||||
},
|
||||
image_edit_oai: {
|
||||
name: 'image_edit_oai',
|
||||
description: 'Edit an image',
|
||||
schema: {},
|
||||
},
|
||||
},
|
||||
extractBaseURL: jest.fn((url) => url),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getFiles: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
describe('OpenAIImageTools - IMAGE_GEN_OAI_MODEL environment variable', () => {
|
||||
let originalEnv;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
originalEnv = { ...process.env };
|
||||
|
||||
process.env.IMAGE_GEN_OAI_API_KEY = 'test-api-key';
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should use default model "gpt-image-1" when IMAGE_GEN_OAI_MODEL is not set', async () => {
|
||||
delete process.env.IMAGE_GEN_OAI_MODEL;
|
||||
|
||||
const [imageGenTool] = createOpenAIImageTools({
|
||||
isAgent: true,
|
||||
override: false,
|
||||
req: { user: { id: 'test-user' } },
|
||||
});
|
||||
|
||||
const mockGenerate = jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: mockGenerate,
|
||||
},
|
||||
}));
|
||||
|
||||
await imageGenTool.func({ prompt: 'test prompt' });
|
||||
|
||||
expect(mockGenerate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gpt-image-1',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use "gpt-image-1.5" when IMAGE_GEN_OAI_MODEL is set to "gpt-image-1.5"', async () => {
|
||||
process.env.IMAGE_GEN_OAI_MODEL = 'gpt-image-1.5';
|
||||
|
||||
const mockGenerate = jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: mockGenerate,
|
||||
},
|
||||
}));
|
||||
|
||||
const [imageGenTool] = createOpenAIImageTools({
|
||||
isAgent: true,
|
||||
override: false,
|
||||
req: { user: { id: 'test-user' } },
|
||||
});
|
||||
|
||||
await imageGenTool.func({ prompt: 'test prompt' });
|
||||
|
||||
expect(mockGenerate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gpt-image-1.5',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use custom model name from IMAGE_GEN_OAI_MODEL environment variable', async () => {
|
||||
process.env.IMAGE_GEN_OAI_MODEL = 'custom-image-model';
|
||||
|
||||
const mockGenerate = jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: mockGenerate,
|
||||
},
|
||||
}));
|
||||
|
||||
const [imageGenTool] = createOpenAIImageTools({
|
||||
isAgent: true,
|
||||
override: false,
|
||||
req: { user: { id: 'test-user' } },
|
||||
});
|
||||
|
||||
await imageGenTool.func({ prompt: 'test prompt' });
|
||||
|
||||
expect(mockGenerate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'custom-image-model',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
12
bun.lock
12
bun.lock
|
|
@ -36,7 +36,7 @@
|
|||
},
|
||||
"api": {
|
||||
"name": "@librechat/backend",
|
||||
"version": "0.8.2-rc1",
|
||||
"version": "0.8.2-rc2",
|
||||
"dependencies": {
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.941.0",
|
||||
"@aws-sdk/client-s3": "^3.758.0",
|
||||
|
|
@ -124,7 +124,7 @@
|
|||
},
|
||||
"client": {
|
||||
"name": "@librechat/frontend",
|
||||
"version": "0.8.2-rc1",
|
||||
"version": "0.8.2-rc2",
|
||||
"dependencies": {
|
||||
"@ariakit/react": "^0.4.15",
|
||||
"@ariakit/react-core": "^0.4.17",
|
||||
|
|
@ -254,7 +254,7 @@
|
|||
},
|
||||
"packages/api": {
|
||||
"name": "@librechat/api",
|
||||
"version": "1.7.10",
|
||||
"version": "1.7.20",
|
||||
"devDependencies": {
|
||||
"@babel/preset-env": "^7.21.5",
|
||||
"@babel/preset-react": "^7.18.6",
|
||||
|
|
@ -321,7 +321,7 @@
|
|||
},
|
||||
"packages/client": {
|
||||
"name": "@librechat/client",
|
||||
"version": "0.4.2",
|
||||
"version": "0.4.4",
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.28.5",
|
||||
"@babel/preset-env": "^7.28.5",
|
||||
|
|
@ -409,7 +409,7 @@
|
|||
},
|
||||
"packages/data-provider": {
|
||||
"name": "librechat-data-provider",
|
||||
"version": "0.8.210",
|
||||
"version": "0.8.220",
|
||||
"dependencies": {
|
||||
"axios": "^1.12.1",
|
||||
"dayjs": "^1.11.13",
|
||||
|
|
@ -447,7 +447,7 @@
|
|||
},
|
||||
"packages/data-schemas": {
|
||||
"name": "@librechat/data-schemas",
|
||||
"version": "0.0.32",
|
||||
"version": "0.0.33",
|
||||
"devDependencies": {
|
||||
"@rollup/plugin-alias": "^5.1.0",
|
||||
"@rollup/plugin-commonjs": "^29.0.0",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/** v0.8.2-rc1 */
|
||||
/** v0.8.2-rc2 */
|
||||
module.exports = {
|
||||
roots: ['<rootDir>/src'],
|
||||
testEnvironment: 'jsdom',
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@librechat/frontend",
|
||||
"version": "v0.8.2-rc1",
|
||||
"version": "v0.8.2-rc2",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
|
@ -12,7 +12,7 @@
|
|||
"dev": "cross-env NODE_ENV=development vite",
|
||||
"preview-prod": "cross-env NODE_ENV=development vite preview",
|
||||
"test": "cross-env NODE_ENV=development jest --watch",
|
||||
"test:ci": "cross-env NODE_ENV=development jest --ci",
|
||||
"test:ci": "cross-env NODE_ENV=development jest --ci --logHeapUsage",
|
||||
"b:test": "NODE_ENV=test bunx jest --watch",
|
||||
"b:build": "NODE_ENV=production bun --bun vite build",
|
||||
"b:dev": "NODE_ENV=development bunx vite"
|
||||
|
|
@ -39,10 +39,10 @@
|
|||
"@marsidev/react-turnstile": "^1.1.0",
|
||||
"@mcp-ui/client": "^5.7.0",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.15",
|
||||
"@radix-ui/react-alert-dialog": "1.0.2",
|
||||
"@radix-ui/react-checkbox": "^1.0.3",
|
||||
"@radix-ui/react-collapsible": "^1.0.3",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dialog": "1.0.2",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.1",
|
||||
"@radix-ui/react-hover-card": "^1.0.5",
|
||||
"@radix-ui/react-icons": "^1.3.0",
|
||||
|
|
@ -80,6 +80,7 @@
|
|||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.394.0",
|
||||
"match-sorter": "^8.1.0",
|
||||
"mermaid": "^11.12.2",
|
||||
"micromark-extension-llm-math": "^3.1.0",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"rc-input-number": "^7.4.2",
|
||||
|
|
@ -95,7 +96,7 @@
|
|||
"react-lazy-load-image-component": "^1.6.0",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-resizable-panels": "^3.0.6",
|
||||
"react-router-dom": "^6.11.2",
|
||||
"react-router-dom": "^6.30.3",
|
||||
"react-speech-recognition": "^3.10.0",
|
||||
"react-textarea-autosize": "^8.4.0",
|
||||
"react-transition-group": "^4.4.5",
|
||||
|
|
@ -109,9 +110,11 @@
|
|||
"remark-math": "^6.0.0",
|
||||
"remark-supersub": "^1.0.0",
|
||||
"sse.js": "^2.5.0",
|
||||
"swr": "^2.3.8",
|
||||
"tailwind-merge": "^1.9.1",
|
||||
"tailwindcss-animate": "^1.0.5",
|
||||
"tailwindcss-radix": "^2.8.0",
|
||||
"ts-md5": "^1.3.1",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
|
|||
23
client/public/assets/gemini_image_gen.svg
Normal file
23
client/public/assets/gemini_image_gen.svg
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
<svg width="200" height="200" viewBox="0 0 200 200" xmlns="http://www.w3.org/2000/svg">
|
||||
<defs>
|
||||
<linearGradient id="starGradient" x1="0%" y1="0%" x2="100%" y2="100%">
|
||||
<stop offset="0%" style="stop-color:#9B59B6"/>
|
||||
<stop offset="50%" style="stop-color:#5B8DEF"/>
|
||||
<stop offset="100%" style="stop-color:#00BFFF"/>
|
||||
</linearGradient>
|
||||
<filter id="shadow" x="-20%" y="-20%" width="140%" height="140%">
|
||||
<feDropShadow dx="0" dy="4" stdDeviation="8" flood-color="#000" flood-opacity="0.1"/>
|
||||
</filter>
|
||||
</defs>
|
||||
|
||||
<!-- Rounded square background -->
|
||||
<rect x="20" y="20" width="160" height="160" rx="35" ry="35" fill="#F5F5F7" filter="url(#shadow)"/>
|
||||
|
||||
<!-- 4-pointed star -->
|
||||
<path d="M100 40
|
||||
C100 70, 85 85, 55 100
|
||||
C85 115, 100 130, 100 160
|
||||
C100 130, 115 115, 145 100
|
||||
C115 85, 100 70, 100 40 Z"
|
||||
fill="url(#starGradient)"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 926 B |
|
|
@ -18,6 +18,16 @@ const App = () => {
|
|||
const { setError } = useApiErrorBoundary();
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
// Always attempt network requests, even when navigator.onLine is false
|
||||
// This is needed because localhost is reachable without WiFi
|
||||
networkMode: 'always',
|
||||
},
|
||||
mutations: {
|
||||
networkMode: 'always',
|
||||
},
|
||||
},
|
||||
queryCache: new QueryCache({
|
||||
onError: (error) => {
|
||||
if (error?.response?.status === 401) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
import useAddedResponse from '~/hooks/Chat/useAddedResponse';
|
||||
type TAddedChatContext = ReturnType<typeof useAddedResponse>;
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { SetterOrUpdater } from 'recoil';
|
||||
import type { ConvoGenerator } from '~/common';
|
||||
|
||||
type TAddedChatContext = {
|
||||
conversation: TConversation | null;
|
||||
setConversation: SetterOrUpdater<TConversation | null>;
|
||||
generateConversation: ConvoGenerator;
|
||||
};
|
||||
|
||||
export const AddedChatContext = createContext<TAddedChatContext>({} as TAddedChatContext);
|
||||
export const useAddedChatContext = () => useContext(AddedChatContext);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import { getEndpointField } from 'librechat-data-provider';
|
||||
import { getEndpointField, isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import type { EModelEndpoint } from 'librechat-data-provider';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from './AgentsMapContext';
|
||||
import { useChatContext } from './ChatContext';
|
||||
|
||||
interface DragDropContextValue {
|
||||
|
|
@ -9,6 +10,7 @@ interface DragDropContextValue {
|
|||
agentId: string | null | undefined;
|
||||
endpoint: string | null | undefined;
|
||||
endpointType?: EModelEndpoint | undefined;
|
||||
useResponsesApi?: boolean;
|
||||
}
|
||||
|
||||
const DragDropContext = createContext<DragDropContextValue | undefined>(undefined);
|
||||
|
|
@ -16,6 +18,7 @@ const DragDropContext = createContext<DragDropContextValue | undefined>(undefine
|
|||
export function DragDropProvider({ children }: { children: React.ReactNode }) {
|
||||
const { conversation } = useChatContext();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const agentsMap = useAgentsMapContext();
|
||||
|
||||
const endpointType = useMemo(() => {
|
||||
return (
|
||||
|
|
@ -24,6 +27,34 @@ export function DragDropProvider({ children }: { children: React.ReactNode }) {
|
|||
);
|
||||
}, [conversation?.endpoint, endpointsConfig]);
|
||||
|
||||
const needsAgentFetch = useMemo(() => {
|
||||
const isAgents = isAgentsEndpoint(conversation?.endpoint);
|
||||
if (!isAgents || !conversation?.agent_id) {
|
||||
return false;
|
||||
}
|
||||
const agent = agentsMap?.[conversation.agent_id];
|
||||
return !agent?.model_parameters;
|
||||
}, [conversation?.endpoint, conversation?.agent_id, agentsMap]);
|
||||
|
||||
const { data: agentData } = useGetAgentByIdQuery(conversation?.agent_id, {
|
||||
enabled: needsAgentFetch,
|
||||
});
|
||||
|
||||
const useResponsesApi = useMemo(() => {
|
||||
const isAgents = isAgentsEndpoint(conversation?.endpoint);
|
||||
if (!isAgents || !conversation?.agent_id || conversation?.useResponsesApi) {
|
||||
return conversation?.useResponsesApi;
|
||||
}
|
||||
const agent = agentData || agentsMap?.[conversation.agent_id];
|
||||
return agent?.model_parameters?.useResponsesApi;
|
||||
}, [
|
||||
conversation?.endpoint,
|
||||
conversation?.agent_id,
|
||||
conversation?.useResponsesApi,
|
||||
agentData,
|
||||
agentsMap,
|
||||
]);
|
||||
|
||||
/** Context value only created when conversation fields change */
|
||||
const contextValue = useMemo<DragDropContextValue>(
|
||||
() => ({
|
||||
|
|
@ -31,8 +62,15 @@ export function DragDropProvider({ children }: { children: React.ReactNode }) {
|
|||
agentId: conversation?.agent_id,
|
||||
endpoint: conversation?.endpoint,
|
||||
endpointType: endpointType,
|
||||
useResponsesApi: useResponsesApi,
|
||||
}),
|
||||
[conversation?.conversationId, conversation?.agent_id, conversation?.endpoint, endpointType],
|
||||
[
|
||||
conversation?.conversationId,
|
||||
conversation?.agent_id,
|
||||
conversation?.endpoint,
|
||||
useResponsesApi,
|
||||
endpointType,
|
||||
],
|
||||
);
|
||||
|
||||
return <DragDropContext.Provider value={contextValue}>{children}</DragDropContext.Provider>;
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue