mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-30 15:18:50 +01:00
Merge branch 'dev' into feat/context-window-ui
This commit is contained in:
commit
cb8322ca85
407 changed files with 25479 additions and 19894 deletions
16
.env.example
16
.env.example
|
|
@ -68,6 +68,18 @@ DEBUG_CONSOLE=false
|
|||
# UID=1000
|
||||
# GID=1000
|
||||
|
||||
#==============#
|
||||
# Node Options #
|
||||
#==============#
|
||||
|
||||
# NOTE: NODE_MAX_OLD_SPACE_SIZE is NOT recognized by Node.js directly.
|
||||
# This variable is used as a build argument for Docker or CI/CD workflows,
|
||||
# and is NOT used by Node.js to set the heap size at runtime.
|
||||
# To configure Node.js memory, use NODE_OPTIONS, e.g.:
|
||||
# NODE_OPTIONS="--max-old-space-size=6144"
|
||||
# See: https://nodejs.org/api/cli.html#--max-old-space-sizesize-in-mib
|
||||
NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
||||
#===============#
|
||||
# Configuration #
|
||||
#===============#
|
||||
|
|
@ -248,6 +260,7 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=
|
|||
# IMAGE_GEN_OAI_API_KEY= # Create or reuse OpenAI API key for image generation tool
|
||||
# IMAGE_GEN_OAI_BASEURL= # Custom OpenAI base URL for image generation tool
|
||||
# IMAGE_GEN_OAI_AZURE_API_VERSION= # Custom Azure OpenAI deployments
|
||||
# IMAGE_GEN_OAI_MODEL=gpt-image-1 # OpenAI image model (e.g., gpt-image-1, gpt-image-1.5)
|
||||
# IMAGE_GEN_OAI_DESCRIPTION=
|
||||
# IMAGE_GEN_OAI_DESCRIPTION_WITH_FILES=Custom description for image generation tool when files are present
|
||||
# IMAGE_GEN_OAI_DESCRIPTION_NO_FILES=Custom description for image generation tool when no files are present
|
||||
|
|
@ -656,6 +669,9 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
|||
|
||||
# Enable Redis for caching and session storage
|
||||
# USE_REDIS=true
|
||||
# Enable Redis for resumable LLM streams (defaults to USE_REDIS value if not set)
|
||||
# Set to false to use in-memory storage for streams while keeping Redis for other caches
|
||||
# USE_REDIS_STREAMS=true
|
||||
|
||||
# Single Redis instance
|
||||
# REDIS_URI=redis://127.0.0.1:6379
|
||||
|
|
|
|||
1
.github/workflows/backend-review.yml
vendored
1
.github/workflows/backend-review.yml
vendored
|
|
@ -24,6 +24,7 @@ jobs:
|
|||
BAN_DURATION: ${{ secrets.BAN_DURATION }}
|
||||
BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }}
|
||||
NODE_ENV: CI
|
||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Use Node.js 20.x
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ on:
|
|||
- 'packages/api/src/cache/**'
|
||||
- 'packages/api/src/cluster/**'
|
||||
- 'packages/api/src/mcp/**'
|
||||
- 'packages/api/src/stream/**'
|
||||
- 'redis-config/**'
|
||||
- '.github/workflows/cache-integration-tests.yml'
|
||||
|
||||
|
|
|
|||
4
.github/workflows/frontend-review.yml
vendored
4
.github/workflows/frontend-review.yml
vendored
|
|
@ -16,6 +16,8 @@ jobs:
|
|||
name: Run frontend unit tests on Ubuntu
|
||||
timeout-minutes: 60
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Use Node.js 20.x
|
||||
|
|
@ -38,6 +40,8 @@ jobs:
|
|||
name: Run frontend unit tests on Windows
|
||||
timeout-minutes: 60
|
||||
runs-on: windows-latest
|
||||
env:
|
||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Use Node.js 20.x
|
||||
|
|
|
|||
11
Dockerfile
11
Dockerfile
|
|
@ -1,4 +1,4 @@
|
|||
# v0.8.1
|
||||
# v0.8.2-rc1
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
|
@ -14,6 +14,9 @@ ENV LD_PRELOAD=/usr/lib/libjemalloc.so.2
|
|||
COPY --from=ghcr.io/astral-sh/uv:0.9.5-python3.12-alpine /usr/local/bin/uv /usr/local/bin/uvx /bin/
|
||||
RUN uv --version
|
||||
|
||||
# Set configurable max-old-space-size with default
|
||||
ARG NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
||||
RUN mkdir -p /app && chown node:node /app
|
||||
WORKDIR /app
|
||||
|
||||
|
|
@ -30,7 +33,7 @@ RUN \
|
|||
# Allow mounting of these files, which have no default
|
||||
touch .env ; \
|
||||
# Create directories for the volumes to inherit the correct permissions
|
||||
mkdir -p /app/client/public/images /app/api/logs /app/uploads ; \
|
||||
mkdir -p /app/client/public/images /app/logs /app/uploads ; \
|
||||
npm config set fetch-retry-maxtimeout 600000 ; \
|
||||
npm config set fetch-retries 5 ; \
|
||||
npm config set fetch-retry-mintimeout 15000 ; \
|
||||
|
|
@ -39,8 +42,8 @@ RUN \
|
|||
COPY --chown=node:node . .
|
||||
|
||||
RUN \
|
||||
# React client build
|
||||
NODE_OPTIONS="--max-old-space-size=2048" npm run frontend; \
|
||||
# React client build with configurable memory
|
||||
NODE_OPTIONS="--max-old-space-size=${NODE_MAX_OLD_SPACE_SIZE}" npm run frontend; \
|
||||
npm prune --production; \
|
||||
npm cache clean --force
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
# Dockerfile.multi
|
||||
# v0.8.1
|
||||
# v0.8.2-rc1
|
||||
|
||||
# Set configurable max-old-space-size with default
|
||||
ARG NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
|
|
@ -7,6 +10,7 @@ FROM node:20-alpine AS base-min
|
|||
RUN apk add --no-cache jemalloc
|
||||
# Set environment variable to use jemalloc
|
||||
ENV LD_PRELOAD=/usr/lib/libjemalloc.so.2
|
||||
|
||||
WORKDIR /app
|
||||
RUN apk --no-cache add curl
|
||||
RUN npm config set fetch-retry-maxtimeout 600000 && \
|
||||
|
|
@ -59,7 +63,7 @@ COPY client ./
|
|||
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
|
||||
COPY --from=client-package-build /app/packages/client/dist /app/packages/client/dist
|
||||
COPY --from=client-package-build /app/packages/client/src /app/packages/client/src
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
ENV NODE_OPTIONS="--max-old-space-size=${NODE_MAX_OLD_SPACE_SIZE}"
|
||||
RUN npm run build
|
||||
|
||||
# API setup (including client dist)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ const {
|
|||
EModelEndpoint,
|
||||
isParamEndpoint,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
supportsBalanceCheck,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -714,7 +715,7 @@ class BaseClient {
|
|||
iconURL: this.options.iconURL,
|
||||
endpoint: this.options.endpoint,
|
||||
...(this.metadata ?? {}),
|
||||
metadata,
|
||||
metadata: Object.keys(metadata ?? {}).length > 0 ? metadata : undefined,
|
||||
};
|
||||
|
||||
if (typeof completion === 'string') {
|
||||
|
|
@ -965,6 +966,13 @@ class BaseClient {
|
|||
|
||||
const unsetFields = {};
|
||||
const exceptions = new Set(['spec', 'iconURL']);
|
||||
const hasNonEphemeralAgent =
|
||||
isAgentsEndpoint(this.options.endpoint) &&
|
||||
endpointOptions?.agent_id &&
|
||||
!isEphemeralAgentId(endpointOptions.agent_id);
|
||||
if (hasNonEphemeralAgent) {
|
||||
exceptions.add('model');
|
||||
}
|
||||
if (existingConvo != null) {
|
||||
this.fetchedConvo = true;
|
||||
for (const key in existingConvo) {
|
||||
|
|
|
|||
|
|
@ -18,17 +18,17 @@ function generateShadcnPrompt(options) {
|
|||
Here are the components that are available, along with how to import them, and how to use them:
|
||||
|
||||
${Object.values(components)
|
||||
.map((component) => {
|
||||
if (useXML) {
|
||||
return dedent`
|
||||
.map((component) => {
|
||||
if (useXML) {
|
||||
return dedent`
|
||||
<component>
|
||||
<name>${component.componentName}</name>
|
||||
<import-instructions>${component.importDocs}</import-instructions>
|
||||
<usage-instructions>${component.usageDocs}</usage-instructions>
|
||||
</component>
|
||||
`;
|
||||
} else {
|
||||
return dedent`
|
||||
} else {
|
||||
return dedent`
|
||||
# ${component.componentName}
|
||||
|
||||
## Import Instructions
|
||||
|
|
@ -37,9 +37,9 @@ function generateShadcnPrompt(options) {
|
|||
## Usage Instructions
|
||||
${component.usageDocs}
|
||||
`;
|
||||
}
|
||||
})
|
||||
.join('\n\n')}
|
||||
}
|
||||
})
|
||||
.join('\n\n')}
|
||||
`;
|
||||
|
||||
return systemPrompt;
|
||||
|
|
|
|||
|
|
@ -78,6 +78,8 @@ function createOpenAIImageTools(fields = {}) {
|
|||
let apiKey = fields.IMAGE_GEN_OAI_API_KEY ?? getApiKey();
|
||||
const closureConfig = { apiKey };
|
||||
|
||||
const imageModel = process.env.IMAGE_GEN_OAI_MODEL || 'gpt-image-1';
|
||||
|
||||
let baseURL = 'https://api.openai.com/v1/';
|
||||
if (!override && process.env.IMAGE_GEN_OAI_BASEURL) {
|
||||
baseURL = extractBaseURL(process.env.IMAGE_GEN_OAI_BASEURL);
|
||||
|
|
@ -157,7 +159,7 @@ function createOpenAIImageTools(fields = {}) {
|
|||
|
||||
resp = await openai.images.generate(
|
||||
{
|
||||
model: 'gpt-image-1',
|
||||
model: imageModel,
|
||||
prompt: replaceUnwantedChars(prompt),
|
||||
n: Math.min(Math.max(1, n), 10),
|
||||
background,
|
||||
|
|
@ -239,7 +241,7 @@ Error Message: ${error.message}`);
|
|||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('model', 'gpt-image-1');
|
||||
formData.append('model', imageModel);
|
||||
formData.append('prompt', replaceUnwantedChars(prompt));
|
||||
// TODO: `mask` support
|
||||
// TODO: more than 1 image support
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ class OpenWeather extends Tool {
|
|||
|
||||
if (['current_forecast', 'timestamp', 'daily_aggregation', 'overview'].includes(action)) {
|
||||
if (typeof finalLat !== 'number' || typeof finalLon !== 'number') {
|
||||
return 'Error: lat and lon are required and must be numbers for this action (or specify \'city\').';
|
||||
return "Error: lat and lon are required and must be numbers for this action (or specify 'city').";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ class OpenWeather extends Tool {
|
|||
let dt;
|
||||
if (action === 'timestamp') {
|
||||
if (!date) {
|
||||
return 'Error: For timestamp action, a \'date\' in YYYY-MM-DD format is required.';
|
||||
return "Error: For timestamp action, a 'date' in YYYY-MM-DD format is required.";
|
||||
}
|
||||
dt = this.convertDateToUnix(date);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -348,10 +348,10 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
/** Placeholder used for UI purposes */
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
serverName &&
|
||||
(await getMCPServersRegistry().getServerConfig(serverName, user)) == undefined
|
||||
) {
|
||||
const serverConfig = serverName
|
||||
? await getMCPServersRegistry().getServerConfig(serverName, user)
|
||||
: null;
|
||||
if (!serverConfig) {
|
||||
logger.warn(
|
||||
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
|
||||
);
|
||||
|
|
@ -362,6 +362,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
{
|
||||
type: 'all',
|
||||
serverName,
|
||||
config: serverConfig,
|
||||
},
|
||||
];
|
||||
continue;
|
||||
|
|
@ -372,6 +373,7 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
type: 'single',
|
||||
toolKey: tool,
|
||||
serverName,
|
||||
config: serverConfig,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
|
@ -432,9 +434,11 @@ Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|new
|
|||
user: safeUser,
|
||||
userMCPAuthMap,
|
||||
res: options.res,
|
||||
streamId: options.req?._resumableStreamId || null,
|
||||
model: agent?.model ?? model,
|
||||
serverName: config.serverName,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
config: config.config,
|
||||
};
|
||||
|
||||
if (config.type === 'all' && toolConfigs.length === 1) {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,35 @@
|
|||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
const winston = require('winston');
|
||||
require('winston-daily-rotate-file');
|
||||
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
/**
|
||||
* Determine the log directory.
|
||||
* Priority:
|
||||
* 1. LIBRECHAT_LOG_DIR environment variable (allows user override)
|
||||
* 2. /app/logs if running in Docker (bind-mounted with correct permissions)
|
||||
* 3. api/logs relative to this file (local development)
|
||||
*/
|
||||
const getLogDir = () => {
|
||||
if (process.env.LIBRECHAT_LOG_DIR) {
|
||||
return process.env.LIBRECHAT_LOG_DIR;
|
||||
}
|
||||
|
||||
// Check if running in Docker container (cwd is /app)
|
||||
if (process.cwd() === '/app') {
|
||||
const dockerLogDir = '/app/logs';
|
||||
// Ensure the directory exists
|
||||
if (!fs.existsSync(dockerLogDir)) {
|
||||
fs.mkdirSync(dockerLogDir, { recursive: true });
|
||||
}
|
||||
return dockerLogDir;
|
||||
}
|
||||
|
||||
// Local development: use api/logs relative to this file
|
||||
return path.join(__dirname, '..', 'logs');
|
||||
};
|
||||
|
||||
const logDir = getLogDir();
|
||||
|
||||
const { NODE_ENV, DEBUG_LOGGING = false } = process.env;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,36 @@
|
|||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
const winston = require('winston');
|
||||
require('winston-daily-rotate-file');
|
||||
const { redactFormat, redactMessage, debugTraverse, jsonTruncateFormat } = require('./parsers');
|
||||
|
||||
const logDir = path.join(__dirname, '..', 'logs');
|
||||
/**
|
||||
* Determine the log directory.
|
||||
* Priority:
|
||||
* 1. LIBRECHAT_LOG_DIR environment variable (allows user override)
|
||||
* 2. /app/logs if running in Docker (bind-mounted with correct permissions)
|
||||
* 3. api/logs relative to this file (local development)
|
||||
*/
|
||||
const getLogDir = () => {
|
||||
if (process.env.LIBRECHAT_LOG_DIR) {
|
||||
return process.env.LIBRECHAT_LOG_DIR;
|
||||
}
|
||||
|
||||
// Check if running in Docker container (cwd is /app)
|
||||
if (process.cwd() === '/app') {
|
||||
const dockerLogDir = '/app/logs';
|
||||
// Ensure the directory exists
|
||||
if (!fs.existsSync(dockerLogDir)) {
|
||||
fs.mkdirSync(dockerLogDir, { recursive: true });
|
||||
}
|
||||
return dockerLogDir;
|
||||
}
|
||||
|
||||
// Local development: use api/logs relative to this file
|
||||
return path.join(__dirname, '..', 'logs');
|
||||
};
|
||||
|
||||
const logDir = getLogDir();
|
||||
|
||||
const { NODE_ENV, DEBUG_LOGGING = true, CONSOLE_JSON = false, DEBUG_CONSOLE = false } = process.env;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
const mongoose = require('mongoose');
|
||||
const crypto = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_all, mcp_delimiter } =
|
||||
const { getCustomEndpointConfig } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
actionDelimiter,
|
||||
isAgentsEndpoint,
|
||||
getResponseSender,
|
||||
isEphemeralAgentId,
|
||||
encodeEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, mcp_all, mcp_delimiter } =
|
||||
require('librechat-data-provider').Constants;
|
||||
const {
|
||||
removeAgentFromAllProjects,
|
||||
|
|
@ -92,7 +102,7 @@ const getAgents = async (searchParameter) => await Agent.find(searchParameter).l
|
|||
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_parameters: _m }) => {
|
||||
const loadEphemeralAgent = async ({ req, spec, endpoint, model_parameters: _m }) => {
|
||||
const { model, ...model_parameters } = _m;
|
||||
const modelSpecs = req.config?.modelSpecs?.list;
|
||||
/** @type {TModelSpec | null} */
|
||||
|
|
@ -139,8 +149,28 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
|
|||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
|
||||
// Compute display name using getResponseSender (same logic used for addedConvo agents)
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadEphemeralAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
const sender = getResponseSender({
|
||||
modelLabel: model_parameters?.modelLabel,
|
||||
modelDisplayLabel: endpointConfig?.modelDisplayLabel,
|
||||
});
|
||||
|
||||
// Encode ephemeral agent ID with endpoint, model, and computed sender for display
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender });
|
||||
|
||||
const result = {
|
||||
id: agent_id,
|
||||
id: ephemeralId,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
|
|
@ -169,8 +199,8 @@ const loadAgent = async ({ req, spec, agent_id, endpoint, model_parameters }) =>
|
|||
if (!agent_id) {
|
||||
return null;
|
||||
}
|
||||
if (agent_id === EPHEMERAL_AGENT_ID) {
|
||||
return await loadEphemeralAgent({ req, spec, agent_id, endpoint, model_parameters });
|
||||
if (isEphemeralAgentId(agent_id)) {
|
||||
return await loadEphemeralAgent({ req, spec, endpoint, model_parameters });
|
||||
}
|
||||
const agent = await getAgent({
|
||||
id: agent_id,
|
||||
|
|
|
|||
|
|
@ -1960,7 +1960,8 @@ describe('models/Agent', () => {
|
|||
});
|
||||
|
||||
if (result) {
|
||||
expect(result.id).toBe(EPHEMERAL_AGENT_ID);
|
||||
// Ephemeral agent ID is encoded with endpoint and model
|
||||
expect(result.id).toBe('openai__gpt-4');
|
||||
expect(result.instructions).toBe('Test instructions');
|
||||
expect(result.provider).toBe('openai');
|
||||
expect(result.model).toBe('gpt-4');
|
||||
|
|
@ -1978,7 +1979,7 @@ describe('models/Agent', () => {
|
|||
const mockReq = { user: { id: 'user123' } };
|
||||
const result = await loadAgent({
|
||||
req: mockReq,
|
||||
agent_id: 'non_existent_agent',
|
||||
agent_id: 'agent_non_existent',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
|
@ -2105,7 +2106,7 @@ describe('models/Agent', () => {
|
|||
test('should handle loadAgent with malformed req object', async () => {
|
||||
const result = await loadAgent({
|
||||
req: null,
|
||||
agent_id: 'test',
|
||||
agent_id: 'agent_test',
|
||||
endpoint: 'openai',
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
|
|
|||
218
api/models/loadAddedAgent.js
Normal file
218
api/models/loadAddedAgent.js
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getCustomEndpointConfig } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
isAgentsEndpoint,
|
||||
getResponseSender,
|
||||
isEphemeralAgentId,
|
||||
appendAgentIdSuffix,
|
||||
encodeEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
|
||||
const { mcp_all, mcp_delimiter } = Constants;
|
||||
|
||||
/**
|
||||
* Constant for added conversation agent ID
|
||||
*/
|
||||
const ADDED_AGENT_ID = 'added_agent';
|
||||
|
||||
/**
|
||||
* Get an agent document based on the provided ID.
|
||||
* @param {Object} searchParameter - The search parameters to find the agent.
|
||||
* @param {string} searchParameter.id - The ID of the agent.
|
||||
* @returns {Promise<import('librechat-data-provider').Agent|null>}
|
||||
*/
|
||||
let getAgent;
|
||||
|
||||
/**
|
||||
* Set the getAgent function (dependency injection to avoid circular imports)
|
||||
* @param {Function} fn
|
||||
*/
|
||||
const setGetAgent = (fn) => {
|
||||
getAgent = fn;
|
||||
};
|
||||
|
||||
/**
|
||||
* Load an agent from an added conversation (TConversation).
|
||||
* Used for multi-convo parallel agent execution.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req
|
||||
* @param {import('librechat-data-provider').TConversation} params.conversation - The added conversation
|
||||
* @param {import('librechat-data-provider').Agent} [params.primaryAgent] - The primary agent (used to duplicate tools when both are ephemeral)
|
||||
* @returns {Promise<import('librechat-data-provider').Agent|null>} The agent config as a plain object, or null if invalid.
|
||||
*/
|
||||
const loadAddedAgent = async ({ req, conversation, primaryAgent }) => {
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// If there's an agent_id, load the existing agent
|
||||
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
|
||||
if (!getAgent) {
|
||||
throw new Error('getAgent not initialized - call setGetAgent first');
|
||||
}
|
||||
const agent = await getAgent({
|
||||
id: conversation.agent_id,
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);
|
||||
return null;
|
||||
}
|
||||
|
||||
agent.version = agent.versions ? agent.versions.length : 0;
|
||||
// Append suffix to distinguish from primary agent (matches ephemeral format)
|
||||
// This is needed when both agents have the same ID or for consistent parallel content attribution
|
||||
agent.id = appendAgentIdSuffix(agent.id, 1);
|
||||
return agent;
|
||||
}
|
||||
|
||||
// Otherwise, create an ephemeral agent config from the conversation
|
||||
const { model, endpoint, promptPrefix, spec, ...rest } = conversation;
|
||||
|
||||
if (!endpoint || !model) {
|
||||
logger.warn('[loadAddedAgent] Missing required endpoint or model for ephemeral agent');
|
||||
return null;
|
||||
}
|
||||
|
||||
// If both primary and added agents are ephemeral, duplicate tools from primary agent
|
||||
const primaryIsEphemeral = primaryAgent && isEphemeralAgentId(primaryAgent.id);
|
||||
if (primaryIsEphemeral && Array.isArray(primaryAgent.tools)) {
|
||||
// Get display name using getResponseSender
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
const sender = getResponseSender({
|
||||
modelLabel: rest.modelLabel,
|
||||
modelDisplayLabel: endpointConfig?.modelDisplayLabel,
|
||||
});
|
||||
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
return {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters: {},
|
||||
model,
|
||||
tools: [...primaryAgent.tools],
|
||||
};
|
||||
}
|
||||
|
||||
// Extract ephemeral agent options from conversation if present
|
||||
const ephemeralAgent = rest.ephemeralAgent;
|
||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||
const userId = req.user?.id;
|
||||
|
||||
// Check model spec for MCP servers
|
||||
const modelSpecs = req.config?.modelSpecs?.list;
|
||||
let modelSpec = null;
|
||||
if (spec != null && spec !== '') {
|
||||
modelSpec = modelSpecs?.find((s) => s.name === spec) || null;
|
||||
}
|
||||
if (modelSpec?.mcpServers) {
|
||||
for (const mcpServer of modelSpec.mcpServers) {
|
||||
mcpServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {string[]} */
|
||||
const tools = [];
|
||||
if (ephemeralAgent?.execute_code === true || modelSpec?.executeCode === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true || modelSpec?.fileSearch === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true || modelSpec?.webSearch === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await getMCPServerTools(userId, mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
}
|
||||
}
|
||||
|
||||
// Build model_parameters from conversation fields
|
||||
const model_parameters = {};
|
||||
const paramKeys = [
|
||||
'temperature',
|
||||
'top_p',
|
||||
'topP',
|
||||
'topK',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'maxOutputTokens',
|
||||
'maxTokens',
|
||||
'max_tokens',
|
||||
];
|
||||
|
||||
for (const key of paramKeys) {
|
||||
if (rest[key] != null) {
|
||||
model_parameters[key] = rest[key];
|
||||
}
|
||||
}
|
||||
|
||||
// Get endpoint config for modelDisplayLabel (same pattern as initialize.js)
|
||||
const appConfig = req.config;
|
||||
let endpointConfig = appConfig?.endpoints?.[endpoint];
|
||||
if (!isAgentsEndpoint(endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = getCustomEndpointConfig({ endpoint, appConfig });
|
||||
} catch (err) {
|
||||
logger.error('[loadAddedAgent] Error getting custom endpoint config', err);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute display name using getResponseSender (same logic used for main agent)
|
||||
const sender = getResponseSender({
|
||||
modelLabel: rest.modelLabel,
|
||||
modelDisplayLabel: endpointConfig?.modelDisplayLabel,
|
||||
});
|
||||
|
||||
/** Encoded ephemeral agent ID with endpoint, model, sender, and index=1 to distinguish from primary */
|
||||
const ephemeralId = encodeEphemeralAgentId({ endpoint, model, sender, index: 1 });
|
||||
|
||||
const result = {
|
||||
id: ephemeralId,
|
||||
instructions: promptPrefix || '',
|
||||
provider: endpoint,
|
||||
model_parameters,
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
ADDED_AGENT_ID,
|
||||
loadAddedAgent,
|
||||
setGetAgent,
|
||||
};
|
||||
|
|
@ -113,6 +113,8 @@ const tokenValues = Object.assign(
|
|||
'gpt-4o-2024-05-13': { prompt: 5, completion: 15 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-5': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5.1': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5.2': { prompt: 1.75, completion: 14 },
|
||||
'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
|
||||
'gpt-5-mini': { prompt: 0.25, completion: 2 },
|
||||
'gpt-5-pro': { prompt: 15, completion: 120 },
|
||||
|
|
|
|||
|
|
@ -35,6 +35,19 @@ describe('getValueKey', () => {
|
|||
expect(getValueKey('gpt-5-0130')).toBe('gpt-5');
|
||||
});
|
||||
|
||||
it('should return "gpt-5.1" for model name containing "gpt-5.1"', () => {
|
||||
expect(getValueKey('gpt-5.1')).toBe('gpt-5.1');
|
||||
expect(getValueKey('gpt-5.1-chat')).toBe('gpt-5.1');
|
||||
expect(getValueKey('gpt-5.1-codex')).toBe('gpt-5.1');
|
||||
expect(getValueKey('openai/gpt-5.1')).toBe('gpt-5.1');
|
||||
});
|
||||
|
||||
it('should return "gpt-5.2" for model name containing "gpt-5.2"', () => {
|
||||
expect(getValueKey('gpt-5.2')).toBe('gpt-5.2');
|
||||
expect(getValueKey('gpt-5.2-chat')).toBe('gpt-5.2');
|
||||
expect(getValueKey('openai/gpt-5.2')).toBe('gpt-5.2');
|
||||
});
|
||||
|
||||
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
|
||||
expect(getValueKey('gpt-3.5-turbo-1106-some-other-info')).toBe('gpt-3.5-turbo-1106');
|
||||
expect(getValueKey('openai/gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
|
||||
|
|
@ -310,6 +323,34 @@ describe('getMultiplier', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5.1', () => {
|
||||
expect(getMultiplier({ model: 'gpt-5.1', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.1'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5.1', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5.1'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5.1', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.1'].prompt,
|
||||
);
|
||||
expect(tokenValues['gpt-5.1'].prompt).toBe(1.25);
|
||||
expect(tokenValues['gpt-5.1'].completion).toBe(10);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5.2', () => {
|
||||
expect(getMultiplier({ model: 'gpt-5.2', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.2'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5.2', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5.2'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5.2', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5.2'].prompt,
|
||||
);
|
||||
expect(tokenValues['gpt-5.2'].prompt).toBe(1.75);
|
||||
expect(tokenValues['gpt-5.2'].completion).toBe(14);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o', () => {
|
||||
const valueKey = getValueKey('gpt-4o-2024-08-06');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.8.1",
|
||||
"version": "v0.8.2-rc1",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
"server-dev": "echo 'please run this from the root directory'",
|
||||
"test": "cross-env NODE_ENV=test jest",
|
||||
"b:test": "NODE_ENV=test bun jest",
|
||||
"test:ci": "jest --ci",
|
||||
"test:ci": "jest --ci --logHeapUsage",
|
||||
"add-balance": "node ./add-balance.js",
|
||||
"list-balances": "node ./list-balances.js",
|
||||
"user-stats": "node ./user-stats.js",
|
||||
|
|
@ -42,12 +42,12 @@
|
|||
"@azure/storage-blob": "^12.27.0",
|
||||
"@googleapis/youtube": "^20.0.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/core": "^0.3.79",
|
||||
"@librechat/agents": "^3.0.50",
|
||||
"@langchain/core": "^0.3.80",
|
||||
"@librechat/agents": "^3.0.61",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@modelcontextprotocol/sdk": "^1.24.3",
|
||||
"@modelcontextprotocol/sdk": "^1.25.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"axios": "^1.12.1",
|
||||
|
|
@ -79,6 +79,7 @@
|
|||
"klona": "^2.0.6",
|
||||
"librechat-data-provider": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"mathjs": "^15.1.0",
|
||||
"meilisearch": "^0.38.0",
|
||||
"memorystore": "^1.6.7",
|
||||
"mime": "^3.0.0",
|
||||
|
|
|
|||
|
|
@ -350,9 +350,6 @@ function disposeClient(client) {
|
|||
if (client.agentConfigs) {
|
||||
client.agentConfigs = null;
|
||||
}
|
||||
if (client.agentIdMap) {
|
||||
client.agentIdMap = null;
|
||||
}
|
||||
if (client.artifactPromises) {
|
||||
client.artifactPromises = null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ const {
|
|||
setAuthTokens,
|
||||
registerUser,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const {
|
||||
deleteAllUserSessions,
|
||||
getUserById,
|
||||
findSession,
|
||||
updateUser,
|
||||
findUser,
|
||||
} = require('~/models');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
const { getOAuthReconnectionManager } = require('~/config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
|
@ -72,16 +78,38 @@ const refreshController = async (req, res) => {
|
|||
const openIdConfig = getOpenIdConfig();
|
||||
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
|
||||
const claims = tokenset.claims();
|
||||
const { user, error } = await findOpenIDUser({
|
||||
const { user, error, migration } = await findOpenIDUser({
|
||||
findUser,
|
||||
email: claims.email,
|
||||
openidId: claims.sub,
|
||||
idOnTheSource: claims.oid,
|
||||
strategyName: 'refreshController',
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`[refreshController] findOpenIDUser result: user=${user?.email ?? 'null'}, error=${error ?? 'null'}, migration=${migration}, userOpenidId=${user?.openidId ?? 'null'}, claimsSub=${claims.sub}`,
|
||||
);
|
||||
|
||||
if (error || !user) {
|
||||
logger.warn(
|
||||
`[refreshController] Redirecting to /login: error=${error ?? 'null'}, user=${user ? 'exists' : 'null'}`,
|
||||
);
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
|
||||
// Handle migration: update user with openidId if found by email without openidId
|
||||
// Also handle case where user has mismatched openidId (e.g., after database switch)
|
||||
if (migration || user.openidId !== claims.sub) {
|
||||
const reason = migration ? 'migration' : 'openidId mismatch';
|
||||
await updateUser(user._id.toString(), {
|
||||
provider: 'openid',
|
||||
openidId: claims.sub,
|
||||
});
|
||||
logger.info(
|
||||
`[refreshController] Updated user ${user.email} openidId (${reason}): ${user.openidId ?? 'null'} -> ${claims.sub}`,
|
||||
);
|
||||
}
|
||||
|
||||
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString(), refreshToken);
|
||||
|
||||
user.federatedTokens = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { sendEvent, GenerationJobManager } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -144,17 +144,38 @@ function checkIfLastAgent(last_agent_id, langgraph_node) {
|
|||
return langgraph_node?.endsWith(last_agent_id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to emit events either to res (standard mode) or to job emitter (resumable mode).
|
||||
* @param {ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} eventData - The event data to send
|
||||
*/
|
||||
function emitEvent(res, streamId, eventData) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default handlers for stream events.
|
||||
* @param {Object} options - The options object.
|
||||
* @param {ServerResponse} options.res - The options object.
|
||||
* @param {ContentAggregator} options.aggregateContent - The options object.
|
||||
* @param {ServerResponse} options.res - The server response object.
|
||||
* @param {ContentAggregator} options.aggregateContent - Content aggregator function.
|
||||
* @param {ToolEndCallback} options.toolEndCallback - Callback to use when tool ends.
|
||||
* @param {Array<UsageMetadata>} options.collectedUsage - The list of collected usage metadata.
|
||||
* @param {string | null} [options.streamId] - The stream ID for resumable mode, or null for standard mode.
|
||||
* @returns {Record<string, t.EventHandler>} The default handlers.
|
||||
* @throws {Error} If the request is not found.
|
||||
*/
|
||||
function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedUsage }) {
|
||||
function getDefaultHandlers({
|
||||
res,
|
||||
aggregateContent,
|
||||
toolEndCallback,
|
||||
collectedUsage,
|
||||
streamId = null,
|
||||
}) {
|
||||
if (!res || !aggregateContent) {
|
||||
throw new Error(
|
||||
`[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`,
|
||||
|
|
@ -173,16 +194,16 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.stepDetails.type === StepTypes.TOOL_CALLS) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else {
|
||||
const agentName = metadata?.name ?? 'Agent';
|
||||
const isToolCall = data?.stepDetails.type === StepTypes.TOOL_CALLS;
|
||||
const action = isToolCall ? 'performing a task...' : 'thinking...';
|
||||
sendEvent(res, {
|
||||
emitEvent(res, streamId, {
|
||||
event: 'on_agent_update',
|
||||
data: {
|
||||
runId: metadata?.run_id,
|
||||
|
|
@ -202,11 +223,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.delta.type === StepTypes.TOOL_CALLS) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -220,11 +241,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.result != null) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -238,9 +259,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -254,9 +275,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -266,15 +287,30 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
return handlers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to write attachment events either to res or to job emitter.
|
||||
* @param {ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} attachment - The attachment data
|
||||
*/
|
||||
function writeAttachment(res, streamId, attachment) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
|
||||
} else {
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Promise<MongoFile | { filename: string; filepath: string; expires: number;} | null>[]} params.artifactPromises
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode, or null for standard mode.
|
||||
* @returns {ToolEndCallback} The tool end callback.
|
||||
*/
|
||||
function createToolEndCallback({ req, res, artifactPromises }) {
|
||||
function createToolEndCallback({ req, res, artifactPromises, streamId = null }) {
|
||||
/**
|
||||
* @type {ToolEndCallback}
|
||||
*/
|
||||
|
|
@ -302,10 +338,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
if (!attachment) {
|
||||
return null;
|
||||
}
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing file citations:', error);
|
||||
|
|
@ -314,8 +350,6 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
);
|
||||
}
|
||||
|
||||
// TODO: a lot of duplicated code in createToolEndCallback
|
||||
// we should refactor this to use a helper function in a follow-up PR
|
||||
if (output.artifact[Tools.ui_resources]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
|
|
@ -326,10 +360,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -348,10 +382,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
[Tools.web_search]: { ...output.artifact[Tools.web_search] },
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -388,7 +422,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
toolCallId: output.tool_call_id,
|
||||
conversationId: metadata.thread_id,
|
||||
});
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return fileMetadata;
|
||||
}
|
||||
|
||||
|
|
@ -396,7 +430,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
return null;
|
||||
}
|
||||
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
|
||||
writeAttachment(res, streamId, fileMetadata);
|
||||
return fileMetadata;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -435,7 +469,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
session_id: output.artifact.session_id,
|
||||
});
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return fileMetadata;
|
||||
}
|
||||
|
||||
|
|
@ -443,7 +477,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
return null;
|
||||
}
|
||||
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
|
||||
writeAttachment(res, streamId, fileMetadata);
|
||||
return fileMetadata;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing code output:', error);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const {
|
|||
getBalanceConfig,
|
||||
getProviderConfig,
|
||||
memoryInstructions,
|
||||
GenerationJobManager,
|
||||
getTransactionsConfig,
|
||||
createMemoryProcessor,
|
||||
filterMalformedContentParts,
|
||||
|
|
@ -36,14 +37,13 @@ const {
|
|||
EModelEndpoint,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
isEphemeralAgentId,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { checkCapability } = require('~/server/services/Config');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
|
@ -95,59 +95,101 @@ function logToolError(graph, error, toolId) {
|
|||
});
|
||||
}
|
||||
|
||||
/** Regex pattern to match agent ID suffix (____N) */
|
||||
const AGENT_SUFFIX_PATTERN = /____(\d+)$/;
|
||||
|
||||
/**
|
||||
* Applies agent labeling to conversation history when multi-agent patterns are detected.
|
||||
* Labels content parts by their originating agent to prevent identity confusion.
|
||||
* Creates a mapMethod for getMessagesForConversation that processes agent content.
|
||||
* - Strips agentId/groupId metadata from all content
|
||||
* - For multi-agent: filters to primary agent content only (no suffix or lowest suffix)
|
||||
* - For multi-agent: applies agent labels to content
|
||||
*
|
||||
* @param {TMessage[]} orderedMessages - The ordered conversation messages
|
||||
* @param {Agent} primaryAgent - The primary agent configuration
|
||||
* @param {Map<string, Agent>} agentConfigs - Map of additional agent configurations
|
||||
* @returns {TMessage[]} Messages with agent labels applied where appropriate
|
||||
* @param {Agent} primaryAgent - Primary agent configuration
|
||||
* @param {Map<string, Agent>} [agentConfigs] - Additional agent configurations
|
||||
* @returns {(message: TMessage) => TMessage} Map method for processing messages
|
||||
*/
|
||||
function applyAgentLabelsToHistory(orderedMessages, primaryAgent, agentConfigs) {
|
||||
const shouldLabelByAgent = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
|
||||
|
||||
if (!shouldLabelByAgent) {
|
||||
return orderedMessages;
|
||||
}
|
||||
|
||||
const processedMessages = [];
|
||||
|
||||
for (let i = 0; i < orderedMessages.length; i++) {
|
||||
const message = orderedMessages[i];
|
||||
|
||||
/** @type {Record<string, string>} */
|
||||
const agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
|
||||
function createMultiAgentMapper(primaryAgent, agentConfigs) {
|
||||
const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
|
||||
|
||||
/** @type {Record<string, string> | null} */
|
||||
let agentNames = null;
|
||||
if (hasMultipleAgents) {
|
||||
agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
|
||||
if (agentConfigs) {
|
||||
for (const [agentId, agentConfig] of agentConfigs.entries()) {
|
||||
agentNames[agentId] = agentConfig.name || agentConfig.id;
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
!message.isCreatedByUser &&
|
||||
message.metadata?.agentIdMap &&
|
||||
Array.isArray(message.content)
|
||||
) {
|
||||
try {
|
||||
const labeledContent = labelContentByAgent(
|
||||
message.content,
|
||||
message.metadata.agentIdMap,
|
||||
agentNames,
|
||||
);
|
||||
|
||||
processedMessages.push({ ...message, content: labeledContent });
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error applying agent labels to message:', error);
|
||||
processedMessages.push(message);
|
||||
}
|
||||
} else {
|
||||
processedMessages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
return processedMessages;
|
||||
return (message) => {
|
||||
if (message.isCreatedByUser || !Array.isArray(message.content)) {
|
||||
return message;
|
||||
}
|
||||
|
||||
// Find primary agent ID (no suffix, or lowest suffix number) - only needed for multi-agent
|
||||
let primaryAgentId = null;
|
||||
let hasAgentMetadata = false;
|
||||
|
||||
if (hasMultipleAgents) {
|
||||
let lowestSuffixIndex = Infinity;
|
||||
for (const part of message.content) {
|
||||
const agentId = part?.agentId;
|
||||
if (!agentId) {
|
||||
continue;
|
||||
}
|
||||
hasAgentMetadata = true;
|
||||
|
||||
const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN);
|
||||
if (!suffixMatch) {
|
||||
primaryAgentId = agentId;
|
||||
break;
|
||||
}
|
||||
const suffixIndex = parseInt(suffixMatch[1], 10);
|
||||
if (suffixIndex < lowestSuffixIndex) {
|
||||
lowestSuffixIndex = suffixIndex;
|
||||
primaryAgentId = agentId;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single agent: just check if any metadata exists
|
||||
hasAgentMetadata = message.content.some((part) => part?.agentId || part?.groupId);
|
||||
}
|
||||
|
||||
if (!hasAgentMetadata) {
|
||||
return message;
|
||||
}
|
||||
|
||||
try {
|
||||
/** @type {Array<TMessageContentParts>} */
|
||||
const filteredContent = [];
|
||||
/** @type {Record<number, string>} */
|
||||
const agentIdMap = {};
|
||||
|
||||
for (const part of message.content) {
|
||||
const agentId = part?.agentId;
|
||||
// For single agent: include all parts; for multi-agent: filter to primary
|
||||
if (!hasMultipleAgents || !agentId || agentId === primaryAgentId) {
|
||||
const newIndex = filteredContent.length;
|
||||
const { agentId: _a, groupId: _g, ...cleanPart } = part;
|
||||
filteredContent.push(cleanPart);
|
||||
if (agentId && hasMultipleAgents) {
|
||||
agentIdMap[newIndex] = agentId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const finalContent =
|
||||
Object.keys(agentIdMap).length > 0 && agentNames
|
||||
? labelContentByAgent(filteredContent, agentIdMap, agentNames)
|
||||
: filteredContent;
|
||||
|
||||
return { ...message, content: finalContent };
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error processing multi-agent message:', error);
|
||||
return message;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
class AgentClient extends BaseClient {
|
||||
|
|
@ -199,8 +241,6 @@ class AgentClient extends BaseClient {
|
|||
this.indexTokenCountMap = {};
|
||||
/** @type {(messages: BaseMessage[]) => Promise<void>} */
|
||||
this.processMemory;
|
||||
/** @type {Record<number, string> | null} */
|
||||
this.agentIdMap = null;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -289,18 +329,13 @@ class AgentClient extends BaseClient {
|
|||
{ instructions = null, additional_instructions = null },
|
||||
opts,
|
||||
) {
|
||||
let orderedMessages = this.constructor.getMessagesForConversation({
|
||||
const orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
summary: this.shouldSummarize,
|
||||
mapMethod: createMultiAgentMapper(this.options.agent, this.agentConfigs),
|
||||
});
|
||||
|
||||
orderedMessages = applyAgentLabelsToHistory(
|
||||
orderedMessages,
|
||||
this.options.agent,
|
||||
this.agentConfigs,
|
||||
);
|
||||
|
||||
let payload;
|
||||
/** @type {number | undefined} */
|
||||
let promptTokens;
|
||||
|
|
@ -552,10 +587,9 @@ class AgentClient extends BaseClient {
|
|||
agent: prelimAgent,
|
||||
allowedProviders,
|
||||
endpointOption: {
|
||||
endpoint:
|
||||
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
endpoint: !isEphemeralAgentId(prelimAgent.id)
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -595,10 +629,12 @@ class AgentClient extends BaseClient {
|
|||
const userId = this.options.req.user.id + '';
|
||||
const messageId = this.responseMessageId + '';
|
||||
const conversationId = this.conversationId + '';
|
||||
const streamId = this.options.req?._resumableStreamId || null;
|
||||
const [withoutKeys, processMemory] = await createMemoryProcessor({
|
||||
userId,
|
||||
config,
|
||||
messageId,
|
||||
streamId,
|
||||
conversationId,
|
||||
memoryMethods: {
|
||||
setMemory: db.setMemory,
|
||||
|
|
@ -692,9 +728,7 @@ class AgentClient extends BaseClient {
|
|||
});
|
||||
|
||||
const completion = filterMalformedContentParts(this.contentParts);
|
||||
const metadata = this.agentIdMap ? { agentIdMap: this.agentIdMap } : undefined;
|
||||
|
||||
return { completion, metadata };
|
||||
return { completion };
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -890,12 +924,10 @@ class AgentClient extends BaseClient {
|
|||
*/
|
||||
const runAgents = async (messages) => {
|
||||
const agents = [this.options.agent];
|
||||
if (
|
||||
this.agentConfigs &&
|
||||
this.agentConfigs.size > 0 &&
|
||||
((this.options.agent.edges?.length ?? 0) > 0 ||
|
||||
(await checkCapability(this.options.req, AgentCapabilities.chain)))
|
||||
) {
|
||||
// Include additional agents when:
|
||||
// - agentConfigs has agents (from addedConvo parallel execution or agent handoffs)
|
||||
// - Agents without incoming edges become start nodes and run in parallel automatically
|
||||
if (this.agentConfigs && this.agentConfigs.size > 0) {
|
||||
agents.push(...this.agentConfigs.values());
|
||||
}
|
||||
|
||||
|
|
@ -955,6 +987,12 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
this.run = run;
|
||||
|
||||
const streamId = this.options.req?._resumableStreamId;
|
||||
if (streamId && run.Graph) {
|
||||
GenerationJobManager.setGraph(streamId, run.Graph);
|
||||
}
|
||||
|
||||
if (userMCPAuthMap != null) {
|
||||
config.configurable.userMCPAuthMap = userMCPAuthMap;
|
||||
}
|
||||
|
|
@ -985,24 +1023,6 @@ class AgentClient extends BaseClient {
|
|||
);
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
/** Capture agent ID map if we have edges or multiple agents */
|
||||
const shouldStoreAgentMap =
|
||||
(this.options.agent.edges?.length ?? 0) > 0 || (this.agentConfigs?.size ?? 0) > 0;
|
||||
if (shouldStoreAgentMap && run?.Graph) {
|
||||
const contentPartAgentMap = run.Graph.getContentPartAgentMap();
|
||||
if (contentPartAgentMap && contentPartAgentMap.size > 0) {
|
||||
this.agentIdMap = Object.fromEntries(contentPartAgentMap);
|
||||
logger.debug('[AgentClient] Captured agent ID map:', {
|
||||
totalParts: this.contentParts.length,
|
||||
mappedParts: Object.keys(this.agentIdMap).length,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error capturing agent ID map:', error);
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
|
|
|
|||
|
|
@ -2,14 +2,11 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
sendEvent,
|
||||
GenerationJobManager,
|
||||
sanitizeFileForTransmit,
|
||||
sanitizeMessageForTransmit,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const { handleAbortError } = require('~/server/middleware');
|
||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||
const { saveMessage } = require('~/models');
|
||||
|
||||
|
|
@ -31,12 +28,16 @@ function createCloseHandler(abortController) {
|
|||
};
|
||||
}
|
||||
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
/**
|
||||
* Resumable Agent Controller - Generation runs independently of HTTP connection.
|
||||
* Returns streamId immediately, client subscribes separately via SSE.
|
||||
*/
|
||||
const ResumableAgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
conversationId: reqConversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
|
|
@ -44,18 +45,354 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let sender;
|
||||
let abortKey;
|
||||
const userId = req.user.id;
|
||||
|
||||
// Generate conversationId upfront if not provided - streamId === conversationId always
|
||||
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
||||
const conversationId =
|
||||
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
|
||||
const streamId = conversationId;
|
||||
|
||||
let client = null;
|
||||
|
||||
try {
|
||||
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
|
||||
req._resumableStreamId = streamId;
|
||||
|
||||
// Send JSON response IMMEDIATELY so client can connect to SSE stream
|
||||
// This is critical: tool loading (MCP OAuth) may emit events that the client needs to receive
|
||||
res.json({ streamId, conversationId, status: 'started' });
|
||||
|
||||
// Note: We no longer use res.on('close') to abort since we send JSON immediately.
|
||||
// The response closes normally after res.json(), which is not an abort condition.
|
||||
// Abort handling is done through GenerationJobManager via the SSE stream connection.
|
||||
|
||||
// Track if partial response was already saved to avoid duplicates
|
||||
let partialResponseSaved = false;
|
||||
|
||||
/**
|
||||
* Listen for all subscribers leaving to save partial response.
|
||||
* This ensures the response is saved to DB even if all clients disconnect
|
||||
* while generation continues.
|
||||
*
|
||||
* Note: The messageId used here falls back to `${userMessage.messageId}_` if the
|
||||
* actual response messageId isn't available yet. The final response save will
|
||||
* overwrite this with the complete response using the same messageId pattern.
|
||||
*/
|
||||
job.emitter.on('allSubscribersLeft', async (aggregatedContent) => {
|
||||
if (partialResponseSaved || !aggregatedContent || aggregatedContent.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resumeState = await GenerationJobManager.getResumeState(streamId);
|
||||
if (!resumeState?.userMessage) {
|
||||
logger.debug('[ResumableAgentController] No user message to save partial response for');
|
||||
return;
|
||||
}
|
||||
|
||||
partialResponseSaved = true;
|
||||
const responseConversationId = resumeState.conversationId || conversationId;
|
||||
|
||||
try {
|
||||
const partialMessage = {
|
||||
messageId: resumeState.responseMessageId || `${resumeState.userMessage.messageId}_`,
|
||||
conversationId: responseConversationId,
|
||||
parentMessageId: resumeState.userMessage.messageId,
|
||||
sender: client?.sender ?? 'AI',
|
||||
content: aggregatedContent,
|
||||
unfinished: true,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
user: userId,
|
||||
endpoint: endpointOption.endpoint,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
if (req.body?.agent_id) {
|
||||
partialMessage.agent_id = req.body.agent_id;
|
||||
}
|
||||
|
||||
await saveMessage(req, partialMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - partial response on disconnect',
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`[ResumableAgentController] Saved partial response for ${streamId}, content parts: ${aggregatedContent.length}`,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[ResumableAgentController] Error saving partial response:', error);
|
||||
// Reset flag so we can try again if subscribers reconnect and leave again
|
||||
partialResponseSaved = false;
|
||||
}
|
||||
});
|
||||
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
// Use the job's abort controller signal - allows abort via GenerationJobManager.abortJob()
|
||||
signal: job.abortController.signal,
|
||||
});
|
||||
|
||||
if (job.abortController.signal.aborted) {
|
||||
GenerationJobManager.completeJob(streamId, 'Request aborted during initialization');
|
||||
return;
|
||||
}
|
||||
|
||||
client = result.client;
|
||||
|
||||
if (client?.sender) {
|
||||
GenerationJobManager.updateMetadata(streamId, { sender: client.sender });
|
||||
}
|
||||
|
||||
// Store reference to client's contentParts - graph will be set when run is created
|
||||
if (client?.contentParts) {
|
||||
GenerationJobManager.setContentParts(streamId, client.contentParts);
|
||||
}
|
||||
|
||||
let userMessage;
|
||||
|
||||
const getReqData = (data = {}) => {
|
||||
if (data.userMessage) {
|
||||
userMessage = data.userMessage;
|
||||
}
|
||||
// conversationId is pre-generated, no need to update from callback
|
||||
};
|
||||
|
||||
// Start background generation - readyPromise resolves immediately now
|
||||
// (sync mechanism handles late subscribers)
|
||||
const startGeneration = async () => {
|
||||
try {
|
||||
// Short timeout as safety net - promise should already be resolved
|
||||
await Promise.race([job.readyPromise, new Promise((resolve) => setTimeout(resolve, 100))]);
|
||||
} catch (waitError) {
|
||||
logger.warn(
|
||||
`[ResumableAgentController] Error waiting for subscriber: ${waitError.message}`,
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const onStart = (userMsg, respMsgId, _isNewConvo) => {
|
||||
userMessage = userMsg;
|
||||
|
||||
// Store userMessage and responseMessageId upfront for resume capability
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
responseMessageId: respMsgId,
|
||||
userMessage: {
|
||||
messageId: userMsg.messageId,
|
||||
parentMessageId: userMsg.parentMessageId,
|
||||
conversationId: userMsg.conversationId,
|
||||
text: userMsg.text,
|
||||
},
|
||||
});
|
||||
|
||||
GenerationJobManager.emitChunk(streamId, {
|
||||
created: true,
|
||||
message: userMessage,
|
||||
streamId,
|
||||
});
|
||||
};
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
onStart,
|
||||
getReqData,
|
||||
isContinued,
|
||||
isRegenerate,
|
||||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController: job.abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res: {
|
||||
write: () => true,
|
||||
end: () => {},
|
||||
headersSent: false,
|
||||
writableEnded: false,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const response = await client.sendMessage(text, messageOptions);
|
||||
|
||||
const messageId = response.messageId;
|
||||
const endpoint = endpointOption.endpoint;
|
||||
response.endpoint = endpoint;
|
||||
|
||||
const databasePromise = response.databasePromise;
|
||||
delete response.databasePromise;
|
||||
|
||||
const { conversation: convoData = {} } = await databasePromise;
|
||||
const conversation = { ...convoData };
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
if (req.body.files && client.options?.attachments) {
|
||||
userMessage.files = [];
|
||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
||||
for (const attachment of client.options.attachments) {
|
||||
if (messageFiles.has(attachment.file_id)) {
|
||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
||||
}
|
||||
}
|
||||
delete userMessage.image_urls;
|
||||
}
|
||||
|
||||
// Check abort state BEFORE calling completeJob (which triggers abort signal for cleanup)
|
||||
const wasAbortedBeforeComplete = job.abortController.signal.aborted;
|
||||
const isNewConvo = !reqConversationId || reqConversationId === 'new';
|
||||
const shouldGenerateTitle =
|
||||
addTitle &&
|
||||
parentMessageId === Constants.NO_PARENT &&
|
||||
isNewConvo &&
|
||||
!wasAbortedBeforeComplete;
|
||||
|
||||
if (!wasAbortedBeforeComplete) {
|
||||
const finalEvent = {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: { ...response },
|
||||
};
|
||||
|
||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||
GenerationJobManager.completeJob(streamId);
|
||||
|
||||
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...response, user: userId },
|
||||
{ context: 'api/server/controllers/agents/request.js - resumable response end' },
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const finalEvent = {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: { ...response, error: true },
|
||||
error: { message: 'Request was aborted' },
|
||||
};
|
||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||
GenerationJobManager.completeJob(streamId, 'Request aborted');
|
||||
}
|
||||
|
||||
if (!client.skipSaveUserMessage && userMessage) {
|
||||
await saveMessage(req, userMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - resumable user message',
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldGenerateTitle) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
client,
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error('[ResumableAgentController] Error in title generation', err);
|
||||
})
|
||||
.finally(() => {
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Check if this was an abort (not a real error)
|
||||
const wasAborted = job.abortController.signal.aborted || error.message?.includes('abort');
|
||||
|
||||
if (wasAborted) {
|
||||
logger.debug(`[ResumableAgentController] Generation aborted for ${streamId}`);
|
||||
// abortJob already handled emitDone and completeJob
|
||||
} else {
|
||||
logger.error(`[ResumableAgentController] Generation error for ${streamId}:`, error);
|
||||
GenerationJobManager.emitError(streamId, error.message || 'Generation failed');
|
||||
GenerationJobManager.completeJob(streamId, error.message);
|
||||
}
|
||||
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
|
||||
// Don't continue to title generation after error/abort
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Start generation and handle any unhandled errors
|
||||
startGeneration().catch((err) => {
|
||||
logger.error(
|
||||
`[ResumableAgentController] Unhandled error in background generation: ${err.message}`,
|
||||
);
|
||||
GenerationJobManager.completeJob(streamId, err.message);
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[ResumableAgentController] Initialization error:', error);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({ error: error.message || 'Failed to start generation' });
|
||||
} else {
|
||||
// JSON already sent, emit error to stream so client can receive it
|
||||
GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation');
|
||||
}
|
||||
GenerationJobManager.completeJob(streamId, error.message);
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Agent Controller - Routes to ResumableAgentController for all requests.
|
||||
* The legacy non-resumable path is kept below but no longer used by default.
|
||||
*/
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
return ResumableAgentController(req, res, next, initializeClient, addTitle);
|
||||
};
|
||||
|
||||
/**
|
||||
* Legacy Non-resumable Agent Controller - Uses GenerationJobManager for abort handling.
|
||||
* Response is streamed directly to client via res, but abort state is managed centrally.
|
||||
* @deprecated Use ResumableAgentController instead
|
||||
*/
|
||||
const _LegacyAgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId: reqConversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
// Generate conversationId upfront if not provided - streamId === conversationId always
|
||||
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
||||
const conversationId =
|
||||
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
|
||||
const streamId = conversationId;
|
||||
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let userMessagePromise;
|
||||
let getAbortData;
|
||||
let client = null;
|
||||
let cleanupHandlers = [];
|
||||
|
||||
const newConvo = !conversationId;
|
||||
// Match the same logic used for conversationId generation above
|
||||
const isNewConvo = !reqConversationId || reqConversationId === 'new';
|
||||
const userId = req.user.id;
|
||||
|
||||
// Create handler to avoid capturing the entire parent scope
|
||||
|
|
@ -64,24 +401,20 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
// Update job metadata with prompt tokens for abort handling
|
||||
GenerationJobManager.updateMetadata(streamId, { promptTokens: data[key] });
|
||||
} else if (key === 'sender') {
|
||||
sender = data[key];
|
||||
} else if (key === 'abortKey') {
|
||||
abortKey = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
GenerationJobManager.updateMetadata(streamId, { sender: data[key] });
|
||||
}
|
||||
// conversationId is pre-generated, no need to update from callback
|
||||
}
|
||||
};
|
||||
|
||||
// Create a function to handle final cleanup
|
||||
const performCleanup = () => {
|
||||
const performCleanup = async () => {
|
||||
logger.debug('[AgentController] Performing cleanup');
|
||||
if (Array.isArray(cleanupHandlers)) {
|
||||
for (const handler of cleanupHandlers) {
|
||||
|
|
@ -95,10 +428,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
}
|
||||
|
||||
// Clean up abort controller
|
||||
if (abortKey) {
|
||||
logger.debug('[AgentController] Cleaning up abort controller');
|
||||
cleanupAbortController(abortKey);
|
||||
// Complete the job in GenerationJobManager
|
||||
if (streamId) {
|
||||
logger.debug('[AgentController] Completing job in GenerationJobManager');
|
||||
await GenerationJobManager.completeJob(streamId);
|
||||
}
|
||||
|
||||
// Dispose client properly
|
||||
|
|
@ -110,11 +443,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
client = null;
|
||||
getReqData = null;
|
||||
userMessage = null;
|
||||
getAbortData = null;
|
||||
endpointOption.agent = null;
|
||||
endpointOption = null;
|
||||
cleanupHandlers = null;
|
||||
userMessagePromise = null;
|
||||
|
||||
// Clear request data map
|
||||
if (requestDataMap.has(req)) {
|
||||
|
|
@ -136,6 +465,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
};
|
||||
cleanupHandlers.push(removePrelimHandler);
|
||||
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
|
|
@ -143,6 +473,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
endpointOption,
|
||||
signal: prelimAbortController.signal,
|
||||
});
|
||||
|
||||
if (prelimAbortController.signal?.aborted) {
|
||||
prelimAbortController = null;
|
||||
throw new Error('Request was aborted before initialization could complete');
|
||||
|
|
@ -161,28 +492,24 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
// Store request data in WeakMap keyed by req object
|
||||
requestDataMap.set(req, { client });
|
||||
|
||||
// Use WeakRef to allow GC but still access content if it exists
|
||||
const contentRef = new WeakRef(client.contentParts || []);
|
||||
// Create job in GenerationJobManager for abort handling
|
||||
// streamId === conversationId (pre-generated above)
|
||||
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
|
||||
|
||||
// Minimize closure scope - only capture small primitives and WeakRef
|
||||
getAbortData = () => {
|
||||
// Dereference WeakRef each time
|
||||
const content = contentRef.deref();
|
||||
// Store endpoint metadata for abort handling
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
sender: client?.sender,
|
||||
});
|
||||
|
||||
return {
|
||||
sender,
|
||||
content: content || [],
|
||||
userMessage,
|
||||
promptTokens,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
};
|
||||
};
|
||||
// Store content parts reference for abort
|
||||
if (client?.contentParts) {
|
||||
GenerationJobManager.setContentParts(streamId, client.contentParts);
|
||||
}
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
const closeHandler = createCloseHandler(abortController);
|
||||
const closeHandler = createCloseHandler(job.abortController);
|
||||
res.on('close', closeHandler);
|
||||
cleanupHandlers.push(() => {
|
||||
try {
|
||||
|
|
@ -192,6 +519,27 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* onStart callback - stores user message and response ID for abort handling
|
||||
*/
|
||||
const onStart = (userMsg, respMsgId, _isNewConvo) => {
|
||||
sendEvent(res, { message: userMsg, created: true });
|
||||
userMessage = userMsg;
|
||||
userMessageId = userMsg.messageId;
|
||||
responseMessageId = respMsgId;
|
||||
|
||||
// Store metadata for abort handling (conversationId is pre-generated)
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
responseMessageId: respMsgId,
|
||||
userMessage: {
|
||||
messageId: userMsg.messageId,
|
||||
parentMessageId: userMsg.parentMessageId,
|
||||
conversationId,
|
||||
text: userMsg.text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
onStart,
|
||||
|
|
@ -201,7 +549,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController,
|
||||
abortController: job.abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
|
|
@ -241,7 +589,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
|
||||
// Only send if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
if (!job.abortController.signal.aborted) {
|
||||
// Create a new response object with minimal copies
|
||||
const finalResponse = { ...response };
|
||||
|
||||
|
|
@ -292,7 +640,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
|
||||
// Add title if needed - extract minimal data
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && isNewConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
|
|
@ -315,7 +663,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
// Handle error without capturing much scope
|
||||
handleAbortError(res, req, error, {
|
||||
conversationId,
|
||||
sender,
|
||||
sender: client?.sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
|
||||
userMessageId,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,54 @@
|
|||
* @import { MCPServerDocument } from 'librechat-data-provider'
|
||||
*/
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
isMCPDomainNotAllowedError,
|
||||
isMCPInspectionFailedError,
|
||||
MCPErrorCodes,
|
||||
} = require('@librechat/api');
|
||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
||||
|
||||
/**
|
||||
* Handles MCP-specific errors and sends appropriate HTTP responses.
|
||||
* @param {Error} error - The error to handle
|
||||
* @param {import('express').Response} res - Express response object
|
||||
* @returns {import('express').Response | null} Response if handled, null if not an MCP error
|
||||
*/
|
||||
function handleMCPError(error, res) {
|
||||
if (isMCPDomainNotAllowedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
if (isMCPInspectionFailedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
// Fallback for legacy string-based error handling (backwards compatibility)
|
||||
if (error.message?.startsWith(MCPErrorCodes.DOMAIN_NOT_ALLOWED)) {
|
||||
return res.status(403).json({
|
||||
error: MCPErrorCodes.DOMAIN_NOT_ALLOWED,
|
||||
message: error.message.replace(/^MCP_DOMAIN_NOT_ALLOWED\s*:\s*/i, ''),
|
||||
});
|
||||
}
|
||||
|
||||
if (error.message?.startsWith(MCPErrorCodes.INSPECTION_FAILED)) {
|
||||
return res.status(400).json({
|
||||
error: MCPErrorCodes.INSPECTION_FAILED,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
*/
|
||||
|
|
@ -175,11 +219,9 @@ const createMCPServerController = async (req, res) => {
|
|||
});
|
||||
} catch (error) {
|
||||
logger.error('[createMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
@ -235,11 +277,9 @@ const updateMCPServerController = async (req, res) => {
|
|||
res.status(200).json(parsedConfig);
|
||||
} catch (error) {
|
||||
logger.error('[updateMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED:')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ const {
|
|||
performStartupChecks,
|
||||
handleJsonParseError,
|
||||
initializeFileStorage,
|
||||
GenerationJobManager,
|
||||
createStreamServices,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
|
|
@ -192,6 +194,11 @@ const startServer = async () => {
|
|||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await checkMigrations();
|
||||
|
||||
// Configure stream services (auto-detects Redis from USE_REDIS env var)
|
||||
const streamServices = createStreamServices();
|
||||
GenerationJobManager.configure(streamServices);
|
||||
GenerationJobManager.initialize();
|
||||
});
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
// abortControllers.js
|
||||
module.exports = new Map();
|
||||
|
|
@ -1,124 +1,102 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent, sanitizeMessageForTransmit } = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
countTokens,
|
||||
isEnabled,
|
||||
sendEvent,
|
||||
GenerationJobManager,
|
||||
sanitizeMessageForTransmit,
|
||||
} = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { abortRun } = require('./abortRun');
|
||||
|
||||
const abortDataMap = new WeakMap();
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {boolean}
|
||||
* Abort an active message generation.
|
||||
* Uses GenerationJobManager for all agent requests.
|
||||
* Since streamId === conversationId, we can directly abort by conversationId.
|
||||
*/
|
||||
function cleanupAbortController(abortKey) {
|
||||
if (!abortControllers.has(abortKey)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey);
|
||||
|
||||
if (!abortController) {
|
||||
abortControllers.delete(abortKey);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 1. Check if this controller has any composed signals and clean them up
|
||||
try {
|
||||
// This creates a temporary composed signal to use for cleanup
|
||||
const composedSignal = AbortSignal.any([abortController.signal]);
|
||||
|
||||
// Get all event types - in practice, AbortSignal typically only uses 'abort'
|
||||
const eventTypes = ['abort'];
|
||||
|
||||
// First, execute a dummy listener removal to handle potential composed signals
|
||||
for (const eventType of eventTypes) {
|
||||
const dummyHandler = () => {};
|
||||
composedSignal.addEventListener(eventType, dummyHandler);
|
||||
composedSignal.removeEventListener(eventType, dummyHandler);
|
||||
|
||||
const listeners = composedSignal.listeners?.(eventType) || [];
|
||||
for (const listener of listeners) {
|
||||
composedSignal.removeEventListener(eventType, listener);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug(`Error cleaning up composed signals: ${e}`);
|
||||
}
|
||||
|
||||
// 2. Abort the controller if not already aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
|
||||
// 3. Remove from registry
|
||||
abortControllers.delete(abortKey);
|
||||
|
||||
// 4. Clean up any data stored in the WeakMap
|
||||
if (abortDataMap.has(abortController)) {
|
||||
abortDataMap.delete(abortController);
|
||||
}
|
||||
|
||||
// 5. Clean up function references on the controller
|
||||
if (abortController.getAbortData) {
|
||||
abortController.getAbortData = null;
|
||||
}
|
||||
|
||||
if (abortController.abortCompletion) {
|
||||
abortController.abortCompletion = null;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {function(): void}
|
||||
*/
|
||||
function createCleanUpHandler(abortKey) {
|
||||
return function () {
|
||||
try {
|
||||
cleanupAbortController(abortKey);
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, endpoint } = req.body;
|
||||
const { abortKey, endpoint } = req.body;
|
||||
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||
abortKey = conversationId;
|
||||
// Use GenerationJobManager to abort the job (streamId === conversationId)
|
||||
const abortResult = await GenerationJobManager.abortJob(conversationId);
|
||||
|
||||
if (!abortResult.success) {
|
||||
if (!res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
const { jobData, content, text } = abortResult;
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||
if (!abortController) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
// Count tokens and spend them
|
||||
const completionTokens = await countTokens(text);
|
||||
const promptTokens = jobData?.promptTokens ?? 0;
|
||||
|
||||
const finalEvent = await abortController.abortCompletion?.();
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
|
||||
JSON.stringify({ abortKey }),
|
||||
const responseMessage = {
|
||||
messageId: jobData?.responseMessageId,
|
||||
parentMessageId: jobData?.userMessage?.messageId,
|
||||
conversationId: jobData?.conversationId,
|
||||
content,
|
||||
text,
|
||||
sender: jobData?.sender ?? 'AI',
|
||||
finish_reason: 'incomplete',
|
||||
endpoint: jobData?.endpoint,
|
||||
iconURL: jobData?.iconURL,
|
||||
model: jobData?.model,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: completionTokens,
|
||||
};
|
||||
|
||||
await spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user: userId },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
cleanupAbortController(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user: userId },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
// Get conversation for title
|
||||
const conversation = await getConvo(userId, conversationId);
|
||||
|
||||
const finalEvent = {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: jobData?.userMessage
|
||||
? sanitizeMessageForTransmit({
|
||||
messageId: jobData.userMessage.messageId,
|
||||
parentMessageId: jobData.userMessage.parentMessageId,
|
||||
conversationId: jobData.userMessage.conversationId,
|
||||
text: jobData.userMessage.text,
|
||||
isCreatedByUser: true,
|
||||
})
|
||||
: null,
|
||||
responseMessage,
|
||||
};
|
||||
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${userId} | ${req.user.email} | Aborted request: ${conversationId}`,
|
||||
);
|
||||
|
||||
if (res.headersSent) {
|
||||
return sendEvent(res, finalEvent);
|
||||
}
|
||||
|
||||
|
|
@ -139,171 +117,13 @@ const handleAbort = function () {
|
|||
};
|
||||
};
|
||||
|
||||
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
const abortController = new AbortController();
|
||||
const { endpointOption } = req.body;
|
||||
|
||||
// Store minimal data in WeakMap to avoid circular references
|
||||
abortDataMap.set(abortController, {
|
||||
getAbortDataFn: getAbortData,
|
||||
userId: req.user.id,
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
});
|
||||
|
||||
// Replace the direct function reference with a wrapper that uses WeakMap
|
||||
abortController.getAbortData = function () {
|
||||
const data = abortDataMap.get(this);
|
||||
if (!data || typeof data.getAbortDataFn !== 'function') {
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
const result = data.getAbortDataFn();
|
||||
|
||||
// Create a copy without circular references
|
||||
const cleanResult = { ...result };
|
||||
|
||||
// If userMessagePromise exists, break its reference to client
|
||||
if (
|
||||
cleanResult.userMessagePromise &&
|
||||
typeof cleanResult.userMessagePromise.then === 'function'
|
||||
) {
|
||||
// Create a new promise that fulfills with the same result but doesn't reference the original
|
||||
const originalPromise = cleanResult.userMessagePromise;
|
||||
cleanResult.userMessagePromise = new Promise((resolve, reject) => {
|
||||
originalPromise.then(
|
||||
(result) => resolve({ ...result }),
|
||||
(error) => reject(error),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return cleanResult;
|
||||
} catch (err) {
|
||||
logger.error('[abortController.getAbortData] Error:', err);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
* @param {boolean} [isNewConvo]
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId, isNewConvo) => {
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
|
||||
const prelimAbortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const abortKey = isNewConvo
|
||||
? `${prelimAbortKey}${Constants.COMMON_DIVIDER}${Constants.NEW_CONVO}`
|
||||
: prelimAbortKey;
|
||||
getReqData({ abortKey });
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
const { overrideUserMessageId } = req?.body ?? {};
|
||||
|
||||
if (overrideUserMessageId != null && prevRequest && prevRequest?.abortController) {
|
||||
const data = prevRequest.abortController.getAbortData();
|
||||
getReqData({ userMessage: data?.userMessage });
|
||||
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||
|
||||
// Store minimal options
|
||||
const minimalOptions = {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
|
||||
const cleanupHandler = createCleanUpHandler(addedAbortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store minimal options
|
||||
const minimalOptions = {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...minimalOptions });
|
||||
const cleanupHandler = createCleanUpHandler(abortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
};
|
||||
|
||||
// Define abortCompletion without capturing the entire parent scope
|
||||
abortController.abortCompletion = async function () {
|
||||
this.abort();
|
||||
|
||||
// Get data from WeakMap
|
||||
const ctrlData = abortDataMap.get(this);
|
||||
if (!ctrlData || !ctrlData.getAbortDataFn) {
|
||||
return { final: true, conversation: {}, title: 'New Chat' };
|
||||
}
|
||||
|
||||
// Get abort data using stored function
|
||||
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
|
||||
ctrlData.getAbortDataFn();
|
||||
|
||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||
const user = ctrlData.userId;
|
||||
|
||||
const responseMessage = {
|
||||
...responseData,
|
||||
conversationId,
|
||||
finish_reason: 'incomplete',
|
||||
endpoint: ctrlData.endpoint,
|
||||
iconURL: ctrlData.iconURL,
|
||||
model: ctrlData.modelOptions?.model ?? ctrlData.model_parameters?.model,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: completionTokens,
|
||||
};
|
||||
|
||||
await spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
let conversation;
|
||||
if (userMessagePromise) {
|
||||
const resolved = await userMessagePromise;
|
||||
conversation = resolved?.conversation;
|
||||
// Break reference to promise
|
||||
resolved.conversation = null;
|
||||
}
|
||||
|
||||
if (!conversation) {
|
||||
conversation = await getConvo(user, conversationId);
|
||||
}
|
||||
|
||||
return {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
};
|
||||
|
||||
return { abortController, onStart };
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle abort errors during generation.
|
||||
* @param {ServerResponse} res
|
||||
* @param {ServerRequest} req
|
||||
* @param {Error | unknown} error
|
||||
* @param {Partial<TMessage> & { partialText?: string }} data
|
||||
* @returns { Promise<void> }
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const handleAbortError = async (res, req, error, data) => {
|
||||
if (error?.message?.includes('base64')) {
|
||||
|
|
@ -368,8 +188,7 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
};
|
||||
}
|
||||
|
||||
const callback = createCleanUpHandler(conversationId);
|
||||
await sendError(req, res, options, callback);
|
||||
await sendError(req, res, options);
|
||||
};
|
||||
|
||||
if (partialText && partialText.length > 5) {
|
||||
|
|
@ -387,6 +206,4 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
module.exports = {
|
||||
handleAbort,
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants, isAgentsEndpoint, ResourceType } = require('librechat-data-provider');
|
||||
const {
|
||||
Constants,
|
||||
ResourceType,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
|
||||
|
|
@ -13,7 +18,8 @@ const { getAgent } = require('~/models/Agent');
|
|||
*/
|
||||
const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||
// Handle ephemeral agents - they don't need permission checks
|
||||
if (agentCustomId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
if (isEphemeralAgentId(agentCustomId)) {
|
||||
return null; // No permission check needed for ephemeral agents
|
||||
}
|
||||
|
||||
|
|
@ -62,7 +68,8 @@ const canAccessAgentFromBody = (options) => {
|
|||
}
|
||||
|
||||
// Skip permission checks for ephemeral agents
|
||||
if (agentId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
if (isEphemeralAgentId(agentId)) {
|
||||
return next();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,10 @@ async function buildEndpointOption(req, res, next) {
|
|||
try {
|
||||
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
|
||||
);
|
||||
logger.error(`Error parsing compact conversation for endpoint ${endpoint}`, error);
|
||||
logger.debug({
|
||||
'Error parsing compact conversation': { endpoint, endpointType, conversation: req.body },
|
||||
});
|
||||
return handleError(res, { text: 'Error parsing conversation' });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,15 @@ const { logViolation, getLogStores } = require('~/cache');
|
|||
|
||||
const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {};
|
||||
|
||||
/**
|
||||
* Helper function to get conversationId from different request body structures.
|
||||
* @param {Object} body - The request body.
|
||||
* @returns {string|undefined} The conversationId.
|
||||
*/
|
||||
const getConversationId = (body) => {
|
||||
return body.conversationId ?? body.arg?.conversationId;
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to validate user's authorization for a conversation.
|
||||
*
|
||||
|
|
@ -24,7 +33,7 @@ const validateConvoAccess = async (req, res, next) => {
|
|||
const namespace = ViolationTypes.CONVO_ACCESS;
|
||||
const cache = getLogStores(namespace);
|
||||
|
||||
const conversationId = req.body.conversationId;
|
||||
const conversationId = getConversationId(req.body);
|
||||
|
||||
if (!conversationId || conversationId === Constants.NEW_CONVO) {
|
||||
return next();
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ afterEach(() => {
|
|||
|
||||
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
|
||||
|
||||
// eslint-disable-next-line jest/no-disabled-tests
|
||||
describe.skip('GET /', () => {
|
||||
it('should return 200 and the correct body', async () => {
|
||||
process.env.APP_TITLE = 'Test Title';
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ jest.mock('~/server/middleware', () => ({
|
|||
forkUserLimiter: (req, res, next) => next(),
|
||||
})),
|
||||
configMiddleware: (req, res, next) => next(),
|
||||
validateConvoAccess: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/utils/import/fork', () => ({
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ const express = require('express');
|
|||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { getBasePath } = require('@librechat/api');
|
||||
|
||||
const mockRegistryInstance = {
|
||||
getServerConfig: jest.fn(),
|
||||
|
|
@ -12,26 +13,36 @@ const mockRegistryInstance = {
|
|||
removeServer: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
jest.mock('@librechat/api', () => {
|
||||
const actual = jest.requireActual('@librechat/api');
|
||||
return {
|
||||
...actual,
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
// Error handling utilities (from @librechat/api mcp/errors)
|
||||
isMCPDomainNotAllowedError: (error) => error?.code === 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
isMCPInspectionFailedError: (error) => error?.code === 'MCP_INSPECTION_FAILED',
|
||||
MCPErrorCodes: {
|
||||
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -271,27 +282,30 @@ describe('MCP Routes', () => {
|
|||
error: 'access_denied',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=access_denied');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when code is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=missing_code');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_code`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when state is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=missing_state');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when flow state is not found', async () => {
|
||||
|
|
@ -301,9 +315,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'invalid-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=invalid_state');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
||||
});
|
||||
|
||||
it('should handle OAuth callback successfully', async () => {
|
||||
|
|
@ -358,9 +373,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
'test-flow-id',
|
||||
'test-auth-code',
|
||||
|
|
@ -394,9 +410,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
|
||||
});
|
||||
|
||||
it('should handle system-level OAuth completion', async () => {
|
||||
|
|
@ -429,9 +446,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
|
|
@ -474,9 +492,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
|
@ -515,9 +534,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
|
||||
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
|
|
@ -573,9 +593,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify storeTokens was called with ORIGINAL flow state credentials
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
||||
|
|
@ -614,9 +635,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify completeOAuthFlow was NOT called (prevented duplicate)
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled();
|
||||
|
|
@ -1385,8 +1407,10 @@ describe('MCP Routes', () => {
|
|||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.expect(302);
|
||||
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(mockFlowManager.completeFlow).not.toHaveBeenCalled();
|
||||
expect(response.headers.location).toContain('/oauth/success');
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
|
||||
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
|
||||
|
|
@ -1433,7 +1457,9 @@ describe('MCP Routes', () => {
|
|||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.expect(302);
|
||||
|
||||
expect(response.headers.location).toContain('/oauth/success');
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const express = require('express');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { getAccessToken } = require('@librechat/api');
|
||||
const { getAccessToken, getBasePath } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { findToken, updateToken, createToken } = require('~/models');
|
||||
|
|
@ -24,6 +24,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
const { code, state } = req.query;
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const basePath = getBasePath();
|
||||
let identifier = action_id;
|
||||
try {
|
||||
let decodedState;
|
||||
|
|
@ -32,17 +33,17 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
} catch (err) {
|
||||
logger.error('Error verifying state parameter:', err);
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
if (decodedState.action_id !== action_id) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
if (!decodedState.user) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
identifier = `${decodedState.user}:${action_id}`;
|
||||
const flowState = await flowManager.getFlowState(identifier, 'oauth');
|
||||
|
|
@ -72,12 +73,12 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
|
||||
/** Redirect to React success page */
|
||||
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('Error in OAuth callback:', error);
|
||||
await flowManager.failFlow(identifier, 'oauth', error);
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ const express = require('express');
|
|||
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions, PermissionBits } = require('librechat-data-provider');
|
||||
const {
|
||||
setHeaders,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
validateConvoAccess,
|
||||
|
|
@ -16,8 +15,6 @@ const { getRoleByName } = require('~/models/Role');
|
|||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(moderateText);
|
||||
|
||||
const checkAgentAccess = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
|
|
@ -28,11 +25,11 @@ const checkAgentResourceAccess = canAccessAgentFromBody({
|
|||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
router.use(moderateText);
|
||||
router.use(checkAgentAccess);
|
||||
router.use(checkAgentResourceAccess);
|
||||
router.use(validateConvoAccess);
|
||||
router.use(buildEndpointOption);
|
||||
router.use(setHeaders);
|
||||
|
||||
const controller = async (req, res, next) => {
|
||||
await AgentController(req, res, next, initializeClient, addTitle);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const express = require('express');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { isEnabled, GenerationJobManager } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
|
|
@ -22,6 +23,188 @@ router.use(uaParser);
|
|||
|
||||
router.use('/', v1);
|
||||
|
||||
/**
|
||||
* Stream endpoints - mounted before chatRouter to bypass rate limiters
|
||||
* These are GET requests and don't need message body validation or rate limiting
|
||||
*/
|
||||
|
||||
/**
|
||||
* @route GET /chat/stream/:streamId
|
||||
* @desc Subscribe to an ongoing generation job's SSE stream with replay support
|
||||
* @access Private
|
||||
* @description Sends sync event with resume state, replays missed chunks, then streams live
|
||||
* @query resume=true - Indicates this is a reconnection (sends sync event)
|
||||
*/
|
||||
router.get('/chat/stream/:streamId', async (req, res) => {
|
||||
const { streamId } = req.params;
|
||||
const isResume = req.query.resume === 'true';
|
||||
|
||||
const job = await GenerationJobManager.getJob(streamId);
|
||||
if (!job) {
|
||||
return res.status(404).json({
|
||||
error: 'Stream not found',
|
||||
message: 'The generation job does not exist or has expired.',
|
||||
});
|
||||
}
|
||||
|
||||
res.setHeader('Content-Encoding', 'identity');
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.flushHeaders();
|
||||
|
||||
logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`);
|
||||
|
||||
// Send sync event with resume state for ALL reconnecting clients
|
||||
// This supports multi-tab scenarios where each tab needs run step data
|
||||
if (isResume) {
|
||||
const resumeState = await GenerationJobManager.getResumeState(streamId);
|
||||
if (resumeState && !res.writableEnded) {
|
||||
// Send sync event with run steps AND aggregatedContent
|
||||
// Client will use aggregatedContent to initialize message state
|
||||
res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
logger.debug(
|
||||
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const result = await GenerationJobManager.subscribe(
|
||||
streamId,
|
||||
(event) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
}
|
||||
},
|
||||
(event) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
res.end();
|
||||
}
|
||||
},
|
||||
(error) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
res.end();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
if (!result) {
|
||||
return res.status(404).json({ error: 'Failed to subscribe to stream' });
|
||||
}
|
||||
|
||||
req.on('close', () => {
|
||||
logger.debug(`[AgentStream] Client disconnected from ${streamId}`);
|
||||
result.unsubscribe();
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @route GET /chat/active
|
||||
* @desc Get all active generation job IDs for the current user
|
||||
* @access Private
|
||||
* @returns { activeJobIds: string[] }
|
||||
*/
|
||||
router.get('/chat/active', async (req, res) => {
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id);
|
||||
res.json({ activeJobIds });
|
||||
});
|
||||
|
||||
/**
|
||||
* @route GET /chat/status/:conversationId
|
||||
* @desc Check if there's an active generation job for a conversation
|
||||
* @access Private
|
||||
* @returns { active, streamId, status, aggregatedContent, createdAt, resumeState }
|
||||
*/
|
||||
router.get('/chat/status/:conversationId', async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
|
||||
// streamId === conversationId, so we can use getJob directly
|
||||
const job = await GenerationJobManager.getJob(conversationId);
|
||||
|
||||
if (!job) {
|
||||
return res.json({ active: false });
|
||||
}
|
||||
|
||||
if (job.metadata.userId !== req.user.id) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
// Get resume state which contains aggregatedContent
|
||||
// Avoid calling both getStreamInfo and getResumeState (both fetch content)
|
||||
const resumeState = await GenerationJobManager.getResumeState(conversationId);
|
||||
const isActive = job.status === 'running';
|
||||
|
||||
res.json({
|
||||
active: isActive,
|
||||
streamId: conversationId,
|
||||
status: job.status,
|
||||
aggregatedContent: resumeState?.aggregatedContent ?? [],
|
||||
createdAt: job.createdAt,
|
||||
resumeState,
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @route POST /chat/abort
|
||||
* @desc Abort an ongoing generation job
|
||||
* @access Private
|
||||
* @description Mounted before chatRouter to bypass buildEndpointOption middleware
|
||||
*/
|
||||
router.post('/chat/abort', async (req, res) => {
|
||||
logger.debug(`[AgentStream] ========== ABORT ENDPOINT HIT ==========`);
|
||||
logger.debug(`[AgentStream] Method: ${req.method}, Path: ${req.path}`);
|
||||
logger.debug(`[AgentStream] Body:`, req.body);
|
||||
|
||||
const { streamId, conversationId, abortKey } = req.body;
|
||||
const userId = req.user?.id;
|
||||
|
||||
// streamId === conversationId, so try any of the provided IDs
|
||||
// Skip "new" as it's a placeholder for new conversations, not an actual ID
|
||||
let jobStreamId =
|
||||
streamId || (conversationId !== 'new' ? conversationId : null) || abortKey?.split(':')[0];
|
||||
let job = jobStreamId ? await GenerationJobManager.getJob(jobStreamId) : null;
|
||||
|
||||
// Fallback: if job not found and we have a userId, look up active jobs for user
|
||||
// This handles the case where frontend sends "new" but job was created with a UUID
|
||||
if (!job && userId) {
|
||||
logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId);
|
||||
if (activeJobIds.length > 0) {
|
||||
// Abort the most recent active job for this user
|
||||
jobStreamId = activeJobIds[0];
|
||||
job = await GenerationJobManager.getJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Found active job for user: ${jobStreamId}`);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(`[AgentStream] Computed jobStreamId: ${jobStreamId}`);
|
||||
|
||||
if (job && jobStreamId) {
|
||||
logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`);
|
||||
await GenerationJobManager.abortJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`);
|
||||
return res.json({ success: true, aborted: jobStreamId });
|
||||
}
|
||||
|
||||
logger.warn(`[AgentStream] Job not found for streamId: ${jobStreamId}`);
|
||||
return res.status(404).json({ error: 'Job not found', streamId: jobStreamId });
|
||||
});
|
||||
|
||||
const chatRouter = express.Router();
|
||||
chatRouter.use(configMiddleware);
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
createImportLimiters,
|
||||
validateConvoAccess,
|
||||
createForkLimiters,
|
||||
configMiddleware,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -67,16 +68,17 @@ router.get('/:conversationId', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
router.post('/gen_title', async (req, res) => {
|
||||
const { conversationId } = req.body;
|
||||
router.get('/gen_title/:conversationId', async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${conversationId}`;
|
||||
let title = await titleCache.get(key);
|
||||
|
||||
if (!title) {
|
||||
// Retry every 1s for up to 20s
|
||||
for (let i = 0; i < 20; i++) {
|
||||
await sleep(1000);
|
||||
// Exponential backoff: 500ms, 1s, 2s, 4s, 8s (total ~15.5s max wait)
|
||||
const delays = [500, 1000, 2000, 4000, 8000];
|
||||
for (const delay of delays) {
|
||||
await sleep(delay);
|
||||
title = await titleCache.get(key);
|
||||
if (title) {
|
||||
break;
|
||||
|
|
@ -150,17 +152,39 @@ router.delete('/all', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
router.post('/update', async (req, res) => {
|
||||
const update = req.body.arg;
|
||||
/** Maximum allowed length for conversation titles */
|
||||
const MAX_CONVO_TITLE_LENGTH = 1024;
|
||||
|
||||
if (!update.conversationId) {
|
||||
/**
|
||||
* Updates a conversation's title.
|
||||
* @route POST /update
|
||||
* @param {string} req.body.arg.conversationId - The conversation ID to update.
|
||||
* @param {string} req.body.arg.title - The new title for the conversation.
|
||||
* @returns {object} 201 - The updated conversation object.
|
||||
*/
|
||||
router.post('/update', validateConvoAccess, async (req, res) => {
|
||||
const { conversationId, title } = req.body.arg ?? {};
|
||||
|
||||
if (!conversationId) {
|
||||
return res.status(400).json({ error: 'conversationId is required' });
|
||||
}
|
||||
|
||||
if (title === undefined) {
|
||||
return res.status(400).json({ error: 'title is required' });
|
||||
}
|
||||
|
||||
if (typeof title !== 'string') {
|
||||
return res.status(400).json({ error: 'title must be a string' });
|
||||
}
|
||||
|
||||
const sanitizedTitle = title.trim().slice(0, MAX_CONVO_TITLE_LENGTH);
|
||||
|
||||
try {
|
||||
const dbResponse = await saveConvo(req, update, {
|
||||
context: `POST /api/convos/update ${update.conversationId}`,
|
||||
});
|
||||
const dbResponse = await saveConvo(
|
||||
req,
|
||||
{ conversationId, title: sanitizedTitle },
|
||||
{ context: `POST /api/convos/update ${conversationId}` },
|
||||
);
|
||||
res.status(201).json(dbResponse);
|
||||
} catch (error) {
|
||||
logger.error('Error updating conversation', error);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ const {
|
|||
createSafeUser,
|
||||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
getBasePath,
|
||||
getUserMCPAuthMap,
|
||||
generateCheckAccess,
|
||||
} = require('@librechat/api');
|
||||
|
|
@ -105,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
|||
* This handles the OAuth callback after the user has authorized the application
|
||||
*/
|
||||
router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
const basePath = getBasePath();
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const { code, state, error: oauthError } = req.query;
|
||||
|
|
@ -118,17 +120,19 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
|
||||
if (oauthError) {
|
||||
logger.error('[MCP OAuth] OAuth error received', { error: oauthError });
|
||||
return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`);
|
||||
return res.redirect(
|
||||
`${basePath}/oauth/error?error=${encodeURIComponent(String(oauthError))}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!code || typeof code !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid code');
|
||||
return res.redirect('/oauth/error?error=missing_code');
|
||||
return res.redirect(`${basePath}/oauth/error?error=missing_code`);
|
||||
}
|
||||
|
||||
if (!state || typeof state !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid state');
|
||||
return res.redirect('/oauth/error?error=missing_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
|
||||
}
|
||||
|
||||
const flowId = state;
|
||||
|
|
@ -142,7 +146,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
|
||||
if (!flowState) {
|
||||
logger.error('[MCP OAuth] Flow state not found for flowId:', flowId);
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Flow state details', {
|
||||
|
|
@ -160,7 +164,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
flowId,
|
||||
serverName,
|
||||
});
|
||||
return res.redirect(`/oauth/success?serverName=${encodeURIComponent(serverName)}`);
|
||||
return res.redirect(`${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
|
|
@ -254,11 +258,11 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
}
|
||||
|
||||
/** Redirect to success page with flowId and serverName */
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] OAuth callback error', error);
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -588,7 +592,7 @@ async function getOAuthHeaders(serverName, userId) {
|
|||
return serverConfig?.oauth_headers ?? {};
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
MCP Server CRUD Routes (User-Managed MCP Servers)
|
||||
*/
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const express = require('express');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ContentTypes } = require('librechat-data-provider');
|
||||
const { unescapeLaTeX, countTokens } = require('@librechat/api');
|
||||
|
|
@ -111,6 +112,91 @@ router.get('/', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Creates a new branch message from a specific agent's content within a parallel response message.
|
||||
* Filters the original message's content to only include parts attributed to the specified agentId.
|
||||
* Only available for non-user messages with content attributions.
|
||||
*
|
||||
* @route POST /branch
|
||||
* @param {string} req.body.messageId - The ID of the source message
|
||||
* @param {string} req.body.agentId - The agentId to filter content by
|
||||
* @returns {TMessage} The newly created branch message
|
||||
*/
|
||||
router.post('/branch', async (req, res) => {
|
||||
try {
|
||||
const { messageId, agentId } = req.body;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!messageId || !agentId) {
|
||||
return res.status(400).json({ error: 'messageId and agentId are required' });
|
||||
}
|
||||
|
||||
const sourceMessage = await getMessage({ user: userId, messageId });
|
||||
if (!sourceMessage) {
|
||||
return res.status(404).json({ error: 'Source message not found' });
|
||||
}
|
||||
|
||||
if (sourceMessage.isCreatedByUser) {
|
||||
return res.status(400).json({ error: 'Cannot branch from user messages' });
|
||||
}
|
||||
|
||||
if (!Array.isArray(sourceMessage.content)) {
|
||||
return res.status(400).json({ error: 'Message does not have content' });
|
||||
}
|
||||
|
||||
const hasAgentMetadata = sourceMessage.content.some((part) => part?.agentId);
|
||||
if (!hasAgentMetadata) {
|
||||
return res
|
||||
.status(400)
|
||||
.json({ error: 'Message does not have parallel content with attributions' });
|
||||
}
|
||||
|
||||
/** @type {Array<import('librechat-data-provider').TMessageContentParts>} */
|
||||
const filteredContent = [];
|
||||
for (const part of sourceMessage.content) {
|
||||
if (part?.agentId === agentId) {
|
||||
const { agentId: _a, groupId: _g, ...cleanPart } = part;
|
||||
filteredContent.push(cleanPart);
|
||||
}
|
||||
}
|
||||
|
||||
if (filteredContent.length === 0) {
|
||||
return res.status(400).json({ error: 'No content found for the specified agentId' });
|
||||
}
|
||||
|
||||
const newMessageId = uuidv4();
|
||||
/** @type {import('librechat-data-provider').TMessage} */
|
||||
const newMessage = {
|
||||
messageId: newMessageId,
|
||||
conversationId: sourceMessage.conversationId,
|
||||
parentMessageId: sourceMessage.parentMessageId,
|
||||
attachments: sourceMessage.attachments,
|
||||
isCreatedByUser: false,
|
||||
model: sourceMessage.model,
|
||||
endpoint: sourceMessage.endpoint,
|
||||
sender: sourceMessage.sender,
|
||||
iconURL: sourceMessage.iconURL,
|
||||
content: filteredContent,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
user: userId,
|
||||
};
|
||||
|
||||
const savedMessage = await saveMessage(req, newMessage, {
|
||||
context: 'POST /api/messages/branch',
|
||||
});
|
||||
|
||||
if (!savedMessage) {
|
||||
return res.status(500).json({ error: 'Failed to save branch message' });
|
||||
}
|
||||
|
||||
res.status(201).json(savedMessage);
|
||||
} catch (error) {
|
||||
logger.error('Error creating branch message:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/artifact/:messageId', async (req, res) => {
|
||||
try {
|
||||
const { messageId } = req.params;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@ const { nanoid } = require('nanoid');
|
|||
const { tool } = require('@langchain/core/tools');
|
||||
const { GraphEvents, sleep } = require('@librechat/agents');
|
||||
const { logger, encryptV2, decryptV2 } = require('@librechat/data-schemas');
|
||||
const { sendEvent, logAxiosError, refreshAccessToken } = require('@librechat/api');
|
||||
const {
|
||||
sendEvent,
|
||||
logAxiosError,
|
||||
refreshAccessToken,
|
||||
GenerationJobManager,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
|
|
@ -127,6 +132,7 @@ async function loadActionSets(searchParams) {
|
|||
* @param {string | undefined} [params.description] - The description for the tool.
|
||||
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
|
||||
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable streams.
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createActionTool({
|
||||
|
|
@ -138,6 +144,7 @@ async function createActionTool({
|
|||
name,
|
||||
description,
|
||||
encrypted,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolInput, config) => {
|
||||
|
|
@ -192,7 +199,12 @@ async function createActionTool({
|
|||
`${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`,
|
||||
'oauth_login',
|
||||
async () => {
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
logger.debug('Sent OAuth login request to client', { action_id, identifier });
|
||||
return true;
|
||||
},
|
||||
|
|
@ -217,7 +229,12 @@ async function createActionTool({
|
|||
logger.debug('Received OAuth Authorization response', { action_id, identifier });
|
||||
data.delta.auth = undefined;
|
||||
data.delta.expires_at = undefined;
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const successEventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, successEventData);
|
||||
} else {
|
||||
sendEvent(res, successEventData);
|
||||
}
|
||||
await sleep(3000);
|
||||
metadata.oauth_access_token = result.access_token;
|
||||
metadata.oauth_refresh_token = result.refresh_token;
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { webcrypto } = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, checkEmailConfig, isEmailDomainAllowed } = require('@librechat/api');
|
||||
const {
|
||||
logger,
|
||||
DEFAULT_SESSION_EXPIRY,
|
||||
DEFAULT_REFRESH_TOKEN_EXPIRY,
|
||||
} = require('@librechat/data-schemas');
|
||||
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const { isEnabled, checkEmailConfig, isEmailDomainAllowed, math } = require('@librechat/api');
|
||||
const {
|
||||
findUser,
|
||||
findToken,
|
||||
|
|
@ -369,19 +373,21 @@ const setAuthTokens = async (userId, res, _session = null) => {
|
|||
let session = _session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
const expiresIn = math(process.env.REFRESH_TOKEN_EXPIRY, DEFAULT_REFRESH_TOKEN_EXPIRY);
|
||||
|
||||
if (session && session._id && session.expiration != null) {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
const result = await createSession(userId);
|
||||
const result = await createSession(userId, { expiresIn });
|
||||
session = result.session;
|
||||
refreshToken = result.refreshToken;
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
const sessionExpiry = math(process.env.SESSION_EXPIRY, DEFAULT_SESSION_EXPIRY);
|
||||
const token = await generateToken(user, sessionExpiry);
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
|
|
@ -418,10 +424,10 @@ const setOpenIDAuthTokens = (tokenset, res, userId, existingRefreshToken) => {
|
|||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
return;
|
||||
}
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY
|
||||
? eval(REFRESH_TOKEN_EXPIRY)
|
||||
: 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
const expiryInMilliseconds = math(
|
||||
process.env.REFRESH_TOKEN_EXPIRY,
|
||||
DEFAULT_REFRESH_TOKEN_EXPIRY,
|
||||
);
|
||||
const expirationDate = new Date(Date.now() + expiryInMilliseconds);
|
||||
if (tokenset == null) {
|
||||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { CacheKeys, Time } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
|
|
@ -39,12 +39,12 @@ async function getCachedTools(options = {}) {
|
|||
* @param {Object} options - Options for caching tools
|
||||
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||
* @param {string} [options.serverName] - MCP server name for server-specific tools
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds (default: 12 hours)
|
||||
* @returns {Promise<boolean>} Whether the operation was successful
|
||||
*/
|
||||
async function setCachedTools(tools, options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { userId, serverName, ttl } = options;
|
||||
const { userId, serverName, ttl = Time.TWELVE_HOURS } = options;
|
||||
|
||||
// Cache by MCP server if specified (requires userId)
|
||||
if (serverName && userId) {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@ async function getEndpointsConfig(req) {
|
|||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
if (cachedEndpointsConfig) {
|
||||
return cachedEndpointsConfig;
|
||||
if (cachedEndpointsConfig.gptPlugins) {
|
||||
await cache.delete(CacheKeys.ENDPOINT_CONFIG);
|
||||
} else {
|
||||
return cachedEndpointsConfig;
|
||||
}
|
||||
}
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
|
|
|
|||
136
api/server/services/Endpoints/agents/addedConvo.js
Normal file
136
api/server/services/Endpoints/agents/addedConvo.js
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { initializeAgent, validateAgentModel } = require('@librechat/api');
|
||||
const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
||||
// Initialize the getAgent dependency
|
||||
setGetAgent(getAgent);
|
||||
|
||||
/**
|
||||
* Process addedConvo for parallel agent execution.
|
||||
* Creates a parallel agent config from an added conversation.
|
||||
*
|
||||
* When an added agent has no incoming edges, it becomes a start node
|
||||
* and runs in parallel with the primary agent automatically.
|
||||
*
|
||||
* Edge cases handled:
|
||||
* - Primary agent has edges (handoffs): Added agent runs in parallel with primary,
|
||||
* but doesn't participate in the primary's handoff graph
|
||||
* - Primary agent has agent_ids (legacy chain): Added agent runs in parallel with primary,
|
||||
* but doesn't participate in the chain
|
||||
* - Primary agent has both: Added agent is independent, runs parallel from start
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req
|
||||
* @param {import('express').Response} params.res
|
||||
* @param {Object} params.endpointOption - The endpoint option containing addedConvo
|
||||
* @param {Object} params.modelsConfig - The models configuration
|
||||
* @param {Function} params.logViolation - Function to log violations
|
||||
* @param {Function} params.loadTools - Function to load agent tools
|
||||
* @param {Array} params.requestFiles - Request files
|
||||
* @param {string} params.conversationId - The conversation ID
|
||||
* @param {Set} params.allowedProviders - Set of allowed providers
|
||||
* @param {Map} params.agentConfigs - Map of agent configs to add to
|
||||
* @param {string} params.primaryAgentId - The primary agent ID
|
||||
* @param {Object|undefined} params.userMCPAuthMap - User MCP auth map to merge into
|
||||
* @returns {Promise<{userMCPAuthMap: Object|undefined}>} The updated userMCPAuthMap
|
||||
*/
|
||||
const processAddedConvo = async ({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
allowedProviders,
|
||||
agentConfigs,
|
||||
primaryAgentId,
|
||||
primaryAgent,
|
||||
userMCPAuthMap,
|
||||
}) => {
|
||||
const addedConvo = endpointOption.addedConvo;
|
||||
logger.debug('[processAddedConvo] Called with addedConvo:', {
|
||||
hasAddedConvo: addedConvo != null,
|
||||
addedConvoEndpoint: addedConvo?.endpoint,
|
||||
addedConvoModel: addedConvo?.model,
|
||||
addedConvoAgentId: addedConvo?.agent_id,
|
||||
});
|
||||
if (addedConvo == null) {
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
try {
|
||||
const addedAgent = await loadAddedAgent({ req, conversation: addedConvo, primaryAgent });
|
||||
if (!addedAgent) {
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
const addedValidation = await validateAgentModel({
|
||||
req,
|
||||
res,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
agent: addedAgent,
|
||||
});
|
||||
|
||||
if (!addedValidation.isValid) {
|
||||
logger.warn(
|
||||
`[processAddedConvo] Added agent validation failed: ${addedValidation.error?.message}`,
|
||||
);
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
const addedConfig = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
agent: addedAgent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
},
|
||||
{
|
||||
getConvoFiles,
|
||||
getFiles: db.getFiles,
|
||||
getUserKey: db.getUserKey,
|
||||
updateFilesUsage: db.updateFilesUsage,
|
||||
getUserKeyValues: db.getUserKeyValues,
|
||||
getToolFilesByIds: db.getToolFilesByIds,
|
||||
},
|
||||
);
|
||||
|
||||
if (userMCPAuthMap != null) {
|
||||
Object.assign(userMCPAuthMap, addedConfig.userMCPAuthMap ?? {});
|
||||
} else {
|
||||
userMCPAuthMap = addedConfig.userMCPAuthMap;
|
||||
}
|
||||
|
||||
const addedAgentId = addedConfig.id || ADDED_AGENT_ID;
|
||||
agentConfigs.set(addedAgentId, addedConfig);
|
||||
|
||||
// No edges needed - agent without incoming edges becomes a start node
|
||||
// and runs in parallel with the primary agent automatically.
|
||||
// This is independent of any edges/agent_ids the primary agent has.
|
||||
|
||||
logger.debug(
|
||||
`[processAddedConvo] Added parallel agent: ${addedAgentId} (primary: ${primaryAgentId}, ` +
|
||||
`primary has edges: ${!!endpointOption.edges}, primary has agent_ids: ${!!endpointOption.agent_ids})`,
|
||||
);
|
||||
|
||||
return { userMCPAuthMap };
|
||||
} catch (err) {
|
||||
logger.error('[processAddedConvo] Error processing addedConvo for parallel agent', err);
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
processAddedConvo,
|
||||
ADDED_AGENT_ID,
|
||||
};
|
||||
|
|
@ -15,6 +15,9 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
|||
return undefined;
|
||||
});
|
||||
|
||||
/** @type {import('librechat-data-provider').TConversation | undefined} */
|
||||
const addedConvo = req.body?.addedConvo;
|
||||
|
||||
return removeNullishValues({
|
||||
spec,
|
||||
iconURL,
|
||||
|
|
@ -23,6 +26,7 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
|||
endpointType,
|
||||
model_parameters,
|
||||
agent: agentPromise,
|
||||
addedConvo,
|
||||
});
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ const {
|
|||
createSequentialChainEdges,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
getResponseSender,
|
||||
isEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
createToolEndCallback,
|
||||
|
|
@ -20,14 +20,17 @@ const { getModelsConfig } = require('~/server/controllers/ModelController');
|
|||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { processAddedConvo } = require('./addedConvo');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { logViolation } = require('~/cache');
|
||||
const db = require('~/models');
|
||||
|
||||
/**
|
||||
* @param {AbortSignal} signal
|
||||
* Creates a tool loader function for the agent.
|
||||
* @param {AbortSignal} signal - The abort signal
|
||||
* @param {string | null} [streamId] - The stream ID for resumable mode
|
||||
*/
|
||||
function createToolLoader(signal) {
|
||||
function createToolLoader(signal, streamId = null) {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
|
|
@ -52,6 +55,7 @@ function createToolLoader(signal) {
|
|||
agent,
|
||||
signal,
|
||||
tool_resources,
|
||||
streamId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error loading tools for agent ' + agentId, error);
|
||||
|
|
@ -65,18 +69,21 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
}
|
||||
const appConfig = req.config;
|
||||
|
||||
// TODO: use endpointOption to determine options/modelOptions
|
||||
/** @type {string | null} */
|
||||
const streamId = req._resumableStreamId || null;
|
||||
|
||||
/** @type {Array<UsageMetadata>} */
|
||||
const collectedUsage = [];
|
||||
/** @type {ArtifactPromises} */
|
||||
const artifactPromises = [];
|
||||
const { contentParts, aggregateContent } = createContentAggregator();
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId });
|
||||
const eventHandlers = getDefaultHandlers({
|
||||
res,
|
||||
aggregateContent,
|
||||
toolEndCallback,
|
||||
collectedUsage,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (!endpointOption.agent) {
|
||||
|
|
@ -105,7 +112,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
const agentConfigs = new Map();
|
||||
const allowedProviders = new Set(appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
const loadTools = createToolLoader(signal);
|
||||
const loadTools = createToolLoader(signal, streamId);
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
/** @type {string} */
|
||||
|
|
@ -227,6 +234,33 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
edges = edges ? edges.concat(chain) : chain;
|
||||
}
|
||||
|
||||
/** Multi-Convo: Process addedConvo for parallel agent execution */
|
||||
const { userMCPAuthMap: updatedMCPAuthMap } = await processAddedConvo({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
allowedProviders,
|
||||
agentConfigs,
|
||||
primaryAgentId: primaryConfig.id,
|
||||
primaryAgent,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
if (updatedMCPAuthMap) {
|
||||
userMCPAuthMap = updatedMCPAuthMap;
|
||||
}
|
||||
|
||||
// Ensure edges is an array when we have multiple agents (multi-agent mode)
|
||||
// MultiAgentGraph.categorizeEdges requires edges to be iterable
|
||||
if (agentConfigs.size > 0 && !edges) {
|
||||
edges = [];
|
||||
}
|
||||
|
||||
primaryConfig.edges = edges;
|
||||
|
||||
let endpointConfig = appConfig.endpoints?.[primaryConfig.endpoint];
|
||||
|
|
@ -270,10 +304,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
endpointType: endpointOption.endpointType,
|
||||
resendFiles: primaryConfig.resendFiles ?? true,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
endpoint:
|
||||
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
|
||||
? primaryConfig.endpoint
|
||||
: EModelEndpoint.agents,
|
||||
endpoint: isEphemeralAgentId(primaryConfig.id) ? primaryConfig.endpoint : EModelEndpoint.agents,
|
||||
});
|
||||
|
||||
return { client, userMCPAuthMap };
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
|||
const { getAssistant } = require('~/models/Assistant');
|
||||
|
||||
const buildOptions = async (endpoint, parsedBody) => {
|
||||
|
||||
const { promptPrefix, assistant_id, iconURL, greeting, spec, artifacts, ...modelOptions } =
|
||||
parsedBody;
|
||||
const endpointOption = removeNullishValues({
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ const {
|
|||
const {
|
||||
sendEvent,
|
||||
MCPOAuthHandler,
|
||||
isMCPDomainAllowed,
|
||||
normalizeServerName,
|
||||
convertWithResolvedRefs,
|
||||
GenerationJobManager,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
|
|
@ -21,13 +23,14 @@ const {
|
|||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getMCPManager,
|
||||
getFlowStateManager,
|
||||
getOAuthReconnectionManager,
|
||||
getMCPServersRegistry,
|
||||
getFlowStateManager,
|
||||
getMCPManager,
|
||||
} = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
|
|
@ -35,8 +38,9 @@ const { getLogStores } = require('~/cache');
|
|||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
||||
function createRunStepDeltaEmitter({ res, stepId, toolCall, streamId = null }) {
|
||||
/**
|
||||
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
|
||||
* @returns {void}
|
||||
|
|
@ -52,7 +56,12 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
|||
expires_at: Date.now() + Time.TWO_MINUTES,
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -63,8 +72,9 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
|||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {number} [params.index]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
|
||||
function createRunStepEmitter({ res, runId, stepId, toolCall, index, streamId = null }) {
|
||||
return function () {
|
||||
/** @type {import('@librechat/agents').RunStep} */
|
||||
const data = {
|
||||
|
|
@ -77,7 +87,12 @@ function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
|
|||
tool_calls: [toolCall],
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -108,10 +123,9 @@ function createOAuthStart({ flowId, flowManager, callback }) {
|
|||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string} params.loginFlowId - The ID of the login flow.
|
||||
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createOAuthEnd({ res, stepId, toolCall }) {
|
||||
function createOAuthEnd({ res, stepId, toolCall, streamId = null }) {
|
||||
return async function () {
|
||||
/** @type {{ id: string; delta: AgentToolCallDelta }} */
|
||||
const data = {
|
||||
|
|
@ -121,7 +135,12 @@ function createOAuthEnd({ res, stepId, toolCall }) {
|
|||
tool_calls: [{ ...toolCall }],
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
logger.debug('Sent OAuth login success to client');
|
||||
};
|
||||
}
|
||||
|
|
@ -137,7 +156,9 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
|
|||
return function () {
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
// Clean up both mcp_oauth and mcp_get_tokens flows
|
||||
flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted'));
|
||||
flowManager.failFlow(flowId, 'mcp_get_tokens', new Error('Tool call aborted'));
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -162,10 +183,19 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
|||
* @param {AbortSignal} params.signal
|
||||
* @param {string} params.model
|
||||
* @param {number} [params.index]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) {
|
||||
async function reconnectServer({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
|
|
@ -176,36 +206,60 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
|||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
const runStepEmitter = createRunStepEmitter({
|
||||
res,
|
||||
index,
|
||||
runId,
|
||||
stepId,
|
||||
toolCall,
|
||||
});
|
||||
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
});
|
||||
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
|
||||
const oauthStart = createOAuthStart({
|
||||
res,
|
||||
flowId,
|
||||
callback,
|
||||
flowManager,
|
||||
});
|
||||
return await reinitMCPServer({
|
||||
user,
|
||||
signal,
|
||||
serverName,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
// Set up abort handler to clean up OAuth flows if request is aborted
|
||||
const oauthFlowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
const abortHandler = () => {
|
||||
logger.info(
|
||||
`[MCP][User: ${user.id}][${serverName}] Tool loading aborted, cleaning up OAuth flows`,
|
||||
);
|
||||
// Clean up both mcp_oauth and mcp_get_tokens flows
|
||||
flowManager.failFlow(oauthFlowId, 'mcp_oauth', new Error('Tool loading aborted'));
|
||||
flowManager.failFlow(oauthFlowId, 'mcp_get_tokens', new Error('Tool loading aborted'));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener('abort', abortHandler, { once: true });
|
||||
}
|
||||
|
||||
try {
|
||||
const runStepEmitter = createRunStepEmitter({
|
||||
res,
|
||||
index,
|
||||
runId,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
|
||||
const oauthStart = createOAuthStart({
|
||||
res,
|
||||
flowId,
|
||||
callback,
|
||||
flowManager,
|
||||
});
|
||||
return await reinitMCPServer({
|
||||
user,
|
||||
signal,
|
||||
serverName,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
} finally {
|
||||
// Clean up abort handler to prevent memory leaks
|
||||
if (signal) {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -222,11 +276,45 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
|||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {number} [params.index]
|
||||
* @param {AbortSignal} [params.signal]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
|
||||
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
|
||||
async function createMCPTools({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
config,
|
||||
provider,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
const result = await reconnectServer({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
if (!result || !result.tools) {
|
||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||
return;
|
||||
|
|
@ -239,8 +327,10 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
user,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
availableTools: result.availableTools,
|
||||
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||
config: serverConfig,
|
||||
});
|
||||
if (toolInstance) {
|
||||
serverTools.push(toolInstance);
|
||||
|
|
@ -259,9 +349,11 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
* @param {string} params.model - The model for the tool.
|
||||
* @param {number} [params.index]
|
||||
* @param {AbortSignal} [params.signal]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {LCAvailableTools} [params.availableTools]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTool({
|
||||
|
|
@ -273,9 +365,25 @@ async function createMCPTool({
|
|||
provider,
|
||||
userMCPAuthMap,
|
||||
availableTools,
|
||||
config,
|
||||
streamId = null,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
// Runtime domain validation: check if the server's domain is still allowed
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {LCTool | undefined} */
|
||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
|
|
@ -289,6 +397,7 @@ async function createMCPTool({
|
|||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||
}
|
||||
|
|
@ -304,10 +413,18 @@ async function createMCPTool({
|
|||
toolName,
|
||||
serverName,
|
||||
toolDefinition,
|
||||
streamId,
|
||||
});
|
||||
}
|
||||
|
||||
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
|
||||
function createToolInstance({
|
||||
res,
|
||||
toolName,
|
||||
serverName,
|
||||
toolDefinition,
|
||||
provider: _provider,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {LCTool} */
|
||||
const { description, parameters } = toolDefinition;
|
||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||
|
|
@ -343,6 +460,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
|
|||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const oauthStart = createOAuthStart({
|
||||
flowId,
|
||||
|
|
@ -353,6 +471,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
|
|||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (derivedSignal) {
|
||||
|
|
@ -448,7 +567,10 @@ async function getMCPSetupData(userId) {
|
|||
/** @type {Map<string, import('@librechat/api').MCPConnection>} */
|
||||
let appConnections = new Map();
|
||||
try {
|
||||
appConnections = (await mcpManager.appConnections?.getAll()) || new Map();
|
||||
// Use getLoaded() instead of getAll() to avoid forcing connection creation
|
||||
// getAll() creates connections for all servers, which is problematic for servers
|
||||
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders)
|
||||
appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map();
|
||||
} catch (error) {
|
||||
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,4 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
// Mock all dependencies - define mocks before imports
|
||||
// Mock all dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({
|
|||
},
|
||||
}));
|
||||
|
||||
// Create mock registry instance
|
||||
const mockRegistryInstance = {
|
||||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
||||
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
// Create isMCPDomainAllowed mock that can be configured per-test
|
||||
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
|
||||
|
||||
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
|
||||
|
||||
jest.mock('@librechat/api', () => {
|
||||
// Access mock via getter to avoid hoisting issues
|
||||
return {
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
get isMCPDomainAllowed() {
|
||||
return mockIsMCPDomainAllowed;
|
||||
},
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
CacheKeys: {
|
||||
|
|
@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({
|
|||
|
||||
jest.mock('./Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
getAppConfig: jest.fn(),
|
||||
get getAppConfig() {
|
||||
return mockGetAppConfig;
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
|
|
@ -128,7 +144,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
|
||||
beforeEach(() => {
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
appConnections: { getAll: jest.fn(() => new Map()) },
|
||||
appConnections: { getLoaded: jest.fn(() => new Map()) },
|
||||
getUserConnections: jest.fn(() => new Map()),
|
||||
});
|
||||
mockRegistryInstance.getOAuthServers.mockResolvedValue(new Set());
|
||||
|
|
@ -143,7 +159,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
const mockOAuthServers = new Set(['server2']);
|
||||
|
||||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => Promise.resolve(mockAppConnections)) },
|
||||
appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) },
|
||||
getUserConnections: jest.fn(() => mockUserConnections),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
|
@ -153,7 +169,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
|
||||
expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId);
|
||||
|
||||
|
|
@ -174,7 +190,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
|
||||
|
||||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => Promise.resolve(null)) },
|
||||
appConnections: { getLoaded: jest.fn(() => Promise.resolve(null)) },
|
||||
getUserConnections: jest.fn(() => null),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
|
@ -692,6 +708,18 @@ describe('User parameter passing tests', () => {
|
|||
createFlowWithHandler: jest.fn(),
|
||||
failFlow: jest.fn(),
|
||||
});
|
||||
|
||||
// Reset domain validation mock to default (allow all)
|
||||
mockIsMCPDomainAllowed.mockReset();
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
// Reset registry mocks
|
||||
mockRegistryInstance.getServerConfig.mockReset();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
||||
|
||||
// Reset getAppConfig mock to default (no restrictions)
|
||||
mockGetAppConfig.mockReset();
|
||||
mockGetAppConfig.mockResolvedValue({});
|
||||
});
|
||||
|
||||
describe('createMCPTools', () => {
|
||||
|
|
@ -887,6 +915,229 @@ describe('User parameter passing tests', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('Runtime domain validation', () => {
|
||||
it('should skip tool creation when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://disallowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Should return undefined for disallowed domain
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
|
||||
// Verify domain validation was called with correct parameters
|
||||
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
|
||||
{ url: 'https://disallowed-domain.com/sse' },
|
||||
['allowed-domain.com'],
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow tool creation when domain is allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'admin' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://allowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return true (domain allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
|
||||
});
|
||||
|
||||
it('should skip domain validation for stdio transports (no URL)', async () => {
|
||||
const mockUser = { id: 'stdio-test-user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config without URL (stdio transport)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
command: 'npx',
|
||||
args: ['@modelcontextprotocol/server'],
|
||||
});
|
||||
|
||||
// Mock getAppConfig (should not be called for stdio)
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
|
||||
});
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully without domain check
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
|
||||
expect(mockGetAppConfig).not.toHaveBeenCalled();
|
||||
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return empty array from createMCPTools when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTools({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
serverName: 'test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
config: serverConfig,
|
||||
});
|
||||
|
||||
// Should return empty array for disallowed domain
|
||||
expect(result).toEqual([]);
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed early
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
});
|
||||
|
||||
it('should use user role when fetching domain restrictions', async () => {
|
||||
const adminUser = { id: 'admin-user', role: 'admin' };
|
||||
const regularUser = { id: 'regular-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock different responses based on role
|
||||
mockGetAppConfig
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
|
||||
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Call with admin user
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: adminUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Reset and call with regular user
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: regularUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Verify getAppConfig was called with correct roles
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('User parameter integrity', () => {
|
||||
it('should preserve user object properties through the call chain', async () => {
|
||||
const complexUser = {
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ describe('processMessages', () => {
|
|||
type: 'text',
|
||||
text: {
|
||||
value:
|
||||
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
|
||||
"The text you have uploaded is from the book \"Harry Potter and the Philosopher's Stone\" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.",
|
||||
annotations: [
|
||||
{
|
||||
type: 'file_citation',
|
||||
|
|
@ -424,7 +424,7 @@ These points highlight Harry's initial experiences in the magical world and set
|
|||
type: 'text',
|
||||
text: {
|
||||
value:
|
||||
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
|
||||
"The text you have uploaded is from the book \"Harry Potter and the Philosopher's Stone\" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.",
|
||||
annotations: [
|
||||
{
|
||||
type: 'file_citation',
|
||||
|
|
@ -582,7 +582,7 @@ These points highlight Harry's initial experiences in the magical world and set
|
|||
type: 'text',
|
||||
text: {
|
||||
value:
|
||||
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation【11:2†source】.',
|
||||
"This is a test ^1^ with pre-existing citation-like text. Here's a real citation【11:2†source】.",
|
||||
annotations: [
|
||||
{
|
||||
type: 'file_citation',
|
||||
|
|
@ -610,7 +610,7 @@ These points highlight Harry's initial experiences in the magical world and set
|
|||
});
|
||||
|
||||
const expectedText =
|
||||
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation^1^.\n\n^1.^ test.txt';
|
||||
"This is a test ^1^ with pre-existing citation-like text. Here's a real citation^1^.\n\n^1.^ test.txt";
|
||||
|
||||
expect(result.text).toBe(expectedText);
|
||||
expect(result.edited).toBe(true);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ const {
|
|||
} = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
ErrorTypes,
|
||||
ContentTypes,
|
||||
imageGenTools,
|
||||
|
|
@ -18,6 +17,7 @@ const {
|
|||
ImageVisionTool,
|
||||
openapiToFunction,
|
||||
AgentCapabilities,
|
||||
isEphemeralAgentId,
|
||||
validateActionDomain,
|
||||
defaultAgentCapabilities,
|
||||
validateAndParseOpenAPISpec,
|
||||
|
|
@ -369,7 +369,15 @@ async function processRequiredActions(client, requiredActions) {
|
|||
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
|
||||
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
|
||||
*/
|
||||
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
|
||||
async function loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
signal,
|
||||
tool_resources,
|
||||
openAIApiKey,
|
||||
streamId = null,
|
||||
}) {
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
return {};
|
||||
} else if (
|
||||
|
|
@ -385,7 +393,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
|
||||
/** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */
|
||||
if (enabledCapabilities.size === 0 && agent.id === Constants.EPHEMERAL_AGENT_ID) {
|
||||
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
|
||||
enabledCapabilities = new Set(
|
||||
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
|
||||
);
|
||||
|
|
@ -422,7 +430,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
/** @type {ReturnType<typeof createOnSearchResults>} */
|
||||
let webSearchCallbacks;
|
||||
if (includesWebSearch) {
|
||||
webSearchCallbacks = createOnSearchResults(res);
|
||||
webSearchCallbacks = createOnSearchResults(res, streamId);
|
||||
}
|
||||
|
||||
/** @type {Record<string, Record<string, string>>} */
|
||||
|
|
@ -622,6 +630,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
encrypted,
|
||||
name: toolName,
|
||||
description: functionSig.description,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (!tool) {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,29 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { Tools } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { GenerationJobManager } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Helper to write attachment events either to res or to job emitter.
|
||||
* @param {import('http').ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} attachment - The attachment data
|
||||
*/
|
||||
function writeAttachment(res, streamId, attachment) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
|
||||
} else {
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a function to handle search results and stream them as attachments
|
||||
* @param {import('http').ServerResponse} res - The HTTP server response object
|
||||
* @param {string | null} [streamId] - The stream ID for resumable mode, or null for standard mode
|
||||
* @returns {{ onSearchResults: function(SearchResult, GraphRunnableConfig): void; onGetHighlights: function(string): void}} - Function that takes search results and returns or streams an attachment
|
||||
*/
|
||||
function createOnSearchResults(res) {
|
||||
function createOnSearchResults(res, streamId = null) {
|
||||
const context = {
|
||||
sourceMap: new Map(),
|
||||
searchResultData: undefined,
|
||||
|
|
@ -70,7 +86,7 @@ function createOnSearchResults(res) {
|
|||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -92,7 +108,7 @@ function createOnSearchResults(res) {
|
|||
}
|
||||
|
||||
const attachment = buildAttachment(context);
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ async function initializeMCPs() {
|
|||
}
|
||||
|
||||
// Initialize MCPServersRegistry first (required for MCPManager)
|
||||
// Pass allowedDomains from mcpSettings for domain validation
|
||||
try {
|
||||
createMCPServersRegistry(mongoose);
|
||||
createMCPServersRegistry(mongoose, appConfig?.mcpSettings?.allowedDomains);
|
||||
} catch (error) {
|
||||
logger.error('[MCP] Failed to initialize MCPServersRegistry:', error);
|
||||
throw error;
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ const jwksRsa = require('jwks-rsa');
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { isEnabled, findOpenIDUser, math } = require('@librechat/api');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { isEnabled, findOpenIDUser } = require('@librechat/api');
|
||||
const { updateUser, findUser } = require('~/models');
|
||||
|
||||
/**
|
||||
|
|
@ -27,9 +27,7 @@ const { updateUser, findUser } = require('~/models');
|
|||
const openIdJwtLogin = (openIdConfig) => {
|
||||
let jwksRsaOptions = {
|
||||
cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true,
|
||||
cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME
|
||||
? eval(process.env.OPENID_JWKS_URL_CACHE_TIME)
|
||||
: 60000,
|
||||
cacheMaxAge: math(process.env.OPENID_JWKS_URL_CACHE_TIME, 60000),
|
||||
jwksUri: openIdConfig.serverMetadata().jwks_uri,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -11,3 +11,7 @@ OPENAI_API_KEY=your-api-key
|
|||
BAN_VIOLATIONS=true
|
||||
BAN_DURATION=7200000
|
||||
BAN_INTERVAL=20
|
||||
|
||||
# NODE_MAX_OLD_SPACE_SIZE is only used as a Docker build argument.
|
||||
# Node.js does NOT recognize this environment variable for heap size.
|
||||
NODE_MAX_OLD_SPACE_SIZE=6144
|
||||
|
|
|
|||
162
api/test/app/clients/tools/structured/OpenAIImageTools.test.js
Normal file
162
api/test/app/clients/tools/structured/OpenAIImageTools.test.js
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
const OpenAI = require('openai');
|
||||
const createOpenAIImageTools = require('~/app/clients/tools/structured/OpenAIImageTools');
|
||||
|
||||
jest.mock('openai');
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
logAxiosError: jest.fn(),
|
||||
oaiToolkit: {
|
||||
image_gen_oai: {
|
||||
name: 'image_gen_oai',
|
||||
description: 'Generate an image',
|
||||
schema: {},
|
||||
},
|
||||
image_edit_oai: {
|
||||
name: 'image_edit_oai',
|
||||
description: 'Edit an image',
|
||||
schema: {},
|
||||
},
|
||||
},
|
||||
extractBaseURL: jest.fn((url) => url),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getFiles: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
describe('OpenAIImageTools - IMAGE_GEN_OAI_MODEL environment variable', () => {
|
||||
let originalEnv;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
originalEnv = { ...process.env };
|
||||
|
||||
process.env.IMAGE_GEN_OAI_API_KEY = 'test-api-key';
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should use default model "gpt-image-1" when IMAGE_GEN_OAI_MODEL is not set', async () => {
|
||||
delete process.env.IMAGE_GEN_OAI_MODEL;
|
||||
|
||||
const [imageGenTool] = createOpenAIImageTools({
|
||||
isAgent: true,
|
||||
override: false,
|
||||
req: { user: { id: 'test-user' } },
|
||||
});
|
||||
|
||||
const mockGenerate = jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: mockGenerate,
|
||||
},
|
||||
}));
|
||||
|
||||
await imageGenTool.func({ prompt: 'test prompt' });
|
||||
|
||||
expect(mockGenerate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gpt-image-1',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use "gpt-image-1.5" when IMAGE_GEN_OAI_MODEL is set to "gpt-image-1.5"', async () => {
|
||||
process.env.IMAGE_GEN_OAI_MODEL = 'gpt-image-1.5';
|
||||
|
||||
const mockGenerate = jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: mockGenerate,
|
||||
},
|
||||
}));
|
||||
|
||||
const [imageGenTool] = createOpenAIImageTools({
|
||||
isAgent: true,
|
||||
override: false,
|
||||
req: { user: { id: 'test-user' } },
|
||||
});
|
||||
|
||||
await imageGenTool.func({ prompt: 'test prompt' });
|
||||
|
||||
expect(mockGenerate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gpt-image-1.5',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use custom model name from IMAGE_GEN_OAI_MODEL environment variable', async () => {
|
||||
process.env.IMAGE_GEN_OAI_MODEL = 'custom-image-model';
|
||||
|
||||
const mockGenerate = jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
b64_json: 'base64-encoded-image-data',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: mockGenerate,
|
||||
},
|
||||
}));
|
||||
|
||||
const [imageGenTool] = createOpenAIImageTools({
|
||||
isAgent: true,
|
||||
override: false,
|
||||
req: { user: { id: 'test-user' } },
|
||||
});
|
||||
|
||||
await imageGenTool.func({ prompt: 'test prompt' });
|
||||
|
||||
expect(mockGenerate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'custom-image-model',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
/** v0.8.1 */
|
||||
/** v0.8.2-rc1 */
|
||||
module.exports = {
|
||||
roots: ['<rootDir>/src'],
|
||||
testEnvironment: 'jsdom',
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@librechat/frontend",
|
||||
"version": "v0.8.1",
|
||||
"version": "v0.8.2-rc1",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
|
@ -12,7 +12,7 @@
|
|||
"dev": "cross-env NODE_ENV=development vite",
|
||||
"preview-prod": "cross-env NODE_ENV=development vite preview",
|
||||
"test": "cross-env NODE_ENV=development jest --watch",
|
||||
"test:ci": "cross-env NODE_ENV=development jest --ci",
|
||||
"test:ci": "cross-env NODE_ENV=development jest --ci --logHeapUsage",
|
||||
"b:test": "NODE_ENV=test bunx jest --watch",
|
||||
"b:build": "NODE_ENV=production bun --bun vite build",
|
||||
"b:dev": "NODE_ENV=development bunx vite"
|
||||
|
|
@ -39,10 +39,10 @@
|
|||
"@marsidev/react-turnstile": "^1.1.0",
|
||||
"@mcp-ui/client": "^5.7.0",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.15",
|
||||
"@radix-ui/react-alert-dialog": "1.0.2",
|
||||
"@radix-ui/react-checkbox": "^1.0.3",
|
||||
"@radix-ui/react-collapsible": "^1.0.3",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dialog": "1.0.2",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.1",
|
||||
"@radix-ui/react-hover-card": "^1.0.5",
|
||||
"@radix-ui/react-icons": "^1.3.0",
|
||||
|
|
@ -80,6 +80,7 @@
|
|||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.394.0",
|
||||
"match-sorter": "^8.1.0",
|
||||
"mermaid": "^11.12.2",
|
||||
"micromark-extension-llm-math": "^3.1.0",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"rc-input-number": "^7.4.2",
|
||||
|
|
@ -109,9 +110,11 @@
|
|||
"remark-math": "^6.0.0",
|
||||
"remark-supersub": "^1.0.0",
|
||||
"sse.js": "^2.5.0",
|
||||
"swr": "^2.3.8",
|
||||
"tailwind-merge": "^1.9.1",
|
||||
"tailwindcss-animate": "^1.0.5",
|
||||
"tailwindcss-radix": "^2.8.0",
|
||||
"ts-md5": "^1.3.1",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
|
|||
12
client/src/@types/i18next.d.ts
vendored
12
client/src/@types/i18next.d.ts
vendored
|
|
@ -1,9 +1,9 @@
|
|||
import { defaultNS, resources } from '~/locales/i18n';
|
||||
|
||||
declare module 'i18next' {
|
||||
interface CustomTypeOptions {
|
||||
defaultNS: typeof defaultNS;
|
||||
resources: typeof resources.en;
|
||||
strictKeyChecks: true
|
||||
}
|
||||
}
|
||||
interface CustomTypeOptions {
|
||||
defaultNS: typeof defaultNS;
|
||||
resources: typeof resources.en;
|
||||
strictKeyChecks: true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,16 @@ const App = () => {
|
|||
const { setError } = useApiErrorBoundary();
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
// Always attempt network requests, even when navigator.onLine is false
|
||||
// This is needed because localhost is reachable without WiFi
|
||||
networkMode: 'always',
|
||||
},
|
||||
mutations: {
|
||||
networkMode: 'always',
|
||||
},
|
||||
},
|
||||
queryCache: new QueryCache({
|
||||
onError: (error) => {
|
||||
if (error?.response?.status === 401) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
import useAddedResponse from '~/hooks/Chat/useAddedResponse';
|
||||
type TAddedChatContext = ReturnType<typeof useAddedResponse>;
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { SetterOrUpdater } from 'recoil';
|
||||
import type { ConvoGenerator } from '~/common';
|
||||
|
||||
type TAddedChatContext = {
|
||||
conversation: TConversation | null;
|
||||
setConversation: SetterOrUpdater<TConversation | null>;
|
||||
generateConversation: ConvoGenerator;
|
||||
};
|
||||
|
||||
export const AddedChatContext = createContext<TAddedChatContext>({} as TAddedChatContext);
|
||||
export const useAddedChatContext = () => useContext(AddedChatContext);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import { getEndpointField } from 'librechat-data-provider';
|
||||
import { getEndpointField, isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import type { EModelEndpoint } from 'librechat-data-provider';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from './AgentsMapContext';
|
||||
import { useChatContext } from './ChatContext';
|
||||
|
||||
interface DragDropContextValue {
|
||||
|
|
@ -9,6 +10,7 @@ interface DragDropContextValue {
|
|||
agentId: string | null | undefined;
|
||||
endpoint: string | null | undefined;
|
||||
endpointType?: EModelEndpoint | undefined;
|
||||
useResponsesApi?: boolean;
|
||||
}
|
||||
|
||||
const DragDropContext = createContext<DragDropContextValue | undefined>(undefined);
|
||||
|
|
@ -16,6 +18,7 @@ const DragDropContext = createContext<DragDropContextValue | undefined>(undefine
|
|||
export function DragDropProvider({ children }: { children: React.ReactNode }) {
|
||||
const { conversation } = useChatContext();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const agentsMap = useAgentsMapContext();
|
||||
|
||||
const endpointType = useMemo(() => {
|
||||
return (
|
||||
|
|
@ -24,6 +27,34 @@ export function DragDropProvider({ children }: { children: React.ReactNode }) {
|
|||
);
|
||||
}, [conversation?.endpoint, endpointsConfig]);
|
||||
|
||||
const needsAgentFetch = useMemo(() => {
|
||||
const isAgents = isAgentsEndpoint(conversation?.endpoint);
|
||||
if (!isAgents || !conversation?.agent_id) {
|
||||
return false;
|
||||
}
|
||||
const agent = agentsMap?.[conversation.agent_id];
|
||||
return !agent?.model_parameters;
|
||||
}, [conversation?.endpoint, conversation?.agent_id, agentsMap]);
|
||||
|
||||
const { data: agentData } = useGetAgentByIdQuery(conversation?.agent_id, {
|
||||
enabled: needsAgentFetch,
|
||||
});
|
||||
|
||||
const useResponsesApi = useMemo(() => {
|
||||
const isAgents = isAgentsEndpoint(conversation?.endpoint);
|
||||
if (!isAgents || !conversation?.agent_id || conversation?.useResponsesApi) {
|
||||
return conversation?.useResponsesApi;
|
||||
}
|
||||
const agent = agentData || agentsMap?.[conversation.agent_id];
|
||||
return agent?.model_parameters?.useResponsesApi;
|
||||
}, [
|
||||
conversation?.endpoint,
|
||||
conversation?.agent_id,
|
||||
conversation?.useResponsesApi,
|
||||
agentData,
|
||||
agentsMap,
|
||||
]);
|
||||
|
||||
/** Context value only created when conversation fields change */
|
||||
const contextValue = useMemo<DragDropContextValue>(
|
||||
() => ({
|
||||
|
|
@ -31,8 +62,15 @@ export function DragDropProvider({ children }: { children: React.ReactNode }) {
|
|||
agentId: conversation?.agent_id,
|
||||
endpoint: conversation?.endpoint,
|
||||
endpointType: endpointType,
|
||||
useResponsesApi: useResponsesApi,
|
||||
}),
|
||||
[conversation?.conversationId, conversation?.agent_id, conversation?.endpoint, endpointType],
|
||||
[
|
||||
conversation?.conversationId,
|
||||
conversation?.agent_id,
|
||||
conversation?.endpoint,
|
||||
useResponsesApi,
|
||||
endpointType,
|
||||
],
|
||||
);
|
||||
|
||||
return <DragDropContext.Provider value={contextValue}>{children}</DragDropContext.Provider>;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import { useAddedChatContext } from './AddedChatContext';
|
||||
import { useChatContext } from './ChatContext';
|
||||
|
||||
interface MessagesViewContextValue {
|
||||
|
|
@ -9,7 +8,6 @@ interface MessagesViewContextValue {
|
|||
|
||||
/** Submission and control states */
|
||||
isSubmitting: ReturnType<typeof useChatContext>['isSubmitting'];
|
||||
isSubmittingFamily: boolean;
|
||||
abortScroll: ReturnType<typeof useChatContext>['abortScroll'];
|
||||
setAbortScroll: ReturnType<typeof useChatContext>['setAbortScroll'];
|
||||
|
||||
|
|
@ -34,13 +32,12 @@ export type { MessagesViewContextValue };
|
|||
|
||||
export function MessagesViewProvider({ children }: { children: React.ReactNode }) {
|
||||
const chatContext = useChatContext();
|
||||
const addedChatContext = useAddedChatContext();
|
||||
|
||||
const {
|
||||
ask,
|
||||
index,
|
||||
regenerate,
|
||||
isSubmitting: isSubmittingRoot,
|
||||
isSubmitting,
|
||||
conversation,
|
||||
latestMessage,
|
||||
setAbortScroll,
|
||||
|
|
@ -51,8 +48,6 @@ export function MessagesViewProvider({ children }: { children: React.ReactNode }
|
|||
setMessages,
|
||||
} = chatContext;
|
||||
|
||||
const { isSubmitting: isSubmittingAdditional } = addedChatContext;
|
||||
|
||||
/** Memoize conversation-related values */
|
||||
const conversationValues = useMemo(
|
||||
() => ({
|
||||
|
|
@ -65,12 +60,11 @@ export function MessagesViewProvider({ children }: { children: React.ReactNode }
|
|||
/** Memoize submission states */
|
||||
const submissionStates = useMemo(
|
||||
() => ({
|
||||
isSubmitting: isSubmittingRoot,
|
||||
isSubmittingFamily: isSubmittingRoot || isSubmittingAdditional,
|
||||
abortScroll,
|
||||
isSubmitting,
|
||||
setAbortScroll,
|
||||
}),
|
||||
[isSubmittingRoot, isSubmittingAdditional, abortScroll, setAbortScroll],
|
||||
[isSubmitting, abortScroll, setAbortScroll],
|
||||
);
|
||||
|
||||
/** Memoize message operations (these are typically stable references) */
|
||||
|
|
@ -127,11 +121,10 @@ export function useMessagesConversation() {
|
|||
|
||||
/** Hook for components that only need submission states */
|
||||
export function useMessagesSubmission() {
|
||||
const { isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll } =
|
||||
useMessagesViewContext();
|
||||
const { isSubmitting, abortScroll, setAbortScroll } = useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll }),
|
||||
[isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll],
|
||||
() => ({ isSubmitting, abortScroll, setAbortScroll }),
|
||||
[isSubmitting, abortScroll, setAbortScroll],
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { RefObject } from 'react';
|
||||
import { Constants, FileSources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import { FileSources, EModelEndpoint, isEphemeralAgentId } from 'librechat-data-provider';
|
||||
import type { UseMutationResult } from '@tanstack/react-query';
|
||||
import type * as InputNumberPrimitive from 'rc-input-number';
|
||||
import type { SetterOrUpdater, RecoilState } from 'recoil';
|
||||
|
|
@ -10,7 +10,7 @@ import type { TranslationKeys } from '~/hooks';
|
|||
import { MCPServerDefinition } from '~/hooks/MCP/useMCPServerManager';
|
||||
|
||||
export function isEphemeralAgent(agentId: string | null | undefined): boolean {
|
||||
return agentId == null || agentId === '' || agentId === Constants.EPHEMERAL_AGENT_ID;
|
||||
return isEphemeralAgentId(agentId);
|
||||
}
|
||||
|
||||
export interface ConfigFieldDetail {
|
||||
|
|
@ -356,6 +356,8 @@ export type TOptions = {
|
|||
isResubmission?: boolean;
|
||||
/** Currently only utilized when `isResubmission === true`, uses that message's currently attached files */
|
||||
overrideFiles?: t.TMessage['files'];
|
||||
/** Added conversation for multi-convo feature - sent to server as part of submission payload */
|
||||
addedConvo?: t.TConversation;
|
||||
};
|
||||
|
||||
export type TAskFunction = (props: TAskProps, options?: TOptions) => void;
|
||||
|
|
|
|||
|
|
@ -1,21 +1,23 @@
|
|||
import React, { useMemo } from 'react';
|
||||
import { Label } from '@librechat/client';
|
||||
import React, { useMemo, useState } from 'react';
|
||||
import { Label, OGDialog, OGDialogTrigger } from '@librechat/client';
|
||||
import type t from 'librechat-data-provider';
|
||||
import { useLocalize, TranslationKeys, useAgentCategories } from '~/hooks';
|
||||
import { cn, renderAgentAvatar, getContactDisplayName } from '~/utils';
|
||||
import AgentDetailContent from './AgentDetailContent';
|
||||
|
||||
interface AgentCardProps {
|
||||
agent: t.Agent; // The agent data to display
|
||||
onClick: () => void; // Callback when card is clicked
|
||||
className?: string; // Additional CSS classes
|
||||
agent: t.Agent;
|
||||
onSelect?: (agent: t.Agent) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Card component to display agent information
|
||||
* Card component to display agent information with integrated detail dialog
|
||||
*/
|
||||
const AgentCard: React.FC<AgentCardProps> = ({ agent, onClick, className = '' }) => {
|
||||
const AgentCard: React.FC<AgentCardProps> = ({ agent, onSelect, className = '' }) => {
|
||||
const localize = useLocalize();
|
||||
const { categories } = useAgentCategories();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const categoryLabel = useMemo(() => {
|
||||
if (!agent.category) return '';
|
||||
|
|
@ -31,82 +33,89 @@ const AgentCard: React.FC<AgentCardProps> = ({ agent, onClick, className = '' })
|
|||
return agent.category.charAt(0).toUpperCase() + agent.category.slice(1);
|
||||
}, [agent.category, categories, localize]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'group relative h-40 overflow-hidden rounded-xl border border-border-light',
|
||||
'cursor-pointer shadow-sm transition-all duration-200 hover:border-border-medium hover:shadow-lg',
|
||||
'bg-surface-tertiary hover:bg-surface-hover',
|
||||
'space-y-3 p-4',
|
||||
className,
|
||||
)}
|
||||
onClick={onClick}
|
||||
aria-label={localize('com_agents_agent_card_label', {
|
||||
name: agent.name,
|
||||
description: agent.description ?? '',
|
||||
})}
|
||||
aria-describedby={`agent-${agent.id}-description`}
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
onClick();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
{/* Left column: Avatar and Category */}
|
||||
<div className="flex h-full flex-shrink-0 flex-col justify-between space-y-4">
|
||||
<div className="flex-shrink-0">{renderAgentAvatar(agent, { size: 'sm' })}</div>
|
||||
const displayName = getContactDisplayName(agent);
|
||||
|
||||
{/* Category tag */}
|
||||
{agent.category && (
|
||||
<div className="inline-flex items-center rounded-md border-border-xheavy bg-surface-active-alt px-2 py-1 text-xs font-medium">
|
||||
<Label className="line-clamp-1 font-normal">{categoryLabel}</Label>
|
||||
const handleOpenChange = (open: boolean) => {
|
||||
setIsOpen(open);
|
||||
if (open && onSelect) {
|
||||
onSelect(agent);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialog open={isOpen} onOpenChange={handleOpenChange}>
|
||||
<OGDialogTrigger asChild>
|
||||
<div
|
||||
className={cn(
|
||||
'group relative flex h-32 gap-5 overflow-hidden rounded-xl',
|
||||
'cursor-pointer select-none px-6 py-4',
|
||||
'bg-surface-tertiary transition-colors duration-150 hover:bg-surface-hover',
|
||||
'md:h-36 lg:h-40',
|
||||
'[&_*]:cursor-pointer',
|
||||
className,
|
||||
)}
|
||||
aria-label={localize('com_agents_agent_card_label', {
|
||||
name: agent.name,
|
||||
description: agent.description ?? '',
|
||||
})}
|
||||
aria-describedby={agent.description ? `agent-${agent.id}-description` : undefined}
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
setIsOpen(true);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{/* Category badge - top right */}
|
||||
{categoryLabel && (
|
||||
<span className="absolute right-4 top-3 rounded-md bg-surface-hover px-2 py-0.5 text-xs text-text-secondary">
|
||||
{categoryLabel}
|
||||
</span>
|
||||
)}
|
||||
|
||||
{/* Avatar */}
|
||||
<div className="flex-shrink-0 self-center">
|
||||
<div className="overflow-hidden rounded-full shadow-[0_0_15px_rgba(0,0,0,0.3)] dark:shadow-[0_0_15px_rgba(0,0,0,0.5)]">
|
||||
{renderAgentAvatar(agent, { size: 'sm', showBorder: false })}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex min-w-0 flex-1 flex-col justify-center overflow-hidden">
|
||||
{/* Agent name */}
|
||||
<Label className="line-clamp-2 text-base font-semibold text-text-primary md:text-lg">
|
||||
{agent.name}
|
||||
</Label>
|
||||
|
||||
{/* Agent description */}
|
||||
{agent.description && (
|
||||
<p
|
||||
id={`agent-${agent.id}-description`}
|
||||
className="mt-0.5 line-clamp-2 text-sm leading-snug text-text-secondary md:line-clamp-5"
|
||||
aria-label={localize('com_agents_description_card', {
|
||||
description: agent.description,
|
||||
})}
|
||||
>
|
||||
{agent.description}
|
||||
</p>
|
||||
)}
|
||||
|
||||
{/* Author */}
|
||||
{displayName && (
|
||||
<div className="mt-1 text-xs text-text-tertiary">
|
||||
<span className="truncate">
|
||||
{localize('com_ui_by_author', { 0: displayName || '' })}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right column: Name, description, and other content */}
|
||||
<div className="flex h-full min-w-0 flex-1 flex-col justify-between space-y-1">
|
||||
<div className="space-y-1">
|
||||
{/* Agent name */}
|
||||
<Label className="mb-1 line-clamp-1 text-xl font-semibold text-text-primary">
|
||||
{agent.name}
|
||||
</Label>
|
||||
|
||||
{/* Agent description */}
|
||||
<p
|
||||
id={`agent-${agent.id}-description`}
|
||||
className="line-clamp-3 text-sm leading-relaxed text-text-primary"
|
||||
{...(agent.description
|
||||
? { 'aria-label': `Description: ${agent.description}` }
|
||||
: {})}
|
||||
>
|
||||
{agent.description ?? ''}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Owner info */}
|
||||
{(() => {
|
||||
const displayName = getContactDisplayName(agent);
|
||||
if (displayName) {
|
||||
return (
|
||||
<div className="flex justify-end">
|
||||
<div className="flex items-center text-sm text-text-secondary">
|
||||
<Label>{displayName}</Label>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
})()}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</OGDialogTrigger>
|
||||
|
||||
<AgentDetailContent agent={agent} />
|
||||
</OGDialog>
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
192
client/src/components/Agents/AgentDetailContent.tsx
Normal file
192
client/src/components/Agents/AgentDetailContent.tsx
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
import React from 'react';
|
||||
import { Link, Pin, PinOff } from 'lucide-react';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { OGDialogContent, Button, useToastContext } from '@librechat/client';
|
||||
import {
|
||||
QueryKeys,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
PermissionBits,
|
||||
LocalStorageKeys,
|
||||
AgentListResponse,
|
||||
} from 'librechat-data-provider';
|
||||
import type t from 'librechat-data-provider';
|
||||
import { useLocalize, useDefaultConvo, useFavorites } from '~/hooks';
|
||||
import { renderAgentAvatar, clearMessagesCache } from '~/utils';
|
||||
import { useChatContext } from '~/Providers';
|
||||
|
||||
interface SupportContact {
|
||||
name?: string;
|
||||
email?: string;
|
||||
}
|
||||
|
||||
interface AgentWithSupport extends t.Agent {
|
||||
support_contact?: SupportContact;
|
||||
}
|
||||
|
||||
interface AgentDetailContentProps {
|
||||
agent: AgentWithSupport;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dialog content for displaying agent details
|
||||
* Used inside OGDialog with OGDialogTrigger for proper focus management
|
||||
*/
|
||||
const AgentDetailContent: React.FC<AgentDetailContentProps> = ({ agent }) => {
|
||||
const localize = useLocalize();
|
||||
const queryClient = useQueryClient();
|
||||
const { showToast } = useToastContext();
|
||||
const getDefaultConversation = useDefaultConvo();
|
||||
const { conversation, newConversation } = useChatContext();
|
||||
const { isFavoriteAgent, toggleFavoriteAgent } = useFavorites();
|
||||
const isFavorite = isFavoriteAgent(agent?.id);
|
||||
|
||||
const handleFavoriteClick = () => {
|
||||
if (agent) {
|
||||
toggleFavoriteAgent(agent.id);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Navigate to chat with the selected agent
|
||||
*/
|
||||
const handleStartChat = () => {
|
||||
if (agent) {
|
||||
const keys = [QueryKeys.agents, { requiredPermission: PermissionBits.EDIT }];
|
||||
const listResp = queryClient.getQueryData<AgentListResponse>(keys);
|
||||
if (listResp != null) {
|
||||
if (!listResp.data.some((a) => a.id === agent.id)) {
|
||||
const currentAgents = [agent, ...JSON.parse(JSON.stringify(listResp.data))];
|
||||
queryClient.setQueryData<AgentListResponse>(keys, { ...listResp, data: currentAgents });
|
||||
}
|
||||
}
|
||||
|
||||
localStorage.setItem(`${LocalStorageKeys.AGENT_ID_PREFIX}0`, agent.id);
|
||||
|
||||
clearMessagesCache(queryClient, conversation?.conversationId);
|
||||
queryClient.invalidateQueries([QueryKeys.messages]);
|
||||
|
||||
/** Template with agent configuration */
|
||||
const template = {
|
||||
conversationId: Constants.NEW_CONVO as string,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
agent_id: agent.id,
|
||||
title: localize('com_agents_chat_with', { name: agent.name || localize('com_ui_agent') }),
|
||||
};
|
||||
|
||||
const currentConvo = getDefaultConversation({
|
||||
conversation: { ...(conversation ?? {}), ...template },
|
||||
preset: template,
|
||||
});
|
||||
|
||||
newConversation({
|
||||
template: currentConvo,
|
||||
preset: template,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Copy the agent's shareable link to clipboard
|
||||
*/
|
||||
const handleCopyLink = () => {
|
||||
const baseUrl = new URL(window.location.origin);
|
||||
const chatUrl = `${baseUrl.origin}/c/new?agent_id=${agent.id}`;
|
||||
navigator.clipboard
|
||||
.writeText(chatUrl)
|
||||
.then(() => {
|
||||
showToast({
|
||||
message: localize('com_agents_link_copied'),
|
||||
});
|
||||
})
|
||||
.catch(() => {
|
||||
showToast({
|
||||
message: localize('com_agents_link_copy_failed'),
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Format contact information with mailto links when appropriate
|
||||
*/
|
||||
const formatContact = () => {
|
||||
if (!agent?.support_contact) return null;
|
||||
|
||||
const { name, email } = agent.support_contact;
|
||||
|
||||
if (name && email) {
|
||||
return (
|
||||
<a href={`mailto:${email}`} className="text-primary hover:underline">
|
||||
{name}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
if (email) {
|
||||
return (
|
||||
<a href={`mailto:${email}`} className="text-primary hover:underline">
|
||||
{email}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
if (name) {
|
||||
return <span>{name}</span>;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialogContent className="max-h-[90vh] w-11/12 max-w-lg overflow-y-auto">
|
||||
{/* Agent avatar */}
|
||||
<div className="mt-6 flex justify-center">{renderAgentAvatar(agent, { size: 'xl' })}</div>
|
||||
|
||||
{/* Agent name */}
|
||||
<div className="mt-3 text-center">
|
||||
<h2 className="text-2xl font-bold text-text-primary">
|
||||
{agent?.name || localize('com_agents_loading')}
|
||||
</h2>
|
||||
</div>
|
||||
|
||||
{/* Contact info */}
|
||||
{agent?.support_contact && formatContact() && (
|
||||
<div className="mt-1 text-center text-sm text-text-secondary">
|
||||
{localize('com_agents_contact')}: {formatContact()}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Agent description */}
|
||||
<div className="mt-4 whitespace-pre-wrap px-6 text-center text-base text-text-primary">
|
||||
{agent?.description}
|
||||
</div>
|
||||
|
||||
{/* Action button */}
|
||||
<div className="mb-4 mt-6 flex justify-center gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={handleFavoriteClick}
|
||||
title={isFavorite ? localize('com_ui_unpin') : localize('com_ui_pin')}
|
||||
aria-label={isFavorite ? localize('com_ui_unpin') : localize('com_ui_pin')}
|
||||
>
|
||||
{isFavorite ? <PinOff className="h-4 w-4" /> : <Pin className="h-4 w-4" />}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={handleCopyLink}
|
||||
title={localize('com_agents_copy_link')}
|
||||
aria-label={localize('com_agents_copy_link')}
|
||||
>
|
||||
<Link className="h-4 w-4" aria-hidden="true" />
|
||||
</Button>
|
||||
<Button className="w-full max-w-xs" onClick={handleStartChat} disabled={!agent}>
|
||||
{localize('com_agents_start_chat')}
|
||||
</Button>
|
||||
</div>
|
||||
</OGDialogContent>
|
||||
);
|
||||
};
|
||||
|
||||
export default AgentDetailContent;
|
||||
|
|
@ -10,10 +10,10 @@ import ErrorDisplay from './ErrorDisplay';
|
|||
import AgentCard from './AgentCard';
|
||||
|
||||
interface AgentGridProps {
|
||||
category: string; // Currently selected category
|
||||
searchQuery: string; // Current search query
|
||||
onSelectAgent: (agent: t.Agent) => void; // Callback when agent is selected
|
||||
scrollElementRef?: React.RefObject<HTMLElement>; // Parent scroll container ref for infinite scroll
|
||||
category: string;
|
||||
searchQuery: string;
|
||||
onSelectAgent: (agent: t.Agent) => void;
|
||||
scrollElementRef?: React.RefObject<HTMLElement>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -184,7 +184,7 @@ const AgentGrid: React.FC<AgentGridProps> = ({
|
|||
{/* Agent grid - 2 per row with proper semantic structure */}
|
||||
{currentAgents && currentAgents.length > 0 && (
|
||||
<div
|
||||
className="grid grid-cols-1 gap-6 md:grid-cols-2"
|
||||
className="mx-4 grid grid-cols-1 gap-6 md:grid-cols-2"
|
||||
role="grid"
|
||||
aria-label={localize('com_agents_grid_announcement', {
|
||||
count: currentAgents.length,
|
||||
|
|
@ -193,7 +193,7 @@ const AgentGrid: React.FC<AgentGridProps> = ({
|
|||
>
|
||||
{currentAgents.map((agent: t.Agent, index: number) => (
|
||||
<div key={`${agent.id}-${index}`} role="gridcell">
|
||||
<AgentCard agent={agent} onClick={() => onSelectAgent(agent)} />
|
||||
<AgentCard agent={agent} onSelect={onSelectAgent} />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import { SidePanelGroup } from '~/components/SidePanel';
|
|||
import { OpenSidebar } from '~/components/Chat/Menus';
|
||||
import { cn, clearMessagesCache } from '~/utils';
|
||||
import CategoryTabs from './CategoryTabs';
|
||||
import AgentDetail from './AgentDetail';
|
||||
import SearchBar from './SearchBar';
|
||||
import AgentGrid from './AgentGrid';
|
||||
import store from '~/store';
|
||||
|
|
@ -45,7 +44,6 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
|
|||
|
||||
// Get URL parameters
|
||||
const searchQuery = searchParams.get('q') || '';
|
||||
const selectedAgentId = searchParams.get('agent_id') || '';
|
||||
|
||||
// Animation state
|
||||
type Direction = 'left' | 'right';
|
||||
|
|
@ -58,10 +56,6 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
|
|||
// Ref for the scrollable container to enable infinite scroll
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Local state
|
||||
const [isDetailOpen, setIsDetailOpen] = useState(false);
|
||||
const [selectedAgent, setSelectedAgent] = useState<t.Agent | null>(null);
|
||||
|
||||
// Set page title
|
||||
useDocumentTitle(`${localize('com_agents_marketplace')} | LibreChat`);
|
||||
|
||||
|
|
@ -102,28 +96,12 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
|
|||
}, [category, categoriesQuery.data, displayCategory]);
|
||||
|
||||
/**
|
||||
* Handle agent card selection
|
||||
*
|
||||
* @param agent - The selected agent object
|
||||
* Handle agent card selection - updates URL for deep linking
|
||||
*/
|
||||
const handleAgentSelect = (agent: t.Agent) => {
|
||||
// Update URL with selected agent
|
||||
const newParams = new URLSearchParams(searchParams);
|
||||
newParams.set('agent_id', agent.id);
|
||||
setSearchParams(newParams);
|
||||
setSelectedAgent(agent);
|
||||
setIsDetailOpen(true);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle closing the agent detail dialog
|
||||
*/
|
||||
const handleDetailClose = () => {
|
||||
const newParams = new URLSearchParams(searchParams);
|
||||
newParams.delete('agent_id');
|
||||
setSearchParams(newParams);
|
||||
setSelectedAgent(null);
|
||||
setIsDetailOpen(false);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -229,11 +207,6 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
|
|||
newConversation();
|
||||
};
|
||||
|
||||
// Check if a detail view should be open based on URL
|
||||
useEffect(() => {
|
||||
setIsDetailOpen(!!selectedAgentId);
|
||||
}, [selectedAgentId]);
|
||||
|
||||
// Layout configuration for SidePanelGroup
|
||||
const defaultLayout = useMemo(() => {
|
||||
const resizableLayout = localStorage.getItem('react-resizable-panels:layout');
|
||||
|
|
@ -295,7 +268,7 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
|
|||
variant="outline"
|
||||
data-testid="agents-new-chat-button"
|
||||
aria-label={localize('com_ui_new_chat')}
|
||||
className="rounded-xl border border-border-light bg-surface-secondary p-2 hover:bg-surface-hover max-md:hidden"
|
||||
className="rounded-xl border border-border-light bg-surface-secondary p-2 hover:bg-surface-active-alt max-md:hidden"
|
||||
onClick={handleNewChat}
|
||||
>
|
||||
<NewChatIcon />
|
||||
|
|
@ -512,14 +485,6 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
|
|||
{/* Note: Using Tailwind keyframes for slide in/out animations */}
|
||||
</div>
|
||||
</div>
|
||||
{/* Agent detail dialog */}
|
||||
{isDetailOpen && selectedAgent && (
|
||||
<AgentDetail
|
||||
agent={selectedAgent}
|
||||
isOpen={isDetailOpen}
|
||||
onClose={handleDetailClose}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</main>
|
||||
</SidePanelGroup>
|
||||
|
|
|
|||
|
|
@ -1,75 +1,20 @@
|
|||
import { useMemo, useEffect, useState } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ShieldEllipsis } from 'lucide-react';
|
||||
import { useForm, Controller } from 'react-hook-form';
|
||||
import { Permissions, SystemRoles, roleDefaults, PermissionTypes } from 'librechat-data-provider';
|
||||
import {
|
||||
Button,
|
||||
Switch,
|
||||
OGDialog,
|
||||
DropdownPopup,
|
||||
OGDialogTitle,
|
||||
OGDialogContent,
|
||||
OGDialogTrigger,
|
||||
useToastContext,
|
||||
} from '@librechat/client';
|
||||
import type { Control, UseFormSetValue, UseFormGetValues } from 'react-hook-form';
|
||||
import { Permissions, PermissionTypes } from 'librechat-data-provider';
|
||||
import { Button, useToastContext } from '@librechat/client';
|
||||
import { AdminSettingsDialog } from '~/components/ui';
|
||||
import { useUpdateMarketplacePermissionsMutation } from '~/data-provider';
|
||||
import { useLocalize, useAuthContext } from '~/hooks';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import type { PermissionConfig } from '~/components/ui';
|
||||
|
||||
type FormValues = {
|
||||
[Permissions.USE]: boolean;
|
||||
};
|
||||
|
||||
type LabelControllerProps = {
|
||||
label: string;
|
||||
marketplacePerm: Permissions.USE;
|
||||
control: Control<FormValues, unknown, FormValues>;
|
||||
setValue: UseFormSetValue<FormValues>;
|
||||
getValues: UseFormGetValues<FormValues>;
|
||||
};
|
||||
|
||||
const LabelController: React.FC<LabelControllerProps> = ({
|
||||
control,
|
||||
marketplacePerm,
|
||||
label,
|
||||
getValues,
|
||||
setValue,
|
||||
}) => (
|
||||
<div className="mb-4 flex items-center justify-between gap-2">
|
||||
<button
|
||||
className="cursor-pointer select-none"
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setValue(marketplacePerm, !getValues(marketplacePerm), {
|
||||
shouldDirty: true,
|
||||
})
|
||||
}
|
||||
tabIndex={0}
|
||||
>
|
||||
{label}
|
||||
</button>
|
||||
<Controller
|
||||
name={marketplacePerm}
|
||||
control={control}
|
||||
render={({ field }) => (
|
||||
<Switch
|
||||
{...field}
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
value={field.value.toString()}
|
||||
aria-label={label}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
const permissions: PermissionConfig[] = [
|
||||
{ permission: Permissions.USE, labelKey: 'com_ui_marketplace_allow_use' },
|
||||
];
|
||||
|
||||
const MarketplaceAdminSettings = () => {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const { user, roles } = useAuthContext();
|
||||
const { mutate, isLoading } = useUpdateMarketplacePermissionsMutation({
|
||||
|
||||
const mutation = useUpdateMarketplacePermissionsMutation({
|
||||
onSuccess: () => {
|
||||
showToast({ status: 'success', message: localize('com_ui_saved') });
|
||||
},
|
||||
|
|
@ -78,133 +23,27 @@ const MarketplaceAdminSettings = () => {
|
|||
},
|
||||
});
|
||||
|
||||
const [isRoleMenuOpen, setIsRoleMenuOpen] = useState(false);
|
||||
const [selectedRole, setSelectedRole] = useState<SystemRoles>(SystemRoles.USER);
|
||||
|
||||
const defaultValues = useMemo(() => {
|
||||
const rolePerms = roles?.[selectedRole]?.permissions;
|
||||
if (rolePerms) {
|
||||
return rolePerms[PermissionTypes.MARKETPLACE];
|
||||
}
|
||||
return roleDefaults[selectedRole].permissions[PermissionTypes.MARKETPLACE];
|
||||
}, [roles, selectedRole]);
|
||||
|
||||
const {
|
||||
reset,
|
||||
control,
|
||||
setValue,
|
||||
getValues,
|
||||
handleSubmit,
|
||||
formState: { isSubmitting },
|
||||
} = useForm<FormValues>({
|
||||
mode: 'onChange',
|
||||
defaultValues,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const value = roles?.[selectedRole]?.permissions?.[PermissionTypes.MARKETPLACE];
|
||||
if (value) {
|
||||
reset(value);
|
||||
} else {
|
||||
reset(roleDefaults[selectedRole].permissions[PermissionTypes.MARKETPLACE]);
|
||||
}
|
||||
}, [roles, selectedRole, reset]);
|
||||
|
||||
if (user?.role !== SystemRoles.ADMIN) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const labelControllerData: {
|
||||
marketplacePerm: Permissions.USE;
|
||||
label: string;
|
||||
}[] = [
|
||||
{
|
||||
marketplacePerm: Permissions.USE,
|
||||
label: localize('com_ui_marketplace_allow_use'),
|
||||
},
|
||||
];
|
||||
|
||||
const onSubmit = (data: FormValues) => {
|
||||
mutate({ roleName: selectedRole, updates: data });
|
||||
};
|
||||
|
||||
const roleDropdownItems = [
|
||||
{
|
||||
label: SystemRoles.USER,
|
||||
onClick: () => {
|
||||
setSelectedRole(SystemRoles.USER);
|
||||
},
|
||||
},
|
||||
{
|
||||
label: SystemRoles.ADMIN,
|
||||
onClick: () => {
|
||||
setSelectedRole(SystemRoles.ADMIN);
|
||||
},
|
||||
},
|
||||
];
|
||||
const trigger = (
|
||||
<Button
|
||||
variant="outline"
|
||||
className="relative h-12 rounded-xl border-border-medium font-medium"
|
||||
aria-label={localize('com_ui_admin_settings')}
|
||||
>
|
||||
<ShieldEllipsis className="cursor-pointer" aria-hidden="true" />
|
||||
</Button>
|
||||
);
|
||||
|
||||
return (
|
||||
<OGDialog>
|
||||
<OGDialogTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
className="relative h-12 rounded-xl border-border-medium font-medium"
|
||||
>
|
||||
<ShieldEllipsis className="cursor-pointer" aria-hidden="true" />
|
||||
</Button>
|
||||
</OGDialogTrigger>
|
||||
<OGDialogContent className="w-11/12 max-w-md border-border-light bg-surface-primary text-text-primary">
|
||||
<OGDialogTitle>
|
||||
{localize('com_ui_admin_settings_section', { section: localize('com_ui_marketplace') })}
|
||||
</OGDialogTitle>
|
||||
<div className="p-2">
|
||||
{/* Role selection dropdown */}
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium">{localize('com_ui_role_select')}:</span>
|
||||
<DropdownPopup
|
||||
unmountOnHide={true}
|
||||
menuId="role-dropdown"
|
||||
isOpen={isRoleMenuOpen}
|
||||
setIsOpen={setIsRoleMenuOpen}
|
||||
trigger={
|
||||
<Ariakit.MenuButton className="inline-flex w-1/4 items-center justify-center rounded-lg border border-border-light bg-transparent px-2 py-1 text-text-primary transition-all ease-in-out hover:bg-surface-tertiary">
|
||||
{selectedRole}
|
||||
</Ariakit.MenuButton>
|
||||
}
|
||||
items={roleDropdownItems}
|
||||
itemClassName="items-center justify-center"
|
||||
sameWidth={true}
|
||||
/>
|
||||
</div>
|
||||
{/* Permissions form */}
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<div className="py-5">
|
||||
{labelControllerData.map(({ marketplacePerm, label }) => (
|
||||
<div key={marketplacePerm}>
|
||||
<LabelController
|
||||
control={control}
|
||||
marketplacePerm={marketplacePerm}
|
||||
label={label}
|
||||
getValues={getValues}
|
||||
setValue={setValue}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<div className="flex justify-end">
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
disabled={isSubmitting || isLoading}
|
||||
className="btn rounded bg-green-500 font-bold text-white transition-all hover:bg-green-600"
|
||||
>
|
||||
{localize('com_ui_save')}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</OGDialogContent>
|
||||
</OGDialog>
|
||||
<AdminSettingsDialog
|
||||
permissionType={PermissionTypes.MARKETPLACE}
|
||||
sectionKey="com_ui_marketplace"
|
||||
permissions={permissions}
|
||||
menuId="marketplace-role-dropdown"
|
||||
mutation={mutation}
|
||||
trigger={trigger}
|
||||
dialogContentClassName="w-11/12 max-w-md border-border-light bg-surface-primary text-text-primary"
|
||||
showAdminWarning={false}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,27 @@ jest.mock('~/hooks', () => ({
|
|||
useLocalize: () => mockLocalize,
|
||||
useDebounce: jest.fn(),
|
||||
useAgentCategories: jest.fn(),
|
||||
useDefaultConvo: jest.fn(() => jest.fn(() => ({}))),
|
||||
useFavorites: jest.fn(() => ({
|
||||
isFavoriteAgent: jest.fn(() => false),
|
||||
toggleFavoriteAgent: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock Providers
|
||||
jest.mock('~/Providers', () => ({
|
||||
useChatContext: jest.fn(() => ({
|
||||
conversation: null,
|
||||
newConversation: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock @librechat/client toast context
|
||||
jest.mock('@librechat/client', () => ({
|
||||
...jest.requireActual('@librechat/client'),
|
||||
useToastContext: jest.fn(() => ({
|
||||
showToast: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/data-provider/Agents', () => ({
|
||||
|
|
@ -115,6 +136,13 @@ jest.mock('../SmartLoader', () => ({
|
|||
useHasData: jest.fn(() => true),
|
||||
}));
|
||||
|
||||
// Mock AgentDetailContent to avoid testing dialog internals
|
||||
jest.mock('../AgentDetailContent', () => ({
|
||||
__esModule: true,
|
||||
// eslint-disable-next-line i18next/no-literal-string
|
||||
default: () => <div data-testid="agent-detail-content">Agent Detail Content</div>,
|
||||
}));
|
||||
|
||||
// Import the actual modules to get the mocked functions
|
||||
import { useMarketplaceAgentsInfiniteQuery } from '~/data-provider/Agents';
|
||||
import { useAgentCategories, useDebounce } from '~/hooks';
|
||||
|
|
@ -299,7 +327,12 @@ describe('Accessibility Improvements', () => {
|
|||
};
|
||||
|
||||
it('provides comprehensive ARIA labels', () => {
|
||||
render(<AgentCard agent={mockAgent as t.Agent} onClick={jest.fn()} />);
|
||||
const Wrapper = createWrapper();
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent as t.Agent} onSelect={jest.fn()} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
expect(card).toHaveAttribute('aria-label', 'Test Agent agent. A test agent for testing');
|
||||
|
|
@ -308,16 +341,19 @@ describe('Accessibility Improvements', () => {
|
|||
});
|
||||
|
||||
it('supports keyboard interaction', () => {
|
||||
const onClick = jest.fn();
|
||||
render(<AgentCard agent={mockAgent as t.Agent} onClick={onClick} />);
|
||||
const Wrapper = createWrapper();
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent as t.Agent} onSelect={jest.fn()} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
|
||||
fireEvent.keyDown(card, { key: 'Enter' });
|
||||
expect(onClick).toHaveBeenCalledTimes(1);
|
||||
|
||||
fireEvent.keyDown(card, { key: ' ' });
|
||||
expect(onClick).toHaveBeenCalledTimes(2);
|
||||
// Card should be keyboard accessible - actual dialog behavior is handled by Radix
|
||||
expect(card).toHaveAttribute('tabIndex', '0');
|
||||
expect(() => fireEvent.keyDown(card, { key: 'Enter' })).not.toThrow();
|
||||
expect(() => fireEvent.keyDown(card, { key: ' ' })).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import { render, screen, fireEvent } from '@testing-library/react';
|
|||
import '@testing-library/jest-dom';
|
||||
import AgentCard from '../AgentCard';
|
||||
import type t from 'librechat-data-provider';
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
|
||||
|
||||
// Mock useLocalize hook
|
||||
jest.mock('~/hooks/useLocalize', () => () => (key: string) => {
|
||||
|
|
@ -11,25 +12,32 @@ jest.mock('~/hooks/useLocalize', () => () => (key: string) => {
|
|||
com_agents_agent_card_label: '{{name}} agent. {{description}}',
|
||||
com_agents_category_general: 'General',
|
||||
com_agents_category_hr: 'Human Resources',
|
||||
com_ui_by_author: 'by {{0}}',
|
||||
com_agents_description_card: '{{description}}',
|
||||
};
|
||||
return mockTranslations[key] || key;
|
||||
});
|
||||
|
||||
// Mock useAgentCategories hook
|
||||
jest.mock('~/hooks', () => ({
|
||||
useLocalize: () => (key: string, values?: Record<string, string>) => {
|
||||
useLocalize: () => (key: string, values?: Record<string, string | number>) => {
|
||||
const mockTranslations: Record<string, string> = {
|
||||
com_agents_created_by: 'Created by',
|
||||
com_agents_agent_card_label: '{{name}} agent. {{description}}',
|
||||
com_agents_category_general: 'General',
|
||||
com_agents_category_hr: 'Human Resources',
|
||||
com_ui_by_author: 'by {{0}}',
|
||||
com_agents_description_card: '{{description}}',
|
||||
};
|
||||
let translation = mockTranslations[key] || key;
|
||||
|
||||
// Replace placeholders with actual values
|
||||
if (values) {
|
||||
Object.entries(values).forEach(([placeholder, value]) => {
|
||||
translation = translation.replace(new RegExp(`{{${placeholder}}}`, 'g'), value);
|
||||
translation = translation.replace(
|
||||
new RegExp(`\\{\\{${placeholder}\\}\\}`, 'g'),
|
||||
String(value),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -42,8 +50,81 @@ jest.mock('~/hooks', () => ({
|
|||
{ value: 'custom', label: 'Custom Category' }, // Non-localized custom category
|
||||
],
|
||||
}),
|
||||
useDefaultConvo: jest.fn(() => jest.fn(() => ({}))),
|
||||
useFavorites: jest.fn(() => ({
|
||||
isFavoriteAgent: jest.fn(() => false),
|
||||
toggleFavoriteAgent: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock AgentDetailContent to avoid testing dialog internals
|
||||
jest.mock('../AgentDetailContent', () => ({
|
||||
__esModule: true,
|
||||
// eslint-disable-next-line i18next/no-literal-string
|
||||
default: () => <div data-testid="agent-detail-content">Agent Detail Content</div>,
|
||||
}));
|
||||
|
||||
// Mock Providers
|
||||
jest.mock('~/Providers', () => ({
|
||||
useChatContext: jest.fn(() => ({
|
||||
conversation: null,
|
||||
newConversation: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock @librechat/client with proper Dialog behavior
|
||||
jest.mock('@librechat/client', () => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
const React = require('react');
|
||||
return {
|
||||
...jest.requireActual('@librechat/client'),
|
||||
useToastContext: jest.fn(() => ({
|
||||
showToast: jest.fn(),
|
||||
})),
|
||||
OGDialog: ({ children, open, onOpenChange }: any) => {
|
||||
// Store onOpenChange in context for trigger to call
|
||||
return (
|
||||
<div data-testid="dialog-wrapper" data-open={open}>
|
||||
{React.Children.map(children, (child: any) => {
|
||||
if (child?.type?.displayName === 'OGDialogTrigger' || child?.props?.['data-trigger']) {
|
||||
return React.cloneElement(child, { onOpenChange });
|
||||
}
|
||||
// Only render content when open
|
||||
if (child?.type?.displayName === 'OGDialogContent' && !open) {
|
||||
return null;
|
||||
}
|
||||
return child;
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
OGDialogTrigger: ({ children, asChild, onOpenChange }: any) => {
|
||||
if (asChild && React.isValidElement(children)) {
|
||||
return React.cloneElement(children as React.ReactElement<any>, {
|
||||
onClick: (e: any) => {
|
||||
(children as any).props?.onClick?.(e);
|
||||
onOpenChange?.(true);
|
||||
},
|
||||
});
|
||||
}
|
||||
return <div onClick={() => onOpenChange?.(true)}>{children}</div>;
|
||||
},
|
||||
OGDialogContent: ({ children }: any) => <div data-testid="dialog-content">{children}</div>,
|
||||
Label: ({ children, className }: any) => <span className={className}>{children}</span>,
|
||||
};
|
||||
});
|
||||
|
||||
// Create wrapper with QueryClient
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
describe('AgentCard', () => {
|
||||
const mockAgent: t.Agent = {
|
||||
id: '1',
|
||||
|
|
@ -69,22 +150,30 @@ describe('AgentCard', () => {
|
|||
},
|
||||
};
|
||||
|
||||
const mockOnClick = jest.fn();
|
||||
const mockOnSelect = jest.fn();
|
||||
const Wrapper = createWrapper();
|
||||
|
||||
beforeEach(() => {
|
||||
mockOnClick.mockClear();
|
||||
mockOnSelect.mockClear();
|
||||
});
|
||||
|
||||
it('renders agent information correctly', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByText('Test Agent')).toBeInTheDocument();
|
||||
expect(screen.getByText('A test agent for testing purposes')).toBeInTheDocument();
|
||||
expect(screen.getByText('Test Support')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('displays avatar when provided as object', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const avatarImg = screen.getByAltText('Test Agent avatar');
|
||||
expect(avatarImg).toBeInTheDocument();
|
||||
|
|
@ -97,7 +186,11 @@ describe('AgentCard', () => {
|
|||
avatar: '/string-avatar.png' as any, // Legacy support for string avatars
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithStringAvatar} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={agentWithStringAvatar} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const avatarImg = screen.getByAltText('Test Agent avatar');
|
||||
expect(avatarImg).toBeInTheDocument();
|
||||
|
|
@ -110,51 +203,73 @@ describe('AgentCard', () => {
|
|||
avatar: undefined,
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithoutAvatar as any as t.Agent} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={agentWithoutAvatar as any as t.Agent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
// Check for Feather icon presence by looking for the svg with lucide-feather class
|
||||
const featherIcon = document.querySelector('.lucide-feather');
|
||||
expect(featherIcon).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onClick when card is clicked', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
it('card is clickable and has dialog trigger', () => {
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
fireEvent.click(card);
|
||||
|
||||
expect(mockOnClick).toHaveBeenCalledTimes(1);
|
||||
// Card should be clickable - the actual dialog behavior is handled by Radix
|
||||
expect(card).toBeInTheDocument();
|
||||
expect(() => fireEvent.click(card)).not.toThrow();
|
||||
});
|
||||
|
||||
it('calls onClick when Enter key is pressed', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
it('handles Enter key press', () => {
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
fireEvent.keyDown(card, { key: 'Enter' });
|
||||
|
||||
expect(mockOnClick).toHaveBeenCalledTimes(1);
|
||||
// Card should respond to keyboard - the actual dialog behavior is handled by Radix
|
||||
expect(() => fireEvent.keyDown(card, { key: 'Enter' })).not.toThrow();
|
||||
});
|
||||
|
||||
it('calls onClick when Space key is pressed', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
it('handles Space key press', () => {
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
fireEvent.keyDown(card, { key: ' ' });
|
||||
|
||||
expect(mockOnClick).toHaveBeenCalledTimes(1);
|
||||
// Card should respond to keyboard - the actual dialog behavior is handled by Radix
|
||||
expect(() => fireEvent.keyDown(card, { key: ' ' })).not.toThrow();
|
||||
});
|
||||
|
||||
it('does not call onClick for other keys', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
it('does not call onSelect for other keys', () => {
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
fireEvent.keyDown(card, { key: 'Escape' });
|
||||
|
||||
expect(mockOnClick).not.toHaveBeenCalled();
|
||||
expect(mockOnSelect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('applies additional className when provided', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} className="custom-class" />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} className="custom-class" />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
expect(card).toHaveClass('custom-class');
|
||||
|
|
@ -167,11 +282,14 @@ describe('AgentCard', () => {
|
|||
authorName: undefined,
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithoutContact} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={agentWithoutContact} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByText('Test Agent')).toBeInTheDocument();
|
||||
expect(screen.getByText('A test agent for testing purposes')).toBeInTheDocument();
|
||||
expect(screen.queryByText(/Created by/)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('displays authorName when support_contact is missing', () => {
|
||||
|
|
@ -181,54 +299,21 @@ describe('AgentCard', () => {
|
|||
authorName: 'John Doe',
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithAuthorName} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={agentWithAuthorName} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('displays support_contact email when name is missing', () => {
|
||||
const agentWithEmailOnly = {
|
||||
...mockAgent,
|
||||
support_contact: { email: 'contact@example.com' },
|
||||
authorName: undefined,
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithEmailOnly} onClick={mockOnClick} />);
|
||||
|
||||
expect(screen.getByText('contact@example.com')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('prioritizes support_contact name over authorName', () => {
|
||||
const agentWithBoth = {
|
||||
...mockAgent,
|
||||
support_contact: { name: 'Support Team' },
|
||||
authorName: 'John Doe',
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithBoth} onClick={mockOnClick} />);
|
||||
|
||||
expect(screen.getByText('Support Team')).toBeInTheDocument();
|
||||
expect(screen.queryByText('John Doe')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('prioritizes name over email in support_contact', () => {
|
||||
const agentWithNameAndEmail = {
|
||||
...mockAgent,
|
||||
support_contact: {
|
||||
name: 'Support Team',
|
||||
email: 'support@example.com',
|
||||
},
|
||||
authorName: undefined,
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithNameAndEmail} onClick={mockOnClick} />);
|
||||
|
||||
expect(screen.getByText('Support Team')).toBeInTheDocument();
|
||||
expect(screen.queryByText('support@example.com')).not.toBeInTheDocument();
|
||||
expect(screen.getByText('by John Doe')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('has proper accessibility attributes', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
expect(card).toHaveAttribute('tabIndex', '0');
|
||||
|
|
@ -244,7 +329,11 @@ describe('AgentCard', () => {
|
|||
category: 'general',
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithCategory} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={agentWithCategory} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByText('General')).toBeInTheDocument();
|
||||
});
|
||||
|
|
@ -255,7 +344,11 @@ describe('AgentCard', () => {
|
|||
category: 'custom',
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithCustomCategory} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={agentWithCustomCategory} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByText('Custom Category')).toBeInTheDocument();
|
||||
});
|
||||
|
|
@ -266,15 +359,35 @@ describe('AgentCard', () => {
|
|||
category: 'unknown',
|
||||
};
|
||||
|
||||
render(<AgentCard agent={agentWithUnknownCategory} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={agentWithUnknownCategory} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByText('Unknown')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not display category tag when category is not provided', () => {
|
||||
render(<AgentCard agent={mockAgent} onClick={mockOnClick} />);
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} onSelect={mockOnSelect} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
expect(screen.queryByText('General')).not.toBeInTheDocument();
|
||||
expect(screen.queryByText('Unknown')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('works without onSelect callback', () => {
|
||||
render(
|
||||
<Wrapper>
|
||||
<AgentCard agent={mockAgent} />
|
||||
</Wrapper>,
|
||||
);
|
||||
|
||||
const card = screen.getByRole('button');
|
||||
// Should not throw when clicking without onSelect
|
||||
expect(() => fireEvent.click(card)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ jest.mock('../ErrorDisplay', () => ({
|
|||
// Mock AgentCard component
|
||||
jest.mock('../AgentCard', () => ({
|
||||
__esModule: true,
|
||||
default: ({ agent, onClick }: { agent: t.Agent; onClick: () => void }) => (
|
||||
<div data-testid={`agent-card-${agent.id}`} onClick={onClick}>
|
||||
default: ({ agent, onSelect }: { agent: t.Agent; onSelect?: (agent: t.Agent) => void }) => (
|
||||
<div data-testid={`agent-card-${agent.id}`} onClick={() => onSelect?.(agent)}>
|
||||
<h3>{agent.name}</h3>
|
||||
<p>{agent.description}</p>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -11,9 +11,10 @@ function Footer({ startupConfig }: { startupConfig: TStartupConfig | null | unde
|
|||
|
||||
const privacyPolicyRender = privacyPolicy?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-600 dark:text-green-500"
|
||||
className="text-sm text-green-600 underline decoration-transparent transition-all duration-200 hover:text-green-700 hover:decoration-green-700 focus:text-green-700 focus:decoration-green-700 dark:text-green-500 dark:hover:text-green-400 dark:hover:decoration-green-400 dark:focus:text-green-400 dark:focus:decoration-green-400"
|
||||
href={privacyPolicy.externalUrl}
|
||||
target={privacyPolicy.openNewTab ? '_blank' : undefined}
|
||||
// Removed for WCAG compliance
|
||||
// target={privacyPolicy.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_privacy_policy')}
|
||||
|
|
@ -22,9 +23,10 @@ function Footer({ startupConfig }: { startupConfig: TStartupConfig | null | unde
|
|||
|
||||
const termsOfServiceRender = termsOfService?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-600 dark:text-green-500"
|
||||
className="text-sm text-green-600 underline decoration-transparent transition-all duration-200 hover:text-green-700 hover:decoration-green-700 focus:text-green-700 focus:decoration-green-700 dark:text-green-500 dark:hover:text-green-400 dark:hover:decoration-green-400 dark:focus:text-green-400 dark:focus:decoration-green-400"
|
||||
href={termsOfService.externalUrl}
|
||||
target={termsOfService.openNewTab ? '_blank' : undefined}
|
||||
// Removed for WCAG compliance
|
||||
// target={termsOfService.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_terms_of_service')}
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ function Login() {
|
|||
{localize('com_auth_no_account')}{' '}
|
||||
<a
|
||||
href={registerPage()}
|
||||
className="inline-flex p-1 text-sm font-medium text-green-600 transition-colors hover:text-green-700 dark:text-green-400 dark:hover:text-green-300"
|
||||
className="inline-flex p-1 text-sm font-medium text-green-600 underline decoration-transparent transition-all duration-200 hover:text-green-700 hover:decoration-green-700 focus:text-green-700 focus:decoration-green-700 dark:text-green-500 dark:hover:text-green-400 dark:hover:decoration-green-400 dark:focus:text-green-400 dark:focus:decoration-green-400"
|
||||
>
|
||||
{localize('com_auth_sign_up')}
|
||||
</a>
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ const LoginForm: React.FC<TLoginFormProps> = ({ onSubmit, startupConfig, error,
|
|||
{startupConfig.passwordResetEnabled && (
|
||||
<a
|
||||
href="/forgot-password"
|
||||
className="inline-flex p-1 text-sm font-medium text-green-600 transition-colors hover:text-green-700 dark:text-green-400 dark:hover:text-green-300"
|
||||
className="inline-flex p-1 text-sm font-medium text-green-600 underline decoration-transparent transition-all duration-200 hover:text-green-700 hover:decoration-green-700 focus:text-green-700 focus:decoration-green-700 dark:text-green-500 dark:hover:text-green-400 dark:hover:decoration-green-400 dark:focus:text-green-400 dark:focus:decoration-green-400"
|
||||
>
|
||||
{localize('com_auth_password_forgot')}
|
||||
</a>
|
||||
|
|
|
|||
|
|
@ -156,7 +156,6 @@ test('renders registration form', () => {
|
|||
);
|
||||
});
|
||||
|
||||
// eslint-disable-next-line jest/no-commented-out-tests
|
||||
// test('calls registerUser.mutate on registration', async () => {
|
||||
// const mutate = jest.fn();
|
||||
// const { getByTestId, getByRole, history } = setup({
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ const BookmarkEditDialog = ({
|
|||
<OGDialogTemplate
|
||||
title={bookmark ? localize('com_ui_bookmarks_edit') : localize('com_ui_bookmarks_new')}
|
||||
showCloseButton={false}
|
||||
className="w-11/12 md:max-w-2xl"
|
||||
className="w-11/12 md:max-w-lg"
|
||||
main={
|
||||
<BookmarkForm
|
||||
tags={tags}
|
||||
|
|
|
|||
|
|
@ -85,16 +85,11 @@ const BookmarkForm = ({
|
|||
};
|
||||
|
||||
return (
|
||||
<form
|
||||
ref={formRef}
|
||||
className="mt-6"
|
||||
aria-label="Bookmark form"
|
||||
method="POST"
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
>
|
||||
<div className="flex w-full flex-col items-center gap-2">
|
||||
<div className="grid w-full items-center gap-2">
|
||||
<Label htmlFor="bookmark-tag" className="text-left text-sm font-medium">
|
||||
<form ref={formRef} aria-label="Bookmark form" method="POST" onSubmit={handleSubmit(onSubmit)}>
|
||||
<div className="space-y-4">
|
||||
{/* Tag name input */}
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="bookmark-tag" className="text-sm font-medium text-text-primary">
|
||||
{localize('com_ui_bookmarks_title')}
|
||||
</Label>
|
||||
<Input
|
||||
|
|
@ -118,24 +113,24 @@ const BookmarkForm = ({
|
|||
);
|
||||
},
|
||||
})}
|
||||
className="w-full"
|
||||
aria-invalid={!!errors.tag}
|
||||
placeholder={
|
||||
bookmark ? localize('com_ui_bookmarks_edit') : localize('com_ui_bookmarks_new')
|
||||
}
|
||||
placeholder={localize('com_ui_enter_name')}
|
||||
aria-describedby={errors.tag ? 'bookmark-tag-error' : undefined}
|
||||
/>
|
||||
{errors.tag && (
|
||||
<span id="bookmark-tag-error" role="alert" className="text-sm font-bold text-red-500">
|
||||
<span id="bookmark-tag-error" role="alert" className="text-sm text-red-500">
|
||||
{errors.tag.message}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mt-4 grid w-full items-center gap-2">
|
||||
{/* Description textarea */}
|
||||
<div className="space-y-2">
|
||||
<Label
|
||||
id="bookmark-description-label"
|
||||
htmlFor="bookmark-description"
|
||||
className="text-left text-sm font-medium"
|
||||
className="text-sm font-medium text-text-primary"
|
||||
>
|
||||
{localize('com_ui_bookmarks_description')}
|
||||
</Label>
|
||||
|
|
@ -151,14 +146,20 @@ const BookmarkForm = ({
|
|||
})}
|
||||
id="bookmark-description"
|
||||
disabled={false}
|
||||
placeholder={localize('com_ui_enter_description')}
|
||||
className={cn(
|
||||
'flex h-10 max-h-[250px] min-h-[100px] w-full resize-none rounded-lg border border-input bg-transparent px-3 py-2 text-sm ring-offset-background focus-visible:outline-none',
|
||||
'min-h-[100px] w-full resize-none rounded-lg border border-border-light',
|
||||
'bg-transparent px-3 py-2 text-sm text-text-primary',
|
||||
'placeholder:text-text-tertiary',
|
||||
'focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-border-heavy',
|
||||
)}
|
||||
aria-labelledby="bookmark-description-label"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Add to conversation checkbox */}
|
||||
{conversationId != null && conversationId && (
|
||||
<div className="mt-2 flex w-full items-center">
|
||||
<div className="flex items-center gap-2">
|
||||
<Controller
|
||||
name="addToConversation"
|
||||
control={control}
|
||||
|
|
@ -167,7 +168,7 @@ const BookmarkForm = ({
|
|||
{...field}
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
className="relative float-left mr-2 inline-flex h-4 w-4 cursor-pointer"
|
||||
className="size-4 cursor-pointer"
|
||||
value={field.value?.toString()}
|
||||
aria-label={localize('com_ui_bookmarks_add_to_conversation')}
|
||||
/>
|
||||
|
|
@ -176,16 +177,14 @@ const BookmarkForm = ({
|
|||
<button
|
||||
type="button"
|
||||
aria-label={localize('com_ui_bookmarks_add_to_conversation')}
|
||||
className="form-check-label w-full cursor-pointer text-text-primary"
|
||||
className="cursor-pointer text-sm text-text-primary"
|
||||
onClick={() =>
|
||||
setValue('addToConversation', !(getValues('addToConversation') ?? false), {
|
||||
shouldDirty: true,
|
||||
})
|
||||
}
|
||||
>
|
||||
<div className="flex select-none items-center">
|
||||
{localize('com_ui_bookmarks_add_to_conversation')}
|
||||
</div>
|
||||
{localize('com_ui_bookmarks_add_to_conversation')}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ function AddMultiConvo() {
|
|||
setAddedConvo({
|
||||
...convo,
|
||||
title: '',
|
||||
});
|
||||
} as TConversation);
|
||||
|
||||
const textarea = document.getElementById(mainTextareaId);
|
||||
if (textarea) {
|
||||
|
|
@ -34,13 +34,12 @@ function AddMultiConvo() {
|
|||
|
||||
return (
|
||||
<TooltipAnchor
|
||||
id="add-multi-conversation-button"
|
||||
aria-label={localize('com_ui_add_multi_conversation')}
|
||||
description={localize('com_ui_add_multi_conversation')}
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
aria-label={localize('com_ui_add_multi_conversation')}
|
||||
onClick={clickHandler}
|
||||
data-testid="parameters-button"
|
||||
data-testid="add-multi-convo-button"
|
||||
className="inline-flex size-10 flex-shrink-0 items-center justify-center rounded-xl border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
>
|
||||
<PlusCircle size={16} aria-hidden="true" />
|
||||
|
|
|
|||
|
|
@ -7,7 +7,13 @@ import { Constants, buildTree } from 'librechat-data-provider';
|
|||
import type { TMessage } from 'librechat-data-provider';
|
||||
import type { ChatFormValues } from '~/common';
|
||||
import { ChatContext, AddedChatContext, useFileMapContext, ChatFormProvider } from '~/Providers';
|
||||
import { useChatHelpers, useAddedResponse, useSSE } from '~/hooks';
|
||||
import {
|
||||
useResumableStreamToggle,
|
||||
useAddedResponse,
|
||||
useResumeOnLoad,
|
||||
useAdaptiveSSE,
|
||||
useChatHelpers,
|
||||
} from '~/hooks';
|
||||
import ConversationStarters from './Input/ConversationStarters';
|
||||
import { useGetMessagesByConvoId } from '~/data-provider';
|
||||
import MessagesView from './Messages/MessagesView';
|
||||
|
|
@ -32,7 +38,6 @@ function LoadingSpinner() {
|
|||
function ChatView({ index = 0 }: { index?: number }) {
|
||||
const { conversationId } = useParams();
|
||||
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
|
||||
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
|
||||
const centerFormOnLanding = useRecoilValue(store.centerFormOnLanding);
|
||||
|
||||
const fileMap = useFileMapContext();
|
||||
|
|
@ -49,10 +54,18 @@ function ChatView({ index = 0 }: { index?: number }) {
|
|||
});
|
||||
|
||||
const chatHelpers = useChatHelpers(index, conversationId);
|
||||
const addedChatHelpers = useAddedResponse({ rootIndex: index });
|
||||
const addedChatHelpers = useAddedResponse();
|
||||
|
||||
useSSE(rootSubmission, chatHelpers, false);
|
||||
useSSE(addedSubmission, addedChatHelpers, true);
|
||||
useResumableStreamToggle(
|
||||
chatHelpers.conversation?.endpoint,
|
||||
chatHelpers.conversation?.endpointType,
|
||||
);
|
||||
|
||||
useAdaptiveSSE(rootSubmission, chatHelpers, false, index);
|
||||
|
||||
// Auto-resume if navigating back to conversation with active job
|
||||
// Wait for messages to load before resuming to avoid race condition
|
||||
useResumeOnLoad(conversationId, chatHelpers.getMessages, index, !isLoading);
|
||||
|
||||
const methods = useForm<ChatFormValues>({
|
||||
defaultValues: { text: '' },
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ export default function ExportAndShareMenu({
|
|||
className="inline-flex size-10 flex-shrink-0 items-center justify-center rounded-xl border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
>
|
||||
<Share2
|
||||
className="icon-md text-text-secondary"
|
||||
className="icon-lg text-text-primary"
|
||||
aria-hidden="true"
|
||||
focusable="false"
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -45,10 +45,10 @@ export default function Header() {
|
|||
{!navVisible && (
|
||||
<motion.div
|
||||
className="flex items-center gap-2"
|
||||
initial={{ width: 0, opacity: 0 }}
|
||||
animate={{ width: 'auto', opacity: 1 }}
|
||||
exit={{ width: 0, opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.15 }}
|
||||
key="header-buttons"
|
||||
>
|
||||
<OpenSidebar setNavVisible={setNavVisible} className="max-md:hidden" />
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import { useMemo } from 'react';
|
||||
import type { TConversation, TEndpointOption, TPreset } from 'librechat-data-provider';
|
||||
import { isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import type { SetterOrUpdater } from 'recoil';
|
||||
import useGetSender from '~/hooks/Conversations/useGetSender';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { EndpointIcon } from '~/components/Endpoints';
|
||||
import { getPresetTitle } from '~/utils';
|
||||
import { useAgentsMapContext } from '~/Providers';
|
||||
|
||||
export default function AddedConvo({
|
||||
addedConvo,
|
||||
|
|
@ -13,13 +13,23 @@ export default function AddedConvo({
|
|||
addedConvo: TConversation | null;
|
||||
setAddedConvo: SetterOrUpdater<TConversation | null>;
|
||||
}) {
|
||||
const getSender = useGetSender();
|
||||
const agentsMap = useAgentsMapContext();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const title = useMemo(() => {
|
||||
const sender = getSender(addedConvo as TEndpointOption);
|
||||
const title = getPresetTitle(addedConvo as TPreset);
|
||||
return `+ ${sender}: ${title}`;
|
||||
}, [addedConvo, getSender]);
|
||||
// Priority: agent name > modelDisplayLabel > modelLabel > model
|
||||
if (isAgentsEndpoint(addedConvo?.endpoint) && addedConvo?.agent_id) {
|
||||
const agent = agentsMap?.[addedConvo.agent_id];
|
||||
if (agent?.name) {
|
||||
return `+ ${agent.name}`;
|
||||
}
|
||||
}
|
||||
|
||||
const endpointConfig = endpointsConfig?.[addedConvo?.endpoint ?? ''];
|
||||
const displayLabel =
|
||||
endpointConfig?.modelDisplayLabel || addedConvo?.modelLabel || addedConvo?.model || 'AI';
|
||||
|
||||
return `+ ${displayLabel}`;
|
||||
}, [addedConvo, agentsMap, endpointsConfig]);
|
||||
|
||||
if (!addedConvo) {
|
||||
return null;
|
||||
|
|
|
|||
|
|
@ -100,7 +100,8 @@ function Artifacts() {
|
|||
'ml-1 h-4 w-4 text-text-secondary transition-transform duration-300 md:ml-0.5',
|
||||
isButtonExpanded && 'rotate-180',
|
||||
)}
|
||||
aria-hidden="true" />
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</Ariakit.MenuButton>
|
||||
|
||||
<Ariakit.Menu
|
||||
|
|
|
|||
|
|
@ -78,14 +78,11 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
handleStopGenerating,
|
||||
} = useChatContext();
|
||||
const {
|
||||
addedIndex,
|
||||
generateConversation,
|
||||
conversation: addedConvo,
|
||||
setConversation: setAddedConvo,
|
||||
isSubmitting: isSubmittingAdded,
|
||||
} = useAddedChatContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
const showStopAdded = useRecoilValue(store.showStopButtonByIndex(addedIndex));
|
||||
|
||||
const endpoint = useMemo(
|
||||
() => conversation?.endpointType ?? conversation?.endpoint,
|
||||
|
|
@ -131,7 +128,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
setFiles,
|
||||
textAreaRef,
|
||||
conversationId,
|
||||
isSubmitting: isSubmitting || isSubmittingAdded,
|
||||
isSubmitting,
|
||||
});
|
||||
|
||||
const { submitMessage, submitPrompt } = useSubmitMessage();
|
||||
|
|
@ -327,7 +324,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
</div>
|
||||
<BadgeRow
|
||||
showEphemeralBadges={!isAgentsEndpoint(endpoint) && !isAssistantsEndpoint(endpoint)}
|
||||
isSubmitting={isSubmitting || isSubmittingAdded}
|
||||
isSubmitting={isSubmitting}
|
||||
conversationId={conversationId}
|
||||
onChange={setBadges}
|
||||
isInChat={
|
||||
|
|
@ -346,7 +343,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
|||
/>
|
||||
)}
|
||||
<div className={`${isRTL ? 'ml-2' : 'mr-2'}`}>
|
||||
{(isSubmitting || isSubmittingAdded) && (showStopButton || showStopAdded) ? (
|
||||
{isSubmitting && showStopButton ? (
|
||||
<StopButton stop={handleStopGenerating} setShowStopButton={setShowStopButton} />
|
||||
) : (
|
||||
endpoint && (
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ const AttachFile = ({ disabled }: { disabled?: boolean | null }) => {
|
|||
aria-label={localize('com_sidepanel_attach_files')}
|
||||
disabled={isUploadDisabled}
|
||||
className={cn(
|
||||
'flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
'flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-primary focus-visible:ring-opacity-50',
|
||||
)}
|
||||
onKeyDownCapture={(e) => {
|
||||
if (!inputRef.current) {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ import {
|
|||
getEndpointFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import { useGetFileConfig, useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useGetFileConfig, useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from '~/Providers';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
import AttachFile from './AttachFile';
|
||||
|
||||
|
|
@ -26,6 +27,28 @@ function AttachFileChat({
|
|||
const isAgents = useMemo(() => isAgentsEndpoint(endpoint), [endpoint]);
|
||||
const isAssistants = useMemo(() => isAssistantsEndpoint(endpoint), [endpoint]);
|
||||
|
||||
const agentsMap = useAgentsMapContext();
|
||||
|
||||
const needsAgentFetch = useMemo(() => {
|
||||
if (!isAgents || !conversation?.agent_id) {
|
||||
return false;
|
||||
}
|
||||
const agent = agentsMap?.[conversation.agent_id];
|
||||
return !agent?.model_parameters;
|
||||
}, [isAgents, conversation?.agent_id, agentsMap]);
|
||||
|
||||
const { data: agentData } = useGetAgentByIdQuery(conversation?.agent_id, {
|
||||
enabled: needsAgentFetch,
|
||||
});
|
||||
|
||||
const useResponsesApi = useMemo(() => {
|
||||
if (!isAgents || !conversation?.agent_id || conversation?.useResponsesApi) {
|
||||
return conversation?.useResponsesApi;
|
||||
}
|
||||
const agent = agentData || agentsMap?.[conversation.agent_id];
|
||||
return agent?.model_parameters?.useResponsesApi;
|
||||
}, [isAgents, conversation?.agent_id, conversation?.useResponsesApi, agentData, agentsMap]);
|
||||
|
||||
const { data: fileConfig = null } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
});
|
||||
|
|
@ -68,6 +91,7 @@ function AttachFileChat({
|
|||
conversationId={conversationId}
|
||||
agentId={conversation?.agent_id}
|
||||
endpointFileConfig={endpointFileConfig}
|
||||
useResponsesApi={useResponsesApi}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import {
|
|||
TerminalSquareIcon,
|
||||
} from 'lucide-react';
|
||||
import {
|
||||
Providers,
|
||||
EToolResources,
|
||||
EModelEndpoint,
|
||||
defaultAgentCapabilities,
|
||||
|
|
@ -36,6 +37,8 @@ import { ephemeralAgentByConvoId } from '~/store';
|
|||
import { MenuItemProps } from '~/common';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
type FileUploadType = 'image' | 'document' | 'image_document' | 'image_document_video_audio';
|
||||
|
||||
interface AttachFileMenuProps {
|
||||
agentId?: string | null;
|
||||
endpoint?: string | null;
|
||||
|
|
@ -43,6 +46,7 @@ interface AttachFileMenuProps {
|
|||
conversationId: string;
|
||||
endpointType?: EModelEndpoint;
|
||||
endpointFileConfig?: EndpointFileConfig;
|
||||
useResponsesApi?: boolean;
|
||||
}
|
||||
|
||||
const AttachFileMenu = ({
|
||||
|
|
@ -52,6 +56,7 @@ const AttachFileMenu = ({
|
|||
endpointType,
|
||||
conversationId,
|
||||
endpointFileConfig,
|
||||
useResponsesApi,
|
||||
}: AttachFileMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
|
|
@ -83,9 +88,7 @@ const AttachFileMenu = ({
|
|||
ephemeralAgent,
|
||||
);
|
||||
|
||||
const handleUploadClick = (
|
||||
fileType?: 'image' | 'document' | 'multimodal' | 'google_multimodal',
|
||||
) => {
|
||||
const handleUploadClick = (fileType?: FileUploadType) => {
|
||||
if (!inputRef.current) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -94,9 +97,9 @@ const AttachFileMenu = ({
|
|||
inputRef.current.accept = 'image/*';
|
||||
} else if (fileType === 'document') {
|
||||
inputRef.current.accept = '.pdf,application/pdf';
|
||||
} else if (fileType === 'multimodal') {
|
||||
} else if (fileType === 'image_document') {
|
||||
inputRef.current.accept = 'image/*,.pdf,application/pdf';
|
||||
} else if (fileType === 'google_multimodal') {
|
||||
} else if (fileType === 'image_document_video_audio') {
|
||||
inputRef.current.accept = 'image/*,.pdf,application/pdf,video/*,audio/*';
|
||||
} else {
|
||||
inputRef.current.accept = '';
|
||||
|
|
@ -106,23 +109,33 @@ const AttachFileMenu = ({
|
|||
};
|
||||
|
||||
const dropdownItems = useMemo(() => {
|
||||
const createMenuItems = (
|
||||
onAction: (fileType?: 'image' | 'document' | 'multimodal' | 'google_multimodal') => void,
|
||||
) => {
|
||||
const createMenuItems = (onAction: (fileType?: FileUploadType) => void) => {
|
||||
const items: MenuItemProps[] = [];
|
||||
|
||||
const currentProvider = provider || endpoint;
|
||||
let currentProvider = provider || endpoint;
|
||||
|
||||
// This will be removed in a future PR to formally normalize Providers comparisons to be case insensitive
|
||||
if (currentProvider?.toLowerCase() === Providers.OPENROUTER) {
|
||||
currentProvider = Providers.OPENROUTER;
|
||||
}
|
||||
|
||||
const isAzureWithResponsesApi =
|
||||
currentProvider === EModelEndpoint.azureOpenAI && useResponsesApi;
|
||||
|
||||
if (
|
||||
isDocumentSupportedProvider(endpointType) ||
|
||||
isDocumentSupportedProvider(currentProvider)
|
||||
isDocumentSupportedProvider(currentProvider) ||
|
||||
isAzureWithResponsesApi
|
||||
) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_provider'),
|
||||
onClick: () => {
|
||||
setToolResource(undefined);
|
||||
onAction(
|
||||
(provider || endpoint) === EModelEndpoint.google ? 'google_multimodal' : 'multimodal',
|
||||
);
|
||||
let fileType: Exclude<FileUploadType, 'image' | 'document'> = 'image_document';
|
||||
if (currentProvider === Providers.GOOGLE || currentProvider === Providers.OPENROUTER) {
|
||||
fileType = 'image_document_video_audio';
|
||||
}
|
||||
onAction(fileType);
|
||||
},
|
||||
icon: <FileImageIcon className="icon-md" />,
|
||||
});
|
||||
|
|
@ -204,6 +217,7 @@ const AttachFileMenu = ({
|
|||
provider,
|
||||
endpointType,
|
||||
capabilities,
|
||||
useResponsesApi,
|
||||
setToolResource,
|
||||
setEphemeralAgent,
|
||||
sharePointEnabled,
|
||||
|
|
@ -220,7 +234,7 @@ const AttachFileMenu = ({
|
|||
id="attach-file-menu-button"
|
||||
aria-label="Attach File Options"
|
||||
className={cn(
|
||||
'flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
'flex size-9 items-center justify-center rounded-full p-1 hover:bg-surface-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-primary focus-visible:ring-opacity-50',
|
||||
isPopoverActive && 'bg-surface-hover',
|
||||
)}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import React, { useMemo } from 'react';
|
|||
import { useRecoilValue } from 'recoil';
|
||||
import { OGDialog, OGDialogTemplate } from '@librechat/client';
|
||||
import {
|
||||
Providers,
|
||||
inferMimeType,
|
||||
EToolResources,
|
||||
EModelEndpoint,
|
||||
|
|
@ -46,7 +47,7 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
|||
* Use definition for agents endpoint for ephemeral agents
|
||||
* */
|
||||
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
const { conversationId, agentId, endpoint, endpointType } = useDragDropContext();
|
||||
const { conversationId, agentId, endpoint, endpointType, useResponsesApi } = useDragDropContext();
|
||||
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(conversationId ?? ''));
|
||||
const { fileSearchAllowedByAgent, codeAllowedByAgent, provider } = useAgentToolPermissions(
|
||||
agentId,
|
||||
|
|
@ -55,15 +56,28 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
|||
|
||||
const options = useMemo(() => {
|
||||
const _options: FileOption[] = [];
|
||||
const currentProvider = provider || endpoint;
|
||||
let currentProvider = provider || endpoint;
|
||||
|
||||
// This will be removed in a future PR to formally normalize Providers comparisons to be case insensitive
|
||||
if (currentProvider?.toLowerCase() === Providers.OPENROUTER) {
|
||||
currentProvider = Providers.OPENROUTER;
|
||||
}
|
||||
|
||||
/** Helper to get inferred MIME type for a file */
|
||||
const getFileType = (file: File) => inferMimeType(file.name, file.type);
|
||||
|
||||
const isAzureWithResponsesApi =
|
||||
currentProvider === EModelEndpoint.azureOpenAI && useResponsesApi;
|
||||
|
||||
// Check if provider supports document upload
|
||||
if (isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider)) {
|
||||
const isGoogleProvider = currentProvider === EModelEndpoint.google;
|
||||
const validFileTypes = isGoogleProvider
|
||||
if (
|
||||
isDocumentSupportedProvider(endpointType) ||
|
||||
isDocumentSupportedProvider(currentProvider) ||
|
||||
isAzureWithResponsesApi
|
||||
) {
|
||||
const supportsImageDocVideoAudio =
|
||||
currentProvider === EModelEndpoint.google || currentProvider === Providers.OPENROUTER;
|
||||
const validFileTypes = supportsImageDocVideoAudio
|
||||
? files.every((file) => {
|
||||
const type = getFileType(file);
|
||||
return (
|
||||
|
|
@ -123,6 +137,7 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
|||
endpoint,
|
||||
endpointType,
|
||||
capabilities,
|
||||
useResponsesApi,
|
||||
codeAllowedByAgent,
|
||||
fileSearchAllowedByAgent,
|
||||
]);
|
||||
|
|
|
|||
|
|
@ -5,7 +5,15 @@ import { useGetFiles } from '~/data-provider';
|
|||
import { DataTable, columns } from './Table';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
export default function Files({ open, onOpenChange }) {
|
||||
export function MyFilesModal({
|
||||
open,
|
||||
onOpenChange,
|
||||
triggerRef,
|
||||
}: {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
triggerRef?: React.RefObject<HTMLButtonElement | HTMLDivElement | null>;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const { data: files = [] } = useGetFiles<TFile[]>({
|
||||
|
|
@ -18,7 +26,7 @@ export default function Files({ open, onOpenChange }) {
|
|||
});
|
||||
|
||||
return (
|
||||
<OGDialog open={open} onOpenChange={onOpenChange}>
|
||||
<OGDialog open={open} onOpenChange={onOpenChange} triggerRef={triggerRef}>
|
||||
<OGDialogContent
|
||||
title={localize('com_nav_my_files')}
|
||||
className="w-11/12 bg-background text-text-primary shadow-2xl"
|
||||
|
|
@ -92,7 +92,8 @@ export const columns: ColumnDef<TFile>[] = [
|
|||
className="px-2 py-0 text-xs hover:bg-surface-hover sm:px-2 sm:py-2 sm:text-sm"
|
||||
onClick={() => column.toggleSorting(column.getIsSorted() === 'asc')}
|
||||
aria-sort={ariaSort}
|
||||
aria-label={localize('com_ui_name_sort')} aria-hidden="true"
|
||||
aria-label={localize('com_ui_name_sort')}
|
||||
aria-hidden="true"
|
||||
aria-current={sortState ? 'true' : 'false'}
|
||||
>
|
||||
{localize('com_ui_name')}
|
||||
|
|
@ -150,7 +151,8 @@ export const columns: ColumnDef<TFile>[] = [
|
|||
onClick={() => column.toggleSorting(column.getIsSorted() === 'asc')}
|
||||
className="px-2 py-0 text-xs hover:bg-surface-hover sm:px-2 sm:py-2 sm:text-sm"
|
||||
aria-sort={ariaSort}
|
||||
aria-label={localize('com_ui_date_sort')} aria-hidden="true"
|
||||
aria-label={localize('com_ui_date_sort')}
|
||||
aria-hidden="true"
|
||||
aria-current={sortState ? 'true' : 'false'}
|
||||
>
|
||||
{localize('com_ui_date')}
|
||||
|
|
@ -268,7 +270,8 @@ export const columns: ColumnDef<TFile>[] = [
|
|||
className="px-2 py-0 text-xs hover:bg-surface-hover sm:px-2 sm:py-2 sm:text-sm"
|
||||
onClick={() => column.toggleSorting(column.getIsSorted() === 'asc')}
|
||||
aria-sort={ariaSort}
|
||||
aria-label={localize('com_ui_size_sort')} aria-hidden="true"
|
||||
aria-label={localize('com_ui_size_sort')}
|
||||
aria-hidden="true"
|
||||
aria-current={sortState ? 'true' : 'false'}
|
||||
>
|
||||
{localize('com_ui_size')}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import { useState } from 'react';
|
||||
import { Search } from 'lucide-react';
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
import {
|
||||
flexRender,
|
||||
|
|
@ -17,7 +16,6 @@ import type {
|
|||
} from '@tanstack/react-table';
|
||||
import { FileContext } from 'librechat-data-provider';
|
||||
import {
|
||||
Input,
|
||||
Table,
|
||||
Button,
|
||||
Spinner,
|
||||
|
|
@ -26,6 +24,7 @@ import {
|
|||
TableCell,
|
||||
TableHead,
|
||||
TrashIcon,
|
||||
FilterInput,
|
||||
TableHeader,
|
||||
useMediaQuery,
|
||||
} from '@librechat/client';
|
||||
|
|
@ -115,23 +114,13 @@ export default function DataTable<TData, TValue>({ columns, data }: DataTablePro
|
|||
)}
|
||||
{!isSmallScreen && <span className="ml-2">{localize('com_ui_delete')}</span>}
|
||||
</Button>
|
||||
<div className="relative flex-1">
|
||||
<Search className="absolute left-3 top-1/2 z-10 h-4 w-4 -translate-y-1/2 text-text-secondary" />
|
||||
<Input
|
||||
id="files-filter"
|
||||
placeholder=" "
|
||||
value={(table.getColumn('filename')?.getFilterValue() as string | undefined) ?? ''}
|
||||
onChange={(event) => table.getColumn('filename')?.setFilterValue(event.target.value)}
|
||||
className="peer w-full pl-10 text-sm focus-visible:ring-2 focus-visible:ring-ring"
|
||||
aria-label={localize('com_files_filter_input')}
|
||||
/>
|
||||
<label
|
||||
htmlFor="files-filter"
|
||||
className="pointer-events-none absolute left-10 top-1/2 -translate-y-1/2 text-sm text-text-secondary transition-all duration-200 peer-focus:top-0 peer-focus:bg-background peer-focus:px-1 peer-focus:text-xs peer-[:not(:placeholder-shown)]:top-0 peer-[:not(:placeholder-shown)]:bg-background peer-[:not(:placeholder-shown)]:px-1 peer-[:not(:placeholder-shown)]:text-xs"
|
||||
>
|
||||
{localize('com_files_filter')}
|
||||
</label>
|
||||
</div>
|
||||
<FilterInput
|
||||
inputId="files-filter"
|
||||
label={localize('com_files_filter')}
|
||||
value={(table.getColumn('filename')?.getFilterValue() as string | undefined) ?? ''}
|
||||
onChange={(event) => table.getColumn('filename')?.setFilterValue(event.target.value)}
|
||||
containerClassName="flex-1"
|
||||
/>
|
||||
<div className="relative focus-within:z-[100]">
|
||||
<ColumnVisibilityDropdown
|
||||
table={table}
|
||||
|
|
|
|||
|
|
@ -33,12 +33,12 @@ export function SortFilterHeader<TData, TValue>({
|
|||
{
|
||||
label: localize('com_ui_ascending'),
|
||||
onClick: () => column.toggleSorting(false),
|
||||
icon: <ArrowUpIcon className="h-3.5 w-3.5 text-muted-foreground/70" />,
|
||||
icon: <ArrowUpIcon className="icon-sm text-text-secondary" />,
|
||||
},
|
||||
{
|
||||
label: localize('com_ui_descending'),
|
||||
onClick: () => column.toggleSorting(true),
|
||||
icon: <ArrowDownIcon className="h-3.5 w-3.5 text-muted-foreground/70" />,
|
||||
icon: <ArrowDownIcon className="icon-sm text-text-secondary" />,
|
||||
},
|
||||
];
|
||||
|
||||
|
|
@ -56,9 +56,7 @@ export function SortFilterHeader<TData, TValue>({
|
|||
items.push({
|
||||
label: filterValue,
|
||||
onClick: () => column.setFilterValue(value),
|
||||
icon: (
|
||||
<ListFilter className="h-3.5 w-3.5 text-muted-foreground/70" aria-hidden="true" />
|
||||
),
|
||||
icon: <ListFilter className="icon-sm text-text-secondary" aria-hidden="true" />,
|
||||
show: true,
|
||||
className: isActive ? 'border-l-2 border-l-border-xheavy' : '',
|
||||
});
|
||||
|
|
@ -70,7 +68,7 @@ export function SortFilterHeader<TData, TValue>({
|
|||
items.push({
|
||||
label: localize('com_ui_show_all'),
|
||||
onClick: () => column.setFilterValue(undefined),
|
||||
icon: <FilterX className="h-3.5 w-3.5 text-muted-foreground/70" />,
|
||||
icon: <FilterX className="icon-sm text-text-secondary" />,
|
||||
show: true,
|
||||
});
|
||||
}
|
||||
|
|
@ -113,9 +111,9 @@ export function SortFilterHeader<TData, TValue>({
|
|||
>
|
||||
<span>{title}</span>
|
||||
{column.getIsFiltered() ? (
|
||||
<ListFilter className="icon-sm text-muted-foreground/70" aria-hidden="true" />
|
||||
<ListFilter className="icon-sm" aria-hidden="true" />
|
||||
) : (
|
||||
<ListFilter className="icon-sm opacity-30" aria-hidden="true" />
|
||||
<ListFilter className="icon-sm text-text-secondary" aria-hidden="true" />
|
||||
)}
|
||||
{(() => {
|
||||
const sortState = column.getIsSorted();
|
||||
|
|
|
|||
|
|
@ -278,7 +278,6 @@ describe('AttachFileMenu', () => {
|
|||
{ name: 'OpenAI', endpoint: EModelEndpoint.openAI },
|
||||
{ name: 'Anthropic', endpoint: EModelEndpoint.anthropic },
|
||||
{ name: 'Google', endpoint: EModelEndpoint.google },
|
||||
{ name: 'Azure OpenAI', endpoint: EModelEndpoint.azureOpenAI },
|
||||
{ name: 'Custom', endpoint: EModelEndpoint.custom },
|
||||
];
|
||||
|
||||
|
|
@ -301,6 +300,45 @@ describe('AttachFileMenu', () => {
|
|||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it('should show Upload to Provider for Azure OpenAI with useResponsesApi', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.azureOpenAI,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.azureOpenAI,
|
||||
endpointType: EModelEndpoint.azureOpenAI,
|
||||
useResponsesApi: true,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should NOT show Upload to Provider for Azure OpenAI without useResponsesApi', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.azureOpenAI,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.azureOpenAI,
|
||||
endpointType: EModelEndpoint.azureOpenAI,
|
||||
useResponsesApi: false,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.queryByText('Upload to Provider')).not.toBeInTheDocument();
|
||||
expect(screen.getByText('Upload Image')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Agent Capabilities', () => {
|
||||
|
|
@ -512,7 +550,7 @@ describe('AttachFileMenu', () => {
|
|||
});
|
||||
|
||||
describe('Google Provider Special Case', () => {
|
||||
it('should use google_multimodal file type for Google provider', () => {
|
||||
it('should use image_document_video_audio file type for Google provider', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
|
|
@ -536,7 +574,7 @@ describe('AttachFileMenu', () => {
|
|||
// The file input should have been clicked (indirectly tested through the implementation)
|
||||
});
|
||||
|
||||
it('should use multimodal file type for non-Google providers', () => {
|
||||
it('should use image_document file type for non-Google providers', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
|
|
@ -555,7 +593,7 @@ describe('AttachFileMenu', () => {
|
|||
expect(uploadProviderButton).toBeInTheDocument();
|
||||
fireEvent.click(uploadProviderButton);
|
||||
|
||||
// Implementation detail - multimodal type is used
|
||||
// Implementation detail - image_document type is used
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,6 @@ describe('DragDropModal - Provider Detection', () => {
|
|||
{ name: 'OpenAI', value: EModelEndpoint.openAI },
|
||||
{ name: 'Anthropic', value: EModelEndpoint.anthropic },
|
||||
{ name: 'Google', value: EModelEndpoint.google },
|
||||
{ name: 'Azure OpenAI', value: EModelEndpoint.azureOpenAI },
|
||||
{ name: 'Custom', value: EModelEndpoint.custom },
|
||||
];
|
||||
|
||||
|
|
@ -72,6 +71,10 @@ describe('DragDropModal - Provider Detection', () => {
|
|||
expect(isDocumentSupportedProvider(value)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should NOT recognize Azure OpenAI as supported (requires useResponsesApi)', () => {
|
||||
expect(isDocumentSupportedProvider(EModelEndpoint.azureOpenAI)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('real-world scenarios', () => {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import React, { memo, useCallback } from 'react';
|
||||
import React, { memo, useMemo, useCallback, useRef } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronDown } from 'lucide-react';
|
||||
import { PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import { MultiSelect, MCPIcon } from '@librechat/client';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import { TooltipAnchor } from '@librechat/client';
|
||||
import MCPServerMenuItem from '~/components/MCP/MCPServerMenuItem';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import StackedMCPIcons from '~/components/MCP/StackedMCPIcons';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { useHasAccess } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
|
@ -13,96 +16,117 @@ function MCPSelectContent() {
|
|||
localize,
|
||||
isPinned,
|
||||
mcpValues,
|
||||
isInitializing,
|
||||
placeholderText,
|
||||
batchToggleServers,
|
||||
getConfigDialogProps,
|
||||
getServerStatusIconProps,
|
||||
selectableServers,
|
||||
connectionStatus,
|
||||
isInitializing,
|
||||
getConfigDialogProps,
|
||||
toggleServerSelection,
|
||||
getServerStatusIconProps,
|
||||
} = mcpServerManager;
|
||||
|
||||
const renderSelectedValues = useCallback(
|
||||
(
|
||||
values: string[],
|
||||
placeholder?: string,
|
||||
items?: (string | { label: string; value: string })[],
|
||||
) => {
|
||||
if (values.length === 0) {
|
||||
return placeholder || localize('com_ui_select_placeholder');
|
||||
}
|
||||
if (values.length === 1) {
|
||||
const selectedItem = items?.find((i) => typeof i !== 'string' && i.value == values[0]);
|
||||
return selectedItem && typeof selectedItem !== 'string' ? selectedItem.label : values[0];
|
||||
}
|
||||
return localize('com_ui_x_selected', { 0: values.length });
|
||||
const menuStore = Ariakit.useMenuStore({ focusLoop: true });
|
||||
const isOpen = menuStore.useState('open');
|
||||
const focusedElementRef = useRef<HTMLElement | null>(null);
|
||||
|
||||
const selectedCount = mcpValues?.length ?? 0;
|
||||
|
||||
// Wrap toggleServerSelection to preserve focus after state update
|
||||
const handleToggle = useCallback(
|
||||
(serverName: string) => {
|
||||
// Save currently focused element
|
||||
focusedElementRef.current = document.activeElement as HTMLElement;
|
||||
toggleServerSelection(serverName);
|
||||
// Restore focus after React re-renders
|
||||
requestAnimationFrame(() => {
|
||||
focusedElementRef.current?.focus();
|
||||
});
|
||||
},
|
||||
[localize],
|
||||
[toggleServerSelection],
|
||||
);
|
||||
|
||||
const renderItemContent = useCallback(
|
||||
(serverName: string, defaultContent: React.ReactNode) => {
|
||||
const statusIconProps = getServerStatusIconProps(serverName);
|
||||
const isServerInitializing = isInitializing(serverName);
|
||||
const selectedServers = useMemo(() => {
|
||||
if (!mcpValues || mcpValues.length === 0) {
|
||||
return [];
|
||||
}
|
||||
return selectableServers.filter((s) => mcpValues.includes(s.serverName));
|
||||
}, [selectableServers, mcpValues]);
|
||||
|
||||
/**
|
||||
Common wrapper for the main content (check mark + text).
|
||||
Ensures Check & Text are adjacent and the group takes available space.
|
||||
*/
|
||||
const mainContentWrapper = (
|
||||
<button
|
||||
type="button"
|
||||
className={`flex flex-grow items-center rounded bg-transparent p-0 text-left transition-colors focus:outline-none ${
|
||||
isServerInitializing ? 'opacity-50' : ''
|
||||
}`}
|
||||
tabIndex={0}
|
||||
disabled={isServerInitializing}
|
||||
>
|
||||
{defaultContent}
|
||||
</button>
|
||||
);
|
||||
|
||||
const statusIcon = statusIconProps && <MCPServerStatusIcon {...statusIconProps} />;
|
||||
|
||||
if (statusIcon) {
|
||||
return (
|
||||
<div className="flex w-full items-center justify-between">
|
||||
{mainContentWrapper}
|
||||
<div className="ml-2 flex items-center">{statusIcon}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return mainContentWrapper;
|
||||
},
|
||||
[getServerStatusIconProps, isInitializing],
|
||||
);
|
||||
const displayText = useMemo(() => {
|
||||
if (selectedCount === 0) {
|
||||
return null;
|
||||
}
|
||||
if (selectedCount === 1) {
|
||||
const server = selectableServers.find((s) => s.serverName === mcpValues?.[0]);
|
||||
return server?.config?.title || mcpValues?.[0];
|
||||
}
|
||||
return localize('com_ui_x_selected', { 0: selectedCount });
|
||||
}, [selectedCount, selectableServers, mcpValues, localize]);
|
||||
|
||||
if (!isPinned && mcpValues?.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const configDialogProps = getConfigDialogProps();
|
||||
|
||||
return (
|
||||
<>
|
||||
<MultiSelect
|
||||
items={selectableServers.map((s) => ({
|
||||
label: s.config.title || s.serverName,
|
||||
value: s.serverName,
|
||||
}))}
|
||||
selectedValues={mcpValues ?? []}
|
||||
setSelectedValues={batchToggleServers}
|
||||
renderSelectedValues={renderSelectedValues}
|
||||
renderItemContent={renderItemContent}
|
||||
placeholder={placeholderText}
|
||||
popoverClassName="min-w-fit"
|
||||
className="badge-icon min-w-fit"
|
||||
selectIcon={<MCPIcon className="icon-md text-text-primary" />}
|
||||
selectItemsClassName="border border-blue-600/50 bg-blue-500/10 hover:bg-blue-700/10"
|
||||
selectClassName={cn(
|
||||
'group relative inline-flex items-center justify-center md:justify-start gap-1.5 rounded-full border border-border-medium text-sm font-medium transition-all',
|
||||
'md:w-full size-9 p-2 md:p-3 bg-transparent shadow-sm hover:bg-surface-hover hover:shadow-md active:shadow-inner',
|
||||
)}
|
||||
/>
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<TooltipAnchor
|
||||
description={placeholderText}
|
||||
disabled={isOpen}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
className={cn(
|
||||
'group relative inline-flex items-center justify-center gap-1.5',
|
||||
'border border-border-medium text-sm font-medium transition-all',
|
||||
'h-9 min-w-9 rounded-full bg-transparent px-2.5 shadow-sm',
|
||||
'hover:bg-surface-hover hover:shadow-md active:shadow-inner',
|
||||
'md:w-fit md:justify-start md:px-3',
|
||||
isOpen && 'bg-surface-hover',
|
||||
)}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<StackedMCPIcons selectedServers={selectedServers} maxIcons={3} iconSize="sm" />
|
||||
<span className="hidden truncate text-text-primary md:block">
|
||||
{displayText || placeholderText}
|
||||
</span>
|
||||
<ChevronDown
|
||||
className={cn(
|
||||
'hidden h-3 w-3 text-text-secondary transition-transform md:block',
|
||||
isOpen && 'rotate-180',
|
||||
)}
|
||||
/>
|
||||
</TooltipAnchor>
|
||||
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
gutter={8}
|
||||
aria-label={localize('com_ui_mcp_servers')}
|
||||
className={cn(
|
||||
'z-50 flex min-w-[260px] max-w-[320px] flex-col rounded-xl',
|
||||
'border border-border-light bg-presentation p-1.5 shadow-lg',
|
||||
'origin-top opacity-0 transition-[opacity,transform] duration-200 ease-out',
|
||||
'data-[enter]:scale-100 data-[enter]:opacity-100',
|
||||
'scale-95 data-[leave]:scale-95 data-[leave]:opacity-0',
|
||||
)}
|
||||
>
|
||||
<div className="flex max-h-[320px] flex-col gap-1 overflow-y-auto">
|
||||
{selectableServers.map((server) => (
|
||||
<MCPServerMenuItem
|
||||
key={server.serverName}
|
||||
server={server}
|
||||
isSelected={mcpValues?.includes(server.serverName) ?? false}
|
||||
connectionStatus={connectionStatus}
|
||||
isInitializing={isInitializing}
|
||||
statusIconProps={getServerStatusIconProps(server.serverName)}
|
||||
onToggle={handleToggle}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
{configDialogProps && (
|
||||
<MCPConfigDialog {...configDialogProps} conversationId={conversationId} />
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import React from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronRight } from 'lucide-react';
|
||||
import { PinIcon, MCPIcon } from '@librechat/client';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import { MCPIcon, PinIcon } from '@librechat/client';
|
||||
import MCPServerMenuItem from '~/components/MCP/MCPServerMenuItem';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface MCPSubMenuProps {
|
||||
|
|
@ -13,14 +14,16 @@ interface MCPSubMenuProps {
|
|||
|
||||
const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
|
||||
({ placeholder, ...props }, ref) => {
|
||||
const localize = useLocalize();
|
||||
const { mcpServerManager } = useBadgeRowContext();
|
||||
const {
|
||||
isPinned,
|
||||
mcpValues,
|
||||
setIsPinned,
|
||||
isInitializing,
|
||||
placeholderText,
|
||||
availableMCPServers,
|
||||
selectableServers,
|
||||
connectionStatus,
|
||||
isInitializing,
|
||||
getConfigDialogProps,
|
||||
toggleServerSelection,
|
||||
getServerStatusIconProps,
|
||||
|
|
@ -33,7 +36,7 @@ const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
|
|||
});
|
||||
|
||||
// Don't render if no MCP servers are configured
|
||||
if (!availableMCPServers || availableMCPServers.length === 0) {
|
||||
if (!selectableServers || selectableServers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
@ -44,6 +47,7 @@ const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
|
|||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
hideOnClick={false}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
onClick={(e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
|
|
@ -55,9 +59,9 @@ const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
|
|||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<MCPIcon className="icon-md" />
|
||||
<MCPIcon className="h-5 w-5 flex-shrink-0 text-text-primary" aria-hidden="true" />
|
||||
<span>{placeholder || placeholderText}</span>
|
||||
<ChevronRight className="ml-auto h-3 w-3" />
|
||||
<ChevronRight className="h-3 w-3 flex-shrink-0" aria-hidden="true" />
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
|
|
@ -70,55 +74,36 @@ const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
|
|||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isPinned ? 'Unpin' : 'Pin'}
|
||||
aria-label={isPinned ? localize('com_ui_unpin') : localize('com_ui_pin')}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
aria-label={localize('com_ui_mcp_servers')}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[200px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary p-1 shadow-lg',
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[260px] max-w-[320px] flex-col rounded-xl',
|
||||
'border border-border-light bg-presentation p-1.5 shadow-lg',
|
||||
)}
|
||||
>
|
||||
{availableMCPServers.map((s) => {
|
||||
const statusIconProps = getServerStatusIconProps(s.serverName);
|
||||
const isSelected = mcpValues?.includes(s.serverName) ?? false;
|
||||
const isServerInitializing = isInitializing(s.serverName);
|
||||
|
||||
const statusIcon = statusIconProps && <MCPServerStatusIcon {...statusIconProps} />;
|
||||
|
||||
return (
|
||||
<Ariakit.MenuItem
|
||||
key={s.serverName}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
toggleServerSelection(s.serverName);
|
||||
}}
|
||||
disabled={isServerInitializing}
|
||||
className={cn(
|
||||
'flex items-center gap-2 rounded-lg px-2 py-1.5 text-text-primary hover:cursor-pointer',
|
||||
'scroll-m-1 outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
'w-full min-w-0 justify-between text-sm',
|
||||
isServerInitializing &&
|
||||
'opacity-50 hover:bg-transparent dark:hover:bg-transparent',
|
||||
isSelected && 'bg-surface-active',
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-grow items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isSelected} />
|
||||
<span>{s.config.title || s.serverName}</span>
|
||||
</div>
|
||||
{statusIcon && <div className="ml-2 flex items-center">{statusIcon}</div>}
|
||||
</Ariakit.MenuItem>
|
||||
);
|
||||
})}
|
||||
<div className="flex max-h-[320px] flex-col gap-1 overflow-y-auto">
|
||||
{selectableServers.map((server) => (
|
||||
<MCPServerMenuItem
|
||||
key={server.serverName}
|
||||
server={server}
|
||||
isSelected={mcpValues?.includes(server.serverName) ?? false}
|
||||
connectionStatus={connectionStatus}
|
||||
isInitializing={isInitializing}
|
||||
statusIconProps={getServerStatusIconProps(server.serverName)}
|
||||
onToggle={toggleServerSelection}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
{configDialogProps && <MCPConfigDialog {...configDialogProps} />}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ export default function StreamAudio({ index = 0 }) {
|
|||
const { pauseGlobalAudio } = usePauseGlobalAudio();
|
||||
|
||||
const { conversationId: paramId } = useParams();
|
||||
const queryParam = paramId === 'new' ? paramId : latestMessage?.conversationId ?? paramId ?? '';
|
||||
const queryParam = paramId === 'new' ? paramId : (latestMessage?.conversationId ?? paramId ?? '');
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const getMessages = useCallback(
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue