Compare commits

..

No commits in common. "main" and "chart-1.9.9" have entirely different histories.

309 changed files with 7841 additions and 26086 deletions

View file

@ -677,8 +677,7 @@ AZURE_CONTAINER_NAME=files
#========================# #========================#
ALLOW_SHARED_LINKS=true ALLOW_SHARED_LINKS=true
# Allows unauthenticated access to shared links. Defaults to false (auth required) if not set. ALLOW_SHARED_LINKS_PUBLIC=true
ALLOW_SHARED_LINKS_PUBLIC=false
#==============================# #==============================#
# Static File Cache Control # # Static File Cache Control #
@ -850,24 +849,3 @@ OPENWEATHER_API_KEY=
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it) # Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration # When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
# MCP_SKIP_CODE_CHALLENGE_CHECK=false # MCP_SKIP_CODE_CHALLENGE_CHECK=false
# Circuit breaker: max connect/disconnect cycles before tripping (per server)
# MCP_CB_MAX_CYCLES=7
# Circuit breaker: sliding window (ms) for counting cycles
# MCP_CB_CYCLE_WINDOW_MS=45000
# Circuit breaker: cooldown (ms) after the cycle breaker trips
# MCP_CB_CYCLE_COOLDOWN_MS=15000
# Circuit breaker: max consecutive failed connection rounds before backoff
# MCP_CB_MAX_FAILED_ROUNDS=3
# Circuit breaker: sliding window (ms) for counting failed rounds
# MCP_CB_FAILED_WINDOW_MS=120000
# Circuit breaker: base backoff (ms) after failed round threshold is reached
# MCP_CB_BASE_BACKOFF_MS=30000
# Circuit breaker: max backoff cap (ms) for exponential backoff
# MCP_CB_MAX_BACKOFF_MS=300000

View file

@ -9,159 +9,11 @@ on:
paths: paths:
- 'api/**' - 'api/**'
- 'packages/**' - 'packages/**'
env:
NODE_ENV: CI
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
jobs: jobs:
build: tests_Backend:
name: Build packages name: Run Backend unit tests
timeout-minutes: 60
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
api/node_modules
packages/api/node_modules
packages/data-provider/node_modules
packages/data-schemas/node_modules
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Restore data-provider build cache
id: cache-data-provider
uses: actions/cache@v4
with:
path: packages/data-provider/dist
key: build-data-provider-${{ runner.os }}-${{ hashFiles('packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
- name: Build data-provider
if: steps.cache-data-provider.outputs.cache-hit != 'true'
run: npm run build:data-provider
- name: Restore data-schemas build cache
id: cache-data-schemas
uses: actions/cache@v4
with:
path: packages/data-schemas/dist
key: build-data-schemas-${{ runner.os }}-${{ hashFiles('packages/data-schemas/src/**', 'packages/data-schemas/tsconfig*.json', 'packages/data-schemas/rollup.config.js', 'packages/data-schemas/package.json', 'packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
- name: Build data-schemas
if: steps.cache-data-schemas.outputs.cache-hit != 'true'
run: npm run build:data-schemas
- name: Restore api build cache
id: cache-api
uses: actions/cache@v4
with:
path: packages/api/dist
key: build-api-${{ runner.os }}-${{ hashFiles('packages/api/src/**', 'packages/api/tsconfig*.json', 'packages/api/server-rollup.config.js', 'packages/api/package.json', 'packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json', 'packages/data-schemas/src/**', 'packages/data-schemas/tsconfig*.json', 'packages/data-schemas/rollup.config.js', 'packages/data-schemas/package.json') }}
- name: Build api
if: steps.cache-api.outputs.cache-hit != 'true'
run: npm run build:api
- name: Upload data-provider build
uses: actions/upload-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
retention-days: 2
- name: Upload data-schemas build
uses: actions/upload-artifact@v4
with:
name: build-data-schemas
path: packages/data-schemas/dist
retention-days: 2
- name: Upload api build
uses: actions/upload-artifact@v4
with:
name: build-api
path: packages/api/dist
retention-days: 2
circular-deps:
name: Circular dependency checks
needs: build
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
api/node_modules
packages/api/node_modules
packages/data-provider/node_modules
packages/data-schemas/node_modules
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Download data-provider build
uses: actions/download-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download data-schemas build
uses: actions/download-artifact@v4
with:
name: build-data-schemas
path: packages/data-schemas/dist
- name: Rebuild @librechat/api and check for circular dependencies
run: |
output=$(npm run build:api 2>&1)
echo "$output"
if echo "$output" | grep -q "Circular depend"; then
echo "Error: Circular dependency detected in @librechat/api!"
exit 1
fi
- name: Detect circular dependencies in rollup
working-directory: ./packages/data-provider
run: |
output=$(npm run rollup:api)
echo "$output"
if echo "$output" | grep -q "Circular dependency"; then
echo "Error: Circular dependency detected!"
exit 1
fi
test-api:
name: 'Tests: api'
needs: build
runs-on: ubuntu-latest
timeout-minutes: 15
env: env:
MONGO_URI: ${{ secrets.MONGO_URI }} MONGO_URI: ${{ secrets.MONGO_URI }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@ -171,187 +23,60 @@ jobs:
BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }} BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }}
BAN_DURATION: ${{ secrets.BAN_DURATION }} BAN_DURATION: ${{ secrets.BAN_DURATION }}
BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }} BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }}
NODE_ENV: CI
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Use Node.js 20.x
- name: Use Node.js 20.19
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: '20.19' node-version: 20
cache: 'npm'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
api/node_modules
packages/api/node_modules
packages/data-provider/node_modules
packages/data-schemas/node_modules
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies - name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci run: npm ci
- name: Download data-provider build - name: Install Data Provider Package
uses: actions/download-artifact@v4 run: npm run build:data-provider
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download data-schemas build - name: Install Data Schemas Package
uses: actions/download-artifact@v4 run: npm run build:data-schemas
with:
name: build-data-schemas
path: packages/data-schemas/dist
- name: Download api build - name: Build API Package & Detect Circular Dependencies
uses: actions/download-artifact@v4 run: |
with: output=$(npm run build:api 2>&1)
name: build-api echo "$output"
path: packages/api/dist if echo "$output" | grep -q "Circular depend"; then
echo "Error: Circular dependency detected in @librechat/api!"
exit 1
fi
- name: Create empty auth.json file - name: Create empty auth.json file
run: | run: |
mkdir -p api/data mkdir -p api/data
echo '{}' > api/data/auth.json echo '{}' > api/data/auth.json
- name: Check for Circular dependency in rollup
working-directory: ./packages/data-provider
run: |
output=$(npm run rollup:api)
echo "$output"
if echo "$output" | grep -q "Circular dependency"; then
echo "Error: Circular dependency detected!"
exit 1
fi
- name: Prepare .env.test file - name: Prepare .env.test file
run: cp api/test/.env.test.example api/test/.env.test run: cp api/test/.env.test.example api/test/.env.test
- name: Run unit tests - name: Run unit tests
run: cd api && npm run test:ci run: cd api && npm run test:ci
test-data-provider: - name: Run librechat-data-provider unit tests
name: 'Tests: data-provider'
needs: build
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
api/node_modules
packages/api/node_modules
packages/data-provider/node_modules
packages/data-schemas/node_modules
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Download data-provider build
uses: actions/download-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Run unit tests
run: cd packages/data-provider && npm run test:ci run: cd packages/data-provider && npm run test:ci
test-data-schemas: - name: Run @librechat/data-schemas unit tests
name: 'Tests: data-schemas'
needs: build
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
api/node_modules
packages/api/node_modules
packages/data-provider/node_modules
packages/data-schemas/node_modules
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Download data-provider build
uses: actions/download-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download data-schemas build
uses: actions/download-artifact@v4
with:
name: build-data-schemas
path: packages/data-schemas/dist
- name: Run unit tests
run: cd packages/data-schemas && npm run test:ci run: cd packages/data-schemas && npm run test:ci
test-packages-api: - name: Run @librechat/api unit tests
name: 'Tests: @librechat/api'
needs: build
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
api/node_modules
packages/api/node_modules
packages/data-provider/node_modules
packages/data-schemas/node_modules
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Download data-provider build
uses: actions/download-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download data-schemas build
uses: actions/download-artifact@v4
with:
name: build-data-schemas
path: packages/data-schemas/dist
- name: Download api build
uses: actions/download-artifact@v4
with:
name: build-api
path: packages/api/dist
- name: Run unit tests
run: cd packages/api && npm run test:ci run: cd packages/api && npm run test:ci

View file

@ -11,200 +11,51 @@ on:
- 'client/**' - 'client/**'
- 'packages/data-provider/**' - 'packages/data-provider/**'
jobs:
tests_frontend_ubuntu:
name: Run frontend unit tests on Ubuntu
timeout-minutes: 60
runs-on: ubuntu-latest
env: env:
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}' NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
jobs:
build:
name: Build packages
runs-on: ubuntu-latest
timeout-minutes: 15
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Use Node.js 20.x
- name: Use Node.js 20.19
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: '20.19' node-version: 20
cache: 'npm'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
client/node_modules
packages/client/node_modules
packages/data-provider/node_modules
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies - name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci run: npm ci
- name: Restore data-provider build cache - name: Build Client
id: cache-data-provider run: npm run frontend:ci
uses: actions/cache@v4
with:
path: packages/data-provider/dist
key: build-data-provider-${{ runner.os }}-${{ hashFiles('packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
- name: Build data-provider
if: steps.cache-data-provider.outputs.cache-hit != 'true'
run: npm run build:data-provider
- name: Restore client-package build cache
id: cache-client-package
uses: actions/cache@v4
with:
path: packages/client/dist
key: build-client-package-${{ runner.os }}-${{ hashFiles('packages/client/src/**', 'packages/client/tsconfig*.json', 'packages/client/rollup.config.js', 'packages/client/package.json', 'packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
- name: Build client-package
if: steps.cache-client-package.outputs.cache-hit != 'true'
run: npm run build:client-package
- name: Upload data-provider build
uses: actions/upload-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
retention-days: 2
- name: Upload client-package build
uses: actions/upload-artifact@v4
with:
name: build-client-package
path: packages/client/dist
retention-days: 2
test-ubuntu:
name: 'Tests: Ubuntu'
needs: build
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
client/node_modules
packages/client/node_modules
packages/data-provider/node_modules
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Download data-provider build
uses: actions/download-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download client-package build
uses: actions/download-artifact@v4
with:
name: build-client-package
path: packages/client/dist
- name: Run unit tests - name: Run unit tests
run: npm run test:ci --verbose run: npm run test:ci --verbose
working-directory: client working-directory: client
test-windows: tests_frontend_windows:
name: 'Tests: Windows' name: Run frontend unit tests on Windows
needs: build timeout-minutes: 60
runs-on: windows-latest runs-on: windows-latest
timeout-minutes: 20 env:
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Use Node.js 20.x
- name: Use Node.js 20.19
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: '20.19' node-version: 20
cache: 'npm'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
client/node_modules
packages/client/node_modules
packages/data-provider/node_modules
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies - name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci run: npm ci
- name: Download data-provider build - name: Build Client
uses: actions/download-artifact@v4 run: npm run frontend:ci
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download client-package build
uses: actions/download-artifact@v4
with:
name: build-client-package
path: packages/client/dist
- name: Run unit tests - name: Run unit tests
run: npm run test:ci --verbose run: npm run test:ci --verbose
working-directory: client working-directory: client
build-verify:
name: Vite build verification
needs: build
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v4
- name: Use Node.js 20.19
uses: actions/setup-node@v4
with:
node-version: '20.19'
- name: Restore node_modules cache
id: cache-node-modules
uses: actions/cache@v4
with:
path: |
node_modules
client/node_modules
packages/client/node_modules
packages/data-provider/node_modules
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
- name: Download data-provider build
uses: actions/download-artifact@v4
with:
name: build-data-provider
path: packages/data-provider/dist
- name: Download client-package build
uses: actions/download-artifact@v4
with:
name: build-client-package
path: packages/client/dist
- name: Build client
run: cd client && npm run build:ci

View file

@ -149,15 +149,7 @@ Multi-line imports count total character length across all lines. Consolidate va
- Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc. - Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc.
- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering. - Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering.
- Cover loading, success, and error states for UI/data flows. - Cover loading, success, and error states for UI/data flows.
- Mock data-provider hooks and external dependencies.
### Philosophy
- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort.
- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic.
- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls.
- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals.
- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls.
- Heavy mocking is a code smell, not a testing strategy.
--- ---

View file

@ -1,4 +1,4 @@
# v0.8.3 # v0.8.3-rc2
# Base node image # Base node image
FROM node:20-alpine AS node FROM node:20-alpine AS node

View file

@ -1,5 +1,5 @@
# Dockerfile.multi # Dockerfile.multi
# v0.8.3 # v0.8.3-rc2
# Set configurable max-old-space-size with default # Set configurable max-old-space-size with default
ARG NODE_MAX_OLD_SPACE_SIZE=6144 ARG NODE_MAX_OLD_SPACE_SIZE=6144

View file

@ -1,6 +1,7 @@
const DALLE3 = require('../DALLE3'); const DALLE3 = require('../DALLE3');
const { ProxyAgent } = require('undici'); const { ProxyAgent } = require('undici');
jest.mock('tiktoken');
const processFileURL = jest.fn(); const processFileURL = jest.fn();
describe('DALLE3 Proxy Configuration', () => { describe('DALLE3 Proxy Configuration', () => {

View file

@ -14,6 +14,15 @@ jest.mock('@librechat/data-schemas', () => {
}; };
}); });
jest.mock('tiktoken', () => {
return {
encoding_for_model: jest.fn().mockReturnValue({
encode: jest.fn(),
decode: jest.fn(),
}),
};
});
const processFileURL = jest.fn(); const processFileURL = jest.fn();
const generate = jest.fn(); const generate = jest.fn();

View file

@ -236,12 +236,8 @@ async function performSync(flowManager, flowId, flowType) {
const messageCount = messageProgress.totalDocuments; const messageCount = messageProgress.totalDocuments;
const messagesIndexed = messageProgress.totalProcessed; const messagesIndexed = messageProgress.totalProcessed;
const unindexedMessages = messageCount - messagesIndexed; const unindexedMessages = messageCount - messagesIndexed;
const noneIndexed = messagesIndexed === 0 && unindexedMessages > 0;
if (settingsUpdated || noneIndexed || unindexedMessages > syncThreshold) { if (settingsUpdated || unindexedMessages > syncThreshold) {
if (noneIndexed && !settingsUpdated) {
logger.info('[indexSync] No messages marked as indexed, forcing full sync');
}
logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`); logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`);
await Message.syncWithMeili(); await Message.syncWithMeili();
messagesSync = true; messagesSync = true;
@ -265,13 +261,9 @@ async function performSync(flowManager, flowId, flowType) {
const convoCount = convoProgress.totalDocuments; const convoCount = convoProgress.totalDocuments;
const convosIndexed = convoProgress.totalProcessed; const convosIndexed = convoProgress.totalProcessed;
const unindexedConvos = convoCount - convosIndexed;
const noneConvosIndexed = convosIndexed === 0 && unindexedConvos > 0;
if (settingsUpdated || noneConvosIndexed || unindexedConvos > syncThreshold) { const unindexedConvos = convoCount - convosIndexed;
if (noneConvosIndexed && !settingsUpdated) { if (settingsUpdated || unindexedConvos > syncThreshold) {
logger.info('[indexSync] No conversations marked as indexed, forcing full sync');
}
logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`); logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`);
await Conversation.syncWithMeili(); await Conversation.syncWithMeili();
convosSync = true; convosSync = true;

View file

@ -462,69 +462,4 @@ describe('performSync() - syncThreshold logic', () => {
); );
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)'); expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)');
}); });
test('forces sync when zero documents indexed (reset scenario) even if below threshold', async () => {
Message.getSyncProgress.mockResolvedValue({
totalProcessed: 0,
totalDocuments: 680,
isComplete: false,
});
Conversation.getSyncProgress.mockResolvedValue({
totalProcessed: 0,
totalDocuments: 76,
isComplete: false,
});
Message.syncWithMeili.mockResolvedValue(undefined);
Conversation.syncWithMeili.mockResolvedValue(undefined);
const indexSync = require('./indexSync');
await indexSync();
expect(Message.syncWithMeili).toHaveBeenCalledTimes(1);
expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] No messages marked as indexed, forcing full sync',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] Starting message sync (680 unindexed)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] No conversations marked as indexed, forcing full sync',
);
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (76 unindexed)');
});
test('does NOT force sync when some documents already indexed and below threshold', async () => {
Message.getSyncProgress.mockResolvedValue({
totalProcessed: 630,
totalDocuments: 680,
isComplete: false,
});
Conversation.getSyncProgress.mockResolvedValue({
totalProcessed: 70,
totalDocuments: 76,
isComplete: false,
});
const indexSync = require('./indexSync');
await indexSync();
expect(Message.syncWithMeili).not.toHaveBeenCalled();
expect(Conversation.syncWithMeili).not.toHaveBeenCalled();
expect(mockLogger.info).not.toHaveBeenCalledWith(
'[indexSync] No messages marked as indexed, forcing full sync',
);
expect(mockLogger.info).not.toHaveBeenCalledWith(
'[indexSync] No conversations marked as indexed, forcing full sync',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] 50 messages unindexed (below threshold: 1000, skipping)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] 6 convos unindexed (below threshold: 1000, skipping)',
);
});
}); });

View file

@ -3,13 +3,12 @@ module.exports = {
clearMocks: true, clearMocks: true,
roots: ['<rootDir>'], roots: ['<rootDir>'],
coverageDirectory: 'coverage', coverageDirectory: 'coverage',
maxWorkers: '50%',
testTimeout: 30000, // 30 seconds timeout for all tests testTimeout: 30000, // 30 seconds timeout for all tests
setupFiles: ['./test/jestSetup.js', './test/__mocks__/logger.js'], setupFiles: ['./test/jestSetup.js', './test/__mocks__/logger.js'],
moduleNameMapper: { moduleNameMapper: {
'~/(.*)': '<rootDir>/$1', '~/(.*)': '<rootDir>/$1',
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json', '~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js', '^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js', // Mock for the passport strategy part
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js', '^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
}, },
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'], transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],

View file

@ -4,7 +4,9 @@ const { Action } = require('~/db/models');
* Update an action with new data without overwriting existing properties, * Update an action with new data without overwriting existing properties,
* or create a new action if it doesn't exist. * or create a new action if it doesn't exist.
* *
* @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams * @param {Object} searchParams - The search parameters to find the action to update.
* @param {string} searchParams.action_id - The ID of the action to update.
* @param {string} searchParams.user - The user ID of the action's author.
* @param {Object} updateData - An object containing the properties to update. * @param {Object} updateData - An object containing the properties to update.
* @returns {Promise<Action>} The updated or newly created action document as a plain object. * @returns {Promise<Action>} The updated or newly created action document as a plain object.
*/ */
@ -45,8 +47,10 @@ const getActions = async (searchParams, includeSensitive = false) => {
/** /**
* Deletes an action by params. * Deletes an action by params.
* *
* @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams * @param {Object} searchParams - The search parameters to find the action to delete.
* @returns {Promise<Action|null>} The deleted action document as a plain object, or null if no match. * @param {string} searchParams.action_id - The ID of the action to delete.
* @param {string} searchParams.user - The user ID of the action's author.
* @returns {Promise<Action>} A promise that resolves to the deleted action document as a plain object, or null if no document was found.
*/ */
const deleteAction = async (searchParams) => { const deleteAction = async (searchParams) => {
return await Action.findOneAndDelete(searchParams).lean(); return await Action.findOneAndDelete(searchParams).lean();

View file

@ -1,250 +0,0 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { actionSchema } = require('@librechat/data-schemas');
const { updateAction, getActions, deleteAction } = require('./Action');
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
if (!mongoose.models.Action) {
mongoose.model('Action', actionSchema);
}
await mongoose.connect(mongoUri);
}, 20000);
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.models.Action.deleteMany({});
});
const userId = new mongoose.Types.ObjectId();
describe('Action ownership scoping', () => {
describe('updateAction', () => {
it('updates when action_id and agent_id both match', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_1',
agent_id: 'agent_A',
metadata: { domain: 'example.com' },
});
const result = await updateAction(
{ action_id: 'act_1', agent_id: 'agent_A' },
{ metadata: { domain: 'updated.com' } },
);
expect(result).not.toBeNull();
expect(result.metadata.domain).toBe('updated.com');
expect(result.agent_id).toBe('agent_A');
});
it('does not update when agent_id does not match (creates a new doc via upsert)', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_1',
agent_id: 'agent_B',
metadata: { domain: 'victim.com', api_key: 'secret' },
});
const result = await updateAction(
{ action_id: 'act_1', agent_id: 'agent_A' },
{ user: userId, metadata: { domain: 'attacker.com' } },
);
expect(result.metadata.domain).toBe('attacker.com');
const original = await mongoose.models.Action.findOne({
action_id: 'act_1',
agent_id: 'agent_B',
}).lean();
expect(original).not.toBeNull();
expect(original.metadata.domain).toBe('victim.com');
expect(original.metadata.api_key).toBe('secret');
});
it('updates when action_id and assistant_id both match', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_2',
assistant_id: 'asst_X',
metadata: { domain: 'example.com' },
});
const result = await updateAction(
{ action_id: 'act_2', assistant_id: 'asst_X' },
{ metadata: { domain: 'updated.com' } },
);
expect(result).not.toBeNull();
expect(result.metadata.domain).toBe('updated.com');
});
it('does not overwrite when assistant_id does not match', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_2',
assistant_id: 'asst_victim',
metadata: { domain: 'victim.com', api_key: 'secret' },
});
await updateAction(
{ action_id: 'act_2', assistant_id: 'asst_attacker' },
{ user: userId, metadata: { domain: 'attacker.com' } },
);
const original = await mongoose.models.Action.findOne({
action_id: 'act_2',
assistant_id: 'asst_victim',
}).lean();
expect(original).not.toBeNull();
expect(original.metadata.domain).toBe('victim.com');
expect(original.metadata.api_key).toBe('secret');
});
});
describe('deleteAction', () => {
it('deletes when action_id and agent_id both match', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_del',
agent_id: 'agent_A',
metadata: { domain: 'example.com' },
});
const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' });
expect(result).not.toBeNull();
expect(result.action_id).toBe('act_del');
const remaining = await mongoose.models.Action.countDocuments();
expect(remaining).toBe(0);
});
it('returns null and preserves the document when agent_id does not match', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_del',
agent_id: 'agent_B',
metadata: { domain: 'victim.com' },
});
const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' });
expect(result).toBeNull();
const remaining = await mongoose.models.Action.countDocuments();
expect(remaining).toBe(1);
});
it('deletes when action_id and assistant_id both match', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_del_asst',
assistant_id: 'asst_X',
metadata: { domain: 'example.com' },
});
const result = await deleteAction({ action_id: 'act_del_asst', assistant_id: 'asst_X' });
expect(result).not.toBeNull();
const remaining = await mongoose.models.Action.countDocuments();
expect(remaining).toBe(0);
});
it('returns null and preserves the document when assistant_id does not match', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_del_asst',
assistant_id: 'asst_victim',
metadata: { domain: 'victim.com' },
});
const result = await deleteAction({
action_id: 'act_del_asst',
assistant_id: 'asst_attacker',
});
expect(result).toBeNull();
const remaining = await mongoose.models.Action.countDocuments();
expect(remaining).toBe(1);
});
});
describe('getActions (unscoped baseline)', () => {
it('returns actions by action_id regardless of agent_id', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_shared',
agent_id: 'agent_B',
metadata: { domain: 'example.com' },
});
const results = await getActions({ action_id: 'act_shared' }, true);
expect(results).toHaveLength(1);
expect(results[0].agent_id).toBe('agent_B');
});
it('returns actions scoped by agent_id when provided', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_scoped',
agent_id: 'agent_A',
metadata: { domain: 'a.com' },
});
await mongoose.models.Action.create({
user: userId,
action_id: 'act_other',
agent_id: 'agent_B',
metadata: { domain: 'b.com' },
});
const results = await getActions({ agent_id: 'agent_A' });
expect(results).toHaveLength(1);
expect(results[0].action_id).toBe('act_scoped');
});
});
describe('cross-type protection', () => {
it('updateAction with agent_id filter does not overwrite assistant-owned action', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_cross',
assistant_id: 'asst_victim',
metadata: { domain: 'victim.com', api_key: 'secret' },
});
await updateAction(
{ action_id: 'act_cross', agent_id: 'agent_attacker' },
{ user: userId, metadata: { domain: 'evil.com' } },
);
const original = await mongoose.models.Action.findOne({
action_id: 'act_cross',
assistant_id: 'asst_victim',
}).lean();
expect(original).not.toBeNull();
expect(original.metadata.domain).toBe('victim.com');
expect(original.metadata.api_key).toBe('secret');
});
it('deleteAction with agent_id filter does not delete assistant-owned action', async () => {
await mongoose.models.Action.create({
user: userId,
action_id: 'act_cross_del',
assistant_id: 'asst_victim',
metadata: { domain: 'victim.com' },
});
const result = await deleteAction({ action_id: 'act_cross_del', agent_id: 'agent_attacker' });
expect(result).toBeNull();
const remaining = await mongoose.models.Action.countDocuments();
expect(remaining).toBe(1);
});
});
});

View file

@ -228,7 +228,7 @@ module.exports = {
}, },
], ],
}; };
} catch (_err) { } catch (err) {
logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning'); logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning');
} }
if (cursorFilter) { if (cursorFilter) {
@ -361,7 +361,6 @@ module.exports = {
const deleteMessagesResult = await deleteMessages({ const deleteMessagesResult = await deleteMessages({
conversationId: { $in: conversationIds }, conversationId: { $in: conversationIds },
user,
}); });
return { ...deleteConvoResult, messages: deleteMessagesResult }; return { ...deleteConvoResult, messages: deleteMessagesResult };

View file

@ -549,7 +549,6 @@ describe('Conversation Operations', () => {
expect(result.messages.deletedCount).toBe(5); expect(result.messages.deletedCount).toBe(5);
expect(deleteMessages).toHaveBeenCalledWith({ expect(deleteMessages).toHaveBeenCalledWith({
conversationId: { $in: [mockConversationData.conversationId] }, conversationId: { $in: [mockConversationData.conversationId] },
user: 'user123',
}); });
// Verify conversation was deleted // Verify conversation was deleted

View file

@ -152,11 +152,12 @@ describe('File Access Control', () => {
expect(accessMap.get(fileIds[3])).toBe(false); expect(accessMap.get(fileIds[3])).toBe(false);
}); });
it('should only grant author access to files attached to the agent', async () => { it('should grant access to all files when user is the agent author', async () => {
const authorId = new mongoose.Types.ObjectId(); const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4(); const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()]; const fileIds = [uuidv4(), uuidv4(), uuidv4()];
// Create author user
await User.create({ await User.create({
_id: authorId, _id: authorId,
email: 'author@example.com', email: 'author@example.com',
@ -164,6 +165,7 @@ describe('File Access Control', () => {
provider: 'local', provider: 'local',
}); });
// Create agent
await createAgent({ await createAgent({
id: agentId, id: agentId,
name: 'Test Agent', name: 'Test Agent',
@ -172,83 +174,12 @@ describe('File Access Control', () => {
provider: 'openai', provider: 'openai',
tool_resources: { tool_resources: {
file_search: { file_search: {
file_ids: [fileIds[0]], file_ids: [fileIds[0]], // Only one file attached
},
},
});
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
const accessMap = await hasAccessToFilesViaAgent({
userId: authorId,
role: SystemRoles.USER,
fileIds,
agentId,
});
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(false);
expect(accessMap.get(fileIds[2])).toBe(false);
});
it('should deny all access when agent has no tool_resources', async () => {
const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4();
const fileId = uuidv4();
await User.create({
_id: authorId,
email: 'author-no-resources@example.com',
emailVerified: true,
provider: 'local',
});
await createAgent({
id: agentId,
name: 'Bare Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
});
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
const accessMap = await hasAccessToFilesViaAgent({
userId: authorId,
role: SystemRoles.USER,
fileIds: [fileId],
agentId,
});
expect(accessMap.get(fileId)).toBe(false);
});
it('should grant access to files across multiple resource types', async () => {
const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
await User.create({
_id: authorId,
email: 'author-multi@example.com',
emailVerified: true,
provider: 'local',
});
await createAgent({
id: agentId,
name: 'Multi Resource Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [fileIds[0]],
},
execute_code: {
file_ids: [fileIds[1]],
}, },
}, },
}); });
// Check access as the author
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
const accessMap = await hasAccessToFilesViaAgent({ const accessMap = await hasAccessToFilesViaAgent({
userId: authorId, userId: authorId,
@ -257,48 +188,10 @@ describe('File Access Control', () => {
agentId, agentId,
}); });
// Author should have access to all files
expect(accessMap.get(fileIds[0])).toBe(true); expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true); expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(false); expect(accessMap.get(fileIds[2])).toBe(true);
});
it('should grant author access to attached files when isDelete is true', async () => {
const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4();
const attachedFileId = uuidv4();
const unattachedFileId = uuidv4();
await User.create({
_id: authorId,
email: 'author-delete@example.com',
emailVerified: true,
provider: 'local',
});
await createAgent({
id: agentId,
name: 'Delete Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [attachedFileId],
},
},
});
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
const accessMap = await hasAccessToFilesViaAgent({
userId: authorId,
role: SystemRoles.USER,
fileIds: [attachedFileId, unattachedFileId],
agentId,
isDelete: true,
});
expect(accessMap.get(attachedFileId)).toBe(true);
expect(accessMap.get(unattachedFileId)).toBe(false);
}); });
it('should handle non-existent agent gracefully', async () => { it('should handle non-existent agent gracefully', async () => {

View file

@ -48,14 +48,14 @@ const loadAddedAgent = async ({ req, conversation, primaryAgent }) => {
return null; return null;
} }
// If there's an agent_id, load the existing agent
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) { if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
let agent = req.resolvedAddedAgent;
if (!agent) {
if (!getAgent) { if (!getAgent) {
throw new Error('getAgent not initialized - call setGetAgent first'); throw new Error('getAgent not initialized - call setGetAgent first');
} }
agent = await getAgent({ id: conversation.agent_id }); const agent = await getAgent({
} id: conversation.agent_id,
});
if (!agent) { if (!agent) {
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`); logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);

View file

@ -4,18 +4,31 @@ const defaultRate = 6;
/** /**
* Token Pricing Configuration * Token Pricing Configuration
* *
* Pattern Matching * IMPORTANT: Key Ordering for Pattern Matching
* ================ * ============================================
* `findMatchingPattern` (from @librechat/api) uses `modelName.includes(key)` and selects * The `findMatchingPattern` function iterates through object keys in REVERSE order
* the LONGEST matching key. If a key's length equals the model name's length (exact match), * (last-defined keys are checked first) and uses `modelName.includes(key)` for matching.
* it returns immediately. Definition order does NOT affect correctness.
* *
* Key ordering matters only for: * This means:
* 1. Performance: list older/less common models first so newer/common models * 1. BASE PATTERNS must be defined FIRST (e.g., "kimi", "moonshot")
* are found earlier in the reverse scan. * 2. SPECIFIC PATTERNS must be defined AFTER their base patterns (e.g., "kimi-k2", "kimi-k2.5")
* 2. Same-length tie-breaking: the last-defined key wins on equal-length matches. *
* Example ordering for Kimi models:
* kimi: { prompt: 0.6, completion: 2.5 }, // Base pattern - checked last
* 'kimi-k2': { prompt: 0.6, completion: 2.5 }, // More specific - checked before "kimi"
* 'kimi-k2.5': { prompt: 0.6, completion: 3.0 }, // Most specific - checked first
*
* Why this matters:
* - Model name "kimi-k2.5" contains both "kimi" and "kimi-k2" as substrings
* - If "kimi" were checked first, it would incorrectly match and return wrong pricing
* - By defining specific patterns AFTER base patterns, they're checked first in reverse iteration
* *
* This applies to BOTH `tokenValues` and `cacheTokenValues` objects. * This applies to BOTH `tokenValues` and `cacheTokenValues` objects.
*
* When adding new model families:
* 1. Define the base/generic pattern first
* 2. Define increasingly specific patterns after
* 3. Ensure no pattern is a substring of another that should match differently
*/ */
/** /**
@ -138,9 +151,6 @@ const tokenValues = Object.assign(
'gpt-5.1': { prompt: 1.25, completion: 10 }, 'gpt-5.1': { prompt: 1.25, completion: 10 },
'gpt-5.2': { prompt: 1.75, completion: 14 }, 'gpt-5.2': { prompt: 1.75, completion: 14 },
'gpt-5.3': { prompt: 1.75, completion: 14 }, 'gpt-5.3': { prompt: 1.75, completion: 14 },
'gpt-5.4': { prompt: 2.5, completion: 15 },
// TODO: gpt-5.4-pro pricing not yet officially published — verify before release
'gpt-5.4-pro': { prompt: 5, completion: 30 },
'gpt-5-nano': { prompt: 0.05, completion: 0.4 }, 'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
'gpt-5-mini': { prompt: 0.25, completion: 2 }, 'gpt-5-mini': { prompt: 0.25, completion: 2 },
'gpt-5-pro': { prompt: 15, completion: 120 }, 'gpt-5-pro': { prompt: 15, completion: 120 },
@ -312,7 +322,7 @@ const cacheTokenValues = {
// gpt-4o (incl. mini), o1 (incl. mini/preview): 50% off // gpt-4o (incl. mini), o1 (incl. mini/preview): 50% off
// gpt-4.1 (incl. mini/nano), o3 (incl. mini), o4-mini: 75% off // gpt-4.1 (incl. mini/nano), o3 (incl. mini), o4-mini: 75% off
// gpt-5.x (excl. pro variants): 90% off // gpt-5.x (excl. pro variants): 90% off
// gpt-5-pro, gpt-5.2-pro, gpt-5.4-pro: no caching // gpt-5-pro, gpt-5.2-pro: no caching
'gpt-4o': { write: 2.5, read: 1.25 }, 'gpt-4o': { write: 2.5, read: 1.25 },
'gpt-4o-mini': { write: 0.15, read: 0.075 }, 'gpt-4o-mini': { write: 0.15, read: 0.075 },
'gpt-4.1': { write: 2, read: 0.5 }, 'gpt-4.1': { write: 2, read: 0.5 },
@ -322,7 +332,6 @@ const cacheTokenValues = {
'gpt-5.1': { write: 1.25, read: 0.125 }, 'gpt-5.1': { write: 1.25, read: 0.125 },
'gpt-5.2': { write: 1.75, read: 0.175 }, 'gpt-5.2': { write: 1.75, read: 0.175 },
'gpt-5.3': { write: 1.75, read: 0.175 }, 'gpt-5.3': { write: 1.75, read: 0.175 },
'gpt-5.4': { write: 2.5, read: 0.25 },
'gpt-5-mini': { write: 0.25, read: 0.025 }, 'gpt-5-mini': { write: 0.25, read: 0.025 },
'gpt-5-nano': { write: 0.05, read: 0.005 }, 'gpt-5-nano': { write: 0.05, read: 0.005 },
o1: { write: 15, read: 7.5 }, o1: { write: 15, read: 7.5 },

View file

@ -59,17 +59,6 @@ describe('getValueKey', () => {
expect(getValueKey('openai/gpt-5.3')).toBe('gpt-5.3'); expect(getValueKey('openai/gpt-5.3')).toBe('gpt-5.3');
}); });
it('should return "gpt-5.4" for model name containing "gpt-5.4"', () => {
expect(getValueKey('gpt-5.4')).toBe('gpt-5.4');
expect(getValueKey('gpt-5.4-thinking')).toBe('gpt-5.4');
expect(getValueKey('openai/gpt-5.4')).toBe('gpt-5.4');
});
it('should return "gpt-5.4-pro" for model name containing "gpt-5.4-pro"', () => {
expect(getValueKey('gpt-5.4-pro')).toBe('gpt-5.4-pro');
expect(getValueKey('openai/gpt-5.4-pro')).toBe('gpt-5.4-pro');
});
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => { it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
expect(getValueKey('gpt-3.5-turbo-1106-some-other-info')).toBe('gpt-3.5-turbo-1106'); expect(getValueKey('gpt-3.5-turbo-1106-some-other-info')).toBe('gpt-3.5-turbo-1106');
expect(getValueKey('openai/gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106'); expect(getValueKey('openai/gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
@ -411,33 +400,6 @@ describe('getMultiplier', () => {
); );
}); });
it('should return the correct multiplier for gpt-5.4', () => {
expect(getMultiplier({ model: 'gpt-5.4', tokenType: 'prompt' })).toBe(
tokenValues['gpt-5.4'].prompt,
);
expect(getMultiplier({ model: 'gpt-5.4', tokenType: 'completion' })).toBe(
tokenValues['gpt-5.4'].completion,
);
expect(getMultiplier({ model: 'gpt-5.4-thinking', tokenType: 'prompt' })).toBe(
tokenValues['gpt-5.4'].prompt,
);
expect(getMultiplier({ model: 'openai/gpt-5.4', tokenType: 'completion' })).toBe(
tokenValues['gpt-5.4'].completion,
);
});
it('should return the correct multiplier for gpt-5.4-pro', () => {
expect(getMultiplier({ model: 'gpt-5.4-pro', tokenType: 'prompt' })).toBe(
tokenValues['gpt-5.4-pro'].prompt,
);
expect(getMultiplier({ model: 'gpt-5.4-pro', tokenType: 'completion' })).toBe(
tokenValues['gpt-5.4-pro'].completion,
);
expect(getMultiplier({ model: 'openai/gpt-5.4-pro', tokenType: 'prompt' })).toBe(
tokenValues['gpt-5.4-pro'].prompt,
);
});
it('should return the correct multiplier for gpt-4o', () => { it('should return the correct multiplier for gpt-4o', () => {
const valueKey = getValueKey('gpt-4o-2024-08-06'); const valueKey = getValueKey('gpt-4o-2024-08-06');
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt); expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
@ -1415,7 +1377,6 @@ describe('getCacheMultiplier', () => {
'gpt-5.1', 'gpt-5.1',
'gpt-5.2', 'gpt-5.2',
'gpt-5.3', 'gpt-5.3',
'gpt-5.4',
'gpt-5-mini', 'gpt-5-mini',
'gpt-5-nano', 'gpt-5-nano',
'o1', 'o1',
@ -1452,20 +1413,10 @@ describe('getCacheMultiplier', () => {
expect(getCacheMultiplier({ model: 'gpt-5-pro', cacheType: 'write' })).toBeNull(); expect(getCacheMultiplier({ model: 'gpt-5-pro', cacheType: 'write' })).toBeNull();
expect(getCacheMultiplier({ model: 'gpt-5.2-pro', cacheType: 'read' })).toBeNull(); expect(getCacheMultiplier({ model: 'gpt-5.2-pro', cacheType: 'read' })).toBeNull();
expect(getCacheMultiplier({ model: 'gpt-5.2-pro', cacheType: 'write' })).toBeNull(); expect(getCacheMultiplier({ model: 'gpt-5.2-pro', cacheType: 'write' })).toBeNull();
expect(getCacheMultiplier({ model: 'gpt-5.4-pro', cacheType: 'read' })).toBeNull();
expect(getCacheMultiplier({ model: 'gpt-5.4-pro', cacheType: 'write' })).toBeNull();
}); });
it('should have consistent 10% cache read pricing for gpt-5.x models', () => { it('should have consistent 10% cache read pricing for gpt-5.x models', () => {
const gpt5CacheModels = [ const gpt5CacheModels = ['gpt-5', 'gpt-5.1', 'gpt-5.2', 'gpt-5.3', 'gpt-5-mini', 'gpt-5-nano'];
'gpt-5',
'gpt-5.1',
'gpt-5.2',
'gpt-5.3',
'gpt-5.4',
'gpt-5-mini',
'gpt-5-nano',
];
for (const model of gpt5CacheModels) { for (const model of gpt5CacheModels) {
expect(cacheTokenValues[model].read).toBeCloseTo(cacheTokenValues[model].write * 0.1, 10); expect(cacheTokenValues[model].read).toBeCloseTo(cacheTokenValues[model].write * 0.1, 10);
} }

View file

@ -1,6 +1,6 @@
{ {
"name": "@librechat/backend", "name": "@librechat/backend",
"version": "v0.8.3", "version": "v0.8.3-rc2",
"description": "", "description": "",
"scripts": { "scripts": {
"start": "echo 'please run this from the root directory'", "start": "echo 'please run this from the root directory'",
@ -44,14 +44,13 @@
"@google/genai": "^1.19.0", "@google/genai": "^1.19.0",
"@keyv/redis": "^4.3.3", "@keyv/redis": "^4.3.3",
"@langchain/core": "^0.3.80", "@langchain/core": "^0.3.80",
"@librechat/agents": "^3.1.56", "@librechat/agents": "^3.1.55",
"@librechat/api": "*", "@librechat/api": "*",
"@librechat/data-schemas": "*", "@librechat/data-schemas": "*",
"@microsoft/microsoft-graph-client": "^3.0.7", "@microsoft/microsoft-graph-client": "^3.0.7",
"@modelcontextprotocol/sdk": "^1.27.1", "@modelcontextprotocol/sdk": "^1.27.1",
"@node-saml/passport-saml": "^5.1.0", "@node-saml/passport-saml": "^5.1.0",
"@smithy/node-http-handler": "^4.4.5", "@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5", "axios": "^1.13.5",
"bcryptjs": "^2.4.3", "bcryptjs": "^2.4.3",
"compression": "^1.8.1", "compression": "^1.8.1",
@ -64,10 +63,10 @@
"eventsource": "^3.0.2", "eventsource": "^3.0.2",
"express": "^5.2.1", "express": "^5.2.1",
"express-mongo-sanitize": "^2.2.0", "express-mongo-sanitize": "^2.2.0",
"express-rate-limit": "^8.3.0", "express-rate-limit": "^8.2.1",
"express-session": "^1.18.2", "express-session": "^1.18.2",
"express-static-gzip": "^2.2.0", "express-static-gzip": "^2.2.0",
"file-type": "^21.3.2", "file-type": "^18.7.0",
"firebase": "^11.0.2", "firebase": "^11.0.2",
"form-data": "^4.0.4", "form-data": "^4.0.4",
"handlebars": "^4.7.7", "handlebars": "^4.7.7",
@ -88,7 +87,7 @@
"mime": "^3.0.0", "mime": "^3.0.0",
"module-alias": "^2.2.3", "module-alias": "^2.2.3",
"mongoose": "^8.12.1", "mongoose": "^8.12.1",
"multer": "^2.1.1", "multer": "^2.1.0",
"nanoid": "^3.3.7", "nanoid": "^3.3.7",
"node-fetch": "^2.7.0", "node-fetch": "^2.7.0",
"nodemailer": "^7.0.11", "nodemailer": "^7.0.11",
@ -107,9 +106,10 @@
"pdfjs-dist": "^5.4.624", "pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0", "rate-limit-redis": "^4.2.0",
"sharp": "^0.33.5", "sharp": "^0.33.5",
"tiktoken": "^1.0.15",
"traverse": "^0.6.7", "traverse": "^0.6.7",
"ua-parser-js": "^1.0.36", "ua-parser-js": "^1.0.36",
"undici": "^7.24.1", "undici": "^7.18.2",
"winston": "^3.11.0", "winston": "^3.11.0",
"winston-daily-rotate-file": "^5.0.0", "winston-daily-rotate-file": "^5.0.0",
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz",

View file

@ -1,6 +1,5 @@
const { encryptV3, logger } = require('@librechat/data-schemas'); const { encryptV3, logger } = require('@librechat/data-schemas');
const { const {
verifyOTPOrBackupCode,
generateBackupCodes, generateBackupCodes,
generateTOTPSecret, generateTOTPSecret,
verifyBackupCode, verifyBackupCode,
@ -14,42 +13,24 @@ const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
/** /**
* Enable 2FA for the user by generating a new TOTP secret and backup codes. * Enable 2FA for the user by generating a new TOTP secret and backup codes.
* The secret is encrypted and stored, and 2FA is marked as disabled until confirmed. * The secret is encrypted and stored, and 2FA is marked as disabled until confirmed.
* If 2FA is already enabled, requires OTP or backup code verification to re-enroll.
*/ */
const enable2FA = async (req, res) => { const enable2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const existingUser = await getUserById(
userId,
'+totpSecret +backupCodes _id twoFactorEnabled email',
);
if (existingUser && existingUser.twoFactorEnabled) {
const { token, backupCode } = req.body;
const result = await verifyOTPOrBackupCode({
user: existingUser,
token,
backupCode,
persistBackupUse: false,
});
if (!result.verified) {
const msg = result.message ?? 'TOTP token or backup code is required to re-enroll 2FA';
return res.status(result.status ?? 400).json({ message: msg });
}
}
const secret = generateTOTPSecret(); const secret = generateTOTPSecret();
const { plainCodes, codeObjects } = await generateBackupCodes(); const { plainCodes, codeObjects } = await generateBackupCodes();
// Encrypt the secret with v3 encryption before saving.
const encryptedSecret = encryptV3(secret); const encryptedSecret = encryptV3(secret);
// Update the user record: store the secret & backup codes and set twoFactorEnabled to false.
const user = await updateUser(userId, { const user = await updateUser(userId, {
pendingTotpSecret: encryptedSecret, totpSecret: encryptedSecret,
pendingBackupCodes: codeObjects, backupCodes: codeObjects,
twoFactorEnabled: false,
}); });
const email = user.email || (existingUser && existingUser.email) || ''; const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`;
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${email}?secret=${secret}&issuer=${safeAppTitle}`;
return res.status(200).json({ otpauthUrl, backupCodes: plainCodes }); return res.status(200).json({ otpauthUrl, backupCodes: plainCodes });
} catch (err) { } catch (err) {
@ -65,14 +46,13 @@ const verify2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const { token, backupCode } = req.body; const { token, backupCode } = req.body;
const user = await getUserById(userId, '+totpSecret +pendingTotpSecret +backupCodes _id'); const user = await getUserById(userId, '_id totpSecret backupCodes');
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
if (!user || !secretSource) { if (!user || !user.totpSecret) {
return res.status(400).json({ message: '2FA not initiated' }); return res.status(400).json({ message: '2FA not initiated' });
} }
const secret = await getTOTPSecret(secretSource); const secret = await getTOTPSecret(user.totpSecret);
let isVerified = false; let isVerified = false;
if (token) { if (token) {
@ -98,28 +78,15 @@ const confirm2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const { token } = req.body; const { token } = req.body;
const user = await getUserById( const user = await getUserById(userId, '_id totpSecret');
userId,
'+totpSecret +pendingTotpSecret +pendingBackupCodes _id',
);
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
if (!user || !secretSource) { if (!user || !user.totpSecret) {
return res.status(400).json({ message: '2FA not initiated' }); return res.status(400).json({ message: '2FA not initiated' });
} }
const secret = await getTOTPSecret(secretSource); const secret = await getTOTPSecret(user.totpSecret);
if (await verifyTOTP(secret, token)) { if (await verifyTOTP(secret, token)) {
const update = { await updateUser(userId, { twoFactorEnabled: true });
totpSecret: user.pendingTotpSecret ?? user.totpSecret,
twoFactorEnabled: true,
pendingTotpSecret: null,
pendingBackupCodes: [],
};
if (user.pendingBackupCodes?.length) {
update.backupCodes = user.pendingBackupCodes;
}
await updateUser(userId, update);
return res.status(200).json(); return res.status(200).json();
} }
return res.status(400).json({ message: 'Invalid token.' }); return res.status(400).json({ message: 'Invalid token.' });
@ -137,27 +104,31 @@ const disable2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const { token, backupCode } = req.body; const { token, backupCode } = req.body;
const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled'); const user = await getUserById(userId, '_id totpSecret backupCodes');
if (!user || !user.totpSecret) { if (!user || !user.totpSecret) {
return res.status(400).json({ message: '2FA is not setup for this user' }); return res.status(400).json({ message: '2FA is not setup for this user' });
} }
if (user.twoFactorEnabled) { if (user.twoFactorEnabled) {
const result = await verifyOTPOrBackupCode({ user, token, backupCode }); const secret = await getTOTPSecret(user.totpSecret);
let isVerified = false;
if (!result.verified) { if (token) {
const msg = result.message ?? 'Either token or backup code is required to disable 2FA'; isVerified = await verifyTOTP(secret, token);
return res.status(result.status ?? 400).json({ message: msg }); } else if (backupCode) {
isVerified = await verifyBackupCode({ user, backupCode });
} else {
return res
.status(400)
.json({ message: 'Either token or backup code is required to disable 2FA' });
}
if (!isVerified) {
return res.status(401).json({ message: 'Invalid token or backup code' });
} }
} }
await updateUser(userId, { await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false });
totpSecret: null,
backupCodes: [],
twoFactorEnabled: false,
pendingTotpSecret: null,
pendingBackupCodes: [],
});
return res.status(200).json(); return res.status(200).json();
} catch (err) { } catch (err) {
logger.error('[disable2FA]', err); logger.error('[disable2FA]', err);
@ -167,28 +138,10 @@ const disable2FA = async (req, res) => {
/** /**
* Regenerate backup codes for the user. * Regenerate backup codes for the user.
* Requires OTP or backup code verification if 2FA is already enabled.
*/ */
const regenerateBackupCodes = async (req, res) => { const regenerateBackupCodes = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled');
if (!user) {
return res.status(404).json({ message: 'User not found' });
}
if (user.twoFactorEnabled) {
const { token, backupCode } = req.body;
const result = await verifyOTPOrBackupCode({ user, token, backupCode });
if (!result.verified) {
const msg =
result.message ?? 'TOTP token or backup code is required to regenerate backup codes';
return res.status(result.status ?? 400).json({ message: msg });
}
}
const { plainCodes, codeObjects } = await generateBackupCodes(); const { plainCodes, codeObjects } = await generateBackupCodes();
await updateUser(userId, { backupCodes: codeObjects }); await updateUser(userId, { backupCodes: codeObjects });
return res.status(200).json({ return res.status(200).json({

View file

@ -14,7 +14,6 @@ const {
deleteMessages, deleteMessages,
deletePresets, deletePresets,
deleteUserKey, deleteUserKey,
getUserById,
deleteConvos, deleteConvos,
deleteFiles, deleteFiles,
updateUser, updateUser,
@ -35,7 +34,6 @@ const {
User, User,
} = require('~/db/models'); } = require('~/db/models');
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { verifyOTPOrBackupCode } = require('~/server/services/twoFactorService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config'); const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools'); const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
@ -243,22 +241,6 @@ const deleteUserController = async (req, res) => {
const { user } = req; const { user } = req;
try { try {
const existingUser = await getUserById(
user.id,
'+totpSecret +backupCodes _id twoFactorEnabled',
);
if (existingUser && existingUser.twoFactorEnabled) {
const { token, backupCode } = req.body;
const result = await verifyOTPOrBackupCode({ user: existingUser, token, backupCode });
if (!result.verified) {
const msg =
result.message ??
'TOTP token or backup code is required to delete account with 2FA enabled';
return res.status(result.status ?? 400).json({ message: msg });
}
}
await deleteMessages({ user: user.id }); // delete user messages await deleteMessages({ user: user.id }); // delete user messages
await deleteAllUserSessions({ userId: user.id }); // delete user sessions await deleteAllUserSessions({ userId: user.id }); // delete user sessions
await Transaction.deleteMany({ user: user.id }); // delete user transactions await Transaction.deleteMany({ user: user.id }); // delete user transactions
@ -370,7 +352,6 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ?? serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
clientMetadata.revocation_endpoint_auth_methods_supported; clientMetadata.revocation_endpoint_auth_methods_supported;
const oauthHeaders = serverConfig.oauth_headers ?? {}; const oauthHeaders = serverConfig.oauth_headers ?? {};
const allowedDomains = getMCPServersRegistry().getAllowedDomains();
if (tokens?.access_token) { if (tokens?.access_token) {
try { try {
@ -386,7 +367,6 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
revocationEndpointAuthMethodsSupported, revocationEndpointAuthMethodsSupported,
}, },
oauthHeaders, oauthHeaders,
allowedDomains,
); );
} catch (error) { } catch (error) {
logger.error(`Error revoking OAuth access token for ${serverName}:`, error); logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
@ -407,7 +387,6 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
revocationEndpointAuthMethodsSupported, revocationEndpointAuthMethodsSupported,
}, },
oauthHeaders, oauthHeaders,
allowedDomains,
); );
} catch (error) { } catch (error) {
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error); logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);

View file

@ -1,264 +0,0 @@
const mockGetUserById = jest.fn();
const mockUpdateUser = jest.fn();
const mockVerifyOTPOrBackupCode = jest.fn();
const mockGenerateTOTPSecret = jest.fn();
const mockGenerateBackupCodes = jest.fn();
const mockEncryptV3 = jest.fn();
jest.mock('@librechat/data-schemas', () => ({
encryptV3: (...args) => mockEncryptV3(...args),
logger: { error: jest.fn() },
}));
jest.mock('~/server/services/twoFactorService', () => ({
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
generateBackupCodes: (...args) => mockGenerateBackupCodes(...args),
generateTOTPSecret: (...args) => mockGenerateTOTPSecret(...args),
verifyBackupCode: jest.fn(),
getTOTPSecret: jest.fn(),
verifyTOTP: jest.fn(),
}));
jest.mock('~/models', () => ({
getUserById: (...args) => mockGetUserById(...args),
updateUser: (...args) => mockUpdateUser(...args),
}));
const { enable2FA, regenerateBackupCodes } = require('~/server/controllers/TwoFactorController');
function createRes() {
const res = {};
res.status = jest.fn().mockReturnValue(res);
res.json = jest.fn().mockReturnValue(res);
return res;
}
const PLAIN_CODES = ['code1', 'code2', 'code3'];
const CODE_OBJECTS = [
{ codeHash: 'h1', used: false, usedAt: null },
{ codeHash: 'h2', used: false, usedAt: null },
{ codeHash: 'h3', used: false, usedAt: null },
];
beforeEach(() => {
jest.clearAllMocks();
mockGenerateTOTPSecret.mockReturnValue('NEWSECRET');
mockGenerateBackupCodes.mockResolvedValue({ plainCodes: PLAIN_CODES, codeObjects: CODE_OBJECTS });
mockEncryptV3.mockReturnValue('encrypted-secret');
});
describe('enable2FA', () => {
it('allows first-time setup without token — writes to pending fields', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false, email: 'a@b.com' });
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
await enable2FA(req, res);
expect(res.status).toHaveBeenCalledWith(200);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({ otpauthUrl: expect.any(String), backupCodes: PLAIN_CODES }),
);
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
const updateCall = mockUpdateUser.mock.calls[0][1];
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
expect(updateCall).not.toHaveProperty('totpSecret');
expect(updateCall).not.toHaveProperty('backupCodes');
});
it('re-enrollment writes to pending fields, leaving live 2FA intact', async () => {
const req = { user: { id: 'user1' }, body: { token: '123456' } };
const res = createRes();
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
email: 'a@b.com',
};
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
await enable2FA(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: '123456',
backupCode: undefined,
persistBackupUse: false,
});
expect(res.status).toHaveBeenCalledWith(200);
const updateCall = mockUpdateUser.mock.calls[0][1];
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
expect(updateCall).not.toHaveProperty('totpSecret');
});
it('allows re-enrollment with valid backup code (persistBackupUse: false)', async () => {
const req = { user: { id: 'user1' }, body: { backupCode: 'backup123' } };
const res = createRes();
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
email: 'a@b.com',
};
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
await enable2FA(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith(
expect.objectContaining({ persistBackupUse: false }),
);
expect(res.status).toHaveBeenCalledWith(200);
});
it('returns error when no token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
await enable2FA(req, res);
expect(res.status).toHaveBeenCalledWith(400);
expect(mockUpdateUser).not.toHaveBeenCalled();
});
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await enable2FA(req, res);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
expect(mockUpdateUser).not.toHaveBeenCalled();
});
});
describe('regenerateBackupCodes', () => {
it('returns 404 when user not found', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue(null);
await regenerateBackupCodes(req, res);
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({ message: 'User not found' });
});
it('requires OTP when 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: { token: '123456' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({});
await regenerateBackupCodes(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(200);
expect(res.json).toHaveBeenCalledWith({
backupCodes: PLAIN_CODES,
backupCodesHash: CODE_OBJECTS,
});
});
it('returns error when no token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
await regenerateBackupCodes(req, res);
expect(res.status).toHaveBeenCalledWith(400);
});
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await regenerateBackupCodes(req, res);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
});
it('includes backupCodesHash in response', async () => {
const req = { user: { id: 'user1' }, body: { token: '123456' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({});
await regenerateBackupCodes(req, res);
const responseBody = res.json.mock.calls[0][0];
expect(responseBody).toHaveProperty('backupCodesHash', CODE_OBJECTS);
expect(responseBody).toHaveProperty('backupCodes', PLAIN_CODES);
});
it('allows regeneration without token when 2FA is not enabled', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: false,
});
mockUpdateUser.mockResolvedValue({});
await regenerateBackupCodes(req, res);
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(200);
expect(res.json).toHaveBeenCalledWith({
backupCodes: PLAIN_CODES,
backupCodesHash: CODE_OBJECTS,
});
});
});

View file

@ -1,302 +0,0 @@
const mockGetUserById = jest.fn();
const mockDeleteMessages = jest.fn();
const mockDeleteAllUserSessions = jest.fn();
const mockDeleteUserById = jest.fn();
const mockDeleteAllSharedLinks = jest.fn();
const mockDeletePresets = jest.fn();
const mockDeleteUserKey = jest.fn();
const mockDeleteConvos = jest.fn();
const mockDeleteFiles = jest.fn();
const mockGetFiles = jest.fn();
const mockUpdateUserPlugins = jest.fn();
const mockUpdateUser = jest.fn();
const mockFindToken = jest.fn();
const mockVerifyOTPOrBackupCode = jest.fn();
const mockDeleteUserPluginAuth = jest.fn();
const mockProcessDeleteRequest = jest.fn();
const mockDeleteToolCalls = jest.fn();
const mockDeleteUserAgents = jest.fn();
const mockDeleteUserPrompts = jest.fn();
jest.mock('@librechat/data-schemas', () => ({
logger: { error: jest.fn(), info: jest.fn() },
webSearchKeys: [],
}));
jest.mock('librechat-data-provider', () => ({
Tools: {},
CacheKeys: {},
Constants: { mcp_delimiter: '::', mcp_prefix: 'mcp_' },
FileSources: {},
}));
jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {},
MCPTokenStorage: {},
normalizeHttpError: jest.fn(),
extractWebSearchEnvVars: jest.fn(),
}));
jest.mock('~/models', () => ({
deleteAllUserSessions: (...args) => mockDeleteAllUserSessions(...args),
deleteAllSharedLinks: (...args) => mockDeleteAllSharedLinks(...args),
updateUserPlugins: (...args) => mockUpdateUserPlugins(...args),
deleteUserById: (...args) => mockDeleteUserById(...args),
deleteMessages: (...args) => mockDeleteMessages(...args),
deletePresets: (...args) => mockDeletePresets(...args),
deleteUserKey: (...args) => mockDeleteUserKey(...args),
getUserById: (...args) => mockGetUserById(...args),
deleteConvos: (...args) => mockDeleteConvos(...args),
deleteFiles: (...args) => mockDeleteFiles(...args),
updateUser: (...args) => mockUpdateUser(...args),
findToken: (...args) => mockFindToken(...args),
getFiles: (...args) => mockGetFiles(...args),
}));
jest.mock('~/db/models', () => ({
ConversationTag: { deleteMany: jest.fn() },
AgentApiKey: { deleteMany: jest.fn() },
Transaction: { deleteMany: jest.fn() },
MemoryEntry: { deleteMany: jest.fn() },
Assistant: { deleteMany: jest.fn() },
AclEntry: { deleteMany: jest.fn() },
Balance: { deleteMany: jest.fn() },
Action: { deleteMany: jest.fn() },
Group: { updateMany: jest.fn() },
Token: { deleteMany: jest.fn() },
User: {},
}));
jest.mock('~/server/services/PluginService', () => ({
updateUserPluginAuth: jest.fn(),
deleteUserPluginAuth: (...args) => mockDeleteUserPluginAuth(...args),
}));
jest.mock('~/server/services/twoFactorService', () => ({
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
}));
jest.mock('~/server/services/AuthService', () => ({
verifyEmail: jest.fn(),
resendVerificationEmail: jest.fn(),
}));
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getMCPServersRegistry: jest.fn(),
}));
jest.mock('~/server/services/Config/getCachedTools', () => ({
invalidateCachedTools: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
needsRefresh: jest.fn(),
getNewS3URL: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
processDeleteRequest: (...args) => mockProcessDeleteRequest(...args),
}));
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn(),
}));
jest.mock('~/models/ToolCall', () => ({
deleteToolCalls: (...args) => mockDeleteToolCalls(...args),
}));
jest.mock('~/models/Prompt', () => ({
deleteUserPrompts: (...args) => mockDeleteUserPrompts(...args),
}));
jest.mock('~/models/Agent', () => ({
deleteUserAgents: (...args) => mockDeleteUserAgents(...args),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(),
}));
const { deleteUserController } = require('~/server/controllers/UserController');
function createRes() {
const res = {};
res.status = jest.fn().mockReturnValue(res);
res.json = jest.fn().mockReturnValue(res);
res.send = jest.fn().mockReturnValue(res);
return res;
}
function stubDeletionMocks() {
mockDeleteMessages.mockResolvedValue();
mockDeleteAllUserSessions.mockResolvedValue();
mockDeleteUserKey.mockResolvedValue();
mockDeletePresets.mockResolvedValue();
mockDeleteConvos.mockResolvedValue();
mockDeleteUserPluginAuth.mockResolvedValue();
mockDeleteUserById.mockResolvedValue();
mockDeleteAllSharedLinks.mockResolvedValue();
mockGetFiles.mockResolvedValue([]);
mockProcessDeleteRequest.mockResolvedValue();
mockDeleteFiles.mockResolvedValue();
mockDeleteToolCalls.mockResolvedValue();
mockDeleteUserAgents.mockResolvedValue();
mockDeleteUserPrompts.mockResolvedValue();
}
beforeEach(() => {
jest.clearAllMocks();
stubDeletionMocks();
});
describe('deleteUserController - 2FA enforcement', () => {
it('proceeds with deletion when 2FA is not enabled', async () => {
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false });
await deleteUserController(req, res);
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
expect(mockDeleteMessages).toHaveBeenCalled();
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
});
it('proceeds with deletion when user has no 2FA record', async () => {
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue(null);
await deleteUserController(req, res);
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
});
it('returns error when 2FA is enabled and verification fails with 400', async () => {
const req = { user: { id: 'user1', _id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
await deleteUserController(req, res);
expect(res.status).toHaveBeenCalledWith(400);
expect(mockDeleteMessages).not.toHaveBeenCalled();
});
it('returns 401 when 2FA is enabled and invalid TOTP token provided', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
};
const req = { user: { id: 'user1', _id: 'user1' }, body: { token: 'wrong' } };
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: 'wrong',
backupCode: undefined,
});
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
expect(mockDeleteMessages).not.toHaveBeenCalled();
});
it('returns 401 when 2FA is enabled and invalid backup code provided', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
backupCodes: [],
};
const req = { user: { id: 'user1', _id: 'user1' }, body: { backupCode: 'bad-code' } };
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: undefined,
backupCode: 'bad-code',
});
expect(res.status).toHaveBeenCalledWith(401);
expect(mockDeleteMessages).not.toHaveBeenCalled();
});
it('deletes account when valid TOTP token provided with 2FA enabled', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
};
const req = {
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
body: { token: '123456' },
};
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: '123456',
backupCode: undefined,
});
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
expect(mockDeleteMessages).toHaveBeenCalled();
});
it('deletes account when valid backup code provided with 2FA enabled', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
backupCodes: [{ codeHash: 'h1', used: false }],
};
const req = {
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
body: { backupCode: 'valid-code' },
};
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: undefined,
backupCode: 'valid-code',
});
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
expect(mockDeleteMessages).toHaveBeenCalled();
});
});

View file

@ -1,159 +0,0 @@
jest.mock('~/server/services/PermissionService', () => ({
findPubliclyAccessibleResources: jest.fn(),
findAccessibleResources: jest.fn(),
hasPublicPermission: jest.fn(),
grantPermission: jest.fn().mockResolvedValue({}),
}));
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
}));
const mongoose = require('mongoose');
const { actionDelimiter } = require('librechat-data-provider');
const { agentSchema, actionSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { duplicateAgent } = require('../v1');
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
if (!mongoose.models.Agent) {
mongoose.model('Agent', agentSchema);
}
if (!mongoose.models.Action) {
mongoose.model('Action', actionSchema);
}
await mongoose.connect(mongoUri);
}, 20000);
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.models.Agent.deleteMany({});
await mongoose.models.Action.deleteMany({});
});
describe('duplicateAgentHandler — action domain extraction', () => {
it('builds duplicated action entries using metadata.domain, not action_id', async () => {
const userId = new mongoose.Types.ObjectId();
const originalAgentId = `agent_original`;
const agent = await mongoose.models.Agent.create({
id: originalAgentId,
name: 'Test Agent',
author: userId.toString(),
provider: 'openai',
model: 'gpt-4',
tools: [],
actions: [`api.example.com${actionDelimiter}act_original`],
versions: [{ name: 'Test Agent', createdAt: new Date(), updatedAt: new Date() }],
});
await mongoose.models.Action.create({
user: userId,
action_id: 'act_original',
agent_id: originalAgentId,
metadata: { domain: 'api.example.com' },
});
const req = {
params: { id: agent.id },
user: { id: userId.toString() },
};
const res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(201);
const { agent: newAgent, actions: newActions } = res.json.mock.calls[0][0];
expect(newAgent.id).not.toBe(originalAgentId);
expect(String(newAgent.author)).toBe(userId.toString());
expect(newActions).toHaveLength(1);
expect(newActions[0].metadata.domain).toBe('api.example.com');
expect(newActions[0].agent_id).toBe(newAgent.id);
for (const actionEntry of newAgent.actions) {
const [domain, actionId] = actionEntry.split(actionDelimiter);
expect(domain).toBe('api.example.com');
expect(actionId).toBeTruthy();
expect(actionId).not.toBe('act_original');
}
const allActions = await mongoose.models.Action.find({}).lean();
expect(allActions).toHaveLength(2);
const originalAction = allActions.find((a) => a.action_id === 'act_original');
expect(originalAction.agent_id).toBe(originalAgentId);
const duplicatedAction = allActions.find((a) => a.action_id !== 'act_original');
expect(duplicatedAction.agent_id).toBe(newAgent.id);
expect(duplicatedAction.metadata.domain).toBe('api.example.com');
});
it('strips sensitive metadata fields from duplicated actions', async () => {
const userId = new mongoose.Types.ObjectId();
const originalAgentId = 'agent_sensitive';
await mongoose.models.Agent.create({
id: originalAgentId,
name: 'Sensitive Agent',
author: userId.toString(),
provider: 'openai',
model: 'gpt-4',
tools: [],
actions: [`secure.api.com${actionDelimiter}act_secret`],
versions: [{ name: 'Sensitive Agent', createdAt: new Date(), updatedAt: new Date() }],
});
await mongoose.models.Action.create({
user: userId,
action_id: 'act_secret',
agent_id: originalAgentId,
metadata: {
domain: 'secure.api.com',
api_key: 'sk-secret-key-12345',
oauth_client_id: 'client_id_xyz',
oauth_client_secret: 'client_secret_xyz',
},
});
const req = {
params: { id: originalAgentId },
user: { id: userId.toString() },
};
const res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(201);
const duplicatedAction = await mongoose.models.Action.findOne({
agent_id: { $ne: originalAgentId },
}).lean();
expect(duplicatedAction.metadata.domain).toBe('secure.api.com');
expect(duplicatedAction.metadata.api_key).toBeUndefined();
expect(duplicatedAction.metadata.oauth_client_id).toBeUndefined();
expect(duplicatedAction.metadata.oauth_client_secret).toBeUndefined();
const originalAction = await mongoose.models.Action.findOne({
action_id: 'act_secret',
}).lean();
expect(originalAction.metadata.api_key).toBe('sk-secret-key-12345');
});
});

View file

@ -44,7 +44,6 @@ const {
isEphemeralAgentId, isEphemeralAgentId,
removeNullishValues, removeNullishValues,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { updateBalance, bulkInsertTransactions } = require('~/models'); const { updateBalance, bulkInsertTransactions } = require('~/models');
@ -480,7 +479,6 @@ class AgentClient extends BaseClient {
getUserKeyValues: db.getUserKeyValues, getUserKeyValues: db.getUserKeyValues,
getToolFilesByIds: db.getToolFilesByIds, getToolFilesByIds: db.getToolFilesByIds,
getCodeGeneratedFiles: db.getCodeGeneratedFiles, getCodeGeneratedFiles: db.getCodeGeneratedFiles,
filterFilesByAgentAccess,
}, },
); );
@ -1174,11 +1172,7 @@ class AgentClient extends BaseClient {
} }
} }
/** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */
getEncoding() { getEncoding() {
if (this.model && this.model.toLowerCase().includes('claude')) {
return 'claude';
}
return 'o200k_base'; return 'o200k_base';
} }

View file

@ -1,677 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { Constants } = require('librechat-data-provider');
const { agentSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const d = Constants.mcp_delimiter;
const mockGetAllServerConfigs = jest.fn();
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn().mockResolvedValue({
web_search: true,
execute_code: true,
file_search: true,
}),
}));
jest.mock('~/config', () => ({
getMCPServersRegistry: jest.fn(() => ({
getAllServerConfigs: mockGetAllServerConfigs,
})),
}));
jest.mock('~/models/Project', () => ({
getProjectByName: jest.fn().mockResolvedValue(null),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
jest.mock('~/server/services/Files/images/avatar', () => ({
resizeAvatar: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3Url: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
filterFile: jest.fn(),
}));
jest.mock('~/models/Action', () => ({
updateAction: jest.fn(),
getActions: jest.fn().mockResolvedValue([]),
}));
jest.mock('~/models/File', () => ({
deleteFileByFilter: jest.fn(),
}));
jest.mock('~/server/services/PermissionService', () => ({
findAccessibleResources: jest.fn().mockResolvedValue([]),
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
grantPermission: jest.fn(),
hasPublicPermission: jest.fn().mockResolvedValue(false),
checkPermission: jest.fn().mockResolvedValue(true),
}));
jest.mock('~/models', () => ({
getCategoriesWithCounts: jest.fn(),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({
get: jest.fn(),
set: jest.fn(),
delete: jest.fn(),
})),
}));
const {
filterAuthorizedTools,
createAgent: createAgentHandler,
updateAgent: updateAgentHandler,
duplicateAgent: duplicateAgentHandler,
revertAgentVersion: revertAgentVersionHandler,
} = require('./v1');
const { getMCPServersRegistry } = require('~/config');
let Agent;
describe('MCP Tool Authorization', () => {
let mongoServer;
let mockReq;
let mockRes;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
}, 20000);
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Agent.deleteMany({});
jest.clearAllMocks();
getMCPServersRegistry.mockImplementation(() => ({
getAllServerConfigs: mockGetAllServerConfigs,
}));
mockGetAllServerConfigs.mockResolvedValue({
authorizedServer: { type: 'sse', url: 'https://authorized.example.com' },
anotherServer: { type: 'sse', url: 'https://another.example.com' },
});
mockReq = {
user: {
id: new mongoose.Types.ObjectId().toString(),
role: 'USER',
},
body: {},
params: {},
query: {},
app: { locals: { fileStrategy: 'local' } },
};
mockRes = {
status: jest.fn().mockReturnThis(),
json: jest.fn().mockReturnThis(),
};
});
describe('filterAuthorizedTools', () => {
const availableTools = { web_search: true, custom_tool: true };
const userId = 'test-user-123';
test('should keep authorized MCP tools and strip unauthorized ones', async () => {
const result = await filterAuthorizedTools({
tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`, 'web_search'],
userId,
availableTools,
});
expect(result).toContain(`toolA${d}authorizedServer`);
expect(result).toContain('web_search');
expect(result).not.toContain(`toolB${d}forbiddenServer`);
});
test('should keep system tools without querying MCP registry', async () => {
const result = await filterAuthorizedTools({
tools: ['execute_code', 'file_search', 'web_search'],
userId,
availableTools: {},
});
expect(result).toEqual(['execute_code', 'file_search', 'web_search']);
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
});
test('should not query MCP registry when no MCP tools are present', async () => {
const result = await filterAuthorizedTools({
tools: ['web_search', 'custom_tool'],
userId,
availableTools,
});
expect(result).toEqual(['web_search', 'custom_tool']);
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
});
test('should filter all MCP tools when registry is uninitialized', async () => {
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
const result = await filterAuthorizedTools({
tools: [`toolA${d}someServer`, 'web_search'],
userId,
availableTools,
});
expect(result).toEqual(['web_search']);
expect(result).not.toContain(`toolA${d}someServer`);
});
test('should handle mixed authorized and unauthorized MCP tools', async () => {
const result = await filterAuthorizedTools({
tools: [
'web_search',
`search${d}authorizedServer`,
`attack${d}victimServer`,
'execute_code',
`list${d}anotherServer`,
`steal${d}nonexistent`,
],
userId,
availableTools,
});
expect(result).toEqual([
'web_search',
`search${d}authorizedServer`,
'execute_code',
`list${d}anotherServer`,
]);
});
test('should handle empty tools array', async () => {
const result = await filterAuthorizedTools({
tools: [],
userId,
availableTools,
});
expect(result).toEqual([]);
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
});
test('should handle null/undefined tool entries gracefully', async () => {
const result = await filterAuthorizedTools({
tools: [null, undefined, '', 'web_search'],
userId,
availableTools,
});
expect(result).toEqual(['web_search']);
});
test('should call getAllServerConfigs with the correct userId', async () => {
await filterAuthorizedTools({
tools: [`tool${d}authorizedServer`],
userId: 'specific-user-id',
availableTools,
});
expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id');
});
test('should only call getAllServerConfigs once even with multiple MCP tools', async () => {
await filterAuthorizedTools({
tools: [`tool1${d}authorizedServer`, `tool2${d}anotherServer`, `tool3${d}unknownServer`],
userId,
availableTools,
});
expect(mockGetAllServerConfigs).toHaveBeenCalledTimes(1);
});
test('should preserve existing MCP tools when registry is unavailable', async () => {
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
const existingTools = [`toolA${d}serverA`, `toolB${d}serverB`];
const result = await filterAuthorizedTools({
tools: [...existingTools, `newTool${d}unknownServer`, 'web_search'],
userId,
availableTools,
existingTools,
});
expect(result).toContain(`toolA${d}serverA`);
expect(result).toContain(`toolB${d}serverB`);
expect(result).toContain('web_search');
expect(result).not.toContain(`newTool${d}unknownServer`);
});
test('should still reject all MCP tools when registry is unavailable and no existingTools', async () => {
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
const result = await filterAuthorizedTools({
tools: [`toolA${d}serverA`, 'web_search'],
userId,
availableTools,
});
expect(result).toEqual(['web_search']);
});
test('should not preserve malformed existing tools when registry is unavailable', async () => {
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
const malformedTool = `a${d}b${d}c`;
const result = await filterAuthorizedTools({
tools: [malformedTool, `legit${d}serverA`, 'web_search'],
userId,
availableTools,
existingTools: [malformedTool, `legit${d}serverA`],
});
expect(result).toContain(`legit${d}serverA`);
expect(result).toContain('web_search');
expect(result).not.toContain(malformedTool);
});
test('should reject malformed MCP tool keys with multiple delimiters', async () => {
const result = await filterAuthorizedTools({
tools: [
`attack${d}victimServer${d}authorizedServer`,
`legit${d}authorizedServer`,
`a${d}b${d}c${d}d`,
'web_search',
],
userId,
availableTools,
});
expect(result).toEqual([`legit${d}authorizedServer`, 'web_search']);
expect(result).not.toContainEqual(expect.stringContaining('victimServer'));
expect(result).not.toContainEqual(expect.stringContaining(`a${d}b`));
});
});
describe('createAgentHandler - MCP tool authorization', () => {
test('should strip unauthorized MCP tools on create', async () => {
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'MCP Test Agent',
tools: ['web_search', `validTool${d}authorizedServer`, `attack${d}forbiddenServer`],
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const agent = mockRes.json.mock.calls[0][0];
expect(agent.tools).toContain('web_search');
expect(agent.tools).toContain(`validTool${d}authorizedServer`);
expect(agent.tools).not.toContain(`attack${d}forbiddenServer`);
});
test('should not 500 when MCP registry is uninitialized', async () => {
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'MCP Uninitialized Test',
tools: [`tool${d}someServer`, 'web_search'],
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const agent = mockRes.json.mock.calls[0][0];
expect(agent.tools).toEqual(['web_search']);
});
test('should store mcpServerNames only for authorized servers', async () => {
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'MCP Names Test',
tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`],
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const agent = mockRes.json.mock.calls[0][0];
const agentInDb = await Agent.findOne({ id: agent.id });
expect(agentInDb.mcpServerNames).toContain('authorizedServer');
expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer');
});
});
describe('updateAgentHandler - MCP tool authorization', () => {
let existingAgentId;
let existingAgentAuthorId;
beforeEach(async () => {
existingAgentAuthorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Original Agent',
provider: 'openai',
model: 'gpt-4',
author: existingAgentAuthorId,
tools: ['web_search', `existingTool${d}authorizedServer`],
mcpServerNames: ['authorizedServer'],
versions: [
{
name: 'Original Agent',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search', `existingTool${d}authorizedServer`],
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
existingAgentId = agent.id;
});
test('should preserve existing MCP tools even if editor lacks access', async () => {
mockGetAllServerConfigs.mockResolvedValue({});
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tools: ['web_search', `existingTool${d}authorizedServer`],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
expect(updatedAgent.tools).toContain('web_search');
});
test('should reject newly added unauthorized MCP tools', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tools: ['web_search', `existingTool${d}authorizedServer`, `attack${d}forbiddenServer`],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tools).toContain('web_search');
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
expect(updatedAgent.tools).not.toContain(`attack${d}forbiddenServer`);
});
test('should allow adding authorized MCP tools', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tools: ['web_search', `existingTool${d}authorizedServer`, `newTool${d}anotherServer`],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tools).toContain(`newTool${d}anotherServer`);
});
test('should not query MCP registry when no new MCP tools added', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tools: ['web_search', `existingTool${d}authorizedServer`],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
});
test('should preserve existing MCP tools when registry unavailable and user edits agent', async () => {
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Renamed After Restart',
tools: ['web_search', `existingTool${d}authorizedServer`],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
expect(updatedAgent.tools).toContain('web_search');
expect(updatedAgent.name).toBe('Renamed After Restart');
});
test('should preserve existing MCP tools when server not in configs (disconnected)', async () => {
mockGetAllServerConfigs.mockResolvedValue({});
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Edited While Disconnected',
tools: ['web_search', `existingTool${d}authorizedServer`],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
expect(updatedAgent.name).toBe('Edited While Disconnected');
});
});
describe('duplicateAgentHandler - MCP tool authorization', () => {
let sourceAgentId;
let sourceAgentAuthorId;
beforeEach(async () => {
sourceAgentAuthorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Source Agent',
provider: 'openai',
model: 'gpt-4',
author: sourceAgentAuthorId,
tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`],
mcpServerNames: ['authorizedServer', 'forbiddenServer'],
versions: [
{
name: 'Source Agent',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`],
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
sourceAgentId = agent.id;
});
test('should strip unauthorized MCP tools from duplicated agent', async () => {
mockGetAllServerConfigs.mockResolvedValue({
authorizedServer: { type: 'sse' },
});
mockReq.user.id = sourceAgentAuthorId.toString();
mockReq.params.id = sourceAgentId;
await duplicateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const { agent: newAgent } = mockRes.json.mock.calls[0][0];
expect(newAgent.id).not.toBe(sourceAgentId);
expect(newAgent.tools).toContain('web_search');
expect(newAgent.tools).toContain(`tool${d}authorizedServer`);
expect(newAgent.tools).not.toContain(`tool${d}forbiddenServer`);
const agentInDb = await Agent.findOne({ id: newAgent.id });
expect(agentInDb.mcpServerNames).toContain('authorizedServer');
expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer');
});
test('should preserve source agent MCP tools when registry is unavailable', async () => {
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
mockReq.user.id = sourceAgentAuthorId.toString();
mockReq.params.id = sourceAgentId;
await duplicateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const { agent: newAgent } = mockRes.json.mock.calls[0][0];
expect(newAgent.tools).toContain('web_search');
expect(newAgent.tools).toContain(`tool${d}authorizedServer`);
expect(newAgent.tools).toContain(`tool${d}forbiddenServer`);
});
});
describe('revertAgentVersionHandler - MCP tool authorization', () => {
let existingAgentId;
let existingAgentAuthorId;
beforeEach(async () => {
existingAgentAuthorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Reverted Agent V2',
provider: 'openai',
model: 'gpt-4',
author: existingAgentAuthorId,
tools: ['web_search'],
versions: [
{
name: 'Reverted Agent V1',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search', `oldTool${d}revokedServer`],
createdAt: new Date(Date.now() - 10000),
updatedAt: new Date(Date.now() - 10000),
},
{
name: 'Reverted Agent V2',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search'],
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
existingAgentId = agent.id;
});
test('should strip unauthorized MCP tools after reverting to a previous version', async () => {
mockGetAllServerConfigs.mockResolvedValue({
authorizedServer: { type: 'sse' },
});
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = { version_index: 0 };
await revertAgentVersionHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const result = mockRes.json.mock.calls[0][0];
expect(result.tools).toContain('web_search');
expect(result.tools).not.toContain(`oldTool${d}revokedServer`);
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.tools).toContain('web_search');
expect(agentInDb.tools).not.toContain(`oldTool${d}revokedServer`);
});
test('should keep authorized MCP tools after revert', async () => {
await Agent.updateOne(
{ id: existingAgentId },
{ $set: { 'versions.0.tools': ['web_search', `tool${d}authorizedServer`] } },
);
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = { version_index: 0 };
await revertAgentVersionHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const result = mockRes.json.mock.calls[0][0];
expect(result.tools).toContain('web_search');
expect(result.tools).toContain(`tool${d}authorizedServer`);
});
test('should preserve version MCP tools when registry is unavailable on revert', async () => {
await Agent.updateOne(
{ id: existingAgentId },
{
$set: {
'versions.0.tools': [
'web_search',
`validTool${d}authorizedServer`,
`otherTool${d}anotherServer`,
],
},
},
);
getMCPServersRegistry.mockImplementation(() => {
throw new Error('MCPServersRegistry has not been initialized.');
});
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = { version_index: 0 };
await revertAgentVersionHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const result = mockRes.json.mock.calls[0][0];
expect(result.tools).toContain('web_search');
expect(result.tools).toContain(`validTool${d}authorizedServer`);
expect(result.tools).toContain(`otherTool${d}anotherServer`);
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.tools).toContain(`validTool${d}authorizedServer`);
expect(agentInDb.tools).toContain(`otherTool${d}anotherServer`);
});
});
});

View file

@ -265,7 +265,6 @@ const OpenAIChatCompletionController = async (req, res) => {
toolRegistry: primaryConfig.toolRegistry, toolRegistry: primaryConfig.toolRegistry,
userMCPAuthMap: primaryConfig.userMCPAuthMap, userMCPAuthMap: primaryConfig.userMCPAuthMap,
tool_resources: primaryConfig.tool_resources, tool_resources: primaryConfig.tool_resources,
actionsEnabled: primaryConfig.actionsEnabled,
}); });
}, },
toolEndCallback, toolEndCallback,

View file

@ -429,7 +429,6 @@ const createResponse = async (req, res) => {
toolRegistry: primaryConfig.toolRegistry, toolRegistry: primaryConfig.toolRegistry,
userMCPAuthMap: primaryConfig.userMCPAuthMap, userMCPAuthMap: primaryConfig.userMCPAuthMap,
tool_resources: primaryConfig.tool_resources, tool_resources: primaryConfig.tool_resources,
actionsEnabled: primaryConfig.actionsEnabled,
}); });
}, },
toolEndCallback, toolEndCallback,
@ -587,7 +586,6 @@ const createResponse = async (req, res) => {
toolRegistry: primaryConfig.toolRegistry, toolRegistry: primaryConfig.toolRegistry,
userMCPAuthMap: primaryConfig.userMCPAuthMap, userMCPAuthMap: primaryConfig.userMCPAuthMap,
tool_resources: primaryConfig.tool_resources, tool_resources: primaryConfig.tool_resources,
actionsEnabled: primaryConfig.actionsEnabled,
}); });
}, },
toolEndCallback, toolEndCallback,

View file

@ -6,7 +6,6 @@ const {
agentCreateSchema, agentCreateSchema,
agentUpdateSchema, agentUpdateSchema,
refreshListAvatars, refreshListAvatars,
collectEdgeAgentIds,
mergeAgentOcrConversion, mergeAgentOcrConversion,
MAX_AVATAR_REFRESH_AGENTS, MAX_AVATAR_REFRESH_AGENTS,
convertOcrToContextInPlace, convertOcrToContextInPlace,
@ -36,7 +35,6 @@ const {
} = require('~/models/Agent'); } = require('~/models/Agent');
const { const {
findPubliclyAccessibleResources, findPubliclyAccessibleResources,
getResourcePermissionsMap,
findAccessibleResources, findAccessibleResources,
hasPublicPermission, hasPublicPermission,
grantPermission, grantPermission,
@ -49,7 +47,6 @@ const { refreshS3Url } = require('~/server/services/Files/S3/crud');
const { filterFile } = require('~/server/services/Files/process'); const { filterFile } = require('~/server/services/Files/process');
const { updateAction, getActions } = require('~/models/Action'); const { updateAction, getActions } = require('~/models/Action');
const { getCachedTools } = require('~/server/services/Config'); const { getCachedTools } = require('~/server/services/Config');
const { getMCPServersRegistry } = require('~/config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
const systemTools = { const systemTools = {
@ -61,116 +58,6 @@ const systemTools = {
const MAX_SEARCH_LEN = 100; const MAX_SEARCH_LEN = 100;
const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
/**
* Validates that the requesting user has VIEW access to every agent referenced in edges.
* Agents that do not exist in the database are skipped at create time, the `from` field
* often references the agent being built, which has no DB record yet.
* @param {import('librechat-data-provider').GraphEdge[]} edges
* @param {string} userId
* @param {string} userRole - Used for group/role principal resolution
* @returns {Promise<string[]>} Agent IDs the user cannot VIEW (empty if all accessible)
*/
const validateEdgeAgentAccess = async (edges, userId, userRole) => {
const edgeAgentIds = collectEdgeAgentIds(edges);
if (edgeAgentIds.size === 0) {
return [];
}
const agents = (await Promise.all([...edgeAgentIds].map((id) => getAgent({ id })))).filter(
Boolean,
);
if (agents.length === 0) {
return [];
}
const permissionsMap = await getResourcePermissionsMap({
userId,
role: userRole,
resourceType: ResourceType.AGENT,
resourceIds: agents.map((a) => a._id),
});
return agents
.filter((a) => {
const bits = permissionsMap.get(a._id.toString()) ?? 0;
return (bits & PermissionBits.VIEW) === 0;
})
.map((a) => a.id);
};
/**
* Filters tools to only include those the user is authorized to use.
* MCP tools must match the exact format `{toolName}_mcp_{serverName}` (exactly 2 segments).
* Multi-delimiter keys are rejected to prevent authorization/execution mismatch.
* Non-MCP tools must appear in availableTools (global tool cache) or systemTools.
*
* When `existingTools` is provided and the MCP registry is unavailable (e.g. server restart),
* tools already present on the agent are preserved rather than stripped they were validated
* when originally added, and we cannot re-verify them without the registry.
* @param {object} params
* @param {string[]} params.tools - Raw tool strings from the request
* @param {string} params.userId - Requesting user ID for MCP server access check
* @param {Record<string, unknown>} params.availableTools - Global non-MCP tool cache
* @param {string[]} [params.existingTools] - Tools already persisted on the agent document
* @returns {Promise<string[]>} Only the authorized subset of tools
*/
const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTools }) => {
const filteredTools = [];
let mcpServerConfigs;
let registryUnavailable = false;
const existingToolSet = existingTools?.length ? new Set(existingTools) : null;
for (const tool of tools) {
if (availableTools[tool] || systemTools[tool]) {
filteredTools.push(tool);
continue;
}
if (!tool?.includes(Constants.mcp_delimiter)) {
continue;
}
if (mcpServerConfigs === undefined) {
try {
mcpServerConfigs = (await getMCPServersRegistry().getAllServerConfigs(userId)) ?? {};
} catch (e) {
logger.warn(
'[filterAuthorizedTools] MCP registry unavailable, filtering all MCP tools',
e.message,
);
mcpServerConfigs = {};
registryUnavailable = true;
}
}
const parts = tool.split(Constants.mcp_delimiter);
if (parts.length !== 2) {
logger.warn(
`[filterAuthorizedTools] Rejected malformed MCP tool key "${tool}" for user ${userId}`,
);
continue;
}
if (registryUnavailable && existingToolSet?.has(tool)) {
filteredTools.push(tool);
continue;
}
const [, serverName] = parts;
if (!serverName || !Object.hasOwn(mcpServerConfigs, serverName)) {
logger.warn(
`[filterAuthorizedTools] Rejected MCP tool "${tool}" — server "${serverName}" not accessible to user ${userId}`,
);
continue;
}
filteredTools.push(tool);
}
return filteredTools;
};
/** /**
* Creates an Agent. * Creates an Agent.
* @route POST /Agents * @route POST /Agents
@ -188,24 +75,22 @@ const createAgentHandler = async (req, res) => {
agentData.model_parameters = removeNullishValues(agentData.model_parameters, true); agentData.model_parameters = removeNullishValues(agentData.model_parameters, true);
} }
const { id: userId, role: userRole } = req.user; const { id: userId } = req.user;
if (agentData.edges?.length) {
const unauthorized = await validateEdgeAgentAccess(agentData.edges, userId, userRole);
if (unauthorized.length > 0) {
return res.status(403).json({
error: 'You do not have access to one or more agents referenced in edges',
agent_ids: unauthorized,
});
}
}
agentData.id = `agent_${nanoid()}`; agentData.id = `agent_${nanoid()}`;
agentData.author = userId; agentData.author = userId;
agentData.tools = []; agentData.tools = [];
const availableTools = (await getCachedTools()) ?? {}; const availableTools = (await getCachedTools()) ?? {};
agentData.tools = await filterAuthorizedTools({ tools, userId, availableTools }); for (const tool of tools) {
if (availableTools[tool]) {
agentData.tools.push(tool);
} else if (systemTools[tool]) {
agentData.tools.push(tool);
} else if (tool.includes(Constants.mcp_delimiter)) {
agentData.tools.push(tool);
}
}
const agent = await createAgent(agentData); const agent = await createAgent(agentData);
@ -358,17 +243,6 @@ const updateAgentHandler = async (req, res) => {
updateData.avatar = avatarField; updateData.avatar = avatarField;
} }
if (updateData.edges?.length) {
const { id: userId, role: userRole } = req.user;
const unauthorized = await validateEdgeAgentAccess(updateData.edges, userId, userRole);
if (unauthorized.length > 0) {
return res.status(403).json({
error: 'You do not have access to one or more agents referenced in edges',
agent_ids: unauthorized,
});
}
}
// Convert OCR to context in incoming updateData // Convert OCR to context in incoming updateData
convertOcrToContextInPlace(updateData); convertOcrToContextInPlace(updateData);
@ -387,26 +261,6 @@ const updateAgentHandler = async (req, res) => {
updateData.tools = ocrConversion.tools; updateData.tools = ocrConversion.tools;
} }
if (updateData.tools) {
const existingToolSet = new Set(existingAgent.tools ?? []);
const newMCPTools = updateData.tools.filter(
(t) => !existingToolSet.has(t) && t?.includes(Constants.mcp_delimiter),
);
if (newMCPTools.length > 0) {
const availableTools = (await getCachedTools()) ?? {};
const approvedNew = await filterAuthorizedTools({
tools: newMCPTools,
userId: req.user.id,
availableTools,
});
const rejectedSet = new Set(newMCPTools.filter((t) => !approvedNew.includes(t)));
if (rejectedSet.size > 0) {
updateData.tools = updateData.tools.filter((t) => !rejectedSet.has(t));
}
}
}
let updatedAgent = let updatedAgent =
Object.keys(updateData).length > 0 Object.keys(updateData).length > 0
? await updateAgent({ id }, updateData, { ? await updateAgent({ id }, updateData, {
@ -517,7 +371,7 @@ const duplicateAgentHandler = async (req, res) => {
*/ */
const duplicateAction = async (action) => { const duplicateAction = async (action) => {
const newActionId = nanoid(); const newActionId = nanoid();
const { domain } = action.metadata; const [domain] = action.action_id.split(actionDelimiter);
const fullActionId = `${domain}${actionDelimiter}${newActionId}`; const fullActionId = `${domain}${actionDelimiter}${newActionId}`;
// Sanitize sensitive metadata before persisting // Sanitize sensitive metadata before persisting
@ -527,7 +381,7 @@ const duplicateAgentHandler = async (req, res) => {
} }
const newAction = await updateAction( const newAction = await updateAction(
{ action_id: newActionId, agent_id: newAgentId }, { action_id: newActionId },
{ {
metadata: filteredMetadata, metadata: filteredMetadata,
agent_id: newAgentId, agent_id: newAgentId,
@ -549,17 +403,6 @@ const duplicateAgentHandler = async (req, res) => {
const agentActions = await Promise.all(promises); const agentActions = await Promise.all(promises);
newAgentData.actions = agentActions; newAgentData.actions = agentActions;
if (newAgentData.tools?.length) {
const availableTools = (await getCachedTools()) ?? {};
newAgentData.tools = await filterAuthorizedTools({
tools: newAgentData.tools,
userId,
availableTools,
existingTools: newAgentData.tools,
});
}
const newAgent = await createAgent(newAgentData); const newAgent = await createAgent(newAgentData);
try { try {
@ -888,24 +731,7 @@ const revertAgentVersionHandler = async (req, res) => {
// Permissions are enforced via route middleware (ACL EDIT) // Permissions are enforced via route middleware (ACL EDIT)
let updatedAgent = await revertAgentVersion({ id }, version_index); const updatedAgent = await revertAgentVersion({ id }, version_index);
if (updatedAgent.tools?.length) {
const availableTools = (await getCachedTools()) ?? {};
const filteredTools = await filterAuthorizedTools({
tools: updatedAgent.tools,
userId: req.user.id,
availableTools,
existingTools: updatedAgent.tools,
});
if (filteredTools.length !== updatedAgent.tools.length) {
updatedAgent = await updateAgent(
{ id },
{ tools: filteredTools },
{ updatingUserId: req.user.id },
);
}
}
if (updatedAgent.author) { if (updatedAgent.author) {
updatedAgent.author = updatedAgent.author.toString(); updatedAgent.author = updatedAgent.author.toString();
@ -973,5 +799,4 @@ module.exports = {
uploadAgentAvatar: uploadAgentAvatarHandler, uploadAgentAvatar: uploadAgentAvatarHandler,
revertAgentVersion: revertAgentVersionHandler, revertAgentVersion: revertAgentVersionHandler,
getAgentCategories, getAgentCategories,
filterAuthorizedTools,
}; };

View file

@ -2,7 +2,7 @@ const mongoose = require('mongoose');
const { nanoid } = require('nanoid'); const { nanoid } = require('nanoid');
const { v4: uuidv4 } = require('uuid'); const { v4: uuidv4 } = require('uuid');
const { agentSchema } = require('@librechat/data-schemas'); const { agentSchema } = require('@librechat/data-schemas');
const { FileSources, PermissionBits } = require('librechat-data-provider'); const { FileSources } = require('librechat-data-provider');
const { MongoMemoryServer } = require('mongodb-memory-server'); const { MongoMemoryServer } = require('mongodb-memory-server');
// Only mock the dependencies that are not database-related // Only mock the dependencies that are not database-related
@ -46,9 +46,9 @@ jest.mock('~/models/File', () => ({
jest.mock('~/server/services/PermissionService', () => ({ jest.mock('~/server/services/PermissionService', () => ({
findAccessibleResources: jest.fn().mockResolvedValue([]), findAccessibleResources: jest.fn().mockResolvedValue([]),
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]), findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
getResourcePermissionsMap: jest.fn().mockResolvedValue(new Map()),
grantPermission: jest.fn(), grantPermission: jest.fn(),
hasPublicPermission: jest.fn().mockResolvedValue(false), hasPublicPermission: jest.fn().mockResolvedValue(false),
checkPermission: jest.fn().mockResolvedValue(true),
})); }));
jest.mock('~/models', () => ({ jest.mock('~/models', () => ({
@ -74,7 +74,6 @@ const {
const { const {
findAccessibleResources, findAccessibleResources,
findPubliclyAccessibleResources, findPubliclyAccessibleResources,
getResourcePermissionsMap,
} = require('~/server/services/PermissionService'); } = require('~/server/services/PermissionService');
const { refreshS3Url } = require('~/server/services/Files/S3/crud'); const { refreshS3Url } = require('~/server/services/Files/S3/crud');
@ -1648,112 +1647,4 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
expect(agent.avatar.filepath).toBe('old-s3-path.jpg'); expect(agent.avatar.filepath).toBe('old-s3-path.jpg');
}); });
}); });
describe('Edge ACL validation', () => {
let targetAgent;
beforeEach(async () => {
targetAgent = await Agent.create({
id: `agent_${nanoid()}`,
author: new mongoose.Types.ObjectId().toString(),
name: 'Target Agent',
provider: 'openai',
model: 'gpt-4',
tools: [],
});
});
test('createAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => {
const permMap = new Map();
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
mockReq.body = {
name: 'Attacker Agent',
provider: 'openai',
model: 'gpt-4',
edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }],
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(403);
const response = mockRes.json.mock.calls[0][0];
expect(response.agent_ids).toContain(targetAgent.id);
});
test('createAgentHandler should succeed when user has VIEW on all edge-referenced agents', async () => {
const permMap = new Map([[targetAgent._id.toString(), 1]]);
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
mockReq.body = {
name: 'Legit Agent',
provider: 'openai',
model: 'gpt-4',
edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }],
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
});
test('createAgentHandler should allow edges referencing non-existent agents (self-reference at create time)', async () => {
mockReq.body = {
name: 'Self-Ref Agent',
provider: 'openai',
model: 'gpt-4',
edges: [{ from: 'agent_does_not_exist_yet', to: 'agent_also_new', edgeType: 'handoff' }],
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
});
test('updateAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => {
const ownedAgent = await Agent.create({
id: `agent_${nanoid()}`,
author: mockReq.user.id,
name: 'Owned Agent',
provider: 'openai',
model: 'gpt-4',
tools: [],
});
const permMap = new Map([[ownedAgent._id.toString(), PermissionBits.VIEW]]);
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
mockReq.params = { id: ownedAgent.id };
mockReq.body = {
edges: [{ from: ownedAgent.id, to: targetAgent.id, edgeType: 'handoff' }],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(403);
const response = mockRes.json.mock.calls[0][0];
expect(response.agent_ids).toContain(targetAgent.id);
expect(response.agent_ids).not.toContain(ownedAgent.id);
});
test('updateAgentHandler should succeed when edges field is absent from payload', async () => {
const ownedAgent = await Agent.create({
id: `agent_${nanoid()}`,
author: mockReq.user.id,
name: 'Owned Agent',
provider: 'openai',
model: 'gpt-4',
tools: [],
});
mockReq.params = { id: ownedAgent.id };
mockReq.body = { name: 'Renamed Agent' };
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
const response = mockRes.json.mock.calls[0][0];
expect(response.name).toBe('Renamed Agent');
});
});
}); });

View file

@ -7,11 +7,9 @@
*/ */
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { const {
MCPErrorCodes,
redactServerSecrets,
redactAllServerSecrets,
isMCPDomainNotAllowedError, isMCPDomainNotAllowedError,
isMCPInspectionFailedError, isMCPInspectionFailedError,
MCPErrorCodes,
} = require('@librechat/api'); } = require('@librechat/api');
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
@ -183,8 +181,10 @@ const getMCPServersList = async (req, res) => {
return res.status(401).json({ message: 'Unauthorized' }); return res.status(401).json({ message: 'Unauthorized' });
} }
// 2. Get all server configs from registry (YAML + DB)
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
return res.json(redactAllServerSecrets(serverConfigs));
return res.json(serverConfigs);
} catch (error) { } catch (error) {
logger.error('[getMCPServersList]', error); logger.error('[getMCPServersList]', error);
res.status(500).json({ error: error.message }); res.status(500).json({ error: error.message });
@ -215,7 +215,7 @@ const createMCPServerController = async (req, res) => {
); );
res.status(201).json({ res.status(201).json({
serverName: result.serverName, serverName: result.serverName,
...redactServerSecrets(result.config), ...result.config,
}); });
} catch (error) { } catch (error) {
logger.error('[createMCPServer]', error); logger.error('[createMCPServer]', error);
@ -243,7 +243,7 @@ const getMCPServerById = async (req, res) => {
return res.status(404).json({ message: 'MCP server not found' }); return res.status(404).json({ message: 'MCP server not found' });
} }
res.status(200).json(redactServerSecrets(parsedConfig)); res.status(200).json(parsedConfig);
} catch (error) { } catch (error) {
logger.error('[getMCPServerById]', error); logger.error('[getMCPServerById]', error);
res.status(500).json({ message: error.message }); res.status(500).json({ message: error.message });
@ -274,7 +274,7 @@ const updateMCPServerController = async (req, res) => {
userId, userId,
); );
res.status(200).json(redactServerSecrets(parsedConfig)); res.status(200).json(parsedConfig);
} catch (error) { } catch (error) {
logger.error('[updateMCPServer]', error); logger.error('[updateMCPServer]', error);
const mcpErrorResponse = handleMCPError(error, res); const mcpErrorResponse = handleMCPError(error, res);

View file

@ -1,144 +1,42 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { const {
Constants, Constants,
Permissions,
ResourceType, ResourceType,
SystemRoles,
PermissionTypes,
isAgentsEndpoint, isAgentsEndpoint,
isEphemeralAgentId, isEphemeralAgentId,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { checkPermission } = require('~/server/services/PermissionService');
const { canAccessResource } = require('./canAccessResource'); const { canAccessResource } = require('./canAccessResource');
const { getRoleByName } = require('~/models/Role');
const { getAgent } = require('~/models/Agent'); const { getAgent } = require('~/models/Agent');
/** /**
* Resolves custom agent ID (e.g., "agent_abc123") to a MongoDB document. * Agent ID resolver function for agent_id from request body
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
* This is used specifically for chat routes where agent_id comes from request body
*
* @param {string} agentCustomId - Custom agent ID from request body * @param {string} agentCustomId - Custom agent ID from request body
* @returns {Promise<Object|null>} Agent document with _id field, or null if ephemeral/not found * @returns {Promise<Object|null>} Agent document with _id field, or null if not found
*/ */
const resolveAgentIdFromBody = async (agentCustomId) => { const resolveAgentIdFromBody = async (agentCustomId) => {
// Handle ephemeral agents - they don't need permission checks
// Real agent IDs always start with "agent_", so anything else is ephemeral
if (isEphemeralAgentId(agentCustomId)) { if (isEphemeralAgentId(agentCustomId)) {
return null; return null; // No permission check needed for ephemeral agents
} }
return getAgent({ id: agentCustomId });
return await getAgent({ id: agentCustomId });
}; };
/** /**
* Creates a `canAccessResource` middleware for the given agent ID * Middleware factory that creates middleware to check agent access permissions from request body.
* and chains to the provided continuation on success. * This middleware is specifically designed for chat routes where the agent_id comes from req.body
* * instead of route parameters.
* @param {string} agentId - The agent's custom string ID (e.g., "agent_abc123")
* @param {number} requiredPermission - Permission bit(s) required
* @param {import('express').Request} req
* @param {import('express').Response} res - Written on deny; continuation called on allow
* @param {Function} continuation - Called when the permission check passes
* @returns {Promise<void>}
*/
const checkAgentResourceAccess = (agentId, requiredPermission, req, res, continuation) => {
const middleware = canAccessResource({
resourceType: ResourceType.AGENT,
requiredPermission,
resourceIdParam: 'agent_id',
idResolver: () => resolveAgentIdFromBody(agentId),
});
const tempReq = {
...req,
params: { ...req.params, agent_id: agentId },
};
return middleware(tempReq, res, continuation);
};
/**
* Middleware factory that validates MULTI_CONVO:USE role permission and, when
* addedConvo.agent_id is a non-ephemeral agent, the same resource-level permission
* required for the primary agent (`requiredPermission`). Caches the resolved agent
* document on `req.resolvedAddedAgent` to avoid a duplicate DB fetch in `loadAddedAgent`.
*
* @param {number} requiredPermission - Permission bit(s) to check on the added agent resource
* @returns {(req: import('express').Request, res: import('express').Response, next: Function) => Promise<void>}
*/
const checkAddedConvoAccess = (requiredPermission) => async (req, res, next) => {
const addedConvo = req.body?.addedConvo;
if (!addedConvo || typeof addedConvo !== 'object' || Array.isArray(addedConvo)) {
return next();
}
try {
if (!req.user?.role) {
return res.status(403).json({
error: 'Forbidden',
message: 'Insufficient permissions for multi-conversation',
});
}
if (req.user.role !== SystemRoles.ADMIN) {
const role = await getRoleByName(req.user.role);
const hasMultiConvo = role?.permissions?.[PermissionTypes.MULTI_CONVO]?.[Permissions.USE];
if (!hasMultiConvo) {
return res.status(403).json({
error: 'Forbidden',
message: 'Multi-conversation feature is not enabled',
});
}
}
const addedAgentId = addedConvo.agent_id;
if (!addedAgentId || typeof addedAgentId !== 'string' || isEphemeralAgentId(addedAgentId)) {
return next();
}
if (req.user.role === SystemRoles.ADMIN) {
return next();
}
const agent = await resolveAgentIdFromBody(addedAgentId);
if (!agent) {
return res.status(404).json({
error: 'Not Found',
message: `${ResourceType.AGENT} not found`,
});
}
const hasPermission = await checkPermission({
userId: req.user.id,
role: req.user.role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission,
});
if (!hasPermission) {
return res.status(403).json({
error: 'Forbidden',
message: `Insufficient permissions to access this ${ResourceType.AGENT}`,
});
}
req.resolvedAddedAgent = agent;
return next();
} catch (error) {
logger.error('Failed to validate addedConvo access permissions', error);
return res.status(500).json({
error: 'Internal Server Error',
message: 'Failed to validate addedConvo access permissions',
});
}
};
/**
* Middleware factory that checks agent access permissions from request body.
* Validates both the primary agent_id and, when present, addedConvo.agent_id
* (which also requires MULTI_CONVO:USE role permission).
* *
* @param {Object} options - Configuration options * @param {Object} options - Configuration options
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share) * @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
* @returns {Function} Express middleware function * @returns {Function} Express middleware function
* *
* @example * @example
* // Basic usage for agent chat (requires VIEW permission)
* router.post('/chat', * router.post('/chat',
* canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }), * canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }),
* buildEndpointOption, * buildEndpointOption,
@ -148,12 +46,11 @@ const checkAddedConvoAccess = (requiredPermission) => async (req, res, next) =>
const canAccessAgentFromBody = (options) => { const canAccessAgentFromBody = (options) => {
const { requiredPermission } = options; const { requiredPermission } = options;
// Validate required options
if (!requiredPermission || typeof requiredPermission !== 'number') { if (!requiredPermission || typeof requiredPermission !== 'number') {
throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number'); throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number');
} }
const addedConvoMiddleware = checkAddedConvoAccess(requiredPermission);
return async (req, res, next) => { return async (req, res, next) => {
try { try {
const { endpoint, agent_id } = req.body; const { endpoint, agent_id } = req.body;
@ -170,13 +67,28 @@ const canAccessAgentFromBody = (options) => {
}); });
} }
const afterPrimaryCheck = () => addedConvoMiddleware(req, res, next); // Skip permission checks for ephemeral agents
// Real agent IDs always start with "agent_", so anything else is ephemeral
if (isEphemeralAgentId(agentId)) { if (isEphemeralAgentId(agentId)) {
return afterPrimaryCheck(); return next();
} }
return checkAgentResourceAccess(agentId, requiredPermission, req, res, afterPrimaryCheck); const agentAccessMiddleware = canAccessResource({
resourceType: ResourceType.AGENT,
requiredPermission,
resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver
idResolver: () => resolveAgentIdFromBody(agentId),
});
const tempReq = {
...req,
params: {
...req.params,
agent_id: agentId,
},
};
return agentAccessMiddleware(tempReq, res, next);
} catch (error) { } catch (error) {
logger.error('Failed to validate agent access permissions', error); logger.error('Failed to validate agent access permissions', error);
return res.status(500).json({ return res.status(500).json({

View file

@ -1,509 +0,0 @@
const mongoose = require('mongoose');
const {
ResourceType,
SystemRoles,
PrincipalType,
PrincipalModel,
} = require('librechat-data-provider');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { canAccessAgentFromBody } = require('./canAccessAgentFromBody');
const { User, Role, AclEntry } = require('~/db/models');
const { createAgent } = require('~/models/Agent');
describe('canAccessAgentFromBody middleware', () => {
let mongoServer;
let req, res, next;
let testUser, otherUser;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.connection.dropDatabase();
await Role.create({
name: 'test-role',
permissions: {
AGENTS: { USE: true, CREATE: true, SHARE: true },
MULTI_CONVO: { USE: true },
},
});
await Role.create({
name: 'no-multi-convo',
permissions: {
AGENTS: { USE: true, CREATE: true, SHARE: true },
MULTI_CONVO: { USE: false },
},
});
await Role.create({
name: SystemRoles.ADMIN,
permissions: {
AGENTS: { USE: true, CREATE: true, SHARE: true },
MULTI_CONVO: { USE: true },
},
});
testUser = await User.create({
email: 'test@example.com',
name: 'Test User',
username: 'testuser',
role: 'test-role',
});
otherUser = await User.create({
email: 'other@example.com',
name: 'Other User',
username: 'otheruser',
role: 'test-role',
});
req = {
user: { id: testUser._id, role: testUser.role },
params: {},
body: {
endpoint: 'agents',
agent_id: 'ephemeral_primary',
},
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
next = jest.fn();
jest.clearAllMocks();
});
describe('middleware factory', () => {
test('throws if requiredPermission is missing', () => {
expect(() => canAccessAgentFromBody({})).toThrow(
'canAccessAgentFromBody: requiredPermission is required and must be a number',
);
});
test('throws if requiredPermission is not a number', () => {
expect(() => canAccessAgentFromBody({ requiredPermission: '1' })).toThrow(
'canAccessAgentFromBody: requiredPermission is required and must be a number',
);
});
test('returns a middleware function', () => {
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
expect(typeof middleware).toBe('function');
expect(middleware.length).toBe(3);
});
});
describe('primary agent checks', () => {
test('returns 400 when agent_id is missing on agents endpoint', async () => {
req.body.agent_id = undefined;
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(400);
});
test('proceeds for ephemeral primary agent without addedConvo', async () => {
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('proceeds for non-agents endpoint (ephemeral fallback)', async () => {
req.body.endpoint = 'openAI';
req.body.agent_id = undefined;
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
});
describe('addedConvo — absent or invalid shape', () => {
test('calls next when addedConvo is absent', async () => {
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('calls next when addedConvo is a string', async () => {
req.body.addedConvo = 'not-an-object';
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('calls next when addedConvo is an array', async () => {
req.body.addedConvo = [{ agent_id: 'agent_something' }];
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
});
describe('addedConvo — MULTI_CONVO permission gate', () => {
test('returns 403 when user lacks MULTI_CONVO:USE', async () => {
req.user.role = 'no-multi-convo';
req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({ message: 'Multi-conversation feature is not enabled' }),
);
});
test('returns 403 when user.role is missing', async () => {
req.user = { id: testUser._id };
req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
test('ADMIN bypasses MULTI_CONVO check', async () => {
req.user.role = SystemRoles.ADMIN;
req.body.addedConvo = { agent_id: 'ephemeral_x', endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
});
describe('addedConvo — agent_id shape validation', () => {
test('calls next when agent_id is ephemeral', async () => {
req.body.addedConvo = { agent_id: 'ephemeral_xyz', endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('calls next when agent_id is absent', async () => {
req.body.addedConvo = { endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('calls next when agent_id is not a string (object injection)', async () => {
req.body.addedConvo = { agent_id: { $gt: '' }, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
});
describe('addedConvo — agent resource ACL (IDOR prevention)', () => {
let addedAgent;
beforeEach(async () => {
addedAgent = await createAgent({
id: `agent_added_${Date.now()}`,
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
});
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 15,
grantedBy: otherUser._id,
});
});
test('returns 403 when requester has no ACL for the added agent', async () => {
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({
message: 'Insufficient permissions to access this agent',
}),
);
});
test('returns 404 when added agent does not exist', async () => {
req.body.addedConvo = {
agent_id: 'agent_nonexistent_999',
endpoint: 'agents',
model: 'gpt-4',
};
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(404);
});
test('proceeds when requester has ACL for the added agent', async () => {
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 1,
grantedBy: otherUser._id,
});
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('denies when ACL permission bits are insufficient', async () => {
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 1,
grantedBy: otherUser._id,
});
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 2 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
test('caches resolved agent on req.resolvedAddedAgent', async () => {
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 1,
grantedBy: otherUser._id,
});
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.resolvedAddedAgent).toBeDefined();
expect(req.resolvedAddedAgent._id.toString()).toBe(addedAgent._id.toString());
});
test('ADMIN bypasses agent resource ACL for addedConvo', async () => {
req.user.role = SystemRoles.ADMIN;
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
expect(req.resolvedAddedAgent).toBeUndefined();
});
});
describe('end-to-end: primary real agent + addedConvo real agent', () => {
let primaryAgent, addedAgent;
beforeEach(async () => {
primaryAgent = await createAgent({
id: `agent_primary_${Date.now()}`,
name: 'Primary Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
});
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: primaryAgent._id,
permBits: 15,
grantedBy: testUser._id,
});
addedAgent = await createAgent({
id: `agent_added_${Date.now()}`,
name: 'Added Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
});
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 15,
grantedBy: otherUser._id,
});
req.body.agent_id = primaryAgent.id;
});
test('both checks pass when user has ACL for both agents', async () => {
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 1,
grantedBy: otherUser._id,
});
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
expect(req.resolvedAddedAgent).toBeDefined();
});
test('primary passes but addedConvo denied → 403', async () => {
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
test('primary denied → 403 without reaching addedConvo check', async () => {
const foreignAgent = await createAgent({
id: `agent_foreign_${Date.now()}`,
name: 'Foreign Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
});
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: foreignAgent._id,
permBits: 15,
grantedBy: otherUser._id,
});
req.body.agent_id = foreignAgent.id;
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
});
describe('ephemeral primary + real addedConvo agent', () => {
let addedAgent;
beforeEach(async () => {
addedAgent = await createAgent({
id: `agent_added_${Date.now()}`,
name: 'Added Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
});
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 15,
grantedBy: otherUser._id,
});
});
test('runs full addedConvo ACL check even when primary is ephemeral', async () => {
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
test('proceeds when user has ACL for added agent (ephemeral primary)', async () => {
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: addedAgent._id,
permBits: 1,
grantedBy: otherUser._id,
});
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
});
});

View file

@ -48,7 +48,7 @@ const createForkHandler = (ip = true) => {
}; };
await logViolation(req, res, type, errorMessage, forkViolationScore); await logViolation(req, res, type, errorMessage, forkViolationScore);
res.status(429).json({ message: 'Too many requests. Try again later' }); res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
}; };
}; };

View file

@ -1,93 +0,0 @@
module.exports = {
agents: () => ({ sleep: jest.fn() }),
api: (overrides = {}) => ({
isEnabled: jest.fn(),
resolveImportMaxFileSize: jest.fn(() => 262144000),
createAxiosInstance: jest.fn(() => ({
get: jest.fn(),
post: jest.fn(),
put: jest.fn(),
delete: jest.fn(),
})),
logAxiosError: jest.fn(),
...overrides,
}),
dataSchemas: () => ({
logger: {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
createModels: jest.fn(() => ({
User: {},
Conversation: {},
Message: {},
SharedLink: {},
})),
}),
dataProvider: (overrides = {}) => ({
CacheKeys: { GEN_TITLE: 'GEN_TITLE' },
EModelEndpoint: {
azureAssistants: 'azureAssistants',
assistants: 'assistants',
},
...overrides,
}),
conversationModel: () => ({
getConvosByCursor: jest.fn(),
getConvo: jest.fn(),
deleteConvos: jest.fn(),
saveConvo: jest.fn(),
}),
toolCallModel: () => ({ deleteToolCalls: jest.fn() }),
sharedModels: () => ({
deleteAllSharedLinks: jest.fn(),
deleteConvoSharedLink: jest.fn(),
}),
requireJwtAuth: () => (req, res, next) => next(),
middlewarePassthrough: () => ({
createImportLimiters: jest.fn(() => ({
importIpLimiter: (req, res, next) => next(),
importUserLimiter: (req, res, next) => next(),
})),
createForkLimiters: jest.fn(() => ({
forkIpLimiter: (req, res, next) => next(),
forkUserLimiter: (req, res, next) => next(),
})),
configMiddleware: (req, res, next) => next(),
validateConvoAccess: (req, res, next) => next(),
}),
forkUtils: () => ({
forkConversation: jest.fn(),
duplicateConversation: jest.fn(),
}),
importUtils: () => ({ importConversations: jest.fn() }),
logStores: () => jest.fn(),
multerSetup: () => ({
storage: {},
importFileFilter: jest.fn(),
}),
multerLib: () =>
jest.fn(() => ({
single: jest.fn(() => (req, res, next) => {
req.file = { path: '/tmp/test-file.json' };
next();
}),
})),
assistantEndpoint: () => ({ initializeClient: jest.fn() }),
};

View file

@ -1,135 +0,0 @@
const express = require('express');
const request = require('supertest');
const MOCKS = '../__test-utils__/convos-route-mocks';
jest.mock('@librechat/agents', () => require(MOCKS).agents());
jest.mock('@librechat/api', () => require(MOCKS).api({ limiterCache: jest.fn(() => undefined) }));
jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas());
jest.mock('librechat-data-provider', () =>
require(MOCKS).dataProvider({ ViolationTypes: { FILE_UPLOAD_LIMIT: 'file_upload_limit' } }),
);
jest.mock('~/cache/logViolation', () => jest.fn().mockResolvedValue(undefined));
jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores());
jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel());
jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel());
jest.mock('~/models', () => require(MOCKS).sharedModels());
jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth());
jest.mock('~/server/middleware', () => {
const { createForkLimiters } = jest.requireActual('~/server/middleware/limiters/forkLimiters');
return {
createImportLimiters: jest.fn(() => ({
importIpLimiter: (req, res, next) => next(),
importUserLimiter: (req, res, next) => next(),
})),
createForkLimiters,
configMiddleware: (req, res, next) => next(),
validateConvoAccess: (req, res, next) => next(),
};
});
jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils());
jest.mock('~/server/utils/import', () => require(MOCKS).importUtils());
jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup());
jest.mock('multer', () => require(MOCKS).multerLib());
jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint());
jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint());
describe('POST /api/convos/duplicate - Rate Limiting', () => {
let app;
let duplicateConversation;
const savedEnv = {};
beforeAll(() => {
savedEnv.FORK_USER_MAX = process.env.FORK_USER_MAX;
savedEnv.FORK_USER_WINDOW = process.env.FORK_USER_WINDOW;
savedEnv.FORK_IP_MAX = process.env.FORK_IP_MAX;
savedEnv.FORK_IP_WINDOW = process.env.FORK_IP_WINDOW;
});
afterAll(() => {
for (const key of Object.keys(savedEnv)) {
if (savedEnv[key] === undefined) {
delete process.env[key];
} else {
process.env[key] = savedEnv[key];
}
}
});
const setupApp = () => {
jest.clearAllMocks();
jest.isolateModules(() => {
const convosRouter = require('../convos');
({ duplicateConversation } = require('~/server/utils/import/fork'));
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: 'rate-limit-test-user' };
next();
});
app.use('/api/convos', convosRouter);
});
duplicateConversation.mockResolvedValue({
conversation: { conversationId: 'duplicated-conv' },
});
};
describe('user limit', () => {
beforeEach(() => {
process.env.FORK_USER_MAX = '2';
process.env.FORK_USER_WINDOW = '1';
process.env.FORK_IP_MAX = '100';
process.env.FORK_IP_WINDOW = '1';
setupApp();
});
it('should return 429 after exceeding the user rate limit', async () => {
const userMax = parseInt(process.env.FORK_USER_MAX, 10);
for (let i = 0; i < userMax; i++) {
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(201);
}
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(429);
expect(res.body.message).toMatch(/too many/i);
});
});
describe('IP limit', () => {
beforeEach(() => {
process.env.FORK_USER_MAX = '100';
process.env.FORK_USER_WINDOW = '1';
process.env.FORK_IP_MAX = '2';
process.env.FORK_IP_WINDOW = '1';
setupApp();
});
it('should return 429 after exceeding the IP rate limit', async () => {
const ipMax = parseInt(process.env.FORK_IP_MAX, 10);
for (let i = 0; i < ipMax; i++) {
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(201);
}
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(429);
expect(res.body.message).toMatch(/too many/i);
});
});
});

View file

@ -1,98 +0,0 @@
const express = require('express');
const request = require('supertest');
const multer = require('multer');
const importFileFilter = (req, file, cb) => {
if (file.mimetype === 'application/json') {
cb(null, true);
} else {
cb(new Error('Only JSON files are allowed'), false);
}
};
/** Proxy app that mirrors the production multer + error-handling pattern */
function createImportApp(fileSize) {
const app = express();
const upload = multer({
storage: multer.memoryStorage(),
fileFilter: importFileFilter,
limits: { fileSize },
});
const uploadSingle = upload.single('file');
function handleUpload(req, res, next) {
uploadSingle(req, res, (err) => {
if (err && err.code === 'LIMIT_FILE_SIZE') {
return res.status(413).json({ message: 'File exceeds the maximum allowed size' });
}
if (err) {
return next(err);
}
next();
});
}
app.post('/import', handleUpload, (req, res) => {
res.status(201).json({ message: 'success', size: req.file.size });
});
app.use((err, _req, res, _next) => {
res.status(400).json({ error: err.message });
});
return app;
}
describe('Conversation Import - Multer File Size Limits', () => {
describe('multer rejects files exceeding the configured limit', () => {
it('returns 413 for files larger than the limit', async () => {
const limit = 1024;
const app = createImportApp(limit);
const oversized = Buffer.alloc(limit + 512, 'x');
const res = await request(app)
.post('/import')
.attach('file', oversized, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(413);
expect(res.body.message).toBe('File exceeds the maximum allowed size');
});
it('accepts files within the limit', async () => {
const limit = 4096;
const app = createImportApp(limit);
const valid = Buffer.from(JSON.stringify({ title: 'test' }));
const res = await request(app)
.post('/import')
.attach('file', valid, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(201);
expect(res.body.message).toBe('success');
});
it('rejects at the exact boundary (limit + 1 byte)', async () => {
const limit = 512;
const app = createImportApp(limit);
const boundary = Buffer.alloc(limit + 1, 'a');
const res = await request(app)
.post('/import')
.attach('file', boundary, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(413);
});
it('accepts a file just under the limit', async () => {
const limit = 512;
const app = createImportApp(limit);
const underLimit = Buffer.alloc(limit - 1, 'b');
const res = await request(app)
.post('/import')
.attach('file', underLimit, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(201);
});
});
});

View file

@ -1,24 +1,109 @@
const express = require('express'); const express = require('express');
const request = require('supertest'); const request = require('supertest');
const MOCKS = '../__test-utils__/convos-route-mocks'; jest.mock('@librechat/agents', () => ({
sleep: jest.fn(),
}));
jest.mock('@librechat/agents', () => require(MOCKS).agents()); jest.mock('@librechat/api', () => ({
jest.mock('@librechat/api', () => require(MOCKS).api()); isEnabled: jest.fn(),
jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas()); createAxiosInstance: jest.fn(() => ({
jest.mock('librechat-data-provider', () => require(MOCKS).dataProvider()); get: jest.fn(),
jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel()); post: jest.fn(),
jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel()); put: jest.fn(),
jest.mock('~/models', () => require(MOCKS).sharedModels()); delete: jest.fn(),
jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth()); })),
jest.mock('~/server/middleware', () => require(MOCKS).middlewarePassthrough()); logAxiosError: jest.fn(),
jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils()); }));
jest.mock('~/server/utils/import', () => require(MOCKS).importUtils());
jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores()); jest.mock('@librechat/data-schemas', () => ({
jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup()); logger: {
jest.mock('multer', () => require(MOCKS).multerLib()); debug: jest.fn(),
jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint()); info: jest.fn(),
jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint()); warn: jest.fn(),
error: jest.fn(),
},
createModels: jest.fn(() => ({
User: {},
Conversation: {},
Message: {},
SharedLink: {},
})),
}));
jest.mock('~/models/Conversation', () => ({
getConvosByCursor: jest.fn(),
getConvo: jest.fn(),
deleteConvos: jest.fn(),
saveConvo: jest.fn(),
}));
jest.mock('~/models/ToolCall', () => ({
deleteToolCalls: jest.fn(),
}));
jest.mock('~/models', () => ({
deleteAllSharedLinks: jest.fn(),
deleteConvoSharedLink: jest.fn(),
}));
jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next());
jest.mock('~/server/middleware', () => ({
createImportLimiters: jest.fn(() => ({
importIpLimiter: (req, res, next) => next(),
importUserLimiter: (req, res, next) => next(),
})),
createForkLimiters: jest.fn(() => ({
forkIpLimiter: (req, res, next) => next(),
forkUserLimiter: (req, res, next) => next(),
})),
configMiddleware: (req, res, next) => next(),
validateConvoAccess: (req, res, next) => next(),
}));
jest.mock('~/server/utils/import/fork', () => ({
forkConversation: jest.fn(),
duplicateConversation: jest.fn(),
}));
jest.mock('~/server/utils/import', () => ({
importConversations: jest.fn(),
}));
jest.mock('~/cache/getLogStores', () => jest.fn());
jest.mock('~/server/routes/files/multer', () => ({
storage: {},
importFileFilter: jest.fn(),
}));
jest.mock('multer', () => {
return jest.fn(() => ({
single: jest.fn(() => (req, res, next) => {
req.file = { path: '/tmp/test-file.json' };
next();
}),
}));
});
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
GEN_TITLE: 'GEN_TITLE',
},
EModelEndpoint: {
azureAssistants: 'azureAssistants',
assistants: 'assistants',
},
}));
jest.mock('~/server/services/Endpoints/azureAssistants', () => ({
initializeClient: jest.fn(),
}));
jest.mock('~/server/services/Endpoints/assistants', () => ({
initializeClient: jest.fn(),
}));
describe('Convos Routes', () => { describe('Convos Routes', () => {
let app; let app;

View file

@ -32,9 +32,6 @@ jest.mock('@librechat/api', () => {
getFlowState: jest.fn(), getFlowState: jest.fn(),
completeOAuthFlow: jest.fn(), completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(), generateFlowId: jest.fn(),
resolveStateToFlowId: jest.fn(async (state) => state),
storeStateMapping: jest.fn(),
deleteStateMapping: jest.fn(),
}, },
MCPTokenStorage: { MCPTokenStorage: {
storeTokens: jest.fn(), storeTokens: jest.fn(),
@ -183,10 +180,7 @@ describe('MCP Routes', () => {
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
authorizationUrl: 'https://oauth.example.com/auth', authorizationUrl: 'https://oauth.example.com/auth',
flowId: 'test-user-id:test-server', flowId: 'test-user-id:test-server',
flowMetadata: { state: 'random-state-value' },
}); });
MCPOAuthHandler.storeStateMapping.mockResolvedValue();
mockFlowManager.initFlow = jest.fn().mockResolvedValue();
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id', userId: 'test-user-id',
@ -373,121 +367,6 @@ describe('MCP Routes', () => {
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`); expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
}); });
describe('CSRF fallback via active PENDING flow', () => {
it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'PENDING',
createdAt: Date.now(),
}),
completeFlow: jest.fn().mockResolvedValue(true),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: {},
clientInfo: {},
codeVerifier: 'test-verifier',
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({
access_token: 'test-token',
});
MCPTokenStorage.storeTokens.mockResolvedValue();
mockRegistryInstance.getServerConfig.mockResolvedValue({});
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue({
fetchTools: jest.fn().mockResolvedValue([]),
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getOAuthReconnectionManager.mockReturnValue({
clearReconnection: jest.fn(),
});
require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue();
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
});
it('should reject when no PENDING flow exists and no cookies are present', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue(null),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'COMPLETED',
createdAt: Date.now(),
}),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'PENDING',
createdAt: Date.now() - 3 * 60 * 1000,
}),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
});
it('should handle OAuth callback successfully', async () => { it('should handle OAuth callback successfully', async () => {
// mockRegistryInstance is defined at the top of the file // mockRegistryInstance is defined at the top of the file
const mockFlowManager = { const mockFlowManager = {
@ -1693,14 +1572,12 @@ describe('MCP Routes', () => {
it('should return all server configs for authenticated user', async () => { it('should return all server configs for authenticated user', async () => {
const mockServerConfigs = { const mockServerConfigs = {
'server-1': { 'server-1': {
type: 'sse', endpoint: 'http://server1.com',
url: 'http://server1.com/sse', name: 'Server 1',
title: 'Server 1',
}, },
'server-2': { 'server-2': {
type: 'sse', endpoint: 'http://server2.com',
url: 'http://server2.com/sse', name: 'Server 2',
title: 'Server 2',
}, },
}; };
@ -1709,18 +1586,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/servers'); const response = await request(app).get('/api/mcp/servers');
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body['server-1']).toMatchObject({ expect(response.body).toEqual(mockServerConfigs);
type: 'sse',
url: 'http://server1.com/sse',
title: 'Server 1',
});
expect(response.body['server-2']).toMatchObject({
type: 'sse',
url: 'http://server2.com/sse',
title: 'Server 2',
});
expect(response.body['server-1'].headers).toBeUndefined();
expect(response.body['server-2'].headers).toBeUndefined();
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id');
}); });
@ -1775,10 +1641,10 @@ describe('MCP Routes', () => {
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
expect(response.status).toBe(201); expect(response.status).toBe(201);
expect(response.body.serverName).toBe('test-sse-server'); expect(response.body).toEqual({
expect(response.body.type).toBe('sse'); serverName: 'test-sse-server',
expect(response.body.url).toBe('https://mcp-server.example.com/sse'); ...validConfig,
expect(response.body.title).toBe('Test SSE Server'); });
expect(mockRegistryInstance.addServer).toHaveBeenCalledWith( expect(mockRegistryInstance.addServer).toHaveBeenCalledWith(
'temp_server_name', 'temp_server_name',
expect.objectContaining({ expect.objectContaining({
@ -1832,78 +1698,6 @@ describe('MCP Routes', () => {
expect(response.body.message).toBe('Invalid configuration'); expect(response.body.message).toBe('Invalid configuration');
}); });
it('should reject SSE URL containing env variable references', async () => {
const response = await request(app)
.post('/api/mcp/servers')
.send({
config: {
type: 'sse',
url: 'http://attacker.com/?secret=${JWT_SECRET}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
});
it('should reject streamable-http URL containing env variable references', async () => {
const response = await request(app)
.post('/api/mcp/servers')
.send({
config: {
type: 'streamable-http',
url: 'http://attacker.com/?key=${CREDS_KEY}&iv=${CREDS_IV}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
});
it('should reject websocket URL containing env variable references', async () => {
const response = await request(app)
.post('/api/mcp/servers')
.send({
config: {
type: 'websocket',
url: 'ws://attacker.com/?secret=${MONGO_URI}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
});
it('should redact secrets from create response', async () => {
const validConfig = {
type: 'sse',
url: 'https://mcp-server.example.com/sse',
title: 'Test Server',
};
mockRegistryInstance.addServer.mockResolvedValue({
serverName: 'test-server',
config: {
...validConfig,
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'admin-secret-key' },
oauth: { client_id: 'cid', client_secret: 'admin-oauth-secret' },
headers: { Authorization: 'Bearer leaked-token' },
},
});
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
expect(response.status).toBe(201);
expect(response.body.apiKey?.key).toBeUndefined();
expect(response.body.oauth?.client_secret).toBeUndefined();
expect(response.body.headers).toBeUndefined();
expect(response.body.apiKey?.source).toBe('admin');
expect(response.body.oauth?.client_id).toBe('cid');
});
it('should return 500 when registry throws error', async () => { it('should return 500 when registry throws error', async () => {
const validConfig = { const validConfig = {
type: 'sse', type: 'sse',
@ -1933,9 +1727,7 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/servers/test-server'); const response = await request(app).get('/api/mcp/servers/test-server');
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body.type).toBe('sse'); expect(response.body).toEqual(mockConfig);
expect(response.body.url).toBe('https://mcp-server.example.com/sse');
expect(response.body.title).toBe('Test Server');
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
'test-server', 'test-server',
'test-user-id', 'test-user-id',
@ -1951,29 +1743,6 @@ describe('MCP Routes', () => {
expect(response.body).toEqual({ message: 'MCP server not found' }); expect(response.body).toEqual({ message: 'MCP server not found' });
}); });
it('should redact secrets from get response', async () => {
mockRegistryInstance.getServerConfig.mockResolvedValue({
type: 'sse',
url: 'https://mcp-server.example.com/sse',
title: 'Secret Server',
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'decrypted-admin-key' },
oauth: { client_id: 'cid', client_secret: 'decrypted-oauth-secret' },
headers: { Authorization: 'Bearer internal-token' },
oauth_headers: { 'X-OAuth': 'secret-value' },
});
const response = await request(app).get('/api/mcp/servers/secret-server');
expect(response.status).toBe(200);
expect(response.body.title).toBe('Secret Server');
expect(response.body.apiKey?.key).toBeUndefined();
expect(response.body.apiKey?.source).toBe('admin');
expect(response.body.oauth?.client_secret).toBeUndefined();
expect(response.body.oauth?.client_id).toBe('cid');
expect(response.body.headers).toBeUndefined();
expect(response.body.oauth_headers).toBeUndefined();
});
it('should return 500 when registry throws error', async () => { it('should return 500 when registry throws error', async () => {
mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error')); mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error'));
@ -2000,9 +1769,7 @@ describe('MCP Routes', () => {
.send({ config: updatedConfig }); .send({ config: updatedConfig });
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body.type).toBe('sse'); expect(response.body).toEqual(updatedConfig);
expect(response.body.url).toBe('https://updated-mcp-server.example.com/sse');
expect(response.body.title).toBe('Updated Server');
expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith( expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith(
'test-server', 'test-server',
expect.objectContaining({ expect.objectContaining({
@ -2014,35 +1781,6 @@ describe('MCP Routes', () => {
); );
}); });
it('should redact secrets from update response', async () => {
const validConfig = {
type: 'sse',
url: 'https://mcp-server.example.com/sse',
title: 'Updated Server',
};
mockRegistryInstance.updateServer.mockResolvedValue({
...validConfig,
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'preserved-admin-key' },
oauth: { client_id: 'cid', client_secret: 'preserved-oauth-secret' },
headers: { Authorization: 'Bearer internal-token' },
env: { DATABASE_URL: 'postgres://admin:pass@localhost/db' },
});
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({ config: validConfig });
expect(response.status).toBe(200);
expect(response.body.title).toBe('Updated Server');
expect(response.body.apiKey?.key).toBeUndefined();
expect(response.body.apiKey?.source).toBe('admin');
expect(response.body.oauth?.client_secret).toBeUndefined();
expect(response.body.oauth?.client_id).toBe('cid');
expect(response.body.headers).toBeUndefined();
expect(response.body.env).toBeUndefined();
});
it('should return 400 for invalid configuration', async () => { it('should return 400 for invalid configuration', async () => {
const invalidConfig = { const invalidConfig = {
type: 'sse', type: 'sse',
@ -2059,51 +1797,6 @@ describe('MCP Routes', () => {
expect(response.body.errors).toBeDefined(); expect(response.body.errors).toBeDefined();
}); });
it('should reject SSE URL containing env variable references', async () => {
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({
config: {
type: 'sse',
url: 'http://attacker.com/?secret=${JWT_SECRET}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
});
it('should reject streamable-http URL containing env variable references', async () => {
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({
config: {
type: 'streamable-http',
url: 'http://attacker.com/?key=${CREDS_KEY}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
});
it('should reject websocket URL containing env variable references', async () => {
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({
config: {
type: 'websocket',
url: 'ws://attacker.com/?secret=${MONGO_URI}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
});
it('should return 500 when registry throws error', async () => { it('should return 500 when registry throws error', async () => {
const validConfig = { const validConfig = {
type: 'sse', type: 'sse',

View file

@ -1,200 +0,0 @@
const mongoose = require('mongoose');
const express = require('express');
const request = require('supertest');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
jest.mock('@librechat/agents', () => ({
sleep: jest.fn(),
}));
jest.mock('@librechat/api', () => ({
unescapeLaTeX: jest.fn((x) => x),
countTokens: jest.fn().mockResolvedValue(10),
}));
jest.mock('@librechat/data-schemas', () => ({
...jest.requireActual('@librechat/data-schemas'),
logger: {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
}));
jest.mock('librechat-data-provider', () => ({
...jest.requireActual('librechat-data-provider'),
}));
jest.mock('~/models', () => ({
saveConvo: jest.fn(),
getMessage: jest.fn(),
saveMessage: jest.fn(),
getMessages: jest.fn(),
updateMessage: jest.fn(),
deleteMessages: jest.fn(),
}));
jest.mock('~/server/services/Artifacts/update', () => ({
findAllArtifacts: jest.fn(),
replaceArtifactContent: jest.fn(),
}));
jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next());
jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(),
validateMessageReq: (req, res, next) => next(),
}));
jest.mock('~/models/Conversation', () => ({
getConvosQueried: jest.fn(),
}));
jest.mock('~/db/models', () => ({
Message: {
findOne: jest.fn(),
find: jest.fn(),
meiliSearch: jest.fn(),
},
}));
/* ─── Model-level tests: real MongoDB, proves cross-user deletion is prevented ─── */
const { messageSchema } = require('@librechat/data-schemas');
describe('deleteMessages model-level IDOR prevention', () => {
let mongoServer;
let Message;
const ownerUserId = 'user-owner-111';
const attackerUserId = 'user-attacker-222';
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
await mongoose.connect(mongoServer.getUri());
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Message.deleteMany({});
});
it("should NOT delete another user's message when attacker supplies victim messageId", async () => {
const conversationId = uuidv4();
const victimMsgId = 'victim-msg-001';
await Message.create({
messageId: victimMsgId,
conversationId,
user: ownerUserId,
text: 'Sensitive owner data',
});
await Message.deleteMany({ messageId: victimMsgId, user: attackerUserId });
const victimMsg = await Message.findOne({ messageId: victimMsgId }).lean();
expect(victimMsg).not.toBeNull();
expect(victimMsg.user).toBe(ownerUserId);
expect(victimMsg.text).toBe('Sensitive owner data');
});
it("should delete the user's own message", async () => {
const conversationId = uuidv4();
const ownMsgId = 'own-msg-001';
await Message.create({
messageId: ownMsgId,
conversationId,
user: ownerUserId,
text: 'My message',
});
const result = await Message.deleteMany({ messageId: ownMsgId, user: ownerUserId });
expect(result.deletedCount).toBe(1);
const deleted = await Message.findOne({ messageId: ownMsgId }).lean();
expect(deleted).toBeNull();
});
it('should scope deletion by conversationId, messageId, and user together', async () => {
const convoA = uuidv4();
const convoB = uuidv4();
await Message.create([
{ messageId: 'msg-a1', conversationId: convoA, user: ownerUserId, text: 'A1' },
{ messageId: 'msg-b1', conversationId: convoB, user: ownerUserId, text: 'B1' },
]);
await Message.deleteMany({ messageId: 'msg-a1', conversationId: convoA, user: attackerUserId });
const remaining = await Message.find({ user: ownerUserId }).lean();
expect(remaining).toHaveLength(2);
});
});
/* ─── Route-level tests: supertest + mocked deleteMessages ─── */
describe('DELETE /:conversationId/:messageId route handler', () => {
let app;
const { deleteMessages } = require('~/models');
const authenticatedUserId = 'user-owner-123';
beforeAll(() => {
const messagesRouter = require('../messages');
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: authenticatedUserId };
next();
});
app.use('/api/messages', messagesRouter);
});
beforeEach(() => {
jest.clearAllMocks();
});
it('should pass user and conversationId in the deleteMessages filter', async () => {
deleteMessages.mockResolvedValue({ deletedCount: 1 });
await request(app).delete('/api/messages/convo-1/msg-1');
expect(deleteMessages).toHaveBeenCalledTimes(1);
expect(deleteMessages).toHaveBeenCalledWith({
messageId: 'msg-1',
conversationId: 'convo-1',
user: authenticatedUserId,
});
});
it('should return 204 on successful deletion', async () => {
deleteMessages.mockResolvedValue({ deletedCount: 1 });
const response = await request(app).delete('/api/messages/convo-1/msg-owned');
expect(response.status).toBe(204);
expect(deleteMessages).toHaveBeenCalledWith({
messageId: 'msg-owned',
conversationId: 'convo-1',
user: authenticatedUserId,
});
});
it('should return 500 when deleteMessages throws', async () => {
deleteMessages.mockRejectedValue(new Error('DB failure'));
const response = await request(app).delete('/api/messages/convo-1/msg-1');
expect(response.status).toBe(500);
expect(response.body).toEqual({ error: 'Internal server error' });
});
});

View file

@ -143,9 +143,6 @@ router.post(
if (actions_result && actions_result.length) { if (actions_result && actions_result.length) {
const action = actions_result[0]; const action = actions_result[0];
if (action.agent_id !== agent_id) {
return res.status(403).json({ message: 'Action does not belong to this agent' });
}
metadata = { ...action.metadata, ...metadata }; metadata = { ...action.metadata, ...metadata };
} }
@ -187,7 +184,7 @@ router.post(
} }
/** @type {[Action]} */ /** @type {[Action]} */
const updatedAction = await updateAction({ action_id, agent_id }, actionUpdateData); const updatedAction = await updateAction({ action_id }, actionUpdateData);
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret']; const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
for (let field of sensitiveFields) { for (let field of sensitiveFields) {
@ -254,13 +251,7 @@ router.delete(
{ tools: updatedTools, actions: updatedActions }, { tools: updatedTools, actions: updatedActions },
{ updatingUserId: req.user.id, forceVersion: true }, { updatingUserId: req.user.id, forceVersion: true },
); );
const deleted = await deleteAction({ action_id, agent_id }); await deleteAction({ action_id });
if (!deleted) {
logger.warn('[Agent Action Delete] No matching action document found', {
action_id,
agent_id,
});
}
res.status(200).json({ message: 'Action deleted successfully' }); res.status(200).json({ message: 'Action deleted successfully' });
} catch (error) { } catch (error) {
const message = 'Trouble deleting the Agent Action'; const message = 'Trouble deleting the Agent Action';

View file

@ -76,21 +76,43 @@ router.get('/chat/stream/:streamId', async (req, res) => {
logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`); logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`);
const writeEvent = (event) => { // Send sync event with resume state for ALL reconnecting clients
// This supports multi-tab scenarios where each tab needs run step data
if (isResume) {
const resumeState = await GenerationJobManager.getResumeState(streamId);
if (resumeState && !res.writableEnded) {
// Send sync event with run steps AND aggregatedContent
// Client will use aggregatedContent to initialize message state
res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`);
if (typeof res.flush === 'function') {
res.flush();
}
logger.debug(
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`,
);
}
}
const result = await GenerationJobManager.subscribe(
streamId,
(event) => {
if (!res.writableEnded) { if (!res.writableEnded) {
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
if (typeof res.flush === 'function') { if (typeof res.flush === 'function') {
res.flush(); res.flush();
} }
} }
}; },
(event) => {
const onDone = (event) => { if (!res.writableEnded) {
writeEvent(event); res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
if (typeof res.flush === 'function') {
res.flush();
}
res.end(); res.end();
}; }
},
const onError = (error) => { (error) => {
if (!res.writableEnded) { if (!res.writableEnded) {
res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`); res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`);
if (typeof res.flush === 'function') { if (typeof res.flush === 'function') {
@ -98,40 +120,8 @@ router.get('/chat/stream/:streamId', async (req, res) => {
} }
res.end(); res.end();
} }
}; },
let result;
if (isResume) {
const { subscription, resumeState, pendingEvents } =
await GenerationJobManager.subscribeWithResume(streamId, writeEvent, onDone, onError);
if (!res.writableEnded) {
if (resumeState) {
res.write(
`event: message\ndata: ${JSON.stringify({ sync: true, resumeState, pendingEvents })}\n\n`,
); );
if (typeof res.flush === 'function') {
res.flush();
}
GenerationJobManager.markSyncSent(streamId);
logger.debug(
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps, ${pendingEvents.length} pending events`,
);
} else if (pendingEvents.length > 0) {
for (const event of pendingEvents) {
writeEvent(event);
}
logger.warn(
`[AgentStream] Resume state null for ${streamId}, replayed ${pendingEvents.length} gap events directly`,
);
}
}
result = subscription;
} else {
result = await GenerationJobManager.subscribe(streamId, writeEvent, onDone, onError);
}
if (!result) { if (!result) {
return res.status(404).json({ error: 'Failed to subscribe to stream' }); return res.status(404).json({ error: 'Failed to subscribe to stream' });

View file

@ -60,9 +60,6 @@ router.post('/:assistant_id', async (req, res) => {
if (actions_result && actions_result.length) { if (actions_result && actions_result.length) {
const action = actions_result[0]; const action = actions_result[0];
if (action.assistant_id !== assistant_id) {
return res.status(403).json({ message: 'Action does not belong to this assistant' });
}
metadata = { ...action.metadata, ...metadata }; metadata = { ...action.metadata, ...metadata };
} }
@ -120,7 +117,7 @@ router.post('/:assistant_id', async (req, res) => {
// For new actions, use the assistant owner's user ID // For new actions, use the assistant owner's user ID
actionUpdateData.user = assistant_user || req.user.id; actionUpdateData.user = assistant_user || req.user.id;
} }
promises.push(updateAction({ action_id, assistant_id }, actionUpdateData)); promises.push(updateAction({ action_id }, actionUpdateData));
/** @type {[AssistantDocument, Action]} */ /** @type {[AssistantDocument, Action]} */
let [assistantDocument, updatedAction] = await Promise.all(promises); let [assistantDocument, updatedAction] = await Promise.all(promises);
@ -199,15 +196,9 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
assistantUpdateData.user = req.user.id; assistantUpdateData.user = req.user.id;
} }
promises.push(updateAssistantDoc({ assistant_id }, assistantUpdateData)); promises.push(updateAssistantDoc({ assistant_id }, assistantUpdateData));
promises.push(deleteAction({ action_id, assistant_id })); promises.push(deleteAction({ action_id }));
const [, deletedAction] = await Promise.all(promises); await Promise.all(promises);
if (!deletedAction) {
logger.warn('[Assistant Action Delete] No matching action document found', {
action_id,
assistant_id,
});
}
res.status(200).json({ message: 'Action deleted successfully' }); res.status(200).json({ message: 'Action deleted successfully' });
} catch (error) { } catch (error) {
const message = 'Trouble deleting the Assistant Action'; const message = 'Trouble deleting the Assistant Action';

View file

@ -63,7 +63,7 @@ router.post(
resetPasswordController, resetPasswordController,
); );
router.post('/2fa/enable', middleware.requireJwtAuth, enable2FA); router.get('/2fa/enable', middleware.requireJwtAuth, enable2FA);
router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA); router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA);
router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken); router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken);
router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA); router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA);

View file

@ -16,7 +16,9 @@ const sharedLinksEnabled =
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS); process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
const publicSharedLinksEnabled = const publicSharedLinksEnabled =
sharedLinksEnabled && isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC); sharedLinksEnabled &&
(process.env.ALLOW_SHARED_LINKS_PUBLIC === undefined ||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC));
const sharePointFilePickerEnabled = isEnabled(process.env.ENABLE_SHAREPOINT_FILEPICKER); const sharePointFilePickerEnabled = isEnabled(process.env.ENABLE_SHAREPOINT_FILEPICKER);
const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS); const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS);

View file

@ -1,7 +1,7 @@
const multer = require('multer'); const multer = require('multer');
const express = require('express'); const express = require('express');
const { sleep } = require('@librechat/agents'); const { sleep } = require('@librechat/agents');
const { isEnabled, resolveImportMaxFileSize } = require('@librechat/api'); const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { const {
@ -224,27 +224,8 @@ router.post('/update', validateConvoAccess, async (req, res) => {
}); });
const { importIpLimiter, importUserLimiter } = createImportLimiters(); const { importIpLimiter, importUserLimiter } = createImportLimiters();
/** Fork and duplicate share one rate-limit budget (same "clone" operation class) */
const { forkIpLimiter, forkUserLimiter } = createForkLimiters(); const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
const importMaxFileSize = resolveImportMaxFileSize(); const upload = multer({ storage: storage, fileFilter: importFileFilter });
const upload = multer({
storage,
fileFilter: importFileFilter,
limits: { fileSize: importMaxFileSize },
});
const uploadSingle = upload.single('file');
function handleUpload(req, res, next) {
uploadSingle(req, res, (err) => {
if (err && err.code === 'LIMIT_FILE_SIZE') {
return res.status(413).json({ message: 'File exceeds the maximum allowed size' });
}
if (err) {
return next(err);
}
next();
});
}
/** /**
* Imports a conversation from a JSON file and saves it to the database. * Imports a conversation from a JSON file and saves it to the database.
@ -257,7 +238,7 @@ router.post(
importIpLimiter, importIpLimiter,
importUserLimiter, importUserLimiter,
configMiddleware, configMiddleware,
handleUpload, upload.single('file'),
async (req, res) => { async (req, res) => {
try { try {
/* TODO: optimize to return imported conversations and add manually */ /* TODO: optimize to return imported conversations and add manually */
@ -299,7 +280,7 @@ router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
} }
}); });
router.post('/duplicate', forkIpLimiter, forkUserLimiter, async (req, res) => { router.post('/duplicate', async (req, res) => {
const { conversationId, title } = req.body; const { conversationId, title } = req.body;
try { try {

View file

@ -2,12 +2,12 @@ const fs = require('fs').promises;
const express = require('express'); const express = require('express');
const { EnvVar } = require('@librechat/agents'); const { EnvVar } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { verifyAgentUploadPermission } = require('@librechat/api');
const { const {
Time, Time,
isUUID, isUUID,
CacheKeys, CacheKeys,
FileSources, FileSources,
SystemRoles,
ResourceType, ResourceType,
EModelEndpoint, EModelEndpoint,
PermissionBits, PermissionBits,
@ -381,15 +381,48 @@ router.post('/', async (req, res) => {
return await processFileUpload({ req, res, metadata }); return await processFileUpload({ req, res, metadata });
} }
const denied = await verifyAgentUploadPermission({ /**
req, * Check agent permissions for permanent agent file uploads (not message attachments).
res, * Message attachments (message_file=true) are temporary files for a single conversation
metadata, * and should be allowed for users who can chat with the agent.
getAgent, * Permanent file uploads to tool_resources require EDIT permission.
checkPermission, */
const isMessageAttachment = metadata.message_file === true || metadata.message_file === 'true';
if (metadata.agent_id && metadata.tool_resource && !isMessageAttachment) {
const userId = req.user.id;
/** Admin users bypass permission checks */
if (req.user.role !== SystemRoles.ADMIN) {
const agent = await getAgent({ id: metadata.agent_id });
if (!agent) {
return res.status(404).json({
error: 'Not Found',
message: 'Agent not found',
}); });
if (denied) { }
return;
/** Check if user is the author or has edit permission */
if (agent.author.toString() !== userId) {
const hasEditPermission = await checkPermission({
userId,
role: req.user.role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.EDIT,
});
if (!hasEditPermission) {
logger.warn(
`[/files] User ${userId} denied upload to agent ${metadata.agent_id} (insufficient permissions)`,
);
return res.status(403).json({
error: 'Forbidden',
message: 'Insufficient permissions to upload files to this agent',
});
}
}
}
} }
return await processAgentFileUpload({ req, res, metadata }); return await processAgentFileUpload({ req, res, metadata });

View file

@ -1,376 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { createMethods } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const {
SystemRoles,
AccessRoleIds,
ResourceType,
PrincipalType,
} = require('librechat-data-provider');
const { createAgent } = require('~/models/Agent');
jest.mock('~/server/services/Files/process', () => ({
processAgentFileUpload: jest.fn().mockImplementation(async ({ res }) => {
return res.status(200).json({ message: 'Agent file uploaded', file_id: 'test-file-id' });
}),
processImageFile: jest.fn().mockImplementation(async ({ res }) => {
return res.status(200).json({ message: 'Image processed' });
}),
filterFile: jest.fn(),
}));
jest.mock('fs', () => {
const actualFs = jest.requireActual('fs');
return {
...actualFs,
promises: {
...actualFs.promises,
unlink: jest.fn().mockResolvedValue(undefined),
},
};
});
const fs = require('fs');
const { processAgentFileUpload } = require('~/server/services/Files/process');
const router = require('~/server/routes/files/images');
describe('POST /images - Agent Upload Permission Check (Integration)', () => {
let mongoServer;
let authorId;
let otherUserId;
let agentCustomId;
let User;
let Agent;
let AclEntry;
let methods;
let modelsToCleanup = [];
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
const { createModels } = require('@librechat/data-schemas');
const models = createModels(mongoose);
modelsToCleanup = Object.keys(models);
Object.assign(mongoose.models, models);
methods = createMethods(mongoose);
User = models.User;
Agent = models.Agent;
AclEntry = models.AclEntry;
await methods.seedDefaultRoles();
});
afterAll(async () => {
const collections = mongoose.connection.collections;
for (const key in collections) {
await collections[key].deleteMany({});
}
for (const modelName of modelsToCleanup) {
if (mongoose.models[modelName]) {
delete mongoose.models[modelName];
}
}
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Agent.deleteMany({});
await User.deleteMany({});
await AclEntry.deleteMany({});
authorId = new mongoose.Types.ObjectId();
otherUserId = new mongoose.Types.ObjectId();
agentCustomId = `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`;
await User.create({ _id: authorId, username: 'author', email: 'author@test.com' });
await User.create({ _id: otherUserId, username: 'other', email: 'other@test.com' });
jest.clearAllMocks();
});
const createAppWithUser = (userId, userRole = SystemRoles.USER) => {
const app = express();
app.use(express.json());
app.use((req, _res, next) => {
if (req.method === 'POST') {
req.file = {
originalname: 'test.png',
mimetype: 'image/png',
size: 100,
path: '/tmp/t.png',
filename: 'test.png',
};
req.file_id = uuidv4();
}
next();
});
app.use((req, _res, next) => {
req.user = { id: userId.toString(), role: userRole };
req.app = { locals: {} };
req.config = { fileStrategy: 'local', paths: { imageOutput: '/tmp/images' } };
next();
});
app.use('/images', router);
return app;
};
it('should return 403 when user has no permission on agent', async () => {
await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(403);
expect(response.body.error).toBe('Forbidden');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
it('should allow upload for agent owner', async () => {
await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const app = createAppWithUser(authorId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should allow upload for admin regardless of ownership', async () => {
await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const app = createAppWithUser(otherUserId, SystemRoles.ADMIN);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should allow upload for user with EDIT permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_EDITOR,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should deny upload for user with only VIEW permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(403);
expect(response.body.error).toBe('Forbidden');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
it('should skip permission check for regular image uploads without agent_id/tool_resource', async () => {
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
});
it('should return 404 for non-existent agent', async () => {
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: 'agent_nonexistent123456789',
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(404);
expect(response.body.error).toBe('Not Found');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
it('should allow message_file attachment (boolean true) without EDIT permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
message_file: true,
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should allow message_file attachment (string "true") without EDIT permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
message_file: 'true',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should deny upload when message_file is false (not a message attachment)', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
message_file: false,
file_id: uuidv4(),
});
expect(response.status).toBe(403);
expect(response.body.error).toBe('Forbidden');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
});

View file

@ -2,15 +2,12 @@ const path = require('path');
const fs = require('fs').promises; const fs = require('fs').promises;
const express = require('express'); const express = require('express');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { verifyAgentUploadPermission } = require('@librechat/api');
const { isAssistantsEndpoint } = require('librechat-data-provider'); const { isAssistantsEndpoint } = require('librechat-data-provider');
const { const {
processAgentFileUpload, processAgentFileUpload,
processImageFile, processImageFile,
filterFile, filterFile,
} = require('~/server/services/Files/process'); } = require('~/server/services/Files/process');
const { checkPermission } = require('~/server/services/PermissionService');
const { getAgent } = require('~/models/Agent');
const router = express.Router(); const router = express.Router();
@ -25,16 +22,6 @@ router.post('/', async (req, res) => {
metadata.file_id = req.file_id; metadata.file_id = req.file_id;
if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) { if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) {
const denied = await verifyAgentUploadPermission({
req,
res,
metadata,
getAgent,
checkPermission,
});
if (denied) {
return;
}
return await processAgentFileUpload({ req, res, metadata }); return await processAgentFileUpload({ req, res, metadata });
} }

View file

@ -13,7 +13,6 @@ const {
MCPOAuthHandler, MCPOAuthHandler,
MCPTokenStorage, MCPTokenStorage,
setOAuthSession, setOAuthSession,
PENDING_STALE_MS,
getUserMCPAuthMap, getUserMCPAuthMap,
validateOAuthCsrf, validateOAuthCsrf,
OAUTH_CSRF_COOKIE, OAUTH_CSRF_COOKIE,
@ -50,18 +49,6 @@ const router = Router();
const OAUTH_CSRF_COOKIE_PATH = '/api/mcp'; const OAUTH_CSRF_COOKIE_PATH = '/api/mcp';
const checkMCPUsePermissions = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkMCPCreate = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
/** /**
* Get all MCP tools available to the user * Get all MCP tools available to the user
* Returns only MCP tools, completely decoupled from regular LibreChat tools * Returns only MCP tools, completely decoupled from regular LibreChat tools
@ -104,11 +91,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
} }
const oauthHeaders = await getOAuthHeaders(serverName, userId); const oauthHeaders = await getOAuthHeaders(serverName, userId);
const { const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
authorizationUrl,
flowId: oauthFlowId,
flowMetadata,
} = await MCPOAuthHandler.initiateOAuthFlow(
serverName, serverName,
serverUrl, serverUrl,
userId, userId,
@ -118,7 +101,6 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager);
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH); setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
res.redirect(authorizationUrl); res.redirect(authorizationUrl);
} catch (error) { } catch (error) {
@ -161,53 +143,31 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
return res.redirect(`${basePath}/oauth/error?error=missing_state`); return res.redirect(`${basePath}/oauth/error?error=missing_state`);
} }
const flowsCache = getLogStores(CacheKeys.FLOWS); const flowId = state;
const flowManager = getFlowStateManager(flowsCache); logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager);
if (!flowId) {
logger.error('[MCP OAuth] Could not resolve state to flow ID', { state });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId });
const flowParts = flowId.split(':'); const flowParts = flowId.split(':');
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) { if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
logger.error('[MCP OAuth] Invalid flow ID format', { flowId }); logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`); return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
} }
const [flowUserId] = flowParts; const [flowUserId] = flowParts;
if (
const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH); !validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId); !validateOAuthSession(req, flowUserId)
let hasActiveFlow = false; ) {
if (!hasCsrf && !hasSession) { logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS;
if (hasActiveFlow) {
logger.debug(
'[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow',
{
flowId,
},
);
}
}
if (!hasCsrf && !hasSession && !hasActiveFlow) {
logger.error(
'[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow',
{
flowId, flowId,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
}, });
);
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
} }
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId); logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager); const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
@ -321,13 +281,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const toolFlowId = flowState.metadata?.toolFlowId; const toolFlowId = flowState.metadata?.toolFlowId;
if (toolFlowId) { if (toolFlowId) {
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId }); logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
if (!completed) {
logger.warn(
'[MCP OAuth] Tool flow state not found during completion — waiter will time out',
{ toolFlowId },
);
}
} }
/** Redirect to success page with flowId and serverName */ /** Redirect to success page with flowId and serverName */
@ -482,12 +436,7 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
* Reinitialize MCP server * Reinitialize MCP server
* This endpoint allows reinitializing a specific MCP server * This endpoint allows reinitializing a specific MCP server
*/ */
router.post( router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => {
'/:serverName/reinitialize',
requireJwtAuth,
checkMCPUsePermissions,
setOAuthSession,
async (req, res) => {
try { try {
const { serverName } = req.params; const { serverName } = req.params;
const user = createSafeUser(req.user); const user = createSafeUser(req.user);
@ -549,8 +498,7 @@ router.post(
logger.error('[MCP Reinitialize] Unexpected error', error); logger.error('[MCP Reinitialize] Unexpected error', error);
res.status(500).json({ error: 'Internal server error' }); res.status(500).json({ error: 'Internal server error' });
} }
}, });
);
/** /**
* Get connection status for all MCP servers * Get connection status for all MCP servers
@ -657,7 +605,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
* Check which authentication values exist for a specific MCP server * Check which authentication values exist for a specific MCP server
* This endpoint returns only boolean flags indicating if values are set, not the actual values * This endpoint returns only boolean flags indicating if values are set, not the actual values
*/ */
router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, async (req, res) => { router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
try { try {
const { serverName } = req.params; const { serverName } = req.params;
const user = req.user; const user = req.user;
@ -714,6 +662,19 @@ async function getOAuthHeaders(serverName, userId) {
MCP Server CRUD Routes (User-Managed MCP Servers) MCP Server CRUD Routes (User-Managed MCP Servers)
*/ */
// Permission checkers for MCP server management
const checkMCPUsePermissions = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkMCPCreate = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
/** /**
* Get list of accessible MCP servers * Get list of accessible MCP servers
* @route GET /api/mcp/servers * @route GET /api/mcp/servers

View file

@ -404,8 +404,8 @@ router.put('/:conversationId/:messageId/feedback', validateMessageReq, async (re
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => { router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
try { try {
const { conversationId, messageId } = req.params; const { messageId } = req.params;
await deleteMessages({ messageId, conversationId, user: req.user.id }); await deleteMessages({ messageId });
res.status(204).send(); res.status(204).send();
} catch (error) { } catch (error) {
logger.error('Error deleting message:', error); logger.error('Error deleting message:', error);

View file

@ -19,7 +19,9 @@ const allowSharedLinks =
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS); process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
if (allowSharedLinks) { if (allowSharedLinks) {
const allowSharedLinksPublic = isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC); const allowSharedLinksPublic =
process.env.ALLOW_SHARED_LINKS_PUBLIC === undefined ||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC);
router.get( router.get(
'/:shareId', '/:shareId',
allowSharedLinksPublic ? (req, res, next) => next() : requireJwtAuth, allowSharedLinksPublic ? (req, res, next) => next() : requireJwtAuth,

View file

@ -1,7 +1,6 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { initializeAgent, validateAgentModel } = require('@librechat/api'); const { initializeAgent, validateAgentModel } = require('@librechat/api');
const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent'); const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent');
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
const { getConvoFiles } = require('~/models/Conversation'); const { getConvoFiles } = require('~/models/Conversation');
const { getAgent } = require('~/models/Agent'); const { getAgent } = require('~/models/Agent');
const db = require('~/models'); const db = require('~/models');
@ -109,7 +108,6 @@ const processAddedConvo = async ({
getUserKeyValues: db.getUserKeyValues, getUserKeyValues: db.getUserKeyValues,
getToolFilesByIds: db.getToolFilesByIds, getToolFilesByIds: db.getToolFilesByIds,
getCodeGeneratedFiles: db.getCodeGeneratedFiles, getCodeGeneratedFiles: db.getCodeGeneratedFiles,
filterFilesByAgentAccess,
}, },
); );

View file

@ -10,8 +10,6 @@ const {
createSequentialChainEdges, createSequentialChainEdges,
} = require('@librechat/api'); } = require('@librechat/api');
const { const {
ResourceType,
PermissionBits,
EModelEndpoint, EModelEndpoint,
isAgentsEndpoint, isAgentsEndpoint,
getResponseSender, getResponseSender,
@ -22,9 +20,7 @@ const {
getDefaultHandlers, getDefaultHandlers,
} = require('~/server/controllers/agents/callbacks'); } = require('~/server/controllers/agents/callbacks');
const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService');
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
const { getModelsConfig } = require('~/server/controllers/ModelController'); const { getModelsConfig } = require('~/server/controllers/ModelController');
const { checkPermission } = require('~/server/services/PermissionService');
const AgentClient = require('~/server/controllers/agents/client'); const AgentClient = require('~/server/controllers/agents/client');
const { getConvoFiles } = require('~/models/Conversation'); const { getConvoFiles } = require('~/models/Conversation');
const { processAddedConvo } = require('./addedConvo'); const { processAddedConvo } = require('./addedConvo');
@ -129,7 +125,6 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
toolRegistry: ctx.toolRegistry, toolRegistry: ctx.toolRegistry,
userMCPAuthMap: ctx.userMCPAuthMap, userMCPAuthMap: ctx.userMCPAuthMap,
tool_resources: ctx.tool_resources, tool_resources: ctx.tool_resources,
actionsEnabled: ctx.actionsEnabled,
}); });
logger.debug(`[ON_TOOL_EXECUTE] loaded ${result.loadedTools?.length ?? 0} tools`); logger.debug(`[ON_TOOL_EXECUTE] loaded ${result.loadedTools?.length ?? 0} tools`);
@ -205,7 +200,6 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
getUserCodeFiles: db.getUserCodeFiles, getUserCodeFiles: db.getUserCodeFiles,
getToolFilesByIds: db.getToolFilesByIds, getToolFilesByIds: db.getToolFilesByIds,
getCodeGeneratedFiles: db.getCodeGeneratedFiles, getCodeGeneratedFiles: db.getCodeGeneratedFiles,
filterFilesByAgentAccess,
}, },
); );
@ -217,7 +211,6 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
toolRegistry: primaryConfig.toolRegistry, toolRegistry: primaryConfig.toolRegistry,
userMCPAuthMap: primaryConfig.userMCPAuthMap, userMCPAuthMap: primaryConfig.userMCPAuthMap,
tool_resources: primaryConfig.tool_resources, tool_resources: primaryConfig.tool_resources,
actionsEnabled: primaryConfig.actionsEnabled,
}); });
const agent_ids = primaryConfig.agent_ids; const agent_ids = primaryConfig.agent_ids;
@ -236,22 +229,6 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
return null; return null;
} }
const hasAccess = await checkPermission({
userId: req.user.id,
role: req.user.role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.VIEW,
});
if (!hasAccess) {
logger.warn(
`[processAgent] User ${req.user.id} lacks VIEW access to handoff agent ${agentId}, skipping`,
);
skippedAgentIds.add(agentId);
return null;
}
const validationResult = await validateAgentModel({ const validationResult = await validateAgentModel({
req, req,
res, res,
@ -286,7 +263,6 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
getUserCodeFiles: db.getUserCodeFiles, getUserCodeFiles: db.getUserCodeFiles,
getToolFilesByIds: db.getToolFilesByIds, getToolFilesByIds: db.getToolFilesByIds,
getCodeGeneratedFiles: db.getCodeGeneratedFiles, getCodeGeneratedFiles: db.getCodeGeneratedFiles,
filterFilesByAgentAccess,
}, },
); );
@ -302,7 +278,6 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
toolRegistry: config.toolRegistry, toolRegistry: config.toolRegistry,
userMCPAuthMap: config.userMCPAuthMap, userMCPAuthMap: config.userMCPAuthMap,
tool_resources: config.tool_resources, tool_resources: config.tool_resources,
actionsEnabled: config.actionsEnabled,
}); });
agentConfigs.set(agentId, config); agentConfigs.set(agentId, config);
@ -376,19 +351,6 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
userMCPAuthMap = updatedMCPAuthMap; userMCPAuthMap = updatedMCPAuthMap;
} }
for (const [agentId, config] of agentConfigs) {
if (agentToolContexts.has(agentId)) {
continue;
}
agentToolContexts.set(agentId, {
agent: config,
toolRegistry: config.toolRegistry,
userMCPAuthMap: config.userMCPAuthMap,
tool_resources: config.tool_resources,
actionsEnabled: config.actionsEnabled,
});
}
// Ensure edges is an array when we have multiple agents (multi-agent mode) // Ensure edges is an array when we have multiple agents (multi-agent mode)
// MultiAgentGraph.categorizeEdges requires edges to be iterable // MultiAgentGraph.categorizeEdges requires edges to be iterable
if (agentConfigs.size > 0 && !edges) { if (agentConfigs.size > 0 && !edges) {

View file

@ -1,201 +0,0 @@
const mongoose = require('mongoose');
const {
ResourceType,
PermissionBits,
PrincipalType,
PrincipalModel,
} = require('librechat-data-provider');
const { MongoMemoryServer } = require('mongodb-memory-server');
const mockInitializeAgent = jest.fn();
const mockValidateAgentModel = jest.fn();
jest.mock('@librechat/agents', () => ({
...jest.requireActual('@librechat/agents'),
createContentAggregator: jest.fn(() => ({
contentParts: [],
aggregateContent: jest.fn(),
})),
}));
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
initializeAgent: (...args) => mockInitializeAgent(...args),
validateAgentModel: (...args) => mockValidateAgentModel(...args),
GenerationJobManager: { setCollectedUsage: jest.fn() },
getCustomEndpointConfig: jest.fn(),
createSequentialChainEdges: jest.fn(),
}));
jest.mock('~/server/controllers/agents/callbacks', () => ({
createToolEndCallback: jest.fn(() => jest.fn()),
getDefaultHandlers: jest.fn(() => ({})),
}));
jest.mock('~/server/services/ToolService', () => ({
loadAgentTools: jest.fn(),
loadToolsForExecution: jest.fn(),
}));
jest.mock('~/server/controllers/ModelController', () => ({
getModelsConfig: jest.fn().mockResolvedValue({}),
}));
let agentClientArgs;
jest.mock('~/server/controllers/agents/client', () => {
return jest.fn().mockImplementation((args) => {
agentClientArgs = args;
return {};
});
});
jest.mock('./addedConvo', () => ({
processAddedConvo: jest.fn().mockResolvedValue({ userMCPAuthMap: undefined }),
}));
jest.mock('~/cache', () => ({
logViolation: jest.fn(),
}));
const { initializeClient } = require('./initialize');
const { createAgent } = require('~/models/Agent');
const { User, AclEntry } = require('~/db/models');
const PRIMARY_ID = 'agent_primary';
const TARGET_ID = 'agent_target';
const AUTHORIZED_ID = 'agent_authorized';
describe('initializeClient — processAgent ACL gate', () => {
let mongoServer;
let testUser;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.connection.dropDatabase();
jest.clearAllMocks();
agentClientArgs = undefined;
testUser = await User.create({
email: 'test@example.com',
name: 'Test User',
username: 'testuser',
role: 'USER',
});
mockValidateAgentModel.mockResolvedValue({ isValid: true });
});
const makeReq = () => ({
user: { id: testUser._id.toString(), role: 'USER' },
body: { conversationId: 'conv_1', files: [] },
config: { endpoints: {} },
_resumableStreamId: null,
});
const makeEndpointOption = () => ({
agent: Promise.resolve({
id: PRIMARY_ID,
name: 'Primary',
provider: 'openai',
model: 'gpt-4',
tools: [],
}),
model_parameters: { model: 'gpt-4' },
endpoint: 'agents',
});
const makePrimaryConfig = (edges) => ({
id: PRIMARY_ID,
endpoint: 'agents',
edges,
toolDefinitions: [],
toolRegistry: new Map(),
userMCPAuthMap: null,
tool_resources: {},
resendFiles: true,
maxContextTokens: 4096,
});
it('should skip handoff agent and filter its edge when user lacks VIEW access', async () => {
await createAgent({
id: TARGET_ID,
name: 'Target Agent',
provider: 'openai',
model: 'gpt-4',
author: new mongoose.Types.ObjectId(),
tools: [],
});
const edges = [{ from: PRIMARY_ID, to: TARGET_ID, edgeType: 'handoff' }];
mockInitializeAgent.mockResolvedValue(makePrimaryConfig(edges));
await initializeClient({
req: makeReq(),
res: {},
signal: new AbortController().signal,
endpointOption: makeEndpointOption(),
});
expect(mockInitializeAgent).toHaveBeenCalledTimes(1);
expect(agentClientArgs.agent.edges).toEqual([]);
});
it('should initialize handoff agent and keep its edge when user has VIEW access', async () => {
const authorizedAgent = await createAgent({
id: AUTHORIZED_ID,
name: 'Authorized Agent',
provider: 'openai',
model: 'gpt-4',
author: new mongoose.Types.ObjectId(),
tools: [],
});
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: authorizedAgent._id,
permBits: PermissionBits.VIEW,
grantedBy: testUser._id,
});
const edges = [{ from: PRIMARY_ID, to: AUTHORIZED_ID, edgeType: 'handoff' }];
const handoffConfig = {
id: AUTHORIZED_ID,
edges: [],
toolDefinitions: [],
toolRegistry: new Map(),
userMCPAuthMap: null,
tool_resources: {},
};
let callCount = 0;
mockInitializeAgent.mockImplementation(() => {
callCount++;
return callCount === 1
? Promise.resolve(makePrimaryConfig(edges))
: Promise.resolve(handoffConfig);
});
await initializeClient({
req: makeReq(),
res: {},
signal: new AbortController().signal,
endpointOption: makeEndpointOption(),
});
expect(mockInitializeAgent).toHaveBeenCalledTimes(2);
expect(agentClientArgs.agent.edges).toHaveLength(1);
expect(agentClientArgs.agent.edges[0].to).toBe(AUTHORIZED_ID);
});
});

View file

@ -1,124 +0,0 @@
jest.mock('uuid', () => ({ v4: jest.fn(() => 'mock-uuid') }));
jest.mock('@librechat/data-schemas', () => ({
logger: { warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
}));
jest.mock('@librechat/agents', () => ({
getCodeBaseURL: jest.fn(() => 'http://localhost:8000'),
}));
const mockSanitizeFilename = jest.fn();
jest.mock('@librechat/api', () => ({
logAxiosError: jest.fn(),
getBasePath: jest.fn(() => ''),
sanitizeFilename: mockSanitizeFilename,
}));
jest.mock('librechat-data-provider', () => ({
...jest.requireActual('librechat-data-provider'),
mergeFileConfig: jest.fn(() => ({ serverFileSizeLimit: 100 * 1024 * 1024 })),
getEndpointFileConfig: jest.fn(() => ({
fileSizeLimit: 100 * 1024 * 1024,
supportedMimeTypes: ['*/*'],
})),
fileConfig: { checkType: jest.fn(() => true) },
}));
jest.mock('~/models', () => ({
createFile: jest.fn().mockResolvedValue({}),
getFiles: jest.fn().mockResolvedValue([]),
updateFile: jest.fn(),
claimCodeFile: jest.fn().mockResolvedValue({ file_id: 'mock-uuid', usage: 0 }),
}));
const mockSaveBuffer = jest.fn().mockResolvedValue('/uploads/user123/mock-uuid__output.csv');
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({
saveBuffer: mockSaveBuffer,
})),
}));
jest.mock('~/server/services/Files/permissions', () => ({
filterFilesByAgentAccess: jest.fn().mockResolvedValue([]),
}));
jest.mock('~/server/services/Files/images/convert', () => ({
convertImage: jest.fn(),
}));
jest.mock('~/server/utils', () => ({
determineFileType: jest.fn().mockResolvedValue({ mime: 'text/csv' }),
}));
jest.mock('axios', () =>
jest.fn().mockResolvedValue({
data: Buffer.from('file-content'),
}),
);
const { createFile } = require('~/models');
const { processCodeOutput } = require('../process');
const baseParams = {
req: {
user: { id: 'user123' },
config: {
fileStrategy: 'local',
imageOutputType: 'webp',
fileConfig: {},
},
},
id: 'code-file-id',
apiKey: 'test-key',
toolCallId: 'tool-1',
conversationId: 'conv-1',
messageId: 'msg-1',
session_id: 'session-1',
};
describe('processCodeOutput path traversal protection', () => {
beforeEach(() => {
jest.clearAllMocks();
});
test('sanitizeFilename is called with the raw artifact name', async () => {
mockSanitizeFilename.mockReturnValueOnce('output.csv');
await processCodeOutput({ ...baseParams, name: 'output.csv' });
expect(mockSanitizeFilename).toHaveBeenCalledWith('output.csv');
});
test('sanitized name is used in saveBuffer fileName', async () => {
mockSanitizeFilename.mockReturnValueOnce('sanitized-name.txt');
await processCodeOutput({ ...baseParams, name: '../../../tmp/poc.txt' });
expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../tmp/poc.txt');
const call = mockSaveBuffer.mock.calls[0][0];
expect(call.fileName).toBe('mock-uuid__sanitized-name.txt');
});
test('sanitized name is stored as filename in the file record', async () => {
mockSanitizeFilename.mockReturnValueOnce('safe-output.csv');
await processCodeOutput({ ...baseParams, name: 'unsafe/../../output.csv' });
const fileArg = createFile.mock.calls[0][0];
expect(fileArg.filename).toBe('safe-output.csv');
});
test('sanitized name is used for image file records', async () => {
const { convertImage } = require('~/server/services/Files/images/convert');
convertImage.mockResolvedValueOnce({
filepath: '/images/user123/mock-uuid.webp',
bytes: 100,
});
mockSanitizeFilename.mockReturnValueOnce('safe-chart.png');
await processCodeOutput({ ...baseParams, name: '../../../chart.png' });
expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../chart.png');
const fileArg = createFile.mock.calls[0][0];
expect(fileArg.filename).toBe('safe-chart.png');
});
});

View file

@ -3,7 +3,7 @@ const { v4 } = require('uuid');
const axios = require('axios'); const axios = require('axios');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { getCodeBaseURL } = require('@librechat/agents'); const { getCodeBaseURL } = require('@librechat/agents');
const { logAxiosError, getBasePath, sanitizeFilename } = require('@librechat/api'); const { logAxiosError, getBasePath } = require('@librechat/api');
const { const {
Tools, Tools,
megabyte, megabyte,
@ -146,13 +146,6 @@ const processCodeOutput = async ({
); );
} }
const safeName = sanitizeFilename(name);
if (safeName !== name) {
logger.warn(
`[processCodeOutput] Filename sanitized: "${name}" -> "${safeName}" | conv=${conversationId}`,
);
}
if (isImage) { if (isImage) {
const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1; const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1;
const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`);
@ -163,7 +156,7 @@ const processCodeOutput = async ({
file_id, file_id,
messageId, messageId,
usage, usage,
filename: safeName, filename: name,
conversationId, conversationId,
user: req.user.id, user: req.user.id,
type: `image/${appConfig.imageOutputType}`, type: `image/${appConfig.imageOutputType}`,
@ -207,7 +200,7 @@ const processCodeOutput = async ({
); );
} }
const fileName = `${file_id}__${safeName}`; const fileName = `${file_id}__${name}`;
const filepath = await saveBuffer({ const filepath = await saveBuffer({
userId: req.user.id, userId: req.user.id,
buffer, buffer,
@ -220,7 +213,7 @@ const processCodeOutput = async ({
filepath, filepath,
messageId, messageId,
object: 'file', object: 'file',
filename: safeName, filename: name,
type: mimeType, type: mimeType,
conversationId, conversationId,
user: req.user.id, user: req.user.id,
@ -236,11 +229,6 @@ const processCodeOutput = async ({
await createFile(file, true); await createFile(file, true);
return Object.assign(file, { messageId, toolCallId }); return Object.assign(file, { messageId, toolCallId });
} catch (error) { } catch (error) {
if (error?.message === 'Path traversal detected in filename') {
logger.warn(
`[processCodeOutput] Path traversal blocked for file "${name}" | conv=${conversationId}`,
);
}
logAxiosError({ logAxiosError({
message: 'Error downloading/processing code environment file', message: 'Error downloading/processing code environment file',
error, error,

View file

@ -58,7 +58,6 @@ jest.mock('@librechat/agents', () => ({
jest.mock('@librechat/api', () => ({ jest.mock('@librechat/api', () => ({
logAxiosError: jest.fn(), logAxiosError: jest.fn(),
getBasePath: jest.fn(() => ''), getBasePath: jest.fn(() => ''),
sanitizeFilename: jest.fn((name) => name),
})); }));
// Mock models // Mock models

View file

@ -1,69 +0,0 @@
jest.mock('@librechat/api', () => ({ deleteRagFile: jest.fn() }));
jest.mock('@librechat/data-schemas', () => ({
logger: { warn: jest.fn(), error: jest.fn() },
}));
const mockTmpBase = require('fs').mkdtempSync(
require('path').join(require('os').tmpdir(), 'crud-traversal-'),
);
jest.mock('~/config/paths', () => {
const path = require('path');
return {
publicPath: path.join(mockTmpBase, 'public'),
uploads: path.join(mockTmpBase, 'uploads'),
};
});
const fs = require('fs');
const path = require('path');
const { saveLocalBuffer } = require('../crud');
describe('saveLocalBuffer path containment', () => {
beforeAll(() => {
fs.mkdirSync(path.join(mockTmpBase, 'public', 'images'), { recursive: true });
fs.mkdirSync(path.join(mockTmpBase, 'uploads'), { recursive: true });
});
afterAll(() => {
fs.rmSync(mockTmpBase, { recursive: true, force: true });
});
test('rejects filenames with path traversal sequences', async () => {
await expect(
saveLocalBuffer({
userId: 'user1',
buffer: Buffer.from('malicious'),
fileName: '../../../etc/passwd',
basePath: 'uploads',
}),
).rejects.toThrow('Path traversal detected in filename');
});
test('rejects prefix-collision traversal (startsWith bypass)', async () => {
fs.mkdirSync(path.join(mockTmpBase, 'uploads', 'user10'), { recursive: true });
await expect(
saveLocalBuffer({
userId: 'user1',
buffer: Buffer.from('malicious'),
fileName: '../user10/evil',
basePath: 'uploads',
}),
).rejects.toThrow('Path traversal detected in filename');
});
test('allows normal filenames', async () => {
const result = await saveLocalBuffer({
userId: 'user1',
buffer: Buffer.from('safe content'),
fileName: 'file-id__output.csv',
basePath: 'uploads',
});
expect(result).toBe('/uploads/user1/file-id__output.csv');
const filePath = path.join(mockTmpBase, 'uploads', 'user1', 'file-id__output.csv');
expect(fs.existsSync(filePath)).toBe(true);
fs.unlinkSync(filePath);
});
});

View file

@ -78,13 +78,7 @@ async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' }
fs.mkdirSync(directoryPath, { recursive: true }); fs.mkdirSync(directoryPath, { recursive: true });
} }
const resolvedDir = path.resolve(directoryPath); fs.writeFileSync(path.join(directoryPath, fileName), buffer);
const resolvedPath = path.resolve(resolvedDir, fileName);
const rel = path.relative(resolvedDir, resolvedPath);
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
throw new Error('Path traversal detected in filename');
}
fs.writeFileSync(resolvedPath, buffer);
const filePath = path.posix.join('/', basePath, userId, fileName); const filePath = path.posix.join('/', basePath, userId, fileName);
@ -171,8 +165,9 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) {
} }
/** /**
* Validates that a filepath is strictly contained within a subdirectory under a base path, * Validates if a given filepath is within a specified subdirectory under a base path. This function constructs
* using path.relative to prevent prefix-collision bypasses. * the expected base path using the base, subfolder, and user id from the request, and then checks if the
* provided filepath starts with this constructed base path.
* *
* @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`. * @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`.
* @param {string} base - The base directory path. * @param {string} base - The base directory path.
@ -185,8 +180,7 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) {
const isValidPath = (req, base, subfolder, filepath) => { const isValidPath = (req, base, subfolder, filepath) => {
const normalizedBase = path.resolve(base, subfolder, req.user.id); const normalizedBase = path.resolve(base, subfolder, req.user.id);
const normalizedFilepath = path.resolve(filepath); const normalizedFilepath = path.resolve(filepath);
const rel = path.relative(normalizedBase, normalizedFilepath); return normalizedFilepath.startsWith(normalizedBase);
return !rel.startsWith('..') && !path.isAbsolute(rel) && !rel.includes(`..${path.sep}`);
}; };
/** /**

View file

@ -1,29 +1,10 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { PermissionBits, ResourceType, isEphemeralAgentId } = require('librechat-data-provider'); const { PermissionBits, ResourceType } = require('librechat-data-provider');
const { checkPermission } = require('~/server/services/PermissionService'); const { checkPermission } = require('~/server/services/PermissionService');
const { getAgent } = require('~/models/Agent'); const { getAgent } = require('~/models/Agent');
/** /**
* @param {Object} agent - The agent document (lean) * Checks if a user has access to multiple files through a shared agent (batch operation)
* @returns {Set<string>} All file IDs attached across all resource types
*/
function getAttachedFileIds(agent) {
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const resource of Object.values(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
for (const fileId of resource.file_ids) {
attachedFileIds.add(fileId);
}
}
}
}
return attachedFileIds;
}
/**
* Checks if a user has access to multiple files through a shared agent (batch operation).
* Access is always scoped to files actually attached to the agent's tool_resources.
* @param {Object} params - Parameters object * @param {Object} params - Parameters object
* @param {string} params.userId - The user ID to check access for * @param {string} params.userId - The user ID to check access for
* @param {string} [params.role] - Optional user role to avoid DB query * @param {string} [params.role] - Optional user role to avoid DB query
@ -35,6 +16,7 @@ function getAttachedFileIds(agent) {
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => { const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => {
const accessMap = new Map(); const accessMap = new Map();
// Initialize all files as no access
fileIds.forEach((fileId) => accessMap.set(fileId, false)); fileIds.forEach((fileId) => accessMap.set(fileId, false));
try { try {
@ -44,17 +26,13 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
return accessMap; return accessMap;
} }
const attachedFileIds = getAttachedFileIds(agent); // Check if user is the author - if so, grant access to all files
if (agent.author.toString() === userId.toString()) { if (agent.author.toString() === userId.toString()) {
fileIds.forEach((fileId) => { fileIds.forEach((fileId) => accessMap.set(fileId, true));
if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true);
}
});
return accessMap; return accessMap;
} }
// Check if user has at least VIEW permission on the agent
const hasViewPermission = await checkPermission({ const hasViewPermission = await checkPermission({
userId, userId,
role, role,
@ -68,6 +46,7 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
} }
if (isDelete) { if (isDelete) {
// Check if user has EDIT permission (which would indicate collaborative access)
const hasEditPermission = await checkPermission({ const hasEditPermission = await checkPermission({
userId, userId,
role, role,
@ -76,11 +55,23 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
requiredPermission: PermissionBits.EDIT, requiredPermission: PermissionBits.EDIT,
}); });
// If user only has VIEW permission, they can't access files
// Only users with EDIT permission or higher can access agent files
if (!hasEditPermission) { if (!hasEditPermission) {
return accessMap; return accessMap;
} }
} }
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
}
}
}
// Grant access only to files that are attached to this agent
fileIds.forEach((fileId) => { fileIds.forEach((fileId) => {
if (attachedFileIds.has(fileId)) { if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true); accessMap.set(fileId, true);
@ -104,7 +95,7 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
* @returns {Promise<Array<MongoFile>>} Filtered array of accessible files * @returns {Promise<Array<MongoFile>>} Filtered array of accessible files
*/ */
const filterFilesByAgentAccess = async ({ files, userId, role, agentId }) => { const filterFilesByAgentAccess = async ({ files, userId, role, agentId }) => {
if (!userId || !agentId || !files || files.length === 0 || isEphemeralAgentId(agentId)) { if (!userId || !agentId || !files || files.length === 0) {
return files; return files;
} }

View file

@ -1,409 +0,0 @@
jest.mock('@librechat/data-schemas', () => ({
logger: { error: jest.fn() },
}));
jest.mock('~/server/services/PermissionService', () => ({
checkPermission: jest.fn(),
}));
jest.mock('~/models/Agent', () => ({
getAgent: jest.fn(),
}));
const { logger } = require('@librechat/data-schemas');
const { Constants, PermissionBits, ResourceType } = require('librechat-data-provider');
const { checkPermission } = require('~/server/services/PermissionService');
const { getAgent } = require('~/models/Agent');
const { filterFilesByAgentAccess, hasAccessToFilesViaAgent } = require('./permissions');
const AUTHOR_ID = 'author-user-id';
const USER_ID = 'viewer-user-id';
const AGENT_ID = 'agent_test-abc123';
const AGENT_MONGO_ID = 'mongo-agent-id';
function makeFile(file_id, user) {
return { file_id, user, filename: `${file_id}.txt` };
}
function makeAgent(overrides = {}) {
return {
_id: AGENT_MONGO_ID,
id: AGENT_ID,
author: AUTHOR_ID,
tool_resources: {
file_search: { file_ids: ['attached-1', 'attached-2'] },
execute_code: { file_ids: ['attached-3'] },
},
...overrides,
};
}
beforeEach(() => {
jest.clearAllMocks();
});
describe('filterFilesByAgentAccess', () => {
describe('early returns (no DB calls)', () => {
it('should return files unfiltered for ephemeral agentId', async () => {
const files = [makeFile('f1', 'other-user')];
const result = await filterFilesByAgentAccess({
files,
userId: USER_ID,
agentId: Constants.EPHEMERAL_AGENT_ID,
});
expect(result).toBe(files);
expect(getAgent).not.toHaveBeenCalled();
});
it('should return files unfiltered for non-agent_ prefixed agentId', async () => {
const files = [makeFile('f1', 'other-user')];
const result = await filterFilesByAgentAccess({
files,
userId: USER_ID,
agentId: 'custom-memory-id',
});
expect(result).toBe(files);
expect(getAgent).not.toHaveBeenCalled();
});
it('should return files when userId is missing', async () => {
const files = [makeFile('f1', 'someone')];
const result = await filterFilesByAgentAccess({
files,
userId: undefined,
agentId: AGENT_ID,
});
expect(result).toBe(files);
expect(getAgent).not.toHaveBeenCalled();
});
it('should return files when agentId is missing', async () => {
const files = [makeFile('f1', 'someone')];
const result = await filterFilesByAgentAccess({
files,
userId: USER_ID,
agentId: undefined,
});
expect(result).toBe(files);
expect(getAgent).not.toHaveBeenCalled();
});
it('should return empty array when files is empty', async () => {
const result = await filterFilesByAgentAccess({
files: [],
userId: USER_ID,
agentId: AGENT_ID,
});
expect(result).toEqual([]);
expect(getAgent).not.toHaveBeenCalled();
});
it('should return undefined when files is nullish', async () => {
const result = await filterFilesByAgentAccess({
files: null,
userId: USER_ID,
agentId: AGENT_ID,
});
expect(result).toBeNull();
expect(getAgent).not.toHaveBeenCalled();
});
});
describe('all files owned by userId', () => {
it('should return all files without calling getAgent', async () => {
const files = [makeFile('f1', USER_ID), makeFile('f2', USER_ID)];
const result = await filterFilesByAgentAccess({
files,
userId: USER_ID,
agentId: AGENT_ID,
});
expect(result).toEqual(files);
expect(getAgent).not.toHaveBeenCalled();
});
});
describe('mixed owned and non-owned files', () => {
const ownedFile = makeFile('owned-1', USER_ID);
const sharedFile = makeFile('attached-1', AUTHOR_ID);
const unattachedFile = makeFile('not-attached', AUTHOR_ID);
it('should return owned + accessible non-owned files when user has VIEW', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(true);
const result = await filterFilesByAgentAccess({
files: [ownedFile, sharedFile, unattachedFile],
userId: USER_ID,
role: 'USER',
agentId: AGENT_ID,
});
expect(result).toHaveLength(2);
expect(result.map((f) => f.file_id)).toContain('owned-1');
expect(result.map((f) => f.file_id)).toContain('attached-1');
expect(result.map((f) => f.file_id)).not.toContain('not-attached');
});
it('should return only owned files when user lacks VIEW permission', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(false);
const result = await filterFilesByAgentAccess({
files: [ownedFile, sharedFile],
userId: USER_ID,
role: 'USER',
agentId: AGENT_ID,
});
expect(result).toEqual([ownedFile]);
});
it('should return only owned files when agent is not found', async () => {
getAgent.mockResolvedValue(null);
const result = await filterFilesByAgentAccess({
files: [ownedFile, sharedFile],
userId: USER_ID,
agentId: AGENT_ID,
});
expect(result).toEqual([ownedFile]);
});
it('should return only owned files on DB error (fail-closed)', async () => {
getAgent.mockRejectedValue(new Error('DB connection lost'));
const result = await filterFilesByAgentAccess({
files: [ownedFile, sharedFile],
userId: USER_ID,
agentId: AGENT_ID,
});
expect(result).toEqual([ownedFile]);
expect(logger.error).toHaveBeenCalled();
});
});
describe('file with no user field', () => {
it('should treat file as non-owned and run through access check', async () => {
const noUserFile = makeFile('attached-1', undefined);
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(true);
const result = await filterFilesByAgentAccess({
files: [noUserFile],
userId: USER_ID,
role: 'USER',
agentId: AGENT_ID,
});
expect(getAgent).toHaveBeenCalled();
expect(result).toEqual([noUserFile]);
});
it('should exclude file with no user field when not attached to agent', async () => {
const noUserFile = makeFile('not-attached', null);
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(true);
const result = await filterFilesByAgentAccess({
files: [noUserFile],
userId: USER_ID,
role: 'USER',
agentId: AGENT_ID,
});
expect(result).toEqual([]);
});
});
describe('no owned files (all non-owned)', () => {
const file1 = makeFile('attached-1', AUTHOR_ID);
const file2 = makeFile('not-attached', AUTHOR_ID);
it('should return only attached files when user has VIEW', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(true);
const result = await filterFilesByAgentAccess({
files: [file1, file2],
userId: USER_ID,
role: 'USER',
agentId: AGENT_ID,
});
expect(result).toEqual([file1]);
});
it('should return empty array when no VIEW permission', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(false);
const result = await filterFilesByAgentAccess({
files: [file1, file2],
userId: USER_ID,
agentId: AGENT_ID,
});
expect(result).toEqual([]);
});
it('should return empty array when agent not found', async () => {
getAgent.mockResolvedValue(null);
const result = await filterFilesByAgentAccess({
files: [file1],
userId: USER_ID,
agentId: AGENT_ID,
});
expect(result).toEqual([]);
});
});
});
describe('hasAccessToFilesViaAgent', () => {
describe('agent not found', () => {
it('should return all-false map', async () => {
getAgent.mockResolvedValue(null);
const result = await hasAccessToFilesViaAgent({
userId: USER_ID,
fileIds: ['f1', 'f2'],
agentId: AGENT_ID,
});
expect(result.get('f1')).toBe(false);
expect(result.get('f2')).toBe(false);
});
});
describe('author path', () => {
it('should grant access to attached files for the agent author', async () => {
getAgent.mockResolvedValue(makeAgent());
const result = await hasAccessToFilesViaAgent({
userId: AUTHOR_ID,
fileIds: ['attached-1', 'not-attached'],
agentId: AGENT_ID,
});
expect(result.get('attached-1')).toBe(true);
expect(result.get('not-attached')).toBe(false);
expect(checkPermission).not.toHaveBeenCalled();
});
});
describe('VIEW permission path', () => {
it('should grant access to attached files for viewer with VIEW permission', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(true);
const result = await hasAccessToFilesViaAgent({
userId: USER_ID,
role: 'USER',
fileIds: ['attached-1', 'attached-3', 'not-attached'],
agentId: AGENT_ID,
});
expect(result.get('attached-1')).toBe(true);
expect(result.get('attached-3')).toBe(true);
expect(result.get('not-attached')).toBe(false);
expect(checkPermission).toHaveBeenCalledWith({
userId: USER_ID,
role: 'USER',
resourceType: ResourceType.AGENT,
resourceId: AGENT_MONGO_ID,
requiredPermission: PermissionBits.VIEW,
});
});
it('should deny all when VIEW permission is missing', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValue(false);
const result = await hasAccessToFilesViaAgent({
userId: USER_ID,
fileIds: ['attached-1'],
agentId: AGENT_ID,
});
expect(result.get('attached-1')).toBe(false);
});
});
describe('delete path (EDIT permission required)', () => {
it('should grant access when both VIEW and EDIT pass', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(true);
const result = await hasAccessToFilesViaAgent({
userId: USER_ID,
fileIds: ['attached-1'],
agentId: AGENT_ID,
isDelete: true,
});
expect(result.get('attached-1')).toBe(true);
expect(checkPermission).toHaveBeenCalledTimes(2);
expect(checkPermission).toHaveBeenLastCalledWith(
expect.objectContaining({ requiredPermission: PermissionBits.EDIT }),
);
});
it('should deny all when VIEW passes but EDIT fails', async () => {
getAgent.mockResolvedValue(makeAgent());
checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(false);
const result = await hasAccessToFilesViaAgent({
userId: USER_ID,
fileIds: ['attached-1'],
agentId: AGENT_ID,
isDelete: true,
});
expect(result.get('attached-1')).toBe(false);
});
});
describe('error handling', () => {
it('should return all-false map on DB error (fail-closed)', async () => {
getAgent.mockRejectedValue(new Error('connection refused'));
const result = await hasAccessToFilesViaAgent({
userId: USER_ID,
fileIds: ['f1', 'f2'],
agentId: AGENT_ID,
});
expect(result.get('f1')).toBe(false);
expect(result.get('f2')).toBe(false);
expect(logger.error).toHaveBeenCalledWith(
'[hasAccessToFilesViaAgent] Error checking file access:',
expect.any(Error),
);
});
});
describe('agent with no tool_resources', () => {
it('should deny all files even for the author', async () => {
getAgent.mockResolvedValue(makeAgent({ tool_resources: undefined }));
const result = await hasAccessToFilesViaAgent({
userId: AUTHOR_ID,
fileIds: ['f1'],
agentId: AGENT_ID,
});
expect(result.get('f1')).toBe(false);
});
});
});

View file

@ -34,55 +34,6 @@ const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config'); const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
const MAX_CACHE_SIZE = 1000;
const lastReconnectAttempts = new Map();
const RECONNECT_THROTTLE_MS = 10_000;
const missingToolCache = new Map();
const MISSING_TOOL_TTL_MS = 10_000;
function evictStale(map, ttl) {
if (map.size <= MAX_CACHE_SIZE) {
return;
}
const now = Date.now();
for (const [key, timestamp] of map) {
if (now - timestamp >= ttl) {
map.delete(key);
}
if (map.size <= MAX_CACHE_SIZE) {
return;
}
}
}
const unavailableMsg =
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
/**
* @param {string} toolName
* @param {string} serverName
*/
function createUnavailableToolStub(toolName, serverName) {
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
const _call = async () => [unavailableMsg, null];
const toolInstance = tool(_call, {
schema: {
type: 'object',
properties: {
input: { type: 'string', description: 'Input for the tool' },
},
required: [],
},
name: normalizedToolKey,
description: unavailableMsg,
responseFormat: AgentConstants.CONTENT_AND_ARTIFACT,
});
toolInstance.mcp = true;
toolInstance.mcpRawServerName = serverName;
return toolInstance;
}
function isEmptyObjectSchema(jsonSchema) { function isEmptyObjectSchema(jsonSchema) {
return ( return (
jsonSchema != null && jsonSchema != null &&
@ -260,17 +211,6 @@ async function reconnectServer({
logger.debug( logger.debug(
`[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`, `[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`,
); );
const throttleKey = `${user.id}:${serverName}`;
const now = Date.now();
const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0;
if (now - lastAttempt < RECONNECT_THROTTLE_MS) {
logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`);
return null;
}
lastReconnectAttempts.set(throttleKey, now);
evictStale(lastReconnectAttempts, RECONNECT_THROTTLE_MS);
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${user.id}:${serverName}:${Date.now()}`; const flowId = `${user.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
@ -327,7 +267,7 @@ async function reconnectServer({
userMCPAuthMap, userMCPAuthMap,
forceNew: true, forceNew: true,
returnOnOAuth: false, returnOnOAuth: false,
connectionTimeout: Time.THIRTY_SECONDS, connectionTimeout: Time.TWO_MINUTES,
}); });
} finally { } finally {
// Clean up abort handler to prevent memory leaks // Clean up abort handler to prevent memory leaks
@ -390,13 +330,9 @@ async function createMCPTools({
userMCPAuthMap, userMCPAuthMap,
streamId, streamId,
}); });
if (result === null) {
logger.debug(`[MCP][${serverName}] Reconnect throttled, skipping tool creation.`);
return [];
}
if (!result || !result.tools) { if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return []; return;
} }
const serverTools = []; const serverTools = [];
@ -466,14 +402,6 @@ async function createMCPTool({
/** @type {LCTool | undefined} */ /** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function; let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) { if (!toolDefinition) {
const cachedAt = missingToolCache.get(toolKey);
if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) {
logger.debug(
`[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`,
);
return createUnavailableToolStub(toolName, serverName);
}
logger.warn( logger.warn(
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`, `[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
); );
@ -487,18 +415,11 @@ async function createMCPTool({
streamId, streamId,
}); });
toolDefinition = result?.availableTools?.[toolKey]?.function; toolDefinition = result?.availableTools?.[toolKey]?.function;
if (!toolDefinition) {
missingToolCache.set(toolKey, Date.now());
evictStale(missingToolCache, MISSING_TOOL_TTL_MS);
}
} }
if (!toolDefinition) { if (!toolDefinition) {
logger.warn( logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
`[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`, return;
);
return createUnavailableToolStub(toolName, serverName);
} }
return createToolInstance({ return createToolInstance({
@ -799,5 +720,4 @@ module.exports = {
getMCPSetupData, getMCPSetupData,
checkOAuthFlowStatus, checkOAuthFlowStatus,
getServerConnectionStatus, getServerConnectionStatus,
createUnavailableToolStub,
}; };

View file

@ -45,7 +45,6 @@ const {
getMCPSetupData, getMCPSetupData,
checkOAuthFlowStatus, checkOAuthFlowStatus,
getServerConnectionStatus, getServerConnectionStatus,
createUnavailableToolStub,
} = require('./MCP'); } = require('./MCP');
jest.mock('./Config', () => ({ jest.mock('./Config', () => ({
@ -1099,188 +1098,6 @@ describe('User parameter passing tests', () => {
}); });
}); });
describe('createUnavailableToolStub', () => {
it('should return a tool whose _call returns a valid CONTENT_AND_ARTIFACT two-tuple', async () => {
const stub = createUnavailableToolStub('myTool', 'myServer');
// invoke() goes through langchain's base tool, which checks responseFormat.
// CONTENT_AND_ARTIFACT requires [content, artifact] — a bare string would throw:
// "Tool response format is "content_and_artifact" but the output was not a two-tuple"
const result = await stub.invoke({});
// If we reach here without throwing, the two-tuple format is correct.
// invoke() returns the content portion of [content, artifact] as a string.
expect(result).toContain('temporarily unavailable');
});
});
describe('negative tool cache and throttle interaction', () => {
it('should cache tool as missing even when throttled (cross-user dedup)', async () => {
const mockUser = { id: 'throttle-test-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// First call: reconnect succeeds but tool not found
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: `missing-tool${D}cache-dedup-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
// Second call within 10s for DIFFERENT tool on same server:
// reconnect is throttled (returns null), tool is still cached as missing.
// This is intentional: the cache acts as cross-user dedup since the
// throttle is per-user-per-server and can't prevent N different users
// from each triggering their own reconnect.
const result2 = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: `other-tool${D}cache-dedup-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result2).toBeDefined();
expect(result2.name).toContain('other-tool');
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
it('should prevent user B from triggering reconnect when user A already cached the tool', async () => {
const userA = { id: 'cache-user-A' };
const userB = { id: 'cache-user-B' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// User A: real reconnect, tool not found → cached
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `shared-tool${D}cross-user-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// User B requests the SAME tool within 10s.
// The negative cache is keyed by toolKey (no user prefix), so user B
// gets a cache hit and no reconnect fires. This is the cross-user
// storm protection: without this, user B's unthrottled first request
// would trigger a second reconnect to the same server.
const result = await createMCPTool({
res: mockRes,
user: userB,
toolKey: `shared-tool${D}cross-user-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result).toBeDefined();
expect(result.name).toContain('shared-tool');
// reinitMCPServer still called only once — user B hit the cache
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
it('should prevent user B from triggering reconnect for throttle-cached tools', async () => {
const userA = { id: 'storm-user-A' };
const userB = { id: 'storm-user-B' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// User A: real reconnect for tool-1, tool not found → cached
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `tool-1${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
// User A: tool-2 on same server within 10s → throttled → cached from throttle
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `tool-2${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// User B requests tool-2 — gets cache hit from the throttle-cached entry.
// Without this caching, user B would trigger a real reconnect since
// user B has their own throttle key and hasn't reconnected yet.
const result = await createMCPTool({
res: mockRes,
user: userB,
toolKey: `tool-2${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result).toBeDefined();
expect(result.name).toContain('tool-2');
// Still only 1 real reconnect — user B was protected by the cache
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
});
describe('createMCPTools throttle handling', () => {
it('should return empty array with debug log when reconnect is throttled', async () => {
const mockUser = { id: 'throttle-tools-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// First call: real reconnect
mockReinitMCPServer.mockResolvedValueOnce({
tools: [{ name: 'tool1' }],
availableTools: {
[`tool1${D}throttle-tools-server`]: {
function: { description: 'Tool 1', parameters: {} },
},
},
});
await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'throttle-tools-server',
provider: 'openai',
userMCPAuthMap: {},
});
// Second call within 10s — throttled
const result = await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'throttle-tools-server',
provider: 'openai',
userMCPAuthMap: {},
});
expect(result).toEqual([]);
// reinitMCPServer called only once — second was throttled
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// Should log at debug level (not warn) for throttled case
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Reconnect throttled'));
});
});
describe('User parameter integrity', () => { describe('User parameter integrity', () => {
it('should preserve user object properties through the call chain', async () => { it('should preserve user object properties through the call chain', async () => {
const complexUser = { const complexUser = {

View file

@ -64,26 +64,6 @@ const { redactMessage } = require('~/config/parsers');
const { findPluginAuthsByKeys } = require('~/models'); const { findPluginAuthsByKeys } = require('~/models');
const { getFlowStateManager } = require('~/config'); const { getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
/**
* Resolves the set of enabled agent capabilities from endpoints config,
* falling back to app-level or default capabilities for ephemeral agents.
* @param {ServerRequest} req
* @param {Object} appConfig
* @param {string} agentId
* @returns {Promise<Set<string>>}
*/
async function resolveAgentCapabilities(req, appConfig, agentId) {
const endpointsConfig = await getEndpointsConfig(req);
let capabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
if (capabilities.size === 0 && isEphemeralAgentId(agentId)) {
capabilities = new Set(
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
);
}
return capabilities;
}
/** /**
* Processes the required actions by calling the appropriate tools and returning the outputs. * Processes the required actions by calling the appropriate tools and returning the outputs.
* @param {OpenAIClient} client - OpenAI or StreamRunManager Client. * @param {OpenAIClient} client - OpenAI or StreamRunManager Client.
@ -465,11 +445,17 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
} }
const appConfig = req.config; const appConfig = req.config;
const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id); const endpointsConfig = await getEndpointsConfig(req);
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
enabledCapabilities = new Set(
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
);
}
const checkCapability = (capability) => enabledCapabilities.has(capability); const checkCapability = (capability) => enabledCapabilities.has(capability);
const areToolsEnabled = checkCapability(AgentCapabilities.tools); const areToolsEnabled = checkCapability(AgentCapabilities.tools);
const actionsEnabled = checkCapability(AgentCapabilities.actions);
const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools); const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools);
const filteredTools = agent.tools?.filter((tool) => { const filteredTools = agent.tools?.filter((tool) => {
@ -482,10 +468,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
if (tool === Tools.web_search) { if (tool === Tools.web_search) {
return checkCapability(AgentCapabilities.web_search); return checkCapability(AgentCapabilities.web_search);
} }
if (tool.includes(actionDelimiter)) { if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
return actionsEnabled;
}
if (!areToolsEnabled) {
return false; return false;
} }
return true; return true;
@ -782,7 +765,6 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
toolContextMap, toolContextMap,
toolDefinitions, toolDefinitions,
hasDeferredTools, hasDeferredTools,
actionsEnabled,
}; };
} }
@ -826,7 +808,14 @@ async function loadAgentTools({
} }
const appConfig = req.config; const appConfig = req.config;
const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id); const endpointsConfig = await getEndpointsConfig(req);
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
/** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
enabledCapabilities = new Set(
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
);
}
const checkCapability = (capability) => { const checkCapability = (capability) => {
const enabled = enabledCapabilities.has(capability); const enabled = enabledCapabilities.has(capability);
if (!enabled) { if (!enabled) {
@ -843,7 +832,6 @@ async function loadAgentTools({
return enabled; return enabled;
}; };
const areToolsEnabled = checkCapability(AgentCapabilities.tools); const areToolsEnabled = checkCapability(AgentCapabilities.tools);
const actionsEnabled = checkCapability(AgentCapabilities.actions);
let includesWebSearch = false; let includesWebSearch = false;
const _agentTools = agent.tools?.filter((tool) => { const _agentTools = agent.tools?.filter((tool) => {
@ -854,9 +842,7 @@ async function loadAgentTools({
} else if (tool === Tools.web_search) { } else if (tool === Tools.web_search) {
includesWebSearch = checkCapability(AgentCapabilities.web_search); includesWebSearch = checkCapability(AgentCapabilities.web_search);
return includesWebSearch; return includesWebSearch;
} else if (tool.includes(actionDelimiter)) { } else if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
return actionsEnabled;
} else if (!areToolsEnabled) {
return false; return false;
} }
return true; return true;
@ -961,15 +947,13 @@ async function loadAgentTools({
agentTools.push(...additionalTools); agentTools.push(...additionalTools);
const hasActionTools = _agentTools.some((t) => t.includes(actionDelimiter)); if (!checkCapability(AgentCapabilities.actions)) {
if (!hasActionTools) {
return { return {
toolRegistry, toolRegistry,
userMCPAuthMap, userMCPAuthMap,
toolContextMap, toolContextMap,
toolDefinitions, toolDefinitions,
hasDeferredTools, hasDeferredTools,
actionsEnabled,
tools: agentTools, tools: agentTools,
}; };
} }
@ -985,7 +969,6 @@ async function loadAgentTools({
toolContextMap, toolContextMap,
toolDefinitions, toolDefinitions,
hasDeferredTools, hasDeferredTools,
actionsEnabled,
tools: agentTools, tools: agentTools,
}; };
} }
@ -1118,7 +1101,6 @@ async function loadAgentTools({
userMCPAuthMap, userMCPAuthMap,
toolDefinitions, toolDefinitions,
hasDeferredTools, hasDeferredTools,
actionsEnabled,
tools: agentTools, tools: agentTools,
}; };
} }
@ -1136,11 +1118,9 @@ async function loadAgentTools({
* @param {AbortSignal} [params.signal] - Abort signal * @param {AbortSignal} [params.signal] - Abort signal
* @param {Object} params.agent - The agent object * @param {Object} params.agent - The agent object
* @param {string[]} params.toolNames - Names of tools to load * @param {string[]} params.toolNames - Names of tools to load
* @param {Map} [params.toolRegistry] - Tool registry
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap] - User MCP auth map * @param {Record<string, Record<string, string>>} [params.userMCPAuthMap] - User MCP auth map
* @param {Object} [params.tool_resources] - Tool resources * @param {Object} [params.tool_resources] - Tool resources
* @param {string|null} [params.streamId] - Stream ID for web search callbacks * @param {string|null} [params.streamId] - Stream ID for web search callbacks
* @param {boolean} [params.actionsEnabled] - Whether the actions capability is enabled
* @returns {Promise<{ loadedTools: Array, configurable: Object }>} * @returns {Promise<{ loadedTools: Array, configurable: Object }>}
*/ */
async function loadToolsForExecution({ async function loadToolsForExecution({
@ -1153,17 +1133,11 @@ async function loadToolsForExecution({
userMCPAuthMap, userMCPAuthMap,
tool_resources, tool_resources,
streamId = null, streamId = null,
actionsEnabled,
}) { }) {
const appConfig = req.config; const appConfig = req.config;
const allLoadedTools = []; const allLoadedTools = [];
const configurable = { userMCPAuthMap }; const configurable = { userMCPAuthMap };
if (actionsEnabled === undefined) {
const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent?.id);
actionsEnabled = enabledCapabilities.has(AgentCapabilities.actions);
}
const isToolSearch = toolNames.includes(AgentConstants.TOOL_SEARCH); const isToolSearch = toolNames.includes(AgentConstants.TOOL_SEARCH);
const isPTC = toolNames.includes(AgentConstants.PROGRAMMATIC_TOOL_CALLING); const isPTC = toolNames.includes(AgentConstants.PROGRAMMATIC_TOOL_CALLING);
@ -1220,6 +1194,7 @@ async function loadToolsForExecution({
const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter)); const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter));
const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter)); const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter));
/** @type {Record<string, unknown>} */
if (regularToolNames.length > 0) { if (regularToolNames.length > 0) {
const includesWebSearch = regularToolNames.includes(Tools.web_search); const includesWebSearch = regularToolNames.includes(Tools.web_search);
const webSearchCallbacks = includesWebSearch ? createOnSearchResults(res, streamId) : undefined; const webSearchCallbacks = includesWebSearch ? createOnSearchResults(res, streamId) : undefined;
@ -1250,7 +1225,7 @@ async function loadToolsForExecution({
} }
} }
if (actionToolNames.length > 0 && agent && actionsEnabled) { if (actionToolNames.length > 0 && agent) {
const actionTools = await loadActionToolsForExecution({ const actionTools = await loadActionToolsForExecution({
req, req,
res, res,
@ -1260,11 +1235,6 @@ async function loadToolsForExecution({
actionToolNames, actionToolNames,
}); });
allLoadedTools.push(...actionTools); allLoadedTools.push(...actionTools);
} else if (actionToolNames.length > 0 && agent && !actionsEnabled) {
logger.warn(
`[loadToolsForExecution] Capability "${AgentCapabilities.actions}" disabled. ` +
`Skipping action tool execution. User: ${req.user.id} | Agent: ${agent.id} | Tools: ${actionToolNames.join(', ')}`,
);
} }
if (isPTC && allLoadedTools.length > 0) { if (isPTC && allLoadedTools.length > 0) {
@ -1425,5 +1395,4 @@ module.exports = {
loadAgentTools, loadAgentTools,
loadToolsForExecution, loadToolsForExecution,
processRequiredActions, processRequiredActions,
resolveAgentCapabilities,
}; };

View file

@ -1,8 +1,8 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider'); const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getMCPServersRegistry, getFlowStateManager } = require('~/config');
const { findToken, createToken, updateToken, deleteTokens } = require('~/models'); const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
const { updateMCPServerTools } = require('~/server/services/Config'); const { updateMCPServerTools } = require('~/server/services/Config');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
/** /**
@ -41,33 +41,6 @@ async function reinitMCPServer({
let oauthUrl = null; let oauthUrl = null;
try { try {
const registry = getMCPServersRegistry();
const serverConfig = await registry.getServerConfig(serverName, user?.id);
if (serverConfig?.inspectionFailed) {
logger.info(
`[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`,
);
try {
const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE';
await registry.reinspectServer(serverName, storageLocation, user?.id);
logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`);
} catch (reinspectError) {
logger.error(
`[MCP Reinitialize] Reinspection failed for server ${serverName}:`,
reinspectError,
);
return {
availableTools: null,
success: false,
message: `MCP server '${serverName}' is still unreachable`,
oauthRequired: false,
serverName,
oauthUrl: null,
tools: null,
};
}
}
const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`];
const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS)); const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const mcpManager = getMCPManager(); const mcpManager = getMCPManager();

View file

@ -1,304 +1,19 @@
const { const {
Tools,
Constants, Constants,
EModelEndpoint,
actionDelimiter,
AgentCapabilities, AgentCapabilities,
defaultAgentCapabilities, defaultAgentCapabilities,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const mockGetEndpointsConfig = jest.fn(); /**
const mockGetMCPServerTools = jest.fn(); * Tests for ToolService capability checking logic.
const mockGetCachedTools = jest.fn(); * The actual loadAgentTools function has many dependencies, so we test
jest.mock('~/server/services/Config', () => ({ * the capability checking logic in isolation.
getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args), */
getMCPServerTools: (...args) => mockGetMCPServerTools(...args), describe('ToolService - Capability Checking', () => {
getCachedTools: (...args) => mockGetCachedTools(...args),
}));
const mockLoadToolDefinitions = jest.fn();
const mockGetUserMCPAuthMap = jest.fn();
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
loadToolDefinitions: (...args) => mockLoadToolDefinitions(...args),
getUserMCPAuthMap: (...args) => mockGetUserMCPAuthMap(...args),
}));
const mockLoadToolsUtil = jest.fn();
jest.mock('~/app/clients/tools/util', () => ({
loadTools: (...args) => mockLoadToolsUtil(...args),
}));
const mockLoadActionSets = jest.fn();
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn().mockResolvedValue({}),
}));
jest.mock('~/server/services/Tools/search', () => ({
createOnSearchResults: jest.fn(),
}));
jest.mock('~/server/services/Tools/mcp', () => ({
reinitMCPServer: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
processFileURL: jest.fn(),
uploadImageBuffer: jest.fn(),
}));
jest.mock('~/app/clients/tools/util/fileSearch', () => ({
primeFiles: jest.fn().mockResolvedValue({}),
}));
jest.mock('~/server/services/Files/Code/process', () => ({
primeFiles: jest.fn().mockResolvedValue({}),
}));
jest.mock('../ActionService', () => ({
loadActionSets: (...args) => mockLoadActionSets(...args),
decryptMetadata: jest.fn(),
createActionTool: jest.fn(),
domainParser: jest.fn(),
}));
jest.mock('~/server/services/Threads', () => ({
recordUsage: jest.fn(),
}));
jest.mock('~/models', () => ({
findPluginAuthsByKeys: jest.fn(),
}));
jest.mock('~/config', () => ({
getFlowStateManager: jest.fn(() => ({})),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({})),
}));
const {
loadAgentTools,
loadToolsForExecution,
resolveAgentCapabilities,
} = require('../ToolService');
function createMockReq(capabilities) {
return {
user: { id: 'user_123' },
config: {
endpoints: {
[EModelEndpoint.agents]: {
capabilities,
},
},
},
};
}
function createEndpointsConfig(capabilities) {
return {
[EModelEndpoint.agents]: { capabilities },
};
}
describe('ToolService - Action Capability Gating', () => {
beforeEach(() => {
jest.clearAllMocks();
mockLoadToolDefinitions.mockResolvedValue({
toolDefinitions: [],
toolRegistry: new Map(),
hasDeferredTools: false,
});
mockLoadToolsUtil.mockResolvedValue({ loadedTools: [], toolContextMap: {} });
mockLoadActionSets.mockResolvedValue([]);
});
describe('resolveAgentCapabilities', () => {
it('should return capabilities from endpoints config', async () => {
const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions];
const req = createMockReq(capabilities);
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
const result = await resolveAgentCapabilities(req, req.config, 'agent_123');
expect(result).toBeInstanceOf(Set);
expect(result.has(AgentCapabilities.tools)).toBe(true);
expect(result.has(AgentCapabilities.actions)).toBe(true);
expect(result.has(AgentCapabilities.web_search)).toBe(false);
});
it('should fall back to default capabilities for ephemeral agents with empty config', async () => {
const req = createMockReq(defaultAgentCapabilities);
mockGetEndpointsConfig.mockResolvedValue({});
const result = await resolveAgentCapabilities(req, req.config, Constants.EPHEMERAL_AGENT_ID);
for (const cap of defaultAgentCapabilities) {
expect(result.has(cap)).toBe(true);
}
});
it('should return empty set when no capabilities and not ephemeral', async () => {
const req = createMockReq([]);
mockGetEndpointsConfig.mockResolvedValue({});
const result = await resolveAgentCapabilities(req, req.config, 'agent_123');
expect(result.size).toBe(0);
});
});
describe('loadAgentTools (definitionsOnly=true) — action tool filtering', () => {
const actionToolName = `get_weather${actionDelimiter}api_example_com`;
const regularTool = 'calculator';
it('should exclude action tools from definitions when actions capability is disabled', async () => {
const capabilities = [AgentCapabilities.tools, AgentCapabilities.web_search];
const req = createMockReq(capabilities);
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
await loadAgentTools({
req,
res: {},
agent: { id: 'agent_123', tools: [regularTool, actionToolName] },
definitionsOnly: true,
});
expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1);
const [callArgs] = mockLoadToolDefinitions.mock.calls[0];
expect(callArgs.tools).toContain(regularTool);
expect(callArgs.tools).not.toContain(actionToolName);
});
it('should include action tools in definitions when actions capability is enabled', async () => {
const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions];
const req = createMockReq(capabilities);
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
await loadAgentTools({
req,
res: {},
agent: { id: 'agent_123', tools: [regularTool, actionToolName] },
definitionsOnly: true,
});
expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1);
const [callArgs] = mockLoadToolDefinitions.mock.calls[0];
expect(callArgs.tools).toContain(regularTool);
expect(callArgs.tools).toContain(actionToolName);
});
it('should return actionsEnabled in the result', async () => {
const capabilities = [AgentCapabilities.tools];
const req = createMockReq(capabilities);
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
const result = await loadAgentTools({
req,
res: {},
agent: { id: 'agent_123', tools: [regularTool] },
definitionsOnly: true,
});
expect(result.actionsEnabled).toBe(false);
});
});
describe('loadAgentTools (definitionsOnly=false) — action tool filtering', () => {
const actionToolName = `get_weather${actionDelimiter}api_example_com`;
const regularTool = 'calculator';
it('should not load action sets when actions capability is disabled', async () => {
const capabilities = [AgentCapabilities.tools, AgentCapabilities.web_search];
const req = createMockReq(capabilities);
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
await loadAgentTools({
req,
res: {},
agent: { id: 'agent_123', tools: [regularTool, actionToolName] },
definitionsOnly: false,
});
expect(mockLoadActionSets).not.toHaveBeenCalled();
});
it('should load action sets when actions capability is enabled and action tools present', async () => {
const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions];
const req = createMockReq(capabilities);
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
await loadAgentTools({
req,
res: {},
agent: { id: 'agent_123', tools: [regularTool, actionToolName] },
definitionsOnly: false,
});
expect(mockLoadActionSets).toHaveBeenCalledWith({ agent_id: 'agent_123' });
});
});
describe('loadToolsForExecution — action tool gating', () => {
const actionToolName = `get_weather${actionDelimiter}api_example_com`;
const regularTool = Tools.web_search;
it('should skip action tool loading when actionsEnabled=false', async () => {
const req = createMockReq([]);
req.config = {};
const result = await loadToolsForExecution({
req,
res: {},
agent: { id: 'agent_123' },
toolNames: [regularTool, actionToolName],
actionsEnabled: false,
});
expect(mockLoadActionSets).not.toHaveBeenCalled();
expect(result.loadedTools).toBeDefined();
});
it('should load action tools when actionsEnabled=true', async () => {
const req = createMockReq([AgentCapabilities.actions]);
req.config = {};
await loadToolsForExecution({
req,
res: {},
agent: { id: 'agent_123' },
toolNames: [actionToolName],
actionsEnabled: true,
});
expect(mockLoadActionSets).toHaveBeenCalledWith({ agent_id: 'agent_123' });
});
it('should resolve actionsEnabled from capabilities when not explicitly provided', async () => {
const capabilities = [AgentCapabilities.tools];
const req = createMockReq(capabilities);
mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities));
await loadToolsForExecution({
req,
res: {},
agent: { id: 'agent_123' },
toolNames: [actionToolName],
});
expect(mockGetEndpointsConfig).toHaveBeenCalled();
expect(mockLoadActionSets).not.toHaveBeenCalled();
});
it('should not call loadActionSets when there are no action tools', async () => {
const req = createMockReq([AgentCapabilities.actions]);
req.config = {};
await loadToolsForExecution({
req,
res: {},
agent: { id: 'agent_123' },
toolNames: [regularTool],
actionsEnabled: true,
});
expect(mockLoadActionSets).not.toHaveBeenCalled();
});
});
describe('checkCapability logic', () => { describe('checkCapability logic', () => {
/**
* Simulates the checkCapability function from loadAgentTools
*/
const createCheckCapability = (enabledCapabilities, logger = { warn: jest.fn() }) => { const createCheckCapability = (enabledCapabilities, logger = { warn: jest.fn() }) => {
return (capability) => { return (capability) => {
const enabled = enabledCapabilities.has(capability); const enabled = enabledCapabilities.has(capability);
@ -409,6 +124,10 @@ describe('ToolService - Action Capability Gating', () => {
}); });
describe('userMCPAuthMap gating', () => { describe('userMCPAuthMap gating', () => {
/**
* Simulates the guard condition used in both loadToolDefinitionsWrapper
* and loadAgentTools to decide whether getUserMCPAuthMap should be called.
*/
const shouldFetchMCPAuth = (tools) => const shouldFetchMCPAuth = (tools) =>
tools?.some((t) => t.includes(Constants.mcp_delimiter)) ?? false; tools?.some((t) => t.includes(Constants.mcp_delimiter)) ?? false;
@ -459,17 +178,20 @@ describe('ToolService - Action Capability Gating', () => {
return (capability) => enabledCapabilities.has(capability); return (capability) => enabledCapabilities.has(capability);
}; };
// When deferred_tools is in capabilities
const withDeferred = new Set([AgentCapabilities.deferred_tools, AgentCapabilities.tools]); const withDeferred = new Set([AgentCapabilities.deferred_tools, AgentCapabilities.tools]);
const checkWithDeferred = createCheckCapability(withDeferred); const checkWithDeferred = createCheckCapability(withDeferred);
expect(checkWithDeferred(AgentCapabilities.deferred_tools)).toBe(true); expect(checkWithDeferred(AgentCapabilities.deferred_tools)).toBe(true);
// When deferred_tools is NOT in capabilities
const withoutDeferred = new Set([AgentCapabilities.tools, AgentCapabilities.actions]); const withoutDeferred = new Set([AgentCapabilities.tools, AgentCapabilities.actions]);
const checkWithoutDeferred = createCheckCapability(withoutDeferred); const checkWithoutDeferred = createCheckCapability(withoutDeferred);
expect(checkWithoutDeferred(AgentCapabilities.deferred_tools)).toBe(false); expect(checkWithoutDeferred(AgentCapabilities.deferred_tools)).toBe(false);
}); });
it('should use defaultAgentCapabilities when no capabilities configured', () => { it('should use defaultAgentCapabilities when no capabilities configured', () => {
const endpointsConfig = {}; // Simulates the fallback behavior in loadAgentTools
const endpointsConfig = {}; // No capabilities configured
const enabledCapabilities = new Set( const enabledCapabilities = new Set(
endpointsConfig?.capabilities ?? defaultAgentCapabilities, endpointsConfig?.capabilities ?? defaultAgentCapabilities,
); );

View file

@ -153,11 +153,9 @@ const generateBackupCodes = async (count = 10) => {
* @param {Object} params * @param {Object} params
* @param {Object} params.user * @param {Object} params.user
* @param {string} params.backupCode * @param {string} params.backupCode
* @param {boolean} [params.persist=true] - Whether to persist the used-mark to the database.
* Pass `false` when the caller will immediately overwrite `backupCodes` (e.g. re-enrollment).
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
const verifyBackupCode = async ({ user, backupCode, persist = true }) => { const verifyBackupCode = async ({ user, backupCode }) => {
if (!backupCode || !user || !Array.isArray(user.backupCodes)) { if (!backupCode || !user || !Array.isArray(user.backupCodes)) {
return false; return false;
} }
@ -167,50 +165,17 @@ const verifyBackupCode = async ({ user, backupCode, persist = true }) => {
(codeObj) => codeObj.codeHash === hashedInput && !codeObj.used, (codeObj) => codeObj.codeHash === hashedInput && !codeObj.used,
); );
if (!matchingCode) { if (matchingCode) {
return false;
}
if (persist) {
const updatedBackupCodes = user.backupCodes.map((codeObj) => const updatedBackupCodes = user.backupCodes.map((codeObj) =>
codeObj.codeHash === hashedInput && !codeObj.used codeObj.codeHash === hashedInput && !codeObj.used
? { ...codeObj, used: true, usedAt: new Date() } ? { ...codeObj, used: true, usedAt: new Date() }
: codeObj, : codeObj,
); );
// Update the user record with the marked backup code.
await updateUser(user._id, { backupCodes: updatedBackupCodes }); await updateUser(user._id, { backupCodes: updatedBackupCodes });
}
return true; return true;
};
/**
* Verifies a user's identity via TOTP token or backup code.
* @param {Object} params
* @param {Object} params.user - The user document (must include totpSecret and backupCodes).
* @param {string} [params.token] - A 6-digit TOTP token.
* @param {string} [params.backupCode] - An 8-character backup code.
* @param {boolean} [params.persistBackupUse=true] - Whether to mark the backup code as used in the DB.
* @returns {Promise<{ verified: boolean, status?: number, message?: string }>}
*/
const verifyOTPOrBackupCode = async ({ user, token, backupCode, persistBackupUse = true }) => {
if (!token && !backupCode) {
return { verified: false, status: 400 };
} }
return false;
if (token) {
const secret = await getTOTPSecret(user.totpSecret);
if (!secret) {
return { verified: false, status: 400, message: '2FA secret is missing or corrupted' };
}
const ok = await verifyTOTP(secret, token);
return ok
? { verified: true }
: { verified: false, status: 401, message: 'Invalid token or backup code' };
}
const ok = await verifyBackupCode({ user, backupCode, persist: persistBackupUse });
return ok
? { verified: true }
: { verified: false, status: 401, message: 'Invalid token or backup code' };
}; };
/** /**
@ -248,12 +213,11 @@ const generate2FATempToken = (userId) => {
}; };
module.exports = { module.exports = {
verifyOTPOrBackupCode,
generate2FATempToken,
generateBackupCodes,
generateTOTPSecret, generateTOTPSecret,
verifyBackupCode,
getTOTPSecret,
generateTOTP, generateTOTP,
verifyTOTP, verifyTOTP,
generateBackupCodes,
verifyBackupCode,
getTOTPSecret,
generate2FATempToken,
}; };

View file

@ -358,15 +358,16 @@ function splitAtTargetLevel(messages, targetMessageId) {
* @param {object} params - The parameters for duplicating the conversation. * @param {object} params - The parameters for duplicating the conversation.
* @param {string} params.userId - The ID of the user duplicating the conversation. * @param {string} params.userId - The ID of the user duplicating the conversation.
* @param {string} params.conversationId - The ID of the conversation to duplicate. * @param {string} params.conversationId - The ID of the conversation to duplicate.
* @param {string} [params.title] - Optional title override for the duplicate.
* @returns {Promise<{ conversation: TConversation, messages: TMessage[] }>} The duplicated conversation and messages. * @returns {Promise<{ conversation: TConversation, messages: TMessage[] }>} The duplicated conversation and messages.
*/ */
async function duplicateConversation({ userId, conversationId, title }) { async function duplicateConversation({ userId, conversationId }) {
// Get original conversation
const originalConvo = await getConvo(userId, conversationId); const originalConvo = await getConvo(userId, conversationId);
if (!originalConvo) { if (!originalConvo) {
throw new Error('Conversation not found'); throw new Error('Conversation not found');
} }
// Get original messages
const originalMessages = await getMessages({ const originalMessages = await getMessages({
user: userId, user: userId,
conversationId, conversationId,
@ -382,11 +383,14 @@ async function duplicateConversation({ userId, conversationId, title }) {
cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder); cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder);
const duplicateTitle = title || originalConvo.title; const result = importBatchBuilder.finishConversation(
const result = importBatchBuilder.finishConversation(duplicateTitle, new Date(), originalConvo); originalConvo.title,
new Date(),
originalConvo,
);
await importBatchBuilder.saveBatch(); await importBatchBuilder.saveBatch();
logger.debug( logger.debug(
`user: ${userId} | New conversation "${duplicateTitle}" duplicated from conversation ID ${conversationId}`, `user: ${userId} | New conversation "${originalConvo.title}" duplicated from conversation ID ${conversationId}`,
); );
const conversation = await getConvo(userId, result.conversation.conversationId); const conversation = await getConvo(userId, result.conversation.conversationId);

View file

@ -1,10 +1,7 @@
const fs = require('fs').promises; const fs = require('fs').promises;
const { resolveImportMaxFileSize } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { getImporter } = require('./importers'); const { getImporter } = require('./importers');
const maxFileSize = resolveImportMaxFileSize();
/** /**
* Job definition for importing a conversation. * Job definition for importing a conversation.
* @param {{ filepath, requestUserId }} job - The job object. * @param {{ filepath, requestUserId }} job - The job object.
@ -14,10 +11,11 @@ const importConversations = async (job) => {
try { try {
logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`); logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`);
/* error if file is too large */
const fileInfo = await fs.stat(filepath); const fileInfo = await fs.stat(filepath);
if (fileInfo.size > maxFileSize) { if (fileInfo.size > process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES) {
throw new Error( throw new Error(
`File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${maxFileSize} bytes.`, `File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES} bytes.`,
); );
} }

View file

@ -315,85 +315,24 @@ function convertToUsername(input, defaultValue = '') {
return defaultValue; return defaultValue;
} }
/**
* Exchange the access token for a Graph-scoped token using the On-Behalf-Of (OBO) flow.
*
* The original access token has the app's own audience (api://<client-id>), which Microsoft Graph
* rejects. This exchange produces a token with audience https://graph.microsoft.com and the
* minimum delegated scope (User.Read) required by /me/getMemberObjects.
*
* Uses a dedicated cache key (`${sub}:overage`) to avoid collisions with other OBO exchanges
* in the codebase (userinfo, Graph principal search).
*
* @param {string} accessToken - The original access token from the OpenID tokenset
* @param {string} sub - The subject identifier for cache keying
* @returns {Promise<string>} A Graph-scoped access token
* @see https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow
*/
async function exchangeTokenForOverage(accessToken, sub) {
if (!openidConfig) {
throw new Error('[openidStrategy] OpenID config not initialized; cannot exchange OBO token');
}
const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS);
const cacheKey = `${sub}:overage`;
const cached = await tokensCache.get(cacheKey);
if (cached?.access_token) {
logger.debug('[openidStrategy] Using cached Graph token for overage resolution');
return cached.access_token;
}
const grantResponse = await client.genericGrantRequest(
openidConfig,
'urn:ietf:params:oauth:grant-type:jwt-bearer',
{
scope: 'https://graph.microsoft.com/User.Read',
assertion: accessToken,
requested_token_use: 'on_behalf_of',
},
);
if (!grantResponse.access_token) {
throw new Error(
'[openidStrategy] OBO exchange succeeded but returned no access_token; cannot call Graph API',
);
}
const ttlMs =
Number.isFinite(grantResponse.expires_in) && grantResponse.expires_in > 0
? grantResponse.expires_in * 1000
: 3600 * 1000;
await tokensCache.set(cacheKey, { access_token: grantResponse.access_token }, ttlMs);
return grantResponse.access_token;
}
/** /**
* Resolve Azure AD groups when group overage is in effect (groups moved to _claim_names/_claim_sources). * Resolve Azure AD groups when group overage is in effect (groups moved to _claim_names/_claim_sources).
* *
* NOTE: Microsoft recommends treating _claim_names/_claim_sources as a signal only and using Microsoft Graph * NOTE: Microsoft recommends treating _claim_names/_claim_sources as a signal only and using Microsoft Graph
* to resolve group membership instead of calling the endpoint in _claim_sources directly. * to resolve group membership instead of calling the endpoint in _claim_sources directly.
* *
* Before calling Graph, the access token is exchanged via the OBO flow to obtain a token with the * @param {string} accessToken - Access token with Microsoft Graph permissions
* correct audience (https://graph.microsoft.com) and User.Read scope.
*
* @param {string} accessToken - Access token from the OpenID tokenset (app audience)
* @param {string} sub - The subject identifier of the user (for OBO exchange and cache keying)
* @returns {Promise<string[] | null>} Resolved group IDs or null on failure * @returns {Promise<string[] | null>} Resolved group IDs or null on failure
* @see https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#groups-overage-claim * @see https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#groups-overage-claim
* @see https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects * @see https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects
*/ */
async function resolveGroupsFromOverage(accessToken, sub) { async function resolveGroupsFromOverage(accessToken) {
try { try {
if (!accessToken) { if (!accessToken) {
logger.error('[openidStrategy] Access token missing; cannot resolve group overage'); logger.error('[openidStrategy] Access token missing; cannot resolve group overage');
return null; return null;
} }
const graphToken = await exchangeTokenForOverage(accessToken, sub);
// Use /me/getMemberObjects so least-privileged delegated permission User.Read is sufficient // Use /me/getMemberObjects so least-privileged delegated permission User.Read is sufficient
// when resolving the signed-in user's group membership. // when resolving the signed-in user's group membership.
const url = 'https://graph.microsoft.com/v1.0/me/getMemberObjects'; const url = 'https://graph.microsoft.com/v1.0/me/getMemberObjects';
@ -405,7 +344,7 @@ async function resolveGroupsFromOverage(accessToken, sub) {
const fetchOptions = { const fetchOptions = {
method: 'POST', method: 'POST',
headers: { headers: {
Authorization: `Bearer ${graphToken}`, Authorization: `Bearer ${accessToken}`,
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
body: JSON.stringify({ securityEnabledOnly: false }), body: JSON.stringify({ securityEnabledOnly: false }),
@ -425,7 +364,6 @@ async function resolveGroupsFromOverage(accessToken, sub) {
} }
const data = await response.json(); const data = await response.json();
const values = Array.isArray(data?.value) ? data.value : null; const values = Array.isArray(data?.value) ? data.value : null;
if (!values) { if (!values) {
logger.error( logger.error(
@ -494,8 +432,6 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
const fullName = getFullName(userinfo); const fullName = getFullName(userinfo);
const requiredRole = process.env.OPENID_REQUIRED_ROLE; const requiredRole = process.env.OPENID_REQUIRED_ROLE;
let resolvedOverageGroups = null;
if (requiredRole) { if (requiredRole) {
const requiredRoles = requiredRole const requiredRoles = requiredRole
.split(',') .split(',')
@ -515,21 +451,19 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
// Handle Azure AD group overage for ID token groups: when hasgroups or _claim_* indicate overage, // Handle Azure AD group overage for ID token groups: when hasgroups or _claim_* indicate overage,
// resolve groups via Microsoft Graph instead of relying on token group values. // resolve groups via Microsoft Graph instead of relying on token group values.
const hasOverage =
decodedToken?.hasgroups ||
(decodedToken?._claim_names?.groups &&
decodedToken?._claim_sources?.[decodedToken._claim_names.groups]);
if ( if (
!Array.isArray(roles) &&
typeof roles !== 'string' &&
requiredRoleTokenKind === 'id' && requiredRoleTokenKind === 'id' &&
requiredRoleParameterPath === 'groups' && requiredRoleParameterPath === 'groups' &&
decodedToken && decodedToken &&
hasOverage (decodedToken.hasgroups ||
(decodedToken._claim_names?.groups &&
decodedToken._claim_sources?.[decodedToken._claim_names.groups]))
) { ) {
const overageGroups = await resolveGroupsFromOverage(tokenset.access_token, claims.sub); const overageGroups = await resolveGroupsFromOverage(tokenset.access_token);
if (overageGroups) { if (overageGroups) {
roles = overageGroups; roles = overageGroups;
resolvedOverageGroups = overageGroups;
} }
} }
@ -616,25 +550,7 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
throw new Error('Invalid admin role token kind'); throw new Error('Invalid admin role token kind');
} }
let adminRoles = get(adminRoleObject, adminRoleParameterPath); const adminRoles = get(adminRoleObject, adminRoleParameterPath);
// Handle Azure AD group overage for admin role when using ID token groups
if (adminRoleTokenKind === 'id' && adminRoleParameterPath === 'groups' && adminRoleObject) {
const hasAdminOverage =
adminRoleObject.hasgroups ||
(adminRoleObject._claim_names?.groups &&
adminRoleObject._claim_sources?.[adminRoleObject._claim_names.groups]);
if (hasAdminOverage) {
const overageGroups =
resolvedOverageGroups ||
(await resolveGroupsFromOverage(tokenset.access_token, claims.sub));
if (overageGroups) {
adminRoles = overageGroups;
}
}
}
let adminRoleValues = []; let adminRoleValues = [];
if (Array.isArray(adminRoles)) { if (Array.isArray(adminRoles)) {
adminRoleValues = adminRoles; adminRoleValues = adminRoles;

View file

@ -64,10 +64,6 @@ jest.mock('openid-client', () => {
// Only return additional properties, but don't override any claims // Only return additional properties, but don't override any claims
return Promise.resolve({}); return Promise.resolve({});
}), }),
genericGrantRequest: jest.fn().mockResolvedValue({
access_token: 'exchanged_graph_token',
expires_in: 3600,
}),
customFetch: Symbol('customFetch'), customFetch: Symbol('customFetch'),
}; };
}); });
@ -734,7 +730,7 @@ describe('setupOpenId', () => {
expect.objectContaining({ expect.objectContaining({
method: 'POST', method: 'POST',
headers: expect.objectContaining({ headers: expect.objectContaining({
Authorization: 'Bearer exchanged_graph_token', Authorization: `Bearer ${tokenset.access_token}`,
}), }),
}), }),
); );
@ -749,313 +745,6 @@ describe('setupOpenId', () => {
); );
}); });
describe('OBO token exchange for overage', () => {
it('exchanges access token via OBO before calling Graph API', async () => {
const openidClient = require('openid-client');
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['group-required'] }),
});
await validate(tokenset);
expect(openidClient.genericGrantRequest).toHaveBeenCalledWith(
expect.anything(),
'urn:ietf:params:oauth:grant-type:jwt-bearer',
expect.objectContaining({
scope: 'https://graph.microsoft.com/User.Read',
assertion: tokenset.access_token,
requested_token_use: 'on_behalf_of',
}),
);
expect(undici.fetch).toHaveBeenCalledWith(
'https://graph.microsoft.com/v1.0/me/getMemberObjects',
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer exchanged_graph_token',
}),
}),
);
});
it('caches the exchanged token and reuses it on subsequent calls', async () => {
const openidClient = require('openid-client');
const getLogStores = require('~/cache/getLogStores');
const mockSet = jest.fn();
const mockGet = jest
.fn()
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce({ access_token: 'exchanged_graph_token' });
getLogStores.mockReturnValue({ get: mockGet, set: mockSet });
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['group-required'] }),
});
// First call: cache miss → OBO exchange → cache set
await validate(tokenset);
expect(mockSet).toHaveBeenCalledWith(
'1234:overage',
{ access_token: 'exchanged_graph_token' },
3600000,
);
expect(openidClient.genericGrantRequest).toHaveBeenCalledTimes(1);
// Second call: cache hit → no new OBO exchange
openidClient.genericGrantRequest.mockClear();
await validate(tokenset);
expect(openidClient.genericGrantRequest).not.toHaveBeenCalled();
});
});
describe('admin role group overage', () => {
it('resolves admin groups via Graph when overage is detected for admin role', async () => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
process.env.OPENID_ADMIN_ROLE = 'admin-group-id';
process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['group-required', 'admin-group-id'] }),
});
const { user } = await validate(tokenset);
expect(user.role).toBe('ADMIN');
});
it('does not grant admin when overage groups do not contain admin role', async () => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
process.env.OPENID_ADMIN_ROLE = 'admin-group-id';
process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['group-required', 'other-group'] }),
});
const { user } = await validate(tokenset);
expect(user).toBeTruthy();
expect(user.role).toBeUndefined();
});
it('reuses already-resolved overage groups for admin role check (no duplicate Graph call)', async () => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
process.env.OPENID_ADMIN_ROLE = 'admin-group-id';
process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['group-required', 'admin-group-id'] }),
});
await validate(tokenset);
// Graph API should be called only once (for required role), admin role reuses the result
expect(undici.fetch).toHaveBeenCalledTimes(1);
});
it('demotes existing admin when overage groups no longer contain admin role', async () => {
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
process.env.OPENID_ADMIN_ROLE = 'admin-group-id';
process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id';
const existingAdminUser = {
_id: 'existingAdminId',
provider: 'openid',
email: tokenset.claims().email,
openidId: tokenset.claims().sub,
username: 'adminuser',
name: 'Admin User',
role: 'ADMIN',
};
findUser.mockImplementation(async (query) => {
if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) {
return existingAdminUser;
}
return null;
});
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['group-required'] }),
});
const { user } = await validate(tokenset);
expect(user.role).toBe('USER');
});
it('does not attempt overage for admin role when token kind is not id', async () => {
process.env.OPENID_REQUIRED_ROLE = 'requiredRole';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
process.env.OPENID_ADMIN_ROLE = 'admin';
process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access';
jwtDecode.mockReturnValue({
roles: ['requiredRole'],
hasgroups: true,
});
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
const { user } = await validate(tokenset);
// No Graph call since admin uses access token (not id)
expect(undici.fetch).not.toHaveBeenCalled();
expect(user.role).toBeUndefined();
});
it('resolves admin via Graph independently when OPENID_REQUIRED_ROLE is not configured', async () => {
delete process.env.OPENID_REQUIRED_ROLE;
process.env.OPENID_ADMIN_ROLE = 'admin-group-id';
process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['admin-group-id'] }),
});
const { user } = await validate(tokenset);
expect(user.role).toBe('ADMIN');
expect(undici.fetch).toHaveBeenCalledTimes(1);
});
it('denies admin when OPENID_REQUIRED_ROLE is absent and Graph does not contain admin group', async () => {
delete process.env.OPENID_REQUIRED_ROLE;
process.env.OPENID_ADMIN_ROLE = 'admin-group-id';
process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
undici.fetch.mockResolvedValue({
ok: true,
status: 200,
statusText: 'OK',
json: async () => ({ value: ['other-group'] }),
});
const { user } = await validate(tokenset);
expect(user).toBeTruthy();
expect(user.role).toBeUndefined();
});
it('denies login and logs error when OBO exchange throws', async () => {
const openidClient = require('openid-client');
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
openidClient.genericGrantRequest.mockRejectedValueOnce(new Error('OBO exchange rejected'));
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
const { user, details } = await validate(tokenset);
expect(user).toBe(false);
expect(details.message).toBe('You must have "group-required" role to log in.');
expect(undici.fetch).not.toHaveBeenCalled();
});
it('denies login when OBO exchange returns no access_token', async () => {
const openidClient = require('openid-client');
process.env.OPENID_REQUIRED_ROLE = 'group-required';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
jwtDecode.mockReturnValue({ hasgroups: true });
openidClient.genericGrantRequest.mockResolvedValueOnce({ expires_in: 3600 });
await setupOpenId();
verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid');
const { user, details } = await validate(tokenset);
expect(user).toBe(false);
expect(details.message).toBe('You must have "group-required" role to log in.');
expect(undici.fetch).not.toHaveBeenCalled();
});
});
it('should attempt to download and save the avatar if picture is provided', async () => { it('should attempt to download and save the avatar if picture is provided', async () => {
// Act // Act
const { user } = await validate(tokenset); const { user } = await validate(tokenset);

View file

@ -1,4 +1,5 @@
// --- Mocks --- // --- Mocks ---
jest.mock('tiktoken');
jest.mock('fs'); jest.mock('fs');
jest.mock('path'); jest.mock('path');
jest.mock('node-fetch'); jest.mock('node-fetch');

View file

@ -214,25 +214,6 @@ describe('getModelMaxTokens', () => {
); );
}); });
test('should return correct tokens for gpt-5.4 matches', () => {
expect(getModelMaxTokens('gpt-5.4')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-5.4']);
expect(getModelMaxTokens('gpt-5.4-thinking')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-5.4'],
);
expect(getModelMaxTokens('openai/gpt-5.4')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-5.4'],
);
});
test('should return correct tokens for gpt-5.4-pro matches', () => {
expect(getModelMaxTokens('gpt-5.4-pro')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-5.4-pro'],
);
expect(getModelMaxTokens('openai/gpt-5.4-pro')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-5.4-pro'],
);
});
test('should return correct tokens for Anthropic models', () => { test('should return correct tokens for Anthropic models', () => {
const models = [ const models = [
'claude-2.1', 'claude-2.1',
@ -270,6 +251,16 @@ describe('getModelMaxTokens', () => {
}); });
}); });
// Tests for Google models
test('should return correct tokens for exact match - Google models', () => {
expect(getModelMaxTokens('text-bison-32k', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['text-bison-32k'],
);
expect(getModelMaxTokens('codechat-bison-32k', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['codechat-bison-32k'],
);
});
test('should return undefined for no match - Google models', () => { test('should return undefined for no match - Google models', () => {
expect(getModelMaxTokens('unknown-google-model', EModelEndpoint.google)).toBeUndefined(); expect(getModelMaxTokens('unknown-google-model', EModelEndpoint.google)).toBeUndefined();
}); });
@ -326,6 +317,12 @@ describe('getModelMaxTokens', () => {
expect(getModelMaxTokens('gemini-pro', EModelEndpoint.google)).toBe( expect(getModelMaxTokens('gemini-pro', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['gemini'], maxTokensMap[EModelEndpoint.google]['gemini'],
); );
expect(getModelMaxTokens('code-', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['code-'],
);
expect(getModelMaxTokens('chat-', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['chat-'],
);
}); });
test('should return correct tokens for partial match - Cohere models', () => { test('should return correct tokens for partial match - Cohere models', () => {
@ -514,8 +511,6 @@ describe('getModelMaxTokens', () => {
'gpt-5.1', 'gpt-5.1',
'gpt-5.2', 'gpt-5.2',
'gpt-5.3', 'gpt-5.3',
'gpt-5.4',
'gpt-5.4-pro',
'gpt-5-mini', 'gpt-5-mini',
'gpt-5-nano', 'gpt-5-nano',
'gpt-5-pro', 'gpt-5-pro',
@ -546,184 +541,6 @@ describe('getModelMaxTokens', () => {
}); });
}); });
describe('findMatchingPattern - longest match wins', () => {
test('should prefer longer matching key over shorter cross-provider pattern', () => {
const result = findMatchingPattern(
'gpt-5.2-chat-2025-12-11',
maxTokensMap[EModelEndpoint.openAI],
);
expect(result).toBe('gpt-5.2');
});
test('should match gpt-5.2 tokens for date-suffixed chat variant', () => {
expect(getModelMaxTokens('gpt-5.2-chat-2025-12-11')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-5.2'],
);
});
test('should match gpt-5.2-pro over shorter patterns', () => {
expect(getModelMaxTokens('gpt-5.2-pro-chat-2025-12-11')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-5.2-pro'],
);
});
test('should match gpt-5-mini over gpt-5 for mini variants', () => {
expect(getModelMaxTokens('gpt-5-mini-chat-2025-01-01')).toBe(
maxTokensMap[EModelEndpoint.openAI]['gpt-5-mini'],
);
});
test('should prefer gpt-4-1106 over gpt-4 for versioned model names', () => {
const result = findMatchingPattern('gpt-4-1106-preview', maxTokensMap[EModelEndpoint.openAI]);
expect(result).toBe('gpt-4-1106');
});
test('should prefer gpt-4-32k-0613 over gpt-4-32k for exact versioned names', () => {
const result = findMatchingPattern('gpt-4-32k-0613', maxTokensMap[EModelEndpoint.openAI]);
expect(result).toBe('gpt-4-32k-0613');
});
test('should prefer claude-3-5-sonnet over claude-3', () => {
const result = findMatchingPattern(
'claude-3-5-sonnet-20241022',
maxTokensMap[EModelEndpoint.anthropic],
);
expect(result).toBe('claude-3-5-sonnet');
});
test('should prefer gemini-2.0-flash-lite over gemini-2.0-flash', () => {
const result = findMatchingPattern(
'gemini-2.0-flash-lite-preview',
maxTokensMap[EModelEndpoint.google],
);
expect(result).toBe('gemini-2.0-flash-lite');
});
});
describe('findMatchingPattern - bestLength selection', () => {
test('should return the longest matching key when multiple keys match', () => {
const tokensMap = { short: 100, 'short-med': 200, 'short-med-long': 300 };
expect(findMatchingPattern('short-med-long-extra', tokensMap)).toBe('short-med-long');
});
test('should return the longest match regardless of key insertion order', () => {
const tokensMap = { 'a-b-c': 300, a: 100, 'a-b': 200 };
expect(findMatchingPattern('a-b-c-d', tokensMap)).toBe('a-b-c');
});
test('should return null when no key matches', () => {
const tokensMap = { alpha: 100, beta: 200 };
expect(findMatchingPattern('gamma-delta', tokensMap)).toBeNull();
});
test('should return the single matching key when only one matches', () => {
const tokensMap = { alpha: 100, beta: 200, gamma: 300 };
expect(findMatchingPattern('beta-extended', tokensMap)).toBe('beta');
});
test('should match case-insensitively against model name', () => {
const tokensMap = { 'gpt-5': 400000 };
expect(findMatchingPattern('GPT-5-turbo', tokensMap)).toBe('gpt-5');
});
test('should select the longest key among overlapping substring matches', () => {
const tokensMap = { 'gpt-': 100, 'gpt-5': 200, 'gpt-5.2': 300, 'gpt-5.2-pro': 400 };
expect(findMatchingPattern('gpt-5.2-pro-2025-01-01', tokensMap)).toBe('gpt-5.2-pro');
expect(findMatchingPattern('gpt-5.2-chat-2025-01-01', tokensMap)).toBe('gpt-5.2');
expect(findMatchingPattern('gpt-5.1-preview', tokensMap)).toBe('gpt-5');
expect(findMatchingPattern('gpt-unknown', tokensMap)).toBe('gpt-');
});
test('should not be confused by a short key that appears later in the model name', () => {
const tokensMap = { 'model-v2': 200, v2: 100 };
expect(findMatchingPattern('model-v2-extended', tokensMap)).toBe('model-v2');
});
test('should handle exact-length match as the best match', () => {
const tokensMap = { 'exact-model': 500, exact: 100 };
expect(findMatchingPattern('exact-model', tokensMap)).toBe('exact-model');
});
test('should return null for empty model name', () => {
expect(findMatchingPattern('', { 'gpt-5': 400000 })).toBeNull();
});
test('should prefer last-defined key on same-length ties', () => {
const tokensMap = { 'aa-bb': 100, 'cc-dd': 200 };
// model name contains both 5-char keys; last-defined wins in reverse iteration
expect(findMatchingPattern('aa-bb-cc-dd', tokensMap)).toBe('cc-dd');
});
test('longest match beats short cross-provider pattern even when both present', () => {
const tokensMap = { 'gpt-5.2': 400000, 'chat-': 8187 };
expect(findMatchingPattern('gpt-5.2-chat-2025-12-11', tokensMap)).toBe('gpt-5.2');
});
test('should match case-insensitively against keys', () => {
const tokensMap = { 'GPT-5': 400000 };
expect(findMatchingPattern('gpt-5-turbo', tokensMap)).toBe('GPT-5');
});
});
describe('findMatchingPattern - iteration performance', () => {
let includesSpy;
beforeEach(() => {
includesSpy = jest.spyOn(String.prototype, 'includes');
});
afterEach(() => {
includesSpy.mockRestore();
});
test('exact match early-exits with minimal includes() checks', () => {
const openAIMap = maxTokensMap[EModelEndpoint.openAI];
const keys = Object.keys(openAIMap);
const lastKey = keys[keys.length - 1];
includesSpy.mockClear();
const result = findMatchingPattern(lastKey, openAIMap);
const exactCalls = includesSpy.mock.calls.length;
expect(result).toBe(lastKey);
expect(exactCalls).toBe(1);
});
test('bestLength check skips includes() for shorter keys after a long match', () => {
const openAIMap = maxTokensMap[EModelEndpoint.openAI];
includesSpy.mockClear();
findMatchingPattern('gpt-3.5-turbo-0301-test', openAIMap);
const longKeyCalls = includesSpy.mock.calls.length;
includesSpy.mockClear();
findMatchingPattern('gpt-5.3-chat-latest', openAIMap);
const shortKeyCalls = includesSpy.mock.calls.length;
// gpt-3.5-turbo-0301 (20 chars) matches early, then bestLength prunes most keys
// gpt-5.3 (7 chars) is short, so fewer keys are pruned by the length check
expect(longKeyCalls).toBeLessThan(shortKeyCalls);
});
test('last-defined keys are checked first in reverse iteration', () => {
const tokensMap = { first: 100, second: 200, third: 300 };
includesSpy.mockClear();
const result = findMatchingPattern('third', tokensMap);
const calls = includesSpy.mock.calls.length;
// 'third' is last key, found on first reverse check, exact match exits immediately
expect(result).toBe('third');
expect(calls).toBe(1);
});
});
describe('deprecated PaLM2/Codey model removal', () => {
test('deprecated PaLM2/Codey models no longer have token entries', () => {
expect(getModelMaxTokens('text-bison-32k', EModelEndpoint.google)).toBeUndefined();
expect(getModelMaxTokens('codechat-bison-32k', EModelEndpoint.google)).toBeUndefined();
expect(getModelMaxTokens('code-bison', EModelEndpoint.google)).toBeUndefined();
expect(getModelMaxTokens('chat-bison', EModelEndpoint.google)).toBeUndefined();
});
});
describe('matchModelName', () => { describe('matchModelName', () => {
it('should return the exact model name if it exists in maxTokensMap', () => { it('should return the exact model name if it exists in maxTokensMap', () => {
expect(matchModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613'); expect(matchModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
@ -825,10 +642,10 @@ describe('matchModelName', () => {
expect(matchModelName('gpt-5.3-2025-03-01')).toBe('gpt-5.3'); expect(matchModelName('gpt-5.3-2025-03-01')).toBe('gpt-5.3');
}); });
it('should return the closest matching key for gpt-5.4 matches', () => { // Tests for Google models
expect(matchModelName('openai/gpt-5.4')).toBe('gpt-5.4'); it('should return the exact model name if it exists in maxTokensMap - Google models', () => {
expect(matchModelName('gpt-5.4-thinking')).toBe('gpt-5.4'); expect(matchModelName('text-bison-32k', EModelEndpoint.google)).toBe('text-bison-32k');
expect(matchModelName('gpt-5.4-pro')).toBe('gpt-5.4-pro'); expect(matchModelName('codechat-bison-32k', EModelEndpoint.google)).toBe('codechat-bison-32k');
}); });
it('should return the input model name if no match is found - Google models', () => { it('should return the input model name if no match is found - Google models', () => {
@ -836,6 +653,11 @@ describe('matchModelName', () => {
'unknown-google-model', 'unknown-google-model',
); );
}); });
it('should return the closest matching key for partial matches - Google models', () => {
expect(matchModelName('code-', EModelEndpoint.google)).toBe('code-');
expect(matchModelName('chat-', EModelEndpoint.google)).toBe('chat-');
});
}); });
describe('Meta Models Tests', () => { describe('Meta Models Tests', () => {

4216
bun.lock

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
/** v0.8.3 */ /** v0.8.3-rc2 */
module.exports = { module.exports = {
roots: ['<rootDir>/src'], roots: ['<rootDir>/src'],
testEnvironment: 'jsdom', testEnvironment: 'jsdom',
@ -32,7 +32,6 @@ module.exports = {
'^librechat-data-provider/react-query$': '^librechat-data-provider/react-query$':
'<rootDir>/../node_modules/librechat-data-provider/src/react-query', '<rootDir>/../node_modules/librechat-data-provider/src/react-query',
}, },
maxWorkers: '50%',
restoreMocks: true, restoreMocks: true,
testResultsProcessor: 'jest-junit', testResultsProcessor: 'jest-junit',
coverageReporters: ['text', 'cobertura', 'lcov'], coverageReporters: ['text', 'cobertura', 'lcov'],

View file

@ -1,6 +1,6 @@
{ {
"name": "@librechat/frontend", "name": "@librechat/frontend",
"version": "v0.8.3", "version": "v0.8.3-rc2",
"description": "", "description": "",
"type": "module", "type": "module",
"scripts": { "scripts": {
@ -38,7 +38,6 @@
"@librechat/client": "*", "@librechat/client": "*",
"@marsidev/react-turnstile": "^1.1.0", "@marsidev/react-turnstile": "^1.1.0",
"@mcp-ui/client": "^5.7.0", "@mcp-ui/client": "^5.7.0",
"@monaco-editor/react": "^4.7.0",
"@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "1.0.2", "@radix-ui/react-alert-dialog": "1.0.2",
"@radix-ui/react-checkbox": "^1.0.3", "@radix-ui/react-checkbox": "^1.0.3",
@ -81,7 +80,7 @@
"lodash": "^4.17.23", "lodash": "^4.17.23",
"lucide-react": "^0.394.0", "lucide-react": "^0.394.0",
"match-sorter": "^8.1.0", "match-sorter": "^8.1.0",
"mermaid": "^11.13.0", "mermaid": "^11.12.3",
"micromark-extension-llm-math": "^3.1.0", "micromark-extension-llm-math": "^3.1.0",
"qrcode.react": "^4.2.0", "qrcode.react": "^4.2.0",
"rc-input-number": "^7.4.2", "rc-input-number": "^7.4.2",
@ -94,6 +93,7 @@
"react-gtm-module": "^2.0.11", "react-gtm-module": "^2.0.11",
"react-hook-form": "^7.43.9", "react-hook-form": "^7.43.9",
"react-i18next": "^15.4.0", "react-i18next": "^15.4.0",
"react-lazy-load-image-component": "^1.6.0",
"react-markdown": "^9.0.1", "react-markdown": "^9.0.1",
"react-resizable-panels": "^3.0.6", "react-resizable-panels": "^3.0.6",
"react-router-dom": "^6.30.3", "react-router-dom": "^6.30.3",
@ -122,7 +122,6 @@
"@babel/preset-env": "^7.22.15", "@babel/preset-env": "^7.22.15",
"@babel/preset-react": "^7.22.15", "@babel/preset-react": "^7.22.15",
"@babel/preset-typescript": "^7.22.15", "@babel/preset-typescript": "^7.22.15",
"@happy-dom/jest-environment": "^20.8.3",
"@tanstack/react-query-devtools": "^4.29.0", "@tanstack/react-query-devtools": "^4.29.0",
"@testing-library/dom": "^9.3.0", "@testing-library/dom": "^9.3.0",
"@testing-library/jest-dom": "^5.16.5", "@testing-library/jest-dom": "^5.16.5",
@ -145,10 +144,9 @@
"identity-obj-proxy": "^3.0.0", "identity-obj-proxy": "^3.0.0",
"jest": "^30.2.0", "jest": "^30.2.0",
"jest-canvas-mock": "^2.5.2", "jest-canvas-mock": "^2.5.2",
"jest-environment-jsdom": "^30.2.0", "jest-environment-jsdom": "^29.7.0",
"jest-file-loader": "^1.0.3", "jest-file-loader": "^1.0.3",
"jest-junit": "^16.0.0", "jest-junit": "^16.0.0",
"monaco-editor": "^0.55.1",
"postcss": "^8.4.31", "postcss": "^8.4.31",
"postcss-preset-env": "^11.2.0", "postcss-preset-env": "^11.2.0",
"tailwindcss": "^3.4.1", "tailwindcss": "^3.4.1",

View file

@ -1,8 +1,7 @@
import React, { createContext, useContext, useMemo } from 'react'; import React, { createContext, useContext, useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import type { TMessage } from 'librechat-data-provider'; import type { TMessage } from 'librechat-data-provider';
import { useChatContext } from './ChatContext';
import { getLatestText } from '~/utils'; import { getLatestText } from '~/utils';
import store from '~/store';
export interface ArtifactsContextValue { export interface ArtifactsContextValue {
isSubmitting: boolean; isSubmitting: boolean;
@ -19,28 +18,27 @@ interface ArtifactsProviderProps {
} }
export function ArtifactsProvider({ children, value }: ArtifactsProviderProps) { export function ArtifactsProvider({ children, value }: ArtifactsProviderProps) {
const isSubmitting = useRecoilValue(store.isSubmittingFamily(0)); const { isSubmitting, latestMessage, conversation } = useChatContext();
const latestMessage = useRecoilValue(store.latestMessageFamily(0));
const conversationId = useRecoilValue(store.conversationIdByIndex(0));
const chatLatestMessageText = useMemo(() => { const chatLatestMessageText = useMemo(() => {
return getLatestText({ return getLatestText({
messageId: latestMessage?.messageId ?? null,
text: latestMessage?.text ?? null, text: latestMessage?.text ?? null,
content: latestMessage?.content ?? null, content: latestMessage?.content ?? null,
messageId: latestMessage?.messageId ?? null,
} as TMessage); } as TMessage);
}, [latestMessage?.messageId, latestMessage?.text, latestMessage?.content]); }, [latestMessage?.messageId, latestMessage?.text, latestMessage?.content]);
const defaultContextValue = useMemo<ArtifactsContextValue>( const defaultContextValue = useMemo<ArtifactsContextValue>(
() => ({ () => ({
isSubmitting, isSubmitting,
conversationId: conversationId ?? null,
latestMessageText: chatLatestMessageText, latestMessageText: chatLatestMessageText,
latestMessageId: latestMessage?.messageId ?? null, latestMessageId: latestMessage?.messageId ?? null,
conversationId: conversation?.conversationId ?? null,
}), }),
[isSubmitting, chatLatestMessageText, latestMessage?.messageId, conversationId], [isSubmitting, chatLatestMessageText, latestMessage?.messageId, conversation?.conversationId],
); );
/** Context value only created when relevant values change */
const contextValue = useMemo<ArtifactsContextValue>( const contextValue = useMemo<ArtifactsContextValue>(
() => (value ? { ...defaultContextValue, ...value } : defaultContextValue), () => (value ? { ...defaultContextValue, ...value } : defaultContextValue),
[defaultContextValue, value], [defaultContextValue, value],

View file

@ -1,5 +1,5 @@
import React, { createContext, useContext, useMemo } from 'react'; import React, { createContext, useContext, useMemo } from 'react';
import { isAgentsEndpoint, resolveEndpointType } from 'librechat-data-provider'; import { getEndpointField, isAgentsEndpoint } from 'librechat-data-provider';
import type { EModelEndpoint } from 'librechat-data-provider'; import type { EModelEndpoint } from 'librechat-data-provider';
import { useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider'; import { useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider';
import { useAgentsMapContext } from './AgentsMapContext'; import { useAgentsMapContext } from './AgentsMapContext';
@ -9,7 +9,7 @@ interface DragDropContextValue {
conversationId: string | null | undefined; conversationId: string | null | undefined;
agentId: string | null | undefined; agentId: string | null | undefined;
endpoint: string | null | undefined; endpoint: string | null | undefined;
endpointType?: EModelEndpoint | string | undefined; endpointType?: EModelEndpoint | undefined;
useResponsesApi?: boolean; useResponsesApi?: boolean;
} }
@ -20,6 +20,13 @@ export function DragDropProvider({ children }: { children: React.ReactNode }) {
const { data: endpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig } = useGetEndpointsQuery();
const agentsMap = useAgentsMapContext(); const agentsMap = useAgentsMapContext();
const endpointType = useMemo(() => {
return (
getEndpointField(endpointsConfig, conversation?.endpoint, 'type') ||
(conversation?.endpoint as EModelEndpoint | undefined)
);
}, [conversation?.endpoint, endpointsConfig]);
const needsAgentFetch = useMemo(() => { const needsAgentFetch = useMemo(() => {
const isAgents = isAgentsEndpoint(conversation?.endpoint); const isAgents = isAgentsEndpoint(conversation?.endpoint);
if (!isAgents || !conversation?.agent_id) { if (!isAgents || !conversation?.agent_id) {
@ -33,20 +40,6 @@ export function DragDropProvider({ children }: { children: React.ReactNode }) {
enabled: needsAgentFetch, enabled: needsAgentFetch,
}); });
const agentProvider = useMemo(() => {
const isAgents = isAgentsEndpoint(conversation?.endpoint);
if (!isAgents || !conversation?.agent_id) {
return undefined;
}
const agent = agentData || agentsMap?.[conversation.agent_id];
return agent?.provider;
}, [conversation?.endpoint, conversation?.agent_id, agentData, agentsMap]);
const endpointType = useMemo(
() => resolveEndpointType(endpointsConfig, conversation?.endpoint, agentProvider),
[endpointsConfig, conversation?.endpoint, agentProvider],
);
const useResponsesApi = useMemo(() => { const useResponsesApi = useMemo(() => {
const isAgents = isAgentsEndpoint(conversation?.endpoint); const isAgents = isAgentsEndpoint(conversation?.endpoint);
if (!isAgents || !conversation?.agent_id || conversation?.useResponsesApi) { if (!isAgents || !conversation?.agent_id || conversation?.useResponsesApi) {

View file

@ -18,8 +18,7 @@ interface MessagesViewContextValue {
/** Message state management */ /** Message state management */
index: ReturnType<typeof useChatContext>['index']; index: ReturnType<typeof useChatContext>['index'];
latestMessageId: ReturnType<typeof useChatContext>['latestMessageId']; latestMessage: ReturnType<typeof useChatContext>['latestMessage'];
latestMessageDepth: ReturnType<typeof useChatContext>['latestMessageDepth'];
setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage']; setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage'];
getMessages: ReturnType<typeof useChatContext>['getMessages']; getMessages: ReturnType<typeof useChatContext>['getMessages'];
setMessages: ReturnType<typeof useChatContext>['setMessages']; setMessages: ReturnType<typeof useChatContext>['setMessages'];
@ -40,8 +39,7 @@ export function MessagesViewProvider({ children }: { children: React.ReactNode }
regenerate, regenerate,
isSubmitting, isSubmitting,
conversation, conversation,
latestMessageId, latestMessage,
latestMessageDepth,
setAbortScroll, setAbortScroll,
handleContinue, handleContinue,
setLatestMessage, setLatestMessage,
@ -85,11 +83,10 @@ export function MessagesViewProvider({ children }: { children: React.ReactNode }
const messageState = useMemo( const messageState = useMemo(
() => ({ () => ({
index, index,
latestMessageId, latestMessage,
latestMessageDepth,
setLatestMessage, setLatestMessage,
}), }),
[index, latestMessageId, latestMessageDepth, setLatestMessage], [index, latestMessage, setLatestMessage],
); );
/** Combine all values into final context value */ /** Combine all values into final context value */
@ -142,9 +139,9 @@ export function useMessagesOperations() {
/** Hook for components that only need message state */ /** Hook for components that only need message state */
export function useMessagesState() { export function useMessagesState() {
const { index, latestMessageId, latestMessageDepth, setLatestMessage } = useMessagesViewContext(); const { index, latestMessage, setLatestMessage } = useMessagesViewContext();
return useMemo( return useMemo(
() => ({ index, latestMessageId, latestMessageDepth, setLatestMessage }), () => ({ index, latestMessage, setLatestMessage }),
[index, latestMessageId, latestMessageDepth, setLatestMessage], [index, latestMessage, setLatestMessage],
); );
} }

View file

@ -1,134 +0,0 @@
import React from 'react';
import { renderHook } from '@testing-library/react';
import { EModelEndpoint } from 'librechat-data-provider';
import type { TEndpointsConfig, Agent } from 'librechat-data-provider';
import { DragDropProvider, useDragDropContext } from '../DragDropContext';
const mockEndpointsConfig: TEndpointsConfig = {
[EModelEndpoint.openAI]: { userProvide: false, order: 0 },
[EModelEndpoint.agents]: { userProvide: false, order: 1 },
[EModelEndpoint.anthropic]: { userProvide: false, order: 6 },
Moonshot: { type: EModelEndpoint.custom, userProvide: false, order: 9999 },
'Some Endpoint': { type: EModelEndpoint.custom, userProvide: false, order: 9999 },
};
let mockConversation: Record<string, unknown> | null = null;
let mockAgentsMap: Record<string, Partial<Agent>> = {};
let mockAgentQueryData: Partial<Agent> | undefined;
jest.mock('~/data-provider', () => ({
useGetEndpointsQuery: () => ({ data: mockEndpointsConfig }),
useGetAgentByIdQuery: () => ({ data: mockAgentQueryData }),
}));
jest.mock('../AgentsMapContext', () => ({
useAgentsMapContext: () => mockAgentsMap,
}));
jest.mock('../ChatContext', () => ({
useChatContext: () => ({ conversation: mockConversation }),
}));
function wrapper({ children }: { children: React.ReactNode }) {
return <DragDropProvider>{children}</DragDropProvider>;
}
describe('DragDropContext endpointType resolution', () => {
beforeEach(() => {
mockConversation = null;
mockAgentsMap = {};
mockAgentQueryData = undefined;
});
describe('non-agents endpoints', () => {
it('resolves custom endpoint type for a custom endpoint', () => {
mockConversation = { endpoint: 'Moonshot' };
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.custom);
});
it('resolves endpoint name for a standard endpoint', () => {
mockConversation = { endpoint: EModelEndpoint.openAI };
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.openAI);
});
});
describe('agents endpoint with provider from agentsMap', () => {
it('resolves to custom for agent with Moonshot provider', () => {
mockConversation = { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' };
mockAgentsMap = {
'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial<Agent>,
};
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.custom);
});
it('resolves to custom for agent with custom provider with spaces', () => {
mockConversation = { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' };
mockAgentsMap = {
'agent-1': { provider: 'Some Endpoint', model_parameters: {} } as Partial<Agent>,
};
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.custom);
});
it('resolves to openAI for agent with openAI provider', () => {
mockConversation = { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' };
mockAgentsMap = {
'agent-1': { provider: EModelEndpoint.openAI, model_parameters: {} } as Partial<Agent>,
};
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.openAI);
});
it('resolves to anthropic for agent with anthropic provider', () => {
mockConversation = { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' };
mockAgentsMap = {
'agent-1': { provider: EModelEndpoint.anthropic, model_parameters: {} } as Partial<Agent>,
};
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.anthropic);
});
});
describe('agents endpoint with provider from agentData query', () => {
it('uses agentData when agent is not in agentsMap', () => {
mockConversation = { endpoint: EModelEndpoint.agents, agent_id: 'agent-2' };
mockAgentsMap = {};
mockAgentQueryData = { provider: 'Moonshot' } as Partial<Agent>;
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.custom);
});
});
describe('agents endpoint without provider', () => {
it('falls back to agents when no agent_id', () => {
mockConversation = { endpoint: EModelEndpoint.agents };
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.agents);
});
it('falls back to agents when agent has no provider', () => {
mockConversation = { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' };
mockAgentsMap = { 'agent-1': { model_parameters: {} } as Partial<Agent> };
const { result } = renderHook(() => useDragDropContext(), { wrapper });
expect(result.current.endpointType).toBe(EModelEndpoint.agents);
});
});
describe('consistency: same endpoint type whether used directly or through agents', () => {
it('Moonshot resolves to the same type as direct endpoint and as agent provider', () => {
mockConversation = { endpoint: 'Moonshot' };
const { result: directResult } = renderHook(() => useDragDropContext(), { wrapper });
mockConversation = { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' };
mockAgentsMap = {
'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial<Agent>,
};
const { result: agentResult } = renderHook(() => useDragDropContext(), { wrapper });
expect(directResult.current.endpointType).toBe(agentResult.current.endpointType);
});
});
});

View file

@ -56,13 +56,10 @@ const LiveAnnouncer: React.FC<LiveAnnouncerProps> = ({ children }) => {
const announceAssertive = announcePolite; const announceAssertive = announcePolite;
const contextValue = useMemo( const contextValue = {
() => ({
announcePolite, announcePolite,
announceAssertive, announceAssertive,
}), };
[announcePolite, announceAssertive],
);
useEffect(() => { useEffect(() => {
return () => { return () => {

View file

@ -1,75 +1,37 @@
import React, { useMemo, useState, useEffect, useRef, useCallback } from 'react'; import React, { useMemo, useState, useEffect, useRef, memo } from 'react';
import debounce from 'lodash/debounce'; import debounce from 'lodash/debounce';
import MonacoEditor from '@monaco-editor/react'; import { KeyBinding } from '@codemirror/view';
import type { Monaco } from '@monaco-editor/react'; import { autocompletion, completionKeymap } from '@codemirror/autocomplete';
import type { editor } from 'monaco-editor'; import {
import type { Artifact } from '~/common'; useSandpack,
SandpackCodeEditor,
SandpackProvider as StyledProvider,
} from '@codesandbox/sandpack-react';
import type { SandpackProviderProps } from '@codesandbox/sandpack-react/unstyled';
import type { SandpackBundlerFile } from '@codesandbox/sandpack-client';
import type { CodeEditorRef } from '@codesandbox/sandpack-react';
import type { ArtifactFiles, Artifact } from '~/common';
import { useEditArtifact, useGetStartupConfig } from '~/data-provider';
import { useMutationState, useCodeState } from '~/Providers/EditorContext'; import { useMutationState, useCodeState } from '~/Providers/EditorContext';
import { useArtifactsContext } from '~/Providers'; import { useArtifactsContext } from '~/Providers';
import { useEditArtifact } from '~/data-provider'; import { sharedFiles, sharedOptions } from '~/utils/artifacts';
const LANG_MAP: Record<string, string> = { const CodeEditor = memo(
javascript: 'javascript', ({
typescript: 'typescript', fileKey,
python: 'python', readOnly,
css: 'css',
json: 'json',
markdown: 'markdown',
html: 'html',
xml: 'xml',
sql: 'sql',
yaml: 'yaml',
shell: 'shell',
bash: 'shell',
tsx: 'typescript',
jsx: 'javascript',
c: 'c',
cpp: 'cpp',
java: 'java',
go: 'go',
rust: 'rust',
kotlin: 'kotlin',
swift: 'swift',
php: 'php',
ruby: 'ruby',
r: 'r',
lua: 'lua',
scala: 'scala',
perl: 'perl',
};
const TYPE_MAP: Record<string, string> = {
'text/html': 'html',
'application/vnd.code-html': 'html',
'application/vnd.react': 'typescript',
'application/vnd.ant.react': 'typescript',
'text/markdown': 'markdown',
'text/md': 'markdown',
'text/plain': 'plaintext',
'application/vnd.mermaid': 'markdown',
};
function getMonacoLanguage(type?: string, language?: string): string {
if (language && LANG_MAP[language]) {
return LANG_MAP[language];
}
return TYPE_MAP[type ?? ''] ?? 'plaintext';
}
export const ArtifactCodeEditor = function ArtifactCodeEditor({
artifact, artifact,
monacoRef, editorRef,
readOnly: externalReadOnly,
}: { }: {
artifact: Artifact; fileKey: string;
monacoRef: React.MutableRefObject<editor.IStandaloneCodeEditor | null>;
readOnly?: boolean; readOnly?: boolean;
}) { artifact: Artifact;
const { isSubmitting } = useArtifactsContext(); editorRef: React.MutableRefObject<CodeEditorRef>;
const readOnly = (externalReadOnly ?? false) || isSubmitting; }) => {
const { setCurrentCode } = useCodeState(); const { sandpack } = useSandpack();
const [currentUpdate, setCurrentUpdate] = useState<string | null>(null); const [currentUpdate, setCurrentUpdate] = useState<string | null>(null);
const { isMutating, setIsMutating } = useMutationState(); const { isMutating, setIsMutating } = useMutationState();
const { setCurrentCode } = useCodeState();
const editArtifact = useEditArtifact({ const editArtifact = useEditArtifact({
onMutate: (vars) => { onMutate: (vars) => {
setIsMutating(true); setIsMutating(true);
@ -84,38 +46,68 @@ export const ArtifactCodeEditor = function ArtifactCodeEditor({
}, },
}); });
/**
* Create stable debounced mutation that doesn't depend on changing callbacks
* Use refs to always access the latest values without recreating the debounce
*/
const artifactRef = useRef(artifact); const artifactRef = useRef(artifact);
const isMutatingRef = useRef(isMutating); const isMutatingRef = useRef(isMutating);
const currentUpdateRef = useRef(currentUpdate); const currentUpdateRef = useRef(currentUpdate);
const editArtifactRef = useRef(editArtifact); const editArtifactRef = useRef(editArtifact);
const setCurrentCodeRef = useRef(setCurrentCode); const setCurrentCodeRef = useRef(setCurrentCode);
const prevContentRef = useRef(artifact.content ?? '');
const prevArtifactId = useRef(artifact.id);
const prevReadOnly = useRef(readOnly);
useEffect(() => {
artifactRef.current = artifact; artifactRef.current = artifact;
isMutatingRef.current = isMutating; }, [artifact]);
currentUpdateRef.current = currentUpdate;
editArtifactRef.current = editArtifact;
setCurrentCodeRef.current = setCurrentCode;
useEffect(() => {
isMutatingRef.current = isMutating;
}, [isMutating]);
useEffect(() => {
currentUpdateRef.current = currentUpdate;
}, [currentUpdate]);
useEffect(() => {
editArtifactRef.current = editArtifact;
}, [editArtifact]);
useEffect(() => {
setCurrentCodeRef.current = setCurrentCode;
}, [setCurrentCode]);
/**
* Create debounced mutation once - never recreate it
* All values are accessed via refs so they're always current
*/
const debouncedMutation = useMemo( const debouncedMutation = useMemo(
() => () =>
debounce((code: string) => { debounce((code: string) => {
if (readOnly || isMutatingRef.current || artifactRef.current.index == null) { if (readOnly) {
return;
}
if (isMutatingRef.current) {
return;
}
if (artifactRef.current.index == null) {
return; return;
} }
const art = artifactRef.current;
const isNotOriginal = art.content != null && code.trim() !== art.content.trim();
const isNotRepeated =
currentUpdateRef.current == null ? true : code.trim() !== currentUpdateRef.current.trim();
if (art.content != null && isNotOriginal && isNotRepeated && art.index != null) { const artifact = artifactRef.current;
const artifactIndex = artifact.index;
const isNotOriginal =
code && artifact.content != null && code.trim() !== artifact.content.trim();
const isNotRepeated =
currentUpdateRef.current == null
? true
: code != null && code.trim() !== currentUpdateRef.current.trim();
if (artifact.content && isNotOriginal && isNotRepeated && artifactIndex != null) {
setCurrentCodeRef.current(code); setCurrentCodeRef.current(code);
editArtifactRef.current.mutate({ editArtifactRef.current.mutate({
index: art.index, index: artifactIndex,
messageId: art.messageId ?? '', messageId: artifact.messageId ?? '',
original: art.content, original: artifact.content,
updated: code, updated: code,
}); });
} }
@ -123,204 +115,92 @@ export const ArtifactCodeEditor = function ArtifactCodeEditor({
[readOnly], [readOnly],
); );
/**
* Listen to Sandpack file changes and trigger debounced mutation
*/
useEffect(() => { useEffect(() => {
return () => debouncedMutation.cancel(); const currentCode = (sandpack.files['/' + fileKey] as SandpackBundlerFile | undefined)?.code;
if (currentCode) {
debouncedMutation(currentCode);
}
}, [sandpack.files, fileKey, debouncedMutation]);
/**
* Cleanup: cancel pending mutations when component unmounts or artifact changes
*/
useEffect(() => {
return () => {
debouncedMutation.cancel();
};
}, [artifact.id, debouncedMutation]); }, [artifact.id, debouncedMutation]);
/** return (
* Streaming: use model.applyEdits() to append new content. <SandpackCodeEditor
* Unlike setValue/pushEditOperations, applyEdits preserves existing ref={editorRef}
* tokens so syntax highlighting doesn't flash during updates. showTabs={false}
*/ showRunButton={false}
useEffect(() => { showLineNumbers={true}
const ed = monacoRef.current; showInlineErrors={true}
if (!ed || !readOnly) { readOnly={readOnly === true}
return; extensions={[autocompletion()]}
} extensionsKeymap={Array.from<KeyBinding>(completionKeymap)}
const newContent = artifact.content ?? ''; className="hljs language-javascript bg-black"
const prev = prevContentRef.current; />
);
if (newContent === prev) {
return;
}
const model = ed.getModel();
if (!model) {
return;
}
if (newContent.startsWith(prev) && prev.length > 0) {
const appended = newContent.slice(prev.length);
const endPos = model.getPositionAt(model.getValueLength());
model.applyEdits([
{
range: {
startLineNumber: endPos.lineNumber,
startColumn: endPos.column,
endLineNumber: endPos.lineNumber,
endColumn: endPos.column,
}, },
text: appended,
},
]);
} else {
model.setValue(newContent);
}
prevContentRef.current = newContent;
ed.revealLine(model.getLineCount());
}, [artifact.content, readOnly, monacoRef]);
useEffect(() => {
if (artifact.id === prevArtifactId.current) {
return;
}
prevArtifactId.current = artifact.id;
prevContentRef.current = artifact.content ?? '';
const ed = monacoRef.current;
if (ed && artifact.content != null) {
ed.getModel()?.setValue(artifact.content);
}
}, [artifact.id, artifact.content, monacoRef]);
useEffect(() => {
if (prevReadOnly.current && !readOnly && artifact.content != null) {
const ed = monacoRef.current;
if (ed) {
ed.getModel()?.setValue(artifact.content);
prevContentRef.current = artifact.content;
}
}
prevReadOnly.current = readOnly;
}, [readOnly, artifact.content, monacoRef]);
const handleChange = useCallback(
(value: string | undefined) => {
if (value === undefined || readOnly) {
return;
}
prevContentRef.current = value;
setCurrentCode(value);
if (value.length > 0) {
debouncedMutation(value);
}
},
[readOnly, debouncedMutation, setCurrentCode],
); );
/** export const ArtifactCodeEditor = function ({
* Disable all validation this is an artifact viewer/editor, not an IDE. files,
* Note: these are global Monaco settings that affect all editor instances on the page. fileKey,
* The `as unknown` cast is required because monaco-editor v0.55 types `.languages.typescript` template,
* as `{ deprecated: true }` while the runtime API is fully functional. artifact,
*/ editorRef,
const handleBeforeMount = useCallback((monaco: Monaco) => { sharedProps,
const { typescriptDefaults, javascriptDefaults, JsxEmit } = monaco.languages readOnly: externalReadOnly,
.typescript as unknown as { }: {
typescriptDefaults: { fileKey: string;
setDiagnosticsOptions: (o: { artifact: Artifact;
noSemanticValidation: boolean; files: ArtifactFiles;
noSyntaxValidation: boolean; template: SandpackProviderProps['template'];
}) => void; sharedProps: Partial<SandpackProviderProps>;
setCompilerOptions: (o: { editorRef: React.MutableRefObject<CodeEditorRef>;
allowNonTsExtensions: boolean; readOnly?: boolean;
allowJs: boolean; }) {
jsx: number; const { data: config } = useGetStartupConfig();
}) => void; const { isSubmitting } = useArtifactsContext();
}; const options: typeof sharedOptions = useMemo(() => {
javascriptDefaults: { if (!config) {
setDiagnosticsOptions: (o: { return sharedOptions;
noSemanticValidation: boolean;
noSyntaxValidation: boolean;
}) => void;
setCompilerOptions: (o: {
allowNonTsExtensions: boolean;
allowJs: boolean;
jsx: number;
}) => void;
};
JsxEmit: { React: number };
};
const diagnosticsOff = { noSemanticValidation: true, noSyntaxValidation: true };
const compilerBase = { allowNonTsExtensions: true, allowJs: true, jsx: JsxEmit.React };
typescriptDefaults.setDiagnosticsOptions(diagnosticsOff);
javascriptDefaults.setDiagnosticsOptions(diagnosticsOff);
typescriptDefaults.setCompilerOptions(compilerBase);
javascriptDefaults.setCompilerOptions(compilerBase);
}, []);
const handleMount = useCallback(
(ed: editor.IStandaloneCodeEditor) => {
monacoRef.current = ed;
prevContentRef.current = ed.getModel()?.getValue() ?? artifact.content ?? '';
if (readOnly) {
const model = ed.getModel();
if (model) {
ed.revealLine(model.getLineCount());
} }
} return {
}, ...sharedOptions,
// eslint-disable-next-line react-hooks/exhaustive-deps activeFile: '/' + fileKey,
[monacoRef], bundlerURL: template === 'static' ? config.staticBundlerURL : config.bundlerURL,
); };
}, [config, template, fileKey]);
const initialReadOnly = (externalReadOnly ?? false) || (isSubmitting ?? false);
const [readOnly, setReadOnly] = useState(initialReadOnly);
useEffect(() => {
setReadOnly((externalReadOnly ?? false) || (isSubmitting ?? false));
}, [isSubmitting, externalReadOnly]);
const language = getMonacoLanguage(artifact.type, artifact.language); if (Object.keys(files).length === 0) {
const editorOptions = useMemo<editor.IStandaloneEditorConstructionOptions>(
() => ({
readOnly,
minimap: { enabled: false },
lineNumbers: 'on',
scrollBeyondLastLine: false,
fontSize: 13,
tabSize: 2,
wordWrap: 'on',
automaticLayout: true,
padding: { top: 8 },
renderLineHighlight: readOnly ? 'none' : 'line',
cursorStyle: readOnly ? 'underline-thin' : 'line',
scrollbar: {
vertical: 'visible',
horizontal: 'auto',
verticalScrollbarSize: 8,
horizontalScrollbarSize: 8,
useShadows: false,
alwaysConsumeMouseWheel: false,
},
overviewRulerLanes: 0,
hideCursorInOverviewRuler: true,
overviewRulerBorder: false,
folding: false,
glyphMargin: false,
colorDecorators: !readOnly,
occurrencesHighlight: readOnly ? 'off' : 'singleFile',
selectionHighlight: !readOnly,
renderValidationDecorations: readOnly ? 'off' : 'editable',
quickSuggestions: !readOnly,
suggestOnTriggerCharacters: !readOnly,
parameterHints: { enabled: !readOnly },
hover: { enabled: !readOnly },
matchBrackets: readOnly ? 'never' : 'always',
}),
[readOnly],
);
if (!artifact.content) {
return null; return null;
} }
return ( return (
<div className="h-full w-full bg-[#1e1e1e]"> <StyledProvider
<MonacoEditor theme="dark"
height="100%" files={{
language={readOnly ? 'plaintext' : language} ...files,
theme="vs-dark" ...sharedFiles,
defaultValue={artifact.content} }}
onChange={handleChange} options={options}
beforeMount={handleBeforeMount} {...sharedProps}
onMount={handleMount} template={template}
options={editorOptions} >
/> <CodeEditor fileKey={fileKey} artifact={artifact} editorRef={editorRef} readOnly={readOnly} />
</div> </StyledProvider>
); );
}; };

View file

@ -1,26 +1,30 @@
import { useRef, useEffect } from 'react'; import { useRef, useEffect } from 'react';
import * as Tabs from '@radix-ui/react-tabs'; import * as Tabs from '@radix-ui/react-tabs';
import type { SandpackPreviewRef } from '@codesandbox/sandpack-react/unstyled'; import type { SandpackPreviewRef } from '@codesandbox/sandpack-react/unstyled';
import type { editor } from 'monaco-editor'; import type { CodeEditorRef } from '@codesandbox/sandpack-react';
import type { Artifact } from '~/common'; import type { Artifact } from '~/common';
import { useCodeState } from '~/Providers/EditorContext'; import { useCodeState } from '~/Providers/EditorContext';
import { useArtifactsContext } from '~/Providers';
import useArtifactProps from '~/hooks/Artifacts/useArtifactProps'; import useArtifactProps from '~/hooks/Artifacts/useArtifactProps';
import { useAutoScroll } from '~/hooks/Artifacts/useAutoScroll';
import { ArtifactCodeEditor } from './ArtifactCodeEditor'; import { ArtifactCodeEditor } from './ArtifactCodeEditor';
import { useGetStartupConfig } from '~/data-provider'; import { useGetStartupConfig } from '~/data-provider';
import { ArtifactPreview } from './ArtifactPreview'; import { ArtifactPreview } from './ArtifactPreview';
export default function ArtifactTabs({ export default function ArtifactTabs({
artifact, artifact,
editorRef,
previewRef, previewRef,
isSharedConvo, isSharedConvo,
}: { }: {
artifact: Artifact; artifact: Artifact;
editorRef: React.MutableRefObject<CodeEditorRef>;
previewRef: React.MutableRefObject<SandpackPreviewRef>; previewRef: React.MutableRefObject<SandpackPreviewRef>;
isSharedConvo?: boolean; isSharedConvo?: boolean;
}) { }) {
const { isSubmitting } = useArtifactsContext();
const { currentCode, setCurrentCode } = useCodeState(); const { currentCode, setCurrentCode } = useCodeState();
const { data: startupConfig } = useGetStartupConfig(); const { data: startupConfig } = useGetStartupConfig();
const monacoRef = useRef<editor.IStandaloneCodeEditor | null>(null);
const lastIdRef = useRef<string | null>(null); const lastIdRef = useRef<string | null>(null);
useEffect(() => { useEffect(() => {
@ -30,24 +34,33 @@ export default function ArtifactTabs({
lastIdRef.current = artifact.id; lastIdRef.current = artifact.id;
}, [setCurrentCode, artifact.id]); }, [setCurrentCode, artifact.id]);
const content = artifact.content ?? '';
const contentRef = useRef<HTMLDivElement>(null);
useAutoScroll({ ref: contentRef, content, isSubmitting });
const { files, fileKey, template, sharedProps } = useArtifactProps({ artifact }); const { files, fileKey, template, sharedProps } = useArtifactProps({ artifact });
return ( return (
<div className="flex h-full w-full flex-col"> <div className="flex h-full w-full flex-col">
<Tabs.Content <Tabs.Content
ref={contentRef}
value="code" value="code"
id="artifacts-code" id="artifacts-code"
className="h-full w-full flex-grow overflow-auto" className="h-full w-full flex-grow overflow-auto"
tabIndex={-1} tabIndex={-1}
> >
<ArtifactCodeEditor artifact={artifact} monacoRef={monacoRef} readOnly={isSharedConvo} /> <ArtifactCodeEditor
files={files}
fileKey={fileKey}
template={template}
artifact={artifact}
editorRef={editorRef}
sharedProps={sharedProps}
readOnly={isSharedConvo}
/>
</Tabs.Content> </Tabs.Content>
<Tabs.Content <Tabs.Content value="preview" className="h-full w-full flex-grow overflow-auto" tabIndex={-1}>
value="preview"
className="h-full w-full flex-grow overflow-hidden"
tabIndex={-1}
>
<ArtifactPreview <ArtifactPreview
files={files} files={files}
fileKey={fileKey} fileKey={fileKey}

View file

@ -3,7 +3,7 @@ import * as Tabs from '@radix-ui/react-tabs';
import { Code, Play, RefreshCw, X } from 'lucide-react'; import { Code, Play, RefreshCw, X } from 'lucide-react';
import { useSetRecoilState, useResetRecoilState } from 'recoil'; import { useSetRecoilState, useResetRecoilState } from 'recoil';
import { Button, Spinner, useMediaQuery, Radio } from '@librechat/client'; import { Button, Spinner, useMediaQuery, Radio } from '@librechat/client';
import type { SandpackPreviewRef } from '@codesandbox/sandpack-react'; import type { SandpackPreviewRef, CodeEditorRef } from '@codesandbox/sandpack-react';
import { useShareContext, useMutationState } from '~/Providers'; import { useShareContext, useMutationState } from '~/Providers';
import useArtifacts from '~/hooks/Artifacts/useArtifacts'; import useArtifacts from '~/hooks/Artifacts/useArtifacts';
import DownloadArtifact from './DownloadArtifact'; import DownloadArtifact from './DownloadArtifact';
@ -22,6 +22,7 @@ export default function Artifacts() {
const { isMutating } = useMutationState(); const { isMutating } = useMutationState();
const { isSharedConvo } = useShareContext(); const { isSharedConvo } = useShareContext();
const isMobile = useMediaQuery('(max-width: 868px)'); const isMobile = useMediaQuery('(max-width: 868px)');
const editorRef = useRef<CodeEditorRef>();
const previewRef = useRef<SandpackPreviewRef>(); const previewRef = useRef<SandpackPreviewRef>();
const [isVisible, setIsVisible] = useState(false); const [isVisible, setIsVisible] = useState(false);
const [isClosing, setIsClosing] = useState(false); const [isClosing, setIsClosing] = useState(false);
@ -296,6 +297,7 @@ export default function Artifacts() {
<div className="absolute inset-0 flex flex-col"> <div className="absolute inset-0 flex flex-col">
<ArtifactTabs <ArtifactTabs
artifact={currentArtifact} artifact={currentArtifact}
editorRef={editorRef as React.MutableRefObject<CodeEditorRef>}
previewRef={previewRef as React.MutableRefObject<SandpackPreviewRef>} previewRef={previewRef as React.MutableRefObject<SandpackPreviewRef>}
isSharedConvo={isSharedConvo} isSharedConvo={isSharedConvo}
/> />

View file

@ -1,8 +1,11 @@
import React, { memo, useState } from 'react'; import React, { memo, useEffect, useRef, useState } from 'react';
import copy from 'copy-to-clipboard'; import copy from 'copy-to-clipboard';
import rehypeKatex from 'rehype-katex';
import ReactMarkdown from 'react-markdown';
import { Button } from '@librechat/client'; import { Button } from '@librechat/client';
import rehypeHighlight from 'rehype-highlight';
import { Copy, CircleCheckBig } from 'lucide-react'; import { Copy, CircleCheckBig } from 'lucide-react';
import { handleDoubleClick } from '~/utils'; import { handleDoubleClick, langSubset } from '~/utils';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
type TCodeProps = { type TCodeProps = {
@ -26,6 +29,74 @@ export const code: React.ElementType = memo(({ inline, className, children }: TC
return <code className={`hljs language-${lang} !whitespace-pre`}>{children}</code>; return <code className={`hljs language-${lang} !whitespace-pre`}>{children}</code>;
}); });
export const CodeMarkdown = memo(
({ content = '', isSubmitting }: { content: string; isSubmitting: boolean }) => {
const scrollRef = useRef<HTMLDivElement>(null);
const [userScrolled, setUserScrolled] = useState(false);
const currentContent = content;
const rehypePlugins = [
[rehypeKatex],
[
rehypeHighlight,
{
detect: true,
ignoreMissing: true,
subset: langSubset,
},
],
];
useEffect(() => {
const scrollContainer = scrollRef.current;
if (!scrollContainer) {
return;
}
const handleScroll = () => {
const { scrollTop, scrollHeight, clientHeight } = scrollContainer;
const isNearBottom = scrollHeight - scrollTop - clientHeight < 50;
if (!isNearBottom) {
setUserScrolled(true);
} else {
setUserScrolled(false);
}
};
scrollContainer.addEventListener('scroll', handleScroll);
return () => {
scrollContainer.removeEventListener('scroll', handleScroll);
};
}, []);
useEffect(() => {
const scrollContainer = scrollRef.current;
if (!scrollContainer || !isSubmitting || userScrolled) {
return;
}
scrollContainer.scrollTop = scrollContainer.scrollHeight;
}, [content, isSubmitting, userScrolled]);
return (
<div ref={scrollRef} className="max-h-full overflow-y-auto">
<ReactMarkdown
/* @ts-ignore */
rehypePlugins={rehypePlugins}
components={
{ code } as {
[key: string]: React.ElementType;
}
}
>
{currentContent}
</ReactMarkdown>
</div>
);
},
);
export const CopyCodeButton: React.FC<{ content: string }> = ({ content }) => { export const CopyCodeButton: React.FC<{ content: string }> = ({ content }) => {
const localize = useLocalize(); const localize = useLocalize();
const [isCopied, setIsCopied] = useState(false); const [isCopied, setIsCopied] = useState(false);

View file

@ -1,21 +1,17 @@
import { useCallback } from 'react';
import { useSetRecoilState, useRecoilValue } from 'recoil';
import { PlusCircle } from 'lucide-react'; import { PlusCircle } from 'lucide-react';
import { TooltipAnchor } from '@librechat/client'; import { TooltipAnchor } from '@librechat/client';
import { isAssistantsEndpoint } from 'librechat-data-provider'; import { isAssistantsEndpoint } from 'librechat-data-provider';
import type { TConversation } from 'librechat-data-provider'; import type { TConversation } from 'librechat-data-provider';
import { useGetConversation, useLocalize } from '~/hooks'; import { useChatContext, useAddedChatContext } from '~/Providers';
import { mainTextareaId } from '~/common'; import { mainTextareaId } from '~/common';
import store from '~/store'; import { useLocalize } from '~/hooks';
function AddMultiConvo() { function AddMultiConvo() {
const { conversation } = useChatContext();
const { setConversation: setAddedConvo } = useAddedChatContext();
const localize = useLocalize(); const localize = useLocalize();
const getConversation = useGetConversation(0);
const endpoint = useRecoilValue(store.conversationEndpointByIndex(0));
const setAddedConvo = useSetRecoilState(store.conversationByIndex(1));
const clickHandler = useCallback(() => { const clickHandler = () => {
const conversation = getConversation();
const { title: _t, ...convo } = conversation ?? ({} as TConversation); const { title: _t, ...convo } = conversation ?? ({} as TConversation);
setAddedConvo({ setAddedConvo({
...convo, ...convo,
@ -26,13 +22,13 @@ function AddMultiConvo() {
if (textarea) { if (textarea) {
textarea.focus(); textarea.focus();
} }
}, [getConversation, setAddedConvo]); };
if (!endpoint) { if (!conversation) {
return null; return null;
} }
if (isAssistantsEndpoint(endpoint)) { if (isAssistantsEndpoint(conversation.endpoint)) {
return null; return null;
} }

View file

@ -1,11 +1,11 @@
import React, { useEffect, memo } from 'react'; import React, { useEffect } from 'react';
import TagManager from 'react-gtm-module';
import ReactMarkdown from 'react-markdown'; import ReactMarkdown from 'react-markdown';
import TagManager from 'react-gtm-module';
import { Constants } from 'librechat-data-provider'; import { Constants } from 'librechat-data-provider';
import { useGetStartupConfig } from '~/data-provider'; import { useGetStartupConfig } from '~/data-provider';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
function Footer({ className }: { className?: string }) { export default function Footer({ className }: { className?: string }) {
const { data: config } = useGetStartupConfig(); const { data: config } = useGetStartupConfig();
const localize = useLocalize(); const localize = useLocalize();
@ -98,8 +98,3 @@ function Footer({ className }: { className?: string }) {
</div> </div>
); );
} }
const MemoizedFooter = memo(Footer);
MemoizedFooter.displayName = 'Footer';
export default MemoizedFooter;

View file

@ -1,4 +1,4 @@
import { memo, useMemo } from 'react'; import { useMemo } from 'react';
import { useMediaQuery } from '@librechat/client'; import { useMediaQuery } from '@librechat/client';
import { useOutletContext } from 'react-router-dom'; import { useOutletContext } from 'react-router-dom';
import { AnimatePresence, motion } from 'framer-motion'; import { AnimatePresence, motion } from 'framer-motion';
@ -16,7 +16,7 @@ import { cn } from '~/utils';
const defaultInterface = getConfigDefaults().interface; const defaultInterface = getConfigDefaults().interface;
function Header() { export default function Header() {
const { data: startupConfig } = useGetStartupConfig(); const { data: startupConfig } = useGetStartupConfig();
const { navVisible, setNavVisible } = useOutletContext<ContextType>(); const { navVisible, setNavVisible } = useOutletContext<ContextType>();
@ -35,11 +35,6 @@ function Header() {
permission: Permissions.USE, permission: Permissions.USE,
}); });
const hasAccessToTemporaryChat = useHasAccess({
permissionType: PermissionTypes.TEMPORARY_CHAT,
permission: Permissions.USE,
});
const isSmallScreen = useMediaQuery('(max-width: 768px)'); const isSmallScreen = useMediaQuery('(max-width: 768px)');
return ( return (
@ -78,7 +73,7 @@ function Header() {
<ExportAndShareMenu <ExportAndShareMenu
isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false} isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false}
/> />
{hasAccessToTemporaryChat === true && <TemporaryChat />} <TemporaryChat />
</> </>
)} )}
</div> </div>
@ -90,7 +85,7 @@ function Header() {
<ExportAndShareMenu <ExportAndShareMenu
isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false} isSharedButtonEnabled={startupConfig?.sharedLinksEnabled ?? false}
/> />
{hasAccessToTemporaryChat === true && <TemporaryChat />} <TemporaryChat />
</div> </div>
)} )}
</div> </div>
@ -99,8 +94,3 @@ function Header() {
</div> </div>
); );
} }
const MemoizedHeader = memo(Header);
MemoizedHeader.displayName = 'Header';
export default MemoizedHeader;

View file

@ -219,6 +219,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
<div className={cn('flex w-full items-center', isRTL && 'flex-row-reverse')}> <div className={cn('flex w-full items-center', isRTL && 'flex-row-reverse')}>
{showPlusPopover && !isAssistantsEndpoint(endpoint) && ( {showPlusPopover && !isAssistantsEndpoint(endpoint) && (
<Mention <Mention
conversation={conversation}
setShowMentionPopover={setShowPlusPopover} setShowMentionPopover={setShowPlusPopover}
newConversation={generateConversation} newConversation={generateConversation}
textAreaRef={textAreaRef} textAreaRef={textAreaRef}
@ -229,6 +230,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
)} )}
{showMentionPopover && ( {showMentionPopover && (
<Mention <Mention
conversation={conversation}
setShowMentionPopover={setShowMentionPopover} setShowMentionPopover={setShowMentionPopover}
newConversation={newConversation} newConversation={newConversation}
textAreaRef={textAreaRef} textAreaRef={textAreaRef}

View file

@ -2,9 +2,10 @@ import { memo, useMemo } from 'react';
import { import {
Constants, Constants,
supportsFiles, supportsFiles,
EModelEndpoint,
mergeFileConfig, mergeFileConfig,
isAgentsEndpoint, isAgentsEndpoint,
resolveEndpointType, getEndpointField,
isAssistantsEndpoint, isAssistantsEndpoint,
getEndpointFileConfig, getEndpointFileConfig,
} from 'librechat-data-provider'; } from 'librechat-data-provider';
@ -54,31 +55,21 @@ function AttachFileChat({
const { data: endpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig } = useGetEndpointsQuery();
const agentProvider = useMemo(() => { const endpointType = useMemo(() => {
if (!isAgents || !conversation?.agent_id) { return (
return undefined; getEndpointField(endpointsConfig, endpoint, 'type') ||
} (endpoint as EModelEndpoint | undefined)
const agent = agentData || agentsMap?.[conversation.agent_id];
return agent?.provider;
}, [isAgents, conversation?.agent_id, agentData, agentsMap]);
const endpointType = useMemo(
() => resolveEndpointType(endpointsConfig, endpoint, agentProvider),
[endpointsConfig, endpoint, agentProvider],
); );
}, [endpoint, endpointsConfig]);
const fileConfigEndpoint = useMemo(
() => (isAgents && agentProvider ? agentProvider : endpoint),
[isAgents, agentProvider, endpoint],
);
const endpointFileConfig = useMemo( const endpointFileConfig = useMemo(
() => () =>
getEndpointFileConfig({ getEndpointFileConfig({
endpoint,
fileConfig, fileConfig,
endpointType, endpointType,
endpoint: fileConfigEndpoint,
}), }),
[fileConfigEndpoint, fileConfig, endpointType], [endpoint, fileConfig, endpointType],
); );
const endpointSupportsFiles: boolean = useMemo( const endpointSupportsFiles: boolean = useMemo(
() => supportsFiles[endpointType ?? endpoint ?? ''] ?? false, () => supportsFiles[endpointType ?? endpoint ?? ''] ?? false,
@ -91,7 +82,7 @@ function AttachFileChat({
if (isAssistants && endpointSupportsFiles && !isUploadDisabled) { if (isAssistants && endpointSupportsFiles && !isUploadDisabled) {
return <AttachFile disabled={disableInputs} />; return <AttachFile disabled={disableInputs} />;
} else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) { } else if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) {
return ( return (
<AttachFileMenu <AttachFileMenu
endpoint={endpoint} endpoint={endpoint}

View file

@ -50,7 +50,7 @@ interface AttachFileMenuProps {
endpoint?: string | null; endpoint?: string | null;
disabled?: boolean | null; disabled?: boolean | null;
conversationId: string; conversationId: string;
endpointType?: EModelEndpoint | string; endpointType?: EModelEndpoint;
endpointFileConfig?: EndpointFileConfig; endpointFileConfig?: EndpointFileConfig;
useResponsesApi?: boolean; useResponsesApi?: boolean;
} }

View file

@ -3,10 +3,10 @@ import { useToastContext } from '@librechat/client';
import { EToolResources } from 'librechat-data-provider'; import { EToolResources } from 'librechat-data-provider';
import type { ExtendedFile } from '~/common'; import type { ExtendedFile } from '~/common';
import { useDeleteFilesMutation } from '~/data-provider'; import { useDeleteFilesMutation } from '~/data-provider';
import { logger, getCachedPreview } from '~/utils';
import { useFileDeletion } from '~/hooks/Files'; import { useFileDeletion } from '~/hooks/Files';
import FileContainer from './FileContainer'; import FileContainer from './FileContainer';
import { useLocalize } from '~/hooks'; import { useLocalize } from '~/hooks';
import { logger } from '~/utils';
import Image from './Image'; import Image from './Image';
export default function FileRow({ export default function FileRow({
@ -24,7 +24,7 @@ export default function FileRow({
files: Map<string, ExtendedFile> | undefined; files: Map<string, ExtendedFile> | undefined;
abortUpload?: () => void; abortUpload?: () => void;
setFiles: React.Dispatch<React.SetStateAction<Map<string, ExtendedFile>>>; setFiles: React.Dispatch<React.SetStateAction<Map<string, ExtendedFile>>>;
setFilesLoading?: React.Dispatch<React.SetStateAction<boolean>>; setFilesLoading: React.Dispatch<React.SetStateAction<boolean>>;
fileFilter?: (file: ExtendedFile) => boolean; fileFilter?: (file: ExtendedFile) => boolean;
assistant_id?: string; assistant_id?: string;
agent_id?: string; agent_id?: string;
@ -58,7 +58,6 @@ export default function FileRow({
const { deleteFile } = useFileDeletion({ mutateAsync, agent_id, assistant_id, tool_resource }); const { deleteFile } = useFileDeletion({ mutateAsync, agent_id, assistant_id, tool_resource });
useEffect(() => { useEffect(() => {
if (!setFilesLoading) return;
if (files.length === 0) { if (files.length === 0) {
setFilesLoading(false); setFilesLoading(false);
return; return;
@ -112,14 +111,12 @@ export default function FileRow({
) )
.uniqueFiles.map((file: ExtendedFile, index: number) => { .uniqueFiles.map((file: ExtendedFile, index: number) => {
const handleDelete = () => { const handleDelete = () => {
if (abortUpload && file.progress < 1) {
abortUpload();
}
if (file.progress >= 1) {
showToast({ showToast({
message: localize('com_ui_deleting_file'), message: localize('com_ui_deleting_file'),
status: 'info', status: 'info',
}); });
if (abortUpload && file.progress < 1) {
abortUpload();
} }
deleteFile({ file, setFiles }); deleteFile({ file, setFiles });
}; };
@ -136,7 +133,7 @@ export default function FileRow({
> >
{isImage ? ( {isImage ? (
<Image <Image
url={getCachedPreview(file.file_id) ?? file.preview ?? file.filepath} url={file.progress === 1 ? file.filepath : (file.preview ?? file.filepath)}
onDelete={handleDelete} onDelete={handleDelete}
progress={file.progress} progress={file.progress}
source={file.source} source={file.source}

View file

@ -1,233 +0,0 @@
import React from 'react';
import { render, screen } from '@testing-library/react';
import { RecoilRoot } from 'recoil';
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
import { EModelEndpoint, mergeFileConfig } from 'librechat-data-provider';
import type { TEndpointsConfig, Agent } from 'librechat-data-provider';
import AttachFileChat from '../AttachFileChat';
const mockEndpointsConfig: TEndpointsConfig = {
[EModelEndpoint.openAI]: { userProvide: false, order: 0 },
[EModelEndpoint.agents]: { userProvide: false, order: 1 },
[EModelEndpoint.assistants]: { userProvide: false, order: 2 },
Moonshot: { type: EModelEndpoint.custom, userProvide: false, order: 9999 },
};
const defaultFileConfig = mergeFileConfig({
endpoints: {
Moonshot: { fileLimit: 5 },
[EModelEndpoint.agents]: { fileLimit: 20 },
default: { fileLimit: 10 },
},
});
let mockFileConfig = defaultFileConfig;
let mockAgentsMap: Record<string, Partial<Agent>> = {};
let mockAgentQueryData: Partial<Agent> | undefined;
jest.mock('~/data-provider', () => ({
useGetEndpointsQuery: () => ({ data: mockEndpointsConfig }),
useGetFileConfig: ({ select }: { select?: (data: unknown) => unknown }) => ({
data: select != null ? select(mockFileConfig) : mockFileConfig,
}),
useGetAgentByIdQuery: () => ({ data: mockAgentQueryData }),
}));
jest.mock('~/Providers', () => ({
useAgentsMapContext: () => mockAgentsMap,
}));
/** Capture the props passed to AttachFileMenu */
let mockAttachFileMenuProps: Record<string, unknown> = {};
jest.mock('../AttachFileMenu', () => {
return function MockAttachFileMenu(props: Record<string, unknown>) {
mockAttachFileMenuProps = props;
return <div data-testid="attach-file-menu" data-endpoint-type={String(props.endpointType)} />;
};
});
jest.mock('../AttachFile', () => {
return function MockAttachFile() {
return <div data-testid="attach-file" />;
};
});
const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } });
function renderComponent(conversation: Record<string, unknown> | null, disableInputs = false) {
return render(
<QueryClientProvider client={queryClient}>
<RecoilRoot>
<AttachFileChat conversation={conversation as never} disableInputs={disableInputs} />
</RecoilRoot>
</QueryClientProvider>,
);
}
describe('AttachFileChat', () => {
beforeEach(() => {
mockFileConfig = defaultFileConfig;
mockAgentsMap = {};
mockAgentQueryData = undefined;
mockAttachFileMenuProps = {};
});
describe('rendering decisions', () => {
it('renders AttachFileMenu for agents endpoint', () => {
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
expect(screen.getByTestId('attach-file-menu')).toBeInTheDocument();
});
it('renders AttachFileMenu for custom endpoint with file support', () => {
renderComponent({ endpoint: 'Moonshot' });
expect(screen.getByTestId('attach-file-menu')).toBeInTheDocument();
});
it('renders null for null conversation', () => {
const { container } = renderComponent(null);
expect(container.innerHTML).toBe('');
});
});
describe('endpointType resolution for agents', () => {
it('passes custom endpointType when agent provider is a custom endpoint', () => {
mockAgentsMap = {
'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial<Agent>,
};
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
expect(mockAttachFileMenuProps.endpointType).toBe(EModelEndpoint.custom);
});
it('passes openAI endpointType when agent provider is openAI', () => {
mockAgentsMap = {
'agent-1': { provider: EModelEndpoint.openAI, model_parameters: {} } as Partial<Agent>,
};
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
expect(mockAttachFileMenuProps.endpointType).toBe(EModelEndpoint.openAI);
});
it('passes agents endpointType when no agent provider', () => {
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
expect(mockAttachFileMenuProps.endpointType).toBe(EModelEndpoint.agents);
});
it('passes agents endpointType when no agent_id', () => {
renderComponent({ endpoint: EModelEndpoint.agents });
expect(mockAttachFileMenuProps.endpointType).toBe(EModelEndpoint.agents);
});
it('uses agentData query when agent not in agentsMap', () => {
mockAgentQueryData = { provider: 'Moonshot' } as Partial<Agent>;
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-2' });
expect(mockAttachFileMenuProps.endpointType).toBe(EModelEndpoint.custom);
});
});
describe('endpointType resolution for non-agents', () => {
it('passes custom endpointType for a custom endpoint', () => {
renderComponent({ endpoint: 'Moonshot' });
expect(mockAttachFileMenuProps.endpointType).toBe(EModelEndpoint.custom);
});
it('passes openAI endpointType for openAI endpoint', () => {
renderComponent({ endpoint: EModelEndpoint.openAI });
expect(mockAttachFileMenuProps.endpointType).toBe(EModelEndpoint.openAI);
});
});
describe('consistency: same endpoint type for direct vs agent usage', () => {
it('resolves Moonshot the same way whether used directly or through an agent', () => {
renderComponent({ endpoint: 'Moonshot' });
const directType = mockAttachFileMenuProps.endpointType;
mockAgentsMap = {
'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial<Agent>,
};
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
const agentType = mockAttachFileMenuProps.endpointType;
expect(directType).toBe(agentType);
});
});
describe('upload disabled rendering', () => {
it('renders null for agents endpoint when fileConfig.agents.disabled is true', () => {
mockFileConfig = mergeFileConfig({
endpoints: {
[EModelEndpoint.agents]: { disabled: true },
},
});
const { container } = renderComponent({
endpoint: EModelEndpoint.agents,
agent_id: 'agent-1',
});
expect(container.innerHTML).toBe('');
});
it('renders null for agents endpoint when disableInputs is true', () => {
const { container } = renderComponent(
{ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' },
true,
);
expect(container.innerHTML).toBe('');
});
it('renders AttachFile for assistants endpoint when not disabled', () => {
renderComponent({ endpoint: EModelEndpoint.assistants });
expect(screen.getByTestId('attach-file')).toBeInTheDocument();
});
it('renders AttachFileMenu when provider-specific config overrides agents disabled', () => {
mockFileConfig = mergeFileConfig({
endpoints: {
Moonshot: { disabled: false, fileLimit: 5 },
[EModelEndpoint.agents]: { disabled: true },
},
});
mockAgentsMap = {
'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial<Agent>,
};
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
expect(screen.getByTestId('attach-file-menu')).toBeInTheDocument();
});
it('renders null for assistants endpoint when fileConfig.assistants.disabled is true', () => {
mockFileConfig = mergeFileConfig({
endpoints: {
[EModelEndpoint.assistants]: { disabled: true },
},
});
const { container } = renderComponent({
endpoint: EModelEndpoint.assistants,
});
expect(container.innerHTML).toBe('');
});
});
describe('endpointFileConfig resolution', () => {
it('passes Moonshot-specific file config for agent with Moonshot provider', () => {
mockAgentsMap = {
'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial<Agent>,
};
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
const config = mockAttachFileMenuProps.endpointFileConfig as { fileLimit?: number };
expect(config?.fileLimit).toBe(5);
});
it('passes agents file config when agent has no specific provider config', () => {
mockAgentsMap = {
'agent-1': { provider: EModelEndpoint.openAI, model_parameters: {} } as Partial<Agent>,
};
renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' });
const config = mockAttachFileMenuProps.endpointFileConfig as { fileLimit?: number };
expect(config?.fileLimit).toBe(10);
});
it('passes agents file config when no agent provider', () => {
renderComponent({ endpoint: EModelEndpoint.agents });
const config = mockAttachFileMenuProps.endpointFileConfig as { fileLimit?: number };
expect(config?.fileLimit).toBe(20);
});
});
});

View file

@ -1,10 +1,12 @@
import React from 'react'; import React from 'react';
import { render, screen, fireEvent } from '@testing-library/react'; import { render, screen, fireEvent } from '@testing-library/react';
import '@testing-library/jest-dom';
import { RecoilRoot } from 'recoil'; import { RecoilRoot } from 'recoil';
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
import { EModelEndpoint, Providers } from 'librechat-data-provider'; import { EModelEndpoint } from 'librechat-data-provider';
import AttachFileMenu from '../AttachFileMenu'; import AttachFileMenu from '../AttachFileMenu';
// Mock all the hooks
jest.mock('~/hooks', () => ({ jest.mock('~/hooks', () => ({
useAgentToolPermissions: jest.fn(), useAgentToolPermissions: jest.fn(),
useAgentCapabilities: jest.fn(), useAgentCapabilities: jest.fn(),
@ -23,45 +25,53 @@ jest.mock('~/data-provider', () => ({
})); }));
jest.mock('~/components/SharePoint', () => ({ jest.mock('~/components/SharePoint', () => ({
SharePointPickerDialog: () => null, SharePointPickerDialog: jest.fn(() => null),
})); }));
jest.mock('@librechat/client', () => { jest.mock('@librechat/client', () => {
// eslint-disable-next-line @typescript-eslint/no-require-imports const React = jest.requireActual('react');
const R = require('react');
return { return {
FileUpload: (props) => R.createElement('div', { 'data-testid': 'file-upload' }, props.children), FileUpload: React.forwardRef(({ children, handleFileChange }: any, ref: any) => (
TooltipAnchor: (props) => props.render, <div data-testid="file-upload">
DropdownPopup: (props) => <input ref={ref} type="file" onChange={handleFileChange} data-testid="file-input" />
R.createElement( {children}
'div', </div>
null, )),
R.createElement('div', { onClick: () => props.setIsOpen(!props.isOpen) }, props.trigger), TooltipAnchor: ({ render }: any) => render,
props.isOpen && DropdownPopup: ({ trigger, items, isOpen, setIsOpen }: any) => {
R.createElement( const handleTriggerClick = () => {
'div', if (setIsOpen) {
{ 'data-testid': 'dropdown-menu' }, setIsOpen(!isOpen);
props.items.map((item, idx) => }
R.createElement( };
'button',
{ key: idx, onClick: item.onClick, 'data-testid': `menu-item-${idx}` }, return (
item.label, <div>
), <div onClick={handleTriggerClick}>{trigger}</div>
), {isOpen && (
), <div data-testid="dropdown-menu">
), {items.map((item: any, idx: number) => (
AttachmentIcon: () => R.createElement('span', { 'data-testid': 'attachment-icon' }), <button key={idx} onClick={item.onClick} data-testid={`menu-item-${idx}`}>
SharePointIcon: () => R.createElement('span', { 'data-testid': 'sharepoint-icon' }), {item.label}
</button>
))}
</div>
)}
</div>
);
},
AttachmentIcon: () => <span data-testid="attachment-icon">📎</span>,
SharePointIcon: () => <span data-testid="sharepoint-icon">SP</span>,
}; };
}); });
jest.mock('@ariakit/react', () => { jest.mock('@ariakit/react', () => ({
// eslint-disable-next-line @typescript-eslint/no-require-imports MenuButton: ({ children, onClick, disabled, ...props }: any) => (
const R = require('react'); <button onClick={onClick} disabled={disabled} {...props}>
return { {children}
MenuButton: (props) => R.createElement('button', props, props.children), </button>
}; ),
}); }));
const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions; const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions;
const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities; const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities;
@ -73,283 +83,558 @@ const mockUseSharePointFileHandling = jest.requireMock(
).default; ).default;
const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig; const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig;
const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } }); describe('AttachFileMenu', () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
});
function setupMocks(overrides: { provider?: string } = {}) { const mockHandleFileChange = jest.fn();
beforeEach(() => {
jest.clearAllMocks();
// Default mock implementations
mockUseLocalize.mockReturnValue((key: string) => {
const translations: Record<string, string> = { const translations: Record<string, string> = {
com_ui_upload_provider: 'Upload to Provider', com_ui_upload_provider: 'Upload to Provider',
com_ui_upload_image_input: 'Upload Image', com_ui_upload_image_input: 'Upload Image',
com_ui_upload_ocr_text: 'Upload as Text', com_ui_upload_ocr_text: 'Upload OCR Text',
com_ui_upload_file_search: 'Upload for File Search', com_ui_upload_file_search: 'Upload for File Search',
com_ui_upload_code_files: 'Upload Code Files', com_ui_upload_code_files: 'Upload Code Files',
com_sidepanel_attach_files: 'Attach Files', com_sidepanel_attach_files: 'Attach Files',
com_files_upload_sharepoint: 'Upload from SharePoint', com_files_upload_sharepoint: 'Upload from SharePoint',
}; };
mockUseLocalize.mockReturnValue((key: string) => translations[key] || key); return translations[key] || key;
});
mockUseAgentCapabilities.mockReturnValue({ mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false, contextEnabled: false,
fileSearchEnabled: false, fileSearchEnabled: false,
codeEnabled: false, codeEnabled: false,
}); });
mockUseGetAgentsConfig.mockReturnValue({ agentsConfig: {} });
mockUseFileHandling.mockReturnValue({ handleFileChange: jest.fn() }); mockUseGetAgentsConfig.mockReturnValue({
agentsConfig: {
capabilities: {
contextEnabled: false,
fileSearchEnabled: false,
codeEnabled: false,
},
},
});
mockUseFileHandling.mockReturnValue({
handleFileChange: mockHandleFileChange,
});
mockUseSharePointFileHandling.mockReturnValue({ mockUseSharePointFileHandling.mockReturnValue({
handleSharePointFiles: jest.fn(), handleSharePointFiles: jest.fn(),
isProcessing: false, isProcessing: false,
downloadProgress: 0, downloadProgress: 0,
}); });
mockUseGetStartupConfig.mockReturnValue({ data: { sharePointFilePickerEnabled: false } });
mockUseGetStartupConfig.mockReturnValue({
data: {
sharePointFilePickerEnabled: false,
},
});
mockUseAgentToolPermissions.mockReturnValue({ mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false, fileSearchAllowedByAgent: false,
codeAllowedByAgent: false, codeAllowedByAgent: false,
provider: overrides.provider ?? undefined, provider: undefined,
});
}); });
}
function renderMenu(props: Record<string, unknown> = {}) { const renderAttachFileMenu = (props: any = {}) => {
return render( return render(
<QueryClientProvider client={queryClient}> <QueryClientProvider client={queryClient}>
<RecoilRoot> <RecoilRoot>
<AttachFileMenu conversationId="test-convo" {...props} /> <AttachFileMenu conversationId="test-conversation" {...props} />
</RecoilRoot> </RecoilRoot>
</QueryClientProvider>, </QueryClientProvider>,
); );
} };
function openMenu() { describe('Basic Rendering', () => {
fireEvent.click(screen.getByRole('button', { name: /attach file options/i })); it('should render the attachment button', () => {
} renderAttachFileMenu();
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
});
describe('AttachFileMenu', () => { it('should be disabled when disabled prop is true', () => {
beforeEach(jest.clearAllMocks); renderAttachFileMenu({ disabled: true });
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeDisabled();
});
describe('Upload to Provider vs Upload Image', () => { it('should not be disabled when disabled prop is false', () => {
it('shows "Upload to Provider" when endpointType is custom (resolved from agent provider)', () => { renderAttachFileMenu({ disabled: false });
setupMocks({ provider: 'Moonshot' }); const button = screen.getByRole('button', { name: /attach file options/i });
renderMenu({ endpointType: EModelEndpoint.custom }); expect(button).not.toBeDisabled();
openMenu(); });
});
describe('Provider Detection Fix - endpointType Priority', () => {
it('should prioritize endpointType over currentProvider for LiteLLM gateway', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: 'litellm', // Custom gateway name NOT in documentSupportedProviders
});
renderAttachFileMenu({
endpoint: 'litellm',
endpointType: EModelEndpoint.openAI, // Backend override IS in documentSupportedProviders
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
// With the fix, should show "Upload to Provider" because endpointType is checked first
expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
expect(screen.queryByText('Upload Image')).not.toBeInTheDocument(); expect(screen.queryByText('Upload Image')).not.toBeInTheDocument();
}); });
it('shows "Upload to Provider" when endpointType is openAI', () => { it('should show Upload to Provider for custom endpoints with OpenAI endpointType', () => {
setupMocks({ provider: EModelEndpoint.openAI }); mockUseAgentToolPermissions.mockReturnValue({
renderMenu({ endpointType: EModelEndpoint.openAI }); fileSearchAllowedByAgent: false,
openMenu(); codeAllowedByAgent: false,
provider: 'my-custom-gateway',
});
renderAttachFileMenu({
endpoint: 'my-custom-gateway',
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
}); });
it('shows "Upload to Provider" when endpointType is anthropic', () => { it('should show Upload Image when neither endpointType nor provider support documents', () => {
setupMocks({ provider: EModelEndpoint.anthropic }); mockUseAgentToolPermissions.mockReturnValue({
renderMenu({ endpointType: EModelEndpoint.anthropic }); fileSearchAllowedByAgent: false,
openMenu(); codeAllowedByAgent: false,
expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); provider: 'unsupported-provider',
}); });
it('shows "Upload to Provider" when endpointType is google', () => { renderAttachFileMenu({
setupMocks({ provider: Providers.GOOGLE }); endpoint: 'unsupported-provider',
renderMenu({ endpointType: EModelEndpoint.google }); endpointType: 'unsupported-endpoint' as any,
openMenu();
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
}); });
it('shows "Upload Image" when endpointType is agents (no provider resolution)', () => { const button = screen.getByRole('button', { name: /attach file options/i });
setupMocks(); fireEvent.click(button);
renderMenu({ endpointType: EModelEndpoint.agents });
openMenu();
expect(screen.getByText('Upload Image')).toBeInTheDocument(); expect(screen.getByText('Upload Image')).toBeInTheDocument();
expect(screen.queryByText('Upload to Provider')).not.toBeInTheDocument(); expect(screen.queryByText('Upload to Provider')).not.toBeInTheDocument();
}); });
it('shows "Upload Image" when neither endpointType nor provider supports documents', () => { it('should fallback to currentProvider when endpointType is undefined', () => {
setupMocks({ provider: 'unknown-provider' }); mockUseAgentToolPermissions.mockReturnValue({
renderMenu({ endpointType: 'unknown-type' }); fileSearchAllowedByAgent: false,
openMenu(); codeAllowedByAgent: false,
expect(screen.getByText('Upload Image')).toBeInTheDocument(); provider: EModelEndpoint.openAI,
}); });
it('shows "Upload to Provider" for azureOpenAI with useResponsesApi', () => { renderAttachFileMenu({
setupMocks({ provider: EModelEndpoint.azureOpenAI }); endpoint: EModelEndpoint.openAI,
renderMenu({ endpointType: EModelEndpoint.azureOpenAI, useResponsesApi: true }); endpointType: undefined,
openMenu(); });
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
}); });
it('shows "Upload Image" for azureOpenAI without useResponsesApi', () => { it('should fallback to currentProvider when endpointType is null', () => {
setupMocks({ provider: EModelEndpoint.azureOpenAI }); mockUseAgentToolPermissions.mockReturnValue({
renderMenu({ endpointType: EModelEndpoint.azureOpenAI, useResponsesApi: false }); fileSearchAllowedByAgent: false,
openMenu(); codeAllowedByAgent: false,
expect(screen.getByText('Upload Image')).toBeInTheDocument(); provider: EModelEndpoint.anthropic,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.anthropic,
endpointType: null,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
}); });
}); });
describe('agent provider resolution scenario', () => { describe('Supported Providers', () => {
it('shows "Upload to Provider" when agents endpoint has custom endpointType from provider', () => { const supportedProviders = [
setupMocks({ provider: 'Moonshot' }); { name: 'OpenAI', endpoint: EModelEndpoint.openAI },
renderMenu({ { name: 'Anthropic', endpoint: EModelEndpoint.anthropic },
endpoint: EModelEndpoint.agents, { name: 'Google', endpoint: EModelEndpoint.google },
endpointType: EModelEndpoint.custom, { name: 'Custom', endpoint: EModelEndpoint.custom },
];
supportedProviders.forEach(({ name, endpoint }) => {
it(`should show Upload to Provider for ${name}`, () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: endpoint,
}); });
openMenu();
renderAttachFileMenu({
endpoint,
endpointType: endpoint,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
});
it('should show Upload to Provider for Azure OpenAI with useResponsesApi', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.azureOpenAI,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.azureOpenAI,
endpointType: EModelEndpoint.azureOpenAI,
useResponsesApi: true,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
}); });
it('shows "Upload Image" when agents endpoint has no resolved provider type', () => { it('should NOT show Upload to Provider for Azure OpenAI without useResponsesApi', () => {
setupMocks(); mockUseAgentToolPermissions.mockReturnValue({
renderMenu({ fileSearchAllowedByAgent: false,
endpoint: EModelEndpoint.agents, codeAllowedByAgent: false,
endpointType: EModelEndpoint.agents, provider: EModelEndpoint.azureOpenAI,
}); });
openMenu();
renderAttachFileMenu({
endpoint: EModelEndpoint.azureOpenAI,
endpointType: EModelEndpoint.azureOpenAI,
useResponsesApi: false,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.queryByText('Upload to Provider')).not.toBeInTheDocument();
expect(screen.getByText('Upload Image')).toBeInTheDocument(); expect(screen.getByText('Upload Image')).toBeInTheDocument();
}); });
}); });
describe('Basic Rendering', () => {
it('renders the attachment button', () => {
setupMocks();
renderMenu();
expect(screen.getByRole('button', { name: /attach file options/i })).toBeInTheDocument();
});
it('is disabled when disabled prop is true', () => {
setupMocks();
renderMenu({ disabled: true });
expect(screen.getByRole('button', { name: /attach file options/i })).toBeDisabled();
});
it('is not disabled when disabled prop is false', () => {
setupMocks();
renderMenu({ disabled: false });
expect(screen.getByRole('button', { name: /attach file options/i })).not.toBeDisabled();
});
});
describe('Agent Capabilities', () => { describe('Agent Capabilities', () => {
it('shows OCR Text option when context is enabled', () => { it('should show OCR Text option when context is enabled', () => {
setupMocks();
mockUseAgentCapabilities.mockReturnValue({ mockUseAgentCapabilities.mockReturnValue({
contextEnabled: true, contextEnabled: true,
fileSearchEnabled: false, fileSearchEnabled: false,
codeEnabled: false, codeEnabled: false,
}); });
renderMenu({ endpointType: EModelEndpoint.openAI });
openMenu(); renderAttachFileMenu({
expect(screen.getByText('Upload as Text')).toBeInTheDocument(); endpointType: EModelEndpoint.openAI,
}); });
it('shows File Search option when enabled and allowed by agent', () => { const button = screen.getByRole('button', { name: /attach file options/i });
setupMocks(); fireEvent.click(button);
expect(screen.getByText('Upload OCR Text')).toBeInTheDocument();
});
it('should show File Search option when enabled and allowed by agent', () => {
mockUseAgentCapabilities.mockReturnValue({ mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false, contextEnabled: false,
fileSearchEnabled: true, fileSearchEnabled: true,
codeEnabled: false, codeEnabled: false,
}); });
mockUseAgentToolPermissions.mockReturnValue({ mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: true, fileSearchAllowedByAgent: true,
codeAllowedByAgent: false, codeAllowedByAgent: false,
provider: undefined, provider: undefined,
}); });
renderMenu({ endpointType: EModelEndpoint.openAI });
openMenu(); renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload for File Search')).toBeInTheDocument(); expect(screen.getByText('Upload for File Search')).toBeInTheDocument();
}); });
it('does NOT show File Search when enabled but not allowed by agent', () => { it('should NOT show File Search when enabled but not allowed by agent', () => {
setupMocks();
mockUseAgentCapabilities.mockReturnValue({ mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false, contextEnabled: false,
fileSearchEnabled: true, fileSearchEnabled: true,
codeEnabled: false, codeEnabled: false,
}); });
renderMenu({ endpointType: EModelEndpoint.openAI });
openMenu(); mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: undefined,
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.queryByText('Upload for File Search')).not.toBeInTheDocument(); expect(screen.queryByText('Upload for File Search')).not.toBeInTheDocument();
}); });
it('shows Code Files option when enabled and allowed by agent', () => { it('should show Code Files option when enabled and allowed by agent', () => {
setupMocks();
mockUseAgentCapabilities.mockReturnValue({ mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false, contextEnabled: false,
fileSearchEnabled: false, fileSearchEnabled: false,
codeEnabled: true, codeEnabled: true,
}); });
mockUseAgentToolPermissions.mockReturnValue({ mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false, fileSearchAllowedByAgent: false,
codeAllowedByAgent: true, codeAllowedByAgent: true,
provider: undefined, provider: undefined,
}); });
renderMenu({ endpointType: EModelEndpoint.openAI });
openMenu(); renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload Code Files')).toBeInTheDocument(); expect(screen.getByText('Upload Code Files')).toBeInTheDocument();
}); });
it('shows all options when all capabilities are enabled', () => { it('should show all options when all capabilities are enabled', () => {
setupMocks();
mockUseAgentCapabilities.mockReturnValue({ mockUseAgentCapabilities.mockReturnValue({
contextEnabled: true, contextEnabled: true,
fileSearchEnabled: true, fileSearchEnabled: true,
codeEnabled: true, codeEnabled: true,
}); });
mockUseAgentToolPermissions.mockReturnValue({ mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: true, fileSearchAllowedByAgent: true,
codeAllowedByAgent: true, codeAllowedByAgent: true,
provider: undefined, provider: undefined,
}); });
renderMenu({ endpointType: EModelEndpoint.openAI });
openMenu(); renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
expect(screen.getByText('Upload as Text')).toBeInTheDocument(); expect(screen.getByText('Upload OCR Text')).toBeInTheDocument();
expect(screen.getByText('Upload for File Search')).toBeInTheDocument(); expect(screen.getByText('Upload for File Search')).toBeInTheDocument();
expect(screen.getByText('Upload Code Files')).toBeInTheDocument(); expect(screen.getByText('Upload Code Files')).toBeInTheDocument();
}); });
}); });
describe('SharePoint Integration', () => { describe('SharePoint Integration', () => {
it('shows SharePoint option when enabled', () => { it('should show SharePoint option when enabled', () => {
setupMocks();
mockUseGetStartupConfig.mockReturnValue({ mockUseGetStartupConfig.mockReturnValue({
data: { sharePointFilePickerEnabled: true }, data: {
sharePointFilePickerEnabled: true,
},
}); });
renderMenu({ endpointType: EModelEndpoint.openAI });
openMenu(); renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload from SharePoint')).toBeInTheDocument(); expect(screen.getByText('Upload from SharePoint')).toBeInTheDocument();
}); });
it('does NOT show SharePoint option when disabled', () => { it('should NOT show SharePoint option when disabled', () => {
setupMocks(); mockUseGetStartupConfig.mockReturnValue({
renderMenu({ endpointType: EModelEndpoint.openAI }); data: {
openMenu(); sharePointFilePickerEnabled: false,
},
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.queryByText('Upload from SharePoint')).not.toBeInTheDocument(); expect(screen.queryByText('Upload from SharePoint')).not.toBeInTheDocument();
}); });
}); });
describe('Edge Cases', () => { describe('Edge Cases', () => {
it('handles undefined endpoint and provider gracefully', () => { it('should handle undefined endpoint and provider gracefully', () => {
setupMocks(); mockUseAgentToolPermissions.mockReturnValue({
renderMenu({ endpoint: undefined, endpointType: undefined }); fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: undefined,
});
renderAttachFileMenu({
endpoint: undefined,
endpointType: undefined,
});
const button = screen.getByRole('button', { name: /attach file options/i }); const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument(); expect(button).toBeInTheDocument();
fireEvent.click(button); fireEvent.click(button);
// Should show Upload Image as fallback
expect(screen.getByText('Upload Image')).toBeInTheDocument(); expect(screen.getByText('Upload Image')).toBeInTheDocument();
}); });
it('handles null endpoint and provider gracefully', () => { it('should handle null endpoint and provider gracefully', () => {
setupMocks(); mockUseAgentToolPermissions.mockReturnValue({
renderMenu({ endpoint: null, endpointType: null }); fileSearchAllowedByAgent: false,
expect(screen.getByRole('button', { name: /attach file options/i })).toBeInTheDocument(); codeAllowedByAgent: false,
provider: null,
}); });
it('handles missing agentId gracefully', () => { renderAttachFileMenu({
setupMocks(); endpoint: null,
renderMenu({ agentId: undefined, endpointType: EModelEndpoint.openAI }); endpointType: null,
expect(screen.getByRole('button', { name: /attach file options/i })).toBeInTheDocument();
}); });
it('handles empty string agentId', () => { const button = screen.getByRole('button', { name: /attach file options/i });
setupMocks(); expect(button).toBeInTheDocument();
renderMenu({ agentId: '', endpointType: EModelEndpoint.openAI }); });
expect(screen.getByRole('button', { name: /attach file options/i })).toBeInTheDocument();
it('should handle missing agentId gracefully', () => {
renderAttachFileMenu({
agentId: undefined,
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
});
it('should handle empty string agentId', () => {
renderAttachFileMenu({
agentId: '',
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
});
});
describe('Google Provider Special Case', () => {
it('should use image_document_video_audio file type for Google provider', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.google,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.google,
endpointType: EModelEndpoint.google,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
const uploadProviderButton = screen.getByText('Upload to Provider');
expect(uploadProviderButton).toBeInTheDocument();
// Click the upload to provider option
fireEvent.click(uploadProviderButton);
// The file input should have been clicked (indirectly tested through the implementation)
});
it('should use image_document file type for non-Google providers', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.openAI,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.openAI,
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
const uploadProviderButton = screen.getByText('Upload to Provider');
expect(uploadProviderButton).toBeInTheDocument();
fireEvent.click(uploadProviderButton);
// Implementation detail - image_document type is used
});
});
describe('Regression Tests', () => {
it('should not break the previous behavior for direct provider attachments', () => {
// When using a direct supported provider (not through a gateway)
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.anthropic,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.anthropic,
endpointType: EModelEndpoint.anthropic,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
it('should maintain correct priority when both are supported', () => {
// Both endpointType and provider are supported, endpointType should be checked first
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.google,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.google,
endpointType: EModelEndpoint.openAI, // Different but both supported
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
// Should still work because endpointType (openAI) is supported
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
}); });
}); });
}); });

View file

@ -21,7 +21,6 @@ jest.mock('~/utils', () => ({
logger: { logger: {
log: jest.fn(), log: jest.fn(),
}, },
getCachedPreview: jest.fn(() => undefined),
})); }));
jest.mock('../Image', () => { jest.mock('../Image', () => {
@ -96,7 +95,7 @@ describe('FileRow', () => {
}; };
describe('Image URL Selection Logic', () => { describe('Image URL Selection Logic', () => {
it('should prefer cached preview over filepath when upload is complete', () => { it('should use filepath instead of preview when progress is 1 (upload complete)', () => {
const file = createMockFile({ const file = createMockFile({
file_id: 'uploaded-file', file_id: 'uploaded-file',
preview: 'blob:http://localhost:3080/temp-preview', preview: 'blob:http://localhost:3080/temp-preview',
@ -110,7 +109,8 @@ describe('FileRow', () => {
renderFileRow(filesMap); renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent; const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe('blob:http://localhost:3080/temp-preview'); expect(imageUrl).toBe('/images/user123/uploaded-file__image.png');
expect(imageUrl).not.toContain('blob:');
}); });
it('should use preview when progress is less than 1 (uploading)', () => { it('should use preview when progress is less than 1 (uploading)', () => {
@ -147,7 +147,7 @@ describe('FileRow', () => {
expect(imageUrl).toBe('/images/user123/file-without-preview__image.png'); expect(imageUrl).toBe('/images/user123/file-without-preview__image.png');
}); });
it('should prefer preview over filepath when both exist and progress is 1', () => { it('should use filepath when both preview and filepath exist and progress is exactly 1', () => {
const file = createMockFile({ const file = createMockFile({
file_id: 'complete-file', file_id: 'complete-file',
preview: 'blob:http://localhost:3080/old-blob', preview: 'blob:http://localhost:3080/old-blob',
@ -161,7 +161,7 @@ describe('FileRow', () => {
renderFileRow(filesMap); renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent; const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe('blob:http://localhost:3080/old-blob'); expect(imageUrl).toBe('/images/user123/complete-file__image.png');
}); });
}); });
@ -284,7 +284,7 @@ describe('FileRow', () => {
const urls = screen.getAllByTestId('image-url').map((el) => el.textContent); const urls = screen.getAllByTestId('image-url').map((el) => el.textContent);
expect(urls).toContain('blob:http://localhost:3080/preview-1'); expect(urls).toContain('blob:http://localhost:3080/preview-1');
expect(urls).toContain('blob:http://localhost:3080/preview-2'); expect(urls).toContain('/images/user123/file-2__image.png');
}); });
it('should deduplicate files with the same file_id', () => { it('should deduplicate files with the same file_id', () => {
@ -321,10 +321,10 @@ describe('FileRow', () => {
}); });
}); });
describe('Preview Cache Integration', () => { describe('Regression: Blob URL Bug Fix', () => {
it('should prefer preview blob URL over filepath for zero-flicker rendering', () => { it('should NOT use revoked blob URL after upload completes', () => {
const file = createMockFile({ const file = createMockFile({
file_id: 'cache-test', file_id: 'regression-test',
preview: 'blob:http://localhost:3080/d25f730c-152d-41f7-8d79-c9fa448f606b', preview: 'blob:http://localhost:3080/d25f730c-152d-41f7-8d79-c9fa448f606b',
filepath: filepath:
'/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png', '/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png',
@ -337,24 +337,8 @@ describe('FileRow', () => {
renderFileRow(filesMap); renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent; const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe('blob:http://localhost:3080/d25f730c-152d-41f7-8d79-c9fa448f606b');
});
it('should fall back to filepath when no preview exists', () => { expect(imageUrl).not.toContain('blob:');
const file = createMockFile({
file_id: 'no-preview',
preview: undefined,
filepath:
'/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png',
progress: 1,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe( expect(imageUrl).toBe(
'/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png', '/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png',
); );

View file

@ -2,10 +2,11 @@ import { useState, useRef, useEffect } from 'react';
import { useCombobox } from '@librechat/client'; import { useCombobox } from '@librechat/client';
import { AutoSizer, List } from 'react-virtualized'; import { AutoSizer, List } from 'react-virtualized';
import { EModelEndpoint } from 'librechat-data-provider'; import { EModelEndpoint } from 'librechat-data-provider';
import type { TConversation } from 'librechat-data-provider';
import type { MentionOption, ConvoGenerator } from '~/common'; import type { MentionOption, ConvoGenerator } from '~/common';
import type { SetterOrUpdater } from 'recoil'; import type { SetterOrUpdater } from 'recoil';
import { useGetConversation, useLocalize, TranslationKeys } from '~/hooks';
import useSelectMention from '~/hooks/Input/useSelectMention'; import useSelectMention from '~/hooks/Input/useSelectMention';
import { useLocalize, TranslationKeys } from '~/hooks';
import { useAssistantsMapContext } from '~/Providers'; import { useAssistantsMapContext } from '~/Providers';
import useMentions from '~/hooks/Input/useMentions'; import useMentions from '~/hooks/Input/useMentions';
import { removeCharIfLast } from '~/utils'; import { removeCharIfLast } from '~/utils';
@ -14,6 +15,7 @@ import MentionItem from './MentionItem';
const ROW_HEIGHT = 44; const ROW_HEIGHT = 44;
export default function Mention({ export default function Mention({
conversation,
setShowMentionPopover, setShowMentionPopover,
newConversation, newConversation,
textAreaRef, textAreaRef,
@ -21,6 +23,7 @@ export default function Mention({
placeholder = 'com_ui_mention', placeholder = 'com_ui_mention',
includeAssistants = true, includeAssistants = true,
}: { }: {
conversation: TConversation | null;
setShowMentionPopover: SetterOrUpdater<boolean>; setShowMentionPopover: SetterOrUpdater<boolean>;
newConversation: ConvoGenerator; newConversation: ConvoGenerator;
textAreaRef: React.MutableRefObject<HTMLTextAreaElement | null>; textAreaRef: React.MutableRefObject<HTMLTextAreaElement | null>;
@ -29,7 +32,6 @@ export default function Mention({
includeAssistants?: boolean; includeAssistants?: boolean;
}) { }) {
const localize = useLocalize(); const localize = useLocalize();
const getConversation = useGetConversation(0);
const assistantsMap = useAssistantsMapContext(); const assistantsMap = useAssistantsMapContext();
const { const {
options, options,
@ -43,9 +45,9 @@ export default function Mention({
const { onSelectMention } = useSelectMention({ const { onSelectMention } = useSelectMention({
presets, presets,
modelSpecs, modelSpecs,
conversation,
assistantsMap, assistantsMap,
endpointsConfig, endpointsConfig,
getConversation,
newConversation, newConversation,
}); });

View file

@ -1,9 +1,6 @@
import React, { createContext, useCallback, useContext, useMemo, useRef } from 'react'; import React, { createContext, useContext, useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import type { EModelEndpoint, TConversation } from 'librechat-data-provider'; import type { EModelEndpoint, TConversation } from 'librechat-data-provider';
import type { ConvoGenerator } from '~/common'; import { useChatContext } from '~/Providers/ChatContext';
import { useGetConversation, useNewConvo } from '~/hooks';
import store from '~/store';
interface ModelSelectorChatContextValue { interface ModelSelectorChatContextValue {
endpoint?: EModelEndpoint | null; endpoint?: EModelEndpoint | null;
@ -11,8 +8,8 @@ interface ModelSelectorChatContextValue {
spec?: string | null; spec?: string | null;
agent_id?: string | null; agent_id?: string | null;
assistant_id?: string | null; assistant_id?: string | null;
getConversation: () => TConversation | null; conversation: TConversation | null;
newConversation: ConvoGenerator; newConversation: ReturnType<typeof useChatContext>['newConversation'];
} }
const ModelSelectorChatContext = createContext<ModelSelectorChatContextValue | undefined>( const ModelSelectorChatContext = createContext<ModelSelectorChatContextValue | undefined>(
@ -20,34 +17,20 @@ const ModelSelectorChatContext = createContext<ModelSelectorChatContextValue | u
); );
export function ModelSelectorChatProvider({ children }: { children: React.ReactNode }) { export function ModelSelectorChatProvider({ children }: { children: React.ReactNode }) {
const getConversation = useGetConversation(0); const { conversation, newConversation } = useChatContext();
const { newConversation: nextNewConversation } = useNewConvo();
const spec = useRecoilValue(store.conversationSpecByIndex(0));
const model = useRecoilValue(store.conversationModelByIndex(0));
const agent_id = useRecoilValue(store.conversationAgentIdByIndex(0));
const endpoint = useRecoilValue(store.conversationEndpointByIndex(0));
const assistant_id = useRecoilValue(store.conversationAssistantIdByIndex(0));
const newConversationRef = useRef(nextNewConversation);
newConversationRef.current = nextNewConversation;
const newConversation = useCallback<ConvoGenerator>(
(params) => newConversationRef.current(params),
[],
);
/** Context value only created when relevant conversation properties change */ /** Context value only created when relevant conversation properties change */
const contextValue = useMemo<ModelSelectorChatContextValue>( const contextValue = useMemo<ModelSelectorChatContextValue>(
() => ({ () => ({
model, endpoint: conversation?.endpoint,
spec, model: conversation?.model,
agent_id, spec: conversation?.spec,
endpoint, agent_id: conversation?.agent_id,
assistant_id, assistant_id: conversation?.assistant_id,
getConversation, conversation,
newConversation, newConversation,
}), }),
[endpoint, model, spec, agent_id, assistant_id, getConversation, newConversation], [conversation, newConversation],
); );
return ( return (

View file

@ -58,7 +58,7 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
const agentsMap = useAgentsMapContext(); const agentsMap = useAgentsMapContext();
const assistantsMap = useAssistantsMapContext(); const assistantsMap = useAssistantsMapContext();
const { data: endpointsConfig } = useGetEndpointsQuery(); const { data: endpointsConfig } = useGetEndpointsQuery();
const { endpoint, model, spec, agent_id, assistant_id, getConversation, newConversation } = const { endpoint, model, spec, agent_id, assistant_id, conversation, newConversation } =
useModelSelectorChatContext(); useModelSelectorChatContext();
const localize = useLocalize(); const localize = useLocalize();
const { announcePolite } = useLiveAnnouncer(); const { announcePolite } = useLiveAnnouncer();
@ -114,7 +114,7 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
const { onSelectEndpoint, onSelectSpec } = useSelectMention({ const { onSelectEndpoint, onSelectSpec } = useSelectMention({
// presets, // presets,
modelSpecs, modelSpecs,
getConversation, conversation,
assistantsMap, assistantsMap,
endpointsConfig, endpointsConfig,
newConversation, newConversation,
@ -171,15 +171,14 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
}, 200), }, 200),
[], [],
); );
const setEndpointSearchValue = useCallback((endpoint: string, value: string) => { const setEndpointSearchValue = (endpoint: string, value: string) => {
setEndpointSearchValues((prev) => ({ setEndpointSearchValues((prev) => ({
...prev, ...prev,
[endpoint]: value, [endpoint]: value,
})); }));
}, []); };
const handleSelectSpec = useCallback( const handleSelectSpec = (spec: t.TModelSpec) => {
(spec: t.TModelSpec) => {
let model = spec.preset.model ?? null; let model = spec.preset.model ?? null;
onSelectSpec?.(spec); onSelectSpec?.(spec);
if (isAgentsEndpoint(spec.preset.endpoint)) { if (isAgentsEndpoint(spec.preset.endpoint)) {
@ -192,12 +191,9 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
model, model,
modelSpec: spec.name, modelSpec: spec.name,
}); });
}, };
[onSelectSpec],
);
const handleSelectEndpoint = useCallback( const handleSelectEndpoint = (endpoint: Endpoint) => {
(endpoint: Endpoint) => {
if (!endpoint.hasModels) { if (!endpoint.hasModels) {
if (endpoint.value) { if (endpoint.value) {
onSelectEndpoint?.(endpoint.value); onSelectEndpoint?.(endpoint.value);
@ -208,12 +204,9 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
modelSpec: '', modelSpec: '',
}); });
} }
}, };
[onSelectEndpoint],
);
const handleSelectModel = useCallback( const handleSelectModel = (endpoint: Endpoint, model: string) => {
(endpoint: Endpoint, model: string) => {
if (isAgentsEndpoint(endpoint.value)) { if (isAgentsEndpoint(endpoint.value)) {
onSelectEndpoint?.(endpoint.value, { onSelectEndpoint?.(endpoint.value, {
agent_id: model, agent_id: model,
@ -236,21 +229,22 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
const modelDisplayName = getModelDisplayName(endpoint, model); const modelDisplayName = getModelDisplayName(endpoint, model);
const announcement = localize('com_ui_model_selected', { 0: modelDisplayName }); const announcement = localize('com_ui_model_selected', { 0: modelDisplayName });
announcePolite({ message: announcement, isStatus: true }); announcePolite({ message: announcement, isStatus: true });
}, };
[agentsMap, announcePolite, assistantsMap, getModelDisplayName, localize, onSelectEndpoint],
);
const value = useMemo( const value = {
() => ({ // State
searchValue, searchValue,
searchResults, searchResults,
selectedValues, selectedValues,
endpointSearchValues, endpointSearchValues,
// LibreChat
agentsMap, agentsMap,
modelSpecs, modelSpecs,
assistantsMap, assistantsMap,
mappedEndpoints, mappedEndpoints,
endpointsConfig, endpointsConfig,
// Functions
handleSelectSpec, handleSelectSpec,
handleSelectModel, handleSelectModel,
setSelectedValues, setSelectedValues,
@ -258,28 +252,9 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
setEndpointSearchValue, setEndpointSearchValue,
endpointRequiresUserKey, endpointRequiresUserKey,
setSearchValue: setDebouncedSearchValue, setSearchValue: setDebouncedSearchValue,
// Dialog
...keyProps, ...keyProps,
}), };
[
searchValue,
searchResults,
selectedValues,
endpointSearchValues,
agentsMap,
modelSpecs,
assistantsMap,
mappedEndpoints,
endpointsConfig,
handleSelectSpec,
handleSelectModel,
setSelectedValues,
handleSelectEndpoint,
setEndpointSearchValue,
endpointRequiresUserKey,
setDebouncedSearchValue,
keyProps,
],
);
return <ModelSelectorContext.Provider value={value}>{children}</ModelSelectorContext.Provider>; return <ModelSelectorContext.Provider value={value}>{children}</ModelSelectorContext.Provider>;
} }

Some files were not shown because too many files have changed in this diff Show more