diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..725ac8b6bd --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +# Force LF line endings for shell scripts and git hooks (required for cross-platform compatibility) +.husky/* text eol=lf +*.sh text eol=lf diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 038c90627e..9dd3905c0e 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -97,6 +97,65 @@ jobs: path: packages/api/dist retention-days: 2 + typecheck: + name: TypeScript type checks + needs: build + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + + - name: Use Node.js 20.19 + uses: actions/setup-node@v4 + with: + node-version: '20.19' + + - name: Restore node_modules cache + id: cache-node-modules + uses: actions/cache@v4 + with: + path: | + node_modules + api/node_modules + packages/api/node_modules + packages/data-provider/node_modules + packages/data-schemas/node_modules + key: node-modules-backend-${{ runner.os }}-20.19-${{ hashFiles('package-lock.json') }} + + - name: Install dependencies + if: steps.cache-node-modules.outputs.cache-hit != 'true' + run: npm ci + + - name: Download data-provider build + uses: actions/download-artifact@v4 + with: + name: build-data-provider + path: packages/data-provider/dist + + - name: Download data-schemas build + uses: actions/download-artifact@v4 + with: + name: build-data-schemas + path: packages/data-schemas/dist + + - name: Download api build + uses: actions/download-artifact@v4 + with: + name: build-api + path: packages/api/dist + + - name: Type check data-provider + run: npx tsc --noEmit -p packages/data-provider/tsconfig.json + + - name: Type check data-schemas + run: npx tsc --noEmit -p packages/data-schemas/tsconfig.json + + - name: Type check @librechat/api + run: npx tsc --noEmit -p packages/api/tsconfig.json + + - name: Type check @librechat/client + run: npx tsc --noEmit -p packages/client/tsconfig.json + circular-deps: name: Circular dependency checks needs: build diff --git a/.gitignore b/.gitignore index ff2ae59633..e302d15a46 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,7 @@ bower_components/ .clineignore .cursor .aider* +.bg-shell/ # Floobits .floo @@ -129,6 +130,7 @@ helm/**/charts/ helm/**/.values.yaml !/client/src/@types/i18next.d.ts +!/client/src/@types/react.d.ts # SAML Idp cert *.cert @@ -143,7 +145,6 @@ helm/**/.values.yaml /.codeium *.local.md - # Removed Windows wrapper files per user request hive-mind-prompt-*.txt @@ -175,3 +176,4 @@ claude-flow # Removed Windows wrapper files per user request hive-mind-prompt-*.txt CLAUDE.md +.gsd diff --git a/AGENTS.md b/AGENTS.md index ec44607aa7..ceb2b988dc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,166 +1 @@ -# LibreChat - -## Project Overview - -LibreChat is a monorepo with the following key workspaces: - -| Workspace | Language | Side | Dependency | Purpose | -|---|---|---|---|---| -| `/api` | JS (legacy) | Backend | `packages/api`, `packages/data-schemas`, `packages/data-provider`, `@librechat/agents` | Express server — minimize changes here | -| `/packages/api` | **TypeScript** | Backend | `packages/data-schemas`, `packages/data-provider` | New backend code lives here (TS only, consumed by `/api`) | -| `/packages/data-schemas` | TypeScript | Backend | `packages/data-provider` | Database models/schemas, shareable across backend projects | -| `/packages/data-provider` | TypeScript | Shared | — | Shared API types, endpoints, data-service — used by both frontend and backend | -| `/client` | TypeScript/React | Frontend | `packages/data-provider`, `packages/client` | Frontend SPA | -| `/packages/client` | TypeScript | Frontend | `packages/data-provider` | Shared frontend utilities | - -The source code for `@librechat/agents` (major backend dependency, same team) is at `/home/danny/agentus`. - ---- - -## Workspace Boundaries - -- **All new backend code must be TypeScript** in `/packages/api`. -- Keep `/api` changes to the absolute minimum (thin JS wrappers calling into `/packages/api`). -- Database-specific shared logic goes in `/packages/data-schemas`. -- Frontend/backend shared API logic (endpoints, types, data-service) goes in `/packages/data-provider`. -- Build data-provider from project root: `npm run build:data-provider`. - ---- - -## Code Style - -### Structure and Clarity - -- **Never-nesting**: early returns, flat code, minimal indentation. Break complex operations into well-named helpers. -- **Functional first**: pure functions, immutable data, `map`/`filter`/`reduce` over imperative loops. Only reach for OOP when it clearly improves domain modeling or state encapsulation. -- **No dynamic imports** unless absolutely necessary. - -### DRY - -- Extract repeated logic into utility functions. -- Reusable hooks / higher-order components for UI patterns. -- Parameterized helpers instead of near-duplicate functions. -- Constants for repeated values; configuration objects over duplicated init code. -- Shared validators, centralized error handling, single source of truth for business rules. -- Shared typing system with interfaces/types extending common base definitions. -- Abstraction layers for external API interactions. - -### Iteration and Performance - -- **Minimize looping** — especially over shared data structures like message arrays, which are iterated frequently throughout the codebase. Every additional pass adds up at scale. -- Consolidate sequential O(n) operations into a single pass whenever possible; never loop over the same collection twice if the work can be combined. -- Choose data structures that reduce the need to iterate (e.g., `Map`/`Set` for lookups instead of `Array.find`/`Array.includes`). -- Avoid unnecessary object creation; consider space-time tradeoffs. -- Prevent memory leaks: careful with closures, dispose resources/event listeners, no circular references. - -### Type Safety - -- **Never use `any`**. Explicit types for all parameters, return values, and variables. -- **Limit `unknown`** — avoid `unknown`, `Record`, and `as unknown as T` assertions. A `Record` almost always signals a missing explicit type definition. -- **Don't duplicate types** — before defining a new type, check whether it already exists in the project (especially `packages/data-provider`). Reuse and extend existing types rather than creating redundant definitions. -- Use union types, generics, and interfaces appropriately. -- All TypeScript and ESLint warnings/errors must be addressed — do not leave unresolved diagnostics. - -### Comments and Documentation - -- Write self-documenting code; no inline comments narrating what code does. -- JSDoc only for complex/non-obvious logic or intellisense on public APIs. -- Single-line JSDoc for brief docs, multi-line for complex cases. -- Avoid standalone `//` comments unless absolutely necessary. - -### Import Order - -Imports are organized into three sections: - -1. **Package imports** — sorted shortest to longest line length (`react` always first). -2. **`import type` imports** — sorted longest to shortest (package types first, then local types; length resets between sub-groups). -3. **Local/project imports** — sorted longest to shortest. - -Multi-line imports count total character length across all lines. Consolidate value imports from the same module. Always use standalone `import type { ... }` — never inline `type` inside value imports. - -### JS/TS Loop Preferences - -- **Limit looping as much as possible.** Prefer single-pass transformations and avoid re-iterating the same data. -- `for (let i = 0; ...)` for performance-critical or index-dependent operations. -- `for...of` for simple array iteration. -- `for...in` only for object property enumeration. - ---- - -## Frontend Rules (`client/src/**/*`) - -### Localization - -- All user-facing text must use `useLocalize()`. -- Only update English keys in `client/src/locales/en/translation.json` (other languages are automated externally). -- Semantic key prefixes: `com_ui_`, `com_assistants_`, etc. - -### Components - -- TypeScript for all React components with proper type imports. -- Semantic HTML with ARIA labels (`role`, `aria-label`) for accessibility. -- Group related components in feature directories (e.g., `SidePanel/Memories/`). -- Use index files for clean exports. - -### Data Management - -- Feature hooks: `client/src/data-provider/[Feature]/queries.ts` → `[Feature]/index.ts` → `client/src/data-provider/index.ts`. -- React Query (`@tanstack/react-query`) for all API interactions; proper query invalidation on mutations. -- QueryKeys and MutationKeys in `packages/data-provider/src/keys.ts`. - -### Data-Provider Integration - -- Endpoints: `packages/data-provider/src/api-endpoints.ts` -- Data service: `packages/data-provider/src/data-service.ts` -- Types: `packages/data-provider/src/types/queries.ts` -- Use `encodeURIComponent` for dynamic URL parameters. - -### Performance - -- Prioritize memory and speed efficiency at scale. -- Cursor pagination for large datasets. -- Proper dependency arrays to avoid unnecessary re-renders. -- Leverage React Query caching and background refetching. - ---- - -## Development Commands - -| Command | Purpose | -|---|---| -| `npm run smart-reinstall` | Install deps (if lockfile changed) + build via Turborepo | -| `npm run reinstall` | Clean install — wipe `node_modules` and reinstall from scratch | -| `npm run backend` | Start the backend server | -| `npm run backend:dev` | Start backend with file watching (development) | -| `npm run build` | Build all compiled code via Turborepo (parallel, cached) | -| `npm run frontend` | Build all compiled code sequentially (legacy fallback) | -| `npm run frontend:dev` | Start frontend dev server with HMR (port 3090, requires backend running) | -| `npm run build:data-provider` | Rebuild `packages/data-provider` after changes | - -- Node.js: v20.19.0+ or ^22.12.0 or >= 23.0.0 -- Database: MongoDB -- Backend runs on `http://localhost:3080/`; frontend dev server on `http://localhost:3090/` - ---- - -## Testing - -- Framework: **Jest**, run per-workspace. -- Run tests from their workspace directory: `cd api && npx jest `, `cd packages/api && npx jest `, etc. -- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering. -- Cover loading, success, and error states for UI/data flows. - -### Philosophy - -- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort. -- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic. -- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls. -- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals. -- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls. -- Heavy mocking is a code smell, not a testing strategy. - ---- - -## Formatting - -Fix all formatting lint errors (trailing spaces, tabs, newlines, indentation) using auto-fix when available. All TypeScript/ESLint warnings and errors **must** be resolved. +CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 120000 index 47dc3e3d86..0000000000 --- a/CLAUDE.md +++ /dev/null @@ -1 +0,0 @@ -AGENTS.md \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..81362cfc57 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,172 @@ +# LibreChat + +## Project Overview + +LibreChat is a monorepo with the following key workspaces: + +| Workspace | Language | Side | Dependency | Purpose | +|---|---|---|---|---| +| `/api` | JS (legacy) | Backend | `packages/api`, `packages/data-schemas`, `packages/data-provider`, `@librechat/agents` | Express server — minimize changes here | +| `/packages/api` | **TypeScript** | Backend | `packages/data-schemas`, `packages/data-provider` | New backend code lives here (TS only, consumed by `/api`) | +| `/packages/data-schemas` | TypeScript | Backend | `packages/data-provider` | Database models/schemas, shareable across backend projects | +| `/packages/data-provider` | TypeScript | Shared | — | Shared API types, endpoints, data-service — used by both frontend and backend | +| `/client` | TypeScript/React | Frontend | `packages/data-provider`, `packages/client` | Frontend SPA | +| `/packages/client` | TypeScript | Frontend | `packages/data-provider` | Shared frontend utilities | + +The source code for `@librechat/agents` (major backend dependency, same team) is at `/home/danny/agentus`. + +--- + +## Workspace Boundaries + +- **All new backend code must be TypeScript** in `/packages/api`. +- Keep `/api` changes to the absolute minimum (thin JS wrappers calling into `/packages/api`). +- Database-specific shared logic goes in `/packages/data-schemas`. +- Frontend/backend shared API logic (endpoints, types, data-service) goes in `/packages/data-provider`. +- Build data-provider from project root: `npm run build:data-provider`. + +--- + +## Code Style + +### Naming and File Organization + +- **Single-word file names** whenever possible (e.g., `permissions.ts`, `capabilities.ts`, `service.ts`). +- When multiple words are needed, prefer grouping related modules under a **single-word directory** rather than using multi-word file names (e.g., `admin/capabilities.ts` not `adminCapabilities.ts`). +- The directory already provides context — `app/service.ts` not `app/appConfigService.ts`. + +### Structure and Clarity + +- **Never-nesting**: early returns, flat code, minimal indentation. Break complex operations into well-named helpers. +- **Functional first**: pure functions, immutable data, `map`/`filter`/`reduce` over imperative loops. Only reach for OOP when it clearly improves domain modeling or state encapsulation. +- **No dynamic imports** unless absolutely necessary. + +### DRY + +- Extract repeated logic into utility functions. +- Reusable hooks / higher-order components for UI patterns. +- Parameterized helpers instead of near-duplicate functions. +- Constants for repeated values; configuration objects over duplicated init code. +- Shared validators, centralized error handling, single source of truth for business rules. +- Shared typing system with interfaces/types extending common base definitions. +- Abstraction layers for external API interactions. + +### Iteration and Performance + +- **Minimize looping** — especially over shared data structures like message arrays, which are iterated frequently throughout the codebase. Every additional pass adds up at scale. +- Consolidate sequential O(n) operations into a single pass whenever possible; never loop over the same collection twice if the work can be combined. +- Choose data structures that reduce the need to iterate (e.g., `Map`/`Set` for lookups instead of `Array.find`/`Array.includes`). +- Avoid unnecessary object creation; consider space-time tradeoffs. +- Prevent memory leaks: careful with closures, dispose resources/event listeners, no circular references. + +### Type Safety + +- **Never use `any`**. Explicit types for all parameters, return values, and variables. +- **Limit `unknown`** — avoid `unknown`, `Record`, and `as unknown as T` assertions. A `Record` almost always signals a missing explicit type definition. +- **Don't duplicate types** — before defining a new type, check whether it already exists in the project (especially `packages/data-provider`). Reuse and extend existing types rather than creating redundant definitions. +- Use union types, generics, and interfaces appropriately. +- All TypeScript and ESLint warnings/errors must be addressed — do not leave unresolved diagnostics. + +### Comments and Documentation + +- Write self-documenting code; no inline comments narrating what code does. +- JSDoc only for complex/non-obvious logic or intellisense on public APIs. +- Single-line JSDoc for brief docs, multi-line for complex cases. +- Avoid standalone `//` comments unless absolutely necessary. + +### Import Order + +Imports are organized into three sections: + +1. **Package imports** — sorted shortest to longest line length (`react` always first). +2. **`import type` imports** — sorted longest to shortest (package types first, then local types; length resets between sub-groups). +3. **Local/project imports** — sorted longest to shortest. + +Multi-line imports count total character length across all lines. Consolidate value imports from the same module. Always use standalone `import type { ... }` — never inline `type` inside value imports. + +### JS/TS Loop Preferences + +- **Limit looping as much as possible.** Prefer single-pass transformations and avoid re-iterating the same data. +- `for (let i = 0; ...)` for performance-critical or index-dependent operations. +- `for...of` for simple array iteration. +- `for...in` only for object property enumeration. + +--- + +## Frontend Rules (`client/src/**/*`) + +### Localization + +- All user-facing text must use `useLocalize()`. +- Only update English keys in `client/src/locales/en/translation.json` (other languages are automated externally). +- Semantic key prefixes: `com_ui_`, `com_assistants_`, etc. + +### Components + +- TypeScript for all React components with proper type imports. +- Semantic HTML with ARIA labels (`role`, `aria-label`) for accessibility. +- Group related components in feature directories (e.g., `SidePanel/Memories/`). +- Use index files for clean exports. + +### Data Management + +- Feature hooks: `client/src/data-provider/[Feature]/queries.ts` → `[Feature]/index.ts` → `client/src/data-provider/index.ts`. +- React Query (`@tanstack/react-query`) for all API interactions; proper query invalidation on mutations. +- QueryKeys and MutationKeys in `packages/data-provider/src/keys.ts`. + +### Data-Provider Integration + +- Endpoints: `packages/data-provider/src/api-endpoints.ts` +- Data service: `packages/data-provider/src/data-service.ts` +- Types: `packages/data-provider/src/types/queries.ts` +- Use `encodeURIComponent` for dynamic URL parameters. + +### Performance + +- Prioritize memory and speed efficiency at scale. +- Cursor pagination for large datasets. +- Proper dependency arrays to avoid unnecessary re-renders. +- Leverage React Query caching and background refetching. + +--- + +## Development Commands + +| Command | Purpose | +|---|---| +| `npm run smart-reinstall` | Install deps (if lockfile changed) + build via Turborepo | +| `npm run reinstall` | Clean install — wipe `node_modules` and reinstall from scratch | +| `npm run backend` | Start the backend server | +| `npm run backend:dev` | Start backend with file watching (development) | +| `npm run build` | Build all compiled code via Turborepo (parallel, cached) | +| `npm run frontend` | Build all compiled code sequentially (legacy fallback) | +| `npm run frontend:dev` | Start frontend dev server with HMR (port 3090, requires backend running) | +| `npm run build:data-provider` | Rebuild `packages/data-provider` after changes | + +- Node.js: v20.19.0+ or ^22.12.0 or >= 23.0.0 +- Database: MongoDB +- Backend runs on `http://localhost:3080/`; frontend dev server on `http://localhost:3090/` + +--- + +## Testing + +- Framework: **Jest**, run per-workspace. +- Run tests from their workspace directory: `cd api && npx jest `, `cd packages/api && npx jest `, etc. +- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering. +- Cover loading, success, and error states for UI/data flows. + +### Philosophy + +- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort. +- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic. +- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls. +- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals. +- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls. +- Heavy mocking is a code smell, not a testing strategy. + +--- + +## Formatting + +Fix all formatting lint errors (trailing spaces, tabs, newlines, indentation) using auto-fix when available. All TypeScript/ESLint warnings and errors **must** be resolved. diff --git a/README.md b/README.md index 7da34974e3..a7f68d9a92 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@

- + Deploy on Railway diff --git a/README.zh.md b/README.zh.md index cc9cb5a205..7f74057413 100644 --- a/README.zh.md +++ b/README.zh.md @@ -1,4 +1,4 @@ - +

@@ -34,7 +34,7 @@

- + Deploy on Railway diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index ae2d362773..905cadfd23 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -17,11 +17,13 @@ const { ContentTypes, excludedKeys, EModelEndpoint, + mergeFileConfig, isParamEndpoint, isAgentsEndpoint, isEphemeralAgentId, supportsBalanceCheck, isBedrockDocumentType, + getEndpointFileConfig, } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { logViolation } = require('~/cache'); @@ -32,7 +34,6 @@ class BaseClient { constructor(apiKey, options = {}) { this.apiKey = apiKey; this.sender = options.sender ?? 'AI'; - this.contextStrategy = null; this.currentDateString = new Date().toLocaleDateString('en-us', { year: 'numeric', month: 'long', @@ -72,6 +73,10 @@ class BaseClient { this.currentMessages = []; /** @type {import('librechat-data-provider').VisionModes | undefined} */ this.visionMode; + /** @type {import('librechat-data-provider').FileConfig | undefined} */ + this._mergedFileConfig; + /** @type {import('librechat-data-provider').EndpointFileConfig | undefined} */ + this._endpointFileConfig; } setOptions() { @@ -487,7 +492,12 @@ class BaseClient { } delete userMessage.image_urls; } - userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); + userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user).catch( + (err) => { + logger.error('[BaseClient] Failed to save user message:', err); + return {}; + }, + ); this.savedMessageIds.add(userMessage.messageId); if (typeof opts?.getReqData === 'function') { opts.getReqData({ @@ -519,6 +529,8 @@ class BaseClient { getMultiplier: db.getMultiplier, findBalanceByUser: db.findBalanceByUser, createAutoRefillTransaction: db.createAutoRefillTransaction, + balanceConfig, + upsertBalanceFields: db.upsertBalanceFields, }, ); } @@ -727,21 +739,30 @@ class BaseClient { * @param {string | null} user */ async saveMessageToDatabase(message, endpointOptions, user = null) { + // Snapshot options before any await; disposeClient may set client.options = null + // while this method is suspended at an I/O boundary, but the local reference + // remains valid (disposeClient nulls the property, not the object itself). + const options = this.options; + if (!options) { + logger.error('[BaseClient] saveMessageToDatabase: client disposed before save, skipping'); + return {}; + } + if (this.user && user !== this.user) { throw new Error('User mismatch.'); } - const hasAddedConvo = this.options?.req?.body?.addedConvo != null; + const hasAddedConvo = options?.req?.body?.addedConvo != null; const reqCtx = { - userId: this.options?.req?.user?.id, - isTemporary: this.options?.req?.body?.isTemporary, - interfaceConfig: this.options?.req?.config?.interfaceConfig, + userId: options?.req?.user?.id, + isTemporary: options?.req?.body?.isTemporary, + interfaceConfig: options?.req?.config?.interfaceConfig, }; const savedMessage = await db.saveMessage( reqCtx, { ...message, - endpoint: this.options.endpoint, + endpoint: options.endpoint, unfinished: false, user, ...(hasAddedConvo && { addedConvo: true }), @@ -755,20 +776,20 @@ class BaseClient { const fieldsToKeep = { conversationId: message.conversationId, - endpoint: this.options.endpoint, - endpointType: this.options.endpointType, + endpoint: options.endpoint, + endpointType: options.endpointType, ...endpointOptions, }; const existingConvo = this.fetchedConvo === true ? null - : await db.getConvo(this.options?.req?.user?.id, message.conversationId); + : await db.getConvo(options?.req?.user?.id, message.conversationId); const unsetFields = {}; const exceptions = new Set(['spec', 'iconURL']); const hasNonEphemeralAgent = - isAgentsEndpoint(this.options.endpoint) && + isAgentsEndpoint(options.endpoint) && endpointOptions?.agent_id && !isEphemeralAgentId(endpointOptions.agent_id); if (hasNonEphemeralAgent) { @@ -1072,6 +1093,7 @@ class BaseClient { provider: this.options.agent?.provider ?? this.options.endpoint, endpoint: this.options.agent?.endpoint ?? this.options.endpoint, useResponsesApi: this.options.agent?.model_parameters?.useResponsesApi, + model: this.modelOptions?.model ?? this.model, }, getStrategyFunctions, ); @@ -1144,6 +1166,16 @@ class BaseClient { const provider = this.options.agent?.provider ?? this.options.endpoint; const isBedrock = provider === EModelEndpoint.bedrock; + if (!this._mergedFileConfig && this.options.req?.config?.fileConfig) { + this._mergedFileConfig = mergeFileConfig(this.options.req.config.fileConfig); + const endpoint = this.options.agent?.endpoint ?? this.options.endpoint; + this._endpointFileConfig = getEndpointFileConfig({ + fileConfig: this._mergedFileConfig, + endpoint, + endpointType: this.options.endpointType, + }); + } + for (const file of attachments) { /** @type {FileSources} */ const source = file.source ?? FileSources.local; @@ -1170,6 +1202,14 @@ class BaseClient { } else if (file.type.startsWith('audio/')) { categorizedAttachments.audios.push(file); allFiles.push(file); + } else if ( + file.type && + this._mergedFileConfig && + this._endpointFileConfig?.supportedMimeTypes && + this._mergedFileConfig.checkType(file.type, this._endpointFileConfig.supportedMimeTypes) + ) { + categorizedAttachments.documents.push(file); + allFiles.push(file); } } diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index edbbcaa87b..3ce910948c 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -38,7 +38,7 @@ jest.mock('~/models', () => ({ updateFileUsage: jest.fn(), })); -const { getConvo, saveConvo } = require('~/models'); +const { getConvo, saveConvo, saveMessage } = require('~/models'); jest.mock('@librechat/agents', () => { const actual = jest.requireActual('@librechat/agents'); @@ -906,6 +906,52 @@ describe('BaseClient', () => { ); }); + test('saveMessageToDatabase returns early when this.options is null (client disposed)', async () => { + const savedOptions = TestClient.options; + TestClient.options = null; + saveMessage.mockClear(); + + const result = await TestClient.saveMessageToDatabase( + { messageId: 'msg-1', conversationId: 'conv-1', isCreatedByUser: true, text: 'hi' }, + {}, + null, + ); + + expect(result).toEqual({}); + expect(saveMessage).not.toHaveBeenCalled(); + + TestClient.options = savedOptions; + }); + + test('saveMessageToDatabase uses snapshot of options, immune to mid-await disposal', async () => { + const savedOptions = TestClient.options; + saveMessage.mockClear(); + saveConvo.mockClear(); + + // Make db.saveMessage yield, simulating I/O suspension during which disposal occurs + saveMessage.mockImplementation(async (_reqCtx, msgData) => { + // Simulate disposeClient nullifying client.options while awaiting + TestClient.options = null; + return msgData; + }); + saveConvo.mockResolvedValue({ conversationId: 'conv-1' }); + + const result = await TestClient.saveMessageToDatabase( + { messageId: 'msg-1', conversationId: 'conv-1', isCreatedByUser: true, text: 'hi' }, + { endpoint: 'openAI' }, + null, + ); + + // Should complete without TypeError, using the snapshotted options + expect(result).toHaveProperty('message'); + expect(result).toHaveProperty('conversation'); + expect(saveMessage).toHaveBeenCalled(); + + TestClient.options = savedOptions; + saveMessage.mockReset(); + saveConvo.mockReset(); + }); + test('userMessagePromise is awaited before saving response message', async () => { // Mock the saveMessageToDatabase method TestClient.saveMessageToDatabase = jest.fn().mockImplementation(() => { diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 4b86101425..8adb43f945 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -14,7 +14,6 @@ const { buildImageToolContext, buildWebSearchContext, } = require('@librechat/api'); -const { getMCPServersRegistry } = require('~/config'); const { Tools, Constants, @@ -39,12 +38,13 @@ const { createGeminiImageTool, createOpenAIImageTools, } = require('../'); -const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); +const { createMCPTool, createMCPTools, resolveConfigServers } = require('~/server/services/MCP'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); -const { createMCPTool, createMCPTools } = require('~/server/services/MCP'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getMCPServerTools } = require('~/server/services/Config'); +const { getMCPServersRegistry } = require('~/config'); const { getRoleByName } = require('~/models'); /** @@ -256,6 +256,12 @@ const loadTools = async ({ const toolContextMap = {}; const requestedMCPTools = {}; + /** Resolve config-source servers for the current user/tenant context */ + let configServers; + if (tools.some((tool) => tool && mcpToolPattern.test(tool))) { + configServers = await resolveConfigServers(options.req); + } + for (const tool of tools) { if (tool === Tools.execute_code) { requestedTools[tool] = async () => { @@ -341,7 +347,7 @@ const loadTools = async ({ continue; } const serverConfig = serverName - ? await getMCPServersRegistry().getServerConfig(serverName, user) + ? await getMCPServersRegistry().getServerConfig(serverName, user, configServers) : null; if (!serverConfig) { logger.warn( @@ -419,6 +425,7 @@ const loadTools = async ({ let index = -1; const failedMCPServers = new Set(); const safeUser = createSafeUser(options.req?.user); + for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) { index++; /** @type {LCAvailableTools} */ @@ -433,6 +440,7 @@ const loadTools = async ({ signal, user: safeUser, userMCPAuthMap, + configServers, res: options.res, streamId: options.req?._resumableStreamId || null, model: agent?.model ?? model, diff --git a/api/db/index.js b/api/db/index.js index 5c29902f69..f4359c8adf 100644 --- a/api/db/index.js +++ b/api/db/index.js @@ -1,8 +1,13 @@ const mongoose = require('mongoose'); const { createModels } = require('@librechat/data-schemas'); const { connectDb } = require('./connect'); -const indexSync = require('./indexSync'); +// createModels MUST run before requiring indexSync. +// indexSync.js captures mongoose.models.Message and mongoose.models.Conversation +// at module load time. If those models are not registered first, all MeiliSearch +// sync operations will silently fail on every startup. createModels(mongoose); +const indexSync = require('./indexSync'); + module.exports = { connectDb, indexSync }; diff --git a/api/db/index.spec.js b/api/db/index.spec.js new file mode 100644 index 0000000000..e1ebe176dc --- /dev/null +++ b/api/db/index.spec.js @@ -0,0 +1,26 @@ +describe('api/db/index.js', () => { + test('createModels is called before indexSync is loaded', () => { + jest.resetModules(); + + const callOrder = []; + + jest.mock('@librechat/data-schemas', () => ({ + createModels: jest.fn((m) => { + callOrder.push('createModels'); + m.models.Message = { name: 'Message' }; + m.models.Conversation = { name: 'Conversation' }; + }), + })); + + jest.mock('./indexSync', () => { + callOrder.push('indexSync'); + return jest.fn(); + }); + + jest.mock('./connect', () => ({ connectDb: jest.fn() })); + + require('./index'); + + expect(callOrder).toEqual(['createModels', 'indexSync']); + }); +}); diff --git a/api/db/indexSync.js b/api/db/indexSync.js index 130cde77b8..13059033fb 100644 --- a/api/db/indexSync.js +++ b/api/db/indexSync.js @@ -6,9 +6,6 @@ const { isEnabled, FlowStateManager } = require('@librechat/api'); const { getLogStores } = require('~/cache'); const { batchResetMeiliFlags } = require('./utils'); -const Conversation = mongoose.models.Conversation; -const Message = mongoose.models.Message; - const searchEnabled = isEnabled(process.env.SEARCH); const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); let currentTimeout = null; @@ -200,6 +197,14 @@ async function performSync(flowManager, flowId, flowType) { return { messagesSync: false, convosSync: false }; } + const Message = mongoose.models.Message; + const Conversation = mongoose.models.Conversation; + if (!Message || !Conversation) { + throw new Error( + '[indexSync] Models not registered. Ensure createModels() has been called before indexSync.', + ); + } + const client = MeiliSearchClient.getInstance(); const { status } = await client.health(); @@ -349,6 +354,13 @@ async function indexSync() { logger.debug('[indexSync] Creating indices...'); currentTimeout = setTimeout(async () => { try { + const Message = mongoose.models.Message; + const Conversation = mongoose.models.Conversation; + if (!Message || !Conversation) { + throw new Error( + '[indexSync] Models not registered. Ensure createModels() has been called before indexSync.', + ); + } await Message.syncWithMeili(); await Conversation.syncWithMeili(); } catch (err) { diff --git a/api/package.json b/api/package.json index 8b2f156cd3..86b0f22c0b 100644 --- a/api/package.json +++ b/api/package.json @@ -35,7 +35,7 @@ "homepage": "https://librechat.ai", "dependencies": { "@anthropic-ai/vertex-sdk": "^0.14.3", - "@aws-sdk/client-bedrock-runtime": "^3.980.0", + "@aws-sdk/client-bedrock-runtime": "^3.1013.0", "@aws-sdk/client-s3": "^3.980.0", "@aws-sdk/s3-request-presigner": "^3.758.0", "@azure/identity": "^4.7.0", @@ -44,7 +44,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.62", + "@librechat/agents": "^3.1.63", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", @@ -52,7 +52,7 @@ "@node-saml/passport-saml": "^5.1.0", "@smithy/node-http-handler": "^4.4.5", "ai-tokenizer": "^1.0.6", - "axios": "^1.13.5", + "axios": "1.13.6", "bcryptjs": "^2.4.3", "compression": "^1.8.1", "connect-redis": "^8.1.0", @@ -70,7 +70,7 @@ "file-type": "^21.3.2", "firebase": "^11.0.2", "form-data": "^4.0.4", - "handlebars": "^4.7.7", + "handlebars": "^4.7.9", "https-proxy-agent": "^7.0.6", "ioredis": "^5.3.2", "js-yaml": "^4.1.1", @@ -91,7 +91,7 @@ "multer": "^2.1.1", "nanoid": "^3.3.7", "node-fetch": "^2.7.0", - "nodemailer": "^7.0.11", + "nodemailer": "^8.0.4", "ollama": "^0.5.0", "openai": "5.8.2", "openid-client": "^6.5.0", diff --git a/api/server/cleanup.js b/api/server/cleanup.js index 364c02cd8a..c27814292d 100644 --- a/api/server/cleanup.js +++ b/api/server/cleanup.js @@ -123,9 +123,6 @@ function disposeClient(client) { if (client.maxContextTokens) { client.maxContextTokens = null; } - if (client.contextStrategy) { - client.contextStrategy = null; - } if (client.currentDateString) { client.currentDateString = null; } diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 805d9eef27..4738d45111 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -1,40 +1,12 @@ const { logger } = require('@librechat/data-schemas'); -const { CacheKeys } = require('librechat-data-provider'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); -const { getLogStores } = require('~/cache'); -/** - * @param {ServerRequest} req - * @returns {Promise} The models config. - */ -const getModelsConfig = async (req) => { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - let modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); - if (!modelsConfig) { - modelsConfig = await loadModels(req); - } +const getModelsConfig = (req) => loadModels(req); - return modelsConfig; -}; - -/** - * Loads the models from the config. - * @param {ServerRequest} req - The Express request object. - * @returns {Promise} The models config. - */ async function loadModels(req) { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedModelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); - if (cachedModelsConfig) { - return cachedModelsConfig; - } const defaultModelsConfig = await loadDefaultModels(req); const customModelsConfig = await loadConfigModels(req); - - const modelConfig = { ...defaultModelsConfig, ...customModelsConfig }; - - await cache.set(CacheKeys.MODELS_CONFIG, modelConfig); - return modelConfig; + return { ...defaultModelsConfig, ...customModelsConfig }; } async function modelController(req, res) { diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 279ffb15fd..c5d5c5b888 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,61 +1,37 @@ const { logger } = require('@librechat/data-schemas'); -const { CacheKeys } = require('librechat-data-provider'); const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api'); const { getCachedTools, setCachedTools } = require('~/server/services/Config'); const { availableTools, toolkits } = require('~/app/clients/tools'); const { getAppConfig } = require('~/server/services/Config'); -const { getLogStores } = require('~/cache'); const getAvailablePluginsController = async (req, res) => { try { - const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedPlugins = await cache.get(CacheKeys.PLUGINS); - if (cachedPlugins) { - res.status(200).json(cachedPlugins); - return; - } - - const appConfig = await getAppConfig({ role: req.user?.role }); - /** @type {{ filteredTools: string[], includedTools: string[] }} */ + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); const { filteredTools = [], includedTools = [] } = appConfig; - /** @type {import('@librechat/api').LCManifestTool[]} */ - const pluginManifest = availableTools; - const uniquePlugins = filterUniquePlugins(pluginManifest); - let authenticatedPlugins = []; + const uniquePlugins = filterUniquePlugins(availableTools); + const includeSet = new Set(includedTools); + const filterSet = new Set(filteredTools); + + /** includedTools takes precedence — filteredTools ignored when both are set. */ + const plugins = []; for (const plugin of uniquePlugins) { - authenticatedPlugins.push( - checkPluginAuth(plugin) ? { ...plugin, authenticated: true } : plugin, - ); + if (includeSet.size > 0) { + if (!includeSet.has(plugin.pluginKey)) { + continue; + } + } else if (filterSet.has(plugin.pluginKey)) { + continue; + } + plugins.push(checkPluginAuth(plugin) ? { ...plugin, authenticated: true } : plugin); } - let plugins = authenticatedPlugins; - - if (includedTools.length > 0) { - plugins = plugins.filter((plugin) => includedTools.includes(plugin.pluginKey)); - } else { - plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey)); - } - - await cache.set(CacheKeys.PLUGINS, plugins); res.status(200).json(plugins); } catch (error) { res.status(500).json({ message: error.message }); } }; -/** - * Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file. - * - * This function first attempts to retrieve the list of tools from a cache. If the tools are not found in the cache, - * it reads a plugin manifest file, filters for unique plugins, and determines if each plugin is authenticated. - * Only plugins that are marked as available in the application's local state are included in the final list. - * The resulting list of tools is then cached and sent to the client. - * - * @param {object} req - The request object, containing information about the HTTP request. - * @param {object} res - The response object, used to send back the desired HTTP response. - * @returns {Promise} A promise that resolves when the function has completed. - */ const getAvailableTools = async (req, res) => { try { const userId = req.user?.id; @@ -63,18 +39,10 @@ const getAvailableTools = async (req, res) => { logger.warn('[getAvailableTools] User ID not found in request'); return res.status(401).json({ message: 'Unauthorized' }); } - const cache = getLogStores(CacheKeys.TOOL_CACHE); - const cachedToolsArray = await cache.get(CacheKeys.TOOLS); - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); - // Return early if we have cached tools - if (cachedToolsArray != null) { - res.status(200).json(cachedToolsArray); - return; - } - - /** @type {Record | null} Get tool definitions to filter which tools are actually available */ let toolDefinitions = await getCachedTools(); if (toolDefinitions == null && appConfig?.availableTools != null) { @@ -83,26 +51,17 @@ const getAvailableTools = async (req, res) => { toolDefinitions = appConfig.availableTools; } - /** @type {import('@librechat/api').LCManifestTool[]} */ - let pluginManifest = availableTools; + const uniquePlugins = filterUniquePlugins(availableTools); + const toolDefKeysList = toolDefinitions ? Object.keys(toolDefinitions) : null; + const toolDefKeys = toolDefKeysList ? new Set(toolDefKeysList) : null; - /** @type {TPlugin[]} Deduplicate and authenticate plugins */ - const uniquePlugins = filterUniquePlugins(pluginManifest); - const authenticatedPlugins = uniquePlugins.map((plugin) => { - if (checkPluginAuth(plugin)) { - return { ...plugin, authenticated: true }; - } else { - return plugin; - } - }); - - /** Filter plugins based on availability */ const toolsOutput = []; - for (const plugin of authenticatedPlugins) { - const isToolDefined = toolDefinitions?.[plugin.pluginKey] !== undefined; + for (const plugin of uniquePlugins) { + const isToolDefined = toolDefKeys?.has(plugin.pluginKey) === true; const isToolkit = plugin.toolkit === true && - Object.keys(toolDefinitions ?? {}).some( + toolDefKeysList != null && + toolDefKeysList.some( (key) => getToolkitKey({ toolkits, toolName: key }) === plugin.pluginKey, ); @@ -110,13 +69,10 @@ const getAvailableTools = async (req, res) => { continue; } - toolsOutput.push(plugin); + toolsOutput.push(checkPluginAuth(plugin) ? { ...plugin, authenticated: true } : plugin); } - const finalTools = filterUniquePlugins(toolsOutput); - await cache.set(CacheKeys.TOOLS, finalTools); - - res.status(200).json(finalTools); + res.status(200).json(toolsOutput); } catch (error) { logger.error('[getAvailableTools]', error); res.status(500).json({ message: error.message }); diff --git a/api/server/controllers/PluginController.spec.js b/api/server/controllers/PluginController.spec.js index 06a51a3bd6..9288680567 100644 --- a/api/server/controllers/PluginController.spec.js +++ b/api/server/controllers/PluginController.spec.js @@ -1,6 +1,4 @@ -const { CacheKeys } = require('librechat-data-provider'); const { getCachedTools, getAppConfig } = require('~/server/services/Config'); -const { getLogStores } = require('~/cache'); jest.mock('@librechat/data-schemas', () => ({ logger: { @@ -19,22 +17,15 @@ jest.mock('~/server/services/Config', () => ({ setCachedTools: jest.fn(), })); -// loadAndFormatTools mock removed - no longer used in PluginController -// getMCPManager mock removed - no longer used in PluginController - jest.mock('~/app/clients/tools', () => ({ availableTools: [], toolkits: [], })); -jest.mock('~/cache', () => ({ - getLogStores: jest.fn(), -})); - const { getAvailableTools, getAvailablePluginsController } = require('./PluginController'); describe('PluginController', () => { - let mockReq, mockRes, mockCache; + let mockReq, mockRes; beforeEach(() => { jest.clearAllMocks(); @@ -46,17 +37,12 @@ describe('PluginController', () => { }, }; mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() }; - mockCache = { get: jest.fn(), set: jest.fn() }; - getLogStores.mockReturnValue(mockCache); - // Clear availableTools and toolkits arrays before each test require('~/app/clients/tools').availableTools.length = 0; require('~/app/clients/tools').toolkits.length = 0; - // Reset getCachedTools mock to ensure clean state getCachedTools.mockReset(); - // Reset getAppConfig mock to ensure clean state with default values getAppConfig.mockReset(); getAppConfig.mockResolvedValue({ filteredTools: [], @@ -64,31 +50,8 @@ describe('PluginController', () => { }); }); - describe('cache namespace', () => { - it('getAvailablePluginsController should use TOOL_CACHE namespace', async () => { - mockCache.get.mockResolvedValue([]); - await getAvailablePluginsController(mockReq, mockRes); - expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); - }); - - it('getAvailableTools should use TOOL_CACHE namespace', async () => { - mockCache.get.mockResolvedValue([]); - await getAvailableTools(mockReq, mockRes); - expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); - }); - - it('should NOT use CONFIG_STORE namespace for tool/plugin operations', async () => { - mockCache.get.mockResolvedValue([]); - await getAvailablePluginsController(mockReq, mockRes); - await getAvailableTools(mockReq, mockRes); - const allCalls = getLogStores.mock.calls.flat(); - expect(allCalls).not.toContain(CacheKeys.CONFIG_STORE); - }); - }); - describe('getAvailablePluginsController', () => { it('should use filterUniquePlugins to remove duplicate plugins', async () => { - // Add plugins with duplicates to availableTools const mockPlugins = [ { name: 'Plugin1', pluginKey: 'key1', description: 'First' }, { name: 'Plugin1', pluginKey: 'key1', description: 'First duplicate' }, @@ -97,9 +60,6 @@ describe('PluginController', () => { require('~/app/clients/tools').availableTools.push(...mockPlugins); - mockCache.get.mockResolvedValue(null); - - // Configure getAppConfig to return the expected config getAppConfig.mockResolvedValueOnce({ filteredTools: [], includedTools: [], @@ -109,21 +69,16 @@ describe('PluginController', () => { expect(mockRes.status).toHaveBeenCalledWith(200); const responseData = mockRes.json.mock.calls[0][0]; - // The real filterUniquePlugins should have removed the duplicate expect(responseData).toHaveLength(2); expect(responseData[0].pluginKey).toBe('key1'); expect(responseData[1].pluginKey).toBe('key2'); }); it('should use checkPluginAuth to verify plugin authentication', async () => { - // checkPluginAuth returns false for plugins without authConfig - // so authenticated property won't be added const mockPlugin = { name: 'Plugin1', pluginKey: 'key1', description: 'First' }; require('~/app/clients/tools').availableTools.push(mockPlugin); - mockCache.get.mockResolvedValue(null); - // Configure getAppConfig to return the expected config getAppConfig.mockResolvedValueOnce({ filteredTools: [], includedTools: [], @@ -132,23 +87,9 @@ describe('PluginController', () => { await getAvailablePluginsController(mockReq, mockRes); const responseData = mockRes.json.mock.calls[0][0]; - // The real checkPluginAuth returns false for plugins without authConfig, so authenticated property is not added expect(responseData[0].authenticated).toBeUndefined(); }); - it('should return cached plugins when available', async () => { - const cachedPlugins = [ - { name: 'CachedPlugin', pluginKey: 'cached', description: 'Cached plugin' }, - ]; - - mockCache.get.mockResolvedValue(cachedPlugins); - - await getAvailablePluginsController(mockReq, mockRes); - - // When cache is hit, we return immediately without processing - expect(mockRes.json).toHaveBeenCalledWith(cachedPlugins); - }); - it('should filter plugins based on includedTools', async () => { const mockPlugins = [ { name: 'Plugin1', pluginKey: 'key1', description: 'First' }, @@ -156,9 +97,7 @@ describe('PluginController', () => { ]; require('~/app/clients/tools').availableTools.push(...mockPlugins); - mockCache.get.mockResolvedValue(null); - // Configure getAppConfig to return config with includedTools getAppConfig.mockResolvedValueOnce({ filteredTools: [], includedTools: ['key1'], @@ -170,6 +109,47 @@ describe('PluginController', () => { expect(responseData).toHaveLength(1); expect(responseData[0].pluginKey).toBe('key1'); }); + + it('should exclude plugins in filteredTools', async () => { + const mockPlugins = [ + { name: 'Plugin1', pluginKey: 'key1', description: 'First' }, + { name: 'Plugin2', pluginKey: 'key2', description: 'Second' }, + ]; + + require('~/app/clients/tools').availableTools.push(...mockPlugins); + + getAppConfig.mockResolvedValueOnce({ + filteredTools: ['key2'], + includedTools: [], + }); + + await getAvailablePluginsController(mockReq, mockRes); + + const responseData = mockRes.json.mock.calls[0][0]; + expect(responseData).toHaveLength(1); + expect(responseData[0].pluginKey).toBe('key1'); + }); + + it('should ignore filteredTools when includedTools is set', async () => { + const mockPlugins = [ + { name: 'Plugin1', pluginKey: 'key1', description: 'First' }, + { name: 'Plugin2', pluginKey: 'key2', description: 'Second' }, + { name: 'Plugin3', pluginKey: 'key3', description: 'Third' }, + ]; + + require('~/app/clients/tools').availableTools.push(...mockPlugins); + + getAppConfig.mockResolvedValueOnce({ + includedTools: ['key1', 'key2'], + filteredTools: ['key2'], + }); + + await getAvailablePluginsController(mockReq, mockRes); + + const responseData = mockRes.json.mock.calls[0][0]; + expect(responseData).toHaveLength(2); + expect(responseData.map((p) => p.pluginKey)).toEqual(['key1', 'key2']); + }); }); describe('getAvailableTools', () => { @@ -185,12 +165,11 @@ describe('PluginController', () => { }, }; - const mockCachedPlugins = [ + require('~/app/clients/tools').availableTools.push( { name: 'user-tool', pluginKey: 'user-tool', description: 'Duplicate user tool' }, { name: 'ManifestTool', pluginKey: 'manifest-tool', description: 'Manifest tool' }, - ]; + ); - mockCache.get.mockResolvedValue(mockCachedPlugins); getCachedTools.mockResolvedValueOnce(mockUserTools); mockReq.config = { mcpConfig: null, @@ -202,24 +181,19 @@ describe('PluginController', () => { expect(mockRes.status).toHaveBeenCalledWith(200); const responseData = mockRes.json.mock.calls[0][0]; expect(Array.isArray(responseData)).toBe(true); - // The real filterUniquePlugins should have deduplicated tools with same pluginKey const userToolCount = responseData.filter((tool) => tool.pluginKey === 'user-tool').length; expect(userToolCount).toBe(1); }); it('should use checkPluginAuth to verify authentication status', async () => { - // Add a plugin to availableTools that will be checked const mockPlugin = { name: 'Tool1', pluginKey: 'tool1', description: 'Tool 1', - // No authConfig means checkPluginAuth returns false }; require('~/app/clients/tools').availableTools.push(mockPlugin); - mockCache.get.mockResolvedValue(null); - // getCachedTools returns the tool definitions getCachedTools.mockResolvedValueOnce({ tool1: { type: 'function', @@ -242,7 +216,6 @@ describe('PluginController', () => { expect(Array.isArray(responseData)).toBe(true); const tool = responseData.find((t) => t.pluginKey === 'tool1'); expect(tool).toBeDefined(); - // The real checkPluginAuth returns false for plugins without authConfig, so authenticated property is not added expect(tool.authenticated).toBeUndefined(); }); @@ -256,15 +229,12 @@ describe('PluginController', () => { require('~/app/clients/tools').availableTools.push(mockToolkit); - // Mock toolkits to have a mapping require('~/app/clients/tools').toolkits.push({ name: 'Toolkit1', pluginKey: 'toolkit1', tools: ['toolkit1_function'], }); - mockCache.get.mockResolvedValue(null); - // getCachedTools returns the tool definitions getCachedTools.mockResolvedValueOnce({ toolkit1_function: { type: 'function', @@ -292,7 +262,7 @@ describe('PluginController', () => { describe('helper function integration', () => { it('should handle error cases gracefully', async () => { - mockCache.get.mockRejectedValue(new Error('Cache error')); + getCachedTools.mockRejectedValue(new Error('Cache error')); await getAvailableTools(mockReq, mockRes); @@ -302,17 +272,7 @@ describe('PluginController', () => { }); describe('edge cases with undefined/null values', () => { - it('should handle undefined cache gracefully', async () => { - getLogStores.mockReturnValue(undefined); - - await getAvailableTools(mockReq, mockRes); - - expect(mockRes.status).toHaveBeenCalledWith(500); - }); - - it('should handle null cachedTools and cachedUserTools', async () => { - mockCache.get.mockResolvedValue(null); - // getCachedTools returns empty object instead of null + it('should handle null cachedTools', async () => { getCachedTools.mockResolvedValueOnce({}); mockReq.config = { mcpConfig: null, @@ -321,51 +281,40 @@ describe('PluginController', () => { await getAvailableTools(mockReq, mockRes); - // Should handle null values gracefully expect(mockRes.status).toHaveBeenCalledWith(200); expect(mockRes.json).toHaveBeenCalledWith([]); }); it('should handle when getCachedTools returns undefined', async () => { - mockCache.get.mockResolvedValue(null); mockReq.config = { mcpConfig: null, paths: { structuredTools: '/mock/path' }, }; - // Mock getCachedTools to return undefined getCachedTools.mockReset(); getCachedTools.mockResolvedValueOnce(undefined); await getAvailableTools(mockReq, mockRes); - // Should handle undefined values gracefully expect(mockRes.status).toHaveBeenCalledWith(200); expect(mockRes.json).toHaveBeenCalledWith([]); }); it('should handle empty toolDefinitions object', async () => { - mockCache.get.mockResolvedValue(null); - // Reset getCachedTools to ensure clean state getCachedTools.mockReset(); getCachedTools.mockResolvedValue({}); - mockReq.config = {}; // No mcpConfig at all + mockReq.config = {}; - // Ensure no plugins are available require('~/app/clients/tools').availableTools.length = 0; await getAvailableTools(mockReq, mockRes); - // With empty tool definitions, no tools should be in the final output expect(mockRes.json).toHaveBeenCalledWith([]); }); it('should handle undefined filteredTools and includedTools', async () => { mockReq.config = {}; - mockCache.get.mockResolvedValue(null); - // Configure getAppConfig to return config with undefined properties - // The controller will use default values [] for filteredTools and includedTools getAppConfig.mockResolvedValueOnce({}); await getAvailablePluginsController(mockReq, mockRes); @@ -382,13 +331,8 @@ describe('PluginController', () => { toolkit: true, }; - // No need to mock app.locals anymore as it's not used - - // Add the toolkit to availableTools require('~/app/clients/tools').availableTools.push(mockToolkit); - mockCache.get.mockResolvedValue(null); - // getCachedTools returns empty object to avoid null reference error getCachedTools.mockResolvedValueOnce({}); mockReq.config = { mcpConfig: null, @@ -397,43 +341,32 @@ describe('PluginController', () => { await getAvailableTools(mockReq, mockRes); - // Should handle null toolDefinitions gracefully expect(mockRes.status).toHaveBeenCalledWith(200); }); - it('should handle undefined toolDefinitions when checking isToolDefined (traversaal_search bug)', async () => { - // This test reproduces the bug where toolDefinitions is undefined - // and accessing toolDefinitions[plugin.pluginKey] causes a TypeError + it('should handle undefined toolDefinitions when checking isToolDefined', async () => { const mockPlugin = { name: 'Traversaal Search', pluginKey: 'traversaal_search', description: 'Search plugin', }; - // Add the plugin to availableTools require('~/app/clients/tools').availableTools.push(mockPlugin); - mockCache.get.mockResolvedValue(null); - mockReq.config = { mcpConfig: null, paths: { structuredTools: '/mock/path' }, }; - // CRITICAL: getCachedTools returns undefined - // This is what causes the bug when trying to access toolDefinitions[plugin.pluginKey] getCachedTools.mockResolvedValueOnce(undefined); - // This should not throw an error with the optional chaining fix await getAvailableTools(mockReq, mockRes); - // Should handle undefined toolDefinitions gracefully and return empty array expect(mockRes.status).toHaveBeenCalledWith(200); expect(mockRes.json).toHaveBeenCalledWith([]); }); it('should re-initialize tools from appConfig when cache returns null', async () => { - // Setup: Initial state with tools in appConfig const mockAppTools = { tool1: { type: 'function', @@ -453,15 +386,12 @@ describe('PluginController', () => { }, }; - // Add matching plugins to availableTools require('~/app/clients/tools').availableTools.push( { name: 'Tool 1', pluginKey: 'tool1', description: 'Tool 1' }, { name: 'Tool 2', pluginKey: 'tool2', description: 'Tool 2' }, ); - // Simulate cache cleared state (returns null) - mockCache.get.mockResolvedValue(null); - getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared) + getCachedTools.mockResolvedValueOnce(null); mockReq.config = { filteredTools: [], @@ -469,15 +399,12 @@ describe('PluginController', () => { availableTools: mockAppTools, }; - // Mock setCachedTools to verify it's called to re-initialize const { setCachedTools } = require('~/server/services/Config'); await getAvailableTools(mockReq, mockRes); - // Should have re-initialized the cache with tools from appConfig expect(setCachedTools).toHaveBeenCalledWith(mockAppTools); - // Should still return tools successfully expect(mockRes.status).toHaveBeenCalledWith(200); const responseData = mockRes.json.mock.calls[0][0]; expect(responseData).toHaveLength(2); @@ -486,29 +413,22 @@ describe('PluginController', () => { }); it('should handle cache clear without appConfig.availableTools gracefully', async () => { - // Setup: appConfig without availableTools getAppConfig.mockResolvedValue({ filteredTools: [], includedTools: [], - // No availableTools property }); - // Clear availableTools array require('~/app/clients/tools').availableTools.length = 0; - // Cache returns null (cleared state) - mockCache.get.mockResolvedValue(null); - getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared) + getCachedTools.mockResolvedValueOnce(null); mockReq.config = { filteredTools: [], includedTools: [], - // No availableTools }; await getAvailableTools(mockReq, mockRes); - // Should handle gracefully without crashing expect(mockRes.status).toHaveBeenCalledWith(200); expect(mockRes.json).toHaveBeenCalledWith([]); }); diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 301c6d2f76..16b68968d9 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -26,7 +26,7 @@ const { getLogStores } = require('~/cache'); const db = require('~/models'); const getUserController = async (req, res) => { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); /** @type {IUser} */ const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user }; /** @@ -165,7 +165,7 @@ const deleteUserMcpServers = async (userId) => { }; const updateUserPluginsController = async (req, res) => { - const appConfig = await getAppConfig({ role: req.user?.role }); + const appConfig = await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId }); const { user } = req; const { pluginKey, action, auth, isEntityTool } = req.body; try { diff --git a/api/server/controllers/agents/__tests__/openai.spec.js b/api/server/controllers/agents/__tests__/openai.spec.js index c2f13f7837..c959be6cf4 100644 --- a/api/server/controllers/agents/__tests__/openai.spec.js +++ b/api/server/controllers/agents/__tests__/openai.spec.js @@ -3,6 +3,7 @@ * Tests that recordCollectedUsage is called correctly for token spending */ +const mockProcessStream = jest.fn().mockResolvedValue(undefined); const mockSpendTokens = jest.fn().mockResolvedValue({}); const mockSpendStructuredTokens = jest.fn().mockResolvedValue({}); const mockRecordCollectedUsage = jest @@ -35,7 +36,7 @@ jest.mock('@librechat/agents', () => ({ jest.mock('@librechat/api', () => ({ writeSSE: jest.fn(), createRun: jest.fn().mockResolvedValue({ - processStream: jest.fn().mockResolvedValue(undefined), + processStream: mockProcessStream, }), createChunk: jest.fn().mockReturnValue({}), buildToolSet: jest.fn().mockReturnValue(new Set()), @@ -68,6 +69,7 @@ jest.mock('@librechat/api', () => ({ toolCalls: new Map(), usage: { promptTokens: 100, completionTokens: 50, reasoningTokens: 0 }, }), + resolveRecursionLimit: jest.fn().mockReturnValue(50), createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }), isChatCompletionValidationFailure: jest.fn().mockReturnValue(false), })); @@ -286,4 +288,36 @@ describe('OpenAIChatCompletionController', () => { ); }); }); + + describe('recursionLimit resolution', () => { + it('should pass resolveRecursionLimit result to processStream config', async () => { + const { resolveRecursionLimit } = require('@librechat/api'); + resolveRecursionLimit.mockReturnValueOnce(75); + + await OpenAIChatCompletionController(req, res); + + expect(mockProcessStream).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ recursionLimit: 75 }), + expect.anything(), + ); + }); + + it('should call resolveRecursionLimit with agentsEConfig and agent', async () => { + const { resolveRecursionLimit } = require('@librechat/api'); + const { getAgent } = require('~/models'); + const mockAgent = { id: 'agent-123', name: 'Test', recursion_limit: 200 }; + getAgent.mockResolvedValueOnce(mockAgent); + + req.config = { + endpoints: { + agents: { recursionLimit: 100, maxRecursionLimit: 150, allowedProviders: [] }, + }, + }; + + await OpenAIChatCompletionController(req, res); + + expect(resolveRecursionLimit).toHaveBeenCalledWith(req.config.endpoints.agents, mockAgent); + }); + }); }); diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 47a10165e3..3c1f91bd60 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -21,6 +21,7 @@ const { recordCollectedUsage, GenerationJobManager, getTransactionsConfig, + resolveRecursionLimit, createMemoryProcessor, loadAgent: loadAgentFn, createMultiAgentMapper, @@ -50,6 +51,7 @@ const { const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { createContextHandlers } = require('~/app/clients/prompts'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { getMCPServerTools } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); const { getMCPManager } = require('~/config'); @@ -377,6 +379,9 @@ class AgentClient extends BaseClient { */ const ephemeralAgent = this.options.req.body.ephemeralAgent; const mcpManager = getMCPManager(); + + const configServers = await resolveConfigServers(this.options.req); + await Promise.all( allAgents.map(({ agent, agentId }) => applyContextToAgent({ @@ -384,6 +389,7 @@ class AgentClient extends BaseClient { agentId, logger, mcpManager, + configServers, sharedRunContext, ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined, }), @@ -728,7 +734,7 @@ class AgentClient extends BaseClient { }, user: createSafeUser(this.options.req.user), }, - recursionLimit: agentsEConfig?.recursionLimit ?? 50, + recursionLimit: resolveRecursionLimit(agentsEConfig, this.options.agent), signal: abortController.signal, streamMode: 'values', version: 'v2', @@ -776,17 +782,6 @@ class AgentClient extends BaseClient { agents.push(...this.agentConfigs.values()); } - if (agents[0].recursion_limit && typeof agents[0].recursion_limit === 'number') { - config.recursionLimit = agents[0].recursion_limit; - } - - if ( - agentsEConfig?.maxRecursionLimit && - config.recursionLimit > agentsEConfig?.maxRecursionLimit - ) { - config.recursionLimit = agentsEConfig?.maxRecursionLimit; - } - // TODO: needs to be added as part of AgentContext initialization // const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; // const noSystemMessages = noSystemModelRegex.some((regex) => diff --git a/api/server/controllers/agents/client.test.js b/api/server/controllers/agents/client.test.js index 41a806f66d..1595f652f7 100644 --- a/api/server/controllers/agents/client.test.js +++ b/api/server/controllers/agents/client.test.js @@ -22,6 +22,10 @@ jest.mock('~/server/services/Config', () => ({ getMCPServerTools: jest.fn(), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); + jest.mock('~/models', () => ({ getAgent: jest.fn(), getRoleByName: jest.fn(), @@ -1315,7 +1319,7 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with correct server names - expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2']); + expect(mockFormatInstructions).toHaveBeenCalledWith(['server1', 'server2'], {}); // Verify the instructions do NOT contain [object Promise] expect(client.options.agent.instructions).not.toContain('[object Promise]'); @@ -1355,10 +1359,10 @@ describe('AgentClient - titleConvo', () => { }); // Verify formatInstructionsForContext was called with ephemeral server names - expect(mockFormatInstructions).toHaveBeenCalledWith([ - 'ephemeral-server1', - 'ephemeral-server2', - ]); + expect(mockFormatInstructions).toHaveBeenCalledWith( + ['ephemeral-server1', 'ephemeral-server2'], + {}, + ); // Verify no [object Promise] in instructions expect(client.options.agent.instructions).not.toContain('[object Promise]'); diff --git a/api/server/controllers/agents/filterAuthorizedTools.spec.js b/api/server/controllers/agents/filterAuthorizedTools.spec.js index e215fdc1fc..e6b41aef16 100644 --- a/api/server/controllers/agents/filterAuthorizedTools.spec.js +++ b/api/server/controllers/agents/filterAuthorizedTools.spec.js @@ -22,6 +22,10 @@ jest.mock('~/config', () => ({ })), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); + jest.mock('~/server/services/Files/strategies', () => ({ getStrategyFunctions: jest.fn(), })); @@ -223,7 +227,27 @@ describe('MCP Tool Authorization', () => { availableTools, }); - expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id'); + expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id', undefined); + }); + + test('should pass configServers to getAllServerConfigs and allow config-override servers', async () => { + const configServers = { + 'config-override-server': { type: 'sse', url: 'https://override.example.com' }, + }; + mockGetAllServerConfigs.mockResolvedValue({ + 'config-override-server': configServers['config-override-server'], + }); + + const result = await filterAuthorizedTools({ + tools: [`tool${d}config-override-server`, `tool${d}unauthorizedServer`], + userId, + availableTools, + configServers, + }); + + expect(mockGetAllServerConfigs).toHaveBeenCalledWith(userId, configServers); + expect(result).toContain(`tool${d}config-override-server`); + expect(result).not.toContain(`tool${d}unauthorizedServer`); }); test('should only call getAllServerConfigs once even with multiple MCP tools', async () => { diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js index b649058806..9fa3af82c3 100644 --- a/api/server/controllers/agents/openai.js +++ b/api/server/controllers/agents/openai.js @@ -15,6 +15,7 @@ const { createErrorResponse, recordCollectedUsage, getTransactionsConfig, + resolveRecursionLimit, createToolExecuteHandler, buildNonStreamingResponse, createOpenAIStreamTracker, @@ -194,10 +195,8 @@ const OpenAIChatCompletionController = async (req, res) => { const conversationId = request.conversation_id ?? nanoid(); const parentMessageId = request.parent_message_id ?? null; - // Build allowed providers set - const allowedProviders = new Set( - appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders, - ); + const agentsEConfig = appConfig?.endpoints?.[EModelEndpoint.agents]; + const allowedProviders = new Set(agentsEConfig?.allowedProviders); // Create tool loader const loadTools = createToolLoader(abortController.signal); @@ -491,7 +490,6 @@ const OpenAIChatCompletionController = async (req, res) => { throw new Error('Failed to create agent run'); } - // Process the stream const config = { runName: 'AgentRun', configurable: { @@ -504,6 +502,7 @@ const OpenAIChatCompletionController = async (req, res) => { }, ...(userMCPAuthMap != null && { userMCPAuthMap }), }, + recursionLimit: resolveRecursionLimit(agentsEConfig, agent), signal: abortController.signal, streamMode: 'values', version: 'v2', diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 17985f97ce..e365b232e4 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -38,6 +38,7 @@ const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const { getFileStrategy } = require('~/server/utils/getFileStrategy'); const { filterFile } = require('~/server/services/Files/process'); const { getCachedTools } = require('~/server/services/Config'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { getMCPServersRegistry } = require('~/config'); const { getLogStores } = require('~/cache'); const db = require('~/models'); @@ -101,9 +102,16 @@ const validateEdgeAgentAccess = async (edges, userId, userRole) => { * @param {string} params.userId - Requesting user ID for MCP server access check * @param {Record} params.availableTools - Global non-MCP tool cache * @param {string[]} [params.existingTools] - Tools already persisted on the agent document + * @param {Record} [params.configServers] - Config-source MCP servers resolved from appConfig overrides * @returns {Promise} Only the authorized subset of tools */ -const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTools }) => { +const filterAuthorizedTools = async ({ + tools, + userId, + availableTools, + existingTools, + configServers, +}) => { const filteredTools = []; let mcpServerConfigs; let registryUnavailable = false; @@ -121,7 +129,8 @@ const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTo if (mcpServerConfigs === undefined) { try { - mcpServerConfigs = (await getMCPServersRegistry().getAllServerConfigs(userId)) ?? {}; + mcpServerConfigs = + (await getMCPServersRegistry().getAllServerConfigs(userId, configServers)) ?? {}; } catch (e) { logger.warn( '[filterAuthorizedTools] MCP registry unavailable, filtering all MCP tools', @@ -192,8 +201,17 @@ const createAgentHandler = async (req, res) => { agentData.author = userId; agentData.tools = []; - const availableTools = (await getCachedTools()) ?? {}; - agentData.tools = await filterAuthorizedTools({ tools, userId, availableTools }); + const hasMCPTools = tools.some((t) => t?.includes(Constants.mcp_delimiter)); + const [availableTools, configServers] = await Promise.all([ + getCachedTools().then((t) => t ?? {}), + hasMCPTools ? resolveConfigServers(req) : Promise.resolve(undefined), + ]); + agentData.tools = await filterAuthorizedTools({ + tools, + userId, + availableTools, + configServers, + }); const agent = await db.createAgent(agentData); @@ -376,11 +394,15 @@ const updateAgentHandler = async (req, res) => { ); if (newMCPTools.length > 0) { - const availableTools = (await getCachedTools()) ?? {}; + const [availableTools, configServers] = await Promise.all([ + getCachedTools().then((t) => t ?? {}), + resolveConfigServers(req), + ]); const approvedNew = await filterAuthorizedTools({ tools: newMCPTools, userId: req.user.id, availableTools, + configServers, }); const rejectedSet = new Set(newMCPTools.filter((t) => !approvedNew.includes(t))); if (rejectedSet.size > 0) { @@ -533,12 +555,16 @@ const duplicateAgentHandler = async (req, res) => { newAgentData.actions = agentActions; if (newAgentData.tools?.length) { - const availableTools = (await getCachedTools()) ?? {}; + const [availableTools, configServers] = await Promise.all([ + getCachedTools().then((t) => t ?? {}), + resolveConfigServers(req), + ]); newAgentData.tools = await filterAuthorizedTools({ tools: newAgentData.tools, userId, availableTools, existingTools: newAgentData.tools, + configServers, }); } @@ -873,12 +899,16 @@ const revertAgentVersionHandler = async (req, res) => { let updatedAgent = await db.revertAgentVersion({ id }, version_index); if (updatedAgent.tools?.length) { - const availableTools = (await getCachedTools()) ?? {}; + const [availableTools, configServers] = await Promise.all([ + getCachedTools().then((t) => t ?? {}), + resolveConfigServers(req), + ]); const filteredTools = await filterAuthorizedTools({ tools: updatedAgent.tools, userId: req.user.id, availableTools, existingTools: updatedAgent.tools, + configServers, }); if (filteredTools.length !== updatedAgent.tools.length) { updatedAgent = await db.updateAgent( diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index e4a20c2a5e..631831e617 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -40,6 +40,7 @@ const { sendResponse } = require('~/server/middleware/error'); const { createAutoRefillTransaction, findBalanceByUser, + upsertBalanceFields, getTransactions, getMultiplier, getConvo, @@ -296,7 +297,14 @@ const chatV1 = async (req, res) => { amount: promptTokens, }, }, - { findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation }, + { + findBalanceByUser, + getMultiplier, + createAutoRefillTransaction, + logViolation, + balanceConfig, + upsertBalanceFields, + }, ); }; diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 559d9d8953..237af1b11a 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -37,6 +37,7 @@ const { getMultiplier, getTransactions, findBalanceByUser, + upsertBalanceFields, createAutoRefillTransaction, } = require('~/models'); const { logViolation, getLogStores } = require('~/cache'); @@ -169,7 +170,14 @@ const chatV2 = async (req, res) => { amount: promptTokens, }, }, - { findBalanceByUser, getMultiplier, createAutoRefillTransaction, logViolation }, + { + findBalanceByUser, + getMultiplier, + createAutoRefillTransaction, + logViolation, + balanceConfig, + upsertBalanceFields, + }, ); }; diff --git a/api/server/controllers/auth/oauth.js b/api/server/controllers/auth/oauth.js index 80c2ced002..917e9e2bef 100644 --- a/api/server/controllers/auth/oauth.js +++ b/api/server/controllers/auth/oauth.js @@ -47,9 +47,15 @@ function createOAuthHandler(redirectUri = domains.client) { const refreshToken = req.user.tokenset?.refresh_token || req.user.federatedTokens?.refresh_token; - const exchangeCode = await generateAdminExchangeCode(cache, req.user, token, refreshToken); - const callbackUrl = new URL(redirectUri); + const exchangeCode = await generateAdminExchangeCode( + cache, + req.user, + token, + refreshToken, + callbackUrl.origin, + req.pkceChallenge, + ); callbackUrl.searchParams.set('code', exchangeCode); logger.info(`[OAuth] Admin panel redirect with exchange code for user: ${req.user.email}`); return res.redirect(callbackUrl.toString()); diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 729f01da9d..e31bb93bc6 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -14,6 +14,7 @@ const { isMCPInspectionFailedError, } = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('~/server/services/MCP'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); const { getMCPManager, getMCPServersRegistry } = require('~/config'); @@ -57,7 +58,7 @@ function handleMCPError(error, res) { } /** - * Get all MCP tools available to the user + * Get all MCP tools available to the user. */ const getMCPTools = async (req, res) => { try { @@ -67,10 +68,10 @@ const getMCPTools = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - const configuredServers = mcpConfig ? Object.keys(mcpConfig) : []; + const mcpConfig = await resolveAllMcpConfigs(userId, req.user); + const configuredServers = Object.keys(mcpConfig); - if (!mcpConfig || Object.keys(mcpConfig).length == 0) { + if (!configuredServers.length) { return res.status(200).json({ servers: {} }); } @@ -115,14 +116,11 @@ const getMCPTools = async (req, res) => { try { const serverTools = serverToolsMap.get(serverName); - // Get server config once const serverConfig = mcpConfig[serverName]; - const rawServerConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); - // Initialize server object with all server-level data const server = { name: serverName, - icon: rawServerConfig?.iconPath || '', + icon: serverConfig?.iconPath || '', authenticated: true, authConfig: [], tools: [], @@ -183,7 +181,7 @@ const getMCPServersList = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); + const serverConfigs = await resolveAllMcpConfigs(userId, req.user); return res.json(redactAllServerSecrets(serverConfigs)); } catch (error) { logger.error('[getMCPServersList]', error); @@ -237,7 +235,12 @@ const getMCPServerById = async (req, res) => { if (!serverName) { return res.status(400).json({ message: 'Server name is required' }); } - const parsedConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); + const configServers = await resolveConfigServers(req); + const parsedConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); if (!parsedConfig) { return res.status(404).json({ message: 'MCP server not found' }); diff --git a/api/server/experimental.js b/api/server/experimental.js index 8982b69afb..ff023b4504 100644 --- a/api/server/experimental.js +++ b/api/server/experimental.js @@ -19,6 +19,7 @@ const { performStartupChecks, handleJsonParseError, initializeFileStorage, + preAuthTenantMiddleware, } = require('@librechat/api'); const { connectDb, indexSync } = require('~/db'); const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager'); @@ -31,6 +32,7 @@ const initializeMCPs = require('./services/initializeMCPs'); const configureSocialLogins = require('./socialLogins'); const { getAppConfig } = require('./services/Config'); const staticCache = require('./utils/staticCache'); +const optionalJwtAuth = require('./middleware/optionalJwtAuth'); const noIndex = require('./middleware/noIndex'); const routes = require('./routes'); @@ -312,7 +314,7 @@ if (cluster.isMaster) { app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); - app.use('/api/config', routes.config); + app.use('/api/config', preAuthTenantMiddleware, optionalJwtAuth, routes.config); app.use('/api/assistants', routes.assistants); app.use('/api/files', await routes.files.initialize()); app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute); diff --git a/api/server/index.js b/api/server/index.js index ba376ab335..d26a203c0a 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -8,8 +8,8 @@ const express = require('express'); const passport = require('passport'); const compression = require('compression'); const cookieParser = require('cookie-parser'); -const { logger } = require('@librechat/data-schemas'); const mongoSanitize = require('express-mongo-sanitize'); +const { logger, runAsSystem } = require('@librechat/data-schemas'); const { isEnabled, apiNotFound, @@ -21,6 +21,7 @@ const { createStreamServices, initializeFileStorage, updateInterfacePermissions, + preAuthTenantMiddleware, } = require('@librechat/api'); const { connectDb, indexSync } = require('~/db'); const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager'); @@ -33,6 +34,7 @@ const initializeMCPs = require('./services/initializeMCPs'); const configureSocialLogins = require('./socialLogins'); const { getAppConfig } = require('./services/Config'); const staticCache = require('./utils/staticCache'); +const optionalJwtAuth = require('./middleware/optionalJwtAuth'); const noIndex = require('./middleware/noIndex'); const routes = require('./routes'); @@ -59,11 +61,20 @@ const startServer = async () => { app.disable('x-powered-by'); app.set('trust proxy', trusted_proxy); - await seedDatabase(); - const appConfig = await getAppConfig(); + if (isEnabled(process.env.TENANT_ISOLATION_STRICT)) { + logger.warn( + '[Security] TENANT_ISOLATION_STRICT is active. Ensure your reverse proxy strips or sets ' + + 'the X-Tenant-Id header — untrusted clients must not be able to set it directly.', + ); + } + + await runAsSystem(seedDatabase); + const appConfig = await getAppConfig({ baseOnly: true }); initializeFileStorage(appConfig); - await performStartupChecks(appConfig); - await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }); + await runAsSystem(async () => { + await performStartupChecks(appConfig); + await updateInterfacePermissions({ appConfig, getRoleByName, updateAccessPermissions }); + }); const indexPath = path.join(appConfig.paths.dist, 'index.html'); let indexHTML = fs.readFileSync(indexPath, 'utf8'); @@ -137,10 +148,17 @@ const startServer = async () => { /* Per-request capability cache — must be registered before any route that calls hasCapability */ app.use(capabilityContextMiddleware); - app.use('/oauth', routes.oauth); + /* Pre-auth tenant context for unauthenticated routes that need tenant scoping. + * The reverse proxy / auth gateway sets `X-Tenant-Id` header for multi-tenant deployments. */ + app.use('/oauth', preAuthTenantMiddleware, routes.oauth); /* API Endpoints */ - app.use('/api/auth', routes.auth); + app.use('/api/auth', preAuthTenantMiddleware, routes.auth); app.use('/api/admin', routes.adminAuth); + app.use('/api/admin/config', routes.adminConfig); + app.use('/api/admin/grants', routes.adminGrants); + app.use('/api/admin/groups', routes.adminGroups); + app.use('/api/admin/roles', routes.adminRoles); + app.use('/api/admin/users', routes.adminUsers); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/api-keys', routes.apiKeys); @@ -154,11 +172,11 @@ const startServer = async () => { app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); app.use('/api/models', routes.models); - app.use('/api/config', routes.config); + app.use('/api/config', preAuthTenantMiddleware, optionalJwtAuth, routes.config); app.use('/api/assistants', routes.assistants); app.use('/api/files', await routes.files.initialize()); app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute); - app.use('/api/share', routes.share); + app.use('/api/share', preAuthTenantMiddleware, routes.share); app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); @@ -204,8 +222,10 @@ const startServer = async () => { logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`); } - await initializeMCPs(); - await initializeOAuthReconnectManager(); + await runAsSystem(async () => { + await initializeMCPs(); + await initializeOAuthReconnectManager(); + }); await checkMigrations(); // Configure stream services (auto-detects Redis from USE_REDIS env var) diff --git a/api/server/middleware/__tests__/requireJwtAuth.spec.js b/api/server/middleware/__tests__/requireJwtAuth.spec.js new file mode 100644 index 0000000000..bc288e5dab --- /dev/null +++ b/api/server/middleware/__tests__/requireJwtAuth.spec.js @@ -0,0 +1,116 @@ +/** + * Integration test: verifies that requireJwtAuth chains tenantContextMiddleware + * after successful passport authentication, so ALS tenant context is set for + * all downstream middleware and route handlers. + * + * requireJwtAuth must chain tenantContextMiddleware after passport populates + * req.user (not at global app.use() scope where req.user is undefined). + * If the chaining is removed, these tests fail. + */ + +const { getTenantId } = require('@librechat/data-schemas'); + +// ── Mocks ────────────────────────────────────────────────────────────── + +let mockPassportError = null; + +jest.mock('passport', () => ({ + authenticate: jest.fn(() => { + return (req, _res, done) => { + if (mockPassportError) { + return done(mockPassportError); + } + if (req._mockUser) { + req.user = req._mockUser; + } + done(); + }; + }), +})); + +// Mock @librechat/api — the real tenantContextMiddleware is TS and cannot be +// required directly from CJS tests. This thin wrapper mirrors the real logic +// (read req.user.tenantId, call tenantStorage.run) using the same data-schemas +// primitives. The real implementation is covered by packages/api tenant.spec.ts. +jest.mock('@librechat/api', () => { + const { tenantStorage } = require('@librechat/data-schemas'); + return { + isEnabled: jest.fn(() => false), + tenantContextMiddleware: (req, res, next) => { + const tenantId = req.user?.tenantId; + if (!tenantId) { + return next(); + } + return tenantStorage.run({ tenantId }, async () => next()); + }, + }; +}); + +// ── Helpers ───────────────────────────────────────────────────────────── + +const requireJwtAuth = require('../requireJwtAuth'); + +function mockReq(user) { + return { headers: {}, _mockUser: user }; +} + +function mockRes() { + return { status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis() }; +} + +/** Runs requireJwtAuth and returns the tenantId observed inside next(). */ +function runAuth(user) { + return new Promise((resolve) => { + const req = mockReq(user); + const res = mockRes(); + requireJwtAuth(req, res, () => { + resolve(getTenantId()); + }); + }); +} + +// ── Tests ────────────────────────────────────────────────────────────── + +describe('requireJwtAuth tenant context chaining', () => { + afterEach(() => { + mockPassportError = null; + }); + + it('forwards passport errors to next() without entering tenant middleware', async () => { + mockPassportError = new Error('JWT signature invalid'); + const req = mockReq(undefined); + const res = mockRes(); + const err = await new Promise((resolve) => { + requireJwtAuth(req, res, (e) => resolve(e)); + }); + expect(err).toBeInstanceOf(Error); + expect(err.message).toBe('JWT signature invalid'); + expect(getTenantId()).toBeUndefined(); + }); + + it('sets ALS tenant context after passport auth succeeds', async () => { + const tenantId = await runAuth({ tenantId: 'tenant-abc', role: 'user' }); + expect(tenantId).toBe('tenant-abc'); + }); + + it('ALS tenant context is NOT set when user has no tenantId', async () => { + const tenantId = await runAuth({ role: 'user' }); + expect(tenantId).toBeUndefined(); + }); + + it('ALS tenant context is NOT set when user is undefined', async () => { + const tenantId = await runAuth(undefined); + expect(tenantId).toBeUndefined(); + }); + + it('concurrent requests get isolated tenant contexts', async () => { + const results = await Promise.all( + ['tenant-1', 'tenant-2', 'tenant-3'].map((tid) => runAuth({ tenantId: tid, role: 'user' })), + ); + expect(results).toEqual(['tenant-1', 'tenant-2', 'tenant-3']); + }); + + it('ALS context is not set at top-level scope (outside any request)', () => { + expect(getTenantId()).toBeUndefined(); + }); +}); diff --git a/api/server/middleware/__tests__/validateModel.spec.js b/api/server/middleware/__tests__/validateModel.spec.js new file mode 100644 index 0000000000..634baeed11 --- /dev/null +++ b/api/server/middleware/__tests__/validateModel.spec.js @@ -0,0 +1,178 @@ +const { ViolationTypes } = require('librechat-data-provider'); + +jest.mock('@librechat/api', () => ({ + handleError: jest.fn(), +})); + +jest.mock('~/server/controllers/ModelController', () => ({ + getModelsConfig: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + getEndpointsConfig: jest.fn(), +})); + +jest.mock('~/cache', () => ({ + logViolation: jest.fn(), +})); + +const { handleError } = require('@librechat/api'); +const { getModelsConfig } = require('~/server/controllers/ModelController'); +const { getEndpointsConfig } = require('~/server/services/Config'); +const { logViolation } = require('~/cache'); +const validateModel = require('../validateModel'); + +describe('validateModel', () => { + let req, res, next; + + beforeEach(() => { + jest.clearAllMocks(); + req = { body: { model: 'gpt-4o', endpoint: 'openAI' } }; + res = {}; + next = jest.fn(); + getEndpointsConfig.mockResolvedValue({ + openAI: { userProvide: false }, + }); + getModelsConfig.mockResolvedValue({ + openAI: ['gpt-4o', 'gpt-4o-mini'], + }); + }); + + describe('format validation', () => { + it('rejects missing model', async () => { + req.body.model = undefined; + await validateModel(req, res, next); + expect(handleError).toHaveBeenCalledWith(res, { text: 'Model not provided' }); + expect(next).not.toHaveBeenCalled(); + }); + + it('rejects non-string model', async () => { + req.body.model = 12345; + await validateModel(req, res, next); + expect(handleError).toHaveBeenCalledWith(res, { text: 'Model not provided' }); + expect(next).not.toHaveBeenCalled(); + }); + + it('rejects model exceeding 256 chars', async () => { + req.body.model = 'a'.repeat(257); + await validateModel(req, res, next); + expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' }); + }); + + it('rejects model with leading special character', async () => { + req.body.model = '.bad-model'; + await validateModel(req, res, next); + expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' }); + }); + + it('rejects model with script injection', async () => { + req.body.model = ''; + await validateModel(req, res, next); + expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' }); + }); + + it('trims whitespace before validation', async () => { + req.body.model = ' gpt-4o '; + getModelsConfig.mockResolvedValue({ openAI: ['gpt-4o'] }); + await validateModel(req, res, next); + expect(next).toHaveBeenCalled(); + expect(handleError).not.toHaveBeenCalled(); + }); + + it('rejects model with spaces in the middle', async () => { + req.body.model = 'gpt 4o'; + await validateModel(req, res, next); + expect(handleError).toHaveBeenCalledWith(res, { text: 'Invalid model identifier' }); + }); + + it('accepts standard model IDs', async () => { + const validModels = [ + 'gpt-4o', + 'claude-3-5-sonnet-20241022', + 'us.amazon.nova-pro-v1:0', + 'qwen/qwen3.6-plus-preview:free', + 'Meta-Llama-3-8B-Instruct-4bit', + ]; + for (const model of validModels) { + jest.clearAllMocks(); + req.body.model = model; + getEndpointsConfig.mockResolvedValue({ openAI: { userProvide: false } }); + getModelsConfig.mockResolvedValue({ openAI: [model] }); + next.mockClear(); + + await validateModel(req, res, next); + expect(next).toHaveBeenCalled(); + expect(handleError).not.toHaveBeenCalled(); + } + }); + }); + + describe('userProvide early-return', () => { + it('calls next() immediately for userProvide endpoints without checking model list', async () => { + getEndpointsConfig.mockResolvedValue({ + openAI: { userProvide: true }, + }); + req.body.model = 'any-model-from-user-key'; + + await validateModel(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(getModelsConfig).not.toHaveBeenCalled(); + }); + + it('does not call getModelsConfig for userProvide endpoints', async () => { + getEndpointsConfig.mockResolvedValue({ + CustomEndpoint: { userProvide: true }, + }); + req.body = { model: 'custom-model', endpoint: 'CustomEndpoint' }; + + await validateModel(req, res, next); + + expect(getModelsConfig).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalled(); + }); + }); + + describe('system endpoint list validation', () => { + it('rejects a model not in the available list', async () => { + req.body.model = 'not-in-list'; + + await validateModel(req, res, next); + + expect(logViolation).toHaveBeenCalledWith( + req, + res, + ViolationTypes.ILLEGAL_MODEL_REQUEST, + expect.any(Object), + expect.anything(), + ); + expect(handleError).toHaveBeenCalledWith(res, { text: 'Illegal model request' }); + expect(next).not.toHaveBeenCalled(); + }); + + it('accepts a model in the available list', async () => { + req.body.model = 'gpt-4o'; + + await validateModel(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(handleError).not.toHaveBeenCalled(); + }); + + it('rejects when endpoint has no models loaded', async () => { + getModelsConfig.mockResolvedValue({ openAI: undefined }); + + await validateModel(req, res, next); + + expect(handleError).toHaveBeenCalledWith(res, { text: 'Endpoint models not loaded' }); + }); + + it('rejects when modelsConfig is null', async () => { + getModelsConfig.mockResolvedValue(null); + + await validateModel(req, res, next); + + expect(handleError).toHaveBeenCalledWith(res, { text: 'Models not loaded' }); + }); + }); +}); diff --git a/api/server/middleware/checkDomainAllowed.js b/api/server/middleware/checkDomainAllowed.js index 754eb9c127..f7a3f00e68 100644 --- a/api/server/middleware/checkDomainAllowed.js +++ b/api/server/middleware/checkDomainAllowed.js @@ -18,6 +18,7 @@ const checkDomainAllowed = async (req, res, next) => { const email = req?.user?.email; const appConfig = await getAppConfig({ role: req?.user?.role, + tenantId: req?.user?.tenantId, }); if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { diff --git a/api/server/middleware/config/app.js b/api/server/middleware/config/app.js index bca3c8f71d..fb5f89b229 100644 --- a/api/server/middleware/config/app.js +++ b/api/server/middleware/config/app.js @@ -4,7 +4,9 @@ const { getAppConfig } = require('~/server/services/Config'); const configMiddleware = async (req, res, next) => { try { const userRole = req.user?.role; - req.config = await getAppConfig({ role: userRole }); + const userId = req.user?.id; + const tenantId = req.user?.tenantId; + req.config = await getAppConfig({ role: userRole, userId, tenantId }); next(); } catch (error) { diff --git a/api/server/middleware/optionalJwtAuth.js b/api/server/middleware/optionalJwtAuth.js index 2f59fdda4a..d46478d36e 100644 --- a/api/server/middleware/optionalJwtAuth.js +++ b/api/server/middleware/optionalJwtAuth.js @@ -1,9 +1,10 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); // This middleware does not require authentication, -// but if the user is authenticated, it will set the user object. +// but if the user is authenticated, it will set the user object +// and establish tenant ALS context. const optionalJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; @@ -13,6 +14,7 @@ const optionalJwtAuth = (req, res, next) => { } if (user) { req.user = user; + return tenantContextMiddleware(req, res, next); } next(); }; diff --git a/api/server/middleware/requireJwtAuth.js b/api/server/middleware/requireJwtAuth.js index 16b107aefc..b13e991b23 100644 --- a/api/server/middleware/requireJwtAuth.js +++ b/api/server/middleware/requireJwtAuth.js @@ -1,20 +1,29 @@ const cookies = require('cookie'); const passport = require('passport'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); /** - * Custom Middleware to handle JWT authentication, with support for OpenID token reuse - * Switches between JWT and OpenID authentication based on cookies and environment settings + * Custom Middleware to handle JWT authentication, with support for OpenID token reuse. + * Switches between JWT and OpenID authentication based on cookies and environment settings. + * + * After successful authentication (req.user populated), automatically chains into + * `tenantContextMiddleware` to propagate `req.user.tenantId` into AsyncLocalStorage + * for downstream Mongoose tenant isolation. */ const requireJwtAuth = (req, res, next) => { const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; - if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) { - return passport.authenticate('openidJwt', { session: false })(req, res, next); - } + const strategy = + tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS) ? 'openidJwt' : 'jwt'; - return passport.authenticate('jwt', { session: false })(req, res, next); + passport.authenticate(strategy, { session: false })(req, res, (err) => { + if (err) { + return next(err); + } + // req.user is now populated by passport — set up tenant ALS context + tenantContextMiddleware(req, res, next); + }); }; module.exports = requireJwtAuth; diff --git a/api/server/middleware/validateModel.js b/api/server/middleware/validateModel.js index 40f6e67bfb..71a931f0d1 100644 --- a/api/server/middleware/validateModel.js +++ b/api/server/middleware/validateModel.js @@ -1,7 +1,12 @@ const { handleError } = require('@librechat/api'); const { ViolationTypes } = require('librechat-data-provider'); const { getModelsConfig } = require('~/server/controllers/ModelController'); +const { getEndpointsConfig } = require('~/server/services/Config'); const { logViolation } = require('~/cache'); + +const MAX_MODEL_STRING_LENGTH = 256; +const MODEL_PATTERN = /^[a-zA-Z0-9][a-zA-Z0-9_.:/@+-]*$/; + /** * Validates the model of the request. * @@ -11,11 +16,27 @@ const { logViolation } = require('~/cache'); * @param {Function} next - The Express next function. */ const validateModel = async (req, res, next) => { - const { model, endpoint } = req.body; - if (!model) { + const { endpoint } = req.body; + const rawModel = req.body.model; + + if (!rawModel || typeof rawModel !== 'string') { return handleError(res, { text: 'Model not provided' }); } + const model = rawModel.trim(); + if (!model || model.length > MAX_MODEL_STRING_LENGTH || !MODEL_PATTERN.test(model)) { + return handleError(res, { text: 'Invalid model identifier' }); + } + + req.body.model = model; + + const endpointsConfig = await getEndpointsConfig(req); + const endpointConfig = endpointsConfig?.[endpoint]; + + if (endpointConfig?.userProvide) { + return next(); + } + const modelsConfig = await getModelsConfig(req); if (!modelsConfig) { diff --git a/api/server/routes/__tests__/config.spec.js b/api/server/routes/__tests__/config.spec.js index 7d7d3ea13a..54315a7798 100644 --- a/api/server/routes/__tests__/config.spec.js +++ b/api/server/routes/__tests__/config.spec.js @@ -1,25 +1,73 @@ jest.mock('~/cache/getLogStores'); + +const mockGetAppConfig = jest.fn(); +jest.mock('~/server/services/Config/app', () => ({ + getAppConfig: (...args) => mockGetAppConfig(...args), +})); + +jest.mock('~/server/services/Config/ldap', () => ({ + getLdapConfig: jest.fn(() => null), +})); + +const mockGetTenantId = jest.fn(() => undefined); +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + getTenantId: (...args) => mockGetTenantId(...args), +})); + const request = require('supertest'); const express = require('express'); const configRoute = require('../config'); -// file deepcode ignore UseCsurfForExpress/test: test -const app = express(); -app.disable('x-powered-by'); -app.use('/api/config', configRoute); + +function createApp(user) { + const app = express(); + app.disable('x-powered-by'); + if (user) { + app.use((req, _res, next) => { + req.user = user; + next(); + }); + } + app.use('/api/config', configRoute); + return app; +} + +const baseAppConfig = { + registration: { socialLogins: ['google', 'github'] }, + interfaceConfig: { + privacyPolicy: { externalUrl: 'https://example.com/privacy' }, + termsOfService: { externalUrl: 'https://example.com/tos' }, + modelSelect: true, + }, + turnstileConfig: { siteKey: 'test-key' }, + modelSpecs: { list: [{ name: 'test-spec' }] }, + webSearch: { searchProvider: 'tavily' }, +}; + +const mockUser = { + id: 'user123', + role: 'USER', + tenantId: undefined, +}; afterEach(() => { + jest.resetAllMocks(); delete process.env.APP_TITLE; + delete process.env.CHECK_BALANCE; + delete process.env.START_BALANCE; + delete process.env.SANDPACK_BUNDLER_URL; + delete process.env.SANDPACK_STATIC_BUNDLER_URL; + delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + delete process.env.ALLOW_REGISTRATION; + delete process.env.ALLOW_SOCIAL_LOGIN; + delete process.env.ALLOW_PASSWORD_RESET; + delete process.env.DOMAIN_SERVER; delete process.env.GOOGLE_CLIENT_ID; delete process.env.GOOGLE_CLIENT_SECRET; - delete process.env.FACEBOOK_CLIENT_ID; - delete process.env.FACEBOOK_CLIENT_SECRET; delete process.env.OPENID_CLIENT_ID; delete process.env.OPENID_CLIENT_SECRET; delete process.env.OPENID_ISSUER; delete process.env.OPENID_SESSION_SECRET; - delete process.env.OPENID_BUTTON_LABEL; - delete process.env.OPENID_AUTO_REDIRECT; - delete process.env.OPENID_AUTH_URL; delete process.env.GITHUB_CLIENT_ID; delete process.env.GITHUB_CLIENT_SECRET; delete process.env.DISCORD_CLIENT_ID; @@ -28,78 +76,215 @@ afterEach(() => { delete process.env.SAML_ISSUER; delete process.env.SAML_CERT; delete process.env.SAML_SESSION_SECRET; - delete process.env.SAML_BUTTON_LABEL; - delete process.env.SAML_IMAGE_URL; - delete process.env.DOMAIN_SERVER; - delete process.env.ALLOW_REGISTRATION; - delete process.env.ALLOW_SOCIAL_LOGIN; - delete process.env.ALLOW_PASSWORD_RESET; - delete process.env.LDAP_URL; - delete process.env.LDAP_BIND_DN; - delete process.env.LDAP_BIND_CREDENTIALS; - delete process.env.LDAP_USER_SEARCH_BASE; - delete process.env.LDAP_SEARCH_FILTER; }); -//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why. +describe('GET /api/config', () => { + describe('unauthenticated (no req.user)', () => { + it('should call getAppConfig with baseOnly when no tenant context', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + mockGetTenantId.mockReturnValue(undefined); + const app = createApp(null); -describe.skip('GET /', () => { - it('should return 200 and the correct body', async () => { - process.env.APP_TITLE = 'Test Title'; - process.env.GOOGLE_CLIENT_ID = 'Test Google Client Id'; - process.env.GOOGLE_CLIENT_SECRET = 'Test Google Client Secret'; - process.env.FACEBOOK_CLIENT_ID = 'Test Facebook Client Id'; - process.env.FACEBOOK_CLIENT_SECRET = 'Test Facebook Client Secret'; - process.env.OPENID_CLIENT_ID = 'Test OpenID Id'; - process.env.OPENID_CLIENT_SECRET = 'Test OpenID Secret'; - process.env.OPENID_ISSUER = 'Test OpenID Issuer'; - process.env.OPENID_SESSION_SECRET = 'Test Secret'; - process.env.OPENID_BUTTON_LABEL = 'Test OpenID'; - process.env.OPENID_AUTH_URL = 'http://test-server.com'; - process.env.GITHUB_CLIENT_ID = 'Test Github client Id'; - process.env.GITHUB_CLIENT_SECRET = 'Test Github client Secret'; - process.env.DISCORD_CLIENT_ID = 'Test Discord client Id'; - process.env.DISCORD_CLIENT_SECRET = 'Test Discord client Secret'; - process.env.SAML_ENTRY_POINT = 'http://test-server.com'; - process.env.SAML_ISSUER = 'Test SAML Issuer'; - process.env.SAML_CERT = 'saml.pem'; - process.env.SAML_SESSION_SECRET = 'Test Secret'; - process.env.SAML_BUTTON_LABEL = 'Test SAML'; - process.env.SAML_IMAGE_URL = 'http://test-server.com'; - process.env.DOMAIN_SERVER = 'http://test-server.com'; - process.env.ALLOW_REGISTRATION = 'true'; - process.env.ALLOW_SOCIAL_LOGIN = 'true'; - process.env.ALLOW_PASSWORD_RESET = 'true'; - process.env.LDAP_URL = 'Test LDAP URL'; - process.env.LDAP_BIND_DN = 'Test LDAP Bind DN'; - process.env.LDAP_BIND_CREDENTIALS = 'Test LDAP Bind Credentials'; - process.env.LDAP_USER_SEARCH_BASE = 'Test LDAP User Search Base'; - process.env.LDAP_SEARCH_FILTER = 'Test LDAP Search Filter'; + await request(app).get('/api/config'); - const response = await request(app).get('/'); + expect(mockGetAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); - expect(response.statusCode).toBe(200); - expect(response.body).toEqual({ - appTitle: 'Test Title', - socialLogins: ['google', 'facebook', 'openid', 'github', 'discord', 'saml'], - discordLoginEnabled: true, - facebookLoginEnabled: true, - githubLoginEnabled: true, - googleLoginEnabled: true, - openidLoginEnabled: true, - openidLabel: 'Test OpenID', - openidImageUrl: 'http://test-server.com', - samlLoginEnabled: true, - samlLabel: 'Test SAML', - samlImageUrl: 'http://test-server.com', - ldap: { - enabled: true, - }, - serverDomain: 'http://test-server.com', - emailLoginEnabled: 'true', - registrationEnabled: 'true', - passwordResetEnabled: 'true', - socialLoginEnabled: 'true', + it('should call getAppConfig with tenantId when tenant context is present', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + mockGetTenantId.mockReturnValue('tenant-abc'); + const app = createApp(null); + + await request(app).get('/api/config'); + + expect(mockGetAppConfig).toHaveBeenCalledWith({ tenantId: 'tenant-abc' }); + }); + + it('should map tenant-scoped config fields in unauthenticated response', async () => { + const tenantConfig = { + ...baseAppConfig, + registration: { socialLogins: ['saml'] }, + turnstileConfig: { siteKey: 'tenant-key' }, + }; + mockGetAppConfig.mockResolvedValue(tenantConfig); + mockGetTenantId.mockReturnValue('tenant-abc'); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.statusCode).toBe(200); + expect(response.body.socialLogins).toEqual(['saml']); + expect(response.body.turnstile).toEqual({ siteKey: 'tenant-key' }); + expect(response.body).not.toHaveProperty('modelSpecs'); + }); + + it('should return minimal payload without authenticated-only fields', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.statusCode).toBe(200); + expect(response.body).not.toHaveProperty('modelSpecs'); + expect(response.body).not.toHaveProperty('balance'); + expect(response.body).not.toHaveProperty('webSearch'); + expect(response.body).not.toHaveProperty('bundlerURL'); + expect(response.body).not.toHaveProperty('staticBundlerURL'); + expect(response.body).not.toHaveProperty('sharePointFilePickerEnabled'); + expect(response.body).not.toHaveProperty('conversationImportMaxFileSize'); + }); + + it('should include socialLogins and turnstile from base config', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.body.socialLogins).toEqual(['google', 'github']); + expect(response.body.turnstile).toEqual({ siteKey: 'test-key' }); + }); + + it('should include only privacyPolicy and termsOfService from interface config', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.body.interface).toEqual({ + privacyPolicy: { externalUrl: 'https://example.com/privacy' }, + termsOfService: { externalUrl: 'https://example.com/tos' }, + }); + expect(response.body.interface).not.toHaveProperty('modelSelect'); + }); + + it('should not include interface if no privacyPolicy or termsOfService', async () => { + mockGetAppConfig.mockResolvedValue({ + ...baseAppConfig, + interfaceConfig: { modelSelect: true }, + }); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.body).not.toHaveProperty('interface'); + }); + + it('should include shared env var fields', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + process.env.APP_TITLE = 'Test App'; + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.body.appTitle).toBe('Test App'); + expect(response.body).toHaveProperty('emailLoginEnabled'); + expect(response.body).toHaveProperty('serverDomain'); + }); + + it('should return 500 when getAppConfig throws', async () => { + mockGetAppConfig.mockRejectedValue(new Error('Config service failure')); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.statusCode).toBe(500); + expect(response.body).toHaveProperty('error'); + }); + }); + + describe('authenticated (req.user exists)', () => { + it('should call getAppConfig with role, userId, and tenantId', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + mockGetTenantId.mockReturnValue('fallback-tenant'); + const app = createApp(mockUser); + + await request(app).get('/api/config'); + + expect(mockGetAppConfig).toHaveBeenCalledWith({ + role: 'USER', + userId: 'user123', + tenantId: 'fallback-tenant', + }); + }); + + it('should prefer user tenantId over getTenantId fallback', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + mockGetTenantId.mockReturnValue('fallback-tenant'); + const app = createApp({ ...mockUser, tenantId: 'user-tenant' }); + + await request(app).get('/api/config'); + + expect(mockGetAppConfig).toHaveBeenCalledWith({ + role: 'USER', + userId: 'user123', + tenantId: 'user-tenant', + }); + }); + + it('should include modelSpecs, balance, and webSearch', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + process.env.CHECK_BALANCE = 'true'; + process.env.START_BALANCE = '10000'; + const app = createApp(mockUser); + + const response = await request(app).get('/api/config'); + + expect(response.body.modelSpecs).toEqual({ list: [{ name: 'test-spec' }] }); + expect(response.body.balance).toEqual({ enabled: true, startBalance: 10000 }); + expect(response.body.webSearch).toEqual({ searchProvider: 'tavily' }); + }); + + it('should include full interface config', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + const app = createApp(mockUser); + + const response = await request(app).get('/api/config'); + + expect(response.body.interface).toEqual(baseAppConfig.interfaceConfig); + }); + + it('should include authenticated-only env var fields', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + process.env.SANDPACK_BUNDLER_URL = 'https://bundler.test'; + process.env.SANDPACK_STATIC_BUNDLER_URL = 'https://static-bundler.test'; + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '5000000'; + const app = createApp(mockUser); + + const response = await request(app).get('/api/config'); + + expect(response.body.bundlerURL).toBe('https://bundler.test'); + expect(response.body.staticBundlerURL).toBe('https://static-bundler.test'); + expect(response.body.conversationImportMaxFileSize).toBe(5000000); + }); + + it('should merge per-user balance override into config', async () => { + mockGetAppConfig.mockResolvedValue({ + ...baseAppConfig, + balance: { + enabled: true, + startBalance: 50000, + }, + }); + const app = createApp(mockUser); + + const response = await request(app).get('/api/config'); + + expect(response.body.balance).toEqual( + expect.objectContaining({ + enabled: true, + startBalance: 50000, + }), + ); + }); + + it('should return 500 when getAppConfig throws', async () => { + mockGetAppConfig.mockRejectedValue(new Error('Config service failure')); + const app = createApp(mockUser); + + const response = await request(app).get('/api/config'); + + expect(response.statusCode).toBe(500); + expect(response.body).toHaveProperty('error'); }); }); }); diff --git a/api/server/routes/__tests__/grants.spec.js b/api/server/routes/__tests__/grants.spec.js new file mode 100644 index 0000000000..c7b5b6bdda --- /dev/null +++ b/api/server/routes/__tests__/grants.spec.js @@ -0,0 +1,185 @@ +const express = require('express'); +const request = require('supertest'); +const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { createModels, createMethods } = require('@librechat/data-schemas'); +const { PrincipalType, SystemRoles } = require('librechat-data-provider'); + +/** + * Integration test for the admin grants routes. + * + * Validates the full Express wiring: route registration → middleware → + * handler → real MongoDB. Auth middleware is injected (matching the repo + * pattern in keys.spec.js) so we can control the caller identity without + * a real JWT, while the handler DI deps use real DB methods. + */ + +jest.mock('~/server/middleware', () => ({ + requireJwtAuth: (_req, _res, next) => next(), +})); + +jest.mock('~/server/middleware/roles/capabilities', () => ({ + requireCapability: () => (_req, _res, next) => next(), +})); + +let mongoServer; +let db; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + createModels(mongoose); + db = createMethods(mongoose); + await db.seedSystemGrants(); + await db.initializeRoles(); + await db.seedDefaultRoles(); +}); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +afterEach(async () => { + const SystemGrant = mongoose.models.SystemGrant; + // Clean non-seed grants (keep admin seed) + await SystemGrant.deleteMany({ + $or: [ + { principalId: { $ne: SystemRoles.ADMIN } }, + { principalType: { $ne: PrincipalType.ROLE } }, + ], + }); +}); + +function createApp(user) { + const { createAdminGrantsHandlers, getCachedPrincipals } = require('@librechat/api'); + + const handlers = createAdminGrantsHandlers({ + listGrants: db.listGrants, + countGrants: db.countGrants, + getCapabilitiesForPrincipal: db.getCapabilitiesForPrincipal, + getCapabilitiesForPrincipals: db.getCapabilitiesForPrincipals, + grantCapability: db.grantCapability, + revokeCapability: db.revokeCapability, + getUserPrincipals: db.getUserPrincipals, + hasCapabilityForPrincipals: db.hasCapabilityForPrincipals, + getHeldCapabilities: db.getHeldCapabilities, + getCachedPrincipals, + checkRoleExists: async (name) => (await db.getRoleByName(name)) != null, + }); + + const app = express(); + app.use(express.json()); + app.use((req, _res, next) => { + req.user = user; + next(); + }); + + const router = express.Router(); + router.get('/', handlers.listGrants); + router.get('/effective', handlers.getEffectiveCapabilities); + router.get('/:principalType/:principalId', handlers.getPrincipalGrants); + router.post('/', handlers.assignGrant); + router.delete('/:principalType/:principalId/:capability', handlers.revokeGrant); + app.use('/api/admin/grants', router); + + return app; +} + +describe('Admin Grants Routes — Integration', () => { + const adminUserId = new mongoose.Types.ObjectId(); + const adminUser = { + _id: adminUserId, + id: adminUserId.toString(), + role: SystemRoles.ADMIN, + }; + + it('GET / returns seeded admin grants', async () => { + const app = createApp(adminUser); + const res = await request(app).get('/api/admin/grants').expect(200); + + expect(res.body).toHaveProperty('grants'); + expect(res.body).toHaveProperty('total'); + expect(res.body.grants.length).toBeGreaterThan(0); + // Seeded grants are for the ADMIN role + expect(res.body.grants[0].principalType).toBe(PrincipalType.ROLE); + }); + + it('GET /effective returns capabilities for admin', async () => { + const app = createApp(adminUser); + const res = await request(app).get('/api/admin/grants/effective').expect(200); + + expect(res.body).toHaveProperty('capabilities'); + expect(res.body.capabilities).toContain('access:admin'); + expect(res.body.capabilities).toContain('manage:roles'); + }); + + it('POST / assigns a grant and DELETE / revokes it', async () => { + const app = createApp(adminUser); + + // Assign + const assignRes = await request(app) + .post('/api/admin/grants') + .send({ + principalType: PrincipalType.ROLE, + principalId: SystemRoles.USER, + capability: 'read:users', + }) + .expect(201); + + expect(assignRes.body.grant).toMatchObject({ + principalType: PrincipalType.ROLE, + principalId: SystemRoles.USER, + capability: 'read:users', + }); + + // Verify via GET + const getRes = await request(app) + .get(`/api/admin/grants/${PrincipalType.ROLE}/${SystemRoles.USER}`) + .expect(200); + + expect(getRes.body.grants.some((g) => g.capability === 'read:users')).toBe(true); + + // Revoke + await request(app) + .delete(`/api/admin/grants/${PrincipalType.ROLE}/${SystemRoles.USER}/read:users`) + .expect(200); + + // Verify revoked + const afterRes = await request(app) + .get(`/api/admin/grants/${PrincipalType.ROLE}/${SystemRoles.USER}`) + .expect(200); + + expect(afterRes.body.grants.some((g) => g.capability === 'read:users')).toBe(false); + }); + + it('POST / returns 400 for non-existent role when checkRoleExists is wired', async () => { + const app = createApp(adminUser); + + const res = await request(app) + .post('/api/admin/grants') + .send({ + principalType: PrincipalType.ROLE, + principalId: 'nonexistent-role', + capability: 'read:users', + }) + .expect(400); + + expect(res.body.error).toBe('Role not found'); + }); + + it('POST / returns 401 without authenticated user', async () => { + const app = createApp(undefined); + + const res = await request(app) + .post('/api/admin/grants') + .send({ + principalType: PrincipalType.ROLE, + principalId: SystemRoles.USER, + capability: 'read:users', + }) + .expect(401); + + expect(res.body).toHaveProperty('error', 'Authentication required'); + }); +}); diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 1ad8cac087..f194f361d3 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -18,6 +18,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getOAuthServers: jest.fn(), getAllServerConfigs: jest.fn(), + ensureConfigServers: jest.fn().mockResolvedValue({}), addServer: jest.fn(), updateServer: jest.fn(), removeServer: jest.fn(), @@ -58,6 +59,7 @@ jest.mock('@librechat/api', () => { }); jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(), logger: { debug: jest.fn(), info: jest.fn(), @@ -93,14 +95,18 @@ jest.mock('~/server/services/Config', () => ({ getCachedTools: jest.fn(), getMCPServerTools: jest.fn(), loadCustomConfig: jest.fn(), + getAppConfig: jest.fn().mockResolvedValue({ mcpConfig: {} }), })); jest.mock('~/server/services/Config/mcp', () => ({ updateMCPServerTools: jest.fn(), })); +const mockResolveAllMcpConfigs = jest.fn().mockResolvedValue({}); jest.mock('~/server/services/MCP', () => ({ getMCPSetupData: jest.fn(), + resolveConfigServers: jest.fn().mockResolvedValue({}), + resolveAllMcpConfigs: (...args) => mockResolveAllMcpConfigs(...args), getServerConnectionStatus: jest.fn(), })); @@ -579,6 +585,112 @@ describe('MCP Routes', () => { ); }); + it('should use oauthHeaders from flow state when present', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + oauthHeaders: { 'X-Custom-Auth': 'header-value' }, + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Custom-Auth': 'header-value' }, + ); + expect(mockRegistryInstance.getServerConfig).not.toHaveBeenCalled(); + }); + + it('should fall back to registry oauth_headers when flow state lacks them', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING' }), + completeFlow: jest.fn().mockResolvedValue(), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: { toolFlowId: 'tool-flow-123' }, + clientInfo: {}, + codeVerifier: 'test-verifier', + }; + const mockTokens = { access_token: 'tok', refresh_token: 'ref' }; + + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); + MCPTokenStorage.storeTokens.mockResolvedValue(); + mockRegistryInstance.getServerConfig.mockResolvedValue({ + oauth_headers: { 'X-Registry-Header': 'from-registry' }, + }); + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/config').getMCPManager.mockReturnValue({ + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }); + const { getCachedTools, setCachedTools } = require('~/server/services/Config'); + getCachedTools.mockResolvedValue({}); + setCachedTools.mockResolvedValue(); + + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ code: 'auth-code', state: flowId }); + + expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( + flowId, + 'auth-code', + mockFlowManager, + { 'X-Registry-Header': 'from-registry' }, + ); + expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( + 'test-server', + 'test-user-id', + undefined, + ); + }); + it('should redirect to error page when callback processing fails', async () => { MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error')); const flowId = 'test-user-id:test-server'; @@ -1350,19 +1462,10 @@ describe('MCP Routes', () => { }, }); - expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id'); + expect(getMCPSetupData).toHaveBeenCalledWith('test-user-id', expect.any(Object)); expect(getServerConnectionStatus).toHaveBeenCalledTimes(2); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database error')); @@ -1437,15 +1540,6 @@ describe('MCP Routes', () => { }); }); - it('should return 404 when MCP config is not found', async () => { - getMCPSetupData.mockRejectedValue(new Error('MCP config not found')); - - const response = await request(app).get('/api/mcp/connection/status/test-server'); - - expect(response.status).toBe(404); - expect(response.body).toEqual({ error: 'MCP config not found' }); - }); - it('should return 500 when connection status check fails', async () => { getMCPSetupData.mockRejectedValue(new Error('Database connection failed')); @@ -1704,7 +1798,7 @@ describe('MCP Routes', () => { }, }; - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockServerConfigs); + mockResolveAllMcpConfigs.mockResolvedValue(mockServerConfigs); const response = await request(app).get('/api/mcp/servers'); @@ -1721,11 +1815,14 @@ describe('MCP Routes', () => { }); expect(response.body['server-1'].headers).toBeUndefined(); expect(response.body['server-2'].headers).toBeUndefined(); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); + expect(mockResolveAllMcpConfigs).toHaveBeenCalledWith( + 'test-user-id', + expect.objectContaining({ id: 'test-user-id' }), + ); }); it('should return empty object when no servers are configured', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + mockResolveAllMcpConfigs.mockResolvedValue({}); const response = await request(app).get('/api/mcp/servers'); @@ -1749,7 +1846,7 @@ describe('MCP Routes', () => { }); it('should return 500 when server config retrieval fails', async () => { - mockRegistryInstance.getAllServerConfigs.mockRejectedValue(new Error('Database error')); + mockResolveAllMcpConfigs.mockRejectedValue(new Error('Database error')); const response = await request(app).get('/api/mcp/servers'); @@ -1939,11 +2036,12 @@ describe('MCP Routes', () => { expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( 'test-server', 'test-user-id', + {}, ); }); it('should return 404 when server not found', async () => { - mockRegistryInstance.getServerConfig.mockResolvedValue(null); + mockRegistryInstance.getServerConfig.mockResolvedValue(undefined); const response = await request(app).get('/api/mcp/servers/non-existent-server'); diff --git a/api/server/routes/admin/auth.js b/api/server/routes/admin/auth.js index 530764852b..72f23b7d52 100644 --- a/api/server/routes/admin/auth.js +++ b/api/server/routes/admin/auth.js @@ -1,9 +1,8 @@ const express = require('express'); const passport = require('passport'); -const { randomState } = require('openid-client'); -const { logger } = require('@librechat/data-schemas'); +const crypto = require('node:crypto'); const { CacheKeys } = require('librechat-data-provider'); -const { SystemCapabilities } = require('@librechat/data-schemas'); +const { logger, SystemCapabilities } = require('@librechat/data-schemas'); const { getAdminPanelUrl, exchangeAdminCode, createSetBalanceConfig } = require('@librechat/api'); const { loginController } = require('~/server/controllers/auth/LoginController'); const { requireCapability } = require('~/server/middleware/roles/capabilities'); @@ -24,6 +23,28 @@ const setBalanceConfig = createSetBalanceConfig({ const router = express.Router(); +function resolveRequestOrigin(req) { + const originHeader = req.get('origin'); + if (originHeader) { + try { + return new URL(originHeader).origin; + } catch { + return undefined; + } + } + + const refererHeader = req.get('referer'); + if (!refererHeader) { + return undefined; + } + + try { + return new URL(refererHeader).origin; + } catch { + return undefined; + } +} + router.post( '/login/local', middleware.logHeaders, @@ -52,28 +73,340 @@ router.get('/oauth/openid/check', (req, res) => { res.status(200).json({ message: 'OpenID check successful' }); }); -router.get('/oauth/openid', (req, res, next) => { +/** PKCE challenge cache TTL: 5 minutes (enough for user to authenticate with IdP) */ +const PKCE_CHALLENGE_TTL = 5 * 60 * 1000; +/** Regex pattern for valid PKCE challenges: 64 hex characters (SHA-256 hex digest) */ +const PKCE_CHALLENGE_PATTERN = /^[a-f0-9]{64}$/; + +/** + * Generates a random hex state string for OAuth flows. + * @returns {string} A 32-byte random hex string. + */ +function generateState() { + return crypto.randomBytes(32).toString('hex'); +} + +/** + * Stores a PKCE challenge in cache keyed by state. + * @param {string} state - The OAuth state value. + * @param {string | undefined} codeChallenge - The PKCE code_challenge from query params. + * @param {string} provider - Provider name for logging. + * @returns {Promise} True if stored successfully or no challenge provided. + */ +async function storePkceChallenge(state, codeChallenge, provider) { + if (typeof codeChallenge !== 'string' || !PKCE_CHALLENGE_PATTERN.test(codeChallenge)) { + return true; + } + try { + const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE); + await cache.set(`pkce:${state}`, codeChallenge, PKCE_CHALLENGE_TTL); + return true; + } catch (err) { + logger.error(`[admin/oauth/${provider}] Failed to store PKCE challenge:`, err); + return false; + } +} + +/** + * Middleware to retrieve PKCE challenge from cache using the OAuth state. + * Reads state from req.oauthState (set by a preceding middleware). + * @param {string} provider - Provider name for logging. + * @returns {Function} Express middleware. + */ +function retrievePkceChallenge(provider) { + return async (req, res, next) => { + if (!req.oauthState) { + return next(); + } + try { + const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE); + const challenge = await cache.get(`pkce:${req.oauthState}`); + if (challenge) { + req.pkceChallenge = challenge; + await cache.delete(`pkce:${req.oauthState}`); + } else { + logger.warn( + `[admin/oauth/${provider}/callback] State present but no PKCE challenge found; PKCE will not be enforced for this request`, + ); + } + } catch (err) { + logger.error( + `[admin/oauth/${provider}/callback] Failed to retrieve PKCE challenge, aborting:`, + err, + ); + return res.redirect( + `${getAdminPanelUrl()}/auth/${provider}/callback?error=pkce_retrieval_failed&error_description=Failed+to+retrieve+PKCE+challenge`, + ); + } + next(); + }; +} + +/* ────────────────────────────────────────────── + * OpenID Admin Routes + * ────────────────────────────────────────────── */ + +router.get('/oauth/openid', async (req, res, next) => { + const state = generateState(); + const stored = await storePkceChallenge(state, req.query.code_challenge, 'openid'); + if (!stored) { + return res.redirect( + `${getAdminPanelUrl()}/auth/openid/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`, + ); + } + return passport.authenticate('openidAdmin', { session: false, - state: randomState(), + state, })(req, res, next); }); router.get( '/oauth/openid/callback', + (req, res, next) => { + req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined; + next(); + }, passport.authenticate('openidAdmin', { failureRedirect: `${getAdminPanelUrl()}/auth/openid/callback?error=auth_failed&error_description=Authentication+failed`, failureMessage: true, session: false, }), + retrievePkceChallenge('openid'), requireAdminAccess, setBalanceConfig, middleware.checkDomainAllowed, createOAuthHandler(`${getAdminPanelUrl()}/auth/openid/callback`), ); +/* ────────────────────────────────────────────── + * SAML Admin Routes + * ────────────────────────────────────────────── */ + +router.get('/oauth/saml', async (req, res, next) => { + const state = generateState(); + const stored = await storePkceChallenge(state, req.query.code_challenge, 'saml'); + if (!stored) { + return res.redirect( + `${getAdminPanelUrl()}/auth/saml/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`, + ); + } + + return passport.authenticate('samlAdmin', { + session: false, + additionalParams: { RelayState: state }, + })(req, res, next); +}); + +router.post( + '/oauth/saml/callback', + (req, res, next) => { + req.oauthState = typeof req.body.RelayState === 'string' ? req.body.RelayState : undefined; + next(); + }, + passport.authenticate('samlAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/saml/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + retrievePkceChallenge('saml'), + requireAdminAccess, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/saml/callback`), +); + +/* ────────────────────────────────────────────── + * Google Admin Routes + * ────────────────────────────────────────────── */ + +router.get('/oauth/google', async (req, res, next) => { + const state = generateState(); + const stored = await storePkceChallenge(state, req.query.code_challenge, 'google'); + if (!stored) { + return res.redirect( + `${getAdminPanelUrl()}/auth/google/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`, + ); + } + + return passport.authenticate('googleAdmin', { + scope: ['openid', 'profile', 'email'], + session: false, + state, + })(req, res, next); +}); + +router.get( + '/oauth/google/callback', + (req, res, next) => { + req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined; + next(); + }, + passport.authenticate('googleAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/google/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + retrievePkceChallenge('google'), + requireAdminAccess, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/google/callback`), +); + +/* ────────────────────────────────────────────── + * GitHub Admin Routes + * ────────────────────────────────────────────── */ + +router.get('/oauth/github', async (req, res, next) => { + const state = generateState(); + const stored = await storePkceChallenge(state, req.query.code_challenge, 'github'); + if (!stored) { + return res.redirect( + `${getAdminPanelUrl()}/auth/github/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`, + ); + } + + return passport.authenticate('githubAdmin', { + scope: ['user:email', 'read:user'], + session: false, + state, + })(req, res, next); +}); + +router.get( + '/oauth/github/callback', + (req, res, next) => { + req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined; + next(); + }, + passport.authenticate('githubAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/github/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + retrievePkceChallenge('github'), + requireAdminAccess, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/github/callback`), +); + +/* ────────────────────────────────────────────── + * Discord Admin Routes + * ────────────────────────────────────────────── */ + +router.get('/oauth/discord', async (req, res, next) => { + const state = generateState(); + const stored = await storePkceChallenge(state, req.query.code_challenge, 'discord'); + if (!stored) { + return res.redirect( + `${getAdminPanelUrl()}/auth/discord/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`, + ); + } + + return passport.authenticate('discordAdmin', { + scope: ['identify', 'email'], + session: false, + state, + })(req, res, next); +}); + +router.get( + '/oauth/discord/callback', + (req, res, next) => { + req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined; + next(); + }, + passport.authenticate('discordAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/discord/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + retrievePkceChallenge('discord'), + requireAdminAccess, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/discord/callback`), +); + +/* ────────────────────────────────────────────── + * Facebook Admin Routes + * ────────────────────────────────────────────── */ + +router.get('/oauth/facebook', async (req, res, next) => { + const state = generateState(); + const stored = await storePkceChallenge(state, req.query.code_challenge, 'facebook'); + if (!stored) { + return res.redirect( + `${getAdminPanelUrl()}/auth/facebook/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`, + ); + } + + return passport.authenticate('facebookAdmin', { + scope: ['public_profile'], + session: false, + state, + })(req, res, next); +}); + +router.get( + '/oauth/facebook/callback', + (req, res, next) => { + req.oauthState = typeof req.query.state === 'string' ? req.query.state : undefined; + next(); + }, + passport.authenticate('facebookAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/facebook/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + retrievePkceChallenge('facebook'), + requireAdminAccess, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/facebook/callback`), +); + +/* ────────────────────────────────────────────── + * Apple Admin Routes (POST callback) + * ────────────────────────────────────────────── */ + +router.get('/oauth/apple', async (req, res, next) => { + const state = generateState(); + const stored = await storePkceChallenge(state, req.query.code_challenge, 'apple'); + if (!stored) { + return res.redirect( + `${getAdminPanelUrl()}/auth/apple/callback?error=pkce_store_failed&error_description=Failed+to+store+PKCE+challenge`, + ); + } + + return passport.authenticate('appleAdmin', { + session: false, + state, + })(req, res, next); +}); + +router.post( + '/oauth/apple/callback', + (req, res, next) => { + req.oauthState = typeof req.body.state === 'string' ? req.body.state : undefined; + next(); + }, + passport.authenticate('appleAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/apple/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + retrievePkceChallenge('apple'), + requireAdminAccess, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/apple/callback`), +); + /** Regex pattern for valid exchange codes: 64 hex characters */ -const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/i; +const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/; /** * Exchange OAuth authorization code for tokens. @@ -81,12 +414,12 @@ const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/i; * The code is one-time-use and expires in 30 seconds. * * POST /api/admin/oauth/exchange - * Body: { code: string } + * Body: { code: string, code_verifier?: string } * Response: { token: string, refreshToken: string, user: object } */ router.post('/oauth/exchange', middleware.loginLimiter, async (req, res) => { try { - const { code } = req.body; + const { code, code_verifier: codeVerifier } = req.body; if (!code) { logger.warn('[admin/oauth/exchange] Missing authorization code'); @@ -104,8 +437,20 @@ router.post('/oauth/exchange', middleware.loginLimiter, async (req, res) => { }); } + if ( + codeVerifier !== undefined && + (typeof codeVerifier !== 'string' || codeVerifier.length < 1 || codeVerifier.length > 512) + ) { + logger.warn('[admin/oauth/exchange] Invalid code_verifier format'); + return res.status(400).json({ + error: 'Invalid code_verifier', + error_code: 'INVALID_VERIFIER', + }); + } + const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE); - const result = await exchangeAdminCode(cache, code); + const requestOrigin = resolveRequestOrigin(req); + const result = await exchangeAdminCode(cache, code, requestOrigin, codeVerifier); if (!result) { return res.status(401).json({ diff --git a/api/server/routes/admin/config.js b/api/server/routes/admin/config.js new file mode 100644 index 0000000000..0632077ea9 --- /dev/null +++ b/api/server/routes/admin/config.js @@ -0,0 +1,40 @@ +const express = require('express'); +const { createAdminConfigHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { + hasConfigCapability, + requireCapability, +} = require('~/server/middleware/roles/capabilities'); +const { getAppConfig, invalidateConfigCaches } = require('~/server/services/Config'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); + +const handlers = createAdminConfigHandlers({ + listAllConfigs: db.listAllConfigs, + findConfigByPrincipal: db.findConfigByPrincipal, + upsertConfig: db.upsertConfig, + patchConfigFields: db.patchConfigFields, + unsetConfigField: db.unsetConfigField, + deleteConfig: db.deleteConfig, + toggleConfigActive: db.toggleConfigActive, + hasConfigCapability, + getAppConfig, + invalidateConfigCaches, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', handlers.listConfigs); +router.get('/base', handlers.getBaseConfig); +router.get('/:principalType/:principalId', handlers.getConfig); +router.put('/:principalType/:principalId', handlers.upsertConfigOverrides); +router.patch('/:principalType/:principalId/fields', handlers.patchConfigField); +router.delete('/:principalType/:principalId/fields', handlers.deleteConfigField); +router.delete('/:principalType/:principalId', handlers.deleteConfigOverrides); +router.patch('/:principalType/:principalId/active', handlers.toggleConfig); + +module.exports = router; diff --git a/api/server/routes/admin/grants.js b/api/server/routes/admin/grants.js new file mode 100644 index 0000000000..a0fa73dc43 --- /dev/null +++ b/api/server/routes/admin/grants.js @@ -0,0 +1,35 @@ +const express = require('express'); +const { createAdminGrantsHandlers, getCachedPrincipals } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); + +const handlers = createAdminGrantsHandlers({ + listGrants: db.listGrants, + countGrants: db.countGrants, + getCapabilitiesForPrincipal: db.getCapabilitiesForPrincipal, + getCapabilitiesForPrincipals: db.getCapabilitiesForPrincipals, + grantCapability: db.grantCapability, + revokeCapability: db.revokeCapability, + getUserPrincipals: db.getUserPrincipals, + hasCapabilityForPrincipals: db.hasCapabilityForPrincipals, + getHeldCapabilities: db.getHeldCapabilities, + getCachedPrincipals, + checkRoleExists: async (name) => (await db.getRoleByName(name)) != null, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', handlers.listGrants); +router.get('/effective', handlers.getEffectiveCapabilities); +router.get('/:principalType/:principalId', handlers.getPrincipalGrants); +router.post('/', handlers.assignGrant); +/** Callers should encodeURIComponent the capability for client compatibility (e.g. manage%3Aconfigs%3Aendpoints). */ +router.delete('/:principalType/:principalId/:capability', handlers.revokeGrant); + +module.exports = router; diff --git a/api/server/routes/admin/groups.js b/api/server/routes/admin/groups.js new file mode 100644 index 0000000000..11ed59737e --- /dev/null +++ b/api/server/routes/admin/groups.js @@ -0,0 +1,40 @@ +const express = require('express'); +const { createAdminGroupsHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); +const requireReadGroups = requireCapability(SystemCapabilities.READ_GROUPS); +const requireManageGroups = requireCapability(SystemCapabilities.MANAGE_GROUPS); + +const handlers = createAdminGroupsHandlers({ + listGroups: db.listGroups, + countGroups: db.countGroups, + findGroupById: db.findGroupById, + createGroup: db.createGroup, + updateGroupById: db.updateGroupById, + deleteGroup: db.deleteGroup, + addUserToGroup: db.addUserToGroup, + removeUserFromGroup: db.removeUserFromGroup, + removeMemberById: db.removeMemberById, + findUsers: db.findUsers, + deleteConfig: db.deleteConfig, + deleteAclEntries: db.deleteAclEntries, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', requireReadGroups, handlers.listGroups); +router.post('/', requireManageGroups, handlers.createGroup); +router.get('/:id', requireReadGroups, handlers.getGroup); +router.patch('/:id', requireManageGroups, handlers.updateGroup); +router.delete('/:id', requireManageGroups, handlers.deleteGroup); +router.get('/:id/members', requireReadGroups, handlers.getGroupMembers); +router.post('/:id/members', requireManageGroups, handlers.addGroupMember); +router.delete('/:id/members/:userId', requireManageGroups, handlers.removeGroupMember); + +module.exports = router; diff --git a/api/server/routes/admin/roles.js b/api/server/routes/admin/roles.js new file mode 100644 index 0000000000..f2bbd7f7ea --- /dev/null +++ b/api/server/routes/admin/roles.js @@ -0,0 +1,46 @@ +const express = require('express'); +const { createAdminRolesHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); +const requireReadRoles = requireCapability(SystemCapabilities.READ_ROLES); +const requireManageRoles = requireCapability(SystemCapabilities.MANAGE_ROLES); + +const handlers = createAdminRolesHandlers({ + listRoles: db.listRoles, + countRoles: db.countRoles, + getRoleByName: db.getRoleByName, + createRoleByName: db.createRoleByName, + updateRoleByName: db.updateRoleByName, + updateAccessPermissions: db.updateAccessPermissions, + deleteRoleByName: db.deleteRoleByName, + findUser: db.findUser, + updateUser: db.updateUser, + updateUsersByRole: db.updateUsersByRole, + findUserIdsByRole: db.findUserIdsByRole, + updateUsersRoleByIds: db.updateUsersRoleByIds, + listUsersByRole: db.listUsersByRole, + countUsersByRole: db.countUsersByRole, + deleteConfig: db.deleteConfig, + deleteAclEntries: db.deleteAclEntries, + deleteGrantsForPrincipal: db.deleteGrantsForPrincipal, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', requireReadRoles, handlers.listRoles); +router.post('/', requireManageRoles, handlers.createRole); +router.get('/:name', requireReadRoles, handlers.getRole); +router.patch('/:name', requireManageRoles, handlers.updateRole); +router.delete('/:name', requireManageRoles, handlers.deleteRole); +router.patch('/:name/permissions', requireManageRoles, handlers.updateRolePermissions); +router.get('/:name/members', requireReadRoles, handlers.getRoleMembers); +router.post('/:name/members', requireManageRoles, handlers.addRoleMember); +router.delete('/:name/members/:userId', requireManageRoles, handlers.removeRoleMember); + +module.exports = router; diff --git a/api/server/routes/admin/users.js b/api/server/routes/admin/users.js new file mode 100644 index 0000000000..20d4eb1797 --- /dev/null +++ b/api/server/routes/admin/users.js @@ -0,0 +1,28 @@ +const express = require('express'); +const { createAdminUsersHandlers } = require('@librechat/api'); +const { SystemCapabilities } = require('@librechat/data-schemas'); +const { requireCapability } = require('~/server/middleware/roles/capabilities'); +const { requireJwtAuth } = require('~/server/middleware'); +const db = require('~/models'); + +const router = express.Router(); + +const requireAdminAccess = requireCapability(SystemCapabilities.ACCESS_ADMIN); +const requireReadUsers = requireCapability(SystemCapabilities.READ_USERS); +// const requireManageUsers = requireCapability(SystemCapabilities.MANAGE_USERS); + +const handlers = createAdminUsersHandlers({ + findUsers: db.findUsers, + countUsers: db.countUsers, + deleteUserById: db.deleteUserById, + deleteConfig: db.deleteConfig, + deleteAclEntries: db.deleteAclEntries, +}); + +router.use(requireJwtAuth, requireAdminAccess); + +router.get('/', requireReadUsers, handlers.listUsers); +router.get('/search', requireReadUsers, handlers.searchUsers); +// router.delete('/:id', requireManageUsers, handlers.deleteUser); + +module.exports = router; diff --git a/api/server/routes/agents/__tests__/streamTenant.spec.js b/api/server/routes/agents/__tests__/streamTenant.spec.js new file mode 100644 index 0000000000..1f89953186 --- /dev/null +++ b/api/server/routes/agents/__tests__/streamTenant.spec.js @@ -0,0 +1,186 @@ +const express = require('express'); +const request = require('supertest'); + +const mockGenerationJobManager = { + getJob: jest.fn(), + subscribe: jest.fn(), + getResumeState: jest.fn(), + abortJob: jest.fn(), + getActiveJobIdsForUser: jest.fn().mockResolvedValue([]), +}; + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn().mockReturnValue(false), + GenerationJobManager: mockGenerationJobManager, +})); + +jest.mock('~/models', () => ({ + saveMessage: jest.fn(), +})); + +let mockUserId = 'user-123'; +let mockTenantId; + +jest.mock('~/server/middleware', () => ({ + uaParser: (req, res, next) => next(), + checkBan: (req, res, next) => next(), + requireJwtAuth: (req, res, next) => { + req.user = { id: mockUserId, tenantId: mockTenantId }; + next(); + }, + messageIpLimiter: (req, res, next) => next(), + configMiddleware: (req, res, next) => next(), + messageUserLimiter: (req, res, next) => next(), +})); + +jest.mock('~/server/routes/agents/chat', () => require('express').Router()); +jest.mock('~/server/routes/agents/v1', () => ({ + v1: require('express').Router(), +})); +jest.mock('~/server/routes/agents/openai', () => require('express').Router()); +jest.mock('~/server/routes/agents/responses', () => require('express').Router()); + +const agentsRouter = require('../index'); +const app = express(); +app.use(express.json()); +app.use('/agents', agentsRouter); + +function mockSubscribeSuccess() { + mockGenerationJobManager.subscribe.mockImplementation((_streamId, _writeEvent, onDone) => { + process.nextTick(() => onDone({ done: true })); + return { unsubscribe: jest.fn() }; + }); +} + +describe('SSE stream tenant isolation', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockUserId = 'user-123'; + mockTenantId = undefined; + }); + + describe('GET /chat/stream/:streamId', () => { + it('returns 403 when a user from a different tenant accesses a stream', async () => { + mockUserId = 'user-456'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-456', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns 404 when stream does not exist', async () => { + mockGenerationJobManager.getJob.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/stream/nonexistent'); + expect(res.status).toBe(404); + }); + + it('proceeds past tenant guard when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('proceeds past tenant guard when job has no tenantId (single-tenant mode)', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + mockSubscribeSuccess(); + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(200); + expect(mockGenerationJobManager.subscribe).toHaveBeenCalledTimes(1); + }); + + it('returns 403 when job has tenantId but user has no tenantId', async () => { + mockUserId = 'user-123'; + mockTenantId = undefined; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'some-tenant' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/stream/stream-123'); + expect(res.status).toBe(403); + }); + }); + + describe('GET /chat/status/:conversationId', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + + it('returns status when tenant matches', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-a'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + createdAt: Date.now(), + }); + mockGenerationJobManager.getResumeState.mockResolvedValue(null); + + const res = await request(app).get('/agents/chat/status/conv-123'); + expect(res.status).toBe(200); + expect(res.body.active).toBe(true); + }); + }); + + describe('POST /chat/abort', () => { + it('returns 403 when tenant does not match', async () => { + mockUserId = 'user-123'; + mockTenantId = 'tenant-b'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'user-123', tenantId: 'tenant-a' }, + status: 'running', + }); + + const res = await request(app).post('/agents/chat/abort').send({ streamId: 'stream-123' }); + expect(res.status).toBe(403); + expect(res.body.error).toBe('Unauthorized'); + }); + }); +}); diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index 86966a3f3e..eb42046bed 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -17,6 +17,11 @@ const chat = require('./chat'); const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; +/** Untenanted jobs (pre-multi-tenancy) remain accessible if the userId check passes. */ +function hasTenantMismatch(job, user) { + return job.metadata?.tenantId != null && job.metadata.tenantId !== user.tenantId; +} + const router = express.Router(); /** @@ -67,6 +72,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + res.setHeader('Content-Encoding', 'identity'); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache, no-transform'); @@ -150,7 +159,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { * @returns { activeJobIds: string[] } */ router.get('/chat/active', async (req, res) => { - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + req.user.id, + req.user.tenantId, + ); res.json({ activeJobIds }); }); @@ -174,6 +186,10 @@ router.get('/chat/status/:conversationId', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + // Get resume state which contains aggregatedContent // Avoid calling both getStreamInfo and getResumeState (both fetch content) const resumeState = await GenerationJobManager.getResumeState(conversationId); @@ -213,7 +229,10 @@ router.post('/chat/abort', async (req, res) => { // This handles the case where frontend sends "new" but job was created with a UUID if (!job && userId) { logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`); - const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId); + const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser( + userId, + req.user.tenantId, + ); if (activeJobIds.length > 0) { // Abort the most recent active job for this user jobStreamId = activeJobIds[0]; @@ -230,6 +249,10 @@ router.post('/chat/abort', async (req, res) => { return res.status(403).json({ error: 'Unauthorized' }); } + if (hasTenantMismatch(job, req.user)) { + return res.status(403).json({ error: 'Unauthorized' }); + } + logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`); const abortResult = await GenerationJobManager.abortJob(jobStreamId); logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, { diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 12961d3ff5..e63812f5ba 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,10 +1,9 @@ const express = require('express'); -const { logger } = require('@librechat/data-schemas'); const { isEnabled, getBalanceConfig } = require('@librechat/api'); -const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { defaultSocialLogins } = require('librechat-data-provider'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getAppConfig } = require('~/server/services/Config/app'); -const { getLogStores } = require('~/cache'); const router = express.Router(); const emailLoginEnabled = @@ -20,128 +19,159 @@ const publicSharedLinksEnabled = const sharePointFilePickerEnabled = isEnabled(process.env.ENABLE_SHAREPOINT_FILEPICKER); const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS); -router.get('/', async function (req, res) { - const cache = getLogStores(CacheKeys.CONFIG_STORE); +function isBirthday() { + const today = new Date(); + return today.getMonth() === 1 && today.getDate() === 11; +} - const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); - if (cachedStartupConfig) { - res.send(cachedStartupConfig); - return; - } +function buildSharedPayload() { + const isOpenIdEnabled = + !!process.env.OPENID_CLIENT_ID && + (isEnabled(process.env.OPENID_USE_PKCE) || !!process.env.OPENID_CLIENT_SECRET?.trim()) && + !!process.env.OPENID_ISSUER && + !!process.env.OPENID_SESSION_SECRET; - const isBirthday = () => { - const today = new Date(); - return today.getMonth() === 1 && today.getDate() === 11; - }; + const isSamlEnabled = + !!process.env.SAML_ENTRY_POINT && + !!process.env.SAML_ISSUER && + !!process.env.SAML_CERT && + !!process.env.SAML_SESSION_SECRET; const ldap = getLdapConfig(); + /** @type {Partial} */ + const payload = { + appTitle: process.env.APP_TITLE || 'LibreChat', + discordLoginEnabled: !!process.env.DISCORD_CLIENT_ID && !!process.env.DISCORD_CLIENT_SECRET, + facebookLoginEnabled: !!process.env.FACEBOOK_CLIENT_ID && !!process.env.FACEBOOK_CLIENT_SECRET, + githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET, + googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET, + appleLoginEnabled: + !!process.env.APPLE_CLIENT_ID && + !!process.env.APPLE_TEAM_ID && + !!process.env.APPLE_KEY_ID && + !!process.env.APPLE_PRIVATE_KEY_PATH, + openidLoginEnabled: isOpenIdEnabled, + openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID', + openidImageUrl: process.env.OPENID_IMAGE_URL, + openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT), + samlLoginEnabled: !isOpenIdEnabled && isSamlEnabled, + samlLabel: process.env.SAML_BUTTON_LABEL, + samlImageUrl: process.env.SAML_IMAGE_URL, + serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080', + emailLoginEnabled, + registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION), + socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN), + emailEnabled: + (!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) && + !!process.env.EMAIL_USERNAME && + !!process.env.EMAIL_PASSWORD && + !!process.env.EMAIL_FROM, + passwordResetEnabled, + showBirthdayIcon: + isBirthday() || + isEnabled(process.env.SHOW_BIRTHDAY_ICON) || + process.env.SHOW_BIRTHDAY_ICON === '', + helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai', + sharedLinksEnabled, + publicSharedLinksEnabled, + analyticsGtmId: process.env.ANALYTICS_GTM_ID, + openidReuseTokens, + }; + + const minPasswordLength = parseInt(process.env.MIN_PASSWORD_LENGTH, 10); + if (minPasswordLength && !isNaN(minPasswordLength)) { + payload.minPasswordLength = minPasswordLength; + } + + if (ldap) { + payload.ldap = ldap; + } + + if (typeof process.env.CUSTOM_FOOTER === 'string') { + payload.customFooter = process.env.CUSTOM_FOOTER; + } + + return payload; +} + +function buildWebSearchConfig(appConfig) { + const ws = appConfig?.webSearch; + if (!ws) { + return undefined; + } + const { searchProvider, scraperProvider, rerankerType } = ws; + if (!searchProvider && !scraperProvider && !rerankerType) { + return undefined; + } + return { + ...(searchProvider && { searchProvider }), + ...(scraperProvider && { scraperProvider }), + ...(rerankerType && { rerankerType }), + }; +} + +router.get('/', async function (req, res) { try { - const appConfig = await getAppConfig({ role: req.user?.role }); + const sharedPayload = buildSharedPayload(); - const isOpenIdEnabled = - !!process.env.OPENID_CLIENT_ID && - (isEnabled(process.env.OPENID_USE_PKCE) || !!process.env.OPENID_CLIENT_SECRET?.trim()) && - !!process.env.OPENID_ISSUER && - !!process.env.OPENID_SESSION_SECRET; + if (!req.user) { + const tenantId = getTenantId(); + const baseConfig = await getAppConfig(tenantId ? { tenantId } : { baseOnly: true }); - const isSamlEnabled = - !!process.env.SAML_ENTRY_POINT && - !!process.env.SAML_ISSUER && - !!process.env.SAML_CERT && - !!process.env.SAML_SESSION_SECRET; + /** @type {Partial} */ + const payload = { + ...sharedPayload, + socialLogins: baseConfig?.registration?.socialLogins ?? defaultSocialLogins, + turnstile: baseConfig?.turnstileConfig, + }; + + const interfaceConfig = baseConfig?.interfaceConfig; + if (interfaceConfig?.privacyPolicy || interfaceConfig?.termsOfService) { + payload.interface = {}; + if (interfaceConfig.privacyPolicy) { + payload.interface.privacyPolicy = interfaceConfig.privacyPolicy; + } + if (interfaceConfig.termsOfService) { + payload.interface.termsOfService = interfaceConfig.termsOfService; + } + } + + return res.status(200).send(payload); + } + + const appConfig = await getAppConfig({ + role: req.user.role, + userId: req.user.id, + tenantId: req.user.tenantId || getTenantId(), + }); const balanceConfig = getBalanceConfig(appConfig); /** @type {TStartupConfig} */ const payload = { - appTitle: process.env.APP_TITLE || 'LibreChat', + ...sharedPayload, socialLogins: appConfig?.registration?.socialLogins ?? defaultSocialLogins, - discordLoginEnabled: !!process.env.DISCORD_CLIENT_ID && !!process.env.DISCORD_CLIENT_SECRET, - facebookLoginEnabled: - !!process.env.FACEBOOK_CLIENT_ID && !!process.env.FACEBOOK_CLIENT_SECRET, - githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET, - googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET, - appleLoginEnabled: - !!process.env.APPLE_CLIENT_ID && - !!process.env.APPLE_TEAM_ID && - !!process.env.APPLE_KEY_ID && - !!process.env.APPLE_PRIVATE_KEY_PATH, - openidLoginEnabled: isOpenIdEnabled, - openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID', - openidImageUrl: process.env.OPENID_IMAGE_URL, - openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT), - samlLoginEnabled: !isOpenIdEnabled && isSamlEnabled, - samlLabel: process.env.SAML_BUTTON_LABEL, - samlImageUrl: process.env.SAML_IMAGE_URL, - serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080', - emailLoginEnabled, - registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION), - socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN), - emailEnabled: - (!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) && - !!process.env.EMAIL_USERNAME && - !!process.env.EMAIL_PASSWORD && - !!process.env.EMAIL_FROM, - passwordResetEnabled, - showBirthdayIcon: - isBirthday() || - isEnabled(process.env.SHOW_BIRTHDAY_ICON) || - process.env.SHOW_BIRTHDAY_ICON === '', - helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai', interface: appConfig?.interfaceConfig, turnstile: appConfig?.turnstileConfig, modelSpecs: appConfig?.modelSpecs, balance: balanceConfig, - sharedLinksEnabled, - publicSharedLinksEnabled, - analyticsGtmId: process.env.ANALYTICS_GTM_ID, bundlerURL: process.env.SANDPACK_BUNDLER_URL, staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL, sharePointFilePickerEnabled, sharePointBaseUrl: process.env.SHAREPOINT_BASE_URL, sharePointPickerGraphScope: process.env.SHAREPOINT_PICKER_GRAPH_SCOPE, sharePointPickerSharePointScope: process.env.SHAREPOINT_PICKER_SHAREPOINT_SCOPE, - openidReuseTokens, conversationImportMaxFileSize: process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES ? parseInt(process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES, 10) : 0, }; - const minPasswordLength = parseInt(process.env.MIN_PASSWORD_LENGTH, 10); - if (minPasswordLength && !isNaN(minPasswordLength)) { - payload.minPasswordLength = minPasswordLength; + const webSearch = buildWebSearchConfig(appConfig); + if (webSearch) { + payload.webSearch = webSearch; } - const webSearchConfig = appConfig?.webSearch; - if ( - webSearchConfig != null && - (webSearchConfig.searchProvider || - webSearchConfig.scraperProvider || - webSearchConfig.rerankerType) - ) { - payload.webSearch = {}; - } - - if (webSearchConfig?.searchProvider) { - payload.webSearch.searchProvider = webSearchConfig.searchProvider; - } - if (webSearchConfig?.scraperProvider) { - payload.webSearch.scraperProvider = webSearchConfig.scraperProvider; - } - if (webSearchConfig?.rerankerType) { - payload.webSearch.rerankerType = webSearchConfig.rerankerType; - } - - if (ldap) { - payload.ldap = ldap; - } - - if (typeof process.env.CUSTOM_FOOTER === 'string') { - payload.customFooter = process.env.CUSTOM_FOOTER; - } - - await cache.set(CacheKeys.STARTUP_CONFIG, payload); return res.status(200).send(payload); } catch (err) { logger.error('Error in startup config', err); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 1964075ed3..ded7d835d7 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -267,7 +267,11 @@ router.post( async (req, res) => { try { /* TODO: optimize to return imported conversations and add manually */ - await importConversations({ filepath: req.file.path, requestUserId: req.user.id }); + await importConversations({ + filepath: req.file.path, + requestUserId: req.user.id, + userRole: req.user.role, + }); res.status(201).json({ message: 'Conversation(s) imported successfully' }); } catch (error) { logger.error('Error processing file', error); diff --git a/api/server/routes/endpoints.js b/api/server/routes/endpoints.js index 794abde0c2..e7ff1c7000 100644 --- a/api/server/routes/endpoints.js +++ b/api/server/routes/endpoints.js @@ -1,7 +1,9 @@ const express = require('express'); +const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); const endpointController = require('~/server/controllers/EndpointController'); const router = express.Router(); -router.get('/', endpointController); +/** Auth required for role/tenant-scoped endpoint config resolution. */ +router.get('/', requireJwtAuth, endpointController); module.exports = router; diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 6a48919db3..1feaf63fdb 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -2,6 +2,11 @@ const accessPermissions = require('./accessPermissions'); const assistants = require('./assistants'); const categories = require('./categories'); const adminAuth = require('./admin/auth'); +const adminConfig = require('./admin/config'); +const adminGrants = require('./admin/grants'); +const adminGroups = require('./admin/groups'); +const adminRoles = require('./admin/roles'); +const adminUsers = require('./admin/users'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -31,6 +36,11 @@ module.exports = { mcp, auth, adminAuth, + adminConfig, + adminGrants, + adminGroups, + adminRoles, + adminUsers, keys, apiKeys, user, diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index d6d7ed5ea0..c6496ad4b4 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -1,5 +1,5 @@ const { Router } = require('express'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { CacheKeys, Constants, @@ -36,7 +36,11 @@ const { getFlowStateManager, getMCPManager, } = require('~/config'); -const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); +const { + getServerConnectionStatus, + resolveConfigServers, + getMCPSetupData, +} = require('~/server/services/MCP'); const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { updateMCPServerTools } = require('~/server/services/Config/mcp'); @@ -101,7 +105,8 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async return res.status(400).json({ error: 'Invalid flow state' }); } - const oauthHeaders = await getOAuthHeaders(serverName, userId); + const configServers = await resolveConfigServers(req); + const oauthHeaders = await getOAuthHeaders(serverName, userId, configServers); const { authorizationUrl, flowId: oauthFlowId, @@ -233,7 +238,14 @@ router.get('/:serverName/oauth/callback', async (req, res) => { } logger.debug('[MCP OAuth] Completing OAuth flow'); - const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId); + if (!flowState.oauthHeaders) { + logger.warn( + '[MCP OAuth] oauthHeaders absent from flow state — config-source server oauth_headers will be empty', + { serverName, flowId }, + ); + } + const oauthHeaders = + flowState.oauthHeaders ?? (await getOAuthHeaders(serverName, flowState.userId)); const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); @@ -497,7 +509,12 @@ router.post( logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); const mcpManager = getMCPManager(); - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -522,6 +539,8 @@ router.post( const result = await reinitMCPServer({ user, serverName, + serverConfig, + configServers, userMCPAuthMap, }); @@ -564,6 +583,7 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); const connectionStatus = {}; @@ -593,9 +613,6 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => { connectionStatus, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error('[MCP Connection Status] Failed to get connection status', error); res.status(500).json({ error: 'Failed to get connection status' }); } @@ -616,6 +633,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData( user.id, + { role: user.role, tenantId: getTenantId() }, ); if (!mcpConfig[serverName]) { @@ -640,9 +658,6 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => requiresOAuth: serverStatus.requiresOAuth, }); } catch (error) { - if (error.message === 'MCP config not found') { - return res.status(404).json({ error: error.message }); - } logger.error( `[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`, error, @@ -664,7 +679,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a return res.status(401).json({ error: 'User not authenticated' }); } - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + const configServers = await resolveConfigServers(req); + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + user.id, + configServers, + ); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -703,8 +723,12 @@ router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, a } }); -async function getOAuthHeaders(serverName, userId) { - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, userId); +async function getOAuthHeaders(serverName, userId, configServers) { + const serverConfig = await getMCPServersRegistry().getServerConfig( + serverName, + userId, + configServers, + ); return serverConfig?.oauth_headers ?? {}; } diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index ef50a365b9..816a0eac5b 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -13,6 +13,7 @@ const { checkEmailConfig, isEmailDomainAllowed, shouldUseSecureCookie, + resolveAppConfigForUser, } = require('@librechat/api'); const { findUser, @@ -189,7 +190,7 @@ const registerUser = async (user, additionalData = {}) => { let newUserId; try { - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { const errorMessage = 'The email address provided cannot be used. Please use a different email address.'; @@ -255,19 +256,52 @@ const registerUser = async (user, additionalData = {}) => { }; /** - * Request password reset + * Request password reset. + * + * Uses a two-phase domain check: fast-fail with the memory-cached base config + * (zero DB queries) to block globally denied domains before user lookup, then + * re-check with tenant-scoped config after user lookup so tenant-specific + * restrictions are enforced. + * + * Phase 1 (base check) returns an Error (HTTP 400) — this intentionally reveals + * that the domain is globally blocked, but fires before any DB lookup so it + * cannot confirm user existence. Phase 2 (tenant check) returns the generic + * success message (HTTP 200) to prevent user-enumeration via status codes. + * * @param {ServerRequest} req */ const requestPasswordReset = async (req) => { const { email } = req.body; - const appConfig = await getAppConfig(); - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { + logger.warn( + `[requestPasswordReset] Blocked - email domain not allowed [Email: ${email}] [IP: ${req.ip}]`, + ); const error = new Error(ErrorTypes.AUTH_FAILED); error.code = ErrorTypes.AUTH_FAILED; error.message = 'Email domain not allowed'; return error; } - const user = await findUser({ email }, 'email _id'); + + const user = await findUser({ email }, 'email _id role tenantId'); + let appConfig = baseConfig; + if (user?.tenantId) { + try { + appConfig = await resolveAppConfigForUser(getAppConfig, user); + } catch (err) { + logger.error('[requestPasswordReset] Failed to resolve tenant config, using base:', err); + } + } + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.warn( + `[requestPasswordReset] Tenant config blocked domain [Email: ${email}] [IP: ${req.ip}]`, + ); + return { + message: 'If an account with that email exists, a password reset link has been sent to it.', + }; + } const emailEnabled = checkEmailConfig(); logger.warn(`[requestPasswordReset] [Password reset request initiated] [Email: ${email}]`); diff --git a/api/server/services/AuthService.spec.js b/api/server/services/AuthService.spec.js index da78f8d775..c8abafdbe5 100644 --- a/api/server/services/AuthService.spec.js +++ b/api/server/services/AuthService.spec.js @@ -14,6 +14,7 @@ jest.mock('@librechat/api', () => ({ isEmailDomainAllowed: jest.fn(), math: jest.fn((val, fallback) => (val ? Number(val) : fallback)), shouldUseSecureCookie: jest.fn(() => false), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/models', () => ({ findUser: jest.fn(), @@ -35,8 +36,14 @@ jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn() jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() })); jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() })); -const { shouldUseSecureCookie } = require('@librechat/api'); -const { setOpenIDAuthTokens } = require('./AuthService'); +const { + shouldUseSecureCookie, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); +const { findUser } = require('~/models'); +const { getAppConfig } = require('~/server/services/Config'); +const { setOpenIDAuthTokens, requestPasswordReset } = require('./AuthService'); /** Helper to build a mock Express response */ function mockResponse() { @@ -267,3 +274,68 @@ describe('setOpenIDAuthTokens', () => { }); }); }); + +describe('requestPasswordReset', () => { + beforeEach(() => { + jest.clearAllMocks(); + isEmailDomainAllowed.mockReturnValue(true); + getAppConfig.mockResolvedValue({ + registration: { allowedDomains: ['example.com'] }, + }); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['example.com'] }, + }); + }); + + it('should fast-fail with base config before DB lookup for blocked domains', async () => { + isEmailDomainAllowed.mockReturnValue(false); + + const req = { body: { email: 'blocked@evil.com' }, ip: '127.0.0.1' }; + const result = await requestPasswordReset(req); + + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + expect(findUser).not.toHaveBeenCalled(); + expect(result).toBeInstanceOf(Error); + }); + + it('should call resolveAppConfigForUser for tenant user', async () => { + const user = { + _id: 'user-tenant', + email: 'user@example.com', + tenantId: 'tenant-x', + role: 'USER', + }; + findUser.mockResolvedValue(user); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + await requestPasswordReset(req); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, user); + }); + + it('should reuse baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue({ _id: 'user-no-tenant', email: 'user@example.com' }); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + await requestPasswordReset(req); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + }); + + it('should return generic response when tenant config blocks the domain (non-enumerable)', async () => { + const user = { + _id: 'user-tenant', + email: 'user@example.com', + tenantId: 'tenant-x', + role: 'USER', + }; + findUser.mockResolvedValue(user); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const req = { body: { email: 'user@example.com' }, ip: '127.0.0.1' }; + const result = await requestPasswordReset(req); + + expect(result).not.toBeInstanceOf(Error); + expect(result.message).toContain('If an account with that email exists'); + }); +}); diff --git a/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js new file mode 100644 index 0000000000..ddc97042b9 --- /dev/null +++ b/api/server/services/Config/__tests__/invalidateConfigCaches.spec.js @@ -0,0 +1,122 @@ +// ── Mocks ────────────────────────────────────────────────────────────── + +const mockClearAppConfigCache = jest.fn().mockResolvedValue(undefined); +const mockClearOverrideCache = jest.fn().mockResolvedValue(undefined); + +jest.mock('~/cache/getLogStores', () => { + return jest.fn(() => ({})); +}); + +jest.mock('~/server/services/start/tools', () => ({ + loadAndFormatTools: jest.fn(() => ({})), +})); + +jest.mock('../loadCustomConfig', () => jest.fn().mockResolvedValue({})); + +jest.mock('@librechat/data-schemas', () => { + const actual = jest.requireActual('@librechat/data-schemas'); + return { ...actual, AppService: jest.fn(() => ({ availableTools: {} })) }; +}); + +jest.mock('~/models', () => ({ + getApplicableConfigs: jest.fn().mockResolvedValue([]), + getUserPrincipals: jest.fn().mockResolvedValue([]), +})); + +const mockInvalidateCachedTools = jest.fn().mockResolvedValue(undefined); +jest.mock('../getCachedTools', () => ({ + setCachedTools: jest.fn().mockResolvedValue(undefined), + invalidateCachedTools: mockInvalidateCachedTools, +})); + +const mockClearMcpConfigCache = jest.fn().mockResolvedValue(undefined); +jest.mock('@librechat/api', () => ({ + createAppConfigService: jest.fn(() => ({ + getAppConfig: jest.fn().mockResolvedValue({ availableTools: {} }), + clearAppConfigCache: mockClearAppConfigCache, + clearOverrideCache: mockClearOverrideCache, + })), + clearMcpConfigCache: mockClearMcpConfigCache, +})); + +// ── Tests ────────────────────────────────────────────────────────────── + +const { invalidateConfigCaches } = require('../app'); + +describe('invalidateConfigCaches', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('clears all caches', async () => { + await invalidateConfigCaches(); + + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + expect(mockClearMcpConfigCache).toHaveBeenCalledTimes(1); + }); + + it('passes tenantId through to clearOverrideCache', async () => { + await invalidateConfigCaches('tenant-a'); + + expect(mockClearOverrideCache).toHaveBeenCalledWith('tenant-a'); + expect(mockClearAppConfigCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); + + it('all operations run in parallel (not sequentially)', async () => { + const order = []; + + mockClearAppConfigCache.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('base'); + r(); + }, 10), + ), + ); + mockClearOverrideCache.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('override'); + r(); + }, 10), + ), + ); + mockInvalidateCachedTools.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('tools'); + r(); + }, 10), + ), + ); + mockClearMcpConfigCache.mockImplementation( + () => + new Promise((r) => + setTimeout(() => { + order.push('mcp'); + r(); + }, 10), + ), + ); + + await invalidateConfigCaches(); + + expect(order).toHaveLength(4); + expect(new Set(order)).toEqual(new Set(['base', 'override', 'tools', 'mcp'])); + }); + + it('resolves even when clearAppConfigCache throws (partial failure)', async () => { + mockClearAppConfigCache.mockRejectedValueOnce(new Error('cache connection lost')); + + await expect(invalidateConfigCaches()).resolves.not.toThrow(); + + expect(mockClearOverrideCache).toHaveBeenCalledTimes(1); + expect(mockInvalidateCachedTools).toHaveBeenCalledWith({ invalidateGlobal: true }); + }); +}); diff --git a/api/server/services/Config/app.js b/api/server/services/Config/app.js index 75a5cbe56d..7aa913e636 100644 --- a/api/server/services/Config/app.js +++ b/api/server/services/Config/app.js @@ -1,12 +1,12 @@ const { CacheKeys } = require('librechat-data-provider'); -const { logger, AppService } = require('@librechat/data-schemas'); +const { AppService, logger } = require('@librechat/data-schemas'); +const { createAppConfigService, clearMcpConfigCache } = require('@librechat/api'); +const { setCachedTools, invalidateCachedTools } = require('./getCachedTools'); const { loadAndFormatTools } = require('~/server/services/start/tools'); const loadCustomConfig = require('./loadCustomConfig'); -const { setCachedTools } = require('./getCachedTools'); const getLogStores = require('~/cache/getLogStores'); const paths = require('~/config/paths'); - -const BASE_CONFIG_KEY = '_BASE_'; +const db = require('~/models'); const loadBaseConfig = async () => { /** @type {TCustomConfig} */ @@ -20,65 +20,43 @@ const loadBaseConfig = async () => { return AppService({ config, paths, systemTools }); }; -/** - * Get the app configuration based on user context - * @param {Object} [options] - * @param {string} [options.role] - User role for role-based config - * @param {boolean} [options.refresh] - Force refresh the cache - * @returns {Promise} - */ -async function getAppConfig(options = {}) { - const { role, refresh } = options; - - const cache = getLogStores(CacheKeys.APP_CONFIG); - const cacheKey = role ? role : BASE_CONFIG_KEY; - - if (!refresh) { - const cached = await cache.get(cacheKey); - if (cached) { - return cached; - } - } - - let baseConfig = await cache.get(BASE_CONFIG_KEY); - if (!baseConfig) { - logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...'); - baseConfig = await loadBaseConfig(); - - if (!baseConfig) { - throw new Error('Failed to initialize app configuration through AppService.'); - } - - if (baseConfig.availableTools) { - await setCachedTools(baseConfig.availableTools); - } - - await cache.set(BASE_CONFIG_KEY, baseConfig); - } - - // For now, return the base config - // In the future, this is where we'll apply role-based modifications - if (role) { - // TODO: Apply role-based config modifications - // const roleConfig = await applyRoleBasedConfig(baseConfig, role); - // await cache.set(cacheKey, roleConfig); - // return roleConfig; - } - - return baseConfig; -} +const { getAppConfig, clearAppConfigCache, clearOverrideCache } = createAppConfigService({ + loadBaseConfig, + setCachedTools, + getCache: getLogStores, + cacheKeys: CacheKeys, + getApplicableConfigs: db.getApplicableConfigs, + getUserPrincipals: db.getUserPrincipals, +}); /** - * Clear the app configuration cache - * @returns {Promise} + * Invalidate all config-related caches after an admin config mutation. + * Clears the base config, per-principal override caches, tool caches, + * and the MCP config-source server cache. + * @param {string} [tenantId] - Optional tenant ID to scope override cache clearing. */ -async function clearAppConfigCache() { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cacheKey = CacheKeys.APP_CONFIG; - return await cache.delete(cacheKey); +async function invalidateConfigCaches(tenantId) { + const results = await Promise.allSettled([ + clearAppConfigCache(), + clearOverrideCache(tenantId), + invalidateCachedTools({ invalidateGlobal: true }), + clearMcpConfigCache(), + ]); + const labels = [ + 'clearAppConfigCache', + 'clearOverrideCache', + 'invalidateCachedTools', + 'clearMcpConfigCache', + ]; + for (let i = 0; i < results.length; i++) { + if (results[i].status === 'rejected') { + logger.error(`[invalidateConfigCaches] ${labels[i]} failed:`, results[i].reason); + } + } } module.exports = { getAppConfig, clearAppConfigCache, + invalidateConfigCaches, }; diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index bb22584851..d09b45626c 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -1,133 +1,10 @@ -const { loadCustomEndpointsConfig } = require('@librechat/api'); -const { - CacheKeys, - EModelEndpoint, - isAgentsEndpoint, - orderEndpointsConfig, - defaultAgentCapabilities, -} = require('librechat-data-provider'); +const { createEndpointsConfigService } = require('@librechat/api'); const loadDefaultEndpointsConfig = require('./loadDefaultEConfig'); -const getLogStores = require('~/cache/getLogStores'); const { getAppConfig } = require('./app'); -/** - * - * @param {ServerRequest} req - * @returns {Promise} - */ -async function getEndpointsConfig(req) { - const cache = getLogStores(CacheKeys.CONFIG_STORE); - const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); - if (cachedEndpointsConfig) { - if (cachedEndpointsConfig.gptPlugins) { - await cache.delete(CacheKeys.ENDPOINT_CONFIG); - } else { - return cachedEndpointsConfig; - } - } - - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); - const defaultEndpointsConfig = await loadDefaultEndpointsConfig(appConfig); - const customEndpointsConfig = loadCustomEndpointsConfig(appConfig?.endpoints?.custom); - - /** @type {TEndpointsConfig} */ - const mergedConfig = { - ...defaultEndpointsConfig, - ...customEndpointsConfig, - }; - - if (appConfig.endpoints?.[EModelEndpoint.azureOpenAI]) { - /** @type {Omit} */ - mergedConfig[EModelEndpoint.azureOpenAI] = { - userProvide: false, - }; - } - - // Enable Anthropic endpoint when Vertex AI is configured in YAML - if (appConfig.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig?.enabled) { - /** @type {Omit} */ - mergedConfig[EModelEndpoint.anthropic] = { - userProvide: false, - }; - } - - if (appConfig.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) { - /** @type {Omit} */ - mergedConfig[EModelEndpoint.azureAssistants] = { - userProvide: false, - }; - } - - if ( - mergedConfig[EModelEndpoint.assistants] && - appConfig?.endpoints?.[EModelEndpoint.assistants] - ) { - const { disableBuilder, retrievalModels, capabilities, version, ..._rest } = - appConfig.endpoints[EModelEndpoint.assistants]; - - mergedConfig[EModelEndpoint.assistants] = { - ...mergedConfig[EModelEndpoint.assistants], - version, - retrievalModels, - disableBuilder, - capabilities, - }; - } - if (mergedConfig[EModelEndpoint.agents] && appConfig?.endpoints?.[EModelEndpoint.agents]) { - const { disableBuilder, capabilities, allowedProviders, ..._rest } = - appConfig.endpoints[EModelEndpoint.agents]; - - mergedConfig[EModelEndpoint.agents] = { - ...mergedConfig[EModelEndpoint.agents], - allowedProviders, - disableBuilder, - capabilities, - }; - } - - if ( - mergedConfig[EModelEndpoint.azureAssistants] && - appConfig?.endpoints?.[EModelEndpoint.azureAssistants] - ) { - const { disableBuilder, retrievalModels, capabilities, version, ..._rest } = - appConfig.endpoints[EModelEndpoint.azureAssistants]; - - mergedConfig[EModelEndpoint.azureAssistants] = { - ...mergedConfig[EModelEndpoint.azureAssistants], - version, - retrievalModels, - disableBuilder, - capabilities, - }; - } - - if (mergedConfig[EModelEndpoint.bedrock] && appConfig?.endpoints?.[EModelEndpoint.bedrock]) { - const { availableRegions } = appConfig.endpoints[EModelEndpoint.bedrock]; - mergedConfig[EModelEndpoint.bedrock] = { - ...mergedConfig[EModelEndpoint.bedrock], - availableRegions, - }; - } - - const endpointsConfig = orderEndpointsConfig(mergedConfig); - - await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); - return endpointsConfig; -} - -/** - * @param {ServerRequest} req - * @param {import('librechat-data-provider').AgentCapabilities} capability - * @returns {Promise} - */ -const checkCapability = async (req, capability) => { - const isAgents = isAgentsEndpoint(req.body?.endpointType || req.body?.endpoint); - const endpointsConfig = await getEndpointsConfig(req); - const capabilities = - isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null - ? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []) - : defaultAgentCapabilities; - return capabilities.includes(capability); -}; +const { getEndpointsConfig, checkCapability } = createEndpointsConfigService({ + getAppConfig, + loadDefaultEndpointsConfig, +}); module.exports = { getEndpointsConfig, checkCapability }; diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 2bc83ecc3a..93212cd030 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -1,117 +1,11 @@ -const { isUserProvided, fetchModels } = require('@librechat/api'); -const { - EModelEndpoint, - extractEnvVariable, - normalizeEndpointName, -} = require('librechat-data-provider'); +const { createLoadConfigModels, fetchModels } = require('@librechat/api'); const { getAppConfig } = require('./app'); +const db = require('~/models'); -/** - * Load config endpoints from the cached configuration object - * @function loadConfigModels - * @param {ServerRequest} req - The Express request object. - */ -async function loadConfigModels(req) { - const appConfig = await getAppConfig({ role: req.user?.role }); - if (!appConfig) { - return {}; - } - const modelsConfig = {}; - const azureConfig = appConfig.endpoints?.[EModelEndpoint.azureOpenAI]; - const { modelNames } = azureConfig ?? {}; - - if (modelNames && azureConfig) { - modelsConfig[EModelEndpoint.azureOpenAI] = modelNames; - } - - if (azureConfig?.assistants && azureConfig.assistantModels) { - modelsConfig[EModelEndpoint.azureAssistants] = azureConfig.assistantModels; - } - - const bedrockConfig = appConfig.endpoints?.[EModelEndpoint.bedrock]; - if (bedrockConfig?.models && Array.isArray(bedrockConfig.models)) { - modelsConfig[EModelEndpoint.bedrock] = bedrockConfig.models; - } - - if (!Array.isArray(appConfig.endpoints?.[EModelEndpoint.custom])) { - return modelsConfig; - } - - const customEndpoints = appConfig.endpoints[EModelEndpoint.custom].filter( - (endpoint) => - endpoint.baseURL && - endpoint.apiKey && - endpoint.name && - endpoint.models && - (endpoint.models.fetch || endpoint.models.default), - ); - - /** - * @type {Record>} - * Map for promises keyed by unique combination of baseURL and apiKey */ - const fetchPromisesMap = {}; - /** - * @type {Record} - * Map to associate unique keys with endpoint names; note: one key may can correspond to multiple endpoints */ - const uniqueKeyToEndpointsMap = {}; - /** - * @type {Record>} - * Map to associate endpoint names to their configurations */ - const endpointsMap = {}; - - for (let i = 0; i < customEndpoints.length; i++) { - const endpoint = customEndpoints[i]; - const { models, name: configName, baseURL, apiKey, headers: endpointHeaders } = endpoint; - const name = normalizeEndpointName(configName); - endpointsMap[name] = endpoint; - - const API_KEY = extractEnvVariable(apiKey); - const BASE_URL = extractEnvVariable(baseURL); - - const uniqueKey = `${BASE_URL}__${API_KEY}`; - - modelsConfig[name] = []; - - if (models.fetch && !isUserProvided(API_KEY) && !isUserProvided(BASE_URL)) { - fetchPromisesMap[uniqueKey] = - fetchPromisesMap[uniqueKey] || - fetchModels({ - name, - apiKey: API_KEY, - baseURL: BASE_URL, - user: req.user.id, - userObject: req.user, - headers: endpointHeaders, - direct: endpoint.directEndpoint, - userIdQuery: models.userIdQuery, - }); - uniqueKeyToEndpointsMap[uniqueKey] = uniqueKeyToEndpointsMap[uniqueKey] || []; - uniqueKeyToEndpointsMap[uniqueKey].push(name); - continue; - } - - if (Array.isArray(models.default)) { - modelsConfig[name] = models.default.map((model) => - typeof model === 'string' ? model : model.name, - ); - } - } - - const fetchedData = await Promise.all(Object.values(fetchPromisesMap)); - const uniqueKeys = Object.keys(fetchPromisesMap); - - for (let i = 0; i < fetchedData.length; i++) { - const currentKey = uniqueKeys[i]; - const modelData = fetchedData[i]; - const associatedNames = uniqueKeyToEndpointsMap[currentKey]; - - for (const name of associatedNames) { - const endpoint = endpointsMap[name]; - modelsConfig[name] = !modelData?.length ? (endpoint.models.default ?? []) : modelData; - } - } - - return modelsConfig; -} +const loadConfigModels = createLoadConfigModels({ + getAppConfig, + getUserKeyValues: db.getUserKeyValues, + fetchModels, +}); module.exports = loadConfigModels; diff --git a/api/server/services/Config/loadConfigModels.spec.js b/api/server/services/Config/loadConfigModels.spec.js index 6ffb8ba522..d3ec0309ae 100644 --- a/api/server/services/Config/loadConfigModels.spec.js +++ b/api/server/services/Config/loadConfigModels.spec.js @@ -7,6 +7,13 @@ jest.mock('@librechat/api', () => ({ fetchModels: jest.fn(), })); jest.mock('./app'); +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { debug: jest.fn(), error: jest.fn(), warn: jest.fn() }, +})); +jest.mock('~/models', () => ({ + getUserKeyValues: jest.fn(), +})); const exampleConfig = { endpoints: { @@ -68,11 +75,11 @@ describe('loadConfigModels', () => { const originalEnv = process.env; beforeEach(() => { - jest.resetAllMocks(); - jest.resetModules(); + jest.clearAllMocks(); + fetchModels.mockReset(); + require('~/models').getUserKeyValues.mockReset(); process.env = { ...originalEnv }; - // Default mock for getAppConfig getAppConfig.mockResolvedValue({}); }); @@ -337,6 +344,168 @@ describe('loadConfigModels', () => { expect(result.FalsyFetchModel).toEqual(['defaultModel1', 'defaultModel2']); }); + describe('user-provided API key model fetching', () => { + it('fetches models using user-provided API key when key is stored', async () => { + const { getUserKeyValues } = require('~/models'); + getUserKeyValues.mockResolvedValueOnce({ + apiKey: 'sk-user-key', + baseURL: 'https://api.x.com/v1', + }); + getAppConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'UserEndpoint', + apiKey: 'user_provided', + baseURL: 'user_provided', + models: { fetch: true, default: ['fallback-model'] }, + }, + ], + }, + }); + fetchModels.mockResolvedValue(['fetched-model-a', 'fetched-model-b']); + + const result = await loadConfigModels(mockRequest); + + expect(getUserKeyValues).toHaveBeenCalledWith({ userId: 'testUserId', name: 'UserEndpoint' }); + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: 'sk-user-key', + baseURL: 'https://api.x.com/v1', + skipCache: true, + }), + ); + expect(result.UserEndpoint).toEqual(['fetched-model-a', 'fetched-model-b']); + }); + + it('falls back to defaults when getUserKeyValues returns no apiKey', async () => { + const { getUserKeyValues } = require('~/models'); + getUserKeyValues.mockResolvedValueOnce({ baseURL: 'https://api.x.com/v1' }); + getAppConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'NoKeyEndpoint', + apiKey: 'user_provided', + baseURL: 'https://api.x.com/v1', + models: { fetch: true, default: ['default-model'] }, + }, + ], + }, + }); + + const result = await loadConfigModels(mockRequest); + + expect(fetchModels).not.toHaveBeenCalled(); + expect(result.NoKeyEndpoint).toEqual(['default-model']); + }); + + it('falls back to defaults and logs warn when getUserKeyValues throws infra error', async () => { + const { getUserKeyValues } = require('~/models'); + const { logger } = require('@librechat/data-schemas'); + getUserKeyValues.mockRejectedValueOnce(new Error('DB connection timeout')); + getAppConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'ErrorEndpoint', + apiKey: 'user_provided', + baseURL: 'https://api.example.com/v1', + models: { fetch: true, default: ['fallback'] }, + }, + ], + }, + }); + + const result = await loadConfigModels(mockRequest); + + expect(fetchModels).not.toHaveBeenCalled(); + expect(result.ErrorEndpoint).toEqual(['fallback']); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining( + 'Failed to retrieve user key for "ErrorEndpoint": DB connection timeout', + ), + ); + expect(logger.debug).not.toHaveBeenCalledWith(expect.stringContaining('No user key stored')); + }); + + it('logs debug (not warn) for NO_USER_KEY errors', async () => { + const { getUserKeyValues } = require('~/models'); + const { logger } = require('@librechat/data-schemas'); + getUserKeyValues.mockRejectedValueOnce(new Error(JSON.stringify({ type: 'no_user_key' }))); + getAppConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'MissingKeyEndpoint', + apiKey: 'user_provided', + baseURL: 'https://api.example.com/v1', + models: { fetch: true, default: ['default-model'] }, + }, + ], + }, + }); + + const result = await loadConfigModels(mockRequest); + + expect(result.MissingKeyEndpoint).toEqual(['default-model']); + expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('No user key stored')); + expect(logger.warn).not.toHaveBeenCalledWith( + expect.stringContaining('Failed to retrieve user key'), + ); + }); + + it('skips user key lookup when req.user.id is undefined', async () => { + const { getUserKeyValues } = require('~/models'); + getAppConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'NoUserEndpoint', + apiKey: 'user_provided', + baseURL: 'https://api.x.com/v1', + models: { fetch: true, default: ['anon-model'] }, + }, + ], + }, + }); + + const result = await loadConfigModels({ user: {} }); + + expect(getUserKeyValues).not.toHaveBeenCalled(); + expect(result.NoUserEndpoint).toEqual(['anon-model']); + }); + + it('uses stored baseURL only when baseURL is user_provided', async () => { + const { getUserKeyValues } = require('~/models'); + getUserKeyValues.mockResolvedValueOnce({ apiKey: 'sk-key' }); + getAppConfig.mockResolvedValue({ + endpoints: { + custom: [ + { + name: 'KeyOnly', + apiKey: 'user_provided', + baseURL: 'https://fixed-base.com/v1', + models: { fetch: true, default: ['default'] }, + }, + ], + }, + }); + fetchModels.mockResolvedValue(['model-from-fixed-base']); + + const result = await loadConfigModels(mockRequest); + + expect(fetchModels).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: 'sk-key', + baseURL: 'https://fixed-base.com/v1', + skipCache: true, + }), + ); + expect(result.KeyOnly).toEqual(['model-from-fixed-base']); + }); + }); + it('normalizes Ollama endpoint name to lowercase', async () => { const testCases = [ { diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 31aa831a70..85f2c42a33 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -16,7 +16,8 @@ const { getAppConfig } = require('./app'); */ async function loadDefaultModels(req) { try { - const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); + const appConfig = + req.config ?? (await getAppConfig({ role: req.user?.role, tenantId: req.user?.tenantId })); const vertexConfig = appConfig?.endpoints?.[EModelEndpoint.anthropic]?.vertexConfig; const [openAI, anthropic, azureOpenAI, assistants, azureAssistants, google, bedrock] = diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index cc4e98b59e..fa37e223f5 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -1,97 +1,10 @@ -const { logger } = require('@librechat/data-schemas'); -const { CacheKeys, Constants } = require('librechat-data-provider'); +const { createMCPToolCacheService } = require('@librechat/api'); const { getCachedTools, setCachedTools } = require('./getCachedTools'); -const { getLogStores } = require('~/cache'); -/** - * Updates MCP tools in the cache for a specific server - * @param {Object} params - Parameters for updating MCP tools - * @param {string} params.userId - User ID for user-specific caching - * @param {string} params.serverName - MCP server name - * @param {Array} params.tools - Array of tool objects from MCP server - * @returns {Promise} - */ -async function updateMCPServerTools({ userId, serverName, tools }) { - try { - const serverTools = {}; - const mcpDelimiter = Constants.mcp_delimiter; - - if (tools == null || tools.length === 0) { - logger.debug(`[MCP Cache] No tools to update for server ${serverName} (user: ${userId})`); - return serverTools; - } - - for (const tool of tools) { - const name = `${tool.name}${mcpDelimiter}${serverName}`; - serverTools[name] = { - type: 'function', - ['function']: { - name, - description: tool.description, - parameters: tool.inputSchema, - }, - }; - } - - await setCachedTools(serverTools, { userId, serverName }); - - const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); - logger.debug( - `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, - ); - return serverTools; - } catch (error) { - logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error); - throw error; - } -} - -/** - * Merges app-level tools with global tools - * @param {import('@librechat/api').LCAvailableTools} appTools - * @returns {Promise} - */ -async function mergeAppTools(appTools) { - try { - const count = Object.keys(appTools).length; - if (!count) { - return; - } - const cachedTools = await getCachedTools(); - const mergedTools = { ...cachedTools, ...appTools }; - await setCachedTools(mergedTools); - const cache = getLogStores(CacheKeys.TOOL_CACHE); - await cache.delete(CacheKeys.TOOLS); - logger.debug(`Merged ${count} app-level tools`); - } catch (error) { - logger.error('Failed to merge app-level tools:', error); - throw error; - } -} - -/** - * Caches MCP server tools (no longer merges with global) - * @param {object} params - * @param {string} params.userId - User ID for user-specific caching - * @param {string} params.serverName - * @param {import('@librechat/api').LCAvailableTools} params.serverTools - * @returns {Promise} - */ -async function cacheMCPServerTools({ userId, serverName, serverTools }) { - try { - const count = Object.keys(serverTools).length; - if (!count) { - return; - } - // Only cache server-specific tools, no merging with global - await setCachedTools(serverTools, { userId, serverName }); - logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`); - } catch (error) { - logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error); - throw error; - } -} +const { mergeAppTools, cacheMCPServerTools, updateMCPServerTools } = createMCPToolCacheService({ + getCachedTools, + setCachedTools, +}); module.exports = { mergeAppTools, diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index 4ba62a7eeb..c9a35c35ea 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -142,6 +142,7 @@ class STTService { req.config ?? (await getAppConfig({ role: req?.user?.role, + tenantId: req?.user?.tenantId, })); const sttSchema = appConfig?.speech?.stt; if (!sttSchema) { diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index 2c932968c6..1125dd74ed 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -297,6 +297,7 @@ class TTSService { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); try { res.setHeader('Content-Type', 'audio/mpeg'); @@ -365,6 +366,7 @@ class TTSService { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); const provider = this.getProvider(appConfig); const ttsSchema = appConfig?.speech?.tts?.[provider]; diff --git a/api/server/services/Files/Audio/getCustomConfigSpeech.js b/api/server/services/Files/Audio/getCustomConfigSpeech.js index d0d0b51ac2..b438771ec1 100644 --- a/api/server/services/Files/Audio/getCustomConfigSpeech.js +++ b/api/server/services/Files/Audio/getCustomConfigSpeech.js @@ -17,6 +17,7 @@ async function getCustomConfigSpeech(req, res) { try { const appConfig = await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, }); if (!appConfig) { diff --git a/api/server/services/Files/Audio/getVoices.js b/api/server/services/Files/Audio/getVoices.js index f2f8e100c3..22bd7cea6e 100644 --- a/api/server/services/Files/Audio/getVoices.js +++ b/api/server/services/Files/Audio/getVoices.js @@ -18,6 +18,7 @@ async function getVoices(req, res) { req.config ?? (await getAppConfig({ role: req.user?.role, + tenantId: req.user?.tenantId, })); const ttsSchema = appConfig?.speech?.tts; diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index c28a96edff..7120399b5e 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,3 +1,4 @@ +const { scopedCacheKey } = require('@librechat/data-schemas'); const { Time, CacheKeys, @@ -67,6 +68,8 @@ function createChunkProcessor(user, messageId) { } const messageCache = getLogStores(CacheKeys.MESSAGES); + // Captured at creation time — must be called within an active request ALS scope + const cacheKey = scopedCacheKey(messageId); /** * @returns {Promise<{ text: string, isFinished: boolean }[] | string>} @@ -81,7 +84,7 @@ function createChunkProcessor(user, messageId) { } /** @type { string | { text: string; complete: boolean } } */ - let message = await messageCache.get(messageId); + let message = await messageCache.get(cacheKey); if (!message) { message = await getMessage({ user, messageId }); } @@ -92,7 +95,7 @@ function createChunkProcessor(user, messageId) { } else { const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text; messageCache.set( - messageId, + cacheKey, { text, complete: true, diff --git a/api/server/services/Files/Citations/index.js b/api/server/services/Files/Citations/index.js index 008e21d7c4..a1d9322467 100644 --- a/api/server/services/Files/Citations/index.js +++ b/api/server/services/Files/Citations/index.js @@ -47,7 +47,10 @@ async function processFileCitations({ user, appConfig, toolArtifact, toolCallId, logger.error( `[processFileCitations] Permission check failed for FILE_CITATIONS: ${error.message}`, ); - logger.debug(`[processFileCitations] Proceeding with citations due to permission error`); + logger.warn( + '[processFileCitations] Returning null citations due to permission check error — citations will not be shown for this message', + ); + return null; } } @@ -145,6 +148,8 @@ async function enhanceSourcesWithMetadata(sources, appConfig) { metadata: { ...source.metadata, storageType: configuredStorageType, + fileType: fileRecord.type || undefined, + fileBytes: fileRecord.bytes || undefined, }, }; }); diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 5d97891c55..dbb44740a9 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,5 +1,5 @@ const { tool } = require('@langchain/core/tools'); -const { logger } = require('@librechat/data-schemas'); +const { logger, getTenantId } = require('@librechat/data-schemas'); const { Providers, StepTypes, @@ -14,6 +14,7 @@ const { normalizeJsonSchema, GenerationJobManager, resolveJsonSchemaRefs, + buildOAuthToolCallName, } = require('@librechat/api'); const { Time, CacheKeys, Constants, isAssistantsEndpoint } = require('librechat-data-provider'); const { @@ -53,6 +54,53 @@ function evictStale(map, ttl) { const unavailableMsg = "This tool's MCP server is temporarily unavailable. Please try again shortly."; +/** + * Resolves config-source MCP servers from admin Config overrides for the current + * request context. Returns the parsed configs keyed by server name. + * @param {import('express').Request} req - Express request with user context + * @returns {Promise>} + */ +async function resolveConfigServers(req) { + try { + const registry = getMCPServersRegistry(); + const user = req?.user; + const appConfig = await getAppConfig({ + role: user?.role, + tenantId: getTenantId(), + userId: user?.id, + }); + return await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveConfigServers] Failed to resolve config servers, degrading to empty:', + error, + ); + return {}; + } +} + +/** + * Resolves config-source servers and merges all server configs (YAML + config + user DB) + * for the given user context. Shared helper for controllers needing the full merged config. + * @param {string} userId + * @param {{ id?: string, role?: string }} [user] + * @returns {Promise>} + */ +async function resolveAllMcpConfigs(userId, user) { + const registry = getMCPServersRegistry(); + const appConfig = await getAppConfig({ role: user?.role, tenantId: getTenantId(), userId }); + let configServers = {}; + try { + configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + } catch (error) { + logger.warn( + '[resolveAllMcpConfigs] Config server resolution failed, continuing without:', + error, + ); + } + return await registry.getAllServerConfigs(userId, configServers); +} + /** * @param {string} toolName * @param {string} serverName @@ -248,6 +296,7 @@ async function reconnectServer({ index, signal, serverName, + configServers, userMCPAuthMap, streamId = null, }) { @@ -271,7 +320,7 @@ async function reconnectServer({ const stepId = 'step_oauth_login_' + serverName; const toolCall = { id: flowId, - name: serverName, + name: buildOAuthToolCallName(serverName), type: 'tool_call_chunk', }; @@ -316,6 +365,7 @@ async function reconnectServer({ user, signal, serverName, + configServers, oauthStart, flowManager, userMCPAuthMap, @@ -358,15 +408,14 @@ async function createMCPTools({ config, provider, serverName, + configServers, userMCPAuthMap, streamId = null, }) { - // Early domain validation before reconnecting server (avoid wasted work on disallowed domains) - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); if (!isDomainAllowed) { @@ -381,6 +430,7 @@ async function createMCPTools({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -400,6 +450,7 @@ async function createMCPTools({ user, provider, userMCPAuthMap, + configServers, streamId, availableTools: result.availableTools, toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`, @@ -439,16 +490,15 @@ async function createMCPTool({ userMCPAuthMap, availableTools, config, + configServers, streamId = null, }) { const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); - // Runtime domain validation: check if the server's domain is still allowed - // Use getAppConfig() to support per-user/role domain restrictions const serverConfig = - config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id)); + config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.url) { - const appConfig = await getAppConfig({ role: user?.role }); + const appConfig = await getAppConfig({ role: user?.role, tenantId: user?.tenantId }); const allowedDomains = appConfig?.mcpSettings?.allowedDomains; const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains); if (!isDomainAllowed) { @@ -477,6 +527,7 @@ async function createMCPTool({ index, signal, serverName, + configServers, userMCPAuthMap, streamId, }); @@ -500,6 +551,7 @@ async function createMCPTool({ provider, toolName, serverName, + serverConfig, toolDefinition, streamId, }); @@ -509,13 +561,14 @@ function createToolInstance({ res, toolName, serverName, + serverConfig: capturedServerConfig, toolDefinition, - provider: _provider, + provider: capturedProvider, streamId = null, }) { /** @type {LCTool} */ const { description, parameters } = toolDefinition; - const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; + const isGoogle = capturedProvider === Providers.VERTEXAI || capturedProvider === Providers.GOOGLE; let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null; @@ -544,7 +597,7 @@ function createToolInstance({ const flowManager = getFlowStateManager(flowsCache); derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; const mcpManager = getMCPManager(userId); - const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); + const provider = (config?.metadata?.provider || capturedProvider)?.toLowerCase(); const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; const flowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; @@ -576,6 +629,7 @@ function createToolInstance({ const result = await mcpManager.callTool({ serverName, + serverConfig: capturedServerConfig, toolName, provider, toolArguments, @@ -643,30 +697,36 @@ function createToolInstance({ } /** - * Get MCP setup data including config, connections, and OAuth servers + * Get MCP setup data including config, connections, and OAuth servers. + * Resolves config-source servers from admin Config overrides when tenant context is available. * @param {string} userId - The user ID + * @param {{ role?: string, tenantId?: string }} [options] - Optional role/tenant context * @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers */ -async function getMCPSetupData(userId) { - const mcpConfig = await getMCPServersRegistry().getAllServerConfigs(userId); - - if (!mcpConfig) { - throw new Error('MCP config not found'); - } +async function getMCPSetupData(userId, options = {}) { + const registry = getMCPServersRegistry(); + const { role, tenantId } = options; + const appConfig = await getAppConfig({ role, tenantId, userId }); + const configServers = await registry.ensureConfigServers(appConfig?.mcpConfig || {}); + const mcpConfig = await registry.getAllServerConfigs(userId, configServers); const mcpManager = getMCPManager(userId); /** @type {Map} */ let appConnections = new Map(); try { - // Use getLoaded() instead of getAll() to avoid forcing connection creation + // Use getLoaded() instead of getAll() to avoid forcing connection creation. // getAll() creates connections for all servers, which is problematic for servers - // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders) + // that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders). appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map(); } catch (error) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = await getMCPServersRegistry().getOAuthServers(userId); + const oauthServers = new Set( + Object.entries(mcpConfig) + .filter(([, config]) => config.requiresOAuth) + .map(([name]) => name), + ); return { mcpConfig, @@ -788,6 +848,8 @@ module.exports = { createMCPTool, createMCPTools, getMCPSetupData, + resolveConfigServers, + resolveAllMcpConfigs, checkOAuthFlowStatus, getServerConnectionStatus, createUnavailableToolStub, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 14a9ef90ed..c9925827f8 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -14,6 +14,7 @@ const mockRegistryInstance = { getOAuthServers: jest.fn(() => Promise.resolve(new Set())), getAllServerConfigs: jest.fn(() => Promise.resolve({})), getServerConfig: jest.fn(() => Promise.resolve(null)), + ensureConfigServers: jest.fn(() => Promise.resolve({})), }; // Create isMCPDomainAllowed mock that can be configured per-test @@ -113,38 +114,43 @@ describe('tests for the new helper functions used by the MCP connection status e }); it('should successfully return MCP setup data', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig); + const mockConfigWithOAuth = { + server1: { type: 'stdio' }, + server2: { type: 'http', requiresOAuth: true }, + }; + mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfigWithOAuth); const mockAppConnections = new Map([['server1', { status: 'connected' }]]); const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]); - const mockOAuthServers = new Set(['server2']); const mockMCPManager = { appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) }, getUserConnections: jest.fn(() => mockUserConnections), }; mockGetMCPManager.mockReturnValue(mockMCPManager); - mockRegistryInstance.getOAuthServers.mockResolvedValue(mockOAuthServers); const result = await getMCPSetupData(mockUserId); - expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId); + expect(mockRegistryInstance.ensureConfigServers).toHaveBeenCalled(); + expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith( + mockUserId, + expect.any(Object), + ); expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled(); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); - expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId); - expect(result).toEqual({ - mcpConfig: mockConfig, - appConnections: mockAppConnections, - userConnections: mockUserConnections, - oauthServers: mockOAuthServers, - }); + expect(result.mcpConfig).toEqual(mockConfigWithOAuth); + expect(result.appConnections).toEqual(mockAppConnections); + expect(result.userConnections).toEqual(mockUserConnections); + expect(result.oauthServers).toEqual(new Set(['server2'])); }); - it('should throw error when MCP config not found', async () => { - mockRegistryInstance.getAllServerConfigs.mockResolvedValue(null); - await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found'); + it('should return empty data when no servers are configured', async () => { + mockRegistryInstance.getAllServerConfigs.mockResolvedValue({}); + const result = await getMCPSetupData(mockUserId); + expect(result.mcpConfig).toEqual({}); + expect(result.oauthServers).toEqual(new Set()); }); it('should handle null values from MCP manager gracefully', async () => { diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index ca75e7eb4f..b4d948eda4 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -19,6 +19,7 @@ const { buildWebSearchContext, buildImageToolContext, buildToolClassification, + buildOAuthToolCallName, } = require('@librechat/api'); const { Time, @@ -30,6 +31,7 @@ const { imageGenTools, EModelEndpoint, EToolResources, + isActionTool, actionDelimiter, ImageVisionTool, openapiToFunction, @@ -59,6 +61,7 @@ const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); +const { resolveConfigServers } = require('~/server/services/MCP'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); @@ -488,7 +491,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to if (tool === Tools.web_search) { return checkCapability(AgentCapabilities.web_search); } - if (tool.includes(actionDelimiter)) { + if (isActionTool(tool)) { return actionsEnabled; } if (!areToolsEnabled) { @@ -513,6 +516,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); + const configServers = await resolveConfigServers(req); const pendingOAuthServers = new Set(); const createOAuthEmitter = (serverName) => { @@ -521,7 +525,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const stepId = 'step_oauth_login_' + serverName; const toolCall = { id: flowId, - name: serverName, + name: buildOAuthToolCallName(serverName), type: 'tool_call_chunk', }; @@ -578,6 +582,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to oauthStart, flowManager, serverName, + configServers, userMCPAuthMap, }); @@ -665,6 +670,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const result = await reinitMCPServer({ user: req.user, serverName, + configServers, userMCPAuthMap, flowManager, returnOnOAuth: false, @@ -866,7 +872,7 @@ async function loadAgentTools({ } else if (tool === Tools.web_search) { includesWebSearch = checkCapability(AgentCapabilities.web_search); return includesWebSearch; - } else if (tool.includes(actionDelimiter)) { + } else if (isActionTool(tool)) { return actionsEnabled; } else if (!areToolsEnabled) { return false; @@ -973,7 +979,7 @@ async function loadAgentTools({ agentTools.push(...additionalTools); - const hasActionTools = _agentTools.some((t) => t.includes(actionDelimiter)); + const hasActionTools = _agentTools.some((t) => isActionTool(t)); if (!hasActionTools) { return { toolRegistry, @@ -1232,8 +1238,11 @@ async function loadToolsForExecution({ ? [...new Set([...requestedNonSpecialToolNames, ...ptcOrchestratedToolNames])] : requestedNonSpecialToolNames; - const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter)); - const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter)); + const actionToolNames = []; + const regularToolNames = []; + for (const name of allToolNamesToLoad) { + (isActionTool(name) ? actionToolNames : regularToolNames).push(name); + } if (regularToolNames.length > 0) { const includesWebSearch = regularToolNames.includes(Tools.web_search); diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 7589043e10..f1ebcf9796 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -25,11 +25,13 @@ async function reinitMCPServer({ signal, forceNew, serverName, + configServers, userMCPAuthMap, connectionTimeout, returnOnOAuth = true, oauthStart: _oauthStart, flowManager: _flowManager, + serverConfig: providedConfig, }) { /** @type {MCPConnection | null} */ let connection = null; @@ -42,13 +44,28 @@ async function reinitMCPServer({ try { const registry = getMCPServersRegistry(); - const serverConfig = await registry.getServerConfig(serverName, user?.id); + const serverConfig = + providedConfig ?? (await registry.getServerConfig(serverName, user?.id, configServers)); if (serverConfig?.inspectionFailed) { + if (serverConfig.source === 'config') { + logger.info( + `[MCP Reinitialize] Config-source server ${serverName} has inspectionFailed — retry handled by config cache`, + ); + return { + availableTools: null, + success: false, + message: `MCP server '${serverName}' is still unreachable`, + oauthRequired: false, + serverName, + oauthUrl: null, + tools: null, + }; + } logger.info( `[MCP Reinitialize] Server ${serverName} had failed inspection, attempting reinspection`, ); try { - const storageLocation = serverConfig.dbId ? 'DB' : 'CACHE'; + const storageLocation = serverConfig.source === 'user' ? 'DB' : 'CACHE'; await registry.reinspectServer(serverName, storageLocation, user?.id); logger.info(`[MCP Reinitialize] Reinspection succeeded for server: ${serverName}`); } catch (reinspectError) { @@ -93,6 +110,7 @@ async function reinitMCPServer({ returnOnOAuth, customUserVars, connectionTimeout, + serverConfig, }); logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); @@ -125,6 +143,7 @@ async function reinitMCPServer({ oauthStart, customUserVars, connectionTimeout, + configServers, }); if (discoveryResult.tools && discoveryResult.tools.length > 0) { diff --git a/api/server/services/__tests__/MCP.spec.js b/api/server/services/__tests__/MCP.spec.js new file mode 100644 index 0000000000..39e99d54ac --- /dev/null +++ b/api/server/services/__tests__/MCP.spec.js @@ -0,0 +1,131 @@ +const mockRegistry = { + ensureConfigServers: jest.fn(), + getAllServerConfigs: jest.fn(), +}; + +jest.mock('~/config', () => ({ + getMCPServersRegistry: jest.fn(() => mockRegistry), + getMCPManager: jest.fn(), + getFlowStateManager: jest.fn(), + getOAuthReconnectionManager: jest.fn(), +})); + +jest.mock('@librechat/data-schemas', () => ({ + getTenantId: jest.fn(() => 'tenant-1'), + logger: { debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn() }, +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(), + setCachedTools: jest.fn(), + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), + loadCustomConfig: jest.fn(), +})); + +jest.mock('~/cache', () => ({ getLogStores: jest.fn() })); +jest.mock('~/models', () => ({ + findToken: jest.fn(), + createToken: jest.fn(), + updateToken: jest.fn(), +})); +jest.mock('~/server/services/GraphTokenService', () => ({ + getGraphApiToken: jest.fn(), +})); +jest.mock('~/server/services/Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); + +const { getAppConfig } = require('~/server/services/Config'); +const { resolveConfigServers, resolveAllMcpConfigs } = require('../MCP'); + +describe('resolveConfigServers', () => { + beforeEach(() => jest.clearAllMocks()); + + it('resolves config servers for the current request context', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: { url: 'http://a' } } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ srv: { name: 'srv' } }); + + const result = await resolveConfigServers({ user: { id: 'u1', role: 'admin' } }); + + expect(result).toEqual({ srv: { name: 'srv' } }); + expect(getAppConfig).toHaveBeenCalledWith( + expect.objectContaining({ role: 'admin', userId: 'u1' }), + ); + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({ srv: { url: 'http://a' } }); + }); + + it('returns {} when ensureConfigServers throws', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('returns {} when getAppConfig throws', async () => { + getAppConfig.mockRejectedValue(new Error('db timeout')); + + const result = await resolveConfigServers({ user: { id: 'u1' } }); + + expect(result).toEqual({}); + }); + + it('passes empty mcpConfig when appConfig has none', async () => { + getAppConfig.mockResolvedValue({}); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + + await resolveConfigServers({ user: { id: 'u1' } }); + + expect(mockRegistry.ensureConfigServers).toHaveBeenCalledWith({}); + }); +}); + +describe('resolveAllMcpConfigs', () => { + beforeEach(() => jest.clearAllMocks()); + + it('merges config servers with base servers', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { cfg_srv: {} } }); + mockRegistry.ensureConfigServers.mockResolvedValue({ cfg_srv: { name: 'cfg_srv' } }); + mockRegistry.getAllServerConfigs.mockResolvedValue({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1', role: 'user' }); + + expect(result).toEqual({ + cfg_srv: { name: 'cfg_srv' }, + yaml_srv: { name: 'yaml_srv' }, + }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', { + cfg_srv: { name: 'cfg_srv' }, + }); + }); + + it('continues with empty configServers when ensureConfigServers fails', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: { srv: {} } }); + mockRegistry.ensureConfigServers.mockRejectedValue(new Error('inspect failed')); + mockRegistry.getAllServerConfigs.mockResolvedValue({ yaml_srv: { name: 'yaml_srv' } }); + + const result = await resolveAllMcpConfigs('u1', { id: 'u1' }); + + expect(result).toEqual({ yaml_srv: { name: 'yaml_srv' } }); + expect(mockRegistry.getAllServerConfigs).toHaveBeenCalledWith('u1', {}); + }); + + it('propagates getAllServerConfigs failures', async () => { + getAppConfig.mockResolvedValue({ mcpConfig: {} }); + mockRegistry.ensureConfigServers.mockResolvedValue({}); + mockRegistry.getAllServerConfigs.mockRejectedValue(new Error('redis down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('redis down'); + }); + + it('propagates getAppConfig failures', async () => { + getAppConfig.mockRejectedValue(new Error('mongo down')); + + await expect(resolveAllMcpConfigs('u1', { id: 'u1' })).rejects.toThrow('mongo down'); + }); +}); diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index a468a88eb3..740bb06e5a 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -2,6 +2,7 @@ const { Tools, Constants, EModelEndpoint, + isActionTool, actionDelimiter, AgentCapabilities, defaultAgentCapabilities, @@ -64,6 +65,9 @@ jest.mock('~/models', () => ({ jest.mock('~/config', () => ({ getFlowStateManager: jest.fn(() => ({})), })); +jest.mock('~/server/services/MCP', () => ({ + resolveConfigServers: jest.fn().mockResolvedValue({}), +})); jest.mock('~/cache', () => ({ getLogStores: jest.fn(() => ({})), })); @@ -140,6 +144,42 @@ describe('ToolService - Action Capability Gating', () => { }); }); + describe('isActionTool — cross-delimiter collision guard', () => { + it('should identify real action tools', () => { + expect(isActionTool(`get_weather${actionDelimiter}api_example_com`)).toBe(true); + expect(isActionTool(`fetch_data${actionDelimiter}my---domain---com`)).toBe(true); + }); + + it('should identify action tools whose operationId contains _mcp_', () => { + expect(isActionTool(`sync_mcp_state${actionDelimiter}api---example---com`)).toBe(true); + expect(isActionTool(`get_mcp_config${actionDelimiter}internal---api---com`)).toBe(true); + }); + + it('should reject MCP tools whose name ends with _action', () => { + expect(isActionTool(`get_action${Constants.mcp_delimiter}myserver`)).toBe(false); + expect(isActionTool(`fetch_action${Constants.mcp_delimiter}server_name`)).toBe(false); + expect(isActionTool(`retrieve_action${Constants.mcp_delimiter}srv`)).toBe(false); + }); + + it('should reject MCP tools with _action_ in the middle of their name', () => { + expect(isActionTool(`get_action_data${Constants.mcp_delimiter}myserver`)).toBe(false); + expect(isActionTool(`create_action_item${Constants.mcp_delimiter}server`)).toBe(false); + }); + + it('should reject tools without the action delimiter', () => { + expect(isActionTool('calculator')).toBe(false); + expect(isActionTool(`web_search${Constants.mcp_delimiter}myserver`)).toBe(false); + }); + + it('known limitation: non-RFC domain with _mcp_ substring yields false negative', () => { + // RFC 952/1123 prohibit underscores in hostnames, so this is not expected in practice. + // Encoded domain `api_mcp_internal_com` places `_mcp_` after `_action_`, which + // the guard interprets as the MCP suffix. + const edgeCaseTool = `getData${actionDelimiter}api_mcp_internal_com`; + expect(isActionTool(edgeCaseTool)).toBe(false); + }); + }); + describe('loadAgentTools (definitionsOnly=true) — action tool filtering', () => { const actionToolName = `get_weather${actionDelimiter}api_example_com`; const regularTool = 'calculator'; @@ -180,6 +220,25 @@ describe('ToolService - Action Capability Gating', () => { expect(callArgs.tools).toContain(actionToolName); }); + it('should not filter MCP tools whose name contains _action (cross-delimiter collision)', async () => { + const mcpToolWithAction = `get_action${Constants.mcp_delimiter}myserver`; + const capabilities = [AgentCapabilities.tools]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, mcpToolWithAction] }, + definitionsOnly: true, + }); + + expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1); + const [callArgs] = mockLoadToolDefinitions.mock.calls[0]; + expect(callArgs.tools).toContain(mcpToolWithAction); + expect(callArgs.tools).toContain(regularTool); + }); + it('should return actionsEnabled in the result', async () => { const capabilities = [AgentCapabilities.tools]; const req = createMockReq(capabilities); diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index c7f27acd0e..5728730131 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -7,7 +7,7 @@ const { createMCPServersRegistry, createMCPManager } = require('~/config'); * Initialize MCP servers */ async function initializeMCPs() { - const appConfig = await getAppConfig(); + const appConfig = await getAppConfig({ baseOnly: true }); const mcpServers = appConfig.mcpConfig; try { diff --git a/api/server/socialLogins.js b/api/server/socialLogins.js index 79ede61778..78f0e82a32 100644 --- a/api/server/socialLogins.js +++ b/api/server/socialLogins.js @@ -6,11 +6,16 @@ const { logger, DEFAULT_SESSION_EXPIRY } = require('@librechat/data-schemas'); const { openIdJwtLogin, facebookLogin, + facebookAdminLogin, discordLogin, + discordAdminLogin, setupOpenId, googleLogin, + googleAdminLogin, githubLogin, + githubAdminLogin, appleLogin, + appleAdminLogin, setupSaml, } = require('~/strategies'); const { getLogStores } = require('~/cache'); @@ -58,18 +63,23 @@ const configureSocialLogins = async (app) => { if (process.env.GOOGLE_CLIENT_ID && process.env.GOOGLE_CLIENT_SECRET) { passport.use(googleLogin()); + passport.use('googleAdmin', googleAdminLogin()); } if (process.env.FACEBOOK_CLIENT_ID && process.env.FACEBOOK_CLIENT_SECRET) { passport.use(facebookLogin()); + passport.use('facebookAdmin', facebookAdminLogin()); } if (process.env.GITHUB_CLIENT_ID && process.env.GITHUB_CLIENT_SECRET) { passport.use(githubLogin()); + passport.use('githubAdmin', githubAdminLogin()); } if (process.env.DISCORD_CLIENT_ID && process.env.DISCORD_CLIENT_SECRET) { passport.use(discordLogin()); + passport.use('discordAdmin', discordAdminLogin()); } if (process.env.APPLE_CLIENT_ID && process.env.APPLE_PRIVATE_KEY_PATH) { passport.use(appleLogin()); + passport.use('appleAdmin', appleAdminLogin()); } if ( process.env.OPENID_CLIENT_ID && diff --git a/api/server/utils/import/importConversations.js b/api/server/utils/import/importConversations.js index e56176c609..ad2d743f01 100644 --- a/api/server/utils/import/importConversations.js +++ b/api/server/utils/import/importConversations.js @@ -7,10 +7,10 @@ const maxFileSize = resolveImportMaxFileSize(); /** * Job definition for importing a conversation. - * @param {{ filepath, requestUserId }} job - The job object. + * @param {{ filepath: string, requestUserId: string, userRole?: string }} job */ const importConversations = async (job) => { - const { filepath, requestUserId } = job; + const { filepath, requestUserId, userRole } = job; try { logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`); @@ -24,7 +24,7 @@ const importConversations = async (job) => { const fileData = await fs.readFile(filepath, 'utf8'); const jsonData = JSON.parse(fileData); const importer = getImporter(jsonData); - await importer(jsonData, requestUserId); + await importer(jsonData, requestUserId, undefined, userRole); logger.debug(`user: ${requestUserId} | Finished importing conversations`); } catch (error) { logger.error(`user: ${requestUserId} | Failed to import conversation: `, error); diff --git a/api/server/utils/import/importers-timestamp.spec.js b/api/server/utils/import/importers-timestamp.spec.js index 09021a9ccd..e12c099abb 100644 --- a/api/server/utils/import/importers-timestamp.spec.js +++ b/api/server/utils/import/importers-timestamp.spec.js @@ -8,17 +8,16 @@ jest.mock('~/models', () => ({ bulkSaveConvos: jest.fn(), bulkSaveMessages: jest.fn(), })); -jest.mock('~/cache/getLogStores'); -const getLogStores = require('~/cache/getLogStores'); -const mockedCacheGet = jest.fn(); -getLogStores.mockImplementation(() => ({ - get: mockedCacheGet, + +const mockGetEndpointsConfig = jest.fn().mockResolvedValue(null); +jest.mock('~/server/services/Config', () => ({ + getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args), })); describe('Import Timestamp Ordering', () => { beforeEach(() => { jest.clearAllMocks(); - mockedCacheGet.mockResolvedValue(null); + mockGetEndpointsConfig.mockResolvedValue(null); }); describe('LibreChat Import - Timestamp Issues', () => { diff --git a/api/server/utils/import/importers.js b/api/server/utils/import/importers.js index 39734c181c..7bcca41e04 100644 --- a/api/server/utils/import/importers.js +++ b/api/server/utils/import/importers.js @@ -1,9 +1,9 @@ const { v4: uuidv4 } = require('uuid'); -const { logger } = require('@librechat/data-schemas'); -const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider'); +const { logger, getTenantId } = require('@librechat/data-schemas'); +const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider'); +const { getEndpointsConfig } = require('~/server/services/Config'); const { createImportBatchBuilder } = require('./importBatchBuilder'); const { cloneMessagesWithTimestamps } = require('./fork'); -const getLogStores = require('~/cache/getLogStores'); /** * Returns the appropriate importer function based on the provided JSON data. @@ -194,6 +194,7 @@ async function importLibreChatConvo( jsonData, requestUserId, builderFactory = createImportBatchBuilder, + userRole, ) { try { /** @type {ImportBatchBuilder} */ @@ -202,8 +203,9 @@ async function importLibreChatConvo( /* Endpoint configuration */ let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI; - const cache = getLogStores(CacheKeys.CONFIG_STORE); - const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG); + const endpointsConfig = await getEndpointsConfig({ + user: { id: requestUserId, role: userRole, tenantId: getTenantId() }, + }); const endpointConfig = endpointsConfig?.[endpoint]; if (!endpointConfig && endpointsConfig) { endpoint = Object.keys(endpointsConfig)[0]; diff --git a/api/server/utils/import/importers.spec.js b/api/server/utils/import/importers.spec.js index 7984144cbc..6e712881fc 100644 --- a/api/server/utils/import/importers.spec.js +++ b/api/server/utils/import/importers.spec.js @@ -4,12 +4,13 @@ const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-pr const { getImporter, processAssistantMessage } = require('./importers'); const { ImportBatchBuilder } = require('./importBatchBuilder'); const { bulkSaveMessages, bulkSaveConvos: _bulkSaveConvos } = require('~/models'); -const getLogStores = require('~/cache/getLogStores'); -jest.mock('~/cache/getLogStores'); -const mockedCacheGet = jest.fn(); -getLogStores.mockImplementation(() => ({ - get: mockedCacheGet, +const mockGetEndpointsConfig = jest.fn().mockResolvedValue({ + [EModelEndpoint.openAI]: { userProvide: false }, +}); + +jest.mock('~/server/services/Config', () => ({ + getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args), })); // Mock the database methods @@ -758,7 +759,7 @@ describe('importLibreChatConvo', () => { ); it('should import conversation correctly', async () => { - mockedCacheGet.mockResolvedValue({ + mockGetEndpointsConfig.mockResolvedValue({ [EModelEndpoint.openAI]: {}, }); const expectedNumberOfMessages = 6; @@ -784,7 +785,7 @@ describe('importLibreChatConvo', () => { }); it('should import linear, non-recursive thread correctly with correct endpoint', async () => { - mockedCacheGet.mockResolvedValue({ + mockGetEndpointsConfig.mockResolvedValue({ [EModelEndpoint.azureOpenAI]: {}, }); @@ -924,7 +925,7 @@ describe('importLibreChatConvo', () => { }); it('should retain properties from the original conversation as well as new settings', async () => { - mockedCacheGet.mockResolvedValue({ + mockGetEndpointsConfig.mockResolvedValue({ [EModelEndpoint.azureOpenAI]: {}, }); const requestUserId = 'user-123'; diff --git a/api/strategies/appleStrategy.js b/api/strategies/appleStrategy.js index fbba2a1f41..6eace87bae 100644 --- a/api/strategies/appleStrategy.js +++ b/api/strategies/appleStrategy.js @@ -34,16 +34,28 @@ const getProfileDetails = ({ idToken, profile }) => { // Initialize the social login handler for Apple const appleLogin = socialLogin('apple', getProfileDetails); +const appleAdminLogin = socialLogin('apple', getProfileDetails, { existingUsersOnly: true }); -module.exports = () => +const getAppleConfig = (callbackURL) => ({ + clientID: process.env.APPLE_CLIENT_ID, + teamID: process.env.APPLE_TEAM_ID, + callbackURL, + keyID: process.env.APPLE_KEY_ID, + privateKeyLocation: process.env.APPLE_PRIVATE_KEY_PATH, + passReqToCallback: false, +}); + +const appleStrategy = () => new AppleStrategy( - { - clientID: process.env.APPLE_CLIENT_ID, - teamID: process.env.APPLE_TEAM_ID, - callbackURL: `${process.env.DOMAIN_SERVER}${process.env.APPLE_CALLBACK_URL}`, - keyID: process.env.APPLE_KEY_ID, - privateKeyLocation: process.env.APPLE_PRIVATE_KEY_PATH, - passReqToCallback: false, // Set to true if you need to access the request in the callback - }, + getAppleConfig(`${process.env.DOMAIN_SERVER}${process.env.APPLE_CALLBACK_URL}`), appleLogin, ); + +const appleAdminStrategy = () => + new AppleStrategy( + getAppleConfig(`${process.env.DOMAIN_SERVER}/api/admin/oauth/apple/callback`), + appleAdminLogin, + ); + +module.exports = appleStrategy; +module.exports.appleAdminLogin = appleAdminStrategy; diff --git a/api/strategies/discordStrategy.js b/api/strategies/discordStrategy.js index dc7cb05ac6..7fb68280d5 100644 --- a/api/strategies/discordStrategy.js +++ b/api/strategies/discordStrategy.js @@ -22,15 +22,27 @@ const getProfileDetails = ({ profile }) => { }; const discordLogin = socialLogin('discord', getProfileDetails); +const discordAdminLogin = socialLogin('discord', getProfileDetails, { existingUsersOnly: true }); -module.exports = () => +const getDiscordConfig = (callbackURL) => ({ + clientID: process.env.DISCORD_CLIENT_ID, + clientSecret: process.env.DISCORD_CLIENT_SECRET, + callbackURL, + scope: ['identify', 'email'], + authorizationURL: 'https://discord.com/api/oauth2/authorize?prompt=none', +}); + +const discordStrategy = () => new DiscordStrategy( - { - clientID: process.env.DISCORD_CLIENT_ID, - clientSecret: process.env.DISCORD_CLIENT_SECRET, - callbackURL: `${process.env.DOMAIN_SERVER}${process.env.DISCORD_CALLBACK_URL}`, - scope: ['identify', 'email'], - authorizationURL: 'https://discord.com/api/oauth2/authorize?prompt=none', - }, + getDiscordConfig(`${process.env.DOMAIN_SERVER}${process.env.DISCORD_CALLBACK_URL}`), discordLogin, ); + +const discordAdminStrategy = () => + new DiscordStrategy( + getDiscordConfig(`${process.env.DOMAIN_SERVER}/api/admin/oauth/discord/callback`), + discordAdminLogin, + ); + +module.exports = discordStrategy; +module.exports.discordAdminLogin = discordAdminStrategy; diff --git a/api/strategies/facebookStrategy.js b/api/strategies/facebookStrategy.js index e5d1b054db..f638c3bfdb 100644 --- a/api/strategies/facebookStrategy.js +++ b/api/strategies/facebookStrategy.js @@ -11,16 +11,28 @@ const getProfileDetails = ({ profile }) => ({ }); const facebookLogin = socialLogin('facebook', getProfileDetails); +const facebookAdminLogin = socialLogin('facebook', getProfileDetails, { existingUsersOnly: true }); -module.exports = () => +const getFacebookConfig = (callbackURL) => ({ + clientID: process.env.FACEBOOK_CLIENT_ID, + clientSecret: process.env.FACEBOOK_CLIENT_SECRET, + callbackURL, + proxy: true, + scope: ['public_profile'], + profileFields: ['id', 'email', 'name'], +}); + +const facebookStrategy = () => new FacebookStrategy( - { - clientID: process.env.FACEBOOK_CLIENT_ID, - clientSecret: process.env.FACEBOOK_CLIENT_SECRET, - callbackURL: `${process.env.DOMAIN_SERVER}${process.env.FACEBOOK_CALLBACK_URL}`, - proxy: true, - scope: ['public_profile'], - profileFields: ['id', 'email', 'name'], - }, + getFacebookConfig(`${process.env.DOMAIN_SERVER}${process.env.FACEBOOK_CALLBACK_URL}`), facebookLogin, ); + +const facebookAdminStrategy = () => + new FacebookStrategy( + getFacebookConfig(`${process.env.DOMAIN_SERVER}/api/admin/oauth/facebook/callback`), + facebookAdminLogin, + ); + +module.exports = facebookStrategy; +module.exports.facebookAdminLogin = facebookAdminStrategy; diff --git a/api/strategies/githubStrategy.js b/api/strategies/githubStrategy.js index 1c3937381e..363acbfcdb 100644 --- a/api/strategies/githubStrategy.js +++ b/api/strategies/githubStrategy.js @@ -11,24 +11,36 @@ const getProfileDetails = ({ profile }) => ({ }); const githubLogin = socialLogin('github', getProfileDetails); +const githubAdminLogin = socialLogin('github', getProfileDetails, { existingUsersOnly: true }); -module.exports = () => +const getGitHubConfig = (callbackURL) => ({ + clientID: process.env.GITHUB_CLIENT_ID, + clientSecret: process.env.GITHUB_CLIENT_SECRET, + callbackURL, + proxy: false, + scope: ['user:email'], + ...(process.env.GITHUB_ENTERPRISE_BASE_URL && { + authorizationURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/login/oauth/authorize`, + tokenURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/login/oauth/access_token`, + userProfileURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/api/v3/user`, + userEmailURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/api/v3/user/emails`, + ...(process.env.GITHUB_ENTERPRISE_USER_AGENT && { + userAgent: process.env.GITHUB_ENTERPRISE_USER_AGENT, + }), + }), +}); + +const githubStrategy = () => new GitHubStrategy( - { - clientID: process.env.GITHUB_CLIENT_ID, - clientSecret: process.env.GITHUB_CLIENT_SECRET, - callbackURL: `${process.env.DOMAIN_SERVER}${process.env.GITHUB_CALLBACK_URL}`, - proxy: false, - scope: ['user:email'], - ...(process.env.GITHUB_ENTERPRISE_BASE_URL && { - authorizationURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/login/oauth/authorize`, - tokenURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/login/oauth/access_token`, - userProfileURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/api/v3/user`, - userEmailURL: `${process.env.GITHUB_ENTERPRISE_BASE_URL}/api/v3/user/emails`, - ...(process.env.GITHUB_ENTERPRISE_USER_AGENT && { - userAgent: process.env.GITHUB_ENTERPRISE_USER_AGENT, - }), - }), - }, + getGitHubConfig(`${process.env.DOMAIN_SERVER}${process.env.GITHUB_CALLBACK_URL}`), githubLogin, ); + +const githubAdminStrategy = () => + new GitHubStrategy( + getGitHubConfig(`${process.env.DOMAIN_SERVER}/api/admin/oauth/github/callback`), + githubAdminLogin, + ); + +module.exports = githubStrategy; +module.exports.githubAdminLogin = githubAdminStrategy; diff --git a/api/strategies/googleStrategy.js b/api/strategies/googleStrategy.js index fd65823327..bee9a061a2 100644 --- a/api/strategies/googleStrategy.js +++ b/api/strategies/googleStrategy.js @@ -11,14 +11,26 @@ const getProfileDetails = ({ profile }) => ({ }); const googleLogin = socialLogin('google', getProfileDetails); +const googleAdminLogin = socialLogin('google', getProfileDetails, { existingUsersOnly: true }); -module.exports = () => +const getGoogleConfig = (callbackURL) => ({ + clientID: process.env.GOOGLE_CLIENT_ID, + clientSecret: process.env.GOOGLE_CLIENT_SECRET, + callbackURL, + proxy: true, +}); + +const googleStrategy = () => new GoogleStrategy( - { - clientID: process.env.GOOGLE_CLIENT_ID, - clientSecret: process.env.GOOGLE_CLIENT_SECRET, - callbackURL: `${process.env.DOMAIN_SERVER}${process.env.GOOGLE_CALLBACK_URL}`, - proxy: true, - }, + getGoogleConfig(`${process.env.DOMAIN_SERVER}${process.env.GOOGLE_CALLBACK_URL}`), googleLogin, ); + +const googleAdminStrategy = () => + new GoogleStrategy( + getGoogleConfig(`${process.env.DOMAIN_SERVER}/api/admin/oauth/google/callback`), + googleAdminLogin, + ); + +module.exports = googleStrategy; +module.exports.googleAdminLogin = googleAdminStrategy; diff --git a/api/strategies/index.js b/api/strategies/index.js index 9a1c58ad38..c15bbc4ce5 100644 --- a/api/strategies/index.js +++ b/api/strategies/index.js @@ -1,23 +1,33 @@ const { setupOpenId, getOpenIdConfig, getOpenIdEmail } = require('./openidStrategy'); const openIdJwtLogin = require('./openIdJwtStrategy'); const facebookLogin = require('./facebookStrategy'); +const { facebookAdminLogin } = facebookLogin; const discordLogin = require('./discordStrategy'); +const { discordAdminLogin } = discordLogin; const passportLogin = require('./localStrategy'); const googleLogin = require('./googleStrategy'); +const { googleAdminLogin } = googleLogin; const githubLogin = require('./githubStrategy'); +const { githubAdminLogin } = githubLogin; const { setupSaml } = require('./samlStrategy'); const appleLogin = require('./appleStrategy'); +const { appleAdminLogin } = appleLogin; const ldapLogin = require('./ldapStrategy'); const jwtLogin = require('./jwtStrategy'); module.exports = { appleLogin, + appleAdminLogin, passportLogin, googleLogin, + googleAdminLogin, githubLogin, + githubAdminLogin, discordLogin, + discordAdminLogin, jwtLogin, facebookLogin, + facebookAdminLogin, setupOpenId, getOpenIdConfig, getOpenIdEmail, diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index dcadc26a45..9253f54196 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -2,7 +2,12 @@ const fs = require('fs'); const LdapStrategy = require('passport-ldapauth'); const { logger } = require('@librechat/data-schemas'); const { SystemRoles, ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api'); +const { + isEnabled, + getBalanceConfig, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); const { createUser, findUser, updateUser, countUsers } = require('~/models'); const { getAppConfig } = require('~/server/services/Config'); @@ -89,16 +94,6 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { const ldapId = (LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail; - let user = await findUser({ ldapId }); - if (user && user.provider !== 'ldap') { - logger.info( - `[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`, - ); - return done(null, false, { - message: ErrorTypes.AUTH_FAILED, - }); - } - const fullNameAttributes = LDAP_FULL_NAME && LDAP_FULL_NAME.split(','); const fullName = fullNameAttributes && fullNameAttributes.length > 0 @@ -122,7 +117,31 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { ); } - const appConfig = await getAppConfig(); + // Domain check before findUser for two-phase fast-fail (consistent with SAML/OpenID/social). + // This means cross-provider users from blocked domains get 'Email domain not allowed' + // instead of AUTH_FAILED — both deny access. + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(mail, baseConfig?.registration?.allowedDomains)) { + logger.error( + `[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`, + ); + return done(null, false, { message: 'Email domain not allowed' }); + } + + let user = await findUser({ ldapId }); + if (user && user.provider !== 'ldap') { + logger.info( + `[ldapStrategy] User ${user.email} already exists with provider ${user.provider}`, + ); + return done(null, false, { + message: ErrorTypes.AUTH_FAILED, + }); + } + + const appConfig = user?.tenantId + ? await resolveAppConfigForUser(getAppConfig, user) + : baseConfig; + if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) { logger.error( `[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`, diff --git a/api/strategies/ldapStrategy.spec.js b/api/strategies/ldapStrategy.spec.js index a00e9b14b7..876d70f845 100644 --- a/api/strategies/ldapStrategy.spec.js +++ b/api/strategies/ldapStrategy.spec.js @@ -9,10 +9,10 @@ jest.mock('@librechat/data-schemas', () => ({ })); jest.mock('@librechat/api', () => ({ - // isEnabled used for TLS flags isEnabled: jest.fn(() => false), isEmailDomainAllowed: jest.fn(() => true), getBalanceConfig: jest.fn(() => ({ enabled: false })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/models', () => ({ @@ -30,14 +30,15 @@ jest.mock('~/server/services/Config', () => ({ let verifyCallback; jest.mock('passport-ldapauth', () => { return jest.fn().mockImplementation((options, verify) => { - verifyCallback = verify; // capture the strategy verify function + verifyCallback = verify; return { name: 'ldap', options, verify }; }); }); const { ErrorTypes } = require('librechat-data-provider'); -const { isEmailDomainAllowed } = require('@librechat/api'); +const { isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api'); const { findUser, createUser, updateUser, countUsers } = require('~/models'); +const { getAppConfig } = require('~/server/services/Config'); // Helper to call the verify callback and wrap in a Promise for convenience const callVerify = (userinfo) => @@ -117,6 +118,7 @@ describe('ldapStrategy', () => { expect(user).toBe(false); expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED }); expect(createUser).not.toHaveBeenCalled(); + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); }); it('updates an existing ldap user with current LDAP info', async () => { @@ -158,7 +160,6 @@ describe('ldapStrategy', () => { uid: 'uid999', givenName: 'John', cn: 'John Doe', - // no mail and no custom LDAP_EMAIL }; const { user } = await callVerify(userinfo); @@ -180,4 +181,66 @@ describe('ldapStrategy', () => { expect(user).toBe(false); expect(info).toEqual({ message: 'Email domain not allowed' }); }); + + it('passes getAppConfig and found user to resolveAppConfigForUser', async () => { + const existing = { + _id: 'u3', + provider: 'ldap', + email: 'tenant@example.com', + ldapId: 'uid-tenant', + username: 'tenantuser', + name: 'Tenant User', + tenantId: 'tenant-a', + role: 'USER', + }; + findUser.mockResolvedValue(existing); + + const userinfo = { + uid: 'uid-tenant', + mail: 'tenant@example.com', + givenName: 'Tenant', + cn: 'Tenant User', + }; + + await callVerify(userinfo); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existing); + }); + + it('uses baseConfig for new user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue(null); + + const userinfo = { + uid: 'uid-new', + mail: 'newuser@example.com', + givenName: 'New', + cn: 'New User', + }; + + await callVerify(userinfo); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const existing = { + _id: 'u-blocked', + provider: 'ldap', + ldapId: 'uid-tenant', + tenantId: 'tenant-strict', + role: 'USER', + }; + findUser.mockResolvedValue(existing); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const userinfo = { uid: 'uid-tenant', mail: 'user@example.com', givenName: 'Test', cn: 'Test' }; + const { user, info } = await callVerify(userinfo); + + expect(user).toBe(false); + expect(info).toEqual({ message: 'Email domain not allowed' }); + }); }); diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index b8f9b17b7c..b2d942618f 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -15,6 +15,7 @@ const { findOpenIDUser, getBalanceConfig, isEmailDomainAllowed, + resolveAppConfigForUser, } = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { findUser, createUser, updateUser } = require('~/models'); @@ -468,9 +469,10 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { Object.assign(userinfo, providerUserinfo); } - const appConfig = await getAppConfig(); const email = getOpenIdEmail(userinfo); - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { logger.error( `[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`, ); @@ -491,6 +493,15 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { throw new Error(ErrorTypes.AUTH_FAILED); } + const appConfig = user?.tenantId ? await resolveAppConfigForUser(getAppConfig, user) : baseConfig; + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[OpenID Strategy] Authentication blocked - email domain not allowed [Identifier: ${email}]`, + ); + throw new Error('Email domain not allowed'); + } + const fullName = getFullName(userinfo); const requiredRole = process.env.OPENID_REQUIRED_ROLE; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 7edd4fbd62..af8a430697 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -3,6 +3,8 @@ const fetch = require('node-fetch'); const jwtDecode = require('jsonwebtoken/decode'); const { ErrorTypes } = require('librechat-data-provider'); const { findUser, createUser, updateUser } = require('~/models'); +const { resolveAppConfigForUser, isEnabled } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); const { setupOpenId } = require('./openidStrategy'); // --- Mocks --- @@ -28,6 +30,7 @@ jest.mock('@librechat/api', () => ({ getBalanceConfig: jest.fn(() => ({ enabled: false, })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/models', () => ({ findUser: jest.fn(), @@ -140,7 +143,6 @@ describe('setupOpenId', () => { beforeEach(async () => { // Clear previous mock calls and reset implementations jest.clearAllMocks(); - const { isEnabled } = require('@librechat/api'); isEnabled.mockImplementation(jest.requireActual('@librechat/api').isEnabled); // Reset environment variables needed by the strategy @@ -194,22 +196,26 @@ describe('setupOpenId', () => { }); describe('clientMetadata construction in setupOpenId', () => { - it('sets token_endpoint_auth_method to none for PKCE without a client secret', async () => { - const openidClient = require('openid-client'); + let openidClient; + + beforeEach(() => { + openidClient = require('openid-client'); openidClient.discovery.mockClear(); + }); + + it('sets token_endpoint_auth_method to none for PKCE without a client secret', async () => { process.env.OPENID_USE_PKCE = 'true'; delete process.env.OPENID_CLIENT_SECRET; await setupOpenId(); const [, , metadata] = openidClient.discovery.mock.calls.at(-1); - expect(metadata.client_secret).toBeUndefined(); expect(metadata.token_endpoint_auth_method).toBe('none'); + expect(metadata.client_secret).toBeUndefined(); }); it('leaves token_endpoint_auth_method unset for secret-based clients without nonce', async () => { - const openidClient = require('openid-client'); - openidClient.discovery.mockClear(); + process.env.OPENID_USE_PKCE = 'false'; process.env.OPENID_CLIENT_SECRET = 'my-secret'; await setupOpenId(); @@ -219,9 +225,8 @@ describe('setupOpenId', () => { expect(metadata.token_endpoint_auth_method).toBeUndefined(); }); - it('sets client_secret_post when nonce generation is enabled for secret-based clients', async () => { - const openidClient = require('openid-client'); - openidClient.discovery.mockClear(); + it('sets client_secret and client_secret_post when nonce generation is enabled', async () => { + process.env.OPENID_USE_PKCE = 'false'; process.env.OPENID_GENERATE_NONCE = 'true'; process.env.OPENID_CLIENT_SECRET = 'my-secret'; @@ -232,9 +237,7 @@ describe('setupOpenId', () => { expect(metadata.token_endpoint_auth_method).toBe('client_secret_post'); }); - it('treats a whitespace-only client secret as absent', async () => { - const openidClient = require('openid-client'); - openidClient.discovery.mockClear(); + it('treats whitespace-only secret as absent', async () => { process.env.OPENID_USE_PKCE = 'true'; process.env.OPENID_CLIENT_SECRET = ' '; @@ -245,9 +248,7 @@ describe('setupOpenId', () => { expect(metadata.token_endpoint_auth_method).toBe('none'); }); - it('does not force a public-client auth method when PKCE and a client secret are both configured', async () => { - const openidClient = require('openid-client'); - openidClient.discovery.mockClear(); + it('does not force an auth method when PKCE and a client secret are both configured without nonce', async () => { process.env.OPENID_USE_PKCE = 'true'; process.env.OPENID_CLIENT_SECRET = 'my-secret'; @@ -1888,4 +1889,52 @@ describe('setupOpenId', () => { ); }); }); + + describe('Tenant-scoped config', () => { + it('should call resolveAppConfigForUser for tenant user', async () => { + const existingUser = { + _id: 'openid-tenant-user', + provider: 'openid', + openidId: '1234', + email: 'test@example.com', + tenantId: 'tenant-d', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + + await validate(tokenset); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for new user without calling resolveAppConfigForUser', async () => { + findUser.mockResolvedValue(null); + + await validate(tokenset); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const existingUser = { + _id: 'openid-tenant-blocked', + provider: 'openid', + openidId: '1234', + email: 'test@example.com', + tenantId: 'tenant-restrict', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details).toEqual({ message: 'Email domain not allowed' }); + }); + }); }); diff --git a/api/strategies/samlStrategy.js b/api/strategies/samlStrategy.js index 843baf8a64..4f4bfac158 100644 --- a/api/strategies/samlStrategy.js +++ b/api/strategies/samlStrategy.js @@ -5,7 +5,11 @@ const passport = require('passport'); const { ErrorTypes } = require('librechat-data-provider'); const { hashToken, logger } = require('@librechat/data-schemas'); const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); -const { getBalanceConfig, isEmailDomainAllowed } = require('@librechat/api'); +const { + getBalanceConfig, + isEmailDomainAllowed, + resolveAppConfigForUser, +} = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { findUser, createUser, updateUser } = require('~/models'); const { getAppConfig } = require('~/server/services/Config'); @@ -174,126 +178,179 @@ function convertToUsername(input, defaultValue = '') { return defaultValue; } +/** + * Creates a SAML authentication callback. + * @param {boolean} [existingUsersOnly=false] - If true, only existing users will be authenticated. + * @returns {Function} The SAML callback function for passport. + */ +function createSamlCallback(existingUsersOnly = false) { + return async (profile, done) => { + try { + logger.info(`[samlStrategy] SAML authentication received for NameID: ${profile.nameID}`); + logger.debug('[samlStrategy] SAML profile:', profile); + + const userEmail = getEmail(profile) || ''; + + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(userEmail, baseConfig?.registration?.allowedDomains)) { + logger.error( + `[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`, + ); + return done(null, false, { message: 'Email domain not allowed' }); + } + + let user = await findUser({ samlId: profile.nameID }); + logger.info( + `[samlStrategy] User ${user ? 'found' : 'not found'} with SAML ID: ${profile.nameID}`, + ); + + if (!user) { + user = await findUser({ email: userEmail }); + logger.info(`[samlStrategy] User ${user ? 'found' : 'not found'} with email: ${userEmail}`); + } + + if (user && user.provider !== 'saml') { + logger.info( + `[samlStrategy] User ${user.email} already exists with provider ${user.provider}`, + ); + return done(null, false, { + message: ErrorTypes.AUTH_FAILED, + }); + } + + const appConfig = user?.tenantId + ? await resolveAppConfigForUser(getAppConfig, user) + : baseConfig; + + if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) { + logger.error( + `[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`, + ); + return done(null, false, { message: 'Email domain not allowed' }); + } + + const fullName = getFullName(profile); + + const username = convertToUsername( + getUserName(profile) || getGivenName(profile) || getEmail(profile), + ); + + if (!user) { + if (existingUsersOnly) { + logger.error( + `[samlStrategy] Admin auth blocked - user does not exist [Email: ${userEmail}]`, + ); + return done(null, false, { message: 'User does not exist' }); + } + + user = { + provider: 'saml', + samlId: profile.nameID, + username, + email: userEmail, + emailVerified: true, + name: fullName, + }; + const balanceConfig = getBalanceConfig(appConfig); + user = await createUser(user, balanceConfig, true, true); + } else { + user.provider = 'saml'; + user.samlId = profile.nameID; + user.username = username; + user.name = fullName; + } + + const picture = getPicture(profile); + if (picture && !user.avatar?.includes('manual=true')) { + const imageBuffer = await downloadImage(profile.picture); + if (imageBuffer) { + let fileName; + if (crypto) { + fileName = (await hashToken(profile.nameID)) + '.png'; + } else { + fileName = profile.nameID + '.png'; + } + + const { saveBuffer } = getStrategyFunctions( + appConfig?.fileStrategy ?? process.env.CDN_PROVIDER, + ); + const imagePath = await saveBuffer({ + fileName, + userId: user._id.toString(), + buffer: imageBuffer, + }); + user.avatar = imagePath ?? ''; + } + } + + user = await updateUser(user._id, user); + + logger.info( + `[samlStrategy] Login success SAML ID: ${user.samlId} | email: ${user.email} | username: ${user.username}`, + { + user: { + samlId: user.samlId, + username: user.username, + email: user.email, + name: user.name, + }, + }, + ); + + done(null, user); + } catch (err) { + logger.error('[samlStrategy] Login failed', err); + done(err); + } + }; +} + +/** + * Returns the base SAML configuration shared by both regular and admin strategies. + * @returns {object} The SAML configuration object. + */ +function getBaseSamlConfig() { + return { + entryPoint: process.env.SAML_ENTRY_POINT, + issuer: process.env.SAML_ISSUER, + idpCert: getCertificateContent(process.env.SAML_CERT), + wantAssertionsSigned: process.env.SAML_USE_AUTHN_RESPONSE_SIGNED === 'true' ? false : true, + wantAuthnResponseSigned: process.env.SAML_USE_AUTHN_RESPONSE_SIGNED === 'true' ? true : false, + }; +} + async function setupSaml() { try { + const baseConfig = getBaseSamlConfig(); const samlConfig = { - entryPoint: process.env.SAML_ENTRY_POINT, - issuer: process.env.SAML_ISSUER, + ...baseConfig, callbackUrl: process.env.SAML_CALLBACK_URL, - idpCert: getCertificateContent(process.env.SAML_CERT), - wantAssertionsSigned: process.env.SAML_USE_AUTHN_RESPONSE_SIGNED === 'true' ? false : true, - wantAuthnResponseSigned: process.env.SAML_USE_AUTHN_RESPONSE_SIGNED === 'true' ? true : false, }; - passport.use( - 'saml', - new SamlStrategy(samlConfig, async (profile, done) => { - try { - logger.info(`[samlStrategy] SAML authentication received for NameID: ${profile.nameID}`); - logger.debug('[samlStrategy] SAML profile:', profile); - - const userEmail = getEmail(profile) || ''; - const appConfig = await getAppConfig(); - - if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) { - logger.error( - `[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`, - ); - return done(null, false, { message: 'Email domain not allowed' }); - } - - let user = await findUser({ samlId: profile.nameID }); - logger.info( - `[samlStrategy] User ${user ? 'found' : 'not found'} with SAML ID: ${profile.nameID}`, - ); - - if (!user) { - user = await findUser({ email: userEmail }); - logger.info( - `[samlStrategy] User ${user ? 'found' : 'not found'} with email: ${userEmail}`, - ); - } - - if (user && user.provider !== 'saml') { - logger.info( - `[samlStrategy] User ${user.email} already exists with provider ${user.provider}`, - ); - return done(null, false, { - message: ErrorTypes.AUTH_FAILED, - }); - } - - const fullName = getFullName(profile); - - const username = convertToUsername( - getUserName(profile) || getGivenName(profile) || getEmail(profile), - ); - - if (!user) { - user = { - provider: 'saml', - samlId: profile.nameID, - username, - email: userEmail, - emailVerified: true, - name: fullName, - }; - const balanceConfig = getBalanceConfig(appConfig); - user = await createUser(user, balanceConfig, true, true); - } else { - user.provider = 'saml'; - user.samlId = profile.nameID; - user.username = username; - user.name = fullName; - } - - const picture = getPicture(profile); - if (picture && !user.avatar?.includes('manual=true')) { - const imageBuffer = await downloadImage(profile.picture); - if (imageBuffer) { - let fileName; - if (crypto) { - fileName = (await hashToken(profile.nameID)) + '.png'; - } else { - fileName = profile.nameID + '.png'; - } - - const { saveBuffer } = getStrategyFunctions( - appConfig?.fileStrategy ?? process.env.CDN_PROVIDER, - ); - const imagePath = await saveBuffer({ - fileName, - userId: user._id.toString(), - buffer: imageBuffer, - }); - user.avatar = imagePath ?? ''; - } - } - - user = await updateUser(user._id, user); - - logger.info( - `[samlStrategy] Login success SAML ID: ${user.samlId} | email: ${user.email} | username: ${user.username}`, - { - user: { - samlId: user.samlId, - username: user.username, - email: user.email, - name: user.name, - }, - }, - ); - - done(null, user); - } catch (err) { - logger.error('[samlStrategy] Login failed', err); - done(err); - } - }), - ); + passport.use('saml', new SamlStrategy(samlConfig, createSamlCallback(false))); + setupSamlAdmin(baseConfig); } catch (err) { logger.error('[samlStrategy]', err); } } +/** + * Sets up the SAML strategy specifically for admin authentication. + * Rejects users that don't already exist. + * @param {object} [baseConfig] - Pre-parsed base SAML config to avoid redundant cert parsing. + */ +function setupSamlAdmin(baseConfig) { + try { + const samlAdminConfig = { + ...(baseConfig ?? getBaseSamlConfig()), + callbackUrl: `${process.env.DOMAIN_SERVER}/api/admin/oauth/saml/callback`, + }; + + passport.use('samlAdmin', new SamlStrategy(samlAdminConfig, createSamlCallback(true))); + logger.info('[samlStrategy] Admin SAML strategy registered.'); + } catch (err) { + logger.error('[samlStrategy] setupSamlAdmin', err); + } +} + module.exports = { setupSaml, getCertificateContent }; diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js index 1d16719b87..965fb157ef 100644 --- a/api/strategies/samlStrategy.spec.js +++ b/api/strategies/samlStrategy.spec.js @@ -30,6 +30,7 @@ jest.mock('@librechat/api', () => ({ tokenCredits: 1000, startBalance: 1000, })), + resolveAppConfigForUser: jest.fn(async (_getAppConfig, _user) => ({})), })); jest.mock('~/server/services/Config/EndpointService', () => ({ config: {}, @@ -47,6 +48,9 @@ const fs = require('fs'); const path = require('path'); const fetch = require('node-fetch'); const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); +const { findUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); const { setupSaml, getCertificateContent } = require('./samlStrategy'); // Configure fs mock @@ -54,10 +58,14 @@ jest.mocked(fs).existsSync = jest.fn(); jest.mocked(fs).statSync = jest.fn(); jest.mocked(fs).readFileSync = jest.fn(); -// To capture the verify callback from the strategy, we grab it from the mock constructor +// To capture the verify callback from the strategy, we grab it from the mock constructor. +// setupSaml() registers both 'saml' (regular) and 'samlAdmin' strategies, so we capture +// only the first callback per setupSaml() call (the regular one). let verifyCallback; SamlStrategy.mockImplementation((options, verify) => { - verifyCallback = verify; + if (!verifyCallback) { + verifyCallback = verify; + } return { name: 'saml', options, verify }; }); @@ -215,6 +223,8 @@ describe('setupSaml', () => { beforeEach(async () => { jest.clearAllMocks(); + // Reset so the mock captures the regular (non-admin) callback on next setupSaml() call + verifyCallback = null; // Configure mocks const { findUser, createUser, updateUser } = require('~/models'); @@ -440,4 +450,50 @@ u7wlOSk+oFzDIO/UILIA expect(fetch).not.toHaveBeenCalled(); }); + + it('should pass the found user to resolveAppConfigForUser', async () => { + const existingUser = { + _id: 'tenant-user-id', + provider: 'saml', + samlId: 'saml-1234', + email: 'test@example.com', + tenantId: 'tenant-c', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + + const profile = { ...baseProfile }; + await validate(profile); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for new SAML user without calling resolveAppConfigForUser', async () => { + const profile = { ...baseProfile }; + await validate(profile); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const existingUser = { + _id: 'tenant-blocked', + provider: 'saml', + samlId: 'saml-1234', + email: 'test@example.com', + tenantId: 'tenant-restrict', + role: 'USER', + }; + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const profile = { ...baseProfile }; + const { user } = await validate(profile); + expect(user).toBe(false); + }); }); diff --git a/api/strategies/socialLogin.js b/api/strategies/socialLogin.js index 88fb347042..580e4f3d7e 100644 --- a/api/strategies/socialLogin.js +++ b/api/strategies/socialLogin.js @@ -1,21 +1,21 @@ const { logger } = require('@librechat/data-schemas'); const { ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, isEmailDomainAllowed } = require('@librechat/api'); +const { isEnabled, isEmailDomainAllowed, resolveAppConfigForUser } = require('@librechat/api'); const { createSocialUser, handleExistingUser } = require('./process'); const { getAppConfig } = require('~/server/services/Config'); const { findUser } = require('~/models'); const socialLogin = - (provider, getProfileDetails) => async (accessToken, refreshToken, idToken, profile, cb) => { + (provider, getProfileDetails, options = {}) => + async (accessToken, refreshToken, idToken, profile, cb) => { try { const { email, id, avatarUrl, username, name, emailVerified } = getProfileDetails({ idToken, profile, }); - const appConfig = await getAppConfig(); - - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + const baseConfig = await getAppConfig({ baseOnly: true }); + if (!isEmailDomainAllowed(email, baseConfig?.registration?.allowedDomains)) { logger.error( `[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`, ); @@ -41,6 +41,20 @@ const socialLogin = } } + const appConfig = existingUser?.tenantId + ? await resolveAppConfigForUser(getAppConfig, existingUser) + : baseConfig; + + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`, + ); + const error = new Error(ErrorTypes.AUTH_FAILED); + error.code = ErrorTypes.AUTH_FAILED; + error.message = 'Email domain not allowed'; + return cb(error); + } + if (existingUser?.provider === provider) { await handleExistingUser(existingUser, avatarUrl, appConfig, email); return cb(null, existingUser); @@ -54,6 +68,13 @@ const socialLogin = return cb(error); } + if (options.existingUsersOnly) { + logger.error( + `[${provider}Login] Admin auth blocked - user does not exist [Email: ${email}]`, + ); + return cb(null, false, { message: 'User does not exist' }); + } + const ALLOW_SOCIAL_REGISTRATION = isEnabled(process.env.ALLOW_SOCIAL_REGISTRATION); if (!ALLOW_SOCIAL_REGISTRATION) { logger.error( diff --git a/api/strategies/socialLogin.test.js b/api/strategies/socialLogin.test.js index ba4778c8b1..4fde397d55 100644 --- a/api/strategies/socialLogin.test.js +++ b/api/strategies/socialLogin.test.js @@ -3,6 +3,8 @@ const { ErrorTypes } = require('librechat-data-provider'); const { createSocialUser, handleExistingUser } = require('./process'); const socialLogin = require('./socialLogin'); const { findUser } = require('~/models'); +const { resolveAppConfigForUser } = require('@librechat/api'); +const { getAppConfig } = require('~/server/services/Config'); jest.mock('@librechat/data-schemas', () => { const actualModule = jest.requireActual('@librechat/data-schemas'); @@ -25,6 +27,10 @@ jest.mock('@librechat/api', () => ({ ...jest.requireActual('@librechat/api'), isEnabled: jest.fn().mockReturnValue(true), isEmailDomainAllowed: jest.fn().mockReturnValue(true), + resolveAppConfigForUser: jest.fn().mockResolvedValue({ + fileStrategy: 'local', + balance: { enabled: false }, + }), })); jest.mock('~/models', () => ({ @@ -66,10 +72,7 @@ describe('socialLogin', () => { googleId: googleId, }; - /** Mock findUser to return user on first call (by googleId), null on second call */ - findUser - .mockResolvedValueOnce(existingUser) // First call: finds by googleId - .mockResolvedValueOnce(null); // Second call would be by email, but won't be reached + findUser.mockResolvedValueOnce(existingUser).mockResolvedValueOnce(null); const mockProfile = { id: googleId, @@ -83,13 +86,9 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify it searched by googleId first */ expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); - - /** Verify it did NOT search by email (because it found user by googleId) */ expect(findUser).toHaveBeenCalledTimes(1); - /** Verify handleExistingUser was called with the new email */ expect(handleExistingUser).toHaveBeenCalledWith( existingUser, 'https://example.com/avatar.png', @@ -97,7 +96,6 @@ describe('socialLogin', () => { newEmail, ); - /** Verify callback was called with success */ expect(callback).toHaveBeenCalledWith(null, existingUser); }); @@ -113,7 +111,7 @@ describe('socialLogin', () => { facebookId: facebookId, }; - findUser.mockResolvedValue(existingUser); // Always returns user + findUser.mockResolvedValue(existingUser); const mockProfile = { id: facebookId, @@ -127,7 +125,6 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify it searched by facebookId first */ expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId }); expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]); @@ -150,13 +147,10 @@ describe('socialLogin', () => { _id: 'user789', email: email, provider: 'google', - googleId: 'old-google-id', // Different googleId (edge case) + googleId: 'old-google-id', }; - /** First call (by googleId) returns null, second call (by email) returns user */ - findUser - .mockResolvedValueOnce(null) // By googleId - .mockResolvedValueOnce(existingUser); // By email + findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser); const mockProfile = { id: googleId, @@ -170,13 +164,10 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify both searches happened */ expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); - /** Email passed as-is; findUser implementation handles case normalization */ expect(findUser).toHaveBeenNthCalledWith(2, { email: email }); expect(findUser).toHaveBeenCalledTimes(2); - /** Verify warning log */ expect(logger.warn).toHaveBeenCalledWith( `[${provider}Login] User found by email: ${email} but not by ${provider}Id`, ); @@ -197,7 +188,6 @@ describe('socialLogin', () => { googleId: googleId, }; - /** Both searches return null */ findUser.mockResolvedValue(null); createSocialUser.mockResolvedValue(newUser); @@ -213,10 +203,8 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify both searches happened */ expect(findUser).toHaveBeenCalledTimes(2); - /** Verify createSocialUser was called */ expect(createSocialUser).toHaveBeenCalledWith({ email: email, avatarUrl: 'https://example.com/avatar.png', @@ -242,12 +230,10 @@ describe('socialLogin', () => { const existingUser = { _id: 'user123', email: email, - provider: 'local', // Different provider + provider: 'local', }; - findUser - .mockResolvedValueOnce(null) // By googleId - .mockResolvedValueOnce(existingUser); // By email + findUser.mockResolvedValueOnce(null).mockResolvedValueOnce(existingUser); const mockProfile = { id: googleId, @@ -261,7 +247,6 @@ describe('socialLogin', () => { await loginFn(null, null, null, mockProfile, callback); - /** Verify error callback */ expect(callback).toHaveBeenCalledWith( expect.objectContaining({ code: ErrorTypes.AUTH_FAILED, @@ -274,4 +259,104 @@ describe('socialLogin', () => { ); }); }); + + describe('Tenant-scoped config', () => { + it('should call resolveAppConfigForUser for tenant user', async () => { + const provider = 'google'; + const googleId = 'google-tenant-user'; + const email = 'tenant@example.com'; + + const existingUser = { + _id: 'userTenant', + email, + provider: 'google', + googleId, + tenantId: 'tenant-b', + role: 'USER', + }; + + findUser.mockResolvedValue(existingUser); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'Tenant', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(resolveAppConfigForUser).toHaveBeenCalledWith(getAppConfig, existingUser); + }); + + it('should use baseConfig for non-tenant user without calling resolveAppConfigForUser', async () => { + const provider = 'google'; + const googleId = 'google-new-tenant'; + const email = 'new@example.com'; + + findUser.mockResolvedValue(null); + createSocialUser.mockResolvedValue({ + _id: 'newUser', + email, + provider: 'google', + googleId, + }); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'New', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(resolveAppConfigForUser).not.toHaveBeenCalled(); + expect(getAppConfig).toHaveBeenCalledWith({ baseOnly: true }); + }); + + it('should block login when tenant config restricts the domain', async () => { + const { isEmailDomainAllowed } = require('@librechat/api'); + const provider = 'google'; + const googleId = 'google-tenant-blocked'; + const email = 'blocked@example.com'; + + const existingUser = { + _id: 'userBlocked', + email, + provider: 'google', + googleId, + tenantId: 'tenant-restrict', + role: 'USER', + }; + + findUser.mockResolvedValue(existingUser); + resolveAppConfigForUser.mockResolvedValue({ + registration: { allowedDomains: ['other.com'] }, + }); + isEmailDomainAllowed.mockReturnValueOnce(true).mockReturnValueOnce(false); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'Blocked', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + expect(callback).toHaveBeenCalledWith( + expect.objectContaining({ message: 'Email domain not allowed' }), + ); + }); + }); }); diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 6cecdb95c8..dfa6762ee5 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -1813,3 +1813,57 @@ describe('GLM Model Tests (Zhipu AI)', () => { }); }); }); + +describe('Mistral Model Tests', () => { + describe('getModelMaxTokens', () => { + test('should return correct tokens for mistral-large-3 (256k context)', () => { + expect(getModelMaxTokens('mistral-large-3', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large-3'], + ); + }); + + test('should match mistral-large-3 for suffixed variants', () => { + expect(getModelMaxTokens('mistral-large-3-instruct', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large-3'], + ); + }); + + test('should not match mistral-large-3 for generic mistral-large', () => { + expect(getModelMaxTokens('mistral-large', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large'], + ); + expect(getModelMaxTokens('mistral-large-latest', EModelEndpoint.custom)).toBe( + maxTokensMap[EModelEndpoint.custom]['mistral-large'], + ); + }); + }); + + describe('matchModelName', () => { + test('should match mistral-large-3 exactly', () => { + expect(matchModelName('mistral-large-3', EModelEndpoint.custom)).toBe('mistral-large-3'); + }); + + test('should match mistral-large-3 for prefixed/suffixed variants', () => { + expect(matchModelName('mistral/mistral-large-3', EModelEndpoint.custom)).toBe( + 'mistral-large-3', + ); + expect(matchModelName('mistral-large-3-instruct', EModelEndpoint.custom)).toBe( + 'mistral-large-3', + ); + }); + + test('should match generic mistral-large for non-3 variants', () => { + expect(matchModelName('mistral-large-latest', EModelEndpoint.custom)).toBe('mistral-large'); + }); + }); + + describe('findMatchingPattern', () => { + test('should prefer mistral-large-3 over mistral-large for mistral-large-3 variants', () => { + const result = findMatchingPattern( + 'mistral-large-3-instruct', + maxTokensMap[EModelEndpoint.custom], + ); + expect(result).toBe('mistral-large-3'); + }); + }); +}); diff --git a/client/jest.config.cjs b/client/jest.config.cjs index 4d9087bff7..375e4418a7 100644 --- a/client/jest.config.cjs +++ b/client/jest.config.cjs @@ -41,7 +41,9 @@ module.exports = { '\\.(jpg|jpeg|png|gif|eot|otf|webp|svg|ttf|woff|woff2|mp4|webm|wav|mp3|m4a|aac|oga)$': 'jest-file-loader', }, - transformIgnorePatterns: ['node_modules/?!@zattoo/use-double-click'], + transformIgnorePatterns: [ + '/node_modules/(?!(@zattoo/use-double-click|@dicebear|@react-dnd|react-dnd.*|dnd-core|filenamify|filename-reserved-regex|heic-to|lowlight|highlight\\.js|fault|react-markdown|unified|bail|trough|devlop|is-.*|parse-entities|stringify-entities|character-.*|trim-lines|style-to-object|inline-style-parser|html-url-attributes|escape-string-regexp|longest-streak|zwitch|ccount|markdown-table|comma-separated-tokens|space-separated-tokens|web-namespaces|property-information|remark-.*|rehype-.*|recma-.*|hast.*|mdast-.*|unist-.*|vfile.*|micromark.*|estree-util-.*|decode-named-character-reference)/)/', + ], setupFilesAfterEnv: ['@testing-library/jest-dom/extend-expect', '/test/setupTests.js'], clearMocks: true, }; diff --git a/client/nginx.conf b/client/nginx.conf index c91c47a23f..906b3af128 100644 --- a/client/nginx.conf +++ b/client/nginx.conf @@ -86,9 +86,15 @@ server { # location /api { # proxy_pass http://api:3080/api; +# proxy_set_header X-Forwarded-Proto $scheme; +# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +# proxy_set_header Host $host; # } # location / { # proxy_pass http://api:3080; +# proxy_set_header X-Forwarded-Proto $scheme; +# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +# proxy_set_header Host $host; # } #} diff --git a/client/package.json b/client/package.json index 85d5a176f4..5fe9cddcc7 100644 --- a/client/package.json +++ b/client/package.json @@ -122,7 +122,7 @@ "@babel/preset-env": "^7.22.15", "@babel/preset-react": "^7.22.15", "@babel/preset-typescript": "^7.22.15", - "@happy-dom/jest-environment": "^20.8.3", + "@happy-dom/jest-environment": "^20.8.9", "@tanstack/react-query-devtools": "^4.29.0", "@testing-library/dom": "^9.3.0", "@testing-library/jest-dom": "^5.16.5", diff --git a/client/src/@types/react.d.ts b/client/src/@types/react.d.ts new file mode 100644 index 0000000000..edf0b7af3f --- /dev/null +++ b/client/src/@types/react.d.ts @@ -0,0 +1,8 @@ +import 'react'; + +declare module 'react' { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + interface HTMLAttributes { + inert?: boolean | '' | undefined; + } +} diff --git a/client/src/Providers/ActivePanelContext.tsx b/client/src/Providers/ActivePanelContext.tsx index 9d6082d4e4..46b2a189b7 100644 --- a/client/src/Providers/ActivePanelContext.tsx +++ b/client/src/Providers/ActivePanelContext.tsx @@ -35,3 +35,11 @@ export function useActivePanel() { } return context; } + +/** Returns `active` when it matches a known link, otherwise the first link's id. */ +export function resolveActivePanel(active: string, links: { id: string }[]): string { + if (links.length > 0 && links.some((l) => l.id === active)) { + return active; + } + return links[0]?.id ?? active; +} diff --git a/client/src/Providers/MessagesViewContext.tsx b/client/src/Providers/MessagesViewContext.tsx index f1cae204a4..c44972918c 100644 --- a/client/src/Providers/MessagesViewContext.tsx +++ b/client/src/Providers/MessagesViewContext.tsx @@ -140,6 +140,55 @@ export function useMessagesOperations() { ); } +type OptionalMessagesOps = Pick< + MessagesViewContextValue, + 'ask' | 'regenerate' | 'handleContinue' | 'getMessages' | 'setMessages' +>; + +const NOOP_OPS: OptionalMessagesOps = { + ask: () => {}, + regenerate: () => {}, + handleContinue: () => {}, + getMessages: () => undefined, + setMessages: () => {}, +}; + +/** + * Hook for components that need message operations but may render outside MessagesViewProvider + * (e.g. the /search route). Returns no-op stubs when the provider is absent — UI actions will + * be silently discarded rather than crashing. Callers must use optional chaining on + * `getMessages()` results, as it returns `undefined` outside the provider. + */ +export function useOptionalMessagesOperations(): OptionalMessagesOps { + const context = useContext(MessagesViewContext); + const ask = context?.ask; + const regenerate = context?.regenerate; + const handleContinue = context?.handleContinue; + const getMessages = context?.getMessages; + const setMessages = context?.setMessages; + return useMemo( + () => ({ + ask: ask ?? NOOP_OPS.ask, + regenerate: regenerate ?? NOOP_OPS.regenerate, + handleContinue: handleContinue ?? NOOP_OPS.handleContinue, + getMessages: getMessages ?? NOOP_OPS.getMessages, + setMessages: setMessages ?? NOOP_OPS.setMessages, + }), + [ask, regenerate, handleContinue, getMessages, setMessages], + ); +} + +/** + * Hook for components that need conversation data but may render outside MessagesViewProvider + * (e.g. the /search route). Returns `undefined` for both fields when the provider is absent. + */ +export function useOptionalMessagesConversation() { + const context = useContext(MessagesViewContext); + const conversation = context?.conversation; + const conversationId = context?.conversationId; + return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]); +} + /** Hook for components that only need message state */ export function useMessagesState() { const { index, latestMessageId, latestMessageDepth, setLatestMessage } = useMessagesViewContext(); diff --git a/client/src/Providers/__tests__/ActivePanelContext.spec.tsx b/client/src/Providers/__tests__/ActivePanelContext.spec.tsx index 6a6059c9b4..0f2f89e8f7 100644 --- a/client/src/Providers/__tests__/ActivePanelContext.spec.tsx +++ b/client/src/Providers/__tests__/ActivePanelContext.spec.tsx @@ -1,7 +1,11 @@ import React from 'react'; import { render, fireEvent, screen } from '@testing-library/react'; import '@testing-library/jest-dom/extend-expect'; -import { ActivePanelProvider, useActivePanel } from '~/Providers/ActivePanelContext'; +import { + ActivePanelProvider, + resolveActivePanel, + useActivePanel, +} from '~/Providers/ActivePanelContext'; const STORAGE_KEY = 'side:active-panel'; @@ -58,3 +62,23 @@ describe('ActivePanelContext', () => { spy.mockRestore(); }); }); + +describe('resolveActivePanel', () => { + const links = [{ id: 'conversations' }, { id: 'prompts' }, { id: 'files' }]; + + it('returns active when it matches a link', () => { + expect(resolveActivePanel('prompts', links)).toBe('prompts'); + }); + + it('falls back to first link when active does not match', () => { + expect(resolveActivePanel('hide-panel', links)).toBe('conversations'); + }); + + it('returns active unchanged when links is empty', () => { + expect(resolveActivePanel('agents', [])).toBe('agents'); + }); + + it('falls back to the only link when active is stale', () => { + expect(resolveActivePanel('agents', [{ id: 'conversations' }])).toBe('conversations'); + }); +}); diff --git a/client/src/Providers/__tests__/MessagesViewContext.spec.tsx b/client/src/Providers/__tests__/MessagesViewContext.spec.tsx new file mode 100644 index 0000000000..88cd6f702d --- /dev/null +++ b/client/src/Providers/__tests__/MessagesViewContext.spec.tsx @@ -0,0 +1,53 @@ +import { renderHook } from '@testing-library/react'; +import { + useOptionalMessagesOperations, + useOptionalMessagesConversation, +} from '../MessagesViewContext'; + +describe('useOptionalMessagesOperations', () => { + it('returns noop stubs when rendered outside MessagesViewProvider', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + + expect(result.current.ask).toBeInstanceOf(Function); + expect(result.current.regenerate).toBeInstanceOf(Function); + expect(result.current.handleContinue).toBeInstanceOf(Function); + expect(result.current.getMessages).toBeInstanceOf(Function); + expect(result.current.setMessages).toBeInstanceOf(Function); + }); + + it('noop stubs do not throw when called', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + + expect(() => result.current.ask({} as never)).not.toThrow(); + expect(() => result.current.regenerate({} as never)).not.toThrow(); + expect(() => result.current.handleContinue({} as never)).not.toThrow(); + expect(() => result.current.setMessages([])).not.toThrow(); + }); + + it('getMessages returns undefined outside the provider', () => { + const { result } = renderHook(() => useOptionalMessagesOperations()); + expect(result.current.getMessages()).toBeUndefined(); + }); + + it('returns stable references across re-renders', () => { + const { result, rerender } = renderHook(() => useOptionalMessagesOperations()); + const first = result.current; + rerender(); + expect(result.current).toBe(first); + }); +}); + +describe('useOptionalMessagesConversation', () => { + it('returns undefined fields when rendered outside MessagesViewProvider', () => { + const { result } = renderHook(() => useOptionalMessagesConversation()); + expect(result.current.conversation).toBeUndefined(); + expect(result.current.conversationId).toBeUndefined(); + }); + + it('returns stable references across re-renders', () => { + const { result, rerender } = renderHook(() => useOptionalMessagesConversation()); + const first = result.current; + rerender(); + expect(result.current).toBe(first); + }); +}); diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 85044bb2bc..6ca408685f 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -355,6 +355,28 @@ export type TOptions = { export type TAskFunction = (props: TAskProps, options?: TOptions) => void; +/** + * Stable context object passed from non-memo'd wrapper components (Message, MessageContent) + * to memo'd inner components (MessageRender, ContentRender) via props. + * + * This avoids subscribing to ChatContext inside memo'd components, which would bypass React.memo + * and cause unnecessary re-renders when `isSubmitting` changes during streaming. + * + * The `isSubmitting` property should use a getter backed by a ref so it returns the current + * value at call-time (for callback guards) without being a reactive dependency. + */ +export type TMessageChatContext = { + ask: (...args: Parameters) => void; + index: number; + regenerate: (message: t.TMessage, options?: { addedConvo?: t.TConversation | null }) => void; + conversation: t.TConversation | null; + latestMessageId: string | undefined; + latestMessageDepth: number | undefined; + handleContinue: (e: React.MouseEvent) => void; + /** Should be a getter backed by a ref — reads current value without triggering re-renders */ + readonly isSubmitting: boolean; +}; + export type TMessageProps = { conversation?: t.TConversation | null; messageId?: string | null; diff --git a/client/src/components/Agents/Marketplace.tsx b/client/src/components/Agents/Marketplace.tsx index 0c9c9fb4cc..816705a0db 100644 --- a/client/src/components/Agents/Marketplace.tsx +++ b/client/src/components/Agents/Marketplace.tsx @@ -7,11 +7,10 @@ import { useDocumentTitle, useHasAccess, useLocalize, TranslationKeys } from '~/ import { useGetEndpointsQuery, useGetAgentCategoriesQuery } from '~/data-provider'; import MarketplaceAdminSettings from './MarketplaceAdminSettings'; import { SidePanelGroup } from '~/components/SidePanel'; -import { NewChat } from '~/components/Nav'; -import { cn } from '~/utils'; import CategoryTabs from './CategoryTabs'; import SearchBar from './SearchBar'; import AgentGrid from './AgentGrid'; +import { cn } from '~/utils'; interface AgentMarketplaceProps { className?: string; @@ -202,12 +201,6 @@ const AgentMarketplace: React.FC = ({ className = '' }) = ref={scrollContainerRef} className="scrollbar-gutter-stable relative flex h-full flex-col overflow-y-auto overflow-x-hidden" > - {/* Simplified header for agents marketplace - only show nav controls when needed */} - {!isSmallScreen && ( -

- )} {/* Hero Section - scrolls away */} {!isSmallScreen && (
@@ -222,9 +215,7 @@ const AgentMarketplace: React.FC = ({ className = '' }) =
)} {/* Sticky wrapper for search bar and categories */} -
+
{/* Search bar */}
diff --git a/client/src/components/Artifacts/ArtifactPreview.tsx b/client/src/components/Artifacts/ArtifactPreview.tsx index c125889c88..8257f76887 100644 --- a/client/src/components/Artifacts/ArtifactPreview.tsx +++ b/client/src/components/Artifacts/ArtifactPreview.tsx @@ -6,7 +6,7 @@ import type { } from '@codesandbox/sandpack-react/unstyled'; import type { TStartupConfig } from 'librechat-data-provider'; import type { ArtifactFiles } from '~/common'; -import { sharedFiles, sharedOptions } from '~/utils/artifacts'; +import { sharedFiles, buildSandpackOptions } from '~/utils/artifacts'; export const ArtifactPreview = memo(function ({ files, @@ -39,15 +39,10 @@ export const ArtifactPreview = memo(function ({ }; }, [currentCode, files, fileKey]); - const options: typeof sharedOptions = useMemo(() => { - if (!startupConfig) { - return sharedOptions; - } - return { - ...sharedOptions, - bundlerURL: template === 'static' ? startupConfig.staticBundlerURL : startupConfig.bundlerURL, - }; - }, [startupConfig, template]); + const options: SandpackProviderProps['options'] = useMemo( + () => buildSandpackOptions(template, startupConfig), + [startupConfig, template], + ); if (Object.keys(artifactFiles).length === 0) { return null; diff --git a/client/src/components/Artifacts/Artifacts.tsx b/client/src/components/Artifacts/Artifacts.tsx index 776f689f08..e2a322b1ad 100644 --- a/client/src/components/Artifacts/Artifacts.tsx +++ b/client/src/components/Artifacts/Artifacts.tsx @@ -1,15 +1,16 @@ -import { useRef, useState, useEffect } from 'react'; +import { useRef, useState, useEffect, useCallback } from 'react'; +import copy from 'copy-to-clipboard'; import * as Tabs from '@radix-ui/react-tabs'; import { Code, Play, RefreshCw, X } from 'lucide-react'; import { useSetRecoilState, useResetRecoilState } from 'recoil'; import { Button, Spinner, useMediaQuery, Radio } from '@librechat/client'; import type { SandpackPreviewRef } from '@codesandbox/sandpack-react'; +import CopyButton from '~/components/Messages/Content/CopyButton'; import { useShareContext, useMutationState } from '~/Providers'; import useArtifacts from '~/hooks/Artifacts/useArtifacts'; import DownloadArtifact from './DownloadArtifact'; import ArtifactVersion from './ArtifactVersion'; import ArtifactTabs from './ArtifactTabs'; -import { CopyCodeButton } from './Code'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; import store from '~/store'; @@ -30,6 +31,7 @@ export default function Artifacts() { const [height, setHeight] = useState(90); const [isDragging, setIsDragging] = useState(false); const [blurAmount, setBlurAmount] = useState(0); + const [isCopied, setIsCopied] = useState(false); const dragStartY = useRef(0); const dragStartHeight = useRef(90); const setArtifactsVisible = useSetRecoilState(store.artifactsVisibility); @@ -86,6 +88,16 @@ export default function Artifacts() { setCurrentArtifactId, } = useArtifacts(); + const handleCopyArtifact = useCallback(() => { + const content = currentArtifact?.content ?? ''; + if (!content) { + return; + } + copy(content, { format: 'text/plain' }); + setIsCopied(true); + setTimeout(() => setIsCopied(false), 3000); + }, [currentArtifact?.content]); + const handleDragStart = (e: React.PointerEvent) => { setIsDragging(true); dragStartY.current = e.clientY; @@ -281,7 +293,7 @@ export default function Artifacts() { }} /> )} - + - ); -}; + useEffect(() => { + const scrollContainer = scrollRef.current; + if (!scrollContainer) { + return; + } + + const handleScroll = () => { + const { scrollTop, scrollHeight, clientHeight } = scrollContainer; + const isNearBottom = scrollHeight - scrollTop - clientHeight < 50; + + if (!isNearBottom) { + setUserScrolled(true); + } else { + setUserScrolled(false); + } + }; + + scrollContainer.addEventListener('scroll', handleScroll); + + return () => { + scrollContainer.removeEventListener('scroll', handleScroll); + }; + }, []); + + useEffect(() => { + const scrollContainer = scrollRef.current; + if (!scrollContainer || !isSubmitting || userScrolled) { + return; + } + + scrollContainer.scrollTop = scrollContainer.scrollHeight; + }, [content, isSubmitting, userScrolled]); + + return ( +
+ + {content} + +
+ ); + }, +); diff --git a/client/src/components/Chat/Input/AudioRecorder.tsx b/client/src/components/Chat/Input/AudioRecorder.tsx index dbf2c29d83..e9e19d0904 100644 --- a/client/src/components/Chat/Input/AudioRecorder.tsx +++ b/client/src/components/Chat/Input/AudioRecorder.tsx @@ -1,4 +1,4 @@ -import { useCallback, useRef } from 'react'; +import { memo, useCallback, useRef } from 'react'; import { MicOff } from 'lucide-react'; import { useToastContext, TooltipAnchor, ListeningIcon, Spinner } from '@librechat/client'; import { useLocalize, useSpeechToText, useGetAudioSettings } from '~/hooks'; @@ -7,7 +7,7 @@ import { globalAudioId } from '~/common'; import { cn } from '~/utils'; const isExternalSTT = (speechToTextEndpoint: string) => speechToTextEndpoint === 'external'; -export default function AudioRecorder({ +export default memo(function AudioRecorder({ disabled, ask, methods, @@ -26,10 +26,12 @@ export default function AudioRecorder({ const { speechToTextEndpoint } = useGetAudioSettings(); const existingTextRef = useRef(''); + const isSubmittingRef = useRef(isSubmitting); + isSubmittingRef.current = isSubmitting; const onTranscriptionComplete = useCallback( (text: string) => { - if (isSubmitting) { + if (isSubmittingRef.current) { showToast({ message: localize('com_ui_speech_while_submitting'), status: 'error', @@ -52,7 +54,7 @@ export default function AudioRecorder({ existingTextRef.current = ''; } }, - [ask, reset, showToast, localize, isSubmitting, speechToTextEndpoint], + [ask, reset, showToast, localize, speechToTextEndpoint], ); const setText = useCallback( @@ -125,4 +127,4 @@ export default function AudioRecorder({ } /> ); -} +}); diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index fed355dcb3..9e0ad7f382 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -3,6 +3,8 @@ import { useWatch } from 'react-hook-form'; import { TextareaAutosize } from '@librechat/client'; import { useRecoilState, useRecoilValue } from 'recoil'; import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider'; +import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter, ConvoGenerator } from '~/common'; import { useChatContext, useChatFormContext, @@ -35,7 +37,30 @@ import BadgeRow from './BadgeRow'; import Mention from './Mention'; import store from '~/store'; -const ChatForm = memo(({ index = 0 }: { index?: number }) => { +interface ChatFormProps { + index: number; + /** From ChatContext — individual values so memo can compare them */ + files: Map; + setFiles: FileSetter; + conversation: TConversation | null; + isSubmitting: boolean; + filesLoading: boolean; + setFilesLoading: React.Dispatch>; + newConversation: ConvoGenerator; + handleStopGenerating: (e: React.MouseEvent) => void; +} + +const ChatForm = memo(function ChatForm({ + index, + files, + setFiles, + conversation, + isSubmitting, + filesLoading, + setFilesLoading, + newConversation, + handleStopGenerating, +}: ChatFormProps) { const submitButtonRef = useRef(null); const textAreaRef = useRef(null); useFocusChatEffect(textAreaRef); @@ -65,15 +90,6 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { const { requiresKey } = useRequiresKey(); const methods = useChatFormContext(); - const { - files, - setFiles, - conversation, - isSubmitting, - filesLoading, - newConversation, - handleStopGenerating, - } = useChatContext(); const { generateConversation, conversation: addedConvo, @@ -120,6 +136,15 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { } }, [isCollapsed]); + const handleTextareaFocus = useCallback(() => { + handleFocusOrClick(); + setIsTextAreaFocused(true); + }, [handleFocusOrClick]); + + const handleTextareaBlur = useCallback(() => { + setIsTextAreaFocused(false); + }, []); + useAutoSave({ files, setFiles, @@ -253,7 +278,12 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { handleSaveBadges={handleSaveBadges} setBadges={setBadges} /> - + {endpoint && (
{ tabIndex={0} data-testid="text-input" rows={1} - onFocus={() => { - handleFocusOrClick(); - setIsTextAreaFocused(true); - }} - onBlur={setIsTextAreaFocused.bind(null, false)} + onFocus={handleTextareaFocus} + onBlur={handleTextareaBlur} aria-label={localize('com_ui_message_input')} onClick={handleFocusOrClick} style={{ height: 44, overflowY: 'auto' }} @@ -315,7 +342,13 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { )} >
- +
{ ); }); +ChatForm.displayName = 'ChatForm'; -export default ChatForm; +/** + * Wrapper that subscribes to ChatContext and passes stable individual values + * to the memo'd ChatForm. This prevents ChatForm from re-rendering on every + * streaming chunk — it only re-renders when the specific values it uses change. + */ +function ChatFormWrapper({ index = 0 }: { index?: number }) { + const { + files, + setFiles, + conversation, + isSubmitting, + filesLoading, + setFilesLoading, + newConversation, + handleStopGenerating, + } = useChatContext(); + + /** + * Stabilize conversation reference: only update when rendering-relevant fields change, + * not on every metadata update (e.g., title generation during streaming). + */ + const hasMessages = (conversation?.messages?.length ?? 0) > 0; + const stableConversation = useMemo( + () => conversation, + // eslint-disable-next-line react-hooks/exhaustive-deps + [ + conversation?.conversationId, + conversation?.endpoint, + conversation?.endpointType, + conversation?.agent_id, + conversation?.assistant_id, + conversation?.spec, + conversation?.useResponsesApi, + conversation?.model, + hasMessages, + ], + ); + + /** Stabilize function refs so they never trigger ChatForm re-renders */ + const handleStopRef = useRef(handleStopGenerating); + handleStopRef.current = handleStopGenerating; + const stableHandleStop = useCallback( + (e: React.MouseEvent) => handleStopRef.current(e), + [], + ); + + const newConvoRef = useRef(newConversation); + newConvoRef.current = newConversation; + const stableNewConversation: ConvoGenerator = useCallback( + (...args: Parameters): ReturnType => + newConvoRef.current(...args), + [], + ); + + return ( + + ); +} + +ChatFormWrapper.displayName = 'ChatFormWrapper'; + +export default ChatFormWrapper; diff --git a/client/src/components/Chat/Input/CollapseChat.tsx b/client/src/components/Chat/Input/CollapseChat.tsx index ea099ed69b..7efe52dc8d 100644 --- a/client/src/components/Chat/Input/CollapseChat.tsx +++ b/client/src/components/Chat/Input/CollapseChat.tsx @@ -52,4 +52,4 @@ const CollapseChat = ({ ); }; -export default CollapseChat; +export default React.memo(CollapseChat); diff --git a/client/src/components/Chat/Input/Files/AttachFile.tsx b/client/src/components/Chat/Input/Files/AttachFile.tsx index 38a3fa8c6f..098fa2c4c3 100644 --- a/client/src/components/Chat/Input/Files/AttachFile.tsx +++ b/client/src/components/Chat/Input/Files/AttachFile.tsx @@ -1,14 +1,33 @@ import React, { useRef } from 'react'; import { FileUpload, TooltipAnchor, AttachmentIcon } from '@librechat/client'; -import { useLocalize, useFileHandling } from '~/hooks'; +import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; +import { useFileHandlingNoChatContext, useLocalize } from '~/hooks'; import { cn } from '~/utils'; -const AttachFile = ({ disabled }: { disabled?: boolean | null }) => { +const AttachFile = ({ + disabled, + files, + setFiles, + setFilesLoading, + conversation, +}: { + disabled?: boolean | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; + conversation: TConversation | null; +}) => { const localize = useLocalize(); const inputRef = useRef(null); const isUploadDisabled = disabled ?? false; - const { handleFileChange } = useFileHandling(); + const { handleFileChange } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, + }); return ( diff --git a/client/src/components/Chat/Input/Files/AttachFileChat.tsx b/client/src/components/Chat/Input/Files/AttachFileChat.tsx index 2f954d01d5..7eb9b0c474 100644 --- a/client/src/components/Chat/Input/Files/AttachFileChat.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileChat.tsx @@ -9,6 +9,7 @@ import { getEndpointFileConfig, } from 'librechat-data-provider'; import type { TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; import { useGetFileConfig, useGetEndpointsQuery, useGetAgentByIdQuery } from '~/data-provider'; import { useAgentsMapContext } from '~/Providers'; import AttachFileMenu from './AttachFileMenu'; @@ -17,9 +18,15 @@ import AttachFile from './AttachFile'; function AttachFileChat({ disableInputs, conversation, + files, + setFiles, + setFilesLoading, }: { disableInputs: boolean; conversation: TConversation | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; }) { const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO; const { endpoint } = conversation ?? { endpoint: null }; @@ -90,7 +97,15 @@ function AttachFileChat({ ); if (isAssistants && endpointSupportsFiles && !isUploadDisabled) { - return ; + return ( + + ); } else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) { return ( ); } diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index 62072e49e5..181d219c08 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -23,15 +23,16 @@ import { bedrockDocumentExtensions, isDocumentSupportedProvider, } from 'librechat-data-provider'; -import type { EndpointFileConfig } from 'librechat-data-provider'; +import type { EndpointFileConfig, TConversation } from 'librechat-data-provider'; +import type { ExtendedFile, FileSetter } from '~/common'; import { useAgentToolPermissions, useAgentCapabilities, useGetAgentsConfig, - useFileHandling, + useFileHandlingNoChatContext, useLocalize, } from '~/hooks'; -import useSharePointFileHandling from '~/hooks/Files/useSharePointFileHandling'; +import { useSharePointFileHandlingNoChatContext } from '~/hooks/Files/useSharePointFileHandling'; import { SharePointPickerDialog } from '~/components/SharePoint'; import { useGetStartupConfig } from '~/data-provider'; import { ephemeralAgentByConvoId } from '~/store'; @@ -53,6 +54,10 @@ interface AttachFileMenuProps { endpointType?: EModelEndpoint | string; endpointFileConfig?: EndpointFileConfig; useResponsesApi?: boolean; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; + conversation: TConversation | null; } const AttachFileMenu = ({ @@ -63,6 +68,10 @@ const AttachFileMenu = ({ conversationId, endpointFileConfig, useResponsesApi, + files, + setFiles, + setFilesLoading, + conversation, }: AttachFileMenuProps) => { const localize = useLocalize(); const isUploadDisabled = disabled ?? false; @@ -72,10 +81,17 @@ const AttachFileMenu = ({ ephemeralAgentByConvoId(conversationId), ); const [toolResource, setToolResource] = useState(); - const { handleFileChange } = useFileHandling(); - const { handleSharePointFiles, isProcessing, downloadProgress } = useSharePointFileHandling({ - toolResource, + const { handleFileChange } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, }); + const { handleSharePointFiles, isProcessing, downloadProgress } = + useSharePointFileHandlingNoChatContext( + { toolResource }, + { files, setFiles, setFilesLoading, conversation }, + ); const { agentsConfig } = useGetAgentsConfig(); const { data: startupConfig } = useGetStartupConfig(); diff --git a/client/src/components/Chat/Input/Files/FileFormChat.tsx b/client/src/components/Chat/Input/Files/FileFormChat.tsx index 3c01f2d642..4d37938c5d 100644 --- a/client/src/components/Chat/Input/Files/FileFormChat.tsx +++ b/client/src/components/Chat/Input/Files/FileFormChat.tsx @@ -1,16 +1,30 @@ import { memo } from 'react'; import { useRecoilValue } from 'recoil'; import type { TConversation } from 'librechat-data-provider'; -import { useChatContext } from '~/Providers'; -import { useFileHandling } from '~/hooks'; +import type { ExtendedFile, FileSetter } from '~/common'; +import { useFileHandlingNoChatContext } from '~/hooks'; import FileRow from './FileRow'; import store from '~/store'; -function FileFormChat({ conversation }: { conversation: TConversation | null }) { - const { files, setFiles, setFilesLoading } = useChatContext(); +function FileFormChat({ + conversation, + files, + setFiles, + setFilesLoading, +}: { + conversation: TConversation | null; + files: Map; + setFiles: FileSetter; + setFilesLoading: React.Dispatch>; +}) { const chatDirection = useRecoilValue(store.chatDirection).toLowerCase(); const { endpoint: _endpoint } = conversation ?? { endpoint: null }; - const { abortUpload } = useFileHandling(); + const { abortUpload } = useFileHandlingNoChatContext(undefined, { + files, + setFiles, + setFilesLoading, + conversation, + }); const isRTL = chatDirection === 'rtl'; diff --git a/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx index cea55f5ce8..80f06a1b89 100644 --- a/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx +++ b/client/src/components/Chat/Input/Files/__tests__/AttachFileChat.spec.tsx @@ -59,7 +59,13 @@ function renderComponent(conversation: Record | null, disableIn return render( - + {}} + setFilesLoading={() => {}} + /> , ); diff --git a/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx index cf08721207..c2710d4ef8 100644 --- a/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx +++ b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx @@ -9,13 +9,14 @@ jest.mock('~/hooks', () => ({ useAgentToolPermissions: jest.fn(), useAgentCapabilities: jest.fn(), useGetAgentsConfig: jest.fn(), - useFileHandling: jest.fn(), + useFileHandlingNoChatContext: jest.fn(), useLocalize: jest.fn(), })); jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({ __esModule: true, default: jest.fn(), + useSharePointFileHandlingNoChatContext: jest.fn(), })); jest.mock('~/data-provider', () => ({ @@ -52,6 +53,7 @@ jest.mock('@librechat/client', () => { ), AttachmentIcon: () => R.createElement('span', { 'data-testid': 'attachment-icon' }), SharePointIcon: () => R.createElement('span', { 'data-testid': 'sharepoint-icon' }), + useToastContext: () => ({ showToast: jest.fn() }), }; }); @@ -66,11 +68,14 @@ jest.mock('@ariakit/react', () => { const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions; const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities; const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig; -const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling; +const mockUseFileHandlingNoChatContext = jest.requireMock('~/hooks').useFileHandlingNoChatContext; const mockUseLocalize = jest.requireMock('~/hooks').useLocalize; const mockUseSharePointFileHandling = jest.requireMock( '~/hooks/Files/useSharePointFileHandling', ).default; +const mockUseSharePointFileHandlingNoChatContext = jest.requireMock( + '~/hooks/Files/useSharePointFileHandling', +).useSharePointFileHandlingNoChatContext; const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig; const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } }); @@ -92,12 +97,15 @@ function setupMocks(overrides: { provider?: string } = {}) { codeEnabled: false, }); mockUseGetAgentsConfig.mockReturnValue({ agentsConfig: {} }); - mockUseFileHandling.mockReturnValue({ handleFileChange: jest.fn() }); - mockUseSharePointFileHandling.mockReturnValue({ + mockUseFileHandlingNoChatContext.mockReturnValue({ handleFileChange: jest.fn() }); + const sharePointReturnValue = { handleSharePointFiles: jest.fn(), isProcessing: false, downloadProgress: 0, - }); + error: null, + }; + mockUseSharePointFileHandling.mockReturnValue(sharePointReturnValue); + mockUseSharePointFileHandlingNoChatContext.mockReturnValue(sharePointReturnValue); mockUseGetStartupConfig.mockReturnValue({ data: { sharePointFilePickerEnabled: false } }); mockUseAgentToolPermissions.mockReturnValue({ fileSearchAllowedByAgent: false, @@ -110,7 +118,14 @@ function renderMenu(props: Record = {}) { return render( - + {}} + setFilesLoading={() => {}} + conversation={null} + {...props} + /> , ); diff --git a/client/src/components/Chat/Input/StopButton.tsx b/client/src/components/Chat/Input/StopButton.tsx index 4a058777f1..fd94ba806c 100644 --- a/client/src/components/Chat/Input/StopButton.tsx +++ b/client/src/components/Chat/Input/StopButton.tsx @@ -1,8 +1,15 @@ +import { memo } from 'react'; import { TooltipAnchor } from '@librechat/client'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; -export default function StopButton({ stop, setShowStopButton }) { +export default memo(function StopButton({ + stop, + setShowStopButton, +}: { + stop: (e: React.MouseEvent) => void; + setShowStopButton: (value: boolean) => void; +}) { const localize = useLocalize(); return ( @@ -34,4 +41,4 @@ export default function StopButton({ stop, setShowStopButton }) { } > ); -} +}); diff --git a/client/src/components/Chat/Input/TextareaHeader.tsx b/client/src/components/Chat/Input/TextareaHeader.tsx index 9e67252efe..06c1802585 100644 --- a/client/src/components/Chat/Input/TextareaHeader.tsx +++ b/client/src/components/Chat/Input/TextareaHeader.tsx @@ -1,8 +1,9 @@ +import { memo } from 'react'; import AddedConvo from './AddedConvo'; import type { TConversation } from 'librechat-data-provider'; import type { SetterOrUpdater } from 'recoil'; -export default function TextareaHeader({ +export default memo(function TextareaHeader({ addedConvo, setAddedConvo, }: { @@ -17,4 +18,4 @@ export default function TextareaHeader({
); -} +}); diff --git a/client/src/components/Chat/Menus/Endpoints/ModelSelector.tsx b/client/src/components/Chat/Menus/Endpoints/ModelSelector.tsx index 2c90f57598..b59b718743 100644 --- a/client/src/components/Chat/Menus/Endpoints/ModelSelector.tsx +++ b/client/src/components/Chat/Menus/Endpoints/ModelSelector.tsx @@ -15,6 +15,8 @@ import { CustomMenu as Menu } from './CustomMenu'; import DialogManager from './DialogManager'; import { useLocalize } from '~/hooks'; +const defaultInterface = getConfigDefaults().interface; + function ModelSelectorContent() { const localize = useLocalize(); @@ -122,7 +124,7 @@ function ModelSelectorContent() { } export default function ModelSelector({ startupConfig }: ModelSelectorProps) { - const interfaceConfig = startupConfig?.interface ?? getConfigDefaults().interface; + const interfaceConfig = startupConfig?.interface ?? defaultInterface; const modelSpecs = startupConfig?.modelSpecs?.list ?? []; // Hide the selector when modelSelect is false and there are no model specs to show diff --git a/client/src/components/Chat/Messages/Content/AgentHandoff.tsx b/client/src/components/Chat/Messages/Content/AgentHandoff.tsx index f5fa162ff2..5a5505ee60 100644 --- a/client/src/components/Chat/Messages/Content/AgentHandoff.tsx +++ b/client/src/components/Chat/Messages/Content/AgentHandoff.tsx @@ -1,24 +1,23 @@ import React, { useMemo, useState } from 'react'; -import { EModelEndpoint, Constants } from 'librechat-data-provider'; import { ChevronDown } from 'lucide-react'; +import { EModelEndpoint, Constants } from 'librechat-data-provider'; import type { TMessage } from 'librechat-data-provider'; import MessageIcon from '~/components/Share/MessageIcon'; +import { useLocalize, useExpandCollapse } from '~/hooks'; import { useAgentsMapContext } from '~/Providers'; -import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; interface AgentHandoffProps { name: string; args: string | Record; - output?: string | null; } const AgentHandoff: React.FC = ({ name, args: _args = '' }) => { const localize = useLocalize(); const agentsMap = useAgentsMapContext(); const [showInfo, setShowInfo] = useState(false); + const { style: expandStyle, ref: expandRef } = useExpandCollapse(showInfo); - /** Extracted agent ID from tool name (e.g., "lc_transfer_to_agent_gUV0wMb7zHt3y3Xjz-8_4" -> "agent_gUV0wMb7zHt3y3Xjz-8_4") */ const targetAgentId = useMemo(() => { if (typeof name !== 'string' || !name.startsWith(Constants.LC_TRANSFER_TO_)) { return null; @@ -44,19 +43,24 @@ const AgentHandoff: React.FC = ({ name, args: _args = '' }) = } }, [_args]) as string; - /** Requires more than 2 characters as can be an empty object: `{}` */ const hasInfo = useMemo(() => (args?.trim()?.length ?? 0) > 2, [args]); return ( -
-
+ +
+
+ {hasInfo && ( +
+
+ {localize('com_ui_handoff_instructions')}: +
+
{args}
+
+ )}
- )} +
); }; diff --git a/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx b/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx index 139496c621..3d4fdee1c9 100644 --- a/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx +++ b/client/src/components/Chat/Messages/Content/CodeAnalyze.tsx @@ -1,8 +1,10 @@ -import { useState } from 'react'; +import { useState, useEffect } from 'react'; import { useRecoilValue } from 'recoil'; +import { Terminal } from 'lucide-react'; import { useProgress, useLocalize } from '~/hooks'; import ProgressText from './ProgressText'; import MarkdownLite from './MarkdownLite'; +import { cn } from '~/utils'; import store from '~/store'; export default function CodeAnalyze({ @@ -16,8 +18,14 @@ export default function CodeAnalyze({ }) { const localize = useLocalize(); const progress = useProgress(initialProgress); - const showAnalysisCode = useRecoilValue(store.showCode); - const [showCode, setShowCode] = useState(showAnalysisCode); + const autoExpand = useRecoilValue(store.autoExpandTools); + const [showCode, setShowCode] = useState(autoExpand); + + useEffect(() => { + if (autoExpand) { + setShowCode(true); + } + }, [autoExpand]); const logs = outputs.reduce((acc, output) => { if (output['logs']) { @@ -28,7 +36,10 @@ export default function CodeAnalyze({ return ( <> -
+ + {progress < 1 ? localize('com_ui_analyzing') : localize('com_ui_analyzing_finished')} + +
setShowCode((prev) => !prev)} @@ -36,6 +47,12 @@ export default function CodeAnalyze({ finishedText={localize('com_ui_analyzing_finished')} hasInput={!!code.length} isExpanded={showCode} + icon={ +
{showCode && ( diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index 4b431d7a98..65ebc66908 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -6,12 +6,12 @@ import type { TAttachment, Agents, } from 'librechat-data-provider'; -import { MessageContext, SearchContext } from '~/Providers'; import { ParallelContentRenderer, type PartWithIndex } from './ParallelContent'; -import { mapAttachments } from '~/utils'; +import { mapAttachments, groupSequentialToolCalls } from '~/utils'; +import { MessageContext, SearchContext } from '~/Providers'; import { EditTextPart, EmptyText } from './Parts'; import MemoryArtifacts from './MemoryArtifacts'; -import Sources from '~/components/Web/Sources'; +import ToolCallGroup from './ToolCallGroup'; import Container from './Container'; import Part from './Part'; @@ -160,10 +160,10 @@ const ContentParts = memo(function ContentParts({ } const isTextPart = part?.type === ContentTypes.TEXT || - typeof (part as unknown as Agents.MessageContentText)?.text !== 'string'; + typeof (part as unknown as Agents.MessageContentText)?.text === 'string'; const isThinkPart = part?.type === ContentTypes.THINK || - typeof (part as unknown as Agents.ReasoningDeltaUpdate)?.think !== 'string'; + typeof (part as unknown as Agents.ReasoningDeltaUpdate)?.think === 'string'; if (!isTextPart && !isThinkPart) { return null; } @@ -216,17 +216,32 @@ const ContentParts = memo(function ContentParts({ sequentialParts.push({ part, idx }); } }); + const groupedParts = groupSequentialToolCalls(sequentialParts); return ( - {showEmptyCursor && ( )} - {sequentialParts.map(({ part, idx }) => renderPart(part, idx, idx === lastContentIdx))} + {groupedParts.map((group) => { + if (group.type === 'single') { + const { part, idx } = group.part; + return renderPart(part, idx, idx === lastContentIdx); + } + return ( + p.idx === lastContentIdx)} + renderPart={renderPart} + lastContentIdx={lastContentIdx} + /> + ); + })} ); }); diff --git a/client/src/components/Chat/Messages/Content/FilePreviewDialog.tsx b/client/src/components/Chat/Messages/Content/FilePreviewDialog.tsx new file mode 100644 index 0000000000..c02e2fee4b --- /dev/null +++ b/client/src/components/Chat/Messages/Content/FilePreviewDialog.tsx @@ -0,0 +1,344 @@ +import { useState, useEffect, useCallback, useMemo, useRef } from 'react'; +import copy from 'copy-to-clipboard'; +import { useRecoilValue } from 'recoil'; +import { Download } from 'lucide-react'; +import { OGDialog, OGDialogContent, OGDialogTitle, OGDialogDescription } from '@librechat/client'; +import CopyButton from '~/components/Messages/Content/CopyButton'; +import { logger, sortPagesByRelevance } from '~/utils'; +import { useFileDownload } from '~/data-provider'; +import { useLocalize } from '~/hooks'; +import store from '~/store'; + +interface FilePreviewDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + fileName: string; + fileId?: string; + relevance?: number; + pages?: number[]; + pageRelevance?: Record; + fileType?: string; + fileSize?: number; +} + +function getFileExtension(filename: string): string { + const dot = filename.lastIndexOf('.'); + return dot > 0 ? filename.slice(dot + 1).toLowerCase() : ''; +} + +function canPreviewByMime(mime?: string): 'pdf' | 'text' | false { + if (!mime) { + return false; + } + if (mime.includes('pdf')) { + return 'pdf'; + } + if ( + mime.startsWith('text/') || + mime.includes('json') || + mime.includes('xml') || + mime.includes('javascript') || + mime.includes('typescript') || + mime.includes('yaml') || + mime.includes('csv') + ) { + return 'text'; + } + return false; +} + +function canPreviewByExt(filename: string): 'pdf' | 'text' | false { + const ext = getFileExtension(filename); + if (ext === 'pdf') { + return 'pdf'; + } + const textExts = new Set([ + 'txt', + 'md', + 'csv', + 'json', + 'xml', + 'yaml', + 'yml', + 'html', + 'css', + 'js', + 'ts', + 'jsx', + 'tsx', + 'py', + 'rb', + 'java', + 'c', + 'cpp', + 'h', + 'go', + 'rs', + 'sh', + 'sql', + 'log', + ]); + return textExts.has(ext) ? 'text' : false; +} + +/** Formats bytes with unit suffix (differs from ~/utils/formatBytes which returns a raw number). */ +function formatBytes(bytes: number): string { + if (bytes >= 1048576) { + return `${(bytes / 1048576).toFixed(1)} MB`; + } + if (bytes >= 1024) { + return `${(bytes / 1024).toFixed(1)} KB`; + } + return `${bytes} B`; +} + +function getDisplayType(fileType?: string, fileName?: string): string { + if (fileType) { + if (fileType.includes('pdf')) { + return 'PDF'; + } + if (fileType.includes('word') || fileType.includes('document')) { + return 'Document'; + } + if (fileType.includes('spreadsheet') || fileType.includes('excel')) { + return 'Spreadsheet'; + } + if (fileType.includes('presentation') || fileType.includes('powerpoint')) { + return 'Presentation'; + } + if (fileType.includes('image')) { + return 'Image'; + } + if (fileType.startsWith('text/')) { + return fileType.split('/')[1]?.toUpperCase() || 'Text'; + } + if (fileType.includes('json')) { + return 'JSON'; + } + if (fileType.includes('xml')) { + return 'XML'; + } + } + const ext = fileName ? getFileExtension(fileName) : ''; + return ext ? ext.toUpperCase() : 'File'; +} + +export default function FilePreviewDialog({ + open, + onOpenChange, + fileName, + fileId, + relevance, + pages, + pageRelevance, + fileType, + fileSize, +}: FilePreviewDialogProps) { + const localize = useLocalize(); + const user = useRecoilValue(store.user); + const { refetch: downloadFile } = useFileDownload(user?.id ?? '', fileId); + + const [fileContent, setFileContent] = useState(null); + const [fileBlobUrl, setFileBlobUrl] = useState(null); + const [loading, setLoading] = useState(false); + const [previewError, setPreviewError] = useState(false); + const [isCopied, setIsCopied] = useState(false); + const loadingRef = useRef(false); + + const previewKind = canPreviewByMime(fileType) || canPreviewByExt(fileName); + + const cancelledRef = useRef(false); + + const loadPreview = useCallback(async () => { + if (!fileId || !previewKind || loadingRef.current) { + return; + } + loadingRef.current = true; + cancelledRef.current = false; + setLoading(true); + setPreviewError(false); + + try { + const result = await downloadFile(); + if (cancelledRef.current || !result.data) { + if (!cancelledRef.current) { + setPreviewError(true); + } + return; + } + + const resp = await fetch(result.data); + const blob = await resp.blob(); + + if (cancelledRef.current) { + return; + } + + if (previewKind === 'text') { + setFileContent(await blob.text()); + } else { + const typed = new Blob([blob], { type: 'application/pdf' }); + setFileBlobUrl(URL.createObjectURL(typed)); + } + } catch { + if (!cancelledRef.current) { + setPreviewError(true); + } + } finally { + loadingRef.current = false; + if (!cancelledRef.current) { + setLoading(false); + } + } + }, [fileId, previewKind, downloadFile]); + + const handleDownload = useCallback(async () => { + if (!fileId) { + return; + } + try { + const result = await downloadFile(); + if (!result.data) { + return; + } + const a = document.createElement('a'); + a.href = result.data; + a.setAttribute('download', fileName); + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + setTimeout(() => URL.revokeObjectURL(result.data), 1000); + } catch (err) { + logger.error('[FilePreviewDialog] Download failed:', err); + } + }, [downloadFile, fileId, fileName]); + + useEffect(() => { + if (open && previewKind && !fileContent && !fileBlobUrl) { + loadPreview(); + } + }, [open, previewKind, fileContent, fileBlobUrl, loadPreview]); + + useEffect(() => { + return () => { + if (fileBlobUrl) { + URL.revokeObjectURL(fileBlobUrl); + } + }; + }, [fileBlobUrl]); + + useEffect(() => { + if (!open) { + cancelledRef.current = true; + setFileContent(null); + setFileBlobUrl(null); + setPreviewError(false); + setLoading(false); + setIsCopied(false); + } + }, [open]); + + const handleCopy = useCallback(() => { + if (!fileContent) { + return; + } + copy(fileContent, { format: 'text/plain' }); + setIsCopied(true); + setTimeout(() => setIsCopied(false), 3000); + }, [fileContent]); + + const displayType = useMemo(() => getDisplayType(fileType, fileName), [fileType, fileName]); + const sortedPages = useMemo( + () => (pages && pageRelevance ? sortPagesByRelevance(pages, pageRelevance) : pages), + [pages, pageRelevance], + ); + + const metaParts: string[] = [displayType]; + if (relevance != null && relevance > 0) { + metaParts.push(`${localize('com_ui_relevance')}: ${Math.round(relevance * 100)}%`); + } + if (fileSize != null && fileSize > 0) { + metaParts.push(formatBytes(fileSize)); + } + if (sortedPages && sortedPages.length > 0) { + metaParts.push(localize('com_file_pages', { pages: sortedPages.join(', ') })); + } + + return ( + + +
+ {fileName} +
+ + {metaParts.join(' · ')} + + {fileId && ( + + )} +
+
+ +
+ {loading && ( +
+ + {localize('com_ui_loading')} + +
+ )} + {previewError && ( +
+ + {localize('com_ui_preview_unavailable')} + +
+ )} + {fileBlobUrl && ( +