From ad5c51f62b321bd6f714e946ab08190616c09ecf Mon Sep 17 00:00:00 2001 From: matt burnett Date: Tue, 10 Mar 2026 11:21:36 -0700 Subject: [PATCH 01/39] =?UTF-8?q?=E2=9B=88=EF=B8=8F=20fix:=20MCP=20Reconne?= =?UTF-8?q?ction=20Storm=20Prevention=20with=20Circuit=20Breaker,=20Backof?= =?UTF-8?q?f,=20and=20Tool=20Stubs=20(#12162)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: MCP reconnection stability - circuit breaker, throttling, and cooldown retry * Comment and logging cleanup * fix broken tests --- api/server/services/MCP.js | 66 ++++- packages/api/src/mcp/UserConnectionManager.ts | 3 + .../src/mcp/__tests__/MCPConnection.test.ts | 239 ++++++++++++++++++ .../MCPConnectionAgentLifecycle.test.ts | 4 + packages/api/src/mcp/connection.ts | 117 ++++++++- .../oauth/OAuthReconnectionManager.test.ts | 162 +++++++++++- .../src/mcp/oauth/OAuthReconnectionManager.ts | 44 +++- .../oauth/OAuthReconnectionTracker.test.ts | 95 +++++++ .../src/mcp/oauth/OAuthReconnectionTracker.ts | 44 +++- 9 files changed, 736 insertions(+), 38 deletions(-) diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index ad1f9f5cc3..4f8cdc8195 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -34,6 +34,39 @@ const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); +const lastReconnectAttempts = new Map(); +const RECONNECT_THROTTLE_MS = 10_000; + +const missingToolCache = new Map(); +const MISSING_TOOL_TTL_MS = 10_000; + +const unavailableMsg = + "This tool's MCP server is temporarily unavailable. Please try again shortly."; + +/** + * @param {string} toolName + * @param {string} serverName + */ +function createUnavailableToolStub(toolName, serverName) { + const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`; + const _call = async () => unavailableMsg; + const toolInstance = tool(_call, { + schema: { + type: 'object', + properties: { + input: { type: 'string', description: 'Input for the tool' }, + }, + required: [], + }, + name: normalizedToolKey, + description: unavailableMsg, + responseFormat: AgentConstants.CONTENT_AND_ARTIFACT, + }); + toolInstance.mcp = true; + toolInstance.mcpRawServerName = serverName; + return toolInstance; +} + function isEmptyObjectSchema(jsonSchema) { return ( jsonSchema != null && @@ -211,6 +244,16 @@ async function reconnectServer({ logger.debug( `[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`, ); + + const throttleKey = `${user.id}:${serverName}`; + const now = Date.now(); + const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0; + if (now - lastAttempt < RECONNECT_THROTTLE_MS) { + logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`); + return null; + } + lastReconnectAttempts.set(throttleKey, now); + const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; const flowId = `${user.id}:${serverName}:${Date.now()}`; const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); @@ -267,7 +310,7 @@ async function reconnectServer({ userMCPAuthMap, forceNew: true, returnOnOAuth: false, - connectionTimeout: Time.TWO_MINUTES, + connectionTimeout: Time.THIRTY_SECONDS, }); } finally { // Clean up abort handler to prevent memory leaks @@ -332,7 +375,7 @@ async function createMCPTools({ }); if (!result || !result.tools) { logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); - return; + return []; } const serverTools = []; @@ -402,6 +445,14 @@ async function createMCPTool({ /** @type {LCTool | undefined} */ let toolDefinition = availableTools?.[toolKey]?.function; if (!toolDefinition) { + const cachedAt = missingToolCache.get(toolKey); + if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) { + logger.debug( + `[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`, + ); + return createUnavailableToolStub(toolName, serverName); + } + logger.warn( `[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`, ); @@ -415,11 +466,17 @@ async function createMCPTool({ streamId, }); toolDefinition = result?.availableTools?.[toolKey]?.function; + + if (!toolDefinition) { + missingToolCache.set(toolKey, Date.now()); + } } if (!toolDefinition) { - logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`); - return; + logger.warn( + `[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`, + ); + return createUnavailableToolStub(toolName, serverName); } return createToolInstance({ @@ -720,4 +777,5 @@ module.exports = { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus, + createUnavailableToolStub, }; diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index c0ecd18fe2..0828b1720a 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -65,6 +65,9 @@ export abstract class UserConnectionManager { const userServerMap = this.userConnections.get(userId); let connection = forceNew ? undefined : userServerMap?.get(serverName); + if (forceNew) { + MCPConnection.clearCooldown(serverName); + } const now = Date.now(); // Check if user is idle diff --git a/packages/api/src/mcp/__tests__/MCPConnection.test.ts b/packages/api/src/mcp/__tests__/MCPConnection.test.ts index 4cca1b3316..5cb5606d57 100644 --- a/packages/api/src/mcp/__tests__/MCPConnection.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnection.test.ts @@ -559,3 +559,242 @@ describe('extractSSEErrorMessage', () => { }); }); }); + +/** + * Tests for circuit breaker logic. + * + * Uses standalone implementations that mirror the static/private circuit breaker + * methods in MCPConnection. Same approach as the error detection tests above. + */ +describe('MCPConnection Circuit Breaker', () => { + /** 5 cycles within 60s triggers a 30s cooldown */ + const CB_MAX_CYCLES = 5; + const CB_CYCLE_WINDOW_MS = 60_000; + const CB_CYCLE_COOLDOWN_MS = 30_000; + + /** 3 failed rounds within 120s triggers exponential backoff (30s - 300s) */ + const CB_MAX_FAILED_ROUNDS = 3; + const CB_FAILED_WINDOW_MS = 120_000; + const CB_BASE_BACKOFF_MS = 30_000; + const CB_MAX_BACKOFF_MS = 300_000; + + interface CircuitBreakerState { + cycleCount: number; + cycleWindowStart: number; + cooldownUntil: number; + failedRounds: number; + failedWindowStart: number; + failedBackoffUntil: number; + } + + function createCB(): CircuitBreakerState { + return { + cycleCount: 0, + cycleWindowStart: Date.now(), + cooldownUntil: 0, + failedRounds: 0, + failedWindowStart: Date.now(), + failedBackoffUntil: 0, + }; + } + + function isCircuitOpen(cb: CircuitBreakerState): boolean { + const now = Date.now(); + return now < cb.cooldownUntil || now < cb.failedBackoffUntil; + } + + function recordCycle(cb: CircuitBreakerState): void { + const now = Date.now(); + if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) { + cb.cycleCount = 0; + cb.cycleWindowStart = now; + } + cb.cycleCount++; + if (cb.cycleCount >= CB_MAX_CYCLES) { + cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS; + cb.cycleCount = 0; + cb.cycleWindowStart = now; + } + } + + function recordFailedRound(cb: CircuitBreakerState): void { + const now = Date.now(); + if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) { + cb.failedRounds = 0; + cb.failedWindowStart = now; + } + cb.failedRounds++; + if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) { + const backoff = Math.min( + CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS), + CB_MAX_BACKOFF_MS, + ); + cb.failedBackoffUntil = now + backoff; + } + } + + function resetFailedRounds(cb: CircuitBreakerState): void { + cb.failedRounds = 0; + cb.failedWindowStart = Date.now(); + cb.failedBackoffUntil = 0; + } + + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + describe('cycle tracking', () => { + it('should not trigger cooldown for fewer than 5 cycles', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_CYCLES - 1; i++) { + recordCycle(cb); + } + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should trigger 30s cooldown after 5 cycles within 60s', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_CYCLES; i++) { + recordCycle(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + jest.advanceTimersByTime(29_000); + expect(isCircuitOpen(cb)).toBe(true); + + jest.advanceTimersByTime(1_000); + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should reset cycle count when window expires', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_CYCLES - 1; i++) { + recordCycle(cb); + } + + jest.advanceTimersByTime(CB_CYCLE_WINDOW_MS + 1); + + recordCycle(cb); + expect(isCircuitOpen(cb)).toBe(false); + }); + }); + + describe('failed round tracking', () => { + it('should not trigger backoff for fewer than 3 failures', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_FAILED_ROUNDS - 1; i++) { + recordFailedRound(cb); + } + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should trigger 30s backoff after 3 failures within 120s', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) { + recordFailedRound(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + jest.advanceTimersByTime(CB_BASE_BACKOFF_MS); + expect(isCircuitOpen(cb)).toBe(false); + }); + + it('should use exponential backoff based on failure count', () => { + jest.setSystemTime(Date.now()); + + const cb = createCB(); + + for (let i = 0; i < 3; i++) { + recordFailedRound(cb); + } + expect(cb.failedBackoffUntil - Date.now()).toBe(30_000); + + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(60_000); + + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(120_000); + + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(240_000); + + // capped at 300s + recordFailedRound(cb); + expect(cb.failedBackoffUntil - Date.now()).toBe(300_000); + }); + + it('should reset failed window when window expires', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + recordFailedRound(cb); + recordFailedRound(cb); + + jest.advanceTimersByTime(CB_FAILED_WINDOW_MS + 1); + + recordFailedRound(cb); + expect(isCircuitOpen(cb)).toBe(false); + }); + }); + + describe('resetFailedRounds', () => { + it('should clear failed round state on successful connection', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const cb = createCB(); + for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) { + recordFailedRound(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + resetFailedRounds(cb); + expect(isCircuitOpen(cb)).toBe(false); + expect(cb.failedRounds).toBe(0); + expect(cb.failedBackoffUntil).toBe(0); + }); + }); + + describe('clearCooldown (registry deletion)', () => { + it('should allow connections after clearing circuit breaker state', () => { + const now = Date.now(); + jest.setSystemTime(now); + + const registry = new Map(); + const serverName = 'test-server'; + + const cb = createCB(); + registry.set(serverName, cb); + + for (let i = 0; i < CB_MAX_CYCLES; i++) { + recordCycle(cb); + } + expect(isCircuitOpen(cb)).toBe(true); + + registry.delete(serverName); + + const newCb = createCB(); + expect(isCircuitOpen(newCb)).toBe(false); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts index 14e0694558..281bd590db 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionAgentLifecycle.test.ts @@ -207,6 +207,7 @@ describe('MCPConnection Agent lifecycle – streamable-http', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); @@ -366,6 +367,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test-sse'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); @@ -453,6 +455,7 @@ describe('Regression: old per-request Agent pattern leaks agents', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test-regression'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); @@ -675,6 +678,7 @@ describe('MCPConnection SSE GET stream recovery – integration', () => { }); afterEach(async () => { + MCPConnection.clearCooldown('test-sse-recovery'); await safeDisconnect(conn); conn = null; jest.restoreAllMocks(); diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index 83f1af1824..cac0a4afc5 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -71,6 +71,25 @@ const FIVE_MINUTES = 5 * 60 * 1000; const DEFAULT_TIMEOUT = 60000; /** SSE connections through proxies may need longer initial handshake time */ const SSE_CONNECT_TIMEOUT = 120000; +const DEFAULT_INIT_TIMEOUT = 30000; + +interface CircuitBreakerState { + cycleCount: number; + cycleWindowStart: number; + cooldownUntil: number; + failedRounds: number; + failedWindowStart: number; + failedBackoffUntil: number; +} + +const CB_MAX_CYCLES = 5; +const CB_CYCLE_WINDOW_MS = 60_000; +const CB_CYCLE_COOLDOWN_MS = 30_000; + +const CB_MAX_FAILED_ROUNDS = 3; +const CB_FAILED_WINDOW_MS = 120_000; +const CB_BASE_BACKOFF_MS = 30_000; +const CB_MAX_BACKOFF_MS = 300_000; /** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */ const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES; @@ -274,6 +293,80 @@ export class MCPConnection extends EventEmitter { */ public readonly createdAt: number; + private static circuitBreakers: Map = new Map(); + + public static clearCooldown(serverName: string): void { + MCPConnection.circuitBreakers.delete(serverName); + logger.debug(`[MCP][${serverName}] Circuit breaker state cleared`); + } + + private getCircuitBreaker(): CircuitBreakerState { + let cb = MCPConnection.circuitBreakers.get(this.serverName); + if (!cb) { + cb = { + cycleCount: 0, + cycleWindowStart: Date.now(), + cooldownUntil: 0, + failedRounds: 0, + failedWindowStart: Date.now(), + failedBackoffUntil: 0, + }; + MCPConnection.circuitBreakers.set(this.serverName, cb); + } + return cb; + } + + private isCircuitOpen(): boolean { + const cb = this.getCircuitBreaker(); + const now = Date.now(); + return now < cb.cooldownUntil || now < cb.failedBackoffUntil; + } + + private recordCycle(): void { + const cb = this.getCircuitBreaker(); + const now = Date.now(); + if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) { + cb.cycleCount = 0; + cb.cycleWindowStart = now; + } + cb.cycleCount++; + if (cb.cycleCount >= CB_MAX_CYCLES) { + cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS; + cb.cycleCount = 0; + cb.cycleWindowStart = now; + logger.warn( + `${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${CB_CYCLE_COOLDOWN_MS}ms`, + ); + } + } + + private recordFailedRound(): void { + const cb = this.getCircuitBreaker(); + const now = Date.now(); + if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) { + cb.failedRounds = 0; + cb.failedWindowStart = now; + } + cb.failedRounds++; + if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) { + const backoff = Math.min( + CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS), + CB_MAX_BACKOFF_MS, + ); + cb.failedBackoffUntil = now + backoff; + logger.warn( + `${this.getLogPrefix()} Circuit breaker: too many failures, backing off for ${backoff}ms`, + ); + } + } + + private resetFailedRounds(): void { + const cb = this.getCircuitBreaker(); + cb.failedRounds = 0; + cb.failedWindowStart = Date.now(); + cb.failedBackoffUntil = 0; + } + setRequestHeaders(headers: Record | null): void { if (!headers) { return; @@ -686,6 +779,12 @@ export class MCPConnection extends EventEmitter { return; } + if (this.isCircuitOpen()) { + this.connectionState = 'error'; + this.emit('connectionChange', 'error'); + throw new Error(`${this.getLogPrefix()} Circuit breaker is open, connection attempt blocked`); + } + this.emit('connectionChange', 'connecting'); this.connectPromise = (async () => { @@ -703,7 +802,7 @@ export class MCPConnection extends EventEmitter { this.transport = await runOutsideTracing(() => this.constructTransport(this.options)); this.patchTransportSend(); - const connectTimeout = this.options.initTimeout ?? 120000; + const connectTimeout = this.options.initTimeout ?? DEFAULT_INIT_TIMEOUT; await runOutsideTracing(() => withTimeout( this.client.connect(this.transport!), @@ -716,6 +815,7 @@ export class MCPConnection extends EventEmitter { this.connectionState = 'connected'; this.emit('connectionChange', 'connected'); this.reconnectAttempts = 0; + this.resetFailedRounds(); } catch (error) { // Check if it's a rate limit error - stop immediately to avoid making it worse if (this.isRateLimitError(error)) { @@ -817,6 +917,7 @@ export class MCPConnection extends EventEmitter { this.connectionState = 'error'; this.emit('connectionChange', 'error'); + this.recordFailedRound(); throw error; } finally { this.connectPromise = null; @@ -866,7 +967,8 @@ export class MCPConnection extends EventEmitter { async connect(): Promise { try { - await this.disconnect(); + // preserve cycle tracking across reconnects so the circuit breaker can detect rapid cycling + await this.disconnect(false); await this.connectClient(); if (!(await this.isConnected())) { throw new Error('Connection not established'); @@ -906,7 +1008,7 @@ export class MCPConnection extends EventEmitter { isTransient, } = extractSSEErrorMessage(error); - if (errorCode === 404) { + if (errorCode === 400 || errorCode === 404 || errorCode === 405) { const hasSession = 'sessionId' in transport && (transport as { sessionId?: string }).sessionId != null && @@ -914,14 +1016,14 @@ export class MCPConnection extends EventEmitter { if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) { logger.warn( - `${this.getLogPrefix()} SSE stream not available (404), no session. Ignoring.`, + `${this.getLogPrefix()} SSE stream not available (${errorCode}), no session. Ignoring.`, ); return; } if (hasSession) { logger.warn( - `${this.getLogPrefix()} 404 with active session — session lost, triggering reconnection.`, + `${this.getLogPrefix()} ${errorCode} with active session — session lost, triggering reconnection.`, ); } } @@ -992,7 +1094,7 @@ export class MCPConnection extends EventEmitter { await Promise.all(closing); } - public async disconnect(): Promise { + public async disconnect(resetCycleTracking = true): Promise { try { if (this.transport) { await this.client.close(); @@ -1006,6 +1108,9 @@ export class MCPConnection extends EventEmitter { this.emit('connectionChange', 'disconnected'); } finally { this.connectPromise = null; + if (!resetCycleTracking) { + this.recordCycle(); + } } } diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts index d3447eaeb8..d889da4f2f 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts @@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => { expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); }); - it('should not reconnect servers with expired tokens', async () => { + it('should not reconnect servers with expired tokens and no refresh token', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); (mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); - // server1: has expired token - tokenMethods.findToken.mockResolvedValue({ - userId, - identifier: 'mcp:server1', - expiresAt: new Date(Date.now() - 3600000), // 1 hour ago - } as unknown as MCPOAuthTokens); + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server1') { + return { + userId, + identifier, + expiresAt: new Date(Date.now() - 3600000), + } as unknown as MCPOAuthTokens; + } + return null; + }); await reconnectionManager.reconnectServers(userId); @@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => { expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled(); }); + it('should reconnect servers with expired access token but valid refresh token', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1']); + (mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); + + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server1') { + return { + userId, + identifier, + expiresAt: new Date(Date.now() - 3600000), + } as unknown as MCPOAuthTokens; + } + if (identifier === 'mcp:server1:refresh') { + return { + userId, + identifier, + } as unknown as MCPOAuthTokens; + } + return null; + }); + + const mockNewConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockNewConnection as unknown as MCPConnection, + ); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + await reconnectionManager.reconnectServers(userId); + + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith( + expect.objectContaining({ serverName: 'server1' }), + ); + }); + + it('should reconnect when access token is TTL-deleted but refresh token exists', async () => { + const userId = 'user-123'; + const oauthServers = new Set(['server1']); + (mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); + + tokenMethods.findToken.mockImplementation(async ({ identifier }) => { + if (identifier === 'mcp:server1:refresh') { + return { + userId, + identifier, + } as unknown as MCPOAuthTokens; + } + return null; + }); + + const mockNewConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockNewConnection as unknown as MCPConnection, + ); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + await reconnectionManager.reconnectServers(userId); + + expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith( + expect.objectContaining({ serverName: 'server1' }), + ); + }); + it('should handle connection that returns but is not connected', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); @@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => { }); }); + describe('reconnectServer', () => { + let reconnectionTracker: OAuthReconnectionTracker; + beforeEach(async () => { + reconnectionTracker = new OAuthReconnectionTracker(); + reconnectionManager = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + }); + + it('should return true on successful reconnection', async () => { + const userId = 'user-123'; + const serverName = 'server1'; + + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn(), + }; + mockMCPManager.getUserConnection.mockResolvedValue( + mockConnection as unknown as MCPConnection, + ); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + const result = await reconnectionManager.reconnectServer(userId, serverName); + expect(result).toBe(true); + }); + + it('should return false on failed reconnection', async () => { + const userId = 'user-123'; + const serverName = 'server1'; + + mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed')); + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); + + const result = await reconnectionManager.reconnectServer(userId, serverName); + expect(result).toBe(false); + }); + + it('should return false when MCPManager is not available', async () => { + const userId = 'user-123'; + const serverName = 'server1'; + + (OAuthReconnectionManager as unknown as { instance: null }).instance = null; + (MCPManager.getInstance as jest.Mock).mockImplementation(() => { + throw new Error('MCPManager has not been initialized.'); + }); + + const managerWithoutMCP = await OAuthReconnectionManager.createInstance( + flowManager, + tokenMethods, + reconnectionTracker, + ); + + const result = await managerWithoutMCP.reconnectServer(userId, serverName); + expect(result).toBe(false); + }); + }); + describe('reconnection staggering', () => { let reconnectionTracker: OAuthReconnectionTracker; diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index f14c4abf15..7afe992772 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -96,6 +96,24 @@ export class OAuthReconnectionManager { } } + /** + * Attempts to reconnect a single OAuth MCP server. + * @returns true if reconnection succeeded, false otherwise. + */ + public async reconnectServer(userId: string, serverName: string): Promise { + if (this.mcpManager == null) { + return false; + } + + this.reconnectionsTracker.setActive(userId, serverName); + try { + await this.tryReconnect(userId, serverName); + return !this.reconnectionsTracker.isFailed(userId, serverName); + } catch { + return false; + } + } + public clearReconnection(userId: string, serverName: string) { this.reconnectionsTracker.removeFailed(userId, serverName); this.reconnectionsTracker.removeActive(userId, serverName); @@ -174,23 +192,31 @@ export class OAuthReconnectionManager { } } - // if the server has no tokens for the user, don't attempt to reconnect + // if the server has a valid (non-expired) access token, allow reconnect const accessToken = await this.tokenMethods.findToken({ userId, type: 'mcp_oauth', identifier: `mcp:${serverName}`, }); - if (accessToken == null) { + + if (accessToken != null) { + const now = new Date(); + if (!accessToken.expiresAt || accessToken.expiresAt >= now) { + return true; + } + } + + // if the access token is expired or TTL-deleted, fall back to refresh token + const refreshToken = await this.tokenMethods.findToken({ + userId, + type: 'mcp_oauth', + identifier: `mcp:${serverName}:refresh`, + }); + + if (refreshToken == null) { return false; } - // if the token has expired, don't attempt to reconnect - const now = new Date(); - if (accessToken.expiresAt && accessToken.expiresAt < now) { - return false; - } - - // …otherwise, we're good to go with the reconnect attempt return true; } } diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts index 68ac1d027e..206fe96ef1 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts @@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => { }); }); + describe('cooldown-based retry', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should return true from isFailed within first cooldown period (5 min)', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(true); + + jest.advanceTimersByTime(4 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + }); + + it('should return false from isFailed after first cooldown elapses (5 min)', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(true); + + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => { + const now = Date.now(); + jest.setSystemTime(now); + + // First failure: 5 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + // Second failure: 10 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(9 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + // Third failure: 20 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(19 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + // Fourth failure: 30 min cooldown + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(29 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should cap cooldown at 30 min for attempts beyond 4', () => { + const now = Date.now(); + jest.setSystemTime(now); + + for (let i = 0; i < 5; i++) { + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(30 * 60 * 1000); + } + + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(29 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + + it('should fully reset metadata on removeFailed', () => { + const now = Date.now(); + jest.setSystemTime(now); + + tracker.setFailed(userId, serverName); + tracker.setFailed(userId, serverName); + tracker.setFailed(userId, serverName); + + tracker.removeFailed(userId, serverName); + expect(tracker.isFailed(userId, serverName)).toBe(false); + + tracker.setFailed(userId, serverName); + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed(userId, serverName)).toBe(false); + }); + }); + describe('timestamp tracking edge cases', () => { beforeEach(() => { jest.useFakeTimers(); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts index 9f6ef4abd3..504ea7d43a 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts @@ -1,6 +1,12 @@ +interface FailedMeta { + attempts: number; + lastFailedAt: number; +} + +const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000]; + export class OAuthReconnectionTracker { - /** Map of userId -> Set of serverNames that have failed reconnection */ - private failed: Map> = new Map(); + private failedMeta: Map> = new Map(); /** Map of userId -> Set of serverNames that are actively reconnecting */ private active: Map> = new Map(); /** Map of userId:serverName -> timestamp when reconnection started */ @@ -9,7 +15,17 @@ export class OAuthReconnectionTracker { private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes public isFailed(userId: string, serverName: string): boolean { - return this.failed.get(userId)?.has(serverName) ?? false; + const meta = this.failedMeta.get(userId)?.get(serverName); + if (!meta) { + return false; + } + const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1); + const cooldown = COOLDOWN_SCHEDULE_MS[idx]; + const elapsed = Date.now() - meta.lastFailedAt; + if (elapsed >= cooldown) { + return false; + } + return true; } /** Check if server is in the active set (original simple check) */ @@ -48,11 +64,15 @@ export class OAuthReconnectionTracker { } public setFailed(userId: string, serverName: string): void { - if (!this.failed.has(userId)) { - this.failed.set(userId, new Set()); + if (!this.failedMeta.has(userId)) { + this.failedMeta.set(userId, new Map()); } - - this.failed.get(userId)?.add(serverName); + const userMap = this.failedMeta.get(userId)!; + const existing = userMap.get(serverName); + userMap.set(serverName, { + attempts: (existing?.attempts ?? 0) + 1, + lastFailedAt: Date.now(), + }); } public setActive(userId: string, serverName: string): void { @@ -68,10 +88,10 @@ export class OAuthReconnectionTracker { } public removeFailed(userId: string, serverName: string): void { - const userServers = this.failed.get(userId); - userServers?.delete(serverName); - if (userServers?.size === 0) { - this.failed.delete(userId); + const userMap = this.failedMeta.get(userId); + userMap?.delete(serverName); + if (userMap?.size === 0) { + this.failedMeta.delete(userId); } } @@ -94,7 +114,7 @@ export class OAuthReconnectionTracker { activeTimestamps: number; } { return { - usersWithFailedServers: this.failed.size, + usersWithFailedServers: this.failedMeta.size, usersWithActiveReconnections: this.active.size, activeTimestamps: this.activeTimestamps.size, }; From eb6328c1d980274b4db6d6af1b5184ec67d56636 Mon Sep 17 00:00:00 2001 From: Oreon Lothamer <73498677+oreonl@users.noreply.github.com> Date: Tue, 10 Mar 2026 09:04:35 -1000 Subject: [PATCH 02/39] =?UTF-8?q?=F0=9F=9B=A4=EF=B8=8F=20fix:=20Base=20URL?= =?UTF-8?q?=20Fallback=20for=20Path-based=20OAuth=20Discovery=20in=20Token?= =?UTF-8?q?=20Refresh=20(#12164)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: add base URL fallback for path-based OAuth discovery in token refresh The two `refreshOAuthTokens` paths in `MCPOAuthHandler` were missing the origin-URL fallback that `initiateOAuthFlow` already had. With MCP SDK 1.27.1, `buildDiscoveryUrls` appends the server path to the `.well-known` URL (e.g. `/.well-known/oauth-authorization-server/mcp`), which returns 404 for servers like Sentry that only expose the root discovery endpoint (`/.well-known/oauth-authorization-server`). Without the fallback, discovery returns null during refresh, the token endpoint resolves to the wrong URL, and users are prompted to re-authenticate every time their access token expires instead of the refresh token being exchanged silently. Both refresh paths now mirror the `initiateOAuthFlow` pattern: if discovery fails and the server URL has a non-root path, retry with just the origin URL. Co-Authored-By: Claude Sonnet 4.6 * refactor: extract discoverWithOriginFallback helper; add tests Extract the duplicated path-based URL retry logic from both `refreshOAuthTokens` branches into a single private static helper `discoverWithOriginFallback`, reducing the risk of the two paths drifting in the future. Add three tests covering the new behaviour: - stored clientInfo path: asserts discovery is called twice (path then origin) and that the token endpoint from the origin discovery is used - auto-discovered path: same assertions for the branchless path - root URL: asserts discovery is called only once when the server URL already has no path component Co-Authored-By: Claude Sonnet 4.6 * refactor: use discoverWithOriginFallback in discoverMetadata too Remove the inline duplicate of the origin-fallback logic from `discoverMetadata` and replace it with a call to the shared `discoverWithOriginFallback` helper, giving all three discovery sites a single implementation. Co-Authored-By: Claude Sonnet 4.6 * test: use mock.calls + .href/.toString() for URL assertions Replace brittle `toHaveBeenNthCalledWith(new URL(...))` comparisons with `expect.any(URL)` matchers and explicit `.href`/`.toString()` checks on the captured call args, consistent with the existing mock.calls pattern used throughout handler.test.ts. Co-Authored-By: Claude Sonnet 4.6 --------- Co-authored-by: Claude Sonnet 4.6 --- .../api/src/mcp/__tests__/handler.test.ts | 162 ++++++++++++++++++ packages/api/src/mcp/oauth/handler.ts | 49 +++--- 2 files changed, 191 insertions(+), 20 deletions(-) diff --git a/packages/api/src/mcp/__tests__/handler.test.ts b/packages/api/src/mcp/__tests__/handler.test.ts index db88afe581..e5d94b23e3 100644 --- a/packages/api/src/mcp/__tests__/handler.test.ts +++ b/packages/api/src/mcp/__tests__/handler.test.ts @@ -1439,5 +1439,167 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }), ); }); + + describe('path-based URL origin fallback', () => { + it('retries with origin URL when path-based discovery fails (stored clientInfo path)', async () => { + const metadata = { + serverName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + grant_types: ['authorization_code', 'refresh_token'], + }, + }; + + const originMetadata = { + issuer: 'https://mcp.sentry.dev/', + authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize', + token_endpoint: 'https://mcp.sentry.dev/oauth/token', + token_endpoint_auth_methods_supported: ['client_secret_post'], + response_types_supported: ['code'], + jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata; + + // First call (path-based URL) fails, second call (origin URL) succeeds + mockDiscoverAuthorizationServerMetadata + .mockResolvedValueOnce(undefined) + .mockResolvedValueOnce(originMetadata); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + }), + } as Response); + + const result = await MCPOAuthHandler.refreshOAuthTokens( + 'test-refresh-token', + metadata, + {}, + {}, + ); + + // Discovery attempted twice: once with path URL, once with origin URL + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2); + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith( + 1, + expect.any(URL), + expect.any(Object), + ); + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith( + 2, + expect.any(URL), + expect.any(Object), + ); + const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL; + const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL; + expect(firstDiscoveryUrl).toBeInstanceOf(URL); + expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp'); + expect(secondDiscoveryUrl).toBeInstanceOf(URL); + expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/'); + + // Token endpoint from origin discovery metadata is used + expect(mockFetch).toHaveBeenCalled(); + const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0]; + expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/oauth/token'); + expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' })); + expect(result.access_token).toBe('new-access-token'); + }); + + it('retries with origin URL when path-based discovery fails (auto-discovered path)', async () => { + // No clientInfo — uses the auto-discovered branch + const metadata = { + serverName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + }; + + const originMetadata = { + issuer: 'https://mcp.sentry.dev/', + authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize', + token_endpoint: 'https://mcp.sentry.dev/oauth/token', + response_types_supported: ['code'], + jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + } as AuthorizationServerMetadata; + + // First call (path-based URL) fails, second call (origin URL) succeeds + mockDiscoverAuthorizationServerMetadata + .mockResolvedValueOnce(undefined) + .mockResolvedValueOnce(originMetadata); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + }), + } as Response); + + const result = await MCPOAuthHandler.refreshOAuthTokens( + 'test-refresh-token', + metadata, + {}, + {}, + ); + + // Discovery attempted twice: once with path URL, once with origin URL + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2); + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith( + 1, + expect.any(URL), + expect.any(Object), + ); + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith( + 2, + expect.any(URL), + expect.any(Object), + ); + const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL; + const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL; + expect(firstDiscoveryUrl).toBeInstanceOf(URL); + expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp'); + expect(secondDiscoveryUrl).toBeInstanceOf(URL); + expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/'); + + // Token endpoint from origin discovery metadata is used + expect(mockFetch).toHaveBeenCalled(); + const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0]; + expect(fetchUrl).toBeInstanceOf(URL); + expect(fetchUrl.toString()).toBe('https://mcp.sentry.dev/oauth/token'); + expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' })); + expect(result.access_token).toBe('new-access-token'); + }); + + it('does not retry with origin when server URL has no path (root URL)', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://auth.example.com/', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + }, + }; + + // Root URL discovery fails — no retry + mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ access_token: 'new-token', expires_in: 3600 }), + } as Response); + + await MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}); + + // Only one discovery attempt for a root URL + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1); + }); + }); }); }); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 92b9f1211c..6ef444bf47 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -161,20 +161,7 @@ export class MCPOAuthHandler { logger.debug( `[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`, ); - let rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl, { - fetchFn, - }); - - // If discovery failed and we're using a path-based URL, try the base URL - if (!rawMetadata && authServerUrl.pathname !== '/') { - const baseUrl = new URL(authServerUrl.origin); - logger.debug( - `[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`, - ); - rawMetadata = await discoverAuthorizationServerMetadata(baseUrl, { - fetchFn, - }); - } + const rawMetadata = await this.discoverWithOriginFallback(authServerUrl, fetchFn); if (!rawMetadata) { /** @@ -221,6 +208,27 @@ export class MCPOAuthHandler { }; } + /** + * Discovers OAuth authorization server metadata with origin-URL fallback. + * If discovery fails for a path-based URL, retries with just the origin. + * Mirrors the fallback behavior in `discoverMetadata` and `initiateOAuthFlow`. + */ + private static async discoverWithOriginFallback( + serverUrl: URL, + fetchFn: FetchLike, + ): ReturnType { + const metadata = await discoverAuthorizationServerMetadata(serverUrl, { fetchFn }); + // If discovery failed and we're using a path-based URL, try the base URL + if (!metadata && serverUrl.pathname !== '/') { + const baseUrl = new URL(serverUrl.origin); + logger.debug( + `[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`, + ); + return discoverAuthorizationServerMetadata(baseUrl, { fetchFn }); + } + return metadata; + } + /** * Registers an OAuth client dynamically */ @@ -735,9 +743,10 @@ export class MCPOAuthHandler { throw new Error('No token URL available for refresh'); } else { /** Auto-discover OAuth configuration for refresh */ - const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, { - fetchFn: this.createOAuthFetch(oauthHeaders), - }); + const serverUrl = new URL(metadata.serverUrl); + const fetchFn = this.createOAuthFetch(oauthHeaders); + const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn); + if (!oauthMetadata) { /** * No metadata discovered - use fallback /token endpoint. @@ -911,9 +920,9 @@ export class MCPOAuthHandler { } /** Auto-discover OAuth configuration for refresh */ - const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, { - fetchFn: this.createOAuthFetch(oauthHeaders), - }); + const serverUrl = new URL(metadata.serverUrl); + const fetchFn = this.createOAuthFetch(oauthHeaders); + const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn); let tokenUrl: URL; if (!oauthMetadata?.token_endpoint) { From c0e876a2e6f6346b76604587b5d4ef1e74ca9ad8 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 10 Mar 2026 16:19:07 -0400 Subject: [PATCH 03/39] =?UTF-8?q?=F0=9F=94=84=20refactor:=20OAuth=20Metada?= =?UTF-8?q?ta=20Discovery=20with=20Origin=20Fallback=20(#12170)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔄 refactor: OAuth Metadata Discovery with Origin Fallback Updated the `discoverWithOriginFallback` method to improve the handling of OAuth authorization server metadata discovery. The method now retries with the origin URL when discovery fails for a path-based URL, ensuring consistent behavior across `discoverMetadata` and token refresh flows. This change reduces code duplication and enhances the reliability of the OAuth flow by providing a unified implementation for origin fallback logic. * 🧪 test: Add tests for OAuth Token Refresh with Origin Fallback Introduced new tests for the `refreshOAuthTokens` method in `MCPOAuthHandler` to validate the retry mechanism with the origin URL when path-based discovery fails. The tests cover scenarios where the first discovery attempt throws an error and the subsequent attempt succeeds, as well as cases where the discovery fails entirely. This enhances the reliability of the OAuth token refresh process by ensuring proper handling of discovery failures. * chore: imports order * fix: Improve Base URL Logging and Metadata Discovery in MCPOAuthHandler Updated the logging to use a consistent base URL object when handling discovery failures in the MCPOAuthHandler. This change enhances error reporting by ensuring that the base URL is logged correctly, and it refines the metadata discovery process by returning the result of the discovery attempt with the base URL, improving the reliability of the OAuth flow. --- .../api/src/mcp/__tests__/handler.test.ts | 141 +++++++++++++++++- packages/api/src/mcp/oauth/handler.ts | 22 ++- .../MCPReinitRecovery.integration.test.ts | 6 +- 3 files changed, 153 insertions(+), 16 deletions(-) diff --git a/packages/api/src/mcp/__tests__/handler.test.ts b/packages/api/src/mcp/__tests__/handler.test.ts index e5d94b23e3..3b68d88e9c 100644 --- a/packages/api/src/mcp/__tests__/handler.test.ts +++ b/packages/api/src/mcp/__tests__/handler.test.ts @@ -1498,20 +1498,19 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { ); const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL; const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL; - expect(firstDiscoveryUrl).toBeInstanceOf(URL); expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp'); - expect(secondDiscoveryUrl).toBeInstanceOf(URL); expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/'); - // Token endpoint from origin discovery metadata is used + // Token endpoint from origin discovery metadata is used (string in stored-clientInfo branch) expect(mockFetch).toHaveBeenCalled(); const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0]; - expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/oauth/token'); + expect(typeof fetchUrl).toBe('string'); + expect(fetchUrl).toBe('https://mcp.sentry.dev/oauth/token'); expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' })); expect(result.access_token).toBe('new-access-token'); }); - it('retries with origin URL when path-based discovery fails (auto-discovered path)', async () => { + it('retries with origin URL when path-based discovery fails (no stored clientInfo)', async () => { // No clientInfo — uses the auto-discovered branch const metadata = { serverName: 'sentry', @@ -1563,12 +1562,10 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { ); const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL; const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL; - expect(firstDiscoveryUrl).toBeInstanceOf(URL); expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp'); - expect(secondDiscoveryUrl).toBeInstanceOf(URL); expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/'); - // Token endpoint from origin discovery metadata is used + // Token endpoint from origin discovery metadata is used (URL object in auto-discovered branch) expect(mockFetch).toHaveBeenCalled(); const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0]; expect(fetchUrl).toBeInstanceOf(URL); @@ -1577,6 +1574,46 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { expect(result.access_token).toBe('new-access-token'); }); + it('falls back to /token when both path and origin discovery fail', async () => { + const metadata = { + serverName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + grant_types: ['authorization_code', 'refresh_token'], + }, + }; + + // Both path AND origin discovery return undefined + mockDiscoverAuthorizationServerMetadata + .mockResolvedValueOnce(undefined) + .mockResolvedValueOnce(undefined); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + }), + } as Response); + + const result = await MCPOAuthHandler.refreshOAuthTokens( + 'test-refresh-token', + metadata, + {}, + {}, + ); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2); + + // Falls back to /token relative to server URL origin + const [fetchUrl] = mockFetch.mock.calls[0]; + expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/token'); + expect(result.access_token).toBe('new-access-token'); + }); + it('does not retry with origin when server URL has no path (root URL)', async () => { const metadata = { serverName: 'test-server', @@ -1600,6 +1637,94 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { // Only one discovery attempt for a root URL expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1); }); + + it('retries with origin when path-based discovery throws', async () => { + const metadata = { + serverName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + grant_types: ['authorization_code', 'refresh_token'], + }, + }; + + const originMetadata = { + issuer: 'https://mcp.sentry.dev/', + authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize', + token_endpoint: 'https://mcp.sentry.dev/oauth/token', + token_endpoint_auth_methods_supported: ['client_secret_post'], + response_types_supported: ['code'], + } as AuthorizationServerMetadata; + + // First call throws, second call succeeds + mockDiscoverAuthorizationServerMetadata + .mockRejectedValueOnce(new Error('Network error')) + .mockResolvedValueOnce(originMetadata); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + }), + } as Response); + + const result = await MCPOAuthHandler.refreshOAuthTokens( + 'test-refresh-token', + metadata, + {}, + {}, + ); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2); + const [fetchUrl] = mockFetch.mock.calls[0]; + expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/oauth/token'); + expect(result.access_token).toBe('new-access-token'); + }); + + it('propagates the throw when root URL discovery throws', async () => { + const metadata = { + serverName: 'test-server', + serverUrl: 'https://auth.example.com/', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + }, + }; + + mockDiscoverAuthorizationServerMetadata.mockRejectedValueOnce( + new Error('Discovery failed'), + ); + + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('Discovery failed'); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1); + }); + + it('propagates the throw when both path and origin discovery throw', async () => { + const metadata = { + serverName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + clientInfo: { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + }, + }; + + mockDiscoverAuthorizationServerMetadata + .mockRejectedValueOnce(new Error('Network error')) + .mockRejectedValueOnce(new Error('Origin also failed')); + + await expect( + MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}), + ).rejects.toThrow('Origin also failed'); + + expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2); + }); }); }); }); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 6ef444bf47..83e855591e 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -209,16 +209,28 @@ export class MCPOAuthHandler { } /** - * Discovers OAuth authorization server metadata with origin-URL fallback. - * If discovery fails for a path-based URL, retries with just the origin. - * Mirrors the fallback behavior in `discoverMetadata` and `initiateOAuthFlow`. + * Discovers OAuth authorization server metadata, retrying with just the origin + * when discovery fails for a path-based URL. Shared implementation used by + * `discoverMetadata` and both `refreshOAuthTokens` branches. */ private static async discoverWithOriginFallback( serverUrl: URL, fetchFn: FetchLike, ): ReturnType { - const metadata = await discoverAuthorizationServerMetadata(serverUrl, { fetchFn }); - // If discovery failed and we're using a path-based URL, try the base URL + let metadata: Awaited>; + try { + metadata = await discoverAuthorizationServerMetadata(serverUrl, { fetchFn }); + } catch (err) { + if (serverUrl.pathname === '/') { + throw err; + } + const baseUrl = new URL(serverUrl.origin); + logger.debug( + `[MCPOAuth] Discovery threw for path URL, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`, + { error: err }, + ); + return discoverAuthorizationServerMetadata(baseUrl, { fetchFn }); + } if (!metadata && serverUrl.pathname !== '/') { const baseUrl = new URL(serverUrl.origin); logger.debug( diff --git a/packages/api/src/mcp/registry/__tests__/MCPReinitRecovery.integration.test.ts b/packages/api/src/mcp/registry/__tests__/MCPReinitRecovery.integration.test.ts index b171e84d13..9545486fde 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPReinitRecovery.integration.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPReinitRecovery.integration.test.ts @@ -17,20 +17,20 @@ import * as net from 'net'; import * as http from 'http'; +import { Keyv } from 'keyv'; import { Agent } from 'undici'; +import { Types } from 'mongoose'; import { randomUUID } from 'crypto'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; -import { Keyv } from 'keyv'; -import { Types } from 'mongoose'; import type { IUser } from '@librechat/data-schemas'; import type { Socket } from 'net'; import type * as t from '~/mcp/types'; -import { MCPInspectionFailedError } from '~/mcp/errors'; import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache'; import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; +import { MCPInspectionFailedError } from '~/mcp/errors'; import { FlowStateManager } from '~/flow/manager'; import { MCPConnection } from '~/mcp/connection'; import { MCPManager } from '~/mcp/MCPManager'; From 6167ce6e57f37b2c33563fe72f9c7205f9d99da3 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 10 Mar 2026 17:44:13 -0400 Subject: [PATCH 04/39] =?UTF-8?q?=F0=9F=A7=AA=20chore:=20MCP=20Reconnect?= =?UTF-8?q?=20Storm=20Follow-Up=20Fixes=20and=20Integration=20Tests=20(#12?= =?UTF-8?q?172)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🧪 test: Add reconnection storm regression tests for MCPConnection Introduced a comprehensive test suite for reconnection storm scenarios, validating circuit breaker, throttling, cooldown, and timeout fixes. The tests utilize real MCP SDK transports and a StreamableHTTP server to ensure accurate behavior under rapid connect/disconnect cycles and error handling for SSE 400/405 responses. This enhances the reliability of the MCPConnection by ensuring proper handling of reconnection logic and circuit breaker functionality. * 🔧 fix: Update createUnavailableToolStub to return structured response Modified the `createUnavailableToolStub` function to return an array containing the unavailable message and a null value, enhancing the response structure. Additionally, added a debug log to skip tool creation when the result is null, improving the handling of reconnection scenarios in the MCP service. * 🧪 test: Enhance MCP tool creation tests for cache and throttle interactions Added new test cases for the `createMCPTool` function to validate the caching behavior when tools are unavailable or throttled. The tests ensure that tools are correctly cached as missing and prevent unnecessary reconnects across different users, improving the reliability of the MCP service under concurrent usage scenarios. Additionally, introduced a test for the `createMCPTools` function to verify that it returns an empty array when reconnect is throttled, ensuring proper handling of throttling logic. * 📝 docs: Update AGENTS.md with testing philosophy and guidelines Expanded the testing section in AGENTS.md to emphasize the importance of using real logic over mocks, advocating for the use of spies and real dependencies in tests. Added specific recommendations for testing with MongoDB and MCP SDK, highlighting the need to mock only uncontrollable external services. This update aims to improve testing practices and encourage more robust test implementations. * 🧪 test: Enhance reconnection storm tests with socket tracking and SSE handling Updated the reconnection storm test suite to include a new socket tracking mechanism for better resource management during tests. Improved the handling of SSE 400/405 responses by ensuring they are processed in the same branch as 404 errors, preventing unhandled cases. This enhances the reliability of the MCPConnection under rapid reconnect scenarios and ensures proper error handling. * 🔧 fix: Implement cache eviction for stale reconnect attempts and missing tools Added an `evictStale` function to manage the size of the `lastReconnectAttempts` and `missingToolCache` maps, ensuring they do not exceed a maximum cache size. This enhancement improves resource management by removing outdated entries based on a specified time-to-live (TTL), thereby optimizing the MCP service's performance during reconnection scenarios. --- AGENTS.md | 10 +- api/server/services/MCP.js | 24 +- api/server/services/MCP.spec.js | 183 ++++++ .../mcp/__tests__/reconnection-storm.test.ts | 521 ++++++++++++++++++ 4 files changed, 736 insertions(+), 2 deletions(-) create mode 100644 packages/api/src/mcp/__tests__/reconnection-storm.test.ts diff --git a/AGENTS.md b/AGENTS.md index 23b5fc0fbb..ec44607aa7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -149,7 +149,15 @@ Multi-line imports count total character length across all lines. Consolidate va - Run tests from their workspace directory: `cd api && npx jest `, `cd packages/api && npx jest `, etc. - Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering. - Cover loading, success, and error states for UI/data flows. -- Mock data-provider hooks and external dependencies. + +### Philosophy + +- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort. +- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic. +- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls. +- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals. +- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls. +- Heavy mocking is a code smell, not a testing strategy. --- diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 4f8cdc8195..c66eb0b6ef 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -34,12 +34,28 @@ const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); +const MAX_CACHE_SIZE = 1000; const lastReconnectAttempts = new Map(); const RECONNECT_THROTTLE_MS = 10_000; const missingToolCache = new Map(); const MISSING_TOOL_TTL_MS = 10_000; +function evictStale(map, ttl) { + if (map.size <= MAX_CACHE_SIZE) { + return; + } + const now = Date.now(); + for (const [key, timestamp] of map) { + if (now - timestamp >= ttl) { + map.delete(key); + } + if (map.size <= MAX_CACHE_SIZE) { + return; + } + } +} + const unavailableMsg = "This tool's MCP server is temporarily unavailable. Please try again shortly."; @@ -49,7 +65,7 @@ const unavailableMsg = */ function createUnavailableToolStub(toolName, serverName) { const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`; - const _call = async () => unavailableMsg; + const _call = async () => [unavailableMsg, null]; const toolInstance = tool(_call, { schema: { type: 'object', @@ -253,6 +269,7 @@ async function reconnectServer({ return null; } lastReconnectAttempts.set(throttleKey, now); + evictStale(lastReconnectAttempts, RECONNECT_THROTTLE_MS); const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; const flowId = `${user.id}:${serverName}:${Date.now()}`; @@ -373,6 +390,10 @@ async function createMCPTools({ userMCPAuthMap, streamId, }); + if (result === null) { + logger.debug(`[MCP][${serverName}] Reconnect throttled, skipping tool creation.`); + return []; + } if (!result || !result.tools) { logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`); return []; @@ -469,6 +490,7 @@ async function createMCPTool({ if (!toolDefinition) { missingToolCache.set(toolKey, Date.now()); + evictStale(missingToolCache, MISSING_TOOL_TTL_MS); } } diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index b2caebc91e..14a9ef90ed 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -45,6 +45,7 @@ const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus, + createUnavailableToolStub, } = require('./MCP'); jest.mock('./Config', () => ({ @@ -1098,6 +1099,188 @@ describe('User parameter passing tests', () => { }); }); + describe('createUnavailableToolStub', () => { + it('should return a tool whose _call returns a valid CONTENT_AND_ARTIFACT two-tuple', async () => { + const stub = createUnavailableToolStub('myTool', 'myServer'); + // invoke() goes through langchain's base tool, which checks responseFormat. + // CONTENT_AND_ARTIFACT requires [content, artifact] — a bare string would throw: + // "Tool response format is "content_and_artifact" but the output was not a two-tuple" + const result = await stub.invoke({}); + // If we reach here without throwing, the two-tuple format is correct. + // invoke() returns the content portion of [content, artifact] as a string. + expect(result).toContain('temporarily unavailable'); + }); + }); + + describe('negative tool cache and throttle interaction', () => { + it('should cache tool as missing even when throttled (cross-user dedup)', async () => { + const mockUser = { id: 'throttle-test-user' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // First call: reconnect succeeds but tool not found + mockReinitMCPServer.mockResolvedValueOnce({ + availableTools: {}, + }); + + await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: `missing-tool${D}cache-dedup-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools: undefined, + }); + + // Second call within 10s for DIFFERENT tool on same server: + // reconnect is throttled (returns null), tool is still cached as missing. + // This is intentional: the cache acts as cross-user dedup since the + // throttle is per-user-per-server and can't prevent N different users + // from each triggering their own reconnect. + const result2 = await createMCPTool({ + res: mockRes, + user: mockUser, + toolKey: `other-tool${D}cache-dedup-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools: undefined, + }); + + expect(result2).toBeDefined(); + expect(result2.name).toContain('other-tool'); + expect(mockReinitMCPServer).toHaveBeenCalledTimes(1); + }); + + it('should prevent user B from triggering reconnect when user A already cached the tool', async () => { + const userA = { id: 'cache-user-A' }; + const userB = { id: 'cache-user-B' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // User A: real reconnect, tool not found → cached + mockReinitMCPServer.mockResolvedValueOnce({ + availableTools: {}, + }); + + await createMCPTool({ + res: mockRes, + user: userA, + toolKey: `shared-tool${D}cross-user-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools: undefined, + }); + + expect(mockReinitMCPServer).toHaveBeenCalledTimes(1); + + // User B requests the SAME tool within 10s. + // The negative cache is keyed by toolKey (no user prefix), so user B + // gets a cache hit and no reconnect fires. This is the cross-user + // storm protection: without this, user B's unthrottled first request + // would trigger a second reconnect to the same server. + const result = await createMCPTool({ + res: mockRes, + user: userB, + toolKey: `shared-tool${D}cross-user-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools: undefined, + }); + + expect(result).toBeDefined(); + expect(result.name).toContain('shared-tool'); + // reinitMCPServer still called only once — user B hit the cache + expect(mockReinitMCPServer).toHaveBeenCalledTimes(1); + }); + + it('should prevent user B from triggering reconnect for throttle-cached tools', async () => { + const userA = { id: 'storm-user-A' }; + const userB = { id: 'storm-user-B' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // User A: real reconnect for tool-1, tool not found → cached + mockReinitMCPServer.mockResolvedValueOnce({ + availableTools: {}, + }); + + await createMCPTool({ + res: mockRes, + user: userA, + toolKey: `tool-1${D}storm-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools: undefined, + }); + + // User A: tool-2 on same server within 10s → throttled → cached from throttle + await createMCPTool({ + res: mockRes, + user: userA, + toolKey: `tool-2${D}storm-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools: undefined, + }); + + expect(mockReinitMCPServer).toHaveBeenCalledTimes(1); + + // User B requests tool-2 — gets cache hit from the throttle-cached entry. + // Without this caching, user B would trigger a real reconnect since + // user B has their own throttle key and hasn't reconnected yet. + const result = await createMCPTool({ + res: mockRes, + user: userB, + toolKey: `tool-2${D}storm-server`, + provider: 'openai', + userMCPAuthMap: {}, + availableTools: undefined, + }); + + expect(result).toBeDefined(); + expect(result.name).toContain('tool-2'); + // Still only 1 real reconnect — user B was protected by the cache + expect(mockReinitMCPServer).toHaveBeenCalledTimes(1); + }); + }); + + describe('createMCPTools throttle handling', () => { + it('should return empty array with debug log when reconnect is throttled', async () => { + const mockUser = { id: 'throttle-tools-user' }; + const mockRes = { write: jest.fn(), flush: jest.fn() }; + + // First call: real reconnect + mockReinitMCPServer.mockResolvedValueOnce({ + tools: [{ name: 'tool1' }], + availableTools: { + [`tool1${D}throttle-tools-server`]: { + function: { description: 'Tool 1', parameters: {} }, + }, + }, + }); + + await createMCPTools({ + res: mockRes, + user: mockUser, + serverName: 'throttle-tools-server', + provider: 'openai', + userMCPAuthMap: {}, + }); + + // Second call within 10s — throttled + const result = await createMCPTools({ + res: mockRes, + user: mockUser, + serverName: 'throttle-tools-server', + provider: 'openai', + userMCPAuthMap: {}, + }); + + expect(result).toEqual([]); + // reinitMCPServer called only once — second was throttled + expect(mockReinitMCPServer).toHaveBeenCalledTimes(1); + // Should log at debug level (not warn) for throttled case + expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Reconnect throttled')); + }); + }); + describe('User parameter integrity', () => { it('should preserve user object properties through the call chain', async () => { const complexUser = { diff --git a/packages/api/src/mcp/__tests__/reconnection-storm.test.ts b/packages/api/src/mcp/__tests__/reconnection-storm.test.ts new file mode 100644 index 0000000000..c1cf0ec5df --- /dev/null +++ b/packages/api/src/mcp/__tests__/reconnection-storm.test.ts @@ -0,0 +1,521 @@ +/** + * Reconnection storm regression tests for PR #12162. + * + * Validates circuit breaker, throttling, cooldown, and timeout fixes using real + * MCP SDK transports (no mocked stubs). A real StreamableHTTP server is spun up + * per test suite and MCPConnection talks to it through a genuine HTTP stack. + */ +import http from 'http'; +import { randomUUID } from 'crypto'; +import express from 'express'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; +import type { Socket } from 'net'; +import { OAuthReconnectionTracker } from '~/mcp/oauth/OAuthReconnectionTracker'; +import { MCPConnection } from '~/mcp/connection'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +/* ------------------------------------------------------------------ */ +/* Helpers */ +/* ------------------------------------------------------------------ */ + +interface TestServer { + url: string; + httpServer: http.Server; + close: () => Promise; +} + +function trackSockets(httpServer: http.Server): () => Promise { + const sockets = new Set(); + httpServer.on('connection', (socket: Socket) => { + sockets.add(socket); + socket.once('close', () => sockets.delete(socket)); + }); + return () => + new Promise((resolve) => { + for (const socket of sockets) { + socket.destroy(); + } + sockets.clear(); + httpServer.close(() => resolve()); + }); +} + +function startMCPServer(): Promise { + const app = express(); + app.use(express.json()); + + const transports: Record = {}; + + function createServer(): McpServer { + const server = new McpServer({ name: 'test-server', version: '1.0.0' }); + server.tool('echo', 'echoes input', { message: { type: 'string' } as never }, async (args) => { + const msg = (args as Record).message ?? ''; + return { content: [{ type: 'text', text: msg }] }; + }); + return server; + } + + app.all('/mcp', async (req, res) => { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + + if (sessionId && transports[sessionId]) { + await transports[sessionId].handleRequest(req, res, req.body); + return; + } + + if (!sessionId && isInitializeRequest(req.body)) { + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: (sid) => { + transports[sid] = transport; + }, + }); + transport.onclose = () => { + const sid = transport.sessionId; + if (sid) { + delete transports[sid]; + } + }; + const server = createServer(); + await server.connect(transport); + await transport.handleRequest(req, res, req.body); + return; + } + + if (req.method === 'GET') { + res.status(404).send('Not Found'); + return; + } + + res.status(400).json({ + jsonrpc: '2.0', + error: { code: -32000, message: 'Bad Request: No valid session ID provided' }, + id: null, + }); + }); + + return new Promise((resolve) => { + const httpServer = app.listen(0, '127.0.0.1', () => { + const destroySockets = trackSockets(httpServer); + const addr = httpServer.address() as { port: number }; + resolve({ + url: `http://127.0.0.1:${addr.port}/mcp`, + httpServer, + close: async () => { + for (const t of Object.values(transports)) { + t.close().catch(() => {}); + } + await destroySockets(); + }, + }); + }); + }); +} + +function createConnection(serverName: string, url: string, initTimeout = 5000): MCPConnection { + return new MCPConnection({ + serverName, + serverConfig: { url, type: 'streamable-http', initTimeout } as never, + }); +} + +async function teardownConnection(conn: MCPConnection): Promise { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (conn as any).shouldStopReconnecting = true; + conn.removeAllListeners(); + await conn.disconnect(); +} + +afterEach(() => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (MCPConnection as any).circuitBreakers.clear(); +}); + +/* ------------------------------------------------------------------ */ +/* Fix #2 — Circuit breaker trips after rapid connect/disconnect */ +/* cycles (5 cycles within 60s -> 30s cooldown) */ +/* ------------------------------------------------------------------ */ +describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => { + it('blocks connection after 5 rapid cycles via static circuit breaker', async () => { + const srv = await startMCPServer(); + const conn = createConnection('cycling-server', srv.url); + + let completedCycles = 0; + let breakerMessage = ''; + for (let cycle = 0; cycle < 10; cycle++) { + try { + await conn.connect(); + await teardownConnection(conn); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (conn as any).shouldStopReconnecting = false; + completedCycles++; + } catch (e) { + breakerMessage = (e as Error).message; + break; + } + } + + expect(breakerMessage).toContain('Circuit breaker is open'); + expect(completedCycles).toBeLessThanOrEqual(5); + + await srv.close(); + }); +}); + +/* ------------------------------------------------------------------ */ +/* Fix #3 — SSE 400/405 handled in same branch as 404 */ +/* ------------------------------------------------------------------ */ +describe('Fix #3: SSE 400/405 handled in same branch as 404', () => { + it('400 with active session triggers reconnection (session lost)', async () => { + const srv = await startMCPServer(); + const conn = createConnection('sse-400', srv.url); + await conn.connect(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (conn as any).shouldStopReconnecting = true; + + const changes: string[] = []; + conn.on('connectionChange', (s: string) => changes.push(s)); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const transport = (conn as any).transport; + transport.onerror({ message: 'Failed to open SSE stream', code: 400 }); + + expect(changes).toContain('error'); + + await teardownConnection(conn); + await srv.close(); + }); + + it('405 with active session triggers reconnection (session lost)', async () => { + const srv = await startMCPServer(); + const conn = createConnection('sse-405', srv.url); + await conn.connect(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (conn as any).shouldStopReconnecting = true; + + const changes: string[] = []; + conn.on('connectionChange', (s: string) => changes.push(s)); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const transport = (conn as any).transport; + transport.onerror({ message: 'Method Not Allowed', code: 405 }); + + expect(changes).toContain('error'); + + await teardownConnection(conn); + await srv.close(); + }); +}); + +/* ------------------------------------------------------------------ */ +/* Fix #4 — Circuit breaker state persists in static Map across */ +/* instance replacements */ +/* ------------------------------------------------------------------ */ +describe('Fix #4: Circuit breaker state persists across instance replacement', () => { + it('new MCPConnection for same serverName inherits breaker state from static Map', async () => { + const srv = await startMCPServer(); + + const conn1 = createConnection('replace', srv.url); + await conn1.connect(); + await teardownConnection(conn1); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cbAfterConn1 = (MCPConnection as any).circuitBreakers.get('replace'); + expect(cbAfterConn1).toBeDefined(); + const cyclesAfterConn1 = cbAfterConn1.cycleCount; + expect(cyclesAfterConn1).toBeGreaterThan(0); + + const conn2 = createConnection('replace', srv.url); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cbFromConn2 = (conn2 as any).getCircuitBreaker(); + expect(cbFromConn2.cycleCount).toBe(cyclesAfterConn1); + + await teardownConnection(conn2); + await srv.close(); + }); + + it('clearCooldown resets static state so explicit retry proceeds', () => { + const conn = createConnection('replace', 'http://127.0.0.1:1/mcp'); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cb = (conn as any).getCircuitBreaker(); + cb.cooldownUntil = Date.now() + 999_999; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + expect((conn as any).isCircuitOpen()).toBe(true); + + MCPConnection.clearCooldown('replace'); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + expect((conn as any).isCircuitOpen()).toBe(false); + }); +}); + +/* ------------------------------------------------------------------ */ +/* Fix #5 — Dead servers now trigger circuit breaker via */ +/* recordFailedRound() in the catch path */ +/* ------------------------------------------------------------------ */ +describe('Fix #5: Dead server triggers circuit breaker', () => { + it('3 failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => { + const conn = createConnection('dead', 'http://127.0.0.1:1/mcp', 1000); + const spy = jest.spyOn(conn.client, 'connect'); + + const errors: string[] = []; + for (let i = 0; i < 5; i++) { + try { + await conn.connect(); + } catch (e) { + errors.push((e as Error).message); + } + } + + expect(spy.mock.calls.length).toBe(3); + expect(errors).toHaveLength(5); + expect(errors.filter((m) => m.includes('Circuit breaker is open'))).toHaveLength(2); + + await conn.disconnect(); + }); + + it('user B is immediately blocked when user A already tripped the breaker for the same server', async () => { + const deadUrl = 'http://127.0.0.1:1/mcp'; + + const userA = new MCPConnection({ + serverName: 'shared-dead', + serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never, + userId: 'user-A', + }); + + for (let i = 0; i < 3; i++) { + try { + await userA.connect(); + } catch { + // expected + } + } + + const userB = new MCPConnection({ + serverName: 'shared-dead', + serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never, + userId: 'user-B', + }); + const spyB = jest.spyOn(userB.client, 'connect'); + + let blockedMessage = ''; + try { + await userB.connect(); + } catch (e) { + blockedMessage = (e as Error).message; + } + + expect(blockedMessage).toContain('Circuit breaker is open'); + expect(spyB).toHaveBeenCalledTimes(0); + + await userA.disconnect(); + await userB.disconnect(); + }); + + it('clearCooldown after user retry unblocks all users', async () => { + const deadUrl = 'http://127.0.0.1:1/mcp'; + + const userA = new MCPConnection({ + serverName: 'shared-dead-clear', + serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never, + userId: 'user-A', + }); + for (let i = 0; i < 3; i++) { + try { + await userA.connect(); + } catch { + // expected + } + } + + const userB = new MCPConnection({ + serverName: 'shared-dead-clear', + serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never, + userId: 'user-B', + }); + try { + await userB.connect(); + } catch (e) { + expect((e as Error).message).toContain('Circuit breaker is open'); + } + + MCPConnection.clearCooldown('shared-dead-clear'); + + const spyB = jest.spyOn(userB.client, 'connect'); + try { + await userB.connect(); + } catch { + // expected — server is still dead + } + + expect(spyB).toHaveBeenCalledTimes(1); + + await userA.disconnect(); + await userB.disconnect(); + }); +}); + +/* ------------------------------------------------------------------ */ +/* Fix #5b — disconnect(false) preserves cycle tracking */ +/* ------------------------------------------------------------------ */ +describe('Fix #5b: disconnect(false) preserves cycle tracking', () => { + it('connect() passes false to disconnect, which calls recordCycle()', async () => { + const srv = await startMCPServer(); + const conn = createConnection('wipe', srv.url); + const spy = jest.spyOn(conn, 'disconnect'); + + await conn.connect(); + expect(spy).toHaveBeenCalledWith(false); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cb = (MCPConnection as any).circuitBreakers.get('wipe'); + expect(cb).toBeDefined(); + expect(cb.cycleCount).toBeGreaterThan(0); + + await teardownConnection(conn); + await srv.close(); + }); +}); + +/* ------------------------------------------------------------------ */ +/* Fix #6 — OAuth failure uses cooldown-based retry */ +/* ------------------------------------------------------------------ */ +describe('Fix #6: OAuth failure uses cooldown-based retry', () => { + beforeEach(() => jest.useFakeTimers()); + afterEach(() => jest.useRealTimers()); + + it('isFailed expires after first cooldown of 5 min', () => { + jest.setSystemTime(Date.now()); + const tracker = new OAuthReconnectionTracker(); + tracker.setFailed('u1', 'srv'); + + expect(tracker.isFailed('u1', 'srv')).toBe(true); + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed('u1', 'srv')).toBe(false); + }); + + it('progressive cooldown: 5m, 10m, 20m, 30m (capped)', () => { + jest.setSystemTime(Date.now()); + const tracker = new OAuthReconnectionTracker(); + + tracker.setFailed('u1', 'srv'); + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed('u1', 'srv')).toBe(false); + + tracker.setFailed('u1', 'srv'); + jest.advanceTimersByTime(10 * 60 * 1000); + expect(tracker.isFailed('u1', 'srv')).toBe(false); + + tracker.setFailed('u1', 'srv'); + jest.advanceTimersByTime(20 * 60 * 1000); + expect(tracker.isFailed('u1', 'srv')).toBe(false); + + tracker.setFailed('u1', 'srv'); + jest.advanceTimersByTime(29 * 60 * 1000); + expect(tracker.isFailed('u1', 'srv')).toBe(true); + jest.advanceTimersByTime(1 * 60 * 1000); + expect(tracker.isFailed('u1', 'srv')).toBe(false); + }); + + it('removeFailed resets attempt count so next failure starts at 5m', () => { + jest.setSystemTime(Date.now()); + const tracker = new OAuthReconnectionTracker(); + + tracker.setFailed('u1', 'srv'); + tracker.setFailed('u1', 'srv'); + tracker.setFailed('u1', 'srv'); + tracker.removeFailed('u1', 'srv'); + + tracker.setFailed('u1', 'srv'); + jest.advanceTimersByTime(5 * 60 * 1000); + expect(tracker.isFailed('u1', 'srv')).toBe(false); + }); +}); + +/* ------------------------------------------------------------------ */ +/* Integration: Circuit breaker caps rapid cycling with real transport */ +/* ------------------------------------------------------------------ */ +describe('Cascade: Circuit breaker caps rapid cycling', () => { + it('breaker trips before 10 cycles complete against a live server', async () => { + const srv = await startMCPServer(); + const conn = createConnection('cascade', srv.url); + const spy = jest.spyOn(conn.client, 'connect'); + + let completedCycles = 0; + for (let i = 0; i < 10; i++) { + try { + await conn.connect(); + await teardownConnection(conn); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (conn as any).shouldStopReconnecting = false; + completedCycles++; + } catch (e) { + if ((e as Error).message.includes('Circuit breaker is open')) { + break; + } + throw e; + } + } + + expect(completedCycles).toBeLessThanOrEqual(5); + expect(spy.mock.calls.length).toBeLessThanOrEqual(5); + + await srv.close(); + }); + + it('breaker bounds failures against a killed server', async () => { + const srv = await startMCPServer(); + const conn = createConnection('cascade-die', srv.url, 2000); + + await conn.connect(); + await teardownConnection(conn); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (conn as any).shouldStopReconnecting = false; + await srv.close(); + + let breakerTripped = false; + for (let i = 0; i < 10; i++) { + try { + await conn.connect(); + } catch (e) { + if ((e as Error).message.includes('Circuit breaker is open')) { + breakerTripped = true; + break; + } + } + } + + expect(breakerTripped).toBe(true); + }, 30_000); +}); + +/* ------------------------------------------------------------------ */ +/* Sanity: Real transport works end-to-end */ +/* ------------------------------------------------------------------ */ +describe('Sanity: Real MCP SDK transport works correctly', () => { + it('connects, lists tools, and disconnects cleanly', async () => { + const srv = await startMCPServer(); + const conn = createConnection('sanity', srv.url); + + await conn.connect(); + expect(await conn.isConnected()).toBe(true); + + const tools = await conn.fetchTools(); + expect(tools).toEqual(expect.arrayContaining([expect.objectContaining({ name: 'echo' })])); + + await teardownConnection(conn); + await srv.close(); + }); +}); From fcb344da47cbbe51634ee4c6620598acc88e145b Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 10 Mar 2026 21:15:01 -0400 Subject: [PATCH 05/39] =?UTF-8?q?=F0=9F=9B=82=20fix:=20MCP=20OAuth=20Race?= =?UTF-8?q?=20Conditions,=20CSRF=20Fallback,=20and=20Token=20Expiry=20Hand?= =?UTF-8?q?ling=20(#12171)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: Implement race conditions in MCP OAuth flow - Added connection mutex to coalesce concurrent `getUserConnection` calls, preventing multiple simultaneous attempts. - Enhanced flow state management to retry once when a flow state is missing, improving resilience against race conditions. - Introduced `ReauthenticationRequiredError` for better error handling when access tokens are expired or missing. - Updated tests to cover new race condition scenarios and ensure proper handling of OAuth flows. * fix: Stale PENDING flow detection and OAuth URL re-issuance PENDING flows in handleOAuthRequired now check createdAt age — flows older than 2 minutes are treated as stale and replaced instead of joined. Fixes the case where a leftover PENDING flow from a previous session blocks new OAuth initiation. authorizationUrl is now stored in MCPOAuthFlowMetadata so that when a second caller joins an active PENDING flow (e.g., the SSE-emitting path in ToolService), it can re-issue the URL to the user via oauthStart. * fix: CSRF fallback via active PENDING flow in OAuth callback When the OAuth callback arrives without CSRF or session cookies (common in the chat/SSE flow where cookies can't be set on streaming responses), fall back to validating that a PENDING flow exists for the flowId. This is safe because the flow was created server-side after JWT authentication and the authorization code is PKCE-protected. * test: Extract shared OAuth test server helpers Move MockKeyv, getFreePort, trackSockets, and createOAuthMCPServer into a shared helpers/oauthTestServer module. Enhance the test server with refresh token support, token rotation, metadata discovery, and dynamic client registration endpoints. Add InMemoryTokenStore for token storage tests. Refactor MCPOAuthRaceCondition.test.ts to import from shared helpers. * test: Add comprehensive MCP OAuth test modules MCPOAuthTokenStorage — 21 tests for storeTokens/getTokens with InMemoryTokenStore: encrypt/decrypt round-trips, expiry calculation, refresh callback wiring, ReauthenticationRequiredError paths. MCPOAuthFlow — 10 tests against real HTTP server: token refresh with stored client info, refresh token rotation, metadata discovery, dynamic client registration, full store/retrieve/expire/refresh lifecycle. MCPOAuthConnectionEvents — 5 tests for MCPConnection OAuth event cycle with real OAuth-gated MCP server: oauthRequired emission on 401, oauthHandled reconnection, oauthFailed rejection, token expiry detection. MCPOAuthTokenExpiry — 12 tests for the token expiry edge case: refresh success/failure paths, ReauthenticationRequiredError, PENDING flow CSRF fallback, authorizationUrl metadata storage, full re-auth cycle after refresh failure, concurrent expired token coalescing, stale PENDING flow detection. * test: Enhance MCP OAuth connection tests with cooldown reset Added a `beforeEach` hook to clear the cooldown for `MCPConnection` before each test, ensuring a clean state. Updated the race condition handling in the tests to properly clear the timeout, improving reliability in the event data retrieval process. * refactor: PENDING flow management and state recovery in MCP OAuth - Introduced a constant `PENDING_STALE_MS` to define the age threshold for PENDING flows, improving the handling of stale flows. - Updated the logic in `MCPConnectionFactory` and `FlowStateManager` to check the age of PENDING flows before joining or reusing them. - Modified the `completeFlow` method to return false when the flow state is deleted, ensuring graceful handling of race conditions. - Enhanced tests to validate the new behavior and ensure robustness against state recovery issues. * refactor: MCP OAuth flow management and testing - Updated the `completeFlow` method to log warnings when a tool flow state is not found during completion, improving error handling. - Introduced a new `normalizeExpiresAt` function to standardize expiration timestamp handling across the application. - Refactored token expiration checks in `MCPConnectionFactory` to utilize the new normalization function, ensuring consistent behavior. - Added a comprehensive test suite for OAuth callback CSRF fallback logic, validating the handling of PENDING flows and their staleness. - Enhanced existing tests to cover new expiration normalization logic and ensure robust flow state management. * test: Add CSRF fallback tests for active PENDING flows in MCP OAuth - Introduced new tests to validate CSRF fallback behavior when a fresh PENDING flow exists without cookies, ensuring successful OAuth callback handling. - Added scenarios to reject requests when no PENDING flow exists, when only a COMPLETED flow is present, and when a PENDING flow is stale, enhancing the robustness of flow state management. - Improved overall test coverage for OAuth callback logic, reinforcing the handling of CSRF validation failures. * chore: imports order * refactor: Update UserConnectionManager to conditionally manage pending connections - Modified the logic in `UserConnectionManager` to only set pending connections if `forceNew` is false, preventing unnecessary overwrites. - Adjusted the cleanup process to ensure pending connections are only deleted when not forced, enhancing connection management efficiency. * refactor: MCP OAuth flow state management - Introduced a new method `storeStateMapping` in `MCPOAuthHandler` to securely map the OAuth state parameter to the flow ID, improving callback resolution and security against forgery. - Updated the OAuth initiation and callback handling in `mcp.js` to utilize the new state mapping functionality, ensuring robust flow management. - Refactored `MCPConnectionFactory` to store state mappings during flow initialization, enhancing the integrity of the OAuth process. - Adjusted comments to clarify the purpose of state parameters in authorization URLs, reinforcing code readability. * refactor: MCPConnection with OAuth recovery handling - Added `oauthRecovery` flag to manage OAuth recovery state during connection attempts. - Introduced `decrementCycleCount` method to reduce the circuit breaker's cycle count upon successful reconnection after OAuth recovery. - Updated connection logic to reset the `oauthRecovery` flag after handling OAuth, improving state management and connection reliability. * chore: Add debug logging for OAuth recovery cycle count decrement - Introduced a debug log statement in the `MCPConnection` class to track the decrement of the cycle count after a successful reconnection during OAuth recovery. - This enhancement improves observability and aids in troubleshooting connection issues related to OAuth recovery. * test: Add OAuth recovery cycle management tests - Introduced new tests for the OAuth recovery cycle in `MCPConnection`, validating the decrement of cycle counts after successful reconnections. - Added scenarios to ensure that the cycle count is not decremented on OAuth failures, enhancing the robustness of connection management. - Improved test coverage for OAuth reconnect scenarios, ensuring reliable behavior under various conditions. * feat: Implement circuit breaker configuration in MCP - Added circuit breaker settings to `.env.example` for max cycles, cycle window, and cooldown duration. - Refactored `MCPConnection` to utilize the new configuration values from `mcpConfig`, enhancing circuit breaker management. - Improved code maintainability by centralizing circuit breaker parameters in the configuration file. * refactor: Update decrementCycleCount method for circuit breaker management - Changed the visibility of the `decrementCycleCount` method in `MCPConnection` from private to public static, allowing it to be called with a server name parameter. - Updated calls to `decrementCycleCount` in `MCPConnectionFactory` to use the new static method, improving clarity and consistency in circuit breaker management during connection failures and OAuth recovery. - Enhanced the handling of circuit breaker state by ensuring the method checks for the existence of the circuit breaker before decrementing the cycle count. * refactor: cycle count decrement on tool listing failure - Added a call to `MCPConnection.decrementCycleCount` in the `MCPConnectionFactory` to handle cases where unauthenticated tool listing fails, improving circuit breaker management. - This change ensures that the cycle count is decremented appropriately, maintaining the integrity of the connection recovery process. * refactor: Update circuit breaker configuration and logic - Enhanced circuit breaker settings in `.env.example` to include new parameters for failed rounds and backoff strategies. - Refactored `MCPConnection` to utilize the updated configuration values from `mcpConfig`, improving circuit breaker management. - Updated tests to reflect changes in circuit breaker logic, ensuring accurate validation of connection behavior under rapid reconnect scenarios. * feat: Implement state mapping deletion in MCP flow management - Added a new method `deleteStateMapping` in `MCPOAuthHandler` to remove orphaned state mappings when a flow is replaced, preventing old authorization URLs from resolving after a flow restart. - Updated `MCPConnectionFactory` to call `deleteStateMapping` during flow cleanup, ensuring proper management of OAuth states. - Enhanced test coverage for state mapping functionality to validate the new deletion logic. --- .env.example | 21 + api/server/routes/__tests__/mcp.spec.js | 121 ++++ api/server/routes/mcp.js | 68 +- packages/api/jest.config.mjs | 1 + packages/api/package.json | 4 +- packages/api/src/flow/manager.ts | 79 +-- packages/api/src/mcp/MCPConnectionFactory.ts | 105 ++- packages/api/src/mcp/UserConnectionManager.ts | 86 ++- .../__tests__/MCPConnectionFactory.test.ts | 4 +- .../__tests__/MCPOAuthCSRFFallback.test.ts | 232 +++++++ .../MCPOAuthConnectionEvents.test.ts | 268 +++++++ .../src/mcp/__tests__/MCPOAuthFlow.test.ts | 538 ++++++++++++++ .../__tests__/MCPOAuthRaceCondition.test.ts | 516 ++++++++++++++ .../mcp/__tests__/MCPOAuthTokenExpiry.test.ts | 654 ++++++++++++++++++ .../__tests__/MCPOAuthTokenStorage.test.ts | 544 +++++++++++++++ .../mcp/__tests__/helpers/oauthTestServer.ts | 449 ++++++++++++ .../mcp/__tests__/reconnection-storm.test.ts | 175 ++++- packages/api/src/mcp/connection.ts | 43 +- packages/api/src/mcp/mcpConfig.ts | 14 + packages/api/src/mcp/oauth/handler.ts | 46 +- packages/api/src/mcp/oauth/tokens.ts | 24 +- packages/api/src/mcp/oauth/types.ts | 1 + 22 files changed, 3865 insertions(+), 128 deletions(-) create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts create mode 100644 packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts diff --git a/.env.example b/.env.example index b851749baf..e746737ea4 100644 --- a/.env.example +++ b/.env.example @@ -850,3 +850,24 @@ OPENWEATHER_API_KEY= # Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it) # When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration # MCP_SKIP_CODE_CHALLENGE_CHECK=false + +# Circuit breaker: max connect/disconnect cycles before tripping (per server) +# MCP_CB_MAX_CYCLES=7 + +# Circuit breaker: sliding window (ms) for counting cycles +# MCP_CB_CYCLE_WINDOW_MS=45000 + +# Circuit breaker: cooldown (ms) after the cycle breaker trips +# MCP_CB_CYCLE_COOLDOWN_MS=15000 + +# Circuit breaker: max consecutive failed connection rounds before backoff +# MCP_CB_MAX_FAILED_ROUNDS=3 + +# Circuit breaker: sliding window (ms) for counting failed rounds +# MCP_CB_FAILED_WINDOW_MS=120000 + +# Circuit breaker: base backoff (ms) after failed round threshold is reached +# MCP_CB_BASE_BACKOFF_MS=30000 + +# Circuit breaker: max backoff cap (ms) for exponential backoff +# MCP_CB_MAX_BACKOFF_MS=300000 diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index e87fcf8f15..009b602604 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -32,6 +32,9 @@ jest.mock('@librechat/api', () => { getFlowState: jest.fn(), completeOAuthFlow: jest.fn(), generateFlowId: jest.fn(), + resolveStateToFlowId: jest.fn(async (state) => state), + storeStateMapping: jest.fn(), + deleteStateMapping: jest.fn(), }, MCPTokenStorage: { storeTokens: jest.fn(), @@ -180,7 +183,10 @@ describe('MCP Routes', () => { MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ authorizationUrl: 'https://oauth.example.com/auth', flowId: 'test-user-id:test-server', + flowMetadata: { state: 'random-state-value' }, }); + MCPOAuthHandler.storeStateMapping.mockResolvedValue(); + mockFlowManager.initFlow = jest.fn().mockResolvedValue(); const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', @@ -367,6 +373,121 @@ describe('MCP Routes', () => { expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`); }); + describe('CSRF fallback via active PENDING flow', () => { + it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ + status: 'PENDING', + createdAt: Date.now(), + }), + completeFlow: jest.fn().mockResolvedValue(true), + deleteFlow: jest.fn().mockResolvedValue(true), + }; + const mockFlowState = { + serverName: 'test-server', + userId: 'test-user-id', + metadata: {}, + clientInfo: {}, + codeVerifier: 'test-verifier', + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); + MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({ + access_token: 'test-token', + }); + MCPTokenStorage.storeTokens.mockResolvedValue(); + mockRegistryInstance.getServerConfig.mockResolvedValue({}); + + const mockMcpManager = { + getUserConnection: jest.fn().mockResolvedValue({ + fetchTools: jest.fn().mockResolvedValue([]), + }), + }; + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + require('~/config').getOAuthReconnectionManager.mockReturnValue({ + clearReconnection: jest.fn(), + }); + require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue(); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toContain(`${basePath}/oauth/success`); + }); + + it('should reject when no PENDING flow exists and no cookies are present', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue(null), + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should reject when only a COMPLETED flow exists (not PENDING)', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ + status: 'COMPLETED', + createdAt: Date.now(), + }), + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => { + const flowId = 'test-user-id:test-server'; + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ + status: 'PENDING', + createdAt: Date.now() - 3 * 60 * 1000, + }), + }; + + getLogStores.mockReturnValue({}); + require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .query({ code: 'test-code', state: flowId }); + + const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + }); + it('should handle OAuth callback successfully', async () => { // mockRegistryInstance is defined at the top of the file const mockFlowManager = { diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 2db8c2c462..0afac81192 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -13,6 +13,7 @@ const { MCPOAuthHandler, MCPTokenStorage, setOAuthSession, + PENDING_STALE_MS, getUserMCPAuthMap, validateOAuthCsrf, OAUTH_CSRF_COOKIE, @@ -91,7 +92,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async } const oauthHeaders = await getOAuthHeaders(serverName, userId); - const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow( + const { + authorizationUrl, + flowId: oauthFlowId, + flowMetadata, + } = await MCPOAuthHandler.initiateOAuthFlow( serverName, serverUrl, userId, @@ -101,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); + await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager); setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH); res.redirect(authorizationUrl); } catch (error) { @@ -143,30 +149,52 @@ router.get('/:serverName/oauth/callback', async (req, res) => { return res.redirect(`${basePath}/oauth/error?error=missing_state`); } - const flowId = state; - logger.debug('[MCP OAuth] Using flow ID from state', { flowId }); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager); + if (!flowId) { + logger.error('[MCP OAuth] Could not resolve state to flow ID', { state }); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); + } + logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId }); const flowParts = flowId.split(':'); if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) { - logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId }); + logger.error('[MCP OAuth] Invalid flow ID format', { flowId }); return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } const [flowUserId] = flowParts; - if ( - !validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) && - !validateOAuthSession(req, flowUserId) - ) { - logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', { - flowId, - hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], - hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], - }); - return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + + const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH); + const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId); + let hasActiveFlow = false; + if (!hasCsrf && !hasSession) { + const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth'); + const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity; + hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS; + if (hasActiveFlow) { + logger.debug( + '[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow', + { + flowId, + }, + ); + } } - const flowsCache = getLogStores(CacheKeys.FLOWS); - const flowManager = getFlowStateManager(flowsCache); + if (!hasCsrf && !hasSession && !hasActiveFlow) { + logger.error( + '[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow', + { + flowId, + hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], + hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], + }, + ); + return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + } logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId); const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager); @@ -281,7 +309,13 @@ router.get('/:serverName/oauth/callback', async (req, res) => { const toolFlowId = flowState.metadata?.toolFlowId; if (toolFlowId) { logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId }); - await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); + const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); + if (!completed) { + logger.warn( + '[MCP OAuth] Tool flow state not found during completion — waiter will time out', + { toolFlowId }, + ); + } } /** Redirect to success page with flowId and serverName */ diff --git a/packages/api/jest.config.mjs b/packages/api/jest.config.mjs index 530150a7fa..df9cf6bcc2 100644 --- a/packages/api/jest.config.mjs +++ b/packages/api/jest.config.mjs @@ -7,6 +7,7 @@ export default { '\\.dev\\.ts$', '\\.helper\\.ts$', '\\.helper\\.d\\.ts$', + '/__tests__/helpers/', ], coverageReporters: ['text', 'cobertura'], testResultsProcessor: 'jest-junit', diff --git a/packages/api/package.json b/packages/api/package.json index e4ca4ef3c5..46587797a5 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -18,8 +18,8 @@ "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs", - "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", - "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", + "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", + "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"", "test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", "test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", diff --git a/packages/api/src/flow/manager.ts b/packages/api/src/flow/manager.ts index 4f9023a3d7..b68b9edb7a 100644 --- a/packages/api/src/flow/manager.ts +++ b/packages/api/src/flow/manager.ts @@ -3,6 +3,18 @@ import { logger } from '@librechat/data-schemas'; import type { StoredDataNoRaw } from 'keyv'; import type { FlowState, FlowMetadata, FlowManagerOptions } from './types'; +export const PENDING_STALE_MS = 2 * 60 * 1000; + +const SECONDS_THRESHOLD = 1e10; + +/** + * Normalizes an expiration timestamp to milliseconds. + * Timestamps below 10 billion are assumed to be in seconds (valid until ~2286). + */ +export function normalizeExpiresAt(timestamp: number): number { + return timestamp < SECONDS_THRESHOLD ? timestamp * 1000 : timestamp; +} + export class FlowStateManager { private keyv: Keyv; private ttl: number; @@ -45,32 +57,8 @@ export class FlowStateManager { return `${type}:${flowId}`; } - /** - * Normalizes an expiration timestamp to milliseconds. - * Detects whether the input is in seconds or milliseconds based on magnitude. - * Timestamps below 10 billion are assumed to be in seconds (valid until ~2286). - * @param timestamp - The expiration timestamp (in seconds or milliseconds) - * @returns The timestamp normalized to milliseconds - */ - private normalizeExpirationTimestamp(timestamp: number): number { - const SECONDS_THRESHOLD = 1e10; - if (timestamp < SECONDS_THRESHOLD) { - return timestamp * 1000; - } - return timestamp; - } - - /** - * Checks if a flow's token has expired based on its expires_at field - * @param flowState - The flow state to check - * @returns true if the token has expired, false otherwise (including if no expires_at exists) - */ private isTokenExpired(flowState: FlowState | undefined): boolean { - if (!flowState?.result) { - return false; - } - - if (typeof flowState.result !== 'object') { + if (!flowState?.result || typeof flowState.result !== 'object') { return false; } @@ -79,13 +67,11 @@ export class FlowStateManager { } const expiresAt = (flowState.result as { expires_at: unknown }).expires_at; - if (typeof expiresAt !== 'number' || !Number.isFinite(expiresAt)) { return false; } - const normalizedExpiresAt = this.normalizeExpirationTimestamp(expiresAt); - return normalizedExpiresAt < Date.now(); + return normalizeExpiresAt(expiresAt) < Date.now(); } /** @@ -149,6 +135,8 @@ export class FlowStateManager { let elapsedTime = 0; let isCleanedUp = false; let intervalId: NodeJS.Timeout | null = null; + let missingStateRetried = false; + let isRetrying = false; // Cleanup function to avoid duplicate cleanup const cleanup = () => { @@ -188,16 +176,29 @@ export class FlowStateManager { } intervalId = setInterval(async () => { - if (isCleanedUp) return; + if (isCleanedUp || isRetrying) return; try { - const flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; + let flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; if (!flowState) { - cleanup(); - logger.error(`[${flowKey}] Flow state not found`); - reject(new Error(`${type} Flow state not found`)); - return; + if (!missingStateRetried) { + missingStateRetried = true; + isRetrying = true; + logger.warn( + `[${flowKey}] Flow state not found, retrying once after 500ms (race recovery)`, + ); + await new Promise((r) => setTimeout(r, 500)); + flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; + isRetrying = false; + } + + if (!flowState) { + cleanup(); + logger.error(`[${flowKey}] Flow state not found after retry`); + reject(new Error(`${type} Flow state not found`)); + return; + } } if (signal?.aborted) { @@ -251,10 +252,10 @@ export class FlowStateManager { const flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; if (!flowState) { - logger.warn('[FlowStateManager] Cannot complete flow - flow state not found', { - flowId, - type, - }); + logger.warn( + '[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping', + { flowId, type }, + ); return false; } @@ -297,7 +298,7 @@ export class FlowStateManager { async isFlowStale( flowId: string, type: string, - staleThresholdMs: number = 2 * 60 * 1000, + staleThresholdMs: number = PENDING_STALE_MS, ): Promise<{ isStale: boolean; age: number; status?: string }> { const flowKey = this.getFlowKey(flowId, type); const flowState = (await this.keyv.get(flowKey)) as FlowState | undefined; diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 03131b659b..0fc86e0315 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -2,11 +2,11 @@ import { logger } from '@librechat/data-schemas'; import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; import type { Tool } from '@modelcontextprotocol/sdk/types.js'; import type { TokenMethods } from '@librechat/data-schemas'; -import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth'; +import type { MCPOAuthTokens, OAuthMetadata, MCPOAuthFlowMetadata } from '~/mcp/oauth'; import type { FlowStateManager } from '~/flow/manager'; -import type { FlowMetadata } from '~/flow/types'; import type * as t from './types'; -import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; +import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { PENDING_STALE_MS, normalizeExpiresAt } from '~/flow/manager'; import { sanitizeUrlForLogging } from './utils'; import { withTimeout } from '~/utils/promise'; import { MCPConnection } from './connection'; @@ -104,6 +104,7 @@ export class MCPConnectionFactory { return { tools, connection, oauthRequired: false, oauthUrl: null }; } } catch { + MCPConnection.decrementCycleCount(this.serverName); logger.debug( `${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`, ); @@ -125,7 +126,9 @@ export class MCPConnectionFactory { } return { tools, connection: null, oauthRequired, oauthUrl }; } + MCPConnection.decrementCycleCount(this.serverName); } catch (listError) { + MCPConnection.decrementCycleCount(this.serverName); logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError); } @@ -265,6 +268,10 @@ export class MCPConnectionFactory { if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`); return tokens; } catch (error) { + if (error instanceof ReauthenticationRequiredError) { + logger.info(`${this.logPrefix} ${error.message}, will trigger OAuth flow`); + return null; + } logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error); return null; } @@ -306,11 +313,21 @@ export class MCPConnectionFactory { const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth'); if (existingFlow?.status === 'PENDING') { + const pendingAge = existingFlow.createdAt + ? Date.now() - existingFlow.createdAt + : Infinity; + + if (pendingAge < PENDING_STALE_MS) { + logger.debug( + `${this.logPrefix} Recent PENDING OAuth flow exists (${Math.round(pendingAge / 1000)}s old), skipping new initiation`, + ); + connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); + return; + } + logger.debug( - `${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`, + `${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will replace`, ); - connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); - return; } const { @@ -326,11 +343,17 @@ export class MCPConnectionFactory { ); if (existingFlow) { + const oldState = (existingFlow.metadata as MCPOAuthFlowMetadata)?.state; await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth'); + if (oldState) { + await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager!); + } } // Store flow state BEFORE redirecting so the callback can find it - await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', flowMetadata); + const metadataWithUrl = { ...flowMetadata, authorizationUrl }; + await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl); + await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager!); // Start monitoring in background — createFlow will find the existing PENDING state // written by initFlow above, so metadata arg is unused (pass {} to make that explicit) @@ -495,11 +518,75 @@ export class MCPConnectionFactory { const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth'); if (existingFlow) { + const flowMeta = existingFlow.metadata as MCPOAuthFlowMetadata | undefined; + + if (existingFlow.status === 'PENDING') { + const pendingAge = existingFlow.createdAt + ? Date.now() - existingFlow.createdAt + : Infinity; + + if (pendingAge < PENDING_STALE_MS) { + logger.debug( + `${this.logPrefix} Found recent PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), joining instead of creating new one`, + ); + + const storedAuthUrl = flowMeta?.authorizationUrl; + if (storedAuthUrl && typeof this.oauthStart === 'function') { + logger.info( + `${this.logPrefix} Re-issuing stored authorization URL to caller while joining PENDING flow`, + ); + await this.oauthStart(storedAuthUrl); + } + + const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth', {}, this.signal); + if (typeof this.oauthEnd === 'function') { + await this.oauthEnd(); + } + logger.info( + `${this.logPrefix} Joined existing OAuth flow completed for ${this.serverName}`, + ); + return { + tokens, + clientInfo: flowMeta?.clientInfo, + metadata: flowMeta?.metadata, + }; + } + + logger.debug( + `${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will delete and start fresh`, + ); + } + + if (existingFlow.status === 'COMPLETED') { + const completedAge = existingFlow.completedAt + ? Date.now() - existingFlow.completedAt + : Infinity; + const cachedTokens = existingFlow.result as MCPOAuthTokens | null | undefined; + const isTokenExpired = + cachedTokens?.expires_at != null && + normalizeExpiresAt(cachedTokens.expires_at) < Date.now(); + + if (completedAge <= PENDING_STALE_MS && cachedTokens !== undefined && !isTokenExpired) { + logger.debug( + `${this.logPrefix} Found non-stale COMPLETED OAuth flow, reusing cached tokens`, + ); + return { + tokens: cachedTokens, + clientInfo: flowMeta?.clientInfo, + metadata: flowMeta?.metadata, + }; + } + } + logger.debug( `${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`, ); try { + const oldState = flowMeta?.state; await this.flowManager.deleteFlow(flowId, 'mcp_oauth'); + if (oldState) { + await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager); + } } catch (error) { logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error); } @@ -519,7 +606,9 @@ export class MCPConnectionFactory { ); // Store flow state BEFORE redirecting so the callback can find it - await this.flowManager.initFlow(newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata); + const metadataWithUrl = { ...flowMetadata, authorizationUrl }; + await this.flowManager.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl); + await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager); if (typeof this.oauthStart === 'function') { logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`); diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 0828b1720a..76523fc0fc 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -1,10 +1,10 @@ import { logger } from '@librechat/data-schemas'; import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; -import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; -import { MCPConnection } from './connection'; import type * as t from './types'; +import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { MCPConnection } from './connection'; import { mcpConfig } from './mcpConfig'; /** @@ -21,6 +21,8 @@ export abstract class UserConnectionManager { protected userConnections: Map> = new Map(); /** Last activity timestamp for users (not per server) */ protected userLastActivity: Map = new Map(); + /** In-flight connection promises keyed by `userId:serverName` — coalesces concurrent attempts */ + protected pendingConnections: Map> = new Map(); /** Updates the last activity timestamp for a user */ protected updateUserLastActivity(userId: string): void { @@ -31,29 +33,64 @@ export abstract class UserConnectionManager { ); } - /** Gets or creates a connection for a specific user */ - public async getUserConnection({ - serverName, - forceNew, - user, - flowManager, - customUserVars, - requestBody, - tokenMethods, - oauthStart, - oauthEnd, - signal, - returnOnOAuth = false, - connectionTimeout, - }: { - serverName: string; - forceNew?: boolean; - } & Omit): Promise { + /** Gets or creates a connection for a specific user, coalescing concurrent attempts */ + public async getUserConnection( + opts: { + serverName: string; + forceNew?: boolean; + } & Omit, + ): Promise { + const { serverName, forceNew, user } = opts; const userId = user?.id; if (!userId) { throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); } + const lockKey = `${userId}:${serverName}`; + + if (!forceNew) { + const pending = this.pendingConnections.get(lockKey); + if (pending) { + logger.debug(`[MCP][User: ${userId}][${serverName}] Joining in-flight connection attempt`); + return pending; + } + } + + const connectionPromise = this.createUserConnectionInternal(opts, userId); + + if (!forceNew) { + this.pendingConnections.set(lockKey, connectionPromise); + } + + try { + return await connectionPromise; + } finally { + if (!forceNew && this.pendingConnections.get(lockKey) === connectionPromise) { + this.pendingConnections.delete(lockKey); + } + } + } + + private async createUserConnectionInternal( + { + serverName, + forceNew, + user, + flowManager, + customUserVars, + requestBody, + tokenMethods, + oauthStart, + oauthEnd, + signal, + returnOnOAuth = false, + connectionTimeout, + }: { + serverName: string; + forceNew?: boolean; + } & Omit, + userId: string, + ): Promise { if (await this.appConnections!.has(serverName)) { throw new McpError( ErrorCode.InvalidRequest, @@ -188,6 +225,7 @@ export abstract class UserConnectionManager { /** Disconnects and removes a specific user connection */ public async disconnectUserConnection(userId: string, serverName: string): Promise { + this.pendingConnections.delete(`${userId}:${serverName}`); const userMap = this.userConnections.get(userId); const connection = userMap?.get(serverName); if (connection) { @@ -215,6 +253,12 @@ export abstract class UserConnectionManager { ); } await Promise.allSettled(disconnectPromises); + // Clean up any pending connection promises for this user + for (const key of this.pendingConnections.keys()) { + if (key.startsWith(`${userId}:`)) { + this.pendingConnections.delete(key); + } + } // Ensure user activity timestamp is removed this.userLastActivity.delete(userId); logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index de18e27e89..bceb23b246 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -275,7 +275,7 @@ describe('MCPConnectionFactory', () => { expect(mockFlowManager.initFlow).toHaveBeenCalledWith( 'flow123', 'mcp_oauth', - mockFlowData.flowMetadata, + expect.objectContaining(mockFlowData.flowMetadata), ); const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock @@ -550,7 +550,7 @@ describe('MCPConnectionFactory', () => { expect(mockFlowManager.initFlow).toHaveBeenCalledWith( 'flow123', 'mcp_oauth', - mockFlowData.flowMetadata, + expect.objectContaining(mockFlowData.flowMetadata), ); const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0]; const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock diff --git a/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts new file mode 100644 index 0000000000..cdba06cf8d --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts @@ -0,0 +1,232 @@ +/** + * Tests for the OAuth callback CSRF fallback logic. + * + * The callback route validates requests via three mechanisms (in order): + * 1. CSRF cookie (HMAC-based, set during initiate) + * 2. Session cookie (bound to authenticated userId) + * 3. Active PENDING flow in FlowStateManager (fallback for SSE/chat flows) + * + * This suite tests mechanism 3 — the PENDING flow fallback — including + * staleness enforcement and rejection of non-PENDING flows. + * + * These tests exercise the validation functions directly for fast, + * focused coverage. Route-level integration tests using supertest + * are in api/server/routes/__tests__/mcp.spec.js ("CSRF fallback + * via active PENDING flow" describe block). + */ + +import { Keyv } from 'keyv'; +import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager'; +import type { Request, Response } from 'express'; +import { + generateOAuthCsrfToken, + OAUTH_SESSION_COOKIE, + validateOAuthSession, + OAUTH_CSRF_COOKIE, + validateOAuthCsrf, +} from '~/oauth/csrf'; +import { MockKeyv } from './helpers/oauthTestServer'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +const CSRF_COOKIE_PATH = '/api/mcp'; + +function makeReq(cookies: Record = {}): Request { + return { cookies } as unknown as Request; +} + +function makeRes(): Response { + const res = { + clearCookie: jest.fn(), + } as unknown as Response; + return res; +} + +/** + * Replicate the callback route's three-tier validation logic. + * Returns which mechanism (if any) authorized the request. + */ +async function validateCallback( + req: Request, + res: Response, + flowId: string, + flowManager: FlowStateManager, +): Promise<'csrf' | 'session' | 'pendingFlow' | false> { + const flowUserId = flowId.split(':')[0]; + + const hasCsrf = validateOAuthCsrf(req, res, flowId, CSRF_COOKIE_PATH); + if (hasCsrf) { + return 'csrf'; + } + + const hasSession = validateOAuthSession(req, flowUserId); + if (hasSession) { + return 'session'; + } + + const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth'); + const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity; + if (pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS) { + return 'pendingFlow'; + } + + return false; +} + +describe('OAuth Callback CSRF Fallback', () => { + let flowManager: FlowStateManager; + + beforeEach(() => { + process.env.JWT_SECRET = 'test-secret-for-csrf'; + const store = new MockKeyv(); + flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 300000, ci: true }); + }); + + afterEach(() => { + delete process.env.JWT_SECRET; + jest.clearAllMocks(); + }); + + describe('CSRF cookie validation (mechanism 1)', () => { + it('should accept valid CSRF cookie', async () => { + const flowId = 'user1:test-server'; + const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('csrf'); + }); + + it('should reject invalid CSRF cookie', async () => { + const flowId = 'user1:test-server'; + const req = makeReq({ [OAUTH_CSRF_COOKIE]: 'wrong-token-value' }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + }); + + describe('Session cookie validation (mechanism 2)', () => { + it('should accept valid session cookie when CSRF is absent', async () => { + const flowId = 'user1:test-server'; + const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('session'); + }); + }); + + describe('PENDING flow fallback (mechanism 3)', () => { + it('should accept when a fresh PENDING flow exists and no cookies are present', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('pendingFlow'); + }); + + it('should reject when no PENDING flow, no CSRF cookie, and no session cookie', async () => { + const flowId = 'user1:test-server'; + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should reject when only a COMPLETED flow exists (not PENDING)', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + await flowManager.completeFlow(flowId, 'mcp_oauth', { access_token: 'tok' } as never); + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should reject when only a FAILED flow exists', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + await flowManager.failFlow(flowId, 'mcp_oauth', 'some error'); + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + // Artificially age the flow past the staleness threshold + const store = (flowManager as unknown as { keyv: { get: (k: string) => Promise } }) + .keyv; + const flowState = (await store.get(`mcp_oauth:${flowId}`)) as { createdAt: number }; + flowState.createdAt = Date.now() - PENDING_STALE_MS - 1000; + + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe(false); + }); + + it('should accept PENDING flow that is just under the staleness threshold', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + // Flow was just created, well under threshold + const req = makeReq(); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('pendingFlow'); + }); + }); + + describe('Priority ordering', () => { + it('should prefer CSRF cookie over PENDING flow', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('csrf'); + }); + + it('should prefer session cookie over PENDING flow when CSRF is absent', async () => { + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' }); + + const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf'); + const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken }); + const res = makeRes(); + + const result = await validateCallback(req, res, flowId, flowManager); + expect(result).toBe('session'); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts new file mode 100644 index 0000000000..4e168d00f3 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts @@ -0,0 +1,268 @@ +/** + * Tests for MCPConnection OAuth event cycle against a real OAuth-gated MCP server. + * + * Verifies: oauthRequired emission on 401, oauthHandled reconnection, + * oauthFailed rejection, timeout behavior, and token expiry mid-session. + */ + +import { MCPConnection } from '~/mcp/connection'; +import { createOAuthMCPServer } from './helpers/oauthTestServer'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +jest.mock('~/auth', () => ({ + createSSRFSafeUndiciConnect: jest.fn(() => undefined), + resolveHostnameSSRF: jest.fn(async () => false), +})); + +jest.mock('~/mcp/mcpConfig', () => ({ + mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 }, +})); + +async function safeDisconnect(conn: MCPConnection | null): Promise { + if (!conn) { + return; + } + try { + await conn.disconnect(); + } catch { + // Ignore disconnect errors during cleanup + } +} + +async function exchangeCodeForToken(serverUrl: string): Promise { + const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, { + redirect: 'manual', + }); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${serverUrl}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const data = (await tokenRes.json()) as { access_token: string }; + return data.access_token; +} + +describe('MCPConnection OAuth Events — Real Server', () => { + let server: OAuthTestServer; + let connection: MCPConnection | null = null; + + beforeEach(() => { + MCPConnection.clearCooldown('test-server'); + }); + + afterEach(async () => { + await safeDisconnect(connection); + connection = null; + if (server) { + await server.close(); + } + jest.clearAllMocks(); + }); + + describe('oauthRequired event', () => { + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + it('should emit oauthRequired when connecting without a token', async () => { + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { type: 'streamable-http', url: server.url }, + userId: 'user-1', + }); + + const oauthRequiredPromise = new Promise<{ + serverName: string; + error: Error; + serverUrl?: string; + userId?: string; + }>((resolve) => { + connection!.on('oauthRequired', (data) => { + resolve( + data as { + serverName: string; + error: Error; + serverUrl?: string; + userId?: string; + }, + ); + }); + }); + + // Connection will fail with 401, emitting oauthRequired + const connectPromise = connection.connect().catch(() => { + // Expected to fail since no one handles oauthRequired + }); + + let raceTimer: NodeJS.Timeout | undefined; + const eventData = await Promise.race([ + oauthRequiredPromise, + new Promise((_, reject) => { + raceTimer = setTimeout( + () => reject(new Error('Timed out waiting for oauthRequired')), + 10000, + ); + }), + ]).finally(() => clearTimeout(raceTimer)); + + expect(eventData.serverName).toBe('test-server'); + expect(eventData.error).toBeDefined(); + + // Emit oauthFailed to unblock connect() + connection.emit('oauthFailed', new Error('test cleanup')); + await connectPromise.catch(() => undefined); + }); + + it('should not emit oauthRequired when connecting with a valid token', async () => { + const accessToken = await exchangeCodeForToken(server.url); + + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { type: 'streamable-http', url: server.url }, + userId: 'user-1', + oauthTokens: { + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens, + }); + + let oauthFired = false; + connection.on('oauthRequired', () => { + oauthFired = true; + }); + + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + expect(oauthFired).toBe(false); + }); + }); + + describe('oauthHandled reconnection', () => { + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + it('should succeed on retry after oauthHandled provides valid tokens', async () => { + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { + type: 'streamable-http', + url: server.url, + initTimeout: 15000, + }, + userId: 'user-1', + }); + + // First connect fails with 401 → oauthRequired fires + let oauthFired = false; + connection.on('oauthRequired', () => { + oauthFired = true; + connection!.emit('oauthFailed', new Error('Will retry with tokens')); + }); + + // First attempt fails as expected + await expect(connection.connect()).rejects.toThrow(); + expect(oauthFired).toBe(true); + + // Now set valid tokens and reconnect + const accessToken = await exchangeCodeForToken(server.url); + connection.setOAuthTokens({ + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + }); + }); + + describe('oauthFailed rejection', () => { + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + it('should reject connect() when oauthFailed is emitted', async () => { + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { + type: 'streamable-http', + url: server.url, + initTimeout: 15000, + }, + userId: 'user-1', + }); + + connection.on('oauthRequired', () => { + connection!.emit('oauthFailed', new Error('User denied OAuth')); + }); + + await expect(connection.connect()).rejects.toThrow(); + }); + }); + + describe('Token expiry during session', () => { + it('should detect expired token on reconnect and emit oauthRequired', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 1000 }); + + const accessToken = await exchangeCodeForToken(server.url); + + connection = new MCPConnection({ + serverName: 'test-server', + serverConfig: { + type: 'streamable-http', + url: server.url, + initTimeout: 15000, + }, + userId: 'user-1', + oauthTokens: { + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens, + }); + + // Initial connect should succeed + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + await connection.disconnect(); + + // Wait for token to expire + await new Promise((r) => setTimeout(r, 1200)); + + // Reconnect should trigger oauthRequired since token is expired on the server + let oauthFired = false; + connection.on('oauthRequired', () => { + oauthFired = true; + connection!.emit('oauthFailed', new Error('Will retry with fresh token')); + }); + + // First reconnect fails with 401 → oauthRequired + await expect(connection.connect()).rejects.toThrow(); + expect(oauthFired).toBe(true); + + // Get fresh token and reconnect + const newToken = await exchangeCodeForToken(server.url); + connection.setOAuthTokens({ + access_token: newToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + + await connection.connect(); + expect(await connection.isConnected()).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts new file mode 100644 index 0000000000..8437177c86 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -0,0 +1,538 @@ +/** + * OAuth flow tests against a real HTTP server. + * + * Tests MCPOAuthHandler.refreshOAuthTokens and MCPTokenStorage lifecycle + * using a real test OAuth server (not mocked SDK functions). + */ + +import { createHash } from 'crypto'; +import { Keyv } from 'keyv'; +import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; +import { FlowStateManager } from '~/flow/manager'; +import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +describe('MCP OAuth Flow — Real HTTP Server', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('Token refresh with real server', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should refresh tokens with stored client info via real /token endpoint', async () => { + // First get initial tokens + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + expect(initial.refresh_token).toBeDefined(); + + // Register a client so we have clientInfo + const regRes = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const clientInfo = (await regRes.json()) as { + client_id: string; + client_secret: string; + }; + + // Refresh tokens using the real endpoint + const refreshed = await MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'test-server', + serverUrl: server.url, + clientInfo: { + ...clientInfo, + redirect_uris: ['http://localhost/callback'], + }, + }, + {}, + { + token_url: `${server.url}token`, + client_id: clientInfo.client_id, + client_secret: clientInfo.client_secret, + token_exchange_method: 'DefaultPost', + }, + ); + + expect(refreshed.access_token).toBeDefined(); + expect(refreshed.access_token).not.toBe(initial.access_token); + expect(refreshed.token_type).toBe('Bearer'); + expect(refreshed.obtained_at).toBeDefined(); + }); + + it('should get new refresh token when server rotates', async () => { + const rotatingServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + rotateRefreshTokens: true, + }); + + try { + const code = await rotatingServer.getAuthCode(); + const tokenRes = await fetch(`${rotatingServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const refreshed = await MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'test-server', + serverUrl: rotatingServer.url, + }, + {}, + { + token_url: `${rotatingServer.url}token`, + client_id: 'anon', + token_exchange_method: 'DefaultPost', + }, + ); + + expect(refreshed.access_token).not.toBe(initial.access_token); + expect(refreshed.refresh_token).toBeDefined(); + expect(refreshed.refresh_token).not.toBe(initial.refresh_token); + } finally { + await rotatingServer.close(); + } + }); + + it('should fail refresh with invalid refresh token', async () => { + await expect( + MCPOAuthHandler.refreshOAuthTokens( + 'invalid-refresh-token', + { + serverName: 'test-server', + serverUrl: server.url, + }, + {}, + { + token_url: `${server.url}token`, + client_id: 'anon', + token_exchange_method: 'DefaultPost', + }, + ), + ).rejects.toThrow(); + }); + }); + + describe('OAuth server metadata discovery', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ issueRefreshTokens: true }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should expose /.well-known/oauth-authorization-server', async () => { + const res = await fetch(`${server.url}.well-known/oauth-authorization-server`); + expect(res.status).toBe(200); + + const metadata = (await res.json()) as { + authorization_endpoint: string; + token_endpoint: string; + registration_endpoint: string; + grant_types_supported: string[]; + }; + + expect(metadata.authorization_endpoint).toContain('/authorize'); + expect(metadata.token_endpoint).toContain('/token'); + expect(metadata.registration_endpoint).toContain('/register'); + expect(metadata.grant_types_supported).toContain('authorization_code'); + expect(metadata.grant_types_supported).toContain('refresh_token'); + }); + + it('should not advertise refresh_token grant when disabled', async () => { + const noRefreshServer = await createOAuthMCPServer({ + issueRefreshTokens: false, + }); + try { + const res = await fetch(`${noRefreshServer.url}.well-known/oauth-authorization-server`); + const metadata = (await res.json()) as { grant_types_supported: string[] }; + expect(metadata.grant_types_supported).not.toContain('refresh_token'); + } finally { + await noRefreshServer.close(); + } + }); + }); + + describe('Dynamic client registration', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer(); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should register a client via /register endpoint', async () => { + const res = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + redirect_uris: ['http://localhost/callback'], + }), + }); + + expect(res.status).toBe(200); + const client = (await res.json()) as { + client_id: string; + client_secret: string; + redirect_uris: string[]; + }; + + expect(client.client_id).toBeDefined(); + expect(client.client_secret).toBeDefined(); + expect(client.redirect_uris).toEqual(['http://localhost/callback']); + expect(server.registeredClients.has(client.client_id)).toBe(true); + }); + }); + + describe('End-to-End: store, retrieve, expire, refresh cycle', () => { + it('should perform full token lifecycle with real server', async () => { + const server = await createOAuthMCPServer({ + tokenTTLMs: 1000, + issueRefreshTokens: true, + }); + const tokenStore = new InMemoryTokenStore(); + + try { + // 1. Get initial tokens via auth code exchange + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // 2. Store tokens + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'test-srv', + tokens: initial, + createToken: tokenStore.createToken, + }); + + // 3. Retrieve — should succeed + const valid = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }); + expect(valid).not.toBeNull(); + expect(valid!.access_token).toBe(initial.access_token); + expect(valid!.refresh_token).toBe(initial.refresh_token); + + // 4. Wait for expiry + await new Promise((r) => setTimeout(r, 1200)); + + // 5. Retrieve again — should trigger refresh via callback + const refreshCallback = async (refreshToken: string): Promise => { + const refreshRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + + if (!refreshRes.ok) { + throw new Error(`Refresh failed: ${refreshRes.status}`); + } + + const data = (await refreshRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token?: string; + }; + + return { + ...data, + obtained_at: Date.now(), + expires_at: Date.now() + data.expires_in * 1000, + }; + }; + + const refreshed = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(refreshed).not.toBeNull(); + expect(refreshed!.access_token).not.toBe(initial.access_token); + + // 6. Verify the refreshed token works against the server + const mcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${refreshed!.access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + expect(mcpRes.status).toBe(200); + } finally { + await server.close(); + } + }); + }); + + describe('completeOAuthFlow via FlowStateManager', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ issueRefreshTokens: true }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should exchange auth code and complete flow in FlowStateManager', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'test-user:test-server'; + const code = await server.getAuthCode(); + + // Initialize the flow with metadata the handler needs + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverUrl: server.url, + clientInfo: { + client_id: 'test-client', + redirect_uris: ['http://localhost/callback'], + }, + codeVerifier: 'test-verifier', + metadata: { + token_endpoint: `${server.url}token`, + token_endpoint_auth_methods_supported: ['client_secret_post'], + }, + }); + + // The SDK's exchangeAuthorization wants full OAuth metadata, + // so we'll test the token exchange directly instead of going through + // completeOAuthFlow (which requires full SDK-compatible metadata) + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + const tokens = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token?: string; + }; + + const mcpTokens: MCPOAuthTokens = { + ...tokens, + obtained_at: Date.now(), + expires_at: Date.now() + tokens.expires_in * 1000, + }; + + // Complete the flow + const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens); + expect(completed).toBe(true); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('COMPLETED'); + expect(state?.result?.access_token).toBe(tokens.access_token); + }); + + it('should fail flow when authorization code is invalid', async () => { + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: 'grant_type=authorization_code&code=invalid-code', + }); + + expect(tokenRes.status).toBe(400); + const body = (await tokenRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + + it('should fail when authorization code is reused', async () => { + const code = await server.getAuthCode(); + + // First exchange succeeds + const firstRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + expect(firstRes.status).toBe(200); + + // Second exchange fails + const secondRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + expect(secondRes.status).toBe(400); + const body = (await secondRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + }); + + describe('PKCE verification', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + afterEach(async () => { + await server.close(); + }); + + function generatePKCE(): { verifier: string; challenge: string } { + const verifier = 'dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk'; + const challenge = createHash('sha256').update(verifier).digest('base64url'); + return { verifier, challenge }; + } + + it('should accept valid code_verifier matching code_challenge', async () => { + const { verifier, challenge } = generatePKCE(); + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&code_verifier=${verifier}`, + }); + + expect(tokenRes.status).toBe(200); + const data = (await tokenRes.json()) as { access_token: string }; + expect(data.access_token).toBeDefined(); + }); + + it('should reject wrong code_verifier', async () => { + const { challenge } = generatePKCE(); + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&code_verifier=wrong-verifier`, + }); + + expect(tokenRes.status).toBe(400); + const body = (await tokenRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + + it('should reject missing code_verifier when code_challenge was provided', async () => { + const { challenge } = generatePKCE(); + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + expect(tokenRes.status).toBe(400); + const body = (await tokenRes.json()) as { error: string }; + expect(body.error).toBe('invalid_grant'); + }); + + it('should still accept codes without PKCE when no code_challenge was provided', async () => { + const code = await server.getAuthCode(); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + expect(tokenRes.status).toBe(200); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts new file mode 100644 index 0000000000..85febb3ece --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -0,0 +1,516 @@ +/** + * Tests for MCP OAuth race condition fixes: + * + * 1. Connection mutex coalesces concurrent getUserConnection() calls + * 2. PENDING OAuth flows are reused, not deleted + * 3. No-refresh-token expiry throws ReauthenticationRequiredError + * 4. completeFlow recovers when flow state was deleted by a race + * 5. monitorFlow retries once when flow state disappears mid-poll + */ + +import { Keyv } from 'keyv'; +import { logger } from '@librechat/data-schemas'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; +import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { MockKeyv, createOAuthMCPServer } from './helpers/oauthTestServer'; +import { FlowStateManager } from '~/flow/manager'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +jest.mock('~/auth', () => ({ + createSSRFSafeUndiciConnect: jest.fn(() => undefined), + resolveHostnameSSRF: jest.fn(async () => false), +})); + +jest.mock('~/mcp/mcpConfig', () => ({ + mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 }, +})); + +const mockLogger = logger as jest.Mocked; + +describe('MCP OAuth Race Condition Fixes', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('Fix 1: Connection mutex coalesces concurrent attempts', () => { + it('should return the same pending promise for concurrent getUserConnection calls', async () => { + const { UserConnectionManager } = await import('~/mcp/UserConnectionManager'); + + class TestManager extends UserConnectionManager { + public createCallCount = 0; + + getPendingConnections() { + return this.pendingConnections; + } + } + + const manager = new TestManager(); + + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn().mockResolvedValue(undefined), + isStale: jest.fn().mockReturnValue(false), + }; + + const mockAppConnections = { has: jest.fn().mockResolvedValue(false) }; + manager.appConnections = mockAppConnections as never; + + const mockConfig = { + type: 'streamable-http', + url: 'http://localhost:9999/', + updatedAt: undefined, + dbId: undefined, + }; + + jest + .spyOn( + // eslint-disable-next-line @typescript-eslint/no-require-imports + require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry, + 'getInstance', + ) + .mockReturnValue({ + getServerConfig: jest.fn().mockResolvedValue(mockConfig), + shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + }); + + const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); + const createSpy = jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => { + manager.createCallCount++; + await new Promise((r) => setTimeout(r, 100)); + return mockConnection as never; + }); + + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + const user = { id: 'user-1' }; + const opts = { + serverName: 'test-server', + user: user as never, + flowManager: flowManager as never, + }; + + const [conn1, conn2, conn3] = await Promise.all([ + manager.getUserConnection(opts), + manager.getUserConnection(opts), + manager.getUserConnection(opts), + ]); + + expect(conn1).toBe(conn2); + expect(conn2).toBe(conn3); + expect(createSpy).toHaveBeenCalledTimes(1); + expect(manager.createCallCount).toBe(1); + + createSpy.mockRestore(); + }); + + it('should not coalesce when forceNew is true', async () => { + const { UserConnectionManager } = await import('~/mcp/UserConnectionManager'); + + class TestManager extends UserConnectionManager {} + + const manager = new TestManager(); + + let callCount = 0; + const makeConnection = () => ({ + isConnected: jest.fn().mockResolvedValue(true), + disconnect: jest.fn().mockResolvedValue(undefined), + isStale: jest.fn().mockReturnValue(false), + }); + + const mockAppConnections = { has: jest.fn().mockResolvedValue(false) }; + manager.appConnections = mockAppConnections as never; + + const mockConfig = { + type: 'streamable-http', + url: 'http://localhost:9999/', + updatedAt: undefined, + dbId: undefined, + }; + + jest + .spyOn( + // eslint-disable-next-line @typescript-eslint/no-require-imports + require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry, + 'getInstance', + ) + .mockReturnValue({ + getServerConfig: jest.fn().mockResolvedValue(mockConfig), + shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + }); + + const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); + jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => { + callCount++; + await new Promise((r) => setTimeout(r, 50)); + return makeConnection() as never; + }); + + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + const user = { id: 'user-2' }; + + const [conn1, conn2] = await Promise.all([ + manager.getUserConnection({ + serverName: 'test-server', + forceNew: true, + user: user as never, + flowManager: flowManager as never, + }), + manager.getUserConnection({ + serverName: 'test-server', + forceNew: true, + user: user as never, + flowManager: flowManager as never, + }), + ]); + + expect(callCount).toBe(2); + expect(conn1).not.toBe(conn2); + }); + }); + + describe('Fix 2: PENDING flow is reused, not deleted', () => { + it('should join an existing PENDING flow via createFlow instead of deleting it', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'test-flow-pending'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + clientInfo: { client_id: 'test-client' }, + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('PENDING'); + + const deleteSpy = jest.spyOn(flowManager, 'deleteFlow'); + + const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {}); + + await new Promise((r) => setTimeout(r, 500)); + + await flowManager.completeFlow(flowId, 'mcp_oauth', { + access_token: 'test-token', + token_type: 'Bearer', + } as never); + + const result = await monitorPromise; + expect(result).toEqual( + expect.objectContaining({ access_token: 'test-token', token_type: 'Bearer' }), + ); + expect(deleteSpy).not.toHaveBeenCalled(); + + deleteSpy.mockRestore(); + }); + + it('should delete and recreate FAILED flows', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'test-flow-failed'; + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + await flowManager.failFlow(flowId, 'mcp_oauth', 'previous error'); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('FAILED'); + + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + + const afterDelete = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(afterDelete).toBeUndefined(); + }); + }); + + describe('Fix 3: completeFlow handles deleted state gracefully', () => { + it('should return false when state was deleted by race', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'race-deleted-flow'; + + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + + const stateBeforeComplete = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(stateBeforeComplete).toBeUndefined(); + + const result = await flowManager.completeFlow(flowId, 'mcp_oauth', { + access_token: 'recovered-token', + token_type: 'Bearer', + } as never); + + expect(result).toBe(false); + + const stateAfterComplete = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(stateAfterComplete).toBeUndefined(); + + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('cannot recover metadata'), + expect.any(Object), + ); + }); + + it('should reject monitorFlow when state is deleted and not recoverable', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true }); + + const flowId = 'monitor-retry-flow'; + + await flowManager.initFlow(flowId, 'mcp_oauth', {}); + + const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {}); + + await new Promise((r) => setTimeout(r, 500)); + + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + + await expect(monitorPromise).rejects.toThrow('Flow state not found'); + }); + }); + + describe('State mapping cleanup on flow replacement', () => { + it('should delete old state mapping when a flow is replaced', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'user1:test-server'; + const oldState = 'old-random-state-abc123'; + const newState = 'new-random-state-xyz789'; + + // Simulate initial flow with state mapping + await flowManager.initFlow(flowId, 'mcp_oauth', { state: oldState }); + await MCPOAuthHandler.storeStateMapping(oldState, flowId, flowManager); + + // Old state should resolve + const resolvedBefore = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager); + expect(resolvedBefore).toBe(flowId); + + // Replace the flow: delete old, create new, clean up old state mapping + await flowManager.deleteFlow(flowId, 'mcp_oauth'); + await MCPOAuthHandler.deleteStateMapping(oldState, flowManager); + await flowManager.initFlow(flowId, 'mcp_oauth', { state: newState }); + await MCPOAuthHandler.storeStateMapping(newState, flowId, flowManager); + + // Old state should no longer resolve + const resolvedOld = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager); + expect(resolvedOld).toBeNull(); + + // New state should resolve + const resolvedNew = await MCPOAuthHandler.resolveStateToFlowId(newState, flowManager); + expect(resolvedNew).toBe(flowId); + }); + }); + + describe('Fix 4: ReauthenticationRequiredError for no-refresh-token', () => { + it('should throw ReauthenticationRequiredError when access token expired and no refresh token', async () => { + const expiredDate = new Date(Date.now() - 60000); + + const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => { + if (filter.type === 'mcp_oauth') { + return { + token: 'enc:expired-access-token', + expiresAt: expiredDate, + createdAt: new Date(Date.now() - 120000), + }; + } + if (filter.type === 'mcp_oauth_refresh') { + return null; + } + return null; + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }), + ).rejects.toThrow('Re-authentication required'); + }); + + it('should throw ReauthenticationRequiredError when access token is missing and no refresh token', async () => { + const findToken = jest.fn().mockResolvedValue(null); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should not throw when access token is valid', async () => { + const futureDate = new Date(Date.now() + 3600000); + + const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => { + if (filter.type === 'mcp_oauth') { + return { + token: 'enc:valid-access-token', + expiresAt: futureDate, + createdAt: new Date(), + }; + } + if (filter.type === 'mcp_oauth_refresh') { + return null; + } + return null; + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'user-1', + serverName: 'test-server', + findToken, + }); + + expect(result).not.toBeNull(); + expect(result?.access_token).toBe('valid-access-token'); + }); + }); + + describe('E2E: OAuth-gated MCP server with no refresh tokens', () => { + let server: OAuthTestServer; + + beforeEach(async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should start OAuth-gated MCP server that validates Bearer tokens', async () => { + const res = await fetch(server.url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'initialize', id: 1 }), + }); + + expect(res.status).toBe(401); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('invalid_token'); + }); + + it('should issue tokens via authorization code exchange with no refresh token', async () => { + const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, { + redirect: 'manual', + }); + + expect(authRes.status).toBe(302); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + expect(code).toBeTruthy(); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + expect(tokenRes.status).toBe(200); + const tokenBody = (await tokenRes.json()) as { + access_token: string; + token_type: string; + refresh_token?: string; + }; + expect(tokenBody.access_token).toBeTruthy(); + expect(tokenBody.token_type).toBe('Bearer'); + expect(tokenBody.refresh_token).toBeUndefined(); + }); + + it('should allow MCP requests with valid Bearer token', async () => { + const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, { + redirect: 'manual', + }); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + const { access_token } = (await tokenRes.json()) as { access_token: string }; + + const mcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + + expect(mcpRes.status).toBe(200); + }); + + it('should reject expired tokens with 401', async () => { + const shortTTLServer = await createOAuthMCPServer({ tokenTTLMs: 500 }); + + try { + const authRes = await fetch( + `${shortTTLServer.url}authorize?redirect_uri=http://localhost&state=s1`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + const tokenRes = await fetch(`${shortTTLServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + + const { access_token } = (await tokenRes.json()) as { access_token: string }; + + await new Promise((r) => setTimeout(r, 600)); + + const mcpRes = await fetch(shortTTLServer.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${access_token}`, + }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 2 }), + }); + + expect(mcpRes.status).toBe(401); + } finally { + await shortTTLServer.close(); + } + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts new file mode 100644 index 0000000000..986ac4c8b4 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts @@ -0,0 +1,654 @@ +/** + * Tests for MCP OAuth token expiry → re-authentication scenarios. + * + * Reproduces the edge case where: + * 1. Tokens are stored (access + refresh) + * 2. Access token expires + * 3. Refresh attempt fails (server rejects/revokes refresh token) + * 4. System must fall back to full OAuth re-auth via handleOAuthRequired + * 5. The CSRF cookie may be absent (chat/SSE flow), so the PENDING flow fallback is needed + * + * Also tests the happy path: access token expired but refresh succeeds. + */ + +import { Keyv } from 'keyv'; +import { logger } from '@librechat/data-schemas'; +import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager'; +import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { MockKeyv, InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +describe('MCP OAuth Token Expiry Scenarios', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('Access token expired + refresh token available + refresh succeeds', () => { + let server: OAuthTestServer; + let tokenStore: InMemoryTokenStore; + + beforeEach(async () => { + server = await createOAuthMCPServer({ + tokenTTLMs: 500, + issueRefreshTokens: true, + }); + tokenStore = new InMemoryTokenStore(); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should refresh expired access token via real /token endpoint', async () => { + // Get initial tokens from real server + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // Store expired access token directly (bypassing storeTokens' expiresIn clamping) + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: `enc:${initial.access_token}`, + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: `enc:${initial.refresh_token}`, + expiresIn: 86400, + }); + + const refreshCallback = async (refreshToken: string): Promise => { + const res = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + if (!res.ok) { + throw new Error(`Refresh failed: ${res.status}`); + } + const data = (await res.json()) as { + access_token: string; + token_type: string; + expires_in: number; + }; + return { + ...data, + obtained_at: Date.now(), + expires_at: Date.now() + data.expires_in * 1000, + }; + }; + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(result).not.toBeNull(); + expect(result!.access_token).not.toBe(initial.access_token); + + // Verify the refreshed token works against the server + const mcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${result!.access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + expect(mcpRes.status).toBe(200); + }); + }); + + describe('Access token expired + refresh token rejected by server', () => { + let tokenStore: InMemoryTokenStore; + + beforeEach(() => { + tokenStore = new InMemoryTokenStore(); + }); + + it('should return null when refresh token is rejected (invalid_grant)', async () => { + const server = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + + try { + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // Store expired access token directly + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: `enc:${initial.access_token}`, + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: `enc:${initial.refresh_token}`, + expiresIn: 86400, + }); + + // Simulate server revoking the refresh token + server.issuedRefreshTokens.clear(); + + const refreshCallback = async (refreshToken: string): Promise => { + const res = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + if (!res.ok) { + const body = (await res.json()) as { error: string }; + throw new Error(`Token refresh failed: ${body.error}`); + } + const data = (await res.json()) as MCPOAuthTokens; + return data; + }; + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(result).toBeNull(); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to refresh tokens'), + expect.any(Error), + ); + } finally { + await server.close(); + } + }); + + it('should return null when refresh endpoint returns unauthorized_client', async () => { + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: 'enc:some-refresh-token', + expiresIn: 86400, + }); + + const refreshCallback = jest + .fn() + .mockRejectedValue(new Error('unauthorized_client: client not authorized for refresh')); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + expect(result).toBeNull(); + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining('does not support refresh tokens'), + ); + }); + }); + + describe('Access token expired + NO refresh token → ReauthenticationRequiredError', () => { + let tokenStore: InMemoryTokenStore; + + beforeEach(() => { + tokenStore = new InMemoryTokenStore(); + }); + + it('should throw ReauthenticationRequiredError when no refresh token stored', async () => { + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should throw ReauthenticationRequiredError with correct reason for expired token', async () => { + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }), + ).rejects.toThrow('access token expired'); + }); + + it('should throw ReauthenticationRequiredError with correct reason for missing token', async () => { + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }), + ).rejects.toThrow('access token missing'); + }); + }); + + describe('PENDING flow fallback for CSRF-less OAuth callbacks', () => { + it('should allow OAuth completion when PENDING flow exists (simulating chat/SSE path)', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'user1:test-server'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + userId: 'user1', + serverUrl: 'https://example.com', + state: 'test-state', + authorizationUrl: 'https://example.com/authorize?state=user1:test-server', + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(state?.status).toBe('PENDING'); + + const tokens: MCPOAuthTokens = { + access_token: 'new-access-token', + token_type: 'Bearer', + refresh_token: 'new-refresh-token', + obtained_at: Date.now(), + expires_at: Date.now() + 3600000, + }; + + const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', tokens); + expect(completed).toBe(true); + + const completedState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(completedState?.status).toBe('COMPLETED'); + expect((completedState?.result as MCPOAuthTokens | undefined)?.access_token).toBe( + 'new-access-token', + ); + }); + + it('should store authorizationUrl in flow metadata for re-issuance', async () => { + const store = new MockKeyv(); + const flowManager = new FlowStateManager(store as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const flowId = 'user1:test-server'; + const authUrl = 'https://auth.example.com/authorize?client_id=abc&state=user1:test-server'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + userId: 'user1', + serverUrl: 'https://example.com', + state: 'test-state', + authorizationUrl: authUrl, + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect((state?.metadata as Record)?.authorizationUrl).toBe(authUrl); + }); + }); + + describe('Full token expiry → refresh failure → re-auth flow', () => { + let server: OAuthTestServer; + let tokenStore: InMemoryTokenStore; + + beforeEach(async () => { + server = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + tokenStore = new InMemoryTokenStore(); + }); + + afterEach(async () => { + await server.close(); + }); + + it('should go through full cycle: get tokens → expire → refresh fails → re-auth needed', async () => { + // Step 1: Get initial tokens + const code = await server.getAuthCode(); + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token: string; + }; + + // Step 2: Store tokens with valid expiry first + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'test-srv', + tokens: initial, + createToken: tokenStore.createToken, + }); + + // Step 3: Verify tokens work + const validResult = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }); + expect(validResult).not.toBeNull(); + expect(validResult!.access_token).toBe(initial.access_token); + + // Step 4: Simulate token expiry by directly updating the stored token's expiresAt + await tokenStore.updateToken({ userId: 'u1', identifier: 'mcp:test-srv' }, { expiresIn: -1 }); + + // Step 5: Revoke refresh token on server side (simulating server-side revocation) + server.issuedRefreshTokens.clear(); + + // Step 6: Try to get tokens — refresh should fail, return null + const refreshCallback = async (refreshToken: string): Promise => { + const res = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=refresh_token&refresh_token=${refreshToken}`, + }); + if (!res.ok) { + const body = (await res.json()) as { error: string }; + throw new Error(`Refresh failed: ${body.error}`); + } + const data = (await res.json()) as MCPOAuthTokens; + return data; + }; + + const expiredResult = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + + // Refresh failed → returns null → triggers OAuth re-auth flow + expect(expiredResult).toBeNull(); + + // Step 7: Simulate the re-auth flow via FlowStateManager + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + const flowId = 'u1:test-srv'; + + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-srv', + userId: 'u1', + serverUrl: server.url, + state: 'test-state', + authorizationUrl: `${server.url}authorize?state=${flowId}`, + }); + + // Step 8: Get a new auth code and exchange for tokens (simulating user re-auth) + const newCode = await server.getAuthCode(); + const newTokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${newCode}`, + }); + const newTokens = (await newTokenRes.json()) as { + access_token: string; + token_type: string; + expires_in: number; + refresh_token?: string; + }; + + // Step 9: Complete the flow + const mcpTokens: MCPOAuthTokens = { + ...newTokens, + obtained_at: Date.now(), + expires_at: Date.now() + newTokens.expires_in * 1000, + }; + await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens); + + // Step 10: Store the new tokens + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'test-srv', + tokens: mcpTokens, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + findToken: tokenStore.findToken, + }); + + // Step 11: Verify new tokens work + const newResult = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + }); + expect(newResult).not.toBeNull(); + expect(newResult!.access_token).toBe(newTokens.access_token); + + // Step 12: Verify new token works against server + const finalMcpRes = await fetch(server.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Authorization: `Bearer ${newResult!.access_token}`, + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + id: 1, + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'test', version: '0.0.1' }, + }, + }), + }); + expect(finalMcpRes.status).toBe(200); + }); + }); + + describe('Concurrent token expiry with connection mutex', () => { + it('should handle multiple concurrent getTokens calls when token is expired', async () => { + const tokenStore = new InMemoryTokenStore(); + + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:test-srv', + token: 'enc:expired-token', + expiresIn: -1, + }); + await tokenStore.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:test-srv:refresh', + token: 'enc:valid-refresh', + expiresIn: 86400, + }); + + let refreshCallCount = 0; + const refreshCallback = jest.fn().mockImplementation(async () => { + refreshCallCount++; + await new Promise((r) => setTimeout(r, 100)); + return { + access_token: `refreshed-token-${refreshCallCount}`, + token_type: 'Bearer', + expires_in: 3600, + obtained_at: Date.now(), + expires_at: Date.now() + 3600000, + }; + }); + + // Fire 3 concurrent getTokens calls via FlowStateManager (like the connection mutex does) + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 30000, + ci: true, + }); + + const getTokensViaFlow = () => + flowManager.createFlowWithHandler('u1:test-srv', 'mcp_get_tokens', async () => { + return await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'test-srv', + findToken: tokenStore.findToken, + createToken: tokenStore.createToken, + updateToken: tokenStore.updateToken, + refreshTokens: refreshCallback, + }); + }); + + const [r1, r2, r3] = await Promise.all([ + getTokensViaFlow(), + getTokensViaFlow(), + getTokensViaFlow(), + ]); + + // All should get tokens (either directly or via flow coalescing) + expect(r1).not.toBeNull(); + expect(r2).not.toBeNull(); + expect(r3).not.toBeNull(); + + // The refresh callback should only be called once due to flow coalescing + expect(refreshCallback).toHaveBeenCalledTimes(1); + }); + }); + + describe('Stale PENDING flow detection', () => { + it('should treat PENDING flows older than 2 minutes as stale', async () => { + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 300000, + ci: true, + }); + + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + authorizationUrl: 'https://example.com/auth', + }); + + // Manually age the flow to 3 minutes + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (state) { + state.createdAt = Date.now() - 3 * 60 * 1000; + await (flowStore as unknown as { set: (k: string, v: unknown) => Promise }).set( + `mcp_oauth:${flowId}`, + state, + ); + } + + const agedState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + expect(agedState?.status).toBe('PENDING'); + + const age = agedState?.createdAt ? Date.now() - agedState.createdAt : 0; + expect(age).toBeGreaterThan(2 * 60 * 1000); + + // A new flow should be created (the stale one would be deleted + recreated) + // This verifies our staleness check threshold + expect(age > PENDING_STALE_MS).toBe(true); + }); + + it('should not treat recent PENDING flows as stale', async () => { + const flowStore = new MockKeyv(); + const flowManager = new FlowStateManager(flowStore as unknown as Keyv, { + ttl: 300000, + ci: true, + }); + + const flowId = 'user1:test-server'; + await flowManager.initFlow(flowId, 'mcp_oauth', { + serverName: 'test-server', + authorizationUrl: 'https://example.com/auth', + }); + + const state = await flowManager.getFlowState(flowId, 'mcp_oauth'); + const age = state?.createdAt ? Date.now() - state.createdAt : Infinity; + + expect(age < PENDING_STALE_MS).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts new file mode 100644 index 0000000000..3805586453 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts @@ -0,0 +1,544 @@ +/** + * Integration tests for MCPTokenStorage.storeTokens() and MCPTokenStorage.getTokens(). + * + * Uses InMemoryTokenStore to exercise encrypt/decrypt round-trips, expiry calculation, + * refresh callback wiring, and ReauthenticationRequiredError paths. + */ + +import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth'; +import { InMemoryTokenStore } from './helpers/oauthTestServer'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +describe('MCPTokenStorage', () => { + let store: InMemoryTokenStore; + + beforeEach(() => { + store = new InMemoryTokenStore(); + jest.clearAllMocks(); + }); + + describe('storeTokens', () => { + it('should create new access token with expires_in', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved).not.toBeNull(); + expect(saved!.token).toBe('enc:at1'); + const expiresInMs = saved!.expiresAt.getTime() - Date.now(); + expect(expiresInMs).toBeGreaterThan(3500 * 1000); + expect(expiresInMs).toBeLessThanOrEqual(3600 * 1000); + }); + + it('should create new access token with expires_at (MCPOAuthTokens format)', async () => { + const expiresAt = Date.now() + 7200 * 1000; + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'at1', + token_type: 'Bearer', + expires_at: expiresAt, + obtained_at: Date.now(), + }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved).not.toBeNull(); + const diff = Math.abs(saved!.expiresAt.getTime() - expiresAt); + expect(diff).toBeLessThan(2000); + }); + + it('should default to 1-year expiry when none provided', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer' }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + const oneYearMs = 365 * 24 * 60 * 60 * 1000; + const expiresInMs = saved!.expiresAt.getTime() - Date.now(); + expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000); + }); + + it('should update existing access token', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:old-token', + expiresIn: 3600, + }); + + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'new-token', token_type: 'Bearer', expires_in: 7200 }, + createToken: store.createToken, + updateToken: store.updateToken, + findToken: store.findToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved!.token).toBe('enc:new-token'); + }); + + it('should store refresh token alongside access token', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'at1', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'rt1', + }, + createToken: store.createToken, + }); + + const refreshSaved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + }); + expect(refreshSaved).not.toBeNull(); + expect(refreshSaved!.token).toBe('enc:rt1'); + }); + + it('should skip refresh token when not in response', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + }); + + const refreshSaved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + }); + expect(refreshSaved).toBeNull(); + }); + + it('should store client info when provided', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] }, + }); + + const clientSaved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth_client', + identifier: 'mcp:srv1:client', + }); + expect(clientSaved).not.toBeNull(); + expect(clientSaved!.token).toContain('enc:'); + expect(clientSaved!.token).toContain('cid'); + }); + + it('should use existingTokens to skip DB lookups', async () => { + const findSpy = jest.fn(); + + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 }, + createToken: store.createToken, + updateToken: store.updateToken, + findToken: findSpy, + existingTokens: { + accessToken: null, + refreshToken: null, + clientInfoToken: null, + }, + }); + + expect(findSpy).not.toHaveBeenCalled(); + }); + + it('should handle invalid NaN expiry date', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'at1', + token_type: 'Bearer', + expires_at: NaN, + obtained_at: Date.now(), + }, + createToken: store.createToken, + }); + + const saved = await store.findToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + }); + expect(saved).not.toBeNull(); + const oneYearMs = 365 * 24 * 60 * 60 * 1000; + const expiresInMs = saved!.expiresAt.getTime() - Date.now(); + expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000); + }); + }); + + describe('getTokens', () => { + it('should return valid non-expired tokens', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:valid-token', + expiresIn: 3600, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result).not.toBeNull(); + expect(result!.access_token).toBe('valid-token'); + expect(result!.token_type).toBe('Bearer'); + }); + + it('should return tokens with refresh token when available', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:at', + expiresIn: 3600, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result!.refresh_token).toBe('rt'); + }); + + it('should return tokens without refresh token field when none stored', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:at', + expiresIn: 3600, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result!.refresh_token).toBeUndefined(); + }); + + it('should throw ReauthenticationRequiredError when expired and no refresh', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should throw ReauthenticationRequiredError when missing and no refresh', async () => { + await expect( + MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }), + ).rejects.toThrow(ReauthenticationRequiredError); + }); + + it('should refresh expired access token when refresh token and callback are available', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockResolvedValue({ + access_token: 'refreshed-at', + token_type: 'Bearer', + expires_in: 3600, + obtained_at: Date.now(), + expires_at: Date.now() + 3600000, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + updateToken: store.updateToken, + refreshTokens, + }); + + expect(result).not.toBeNull(); + expect(result!.access_token).toBe('refreshed-at'); + expect(refreshTokens).toHaveBeenCalledWith( + 'rt', + expect.objectContaining({ userId: 'u1', serverName: 'srv1' }), + ); + }); + + it('should return null when refresh fails', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockRejectedValue(new Error('refresh failed')); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + updateToken: store.updateToken, + refreshTokens, + }); + + expect(result).toBeNull(); + }); + + it('should return null when no refreshTokens callback provided', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result).toBeNull(); + }); + + it('should return null when no createToken callback provided', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + refreshTokens: jest.fn(), + }); + + expect(result).toBeNull(); + }); + + it('should pass client info to refreshTokens metadata', async () => { + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_client', + identifier: 'mcp:srv1:client', + token: 'enc:{"client_id":"cid","client_secret":"csec"}', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockResolvedValue({ + access_token: 'new-at', + token_type: 'Bearer', + expires_in: 3600, + }); + + await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + updateToken: store.updateToken, + refreshTokens, + }); + + expect(refreshTokens).toHaveBeenCalledWith( + 'rt', + expect.objectContaining({ + clientInfo: expect.objectContaining({ client_id: 'cid' }), + }), + ); + }); + + it('should handle unauthorized_client refresh error', async () => { + const { logger } = await import('@librechat/data-schemas'); + + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth', + identifier: 'mcp:srv1', + token: 'enc:expired-token', + expiresIn: -1, + }); + await store.createToken({ + userId: 'u1', + type: 'mcp_oauth_refresh', + identifier: 'mcp:srv1:refresh', + token: 'enc:rt', + expiresIn: 86400, + }); + + const refreshTokens = jest.fn().mockRejectedValue(new Error('unauthorized_client')); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + createToken: store.createToken, + refreshTokens, + }); + + expect(result).toBeNull(); + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining('does not support refresh tokens'), + ); + }); + }); + + describe('storeTokens + getTokens round-trip', () => { + it('should store and retrieve tokens with full encrypt/decrypt cycle', async () => { + await MCPTokenStorage.storeTokens({ + userId: 'u1', + serverName: 'srv1', + tokens: { + access_token: 'my-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'my-refresh-token', + }, + createToken: store.createToken, + clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] }, + }); + + const result = await MCPTokenStorage.getTokens({ + userId: 'u1', + serverName: 'srv1', + findToken: store.findToken, + }); + + expect(result!.access_token).toBe('my-access-token'); + expect(result!.refresh_token).toBe('my-refresh-token'); + expect(result!.token_type).toBe('Bearer'); + expect(result!.obtained_at).toBeDefined(); + expect(result!.expires_at).toBeDefined(); + }); + }); +}); diff --git a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts new file mode 100644 index 0000000000..3b68b2ded4 --- /dev/null +++ b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts @@ -0,0 +1,449 @@ +import * as http from 'http'; +import * as net from 'net'; +import { randomUUID, createHash } from 'crypto'; +import { z } from 'zod'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; +import type { FlowState } from '~/flow/types'; +import type { Socket } from 'net'; + +export class MockKeyv { + private store: Map>; + + constructor() { + this.store = new Map(); + } + + async get(key: string): Promise | undefined> { + return this.store.get(key); + } + + async set(key: string, value: FlowState, _ttl?: number): Promise { + this.store.set(key, value); + return true; + } + + async delete(key: string): Promise { + return this.store.delete(key); + } +} + +export function getFreePort(): Promise { + return new Promise((resolve, reject) => { + const srv = net.createServer(); + srv.listen(0, '127.0.0.1', () => { + const addr = srv.address() as net.AddressInfo; + srv.close((err) => (err ? reject(err) : resolve(addr.port))); + }); + }); +} + +export function trackSockets(httpServer: http.Server): () => Promise { + const sockets = new Set(); + httpServer.on('connection', (socket: Socket) => { + sockets.add(socket); + socket.once('close', () => sockets.delete(socket)); + }); + return () => + new Promise((resolve) => { + for (const socket of sockets) { + socket.destroy(); + } + sockets.clear(); + httpServer.close(() => resolve()); + }); +} + +export interface OAuthTestServerOptions { + tokenTTLMs?: number; + issueRefreshTokens?: boolean; + refreshTokenTTLMs?: number; + rotateRefreshTokens?: boolean; +} + +export interface OAuthTestServer { + url: string; + port: number; + close: () => Promise; + issuedTokens: Set; + tokenTTL: number; + tokenIssueTimes: Map; + issuedRefreshTokens: Map; + registeredClients: Map; + getAuthCode: () => Promise; +} + +async function readRequestBody(req: http.IncomingMessage): Promise { + const chunks: Uint8Array[] = []; + for await (const chunk of req) { + chunks.push(chunk as Uint8Array); + } + return Buffer.concat(chunks).toString(); +} + +function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null { + if (contentType?.includes('application/x-www-form-urlencoded')) { + return new URLSearchParams(body); + } + if (contentType?.includes('application/json')) { + const json = JSON.parse(body) as Record; + return new URLSearchParams(json); + } + return new URLSearchParams(body); +} + +export async function createOAuthMCPServer( + options: OAuthTestServerOptions = {}, +): Promise { + const { + tokenTTLMs = 60000, + issueRefreshTokens = false, + refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000, + rotateRefreshTokens = false, + } = options; + + const sessions = new Map(); + const issuedTokens = new Set(); + const tokenIssueTimes = new Map(); + const issuedRefreshTokens = new Map(); + const refreshTokenIssueTimes = new Map(); + const authCodes = new Map(); + const registeredClients = new Map(); + + let port = 0; + + const httpServer = http.createServer(async (req, res) => { + const url = new URL(req.url ?? '/', `http://${req.headers.host}`); + + if (url.pathname === '/.well-known/oauth-authorization-server' && req.method === 'GET') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + issuer: `http://127.0.0.1:${port}`, + authorization_endpoint: `http://127.0.0.1:${port}/authorize`, + token_endpoint: `http://127.0.0.1:${port}/token`, + registration_endpoint: `http://127.0.0.1:${port}/register`, + response_types_supported: ['code'], + grant_types_supported: issueRefreshTokens + ? ['authorization_code', 'refresh_token'] + : ['authorization_code'], + token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post'], + code_challenge_methods_supported: ['S256'], + }), + ); + return; + } + + if (url.pathname === '/register' && req.method === 'POST') { + const body = await readRequestBody(req); + const data = JSON.parse(body) as { redirect_uris?: string[] }; + const clientId = `client-${randomUUID().slice(0, 8)}`; + const clientSecret = `secret-${randomUUID()}`; + registeredClients.set(clientId, { client_id: clientId, client_secret: clientSecret }); + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + client_id: clientId, + client_secret: clientSecret, + redirect_uris: data.redirect_uris ?? [], + }), + ); + return; + } + + if (url.pathname === '/authorize') { + const code = randomUUID(); + const codeChallenge = url.searchParams.get('code_challenge') ?? undefined; + const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined; + authCodes.set(code, { codeChallenge, codeChallengeMethod }); + const redirectUri = url.searchParams.get('redirect_uri') ?? ''; + const state = url.searchParams.get('state') ?? ''; + res.writeHead(302, { + Location: `${redirectUri}?code=${code}&state=${state}`, + }); + res.end(); + return; + } + + if (url.pathname === '/token' && req.method === 'POST') { + const body = await readRequestBody(req); + const params = parseTokenRequest(body, req.headers['content-type']); + if (!params) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_request' })); + return; + } + + const grantType = params.get('grant_type'); + + if (grantType === 'authorization_code') { + const code = params.get('code'); + const codeData = code ? authCodes.get(code) : undefined; + if (!code || !codeData) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + + if (codeData.codeChallenge) { + const codeVerifier = params.get('code_verifier'); + if (!codeVerifier) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + if (codeData.codeChallengeMethod === 'S256') { + const expected = createHash('sha256').update(codeVerifier).digest('base64url'); + if (expected !== codeData.codeChallenge) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + } + } + + authCodes.delete(code); + + const accessToken = randomUUID(); + issuedTokens.add(accessToken); + tokenIssueTimes.set(accessToken, Date.now()); + + const tokenResponse: Record = { + access_token: accessToken, + token_type: 'Bearer', + expires_in: Math.ceil(tokenTTLMs / 1000), + }; + + if (issueRefreshTokens) { + const refreshToken = randomUUID(); + issuedRefreshTokens.set(refreshToken, accessToken); + refreshTokenIssueTimes.set(refreshToken, Date.now()); + tokenResponse.refresh_token = refreshToken; + } + + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify(tokenResponse)); + return; + } + + if (grantType === 'refresh_token' && issueRefreshTokens) { + const refreshToken = params.get('refresh_token'); + if (!refreshToken || !issuedRefreshTokens.has(refreshToken)) { + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + + const issueTime = refreshTokenIssueTimes.get(refreshToken) ?? 0; + if (Date.now() - issueTime > refreshTokenTTLMs) { + issuedRefreshTokens.delete(refreshToken); + refreshTokenIssueTimes.delete(refreshToken); + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_grant' })); + return; + } + + const newAccessToken = randomUUID(); + issuedTokens.add(newAccessToken); + tokenIssueTimes.set(newAccessToken, Date.now()); + + const tokenResponse: Record = { + access_token: newAccessToken, + token_type: 'Bearer', + expires_in: Math.ceil(tokenTTLMs / 1000), + }; + + if (rotateRefreshTokens) { + issuedRefreshTokens.delete(refreshToken); + refreshTokenIssueTimes.delete(refreshToken); + const newRefreshToken = randomUUID(); + issuedRefreshTokens.set(newRefreshToken, newAccessToken); + refreshTokenIssueTimes.set(newRefreshToken, Date.now()); + tokenResponse.refresh_token = newRefreshToken; + } + + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify(tokenResponse)); + return; + } + + res.writeHead(400, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'unsupported_grant_type' })); + return; + } + + // All other paths require Bearer token auth + const authHeader = req.headers.authorization; + if (!authHeader || !authHeader.startsWith('Bearer ')) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_token' })); + return; + } + + const token = authHeader.slice(7); + if (!issuedTokens.has(token)) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_token' })); + return; + } + + const issueTime = tokenIssueTimes.get(token) ?? 0; + if (Date.now() - issueTime > tokenTTLMs) { + issuedTokens.delete(token); + tokenIssueTimes.delete(token); + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_token' })); + return; + } + + // Authenticated MCP request — route to transport + const sid = req.headers['mcp-session-id'] as string | undefined; + let transport = sid ? sessions.get(sid) : undefined; + + if (!transport) { + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + }); + const mcp = new McpServer({ name: 'oauth-test-server', version: '0.0.1' }); + mcp.tool('echo', { message: z.string() }, async (args) => ({ + content: [{ type: 'text' as const, text: `echo: ${args.message}` }], + })); + await mcp.connect(transport); + } + + await transport.handleRequest(req, res); + + if (transport.sessionId && !sessions.has(transport.sessionId)) { + sessions.set(transport.sessionId, transport); + transport.onclose = () => sessions.delete(transport!.sessionId!); + } + }); + + const destroySockets = trackSockets(httpServer); + port = await getFreePort(); + await new Promise((resolve) => httpServer.listen(port, '127.0.0.1', resolve)); + + return { + url: `http://127.0.0.1:${port}/`, + port, + issuedTokens, + tokenTTL: tokenTTLMs, + tokenIssueTimes, + issuedRefreshTokens, + registeredClients, + getAuthCode: async () => { + const authRes = await fetch( + `http://127.0.0.1:${port}/authorize?redirect_uri=http://localhost&state=test`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + return new URL(location).searchParams.get('code') ?? ''; + }, + close: async () => { + const closing = [...sessions.values()].map((t) => t.close().catch(() => undefined)); + sessions.clear(); + await Promise.all(closing); + await destroySockets(); + }, + }; +} + +export interface InMemoryToken { + userId: string; + type: string; + identifier: string; + token: string; + expiresAt: Date; + createdAt: Date; + metadata?: Map | Record; +} + +export class InMemoryTokenStore { + private tokens: Map = new Map(); + + private key(filter: { userId?: string; type?: string; identifier?: string }): string { + return `${filter.userId}:${filter.type}:${filter.identifier}`; + } + + findToken = async (filter: { + userId?: string; + type?: string; + identifier?: string; + }): Promise => { + for (const token of this.tokens.values()) { + const matchUserId = !filter.userId || token.userId === filter.userId; + const matchType = !filter.type || token.type === filter.type; + const matchIdentifier = !filter.identifier || token.identifier === filter.identifier; + if (matchUserId && matchType && matchIdentifier) { + return token; + } + } + return null; + }; + + createToken = async (data: { + userId: string; + type: string; + identifier: string; + token: string; + expiresIn?: number; + metadata?: Record; + }): Promise => { + const expiresIn = data.expiresIn ?? 365 * 24 * 60 * 60; + const token: InMemoryToken = { + userId: data.userId, + type: data.type, + identifier: data.identifier, + token: data.token, + expiresAt: new Date(Date.now() + expiresIn * 1000), + createdAt: new Date(), + metadata: data.metadata, + }; + this.tokens.set(this.key(data), token); + return token; + }; + + updateToken = async ( + filter: { userId?: string; type?: string; identifier?: string }, + data: { + userId?: string; + type?: string; + identifier?: string; + token?: string; + expiresIn?: number; + metadata?: Record; + }, + ): Promise => { + const existing = await this.findToken(filter); + if (!existing) { + throw new Error(`Token not found for filter: ${JSON.stringify(filter)}`); + } + const existingKey = this.key(existing); + const expiresIn = + data.expiresIn ?? Math.floor((existing.expiresAt.getTime() - Date.now()) / 1000); + const updated: InMemoryToken = { + ...existing, + token: data.token ?? existing.token, + expiresAt: data.expiresIn ? new Date(Date.now() + expiresIn * 1000) : existing.expiresAt, + metadata: data.metadata ?? existing.metadata, + }; + this.tokens.set(existingKey, updated); + return updated; + }; + + deleteToken = async (filter: { + userId: string; + type: string; + identifier: string; + }): Promise => { + this.tokens.delete(this.key(filter)); + }; + + getAll(): InMemoryToken[] { + return [...this.tokens.values()]; + } + + clear(): void { + this.tokens.clear(); + } +} diff --git a/packages/api/src/mcp/__tests__/reconnection-storm.test.ts b/packages/api/src/mcp/__tests__/reconnection-storm.test.ts index c1cf0ec5df..e073dca8a3 100644 --- a/packages/api/src/mcp/__tests__/reconnection-storm.test.ts +++ b/packages/api/src/mcp/__tests__/reconnection-storm.test.ts @@ -12,8 +12,12 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import type { Socket } from 'net'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import type { MCPOAuthTokens } from '~/mcp/oauth'; import { OAuthReconnectionTracker } from '~/mcp/oauth/OAuthReconnectionTracker'; +import { createOAuthMCPServer } from './helpers/oauthTestServer'; import { MCPConnection } from '~/mcp/connection'; +import { mcpConfig } from '~/mcp/mcpConfig'; jest.mock('@librechat/data-schemas', () => ({ logger: { @@ -143,16 +147,17 @@ afterEach(() => { /* ------------------------------------------------------------------ */ /* Fix #2 — Circuit breaker trips after rapid connect/disconnect */ -/* cycles (5 cycles within 60s -> 30s cooldown) */ +/* cycles (CB_MAX_CYCLES within window -> cooldown) */ /* ------------------------------------------------------------------ */ describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => { - it('blocks connection after 5 rapid cycles via static circuit breaker', async () => { + it('blocks connection after CB_MAX_CYCLES rapid cycles via static circuit breaker', async () => { const srv = await startMCPServer(); const conn = createConnection('cycling-server', srv.url); let completedCycles = 0; let breakerMessage = ''; - for (let cycle = 0; cycle < 10; cycle++) { + const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2; + for (let cycle = 0; cycle < maxAttempts; cycle++) { try { await conn.connect(); await teardownConnection(conn); @@ -166,7 +171,7 @@ describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => { } expect(breakerMessage).toContain('Circuit breaker is open'); - expect(completedCycles).toBeLessThanOrEqual(5); + expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES); await srv.close(); }); @@ -266,12 +271,13 @@ describe('Fix #4: Circuit breaker state persists across instance replacement', ( /* recordFailedRound() in the catch path */ /* ------------------------------------------------------------------ */ describe('Fix #5: Dead server triggers circuit breaker', () => { - it('3 failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => { + it('failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => { const conn = createConnection('dead', 'http://127.0.0.1:1/mcp', 1000); const spy = jest.spyOn(conn.client, 'connect'); + const totalAttempts = mcpConfig.CB_MAX_FAILED_ROUNDS + 2; const errors: string[] = []; - for (let i = 0; i < 5; i++) { + for (let i = 0; i < totalAttempts; i++) { try { await conn.connect(); } catch (e) { @@ -279,8 +285,8 @@ describe('Fix #5: Dead server triggers circuit breaker', () => { } } - expect(spy.mock.calls.length).toBe(3); - expect(errors).toHaveLength(5); + expect(spy.mock.calls.length).toBe(mcpConfig.CB_MAX_FAILED_ROUNDS); + expect(errors).toHaveLength(totalAttempts); expect(errors.filter((m) => m.includes('Circuit breaker is open'))).toHaveLength(2); await conn.disconnect(); @@ -295,7 +301,7 @@ describe('Fix #5: Dead server triggers circuit breaker', () => { userId: 'user-A', }); - for (let i = 0; i < 3; i++) { + for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) { try { await userA.connect(); } catch { @@ -332,7 +338,7 @@ describe('Fix #5: Dead server triggers circuit breaker', () => { serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never, userId: 'user-A', }); - for (let i = 0; i < 3; i++) { + for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) { try { await userA.connect(); } catch { @@ -448,13 +454,14 @@ describe('Fix #6: OAuth failure uses cooldown-based retry', () => { /* Integration: Circuit breaker caps rapid cycling with real transport */ /* ------------------------------------------------------------------ */ describe('Cascade: Circuit breaker caps rapid cycling', () => { - it('breaker trips before 10 cycles complete against a live server', async () => { + it('breaker trips before double CB_MAX_CYCLES complete against a live server', async () => { const srv = await startMCPServer(); const conn = createConnection('cascade', srv.url); const spy = jest.spyOn(conn.client, 'connect'); let completedCycles = 0; - for (let i = 0; i < 10; i++) { + const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2; + for (let i = 0; i < maxAttempts; i++) { try { await conn.connect(); await teardownConnection(conn); @@ -469,8 +476,8 @@ describe('Cascade: Circuit breaker caps rapid cycling', () => { } } - expect(completedCycles).toBeLessThanOrEqual(5); - expect(spy.mock.calls.length).toBeLessThanOrEqual(5); + expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES); + expect(spy.mock.calls.length).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES); await srv.close(); }); @@ -501,6 +508,146 @@ describe('Cascade: Circuit breaker caps rapid cycling', () => { }, 30_000); }); +/* ------------------------------------------------------------------ */ +/* OAuth: cycle recovery after successful OAuth reconnect */ +/* ------------------------------------------------------------------ */ +describe('OAuth: cycle budget recovery after successful OAuth', () => { + let oauthServer: OAuthTestServer; + + beforeEach(async () => { + oauthServer = await createOAuthMCPServer({ tokenTTLMs: 60000 }); + }); + + afterEach(async () => { + await oauthServer.close(); + }); + + async function exchangeCodeForToken(serverUrl: string): Promise { + const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, { + redirect: 'manual', + }); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code') ?? ''; + const tokenRes = await fetch(`${serverUrl}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const data = (await tokenRes.json()) as { access_token: string }; + return data.access_token; + } + + it('should decrement cycle count after successful OAuth recovery', async () => { + const serverName = 'oauth-cycle-test'; + MCPConnection.clearCooldown(serverName); + + const conn = new MCPConnection({ + serverName, + serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 }, + userId: 'user-1', + }); + + // When oauthRequired fires, get a token and emit oauthHandled + // This triggers the oauthRecovery path inside connectClient + conn.on('oauthRequired', async () => { + const accessToken = await exchangeCodeForToken(oauthServer.url); + conn.setOAuthTokens({ + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + conn.emit('oauthHandled'); + }); + + // connect() → 401 → oauthRequired → oauthHandled → connectClient returns + // connect() sees not connected → throws "Connection not established" + await expect(conn.connect()).rejects.toThrow('Connection not established'); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cb = (MCPConnection as any).circuitBreakers.get(serverName); + const cyclesBeforeRetry = cb.cycleCount; + + // Retry — should succeed and decrement cycle count via oauthRecovery + await conn.connect(); + expect(await conn.isConnected()).toBe(true); + + const cyclesAfterSuccess = cb.cycleCount; + // The retry adds +1 cycle (disconnect(false)) then -1 (oauthRecovery decrement) + // So cyclesAfterSuccess should equal cyclesBeforeRetry, not cyclesBeforeRetry + 1 + expect(cyclesAfterSuccess).toBe(cyclesBeforeRetry); + + await teardownConnection(conn); + }); + + it('should allow more OAuth reconnects than non-OAuth before breaker trips', async () => { + const serverName = 'oauth-budget'; + MCPConnection.clearCooldown(serverName); + + // Each OAuth flow: connect (+1) → 401 → oauthHandled → retry connect (+1) → success (-1) = net 1 + // Without the decrement it would be net 2 per flow, tripping the breaker after ~2 users + let successfulFlows = 0; + for (let i = 0; i < 10; i++) { + const conn = new MCPConnection({ + serverName, + serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 }, + userId: `user-${i}`, + }); + + conn.on('oauthRequired', async () => { + const accessToken = await exchangeCodeForToken(oauthServer.url); + conn.setOAuthTokens({ + access_token: accessToken, + token_type: 'Bearer', + } as MCPOAuthTokens); + conn.emit('oauthHandled'); + }); + + try { + // First connect: 401 → oauthHandled → returns without connection + await conn.connect().catch(() => {}); + // Retry: succeeds with token, decrements cycle + await conn.connect(); + successfulFlows++; + await teardownConnection(conn); + } catch (e) { + conn.removeAllListeners(); + if ((e as Error).message.includes('Circuit breaker is open')) { + break; + } + } + } + + // With the oauthRecovery decrement, each flow is net ~1 cycle instead of ~2, + // so we should get more successful flows before the breaker trips + expect(successfulFlows).toBeGreaterThanOrEqual(3); + }); + + it('should not decrement cycle count when OAuth fails', async () => { + const serverName = 'oauth-failed-no-decrement'; + MCPConnection.clearCooldown(serverName); + + const conn = new MCPConnection({ + serverName, + serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 }, + userId: 'user-1', + }); + + conn.on('oauthRequired', () => { + conn.emit('oauthFailed', new Error('user denied')); + }); + + await expect(conn.connect()).rejects.toThrow(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cb = (MCPConnection as any).circuitBreakers.get(serverName); + const cyclesAfterFailure = cb.cycleCount; + + // connect() recorded +1 cycle, oauthFailed should NOT decrement + expect(cyclesAfterFailure).toBeGreaterThanOrEqual(1); + + conn.removeAllListeners(); + }); +}); + /* ------------------------------------------------------------------ */ /* Sanity: Real transport works end-to-end */ /* ------------------------------------------------------------------ */ diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index cac0a4afc5..8dc857cd3b 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -82,14 +82,6 @@ interface CircuitBreakerState { failedBackoffUntil: number; } -const CB_MAX_CYCLES = 5; -const CB_CYCLE_WINDOW_MS = 60_000; -const CB_CYCLE_COOLDOWN_MS = 30_000; - -const CB_MAX_FAILED_ROUNDS = 3; -const CB_FAILED_WINDOW_MS = 120_000; -const CB_BASE_BACKOFF_MS = 30_000; -const CB_MAX_BACKOFF_MS = 300_000; /** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */ const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES; @@ -281,6 +273,7 @@ export class MCPConnection extends EventEmitter { private oauthTokens?: MCPOAuthTokens | null; private requestHeaders?: Record | null; private oauthRequired = false; + private oauthRecovery = false; private readonly useSSRFProtection: boolean; iconPath?: string; timeout?: number; @@ -325,17 +318,17 @@ export class MCPConnection extends EventEmitter { private recordCycle(): void { const cb = this.getCircuitBreaker(); const now = Date.now(); - if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) { + if (now - cb.cycleWindowStart > mcpConfig.CB_CYCLE_WINDOW_MS) { cb.cycleCount = 0; cb.cycleWindowStart = now; } cb.cycleCount++; - if (cb.cycleCount >= CB_MAX_CYCLES) { - cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS; + if (cb.cycleCount >= mcpConfig.CB_MAX_CYCLES) { + cb.cooldownUntil = now + mcpConfig.CB_CYCLE_COOLDOWN_MS; cb.cycleCount = 0; cb.cycleWindowStart = now; logger.warn( - `${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${CB_CYCLE_COOLDOWN_MS}ms`, + `${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${mcpConfig.CB_CYCLE_COOLDOWN_MS}ms`, ); } } @@ -343,15 +336,16 @@ export class MCPConnection extends EventEmitter { private recordFailedRound(): void { const cb = this.getCircuitBreaker(); const now = Date.now(); - if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) { + if (now - cb.failedWindowStart > mcpConfig.CB_FAILED_WINDOW_MS) { cb.failedRounds = 0; cb.failedWindowStart = now; } cb.failedRounds++; - if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) { + if (cb.failedRounds >= mcpConfig.CB_MAX_FAILED_ROUNDS) { const backoff = Math.min( - CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS), - CB_MAX_BACKOFF_MS, + mcpConfig.CB_BASE_BACKOFF_MS * + Math.pow(2, cb.failedRounds - mcpConfig.CB_MAX_FAILED_ROUNDS), + mcpConfig.CB_MAX_BACKOFF_MS, ); cb.failedBackoffUntil = now + backoff; logger.warn( @@ -367,6 +361,13 @@ export class MCPConnection extends EventEmitter { cb.failedBackoffUntil = 0; } + public static decrementCycleCount(serverName: string): void { + const cb = MCPConnection.circuitBreakers.get(serverName); + if (cb && cb.cycleCount > 0) { + cb.cycleCount--; + } + } + setRequestHeaders(headers: Record | null): void { if (!headers) { return; @@ -816,6 +817,13 @@ export class MCPConnection extends EventEmitter { this.emit('connectionChange', 'connected'); this.reconnectAttempts = 0; this.resetFailedRounds(); + if (this.oauthRecovery) { + MCPConnection.decrementCycleCount(this.serverName); + this.oauthRecovery = false; + logger.debug( + `${this.getLogPrefix()} OAuth recovery: decremented cycle count after successful reconnect`, + ); + } } catch (error) { // Check if it's a rate limit error - stop immediately to avoid making it worse if (this.isRateLimitError(error)) { @@ -899,9 +907,8 @@ export class MCPConnection extends EventEmitter { try { // Wait for OAuth to be handled await oauthHandledPromise; - // Reset the oauthRequired flag this.oauthRequired = false; - // Don't throw the error - just return so connection can be retried + this.oauthRecovery = true; logger.info( `${this.getLogPrefix()} OAuth handled successfully, connection will be retried`, ); diff --git a/packages/api/src/mcp/mcpConfig.ts b/packages/api/src/mcp/mcpConfig.ts index f3efd3592b..a81752e909 100644 --- a/packages/api/src/mcp/mcpConfig.ts +++ b/packages/api/src/mcp/mcpConfig.ts @@ -12,4 +12,18 @@ export const mcpConfig = { USER_CONNECTION_IDLE_TIMEOUT: math( process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000, ), + /** Max connect/disconnect cycles before the circuit breaker trips. Default: 7 */ + CB_MAX_CYCLES: math(process.env.MCP_CB_MAX_CYCLES ?? 7), + /** Sliding window (ms) for counting cycles. Default: 45s */ + CB_CYCLE_WINDOW_MS: math(process.env.MCP_CB_CYCLE_WINDOW_MS ?? 45_000), + /** Cooldown (ms) after the cycle breaker trips. Default: 15s */ + CB_CYCLE_COOLDOWN_MS: math(process.env.MCP_CB_CYCLE_COOLDOWN_MS ?? 15_000), + /** Max consecutive failed connection rounds before backoff. Default: 3 */ + CB_MAX_FAILED_ROUNDS: math(process.env.MCP_CB_MAX_FAILED_ROUNDS ?? 3), + /** Sliding window (ms) for counting failed rounds. Default: 120s */ + CB_FAILED_WINDOW_MS: math(process.env.MCP_CB_FAILED_WINDOW_MS ?? 120_000), + /** Base backoff (ms) after failed round threshold is reached. Default: 30s */ + CB_BASE_BACKOFF_MS: math(process.env.MCP_CB_BASE_BACKOFF_MS ?? 30_000), + /** Max backoff cap (ms) for exponential backoff. Default: 300s */ + CB_MAX_BACKOFF_MS: math(process.env.MCP_CB_MAX_BACKOFF_MS ?? 300_000), }; diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 83e855591e..366d0d2fde 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -426,8 +426,8 @@ export class MCPOAuthHandler { scope: config.scope, }); - /** Add state parameter with flowId to the authorization URL */ - authorizationUrl.searchParams.set('state', flowId); + /** Add cryptographic state parameter to the authorization URL */ + authorizationUrl.searchParams.set('state', state); logger.debug(`[MCPOAuth] Added state parameter to authorization URL`); const flowMetadata: MCPOAuthFlowMetadata = { @@ -505,8 +505,8 @@ export class MCPOAuthHandler { `[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`, ); - /** Add state parameter with flowId to the authorization URL */ - authorizationUrl.searchParams.set('state', flowId); + /** Add cryptographic state parameter to the authorization URL */ + authorizationUrl.searchParams.set('state', state); logger.debug(`[MCPOAuth] Added state parameter to authorization URL`); if (resourceMetadata?.resource != null && resourceMetadata.resource) { @@ -672,6 +672,44 @@ export class MCPOAuthHandler { return randomBytes(32).toString('base64url'); } + private static readonly STATE_MAP_TYPE = 'mcp_oauth_state'; + + /** + * Stores a mapping from the opaque OAuth state parameter to the flowId. + * This allows the callback to resolve the flowId from an unguessable state + * value, preventing attackers from forging callback requests. + */ + static async storeStateMapping( + state: string, + flowId: string, + flowManager: FlowStateManager, + ): Promise { + await flowManager.initFlow(state, this.STATE_MAP_TYPE, { flowId }); + } + + /** + * Resolves an opaque OAuth state parameter back to the original flowId. + * Returns null if the state is not found (expired or never stored). + */ + static async resolveStateToFlowId( + state: string, + flowManager: FlowStateManager, + ): Promise { + const mapping = await flowManager.getFlowState(state, this.STATE_MAP_TYPE); + return (mapping?.metadata?.flowId as string) ?? null; + } + + /** + * Deletes an orphaned state mapping when a flow is replaced. + * Prevents old authorization URLs from resolving after a flow restart. + */ + static async deleteStateMapping( + state: string, + flowManager: FlowStateManager, + ): Promise { + await flowManager.deleteFlow(state, this.STATE_MAP_TYPE); + } + /** * Gets the default redirect URI for a server */ diff --git a/packages/api/src/mcp/oauth/tokens.ts b/packages/api/src/mcp/oauth/tokens.ts index 005ed7dd9a..7b1d189347 100644 --- a/packages/api/src/mcp/oauth/tokens.ts +++ b/packages/api/src/mcp/oauth/tokens.ts @@ -4,6 +4,15 @@ import type { TokenMethods, IToken } from '@librechat/data-schemas'; import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types'; import { isSystemUserId } from '~/mcp/enum'; +export class ReauthenticationRequiredError extends Error { + constructor(serverName: string, reason: 'expired' | 'missing') { + super( + `Re-authentication required for "${serverName}": access token ${reason} and no refresh token available`, + ); + this.name = 'ReauthenticationRequiredError'; + } +} + interface StoreTokensParams { userId: string; serverName: string; @@ -27,7 +36,12 @@ interface GetTokensParams { findToken: TokenMethods['findToken']; refreshTokens?: ( refreshToken: string, - metadata: { userId: string; serverName: string; identifier: string }, + metadata: { + userId: string; + serverName: string; + identifier: string; + clientInfo?: OAuthClientInformation; + }, ) => Promise; createToken?: TokenMethods['createToken']; updateToken?: TokenMethods['updateToken']; @@ -273,10 +287,11 @@ export class MCPTokenStorage { }); if (!refreshTokenData) { + const reason = isMissing ? 'missing' : 'expired'; logger.info( - `${logPrefix} Access token ${isMissing ? 'missing' : 'expired'} and no refresh token available`, + `${logPrefix} Access token ${reason} and no refresh token available — re-authentication required`, ); - return null; + throw new ReauthenticationRequiredError(serverName, reason); } if (!refreshTokens) { @@ -395,6 +410,9 @@ export class MCPTokenStorage { logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`); return tokens; } catch (error) { + if (error instanceof ReauthenticationRequiredError) { + throw error; + } logger.error(`${logPrefix} Failed to retrieve tokens`, error); return null; } diff --git a/packages/api/src/mcp/oauth/types.ts b/packages/api/src/mcp/oauth/types.ts index 178e20e35b..2138b4a782 100644 --- a/packages/api/src/mcp/oauth/types.ts +++ b/packages/api/src/mcp/oauth/types.ts @@ -88,6 +88,7 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata { clientInfo?: OAuthClientInformation; metadata?: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; + authorizationUrl?: string; } export interface MCPOAuthTokens extends OAuthTokens { From 9a5d7eaa4ef1dbae21d06ea82e587c54e867b48b Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 10 Mar 2026 23:14:52 -0400 Subject: [PATCH 06/39] =?UTF-8?q?=E2=9A=A1=20refactor:=20Replace=20`tiktok?= =?UTF-8?q?en`=20with=20`ai-tokenizer`=20(#12175)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: Update dependencies by adding ai-tokenizer and removing tiktoken - Added ai-tokenizer version 1.0.6 to package.json and package-lock.json across multiple packages. - Removed tiktoken version 1.0.15 from package.json and package-lock.json in the same locations, streamlining dependency management. * refactor: replace js-tiktoken with ai-tokenizer - Added support for 'claude' encoding in the AgentClient class to improve model compatibility. - Updated Tokenizer class to utilize 'ai-tokenizer' for both 'o200k_base' and 'claude' encodings, replacing the previous 'tiktoken' dependency. - Refactored tests to reflect changes in tokenizer behavior and ensure accurate token counting for both encoding types. - Removed deprecated references to 'tiktoken' and adjusted related tests for improved clarity and functionality. * chore: remove tiktoken mocks from DALLE3 tests - Eliminated mock implementations of 'tiktoken' from DALLE3-related test files to streamline test setup and align with recent dependency updates. - Adjusted related test structures to ensure compatibility with the new tokenizer implementation. * chore: Add distinct encoding support for Anthropic Claude models - Introduced a new method `getEncoding` in the AgentClient class to handle the specific BPE tokenizer for Claude models, ensuring compatibility with the distinct encoding requirements. - Updated documentation to clarify the encoding logic for Claude and other models. * docs: Update return type documentation for getEncoding method in AgentClient - Clarified the return type of the getEncoding method to specify that it can return an EncodingName or undefined, enhancing code readability and type safety. * refactor: Tokenizer class and error handling - Exported the EncodingName type for broader usage. - Renamed encodingMap to encodingData for clarity. - Improved error handling in getTokenCount method to ensure recovery attempts are logged and return 0 on failure. - Updated countTokens function documentation to specify the use of 'o200k_base' encoding. * refactor: Simplify encoding documentation and export type - Updated the getEncoding method documentation to clarify the default behavior for non-Anthropic Claude models. - Exported the EncodingName type separately from the Tokenizer module for improved clarity and usage. * test: Update text processing tests for token limits - Adjusted test cases to handle smaller text sizes, changing scenarios from ~120k tokens to ~20k tokens for both the real tokenizer and countTokens functions. - Updated token limits in tests to reflect new constraints, ensuring tests accurately assess performance and call reduction. - Enhanced console log messages for clarity regarding token counts and reductions in the updated scenarios. * refactor: Update Tokenizer imports and exports - Moved Tokenizer and countTokens exports to the tokenizer module for better organization. - Adjusted imports in memory.ts to reflect the new structure, ensuring consistent usage across the codebase. - Updated memory.test.ts to mock the Tokenizer from the correct module path, enhancing test accuracy. * refactor: Tokenizer initialization and error handling - Introduced an async `initEncoding` method to preload tokenizers, improving performance and accuracy in token counting. - Updated `getTokenCount` to handle uninitialized tokenizers more gracefully, ensuring proper recovery and logging on errors. - Removed deprecated synchronous tokenizer retrieval, streamlining the overall tokenizer management process. * test: Enhance tokenizer tests with initialization and encoding checks - Added `beforeAll` hooks to initialize tokenizers for 'o200k_base' and 'claude' encodings before running tests, ensuring proper setup. - Updated tests to validate the loading of encodings and the correctness of token counts for both 'o200k_base' and 'claude'. - Improved test structure to deduplicate concurrent initialization calls, enhancing performance and reliability. --- .../structured/specs/DALLE3-proxy.spec.js | 1 - .../tools/structured/specs/DALLE3.spec.js | 9 -- api/package.json | 2 +- api/server/controllers/agents/client.js | 4 + api/strategies/samlStrategy.spec.js | 1 - package-lock.json | 23 ++- packages/api/package.json | 2 +- .../api/src/agents/__tests__/memory.test.ts | 5 +- packages/api/src/agents/memory.ts | 3 +- packages/api/src/index.ts | 2 + packages/api/src/utils/index.ts | 1 - packages/api/src/utils/text.spec.ts | 62 +++----- packages/api/src/utils/tokenizer.spec.ts | 137 ++++-------------- packages/api/src/utils/tokenizer.ts | 98 +++++-------- packages/api/src/utils/tokens.ts | 39 ----- 15 files changed, 112 insertions(+), 277 deletions(-) diff --git a/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js b/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js index 4481a7d70f..262842b3c2 100644 --- a/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js @@ -1,7 +1,6 @@ const DALLE3 = require('../DALLE3'); const { ProxyAgent } = require('undici'); -jest.mock('tiktoken'); const processFileURL = jest.fn(); describe('DALLE3 Proxy Configuration', () => { diff --git a/api/app/clients/tools/structured/specs/DALLE3.spec.js b/api/app/clients/tools/structured/specs/DALLE3.spec.js index d2040989f9..6071929bfc 100644 --- a/api/app/clients/tools/structured/specs/DALLE3.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3.spec.js @@ -14,15 +14,6 @@ jest.mock('@librechat/data-schemas', () => { }; }); -jest.mock('tiktoken', () => { - return { - encoding_for_model: jest.fn().mockReturnValue({ - encode: jest.fn(), - decode: jest.fn(), - }), - }; -}); - const processFileURL = jest.fn(); const generate = jest.fn(); diff --git a/api/package.json b/api/package.json index fcd353af57..1618481b58 100644 --- a/api/package.json +++ b/api/package.json @@ -51,6 +51,7 @@ "@modelcontextprotocol/sdk": "^1.27.1", "@node-saml/passport-saml": "^5.1.0", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "bcryptjs": "^2.4.3", "compression": "^1.8.1", @@ -106,7 +107,6 @@ "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", "sharp": "^0.33.5", - "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", "undici": "^7.18.2", diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 5f99a0762b..0ecd62b819 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -1172,7 +1172,11 @@ class AgentClient extends BaseClient { } } + /** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */ getEncoding() { + if (this.model && this.model.toLowerCase().includes('claude')) { + return 'claude'; + } return 'o200k_base'; } diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js index 06c969ce46..1d16719b87 100644 --- a/api/strategies/samlStrategy.spec.js +++ b/api/strategies/samlStrategy.spec.js @@ -1,5 +1,4 @@ // --- Mocks --- -jest.mock('tiktoken'); jest.mock('fs'); jest.mock('path'); jest.mock('node-fetch'); diff --git a/package-lock.json b/package-lock.json index 09c5219afb..a2db2df389 100644 --- a/package-lock.json +++ b/package-lock.json @@ -66,6 +66,7 @@ "@modelcontextprotocol/sdk": "^1.27.1", "@node-saml/passport-saml": "^5.1.0", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "bcryptjs": "^2.4.3", "compression": "^1.8.1", @@ -121,7 +122,6 @@ "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", "sharp": "^0.33.5", - "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", "undici": "^7.18.2", @@ -22230,6 +22230,20 @@ "node": ">= 14" } }, + "node_modules/ai-tokenizer": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/ai-tokenizer/-/ai-tokenizer-1.0.6.tgz", + "integrity": "sha512-GaakQFxen0pRH/HIA4v68ZM40llCH27HUYUSBLK+gVuZ57e53pYJe1xFvSTj4sJJjbWU92m1X6NjPWyeWkFDow==", + "license": "MIT", + "peerDependencies": { + "ai": "^5.0.0" + }, + "peerDependenciesMeta": { + "ai": { + "optional": true + } + } + }, "node_modules/ajv": { "version": "8.18.0", "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", @@ -41485,11 +41499,6 @@ "node": ">=0.8" } }, - "node_modules/tiktoken": { - "version": "1.0.15", - "resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.15.tgz", - "integrity": "sha512-sCsrq/vMWUSEW29CJLNmPvWxlVp7yh2tlkAjpJltIKqp5CKf98ZNpdeHRmAlPVFlGEbswDc6SmI8vz64W/qErw==" - }, "node_modules/timers-browserify": { "version": "2.0.12", "resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz", @@ -44200,6 +44209,7 @@ "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "connect-redis": "^8.1.0", "eventsource": "^3.0.2", @@ -44222,7 +44232,6 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "tiktoken": "^1.0.15", "undici": "^7.18.2", "zod": "^3.22.4" } diff --git a/packages/api/package.json b/packages/api/package.json index 46587797a5..966447c51b 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -94,6 +94,7 @@ "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "connect-redis": "^8.1.0", "eventsource": "^3.0.2", @@ -116,7 +117,6 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "tiktoken": "^1.0.15", "undici": "^7.18.2", "zod": "^3.22.4" } diff --git a/packages/api/src/agents/__tests__/memory.test.ts b/packages/api/src/agents/__tests__/memory.test.ts index 74cd0f4354..dabe6de629 100644 --- a/packages/api/src/agents/__tests__/memory.test.ts +++ b/packages/api/src/agents/__tests__/memory.test.ts @@ -22,8 +22,9 @@ jest.mock('winston', () => ({ })); // Mock the Tokenizer -jest.mock('~/utils', () => ({ - Tokenizer: { +jest.mock('~/utils/tokenizer', () => ({ + __esModule: true, + default: { getTokenCount: jest.fn((text: string) => text.length), // Simple mock: 1 char = 1 token }, })); diff --git a/packages/api/src/agents/memory.ts b/packages/api/src/agents/memory.ts index b8f65a9772..b7ae8a8123 100644 --- a/packages/api/src/agents/memory.ts +++ b/packages/api/src/agents/memory.ts @@ -19,7 +19,8 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider'; import type { BaseMessage, ToolMessage } from '@langchain/core/messages'; import type { Response as ServerResponse } from 'express'; import { GenerationJobManager } from '~/stream/GenerationJobManager'; -import { Tokenizer, resolveHeaders, createSafeUser } from '~/utils'; +import { resolveHeaders, createSafeUser } from '~/utils'; +import Tokenizer from '~/utils/tokenizer'; type RequiredMemoryMethods = Pick< MemoryMethods, diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index a7edb3882d..687ee7aa49 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -15,6 +15,8 @@ export * from './mcp/errors'; /* Utilities */ export * from './mcp/utils'; export * from './utils'; +export { default as Tokenizer, countTokens } from './utils/tokenizer'; +export type { EncodingName } from './utils/tokenizer'; export * from './db/utils'; /* OAuth */ export * from './oauth'; diff --git a/packages/api/src/utils/index.ts b/packages/api/src/utils/index.ts index 470780cd5c..441c2e02d7 100644 --- a/packages/api/src/utils/index.ts +++ b/packages/api/src/utils/index.ts @@ -19,7 +19,6 @@ export * from './promise'; export * from './sanitizeTitle'; export * from './tempChatRetention'; export * from './text'; -export { default as Tokenizer, countTokens } from './tokenizer'; export * from './yaml'; export * from './http'; export * from './tokens'; diff --git a/packages/api/src/utils/text.spec.ts b/packages/api/src/utils/text.spec.ts index 1b8d8aac98..30185f9da7 100644 --- a/packages/api/src/utils/text.spec.ts +++ b/packages/api/src/utils/text.spec.ts @@ -65,7 +65,7 @@ const createRealTokenCounter = () => { let callCount = 0; const tokenCountFn = (text: string): number => { callCount++; - return Tokenizer.getTokenCount(text, 'cl100k_base'); + return Tokenizer.getTokenCount(text, 'o200k_base'); }; return { tokenCountFn, @@ -590,9 +590,9 @@ describe('processTextWithTokenLimit', () => { }); }); - describe('direct comparison with REAL tiktoken tokenizer', () => { - beforeEach(() => { - Tokenizer.freeAndResetAllEncoders(); + describe('direct comparison with REAL ai-tokenizer', () => { + beforeAll(async () => { + await Tokenizer.initEncoding('o200k_base'); }); it('should produce valid truncation with real tokenizer', async () => { @@ -611,7 +611,7 @@ describe('processTextWithTokenLimit', () => { expect(result.text.length).toBeLessThan(text.length); }); - it('should use fewer tiktoken calls than old implementation (realistic text)', async () => { + it('should use fewer tokenizer calls than old implementation (realistic text)', async () => { const oldCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter(); const text = createRealisticText(15000); @@ -623,8 +623,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -634,17 +632,17 @@ describe('processTextWithTokenLimit', () => { const oldCalls = oldCounter.getCallCount(); const newCalls = newCounter.getCallCount(); - console.log(`[Real tiktoken ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`); - console.log(`[Real tiktoken] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); + console.log(`[Real tokenizer ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`); + console.log(`[Real tokenizer] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); expect(newCalls).toBeLessThan(oldCalls); }); - it('should handle the reported user scenario with real tokenizer (~120k tokens)', async () => { + it('should handle large text with real tokenizer (~20k tokens)', async () => { const oldCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter(); - const text = createRealisticText(120000); - const tokenLimit = 100000; + const text = createRealisticText(20000); + const tokenLimit = 15000; const startOld = performance.now(); await processTextWithTokenLimitOLD({ @@ -654,8 +652,6 @@ describe('processTextWithTokenLimit', () => { }); const timeOld = performance.now() - startOld; - Tokenizer.freeAndResetAllEncoders(); - const startNew = performance.now(); const result = await processTextWithTokenLimit({ text, @@ -667,9 +663,9 @@ describe('processTextWithTokenLimit', () => { const oldCalls = oldCounter.getCallCount(); const newCalls = newCounter.getCallCount(); - console.log(`\n[REAL TIKTOKEN - User reported scenario: ~120k tokens]`); - console.log(`OLD implementation: ${oldCalls} tiktoken calls, ${timeOld.toFixed(0)}ms`); - console.log(`NEW implementation: ${newCalls} tiktoken calls, ${timeNew.toFixed(0)}ms`); + console.log(`\n[REAL TOKENIZER - ~20k tokens]`); + console.log(`OLD implementation: ${oldCalls} tokenizer calls, ${timeOld.toFixed(0)}ms`); + console.log(`NEW implementation: ${newCalls} tokenizer calls, ${timeNew.toFixed(0)}ms`); console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`); console.log( @@ -684,8 +680,8 @@ describe('processTextWithTokenLimit', () => { it('should achieve at least 70% reduction with real tokenizer', async () => { const oldCounter = createRealTokenCounter(); const newCounter = createRealTokenCounter(); - const text = createRealisticText(50000); - const tokenLimit = 10000; + const text = createRealisticText(15000); + const tokenLimit = 5000; await processTextWithTokenLimitOLD({ text, @@ -693,8 +689,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -706,7 +700,7 @@ describe('processTextWithTokenLimit', () => { const reduction = 1 - newCalls / oldCalls; console.log( - `[Real tiktoken 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, + `[Real tokenizer 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, ); expect(reduction).toBeGreaterThanOrEqual(0.7); @@ -714,10 +708,6 @@ describe('processTextWithTokenLimit', () => { }); describe('using countTokens async function from @librechat/api', () => { - beforeEach(() => { - Tokenizer.freeAndResetAllEncoders(); - }); - it('countTokens should return correct token count', async () => { const text = 'Hello, world!'; const count = await countTokens(text); @@ -759,8 +749,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -776,11 +764,11 @@ describe('processTextWithTokenLimit', () => { expect(newCalls).toBeLessThan(oldCalls); }); - it('should handle user reported scenario with countTokens (~120k tokens)', async () => { + it('should handle large text with countTokens (~20k tokens)', async () => { const oldCounter = createCountTokensCounter(); const newCounter = createCountTokensCounter(); - const text = createRealisticText(120000); - const tokenLimit = 100000; + const text = createRealisticText(20000); + const tokenLimit = 15000; const startOld = performance.now(); await processTextWithTokenLimitOLD({ @@ -790,8 +778,6 @@ describe('processTextWithTokenLimit', () => { }); const timeOld = performance.now() - startOld; - Tokenizer.freeAndResetAllEncoders(); - const startNew = performance.now(); const result = await processTextWithTokenLimit({ text, @@ -803,7 +789,7 @@ describe('processTextWithTokenLimit', () => { const oldCalls = oldCounter.getCallCount(); const newCalls = newCounter.getCallCount(); - console.log(`\n[countTokens - User reported scenario: ~120k tokens]`); + console.log(`\n[countTokens - ~20k tokens]`); console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`); console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`); console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`); @@ -820,8 +806,8 @@ describe('processTextWithTokenLimit', () => { it('should achieve at least 70% reduction with countTokens', async () => { const oldCounter = createCountTokensCounter(); const newCounter = createCountTokensCounter(); - const text = createRealisticText(50000); - const tokenLimit = 10000; + const text = createRealisticText(15000); + const tokenLimit = 5000; await processTextWithTokenLimitOLD({ text, @@ -829,8 +815,6 @@ describe('processTextWithTokenLimit', () => { tokenCountFn: oldCounter.tokenCountFn, }); - Tokenizer.freeAndResetAllEncoders(); - await processTextWithTokenLimit({ text, tokenLimit, @@ -842,7 +826,7 @@ describe('processTextWithTokenLimit', () => { const reduction = 1 - newCalls / oldCalls; console.log( - `[countTokens 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, + `[countTokens 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`, ); expect(reduction).toBeGreaterThanOrEqual(0.7); diff --git a/packages/api/src/utils/tokenizer.spec.ts b/packages/api/src/utils/tokenizer.spec.ts index edd6fe14de..b8c1bd8d98 100644 --- a/packages/api/src/utils/tokenizer.spec.ts +++ b/packages/api/src/utils/tokenizer.spec.ts @@ -1,12 +1,3 @@ -/** - * @file Tokenizer.spec.cjs - * - * Tests the real TokenizerSingleton (no mocking of `tiktoken`). - * Make sure to install `tiktoken` and have it configured properly. - */ - -import { logger } from '@librechat/data-schemas'; -import type { Tiktoken } from 'tiktoken'; import Tokenizer from './tokenizer'; jest.mock('@librechat/data-schemas', () => ({ @@ -17,127 +8,49 @@ jest.mock('@librechat/data-schemas', () => ({ describe('Tokenizer', () => { it('should be a singleton (same instance)', async () => { - const AnotherTokenizer = await import('./tokenizer'); // same path + const AnotherTokenizer = await import('./tokenizer'); expect(Tokenizer).toBe(AnotherTokenizer.default); }); - describe('getTokenizer', () => { - it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => { - // The real `encoding_for_model` will be called internally - // as soon as we pass isModelName = true. - const tokenizer = Tokenizer.getTokenizer('gpt-4', true); - - // Basic sanity checks - expect(tokenizer).toBeDefined(); - // You can optionally check certain properties from `tiktoken` if they exist - // e.g., expect(typeof tokenizer.encode).toBe('function'); + describe('initEncoding', () => { + it('should load o200k_base encoding', async () => { + await Tokenizer.initEncoding('o200k_base'); + const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base'); + expect(count).toBeGreaterThan(0); }); - it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => { - // The real `get_encoding` will be called internally - // as soon as we pass isModelName = false. - const tokenizer = Tokenizer.getTokenizer('cl100k_base', false); - - expect(tokenizer).toBeDefined(); - // e.g., expect(typeof tokenizer.encode).toBe('function'); + it('should load claude encoding', async () => { + await Tokenizer.initEncoding('claude'); + const count = Tokenizer.getTokenCount('Hello, world!', 'claude'); + expect(count).toBeGreaterThan(0); }); - it('should return cached tokenizer if previously fetched', () => { - const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false); - const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false); - // Should be the exact same instance from the cache - expect(tokenizer1).toBe(tokenizer2); - }); - }); - - describe('freeAndResetAllEncoders', () => { - beforeEach(() => { - jest.clearAllMocks(); - }); - - it('should free all encoders and reset tokenizerCallsCount to 1', () => { - // By creating two different encodings, we populate the cache - Tokenizer.getTokenizer('cl100k_base', false); - Tokenizer.getTokenizer('r50k_base', false); - - // Now free them - Tokenizer.freeAndResetAllEncoders(); - - // The internal cache is cleared - expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined(); - expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined(); - - // tokenizerCallsCount is reset to 1 - expect(Tokenizer.tokenizerCallsCount).toBe(1); - }); - - it('should catch and log errors if freeing fails', () => { - // Mock logger.error before the test - const mockLoggerError = jest.spyOn(logger, 'error'); - - // Set up a problematic tokenizer in the cache - Tokenizer.tokenizersCache['cl100k_base'] = { - free() { - throw new Error('Intentional free error'); - }, - } as unknown as Tiktoken; - - // Should not throw uncaught errors - Tokenizer.freeAndResetAllEncoders(); - - // Verify logger.error was called with correct arguments - expect(mockLoggerError).toHaveBeenCalledWith( - '[Tokenizer] Free and reset encoders error', - expect.any(Error), - ); - - // Clean up - mockLoggerError.mockRestore(); - Tokenizer.tokenizersCache = {}; + it('should deduplicate concurrent init calls', async () => { + const [, , count] = await Promise.all([ + Tokenizer.initEncoding('o200k_base'), + Tokenizer.initEncoding('o200k_base'), + Tokenizer.initEncoding('o200k_base').then(() => + Tokenizer.getTokenCount('test', 'o200k_base'), + ), + ]); + expect(count).toBeGreaterThan(0); }); }); describe('getTokenCount', () => { - beforeEach(() => { - jest.clearAllMocks(); - Tokenizer.freeAndResetAllEncoders(); + beforeAll(async () => { + await Tokenizer.initEncoding('o200k_base'); + await Tokenizer.initEncoding('claude'); }); it('should return the number of tokens in the given text', () => { - const text = 'Hello, world!'; - const count = Tokenizer.getTokenCount(text, 'cl100k_base'); + const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base'); expect(count).toBeGreaterThan(0); }); - it('should reset encoders if an error is thrown', () => { - // We can simulate an error by temporarily overriding the selected tokenizer's `encode` method. - const tokenizer = Tokenizer.getTokenizer('cl100k_base', false); - const originalEncode = tokenizer.encode; - tokenizer.encode = () => { - throw new Error('Forced error'); - }; - - // Despite the forced error, the code should catch and reset, then re-encode - const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base'); + it('should count tokens using claude encoding', () => { + const count = Tokenizer.getTokenCount('Hello, world!', 'claude'); expect(count).toBeGreaterThan(0); - - // Restore the original encode - tokenizer.encode = originalEncode; - }); - - it('should reset tokenizers after 25 calls', () => { - // Spy on freeAndResetAllEncoders - const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders'); - - // Make 24 calls; should NOT reset yet - for (let i = 0; i < 24; i++) { - Tokenizer.getTokenCount('test text', 'cl100k_base'); - } - expect(resetSpy).not.toHaveBeenCalled(); - - // 25th call triggers the reset - Tokenizer.getTokenCount('the 25th call!', 'cl100k_base'); - expect(resetSpy).toHaveBeenCalledTimes(1); }); }); }); diff --git a/packages/api/src/utils/tokenizer.ts b/packages/api/src/utils/tokenizer.ts index 0b0282d36b..4c638c948e 100644 --- a/packages/api/src/utils/tokenizer.ts +++ b/packages/api/src/utils/tokenizer.ts @@ -1,74 +1,46 @@ import { logger } from '@librechat/data-schemas'; -import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken'; -import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken'; +import { Tokenizer as AiTokenizer } from 'ai-tokenizer'; -interface TokenizerOptions { - debug?: boolean; -} +export type EncodingName = 'o200k_base' | 'claude'; + +type EncodingData = ConstructorParameters[0]; class Tokenizer { - tokenizersCache: Record; - tokenizerCallsCount: number; - private options?: TokenizerOptions; + private tokenizersCache: Partial> = {}; + private loadingPromises: Partial>> = {}; - constructor() { - this.tokenizersCache = {}; - this.tokenizerCallsCount = 0; - } - - getTokenizer( - encoding: TiktokenModel | TiktokenEncoding, - isModelName = false, - extendSpecialTokens: Record = {}, - ): Tiktoken { - let tokenizer: Tiktoken; + /** Pre-loads an encoding so that subsequent getTokenCount calls are accurate. */ + async initEncoding(encoding: EncodingName): Promise { if (this.tokenizersCache[encoding]) { - tokenizer = this.tokenizersCache[encoding]; - } else { - if (isModelName) { - tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens); - } else { - tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens); - } - this.tokenizersCache[encoding] = tokenizer; + return; } - return tokenizer; + if (this.loadingPromises[encoding]) { + return this.loadingPromises[encoding]; + } + this.loadingPromises[encoding] = (async () => { + const data: EncodingData = + encoding === 'claude' + ? await import('ai-tokenizer/encoding/claude') + : await import('ai-tokenizer/encoding/o200k_base'); + this.tokenizersCache[encoding] = new AiTokenizer(data); + })(); + return this.loadingPromises[encoding]; } - freeAndResetAllEncoders(): void { + getTokenCount(text: string, encoding: EncodingName = 'o200k_base'): number { + const tokenizer = this.tokenizersCache[encoding]; + if (!tokenizer) { + this.initEncoding(encoding); + return Math.ceil(text.length / 4); + } try { - Object.keys(this.tokenizersCache).forEach((key) => { - if (this.tokenizersCache[key]) { - this.tokenizersCache[key].free(); - delete this.tokenizersCache[key]; - } - }); - this.tokenizerCallsCount = 1; - } catch (error) { - logger.error('[Tokenizer] Free and reset encoders error', error); - } - } - - resetTokenizersIfNecessary(): void { - if (this.tokenizerCallsCount >= 25) { - if (this.options?.debug) { - logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...'); - } - this.freeAndResetAllEncoders(); - } - this.tokenizerCallsCount++; - } - - getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number { - this.resetTokenizersIfNecessary(); - try { - const tokenizer = this.getTokenizer(encoding); - return tokenizer.encode(text, 'all').length; + return tokenizer.count(text); } catch (error) { logger.error('[Tokenizer] Error getting token count:', error); - this.freeAndResetAllEncoders(); - const tokenizer = this.getTokenizer(encoding); - return tokenizer.encode(text, 'all').length; + delete this.tokenizersCache[encoding]; + delete this.loadingPromises[encoding]; + this.initEncoding(encoding); + return Math.ceil(text.length / 4); } } } @@ -76,13 +48,13 @@ class Tokenizer { const TokenizerSingleton = new Tokenizer(); /** - * Counts the number of tokens in a given text using tiktoken. - * This is an async wrapper around Tokenizer.getTokenCount for compatibility. - * @param text - The text to be tokenized. Defaults to an empty string if not provided. + * Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding. + * @param text - The text to count tokens in. Defaults to an empty string. * @returns The number of tokens in the provided text. */ export async function countTokens(text = ''): Promise { - return TokenizerSingleton.getTokenCount(text, 'cl100k_base'); + await TokenizerSingleton.initEncoding('o200k_base'); + return TokenizerSingleton.getTokenCount(text, 'o200k_base'); } export default TokenizerSingleton; diff --git a/packages/api/src/utils/tokens.ts b/packages/api/src/utils/tokens.ts index 32b2fc6036..ae09da4f28 100644 --- a/packages/api/src/utils/tokens.ts +++ b/packages/api/src/utils/tokens.ts @@ -593,42 +593,3 @@ export function processModelData(input: z.infer): EndpointTo return tokenConfig; } - -export const tiktokenModels = new Set([ - 'text-davinci-003', - 'text-davinci-002', - 'text-davinci-001', - 'text-curie-001', - 'text-babbage-001', - 'text-ada-001', - 'davinci', - 'curie', - 'babbage', - 'ada', - 'code-davinci-002', - 'code-davinci-001', - 'code-cushman-002', - 'code-cushman-001', - 'davinci-codex', - 'cushman-codex', - 'text-davinci-edit-001', - 'code-davinci-edit-001', - 'text-embedding-ada-002', - 'text-similarity-davinci-001', - 'text-similarity-curie-001', - 'text-similarity-babbage-001', - 'text-similarity-ada-001', - 'text-search-davinci-doc-001', - 'text-search-curie-doc-001', - 'text-search-babbage-doc-001', - 'text-search-ada-doc-001', - 'code-search-babbage-code-001', - 'code-search-ada-code-001', - 'gpt2', - 'gpt-4', - 'gpt-4-0314', - 'gpt-4-32k', - 'gpt-4-32k-0314', - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-0301', -]); From fc6f7a337dc66eb95f1b24e2b095a745eb36d1d2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:46:55 -0400 Subject: [PATCH 07/39] =?UTF-8?q?=F0=9F=8C=8D=20i18n:=20Update=20translati?= =?UTF-8?q?on.json=20with=20latest=20translations=20(#12176)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- client/src/locales/lv/translation.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/client/src/locales/lv/translation.json b/client/src/locales/lv/translation.json index 76c2db24ea..57794a9e2a 100644 --- a/client/src/locales/lv/translation.json +++ b/client/src/locales/lv/translation.json @@ -39,7 +39,7 @@ "com_agents_description_card": "Apraksts: {{description}}", "com_agents_description_placeholder": "Pēc izvēles: aprakstiet savu aģentu šeit", "com_agents_empty_state_heading": "Nav atrasts neviens aģents", - "com_agents_enable_file_search": "Iespējot vektorizēto meklēšanu", + "com_agents_enable_file_search": "Iespējot meklēšanu dokumentos", "com_agents_error_bad_request_message": "Pieprasījumu nevarēja apstrādāt.", "com_agents_error_bad_request_suggestion": "Lūdzu, pārbaudiet ievadītos datus un mēģiniet vēlreiz.", "com_agents_error_category_title": "Kategorija Kļūda", @@ -66,7 +66,7 @@ "com_agents_file_context_description": "Visi augšupielādētie faili tiek pilnībā pārveidoti tekstā un nekavējoties pievienoti aģenta pamata kontekstam kā nemainīgs saturs, kas pieejams visu sarunas laiku. Ja augšupielādētajam faila tipam ir pieejams vai konfigurēts OCR, teksta izvilkšana notiek automātiski. Šī metode ir piemērota gadījumos, kad nepieciešams analizēt visu dokumenta, attēla ar tekstu vai PDF faila saturu, taču jāņem vērā, ka tas ievērojami palielina atmiņas patēriņu un izmaksas.", "com_agents_file_context_disabled": "Pirms failu augšupielādes, lai to pievienotu kā kontekstu, ir jāizveido aģents.", "com_agents_file_context_label": "Pievienot failu kā kontekstu", - "com_agents_file_search_disabled": "Lai varētu iespējot vektorizētu meklēšanu ir jāizveido aģents.", + "com_agents_file_search_disabled": "Lai varētu iespējot meklēšanu dokumentos ir jāizveido aģents.", "com_agents_file_search_info": "Kad šī opcija ir iespējota, aģents izmanto vektorizētu datu meklēšanu (RAG pieeju), kas ļauj efektīvi un izmaksu ziņā izdevīgi izgūt atbilstošu kontekstu tikai no būtiskākajām faila daļām, balstoties uz lietotāja jautājumu, nevis analizē visu failu pilnā apjomā.", "com_agents_grid_announcement": "Rādu {{count}} aģentus {{category}} kategorijā", "com_agents_instructions_placeholder": "Sistēmas instrukcijas, ko izmantos aģents", @@ -126,7 +126,7 @@ "com_assistants_delete_actions_success": "Darbība veiksmīgi dzēsta no asistenta", "com_assistants_description_placeholder": "Pēc izvēles: Šeit aprakstiet savu asistentu", "com_assistants_domain_info": "Asistents nosūtīja šo informāciju {{0}}", - "com_assistants_file_search": "Vektorizētā Meklēšana (RAG)", + "com_assistants_file_search": "Meklēšana dokumentos", "com_assistants_file_search_info": "Šī funkcija ļauj asistentam izmantot augšupielādēto failu saturu, pievienojot zināšanas tieši no lietotāja vai citu lietotāju failiem. Pēc faila augšupielādes asistents automātiski identificē un izgūst nepieciešamās teksta daļas atbilstoši lietotāja pieprasījumam, neiekļaujot visu failu pilnā apjomā. Vektoru datubāzu (vector store) pieslēgšana tieši šai funkcijai šobrīd nav atbalstīta; tās iespējams pievienot tikai Provider Playground vidē vai augšupielādējot failus sarunas pavedienam ikreizējai meklēšanai.", "com_assistants_function_use": "Izmantotais asistents {{0}}", "com_assistants_image_vision": "Attēla redzējums", @@ -136,7 +136,7 @@ "com_assistants_knowledge_info": "Ja augšupielādējat failus sadaļā Zināšanas, sarunās ar asistentu var tikt iekļauts faila saturs.", "com_assistants_max_starters_reached": "Sasniegts maksimālais sarunu uzsākšanas iespēju skaits", "com_assistants_name_placeholder": "Pēc izvēles: Asistenta nosaukums", - "com_assistants_non_retrieval_model": "Šajā modelī vektorizētā meklēšana nav iespējota. Lūdzu, izvēlieties citu modeli.", + "com_assistants_non_retrieval_model": "Šajā modelī meklēšana dokumentos nav iespējota. Lūdzu, izvēlieties citu modeli.", "com_assistants_retrieval": "Atgūšana", "com_assistants_running_action": "Darbība palaista", "com_assistants_running_var": "Strādā {{0}}", @@ -232,7 +232,7 @@ "com_endpoint_anthropic_thinking_budget": "Nosaka maksimālo žetonu skaitu, ko Claude drīkst izmantot savā iekšējā spriešanas procesā. Lielāki budžeti var uzlabot atbilžu kvalitāti, nodrošinot rūpīgāku analīzi sarežģītām problēmām, lai gan Claude var neizmantot visu piešķirto budžetu, īpaši diapazonos virs 32 000. Šim iestatījumam jābūt zemākam par \"Maksimālie izvades tokeni\".", "com_endpoint_anthropic_topk": "Top-k maina to, kā modelis atlasa marķierus izvadei. Ja top-k ir 1, tas nozīmē, ka atlasītais marķieris ir visticamākais starp visiem modeļa vārdu krājumā esošajiem marķieriem (to sauc arī par alkatīgo dekodēšanu), savukārt, ja top-k ir 3, tas nozīmē, ka nākamais marķieris tiek izvēlēts no 3 visticamākajiem marķieriem (izmantojot temperatūru).", "com_endpoint_anthropic_topp": "`Top-p` maina to, kā modelis atlasa marķierus izvadei. Marķieri tiek atlasīti no K (skatīt parametru topK) ticamākās līdz vismazāk ticamajai, līdz to varbūtību summa ir vienāda ar `top-p` vērtību.", - "com_endpoint_anthropic_use_web_search": "Iespējojiet tīmekļa meklēšanas funkcionalitāti, izmantojot Anthropic iebūvētās meklēšanas iespējas. Tas ļauj modelim meklēt tīmeklī jaunāko informāciju un sniegt precīzākas un aktuālākas atbildes.", + "com_endpoint_anthropic_use_web_search": "Iespējojiet meklēšanu tīmeklī funkcionalitāti, izmantojot Anthropic iebūvētās meklēšanas iespējas. Tas ļauj modelim meklēt tīmeklī jaunāko informāciju un sniegt precīzākas un aktuālākas atbildes.", "com_endpoint_assistant": "Asistents", "com_endpoint_assistant_model": "Asistenta modelis", "com_endpoint_assistant_placeholder": "Lūdzu, labajā sānu panelī atlasiet asistentu.", @@ -1486,7 +1486,7 @@ "com_ui_version_var": "Versija {{0}}", "com_ui_versions": "Versijas", "com_ui_view_memory": "Skatīt atmiņu", - "com_ui_web_search": "Tīmekļa meklēšana", + "com_ui_web_search": "Meklēšana tīmeklī", "com_ui_web_search_cohere_key": "Ievadiet Cohere API atslēgu", "com_ui_web_search_firecrawl_url": "Firecrawl API URL (pēc izvēles)", "com_ui_web_search_jina_key": "Ievadiet Jina API atslēgu", From 3ddf62c8e5511a2c30672dbe3e6e07bedff374e6 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 12 Mar 2026 20:43:23 -0400 Subject: [PATCH 08/39] =?UTF-8?q?=F0=9F=AB=99=20fix:=20Force=20MeiliSearch?= =?UTF-8?q?=20Full=20Sync=20on=20Empty=20Index=20State=20(#12202)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: meili index sync with unindexed documents - Updated `performSync` function to force a full sync when a fresh MeiliSearch index is detected, even if the number of unindexed messages or convos is below the sync threshold. - Added logging to indicate when a fresh index is detected and a full sync is initiated. - Introduced new tests to validate the behavior of the sync logic under various conditions, ensuring proper handling of fresh indexes and threshold scenarios. This change improves the reliability of the synchronization process, ensuring that all documents are indexed correctly when starting with a fresh index. * refactor: update sync logic for unindexed documents in MeiliSearch - Renamed variables in `performSync` to improve clarity, changing `freshIndex` to `noneIndexed` for better understanding of the sync condition. - Adjusted the logic to ensure a full sync is forced when no messages or conversations are marked as indexed, even if below the sync threshold. - Updated related tests to reflect the new logging messages and conditions, enhancing the accuracy of the sync threshold logic. This change improves the readability and reliability of the synchronization process, ensuring all documents are indexed correctly when starting with a fresh index. * fix: enhance MeiliSearch index creation error handling - Updated the `mongoMeili` function to improve logging and error handling during index creation in MeiliSearch. - Added handling for `MeiliSearchTimeOutError` to log a warning when index creation times out. - Enhanced logging to differentiate between successful index creation and specific failure reasons, including cases where the index already exists. - Improved debug logging for index creation tasks to provide clearer insights into the process. This change enhances the robustness of the index creation process and improves observability for troubleshooting. * fix: update MeiliSearch index creation error handling - Modified the `mongoMeili` function to check for any status other than 'succeeded' during index creation, enhancing error detection. - Improved logging to provide clearer insights when an index creation task fails, particularly for cases where the index already exists. This change strengthens the error handling mechanism for index creation in MeiliSearch, ensuring better observability and reliability. --- api/db/indexSync.js | 14 +++- api/db/indexSync.spec.js | 65 +++++++++++++++++++ .../src/models/plugins/mongoMeili.ts | 31 ++++++--- 3 files changed, 99 insertions(+), 11 deletions(-) diff --git a/api/db/indexSync.js b/api/db/indexSync.js index 8e8e999d92..130cde77b8 100644 --- a/api/db/indexSync.js +++ b/api/db/indexSync.js @@ -236,8 +236,12 @@ async function performSync(flowManager, flowId, flowType) { const messageCount = messageProgress.totalDocuments; const messagesIndexed = messageProgress.totalProcessed; const unindexedMessages = messageCount - messagesIndexed; + const noneIndexed = messagesIndexed === 0 && unindexedMessages > 0; - if (settingsUpdated || unindexedMessages > syncThreshold) { + if (settingsUpdated || noneIndexed || unindexedMessages > syncThreshold) { + if (noneIndexed && !settingsUpdated) { + logger.info('[indexSync] No messages marked as indexed, forcing full sync'); + } logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`); await Message.syncWithMeili(); messagesSync = true; @@ -261,9 +265,13 @@ async function performSync(flowManager, flowId, flowType) { const convoCount = convoProgress.totalDocuments; const convosIndexed = convoProgress.totalProcessed; - const unindexedConvos = convoCount - convosIndexed; - if (settingsUpdated || unindexedConvos > syncThreshold) { + const noneConvosIndexed = convosIndexed === 0 && unindexedConvos > 0; + + if (settingsUpdated || noneConvosIndexed || unindexedConvos > syncThreshold) { + if (noneConvosIndexed && !settingsUpdated) { + logger.info('[indexSync] No conversations marked as indexed, forcing full sync'); + } logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`); await Conversation.syncWithMeili(); convosSync = true; diff --git a/api/db/indexSync.spec.js b/api/db/indexSync.spec.js index c2e5901d6a..dbe07c7595 100644 --- a/api/db/indexSync.spec.js +++ b/api/db/indexSync.spec.js @@ -462,4 +462,69 @@ describe('performSync() - syncThreshold logic', () => { ); expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)'); }); + + test('forces sync when zero documents indexed (reset scenario) even if below threshold', async () => { + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 0, + totalDocuments: 680, + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 0, + totalDocuments: 76, + isComplete: false, + }); + + Message.syncWithMeili.mockResolvedValue(undefined); + Conversation.syncWithMeili.mockResolvedValue(undefined); + + const indexSync = require('./indexSync'); + await indexSync(); + + expect(Message.syncWithMeili).toHaveBeenCalledTimes(1); + expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] No messages marked as indexed, forcing full sync', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Starting message sync (680 unindexed)', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] No conversations marked as indexed, forcing full sync', + ); + expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (76 unindexed)'); + }); + + test('does NOT force sync when some documents already indexed and below threshold', async () => { + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 630, + totalDocuments: 680, + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 70, + totalDocuments: 76, + isComplete: false, + }); + + const indexSync = require('./indexSync'); + await indexSync(); + + expect(Message.syncWithMeili).not.toHaveBeenCalled(); + expect(Conversation.syncWithMeili).not.toHaveBeenCalled(); + expect(mockLogger.info).not.toHaveBeenCalledWith( + '[indexSync] No messages marked as indexed, forcing full sync', + ); + expect(mockLogger.info).not.toHaveBeenCalledWith( + '[indexSync] No conversations marked as indexed, forcing full sync', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] 50 messages unindexed (below threshold: 1000, skipping)', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] 6 convos unindexed (below threshold: 1000, skipping)', + ); + }); }); diff --git a/packages/data-schemas/src/models/plugins/mongoMeili.ts b/packages/data-schemas/src/models/plugins/mongoMeili.ts index 66530e2aba..cc01dbb6c7 100644 --- a/packages/data-schemas/src/models/plugins/mongoMeili.ts +++ b/packages/data-schemas/src/models/plugins/mongoMeili.ts @@ -1,7 +1,7 @@ import _ from 'lodash'; -import { MeiliSearch } from 'meilisearch'; import { parseTextParts } from 'librechat-data-provider'; -import type { SearchResponse, SearchParams, Index } from 'meilisearch'; +import { MeiliSearch, MeiliSearchTimeOutError } from 'meilisearch'; +import type { SearchResponse, SearchParams, Index, MeiliSearchErrorInfo } from 'meilisearch'; import type { CallbackWithoutResultAndOptionalError, FilterQuery, @@ -581,7 +581,6 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions): /** Create index only if it doesn't exist */ const index = client.index(indexName); - // Check if index exists and create if needed (async () => { try { await index.getRawInfo(); @@ -591,18 +590,34 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions): if (errorCode === 'index_not_found') { try { logger.info(`[mongoMeili] Creating new index: ${indexName}`); - await client.createIndex(indexName, { primaryKey }); - logger.info(`[mongoMeili] Successfully created index: ${indexName}`); + const enqueued = await client.createIndex(indexName, { primaryKey }); + const task = await client.waitForTask(enqueued.taskUid, { + timeOutMs: 10000, + intervalMs: 100, + }); + logger.debug(`[mongoMeili] Index ${indexName} creation task:`, task); + if (task.status !== 'succeeded') { + const taskError = task.error as MeiliSearchErrorInfo | null; + if (taskError?.code === 'index_already_exists') { + logger.debug(`[mongoMeili] Index ${indexName} was created by another instance`); + } else { + logger.warn(`[mongoMeili] Index ${indexName} creation failed:`, taskError); + } + } else { + logger.info(`[mongoMeili] Successfully created index: ${indexName}`); + } } catch (createError) { - // Index might have been created by another instance - logger.debug(`[mongoMeili] Index ${indexName} may already exist:`, createError); + if (createError instanceof MeiliSearchTimeOutError) { + logger.warn(`[mongoMeili] Timed out waiting for index ${indexName} creation`); + } else { + logger.warn(`[mongoMeili] Error creating index ${indexName}:`, createError); + } } } else { logger.error(`[mongoMeili] Error checking index ${indexName}:`, error); } } - // Configure index settings to make 'user' field filterable try { await index.updateSettings({ filterableAttributes: ['user'], From 65b0bfde1b07141b4d2fefd246379c6102b5fed4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:48:05 -0400 Subject: [PATCH 09/39] =?UTF-8?q?=F0=9F=8C=8D=20i18n:=20Update=20translati?= =?UTF-8?q?on.json=20with=20latest=20translations=20(#12203)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- client/src/locales/fr/translation.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/src/locales/fr/translation.json b/client/src/locales/fr/translation.json index c9d78ac3f5..7838b33739 100644 --- a/client/src/locales/fr/translation.json +++ b/client/src/locales/fr/translation.json @@ -1203,7 +1203,7 @@ "com_ui_upload_image_input": "Téléverser une image", "com_ui_upload_invalid": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser la limite", "com_ui_upload_invalid_var": "Fichier non valide pour le téléchargement. L'image ne doit pas dépasser {{0}} Mo", - "com_ui_upload_ocr_text": "Téléchager en tant que texte", + "com_ui_upload_ocr_text": "Télécharger en tant que texte", "com_ui_upload_provider": "Télécharger vers le fournisseur", "com_ui_upload_success": "Fichier téléversé avec succès", "com_ui_upload_type": "Sélectionner le type de téléversement", From f32907cd362c9d87b362661cd30a8f8a718fc864 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 12 Mar 2026 23:19:31 -0400 Subject: [PATCH 10/39] =?UTF-8?q?=F0=9F=94=8F=20fix:=20MCP=20Server=20URL?= =?UTF-8?q?=20Schema=20Validation=20(#12204)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: MCP server configuration validation and schema - Added tests to reject URLs containing environment variable references for SSE, streamable-http, and websocket types in the MCP routes. - Introduced a new schema in the data provider to ensure user input URLs do not resolve environment variables, enhancing security against potential leaks. - Updated existing MCP server user input schema to utilize the new validation logic, ensuring consistent handling of user-supplied URLs across the application. * fix: MCP URL validation to reject env variable references - Updated tests to ensure that URLs for SSE, streamable-http, and websocket types containing environment variable patterns are rejected, improving security against potential leaks. - Refactored the MCP server user input schema to enforce stricter validation rules, preventing the resolution of environment variables in user-supplied URLs. - Introduced new test cases for various URL types to validate the rejection logic, ensuring consistent handling across the application. * test: Enhance MCPServerUserInputSchema tests for environment variable handling - Introduced new test cases to validate the prevention of environment variable exfiltration through user input URLs in the MCPServerUserInputSchema. - Updated existing tests to confirm that URLs containing environment variable patterns are correctly resolved or rejected, improving security against potential leaks. - Refactored test structure to better organize environment variable handling scenarios, ensuring comprehensive coverage of edge cases. --- api/server/routes/__tests__/mcp.spec.js | 90 ++++++++++++++ packages/data-provider/specs/mcp.spec.ts | 147 +++++++++++++++++++++++ packages/data-provider/src/mcp.ts | 35 +++++- 3 files changed, 269 insertions(+), 3 deletions(-) create mode 100644 packages/data-provider/specs/mcp.spec.ts diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 009b602604..e0cb680169 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -1819,6 +1819,51 @@ describe('MCP Routes', () => { expect(response.body.message).toBe('Invalid configuration'); }); + it('should reject SSE URL containing env variable references', async () => { + const response = await request(app) + .post('/api/mcp/servers') + .send({ + config: { + type: 'sse', + url: 'http://attacker.com/?secret=${JWT_SECRET}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); + }); + + it('should reject streamable-http URL containing env variable references', async () => { + const response = await request(app) + .post('/api/mcp/servers') + .send({ + config: { + type: 'streamable-http', + url: 'http://attacker.com/?key=${CREDS_KEY}&iv=${CREDS_IV}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); + }); + + it('should reject websocket URL containing env variable references', async () => { + const response = await request(app) + .post('/api/mcp/servers') + .send({ + config: { + type: 'websocket', + url: 'ws://attacker.com/?secret=${MONGO_URI}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); + }); + it('should return 500 when registry throws error', async () => { const validConfig = { type: 'sse', @@ -1918,6 +1963,51 @@ describe('MCP Routes', () => { expect(response.body.errors).toBeDefined(); }); + it('should reject SSE URL containing env variable references', async () => { + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ + config: { + type: 'sse', + url: 'http://attacker.com/?secret=${JWT_SECRET}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled(); + }); + + it('should reject streamable-http URL containing env variable references', async () => { + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ + config: { + type: 'streamable-http', + url: 'http://attacker.com/?key=${CREDS_KEY}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled(); + }); + + it('should reject websocket URL containing env variable references', async () => { + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ + config: { + type: 'websocket', + url: 'ws://attacker.com/?secret=${MONGO_URI}', + }, + }); + + expect(response.status).toBe(400); + expect(response.body.message).toBe('Invalid configuration'); + expect(mockRegistryInstance.updateServer).not.toHaveBeenCalled(); + }); + it('should return 500 when registry throws error', async () => { const validConfig = { type: 'sse', diff --git a/packages/data-provider/specs/mcp.spec.ts b/packages/data-provider/specs/mcp.spec.ts new file mode 100644 index 0000000000..573769c4fa --- /dev/null +++ b/packages/data-provider/specs/mcp.spec.ts @@ -0,0 +1,147 @@ +import { SSEOptionsSchema, MCPServerUserInputSchema } from '../src/mcp'; + +describe('MCPServerUserInputSchema', () => { + describe('env variable exfiltration prevention', () => { + it('should confirm admin schema resolves env vars (attack vector baseline)', () => { + process.env.FAKE_SECRET = 'leaked-secret-value'; + const adminResult = SSEOptionsSchema.safeParse({ + type: 'sse', + url: 'http://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(adminResult.success).toBe(true); + if (adminResult.success) { + expect(adminResult.data.url).toContain('leaked-secret-value'); + } + delete process.env.FAKE_SECRET; + }); + + it('should reject the same URL through user input schema', () => { + process.env.FAKE_SECRET = 'leaked-secret-value'; + const userResult = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'http://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(userResult.success).toBe(false); + delete process.env.FAKE_SECRET; + }); + }); + + describe('env variable rejection', () => { + it('should reject SSE URLs containing env variable patterns', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'http://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(result.success).toBe(false); + }); + + it('should reject streamable-http URLs containing env variable patterns', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'streamable-http', + url: 'http://attacker.com/?jwt=${JWT_SECRET}', + }); + expect(result.success).toBe(false); + }); + + it('should reject WebSocket URLs containing env variable patterns', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'ws://attacker.com/?secret=${FAKE_SECRET}', + }); + expect(result.success).toBe(false); + }); + }); + + describe('protocol allowlisting', () => { + it('should reject file:// URLs for SSE', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'file:///etc/passwd', + }); + expect(result.success).toBe(false); + }); + + it('should reject ftp:// URLs for streamable-http', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'streamable-http', + url: 'ftp://internal-server/data', + }); + expect(result.success).toBe(false); + }); + + it('should reject http:// URLs for WebSocket', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'http://example.com/ws', + }); + expect(result.success).toBe(false); + }); + + it('should reject ws:// URLs for SSE', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'ws://example.com/sse', + }); + expect(result.success).toBe(false); + }); + }); + + describe('valid URL acceptance', () => { + it('should accept valid https:// SSE URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'https://mcp-server.com/sse', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.url).toBe('https://mcp-server.com/sse'); + } + }); + + it('should accept valid http:// SSE URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'sse', + url: 'http://mcp-server.com/sse', + }); + expect(result.success).toBe(true); + }); + + it('should accept valid wss:// WebSocket URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'wss://mcp-server.com/ws', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.url).toBe('wss://mcp-server.com/ws'); + } + }); + + it('should accept valid ws:// WebSocket URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'websocket', + url: 'ws://mcp-server.com/ws', + }); + expect(result.success).toBe(true); + }); + + it('should accept valid https:// streamable-http URLs', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'streamable-http', + url: 'https://mcp-server.com/http', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.url).toBe('https://mcp-server.com/http'); + } + }); + + it('should accept valid http:// streamable-http URLs with "http" alias', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'http', + url: 'http://mcp-server.com/mcp', + }); + expect(result.success).toBe(true); + }); + }); +}); diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index 3911e91ed0..3ad296c4ec 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -223,6 +223,23 @@ const omitServerManagedFields = >(schema: T oauth_headers: true, }); +const envVarPattern = /\$\{[^}]+\}/; +const isWsProtocol = (val: string): boolean => /^wss?:/i.test(val); +const isHttpProtocol = (val: string): boolean => /^https?:/i.test(val); + +/** + * Builds a URL schema for user input that rejects ${VAR} env variable patterns + * and validates protocol constraints without resolving environment variables. + */ +const userUrlSchema = (protocolCheck: (val: string) => boolean, message: string) => + z + .string() + .refine((val) => !envVarPattern.test(val), { + message: 'Environment variable references are not allowed in URLs', + }) + .pipe(z.string().url()) + .refine(protocolCheck, { message }); + /** * MCP Server configuration that comes from UI/API input only. * Omits server-managed fields like startup, timeout, customUserVars, etc. @@ -232,11 +249,23 @@ const omitServerManagedFields = >(schema: T * Stdio allows arbitrary command execution and should only be configured * by administrators via the YAML config file (librechat.yaml). * Only remote transports (SSE, HTTP, WebSocket) are allowed via the API. + * + * SECURITY: URL fields use userUrlSchema instead of the admin schemas' + * extractEnvVariable transform to prevent env variable exfiltration + * through user-controlled URLs (e.g. http://attacker.com/?k=${JWT_SECRET}). + * Protocol checks use positive allowlists (http(s) / ws(s)) to block + * file://, ftp://, javascript:, and other non-network schemes. */ export const MCPServerUserInputSchema = z.union([ - omitServerManagedFields(WebSocketOptionsSchema), - omitServerManagedFields(SSEOptionsSchema), - omitServerManagedFields(StreamableHTTPOptionsSchema), + omitServerManagedFields(WebSocketOptionsSchema).extend({ + url: userUrlSchema(isWsProtocol, 'WebSocket URL must use ws:// or wss://'), + }), + omitServerManagedFields(SSEOptionsSchema).extend({ + url: userUrlSchema(isHttpProtocol, 'SSE URL must use http:// or https://'), + }), + omitServerManagedFields(StreamableHTTPOptionsSchema).extend({ + url: userUrlSchema(isHttpProtocol, 'Streamable HTTP URL must use http:// or https://'), + }), ]); export type MCPServerUserInput = z.infer; From fa9e1b228a09fb02541068902635a97686eb32cc Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 13 Mar 2026 23:18:56 -0400 Subject: [PATCH 11/39] =?UTF-8?q?=F0=9F=AA=AA=20fix:=20MCP=20API=20Respons?= =?UTF-8?q?es=20and=20OAuth=20Validation=20(#12217)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔒 fix: Validate MCP Configs in Server Responses * 🔒 fix: Enhance OAuth URL Validation in MCPOAuthHandler - Introduced validation for OAuth URLs to ensure they do not target private or internal addresses, enhancing security against SSRF attacks. - Updated the OAuth flow to validate both authorization and token URLs before use, ensuring compliance with security standards. - Refactored redirect URI handling to streamline the OAuth client registration process. - Added comprehensive error handling for invalid URLs, improving robustness in OAuth interactions. * 🔒 feat: Implement Permission Checks for MCP Server Management - Added permission checkers for MCP server usage and creation, enhancing access control. - Updated routes for reinitializing MCP servers and retrieving authentication values to include these permission checks, ensuring only authorized users can access these functionalities. - Refactored existing permission logic to improve clarity and maintainability. * 🔒 fix: Enhance MCP Server Response Validation and Redaction - Updated MCP route tests to use `toMatchObject` for better validation of server response structures, ensuring consistency in expected properties. - Refactored the `redactServerSecrets` function to streamline the removal of sensitive information, ensuring that user-sourced API keys are properly redacted while retaining their source. - Improved OAuth security tests to validate rejection of private URLs across multiple endpoints, enhancing protection against SSRF vulnerabilities. - Added comprehensive tests for the `redactServerSecrets` function to ensure proper handling of various server configurations, reinforcing security measures. * chore: eslint * 🔒 fix: Enhance OAuth Server URL Validation in MCPOAuthHandler - Added validation for discovered authorization server URLs to ensure they meet security standards. - Improved logging to provide clearer insights when an authorization server is found from resource metadata. - Refactored the handling of authorization server URLs to enhance robustness against potential security vulnerabilities. * 🔒 test: Bypass SSRF validation for MCP OAuth Flow tests - Mocked SSRF validation functions to allow tests to use real local HTTP servers, facilitating more accurate testing of the MCP OAuth flow. - Updated test setup to ensure compatibility with the new mocking strategy, enhancing the reliability of the tests. * 🔒 fix: Add Validation for OAuth Metadata Endpoints in MCPOAuthHandler - Implemented checks for the presence and validity of registration and token endpoints in the OAuth metadata, enhancing security by ensuring that these URLs are properly validated before use. - Improved error handling and logging to provide better insights during the OAuth metadata processing, reinforcing the robustness of the OAuth flow. * 🔒 refactor: Simplify MCP Auth Values Endpoint Logic - Removed redundant permission checks for accessing the MCP server resource in the auth-values endpoint, streamlining the request handling process. - Consolidated error handling and response structure for improved clarity and maintainability. - Enhanced logging for better insights during the authentication value checks, reinforcing the robustness of the endpoint. * 🔒 test: Refactor LeaderElection Integration Tests for Improved Cleanup - Moved Redis key cleanup to the beforeEach hook to ensure a clean state before each test. - Enhanced afterEach logic to handle instance resignations and Redis key deletion more robustly, improving test reliability and maintainability. --- api/server/controllers/mcp.js | 14 +- api/server/routes/__tests__/mcp.spec.js | 118 ++++++++- api/server/routes/mcp.js | 143 +++++------ .../LeaderElection.cache_integration.spec.ts | 18 +- .../src/mcp/__tests__/MCPOAuthFlow.test.ts | 7 + .../mcp/__tests__/MCPOAuthSecurity.test.ts | 228 ++++++++++++++++++ packages/api/src/mcp/__tests__/utils.test.ts | 201 ++++++++++++++- packages/api/src/mcp/oauth/handler.ts | 60 ++++- .../__tests__/ServerConfigsDB.test.ts | 98 ++++++++ packages/api/src/mcp/utils.ts | 60 +++++ 10 files changed, 845 insertions(+), 102 deletions(-) create mode 100644 packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index e5dfff61ca..729f01da9d 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -7,9 +7,11 @@ */ const { logger } = require('@librechat/data-schemas'); const { + MCPErrorCodes, + redactServerSecrets, + redactAllServerSecrets, isMCPDomainNotAllowedError, isMCPInspectionFailedError, - MCPErrorCodes, } = require('@librechat/api'); const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider'); const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config'); @@ -181,10 +183,8 @@ const getMCPServersList = async (req, res) => { return res.status(401).json({ message: 'Unauthorized' }); } - // 2. Get all server configs from registry (YAML + DB) const serverConfigs = await getMCPServersRegistry().getAllServerConfigs(userId); - - return res.json(serverConfigs); + return res.json(redactAllServerSecrets(serverConfigs)); } catch (error) { logger.error('[getMCPServersList]', error); res.status(500).json({ error: error.message }); @@ -215,7 +215,7 @@ const createMCPServerController = async (req, res) => { ); res.status(201).json({ serverName: result.serverName, - ...result.config, + ...redactServerSecrets(result.config), }); } catch (error) { logger.error('[createMCPServer]', error); @@ -243,7 +243,7 @@ const getMCPServerById = async (req, res) => { return res.status(404).json({ message: 'MCP server not found' }); } - res.status(200).json(parsedConfig); + res.status(200).json(redactServerSecrets(parsedConfig)); } catch (error) { logger.error('[getMCPServerById]', error); res.status(500).json({ message: error.message }); @@ -274,7 +274,7 @@ const updateMCPServerController = async (req, res) => { userId, ); - res.status(200).json(parsedConfig); + res.status(200).json(redactServerSecrets(parsedConfig)); } catch (error) { logger.error('[updateMCPServer]', error); const mcpErrorResponse = handleMCPError(error, res); diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index e0cb680169..1ad8cac087 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -1693,12 +1693,14 @@ describe('MCP Routes', () => { it('should return all server configs for authenticated user', async () => { const mockServerConfigs = { 'server-1': { - endpoint: 'http://server1.com', - name: 'Server 1', + type: 'sse', + url: 'http://server1.com/sse', + title: 'Server 1', }, 'server-2': { - endpoint: 'http://server2.com', - name: 'Server 2', + type: 'sse', + url: 'http://server2.com/sse', + title: 'Server 2', }, }; @@ -1707,7 +1709,18 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/servers'); expect(response.status).toBe(200); - expect(response.body).toEqual(mockServerConfigs); + expect(response.body['server-1']).toMatchObject({ + type: 'sse', + url: 'http://server1.com/sse', + title: 'Server 1', + }); + expect(response.body['server-2']).toMatchObject({ + type: 'sse', + url: 'http://server2.com/sse', + title: 'Server 2', + }); + expect(response.body['server-1'].headers).toBeUndefined(); + expect(response.body['server-2'].headers).toBeUndefined(); expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith('test-user-id'); }); @@ -1762,10 +1775,10 @@ describe('MCP Routes', () => { const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); expect(response.status).toBe(201); - expect(response.body).toEqual({ - serverName: 'test-sse-server', - ...validConfig, - }); + expect(response.body.serverName).toBe('test-sse-server'); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://mcp-server.example.com/sse'); + expect(response.body.title).toBe('Test SSE Server'); expect(mockRegistryInstance.addServer).toHaveBeenCalledWith( 'temp_server_name', expect.objectContaining({ @@ -1864,6 +1877,33 @@ describe('MCP Routes', () => { expect(mockRegistryInstance.addServer).not.toHaveBeenCalled(); }); + it('should redact secrets from create response', async () => { + const validConfig = { + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Test Server', + }; + + mockRegistryInstance.addServer.mockResolvedValue({ + serverName: 'test-server', + config: { + ...validConfig, + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'admin-secret-key' }, + oauth: { client_id: 'cid', client_secret: 'admin-oauth-secret' }, + headers: { Authorization: 'Bearer leaked-token' }, + }, + }); + + const response = await request(app).post('/api/mcp/servers').send({ config: validConfig }); + + expect(response.status).toBe(201); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.headers).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_id).toBe('cid'); + }); + it('should return 500 when registry throws error', async () => { const validConfig = { type: 'sse', @@ -1893,7 +1933,9 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/servers/test-server'); expect(response.status).toBe(200); - expect(response.body).toEqual(mockConfig); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://mcp-server.example.com/sse'); + expect(response.body.title).toBe('Test Server'); expect(mockRegistryInstance.getServerConfig).toHaveBeenCalledWith( 'test-server', 'test-user-id', @@ -1909,6 +1951,29 @@ describe('MCP Routes', () => { expect(response.body).toEqual({ message: 'MCP server not found' }); }); + it('should redact secrets from get response', async () => { + mockRegistryInstance.getServerConfig.mockResolvedValue({ + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Secret Server', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'decrypted-admin-key' }, + oauth: { client_id: 'cid', client_secret: 'decrypted-oauth-secret' }, + headers: { Authorization: 'Bearer internal-token' }, + oauth_headers: { 'X-OAuth': 'secret-value' }, + }); + + const response = await request(app).get('/api/mcp/servers/secret-server'); + + expect(response.status).toBe(200); + expect(response.body.title).toBe('Secret Server'); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.oauth?.client_id).toBe('cid'); + expect(response.body.headers).toBeUndefined(); + expect(response.body.oauth_headers).toBeUndefined(); + }); + it('should return 500 when registry throws error', async () => { mockRegistryInstance.getServerConfig.mockRejectedValue(new Error('Database error')); @@ -1935,7 +2000,9 @@ describe('MCP Routes', () => { .send({ config: updatedConfig }); expect(response.status).toBe(200); - expect(response.body).toEqual(updatedConfig); + expect(response.body.type).toBe('sse'); + expect(response.body.url).toBe('https://updated-mcp-server.example.com/sse'); + expect(response.body.title).toBe('Updated Server'); expect(mockRegistryInstance.updateServer).toHaveBeenCalledWith( 'test-server', expect.objectContaining({ @@ -1947,6 +2014,35 @@ describe('MCP Routes', () => { ); }); + it('should redact secrets from update response', async () => { + const validConfig = { + type: 'sse', + url: 'https://mcp-server.example.com/sse', + title: 'Updated Server', + }; + + mockRegistryInstance.updateServer.mockResolvedValue({ + ...validConfig, + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'preserved-admin-key' }, + oauth: { client_id: 'cid', client_secret: 'preserved-oauth-secret' }, + headers: { Authorization: 'Bearer internal-token' }, + env: { DATABASE_URL: 'postgres://admin:pass@localhost/db' }, + }); + + const response = await request(app) + .patch('/api/mcp/servers/test-server') + .send({ config: validConfig }); + + expect(response.status).toBe(200); + expect(response.body.title).toBe('Updated Server'); + expect(response.body.apiKey?.key).toBeUndefined(); + expect(response.body.apiKey?.source).toBe('admin'); + expect(response.body.oauth?.client_secret).toBeUndefined(); + expect(response.body.oauth?.client_id).toBe('cid'); + expect(response.body.headers).toBeUndefined(); + expect(response.body.env).toBeUndefined(); + }); + it('should return 400 for invalid configuration', async () => { const invalidConfig = { type: 'sse', diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 0afac81192..57a99d199a 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -50,6 +50,18 @@ const router = Router(); const OAUTH_CSRF_COOKIE_PATH = '/api/mcp'; +const checkMCPUsePermissions = generateCheckAccess({ + permissionType: PermissionTypes.MCP_SERVERS, + permissions: [Permissions.USE], + getRoleByName, +}); + +const checkMCPCreate = generateCheckAccess({ + permissionType: PermissionTypes.MCP_SERVERS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); + /** * Get all MCP tools available to the user * Returns only MCP tools, completely decoupled from regular LibreChat tools @@ -470,69 +482,75 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => { * Reinitialize MCP server * This endpoint allows reinitializing a specific MCP server */ -router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => { - try { - const { serverName } = req.params; - const user = createSafeUser(req.user); +router.post( + '/:serverName/reinitialize', + requireJwtAuth, + checkMCPUsePermissions, + setOAuthSession, + async (req, res) => { + try { + const { serverName } = req.params; + const user = createSafeUser(req.user); - if (!user.id) { - return res.status(401).json({ error: 'User not authenticated' }); - } + if (!user.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } - logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); + logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); - const mcpManager = getMCPManager(); - const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); - if (!serverConfig) { - return res.status(404).json({ - error: `MCP server '${serverName}' not found in configuration`, + const mcpManager = getMCPManager(); + const serverConfig = await getMCPServersRegistry().getServerConfig(serverName, user.id); + if (!serverConfig) { + return res.status(404).json({ + error: `MCP server '${serverName}' not found in configuration`, + }); + } + + await mcpManager.disconnectUserConnection(user.id, serverName); + logger.info( + `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, + ); + + /** @type {Record> | undefined} */ + let userMCPAuthMap; + if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { + userMCPAuthMap = await getUserMCPAuthMap({ + userId: user.id, + servers: [serverName], + findPluginAuthsByKeys, + }); + } + + const result = await reinitMCPServer({ + user, + serverName, + userMCPAuthMap, }); - } - await mcpManager.disconnectUserConnection(user.id, serverName); - logger.info( - `[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`, - ); + if (!result) { + return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); + } - /** @type {Record> | undefined} */ - let userMCPAuthMap; - if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') { - userMCPAuthMap = await getUserMCPAuthMap({ - userId: user.id, - servers: [serverName], - findPluginAuthsByKeys, + const { success, message, oauthRequired, oauthUrl } = result; + + if (oauthRequired) { + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + } + + res.json({ + success, + message, + oauthUrl, + serverName, + oauthRequired, }); + } catch (error) { + logger.error('[MCP Reinitialize] Unexpected error', error); + res.status(500).json({ error: 'Internal server error' }); } - - const result = await reinitMCPServer({ - user, - serverName, - userMCPAuthMap, - }); - - if (!result) { - return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' }); - } - - const { success, message, oauthRequired, oauthUrl } = result; - - if (oauthRequired) { - const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); - setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); - } - - res.json({ - success, - message, - oauthUrl, - serverName, - oauthRequired, - }); - } catch (error) { - logger.error('[MCP Reinitialize] Unexpected error', error); - res.status(500).json({ error: 'Internal server error' }); - } -}); + }, +); /** * Get connection status for all MCP servers @@ -639,7 +657,7 @@ router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => * Check which authentication values exist for a specific MCP server * This endpoint returns only boolean flags indicating if values are set, not the actual values */ -router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { +router.get('/:serverName/auth-values', requireJwtAuth, checkMCPUsePermissions, async (req, res) => { try { const { serverName } = req.params; const user = req.user; @@ -696,19 +714,6 @@ async function getOAuthHeaders(serverName, userId) { MCP Server CRUD Routes (User-Managed MCP Servers) */ -// Permission checkers for MCP server management -const checkMCPUsePermissions = generateCheckAccess({ - permissionType: PermissionTypes.MCP_SERVERS, - permissions: [Permissions.USE], - getRoleByName, -}); - -const checkMCPCreate = generateCheckAccess({ - permissionType: PermissionTypes.MCP_SERVERS, - permissions: [Permissions.USE, Permissions.CREATE], - getRoleByName, -}); - /** * Get list of accessible MCP servers * @route GET /api/mcp/servers diff --git a/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts index 9bad4dcfac..f1558db795 100644 --- a/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts +++ b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts @@ -32,14 +32,22 @@ describe('LeaderElection with Redis', () => { process.setMaxListeners(200); }); - afterEach(async () => { - await Promise.all(instances.map((instance) => instance.resign())); - instances = []; - - // Clean up: clear the leader key directly from Redis + beforeEach(async () => { if (keyvRedisClient) { await keyvRedisClient.del(LeaderElection.LEADER_KEY); } + new LeaderElection().clearRefreshTimer(); + }); + + afterEach(async () => { + try { + await Promise.all(instances.map((instance) => instance.resign())); + } finally { + instances = []; + if (keyvRedisClient) { + await keyvRedisClient.del(LeaderElection.LEADER_KEY); + } + } }); afterAll(async () => { diff --git a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts index 8437177c86..f73a5ed3e8 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts @@ -24,6 +24,13 @@ jest.mock('@librechat/data-schemas', () => ({ decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), })); +/** Bypass SSRF validation — these tests use real local HTTP servers. */ +jest.mock('~/auth', () => ({ + ...jest.requireActual('~/auth'), + isSSRFTarget: jest.fn(() => false), + resolveHostnameSSRF: jest.fn(async () => false), +})); + describe('MCP OAuth Flow — Real HTTP Server', () => { afterEach(() => { jest.clearAllMocks(); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts new file mode 100644 index 0000000000..a5188e24b0 --- /dev/null +++ b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts @@ -0,0 +1,228 @@ +/** + * Tests verifying MCP OAuth security hardening: + * + * 1. SSRF via OAuth URLs — validates that the OAuth handler rejects + * token_url, authorization_url, and revocation_endpoint values + * pointing to private/internal addresses. + * + * 2. redirect_uri manipulation — validates that user-supplied redirect_uri + * is ignored in favor of the server-controlled default. + */ + +import * as http from 'http'; +import * as net from 'net'; +import { TokenExchangeMethodEnum } from 'librechat-data-provider'; +import type { Socket } from 'net'; +import type { OAuthTestServer } from './helpers/oauthTestServer'; +import { createOAuthMCPServer } from './helpers/oauthTestServer'; +import { MCPOAuthHandler } from '~/mcp/oauth'; + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, + encryptV2: jest.fn(async (val: string) => `enc:${val}`), + decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')), +})); + +/** + * Mock only the DNS-dependent resolveHostnameSSRF; keep isSSRFTarget real. + * SSRF tests use literal private IPs (127.0.0.1, 169.254.169.254, 10.0.0.1) + * which are caught by isSSRFTarget before resolveHostnameSSRF is reached. + * This avoids non-deterministic DNS lookups in test execution. + */ +jest.mock('~/auth', () => ({ + ...jest.requireActual('~/auth'), + resolveHostnameSSRF: jest.fn(async () => false), +})); + +function getFreePort(): Promise { + return new Promise((resolve, reject) => { + const srv = net.createServer(); + srv.listen(0, '127.0.0.1', () => { + const addr = srv.address() as net.AddressInfo; + srv.close((err) => (err ? reject(err) : resolve(addr.port))); + }); + }); +} + +function trackSockets(httpServer: http.Server): () => Promise { + const sockets = new Set(); + httpServer.on('connection', (socket: Socket) => { + sockets.add(socket); + socket.once('close', () => sockets.delete(socket)); + }); + return () => + new Promise((resolve) => { + for (const socket of sockets) { + socket.destroy(); + } + sockets.clear(); + httpServer.close(() => resolve()); + }); +} + +describe('MCP OAuth SSRF protection', () => { + let oauthServer: OAuthTestServer; + let ssrfTargetServer: http.Server; + let ssrfTargetPort: number; + let ssrfRequestReceived: boolean; + let destroySSRFSockets: () => Promise; + + beforeEach(async () => { + ssrfRequestReceived = false; + + oauthServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + + ssrfTargetPort = await getFreePort(); + ssrfTargetServer = http.createServer((_req, res) => { + ssrfRequestReceived = true; + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + access_token: 'ssrf-token', + token_type: 'Bearer', + expires_in: 3600, + }), + ); + }); + destroySSRFSockets = trackSockets(ssrfTargetServer); + await new Promise((resolve) => + ssrfTargetServer.listen(ssrfTargetPort, '127.0.0.1', resolve), + ); + }); + + afterEach(async () => { + try { + await oauthServer.close(); + } finally { + await destroySSRFSockets(); + } + }); + + it('should reject token_url pointing to a private IP (refreshOAuthTokens)', async () => { + const code = await oauthServer.getAuthCode(); + const tokenRes = await fetch(`${oauthServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const regRes = await fetch(`${oauthServer.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const clientInfo = (await regRes.json()) as { + client_id: string; + client_secret: string; + }; + + const ssrfTokenUrl = `http://127.0.0.1:${ssrfTargetPort}/latest/meta-data/iam/security-credentials/`; + + await expect( + MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'ssrf-test-server', + serverUrl: oauthServer.url, + clientInfo: { + ...clientInfo, + redirect_uris: ['http://localhost/callback'], + }, + }, + {}, + { + token_url: ssrfTokenUrl, + client_id: clientInfo.client_id, + client_secret: clientInfo.client_secret, + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, + }, + ), + ).rejects.toThrow(/targets a blocked address/); + + expect(ssrfRequestReceived).toBe(false); + }); + + it('should reject private authorization_url in initiateOAuthFlow', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://169.254.169.254/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should reject private token_url in initiateOAuthFlow', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'https://auth.example.com/authorize', + token_url: `http://127.0.0.1:${ssrfTargetPort}/token`, + client_id: 'client', + client_secret: 'secret', + }, + ), + ).rejects.toThrow(/targets a blocked address/); + + expect(ssrfRequestReceived).toBe(false); + }); + + it('should reject private revocationEndpoint in revokeOAuthToken', async () => { + await expect( + MCPOAuthHandler.revokeOAuthToken('test-server', 'some-token', 'access', { + serverUrl: 'https://mcp.example.com/', + clientId: 'client', + clientSecret: 'secret', + revocationEndpoint: 'http://10.0.0.1/revoke', + }), + ).rejects.toThrow(/targets a blocked address/); + }); +}); + +describe('MCP OAuth redirect_uri enforcement', () => { + it('should ignore attacker-supplied redirect_uri and use the server default', async () => { + const attackerRedirectUri = 'https://attacker.example.com/steal-code'; + + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'victim-server', + 'https://mcp.example.com/', + 'victim-user-id', + {}, + { + authorization_url: 'https://auth.example.com/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'attacker-client', + client_secret: 'attacker-secret', + redirect_uri: attackerRedirectUri, + }, + ); + + const authUrl = new URL(result.authorizationUrl); + const expectedRedirectUri = `${process.env.DOMAIN_SERVER || 'http://localhost:3080'}/api/mcp/victim-server/oauth/callback`; + expect(authUrl.searchParams.get('redirect_uri')).toBe(expectedRedirectUri); + expect(authUrl.searchParams.get('redirect_uri')).not.toBe(attackerRedirectUri); + }); +}); diff --git a/packages/api/src/mcp/__tests__/utils.test.ts b/packages/api/src/mcp/__tests__/utils.test.ts index 716a230ebe..e4fb31bdad 100644 --- a/packages/api/src/mcp/__tests__/utils.test.ts +++ b/packages/api/src/mcp/__tests__/utils.test.ts @@ -1,4 +1,5 @@ -import { normalizeServerName } from '../utils'; +import { normalizeServerName, redactServerSecrets, redactAllServerSecrets } from '~/mcp/utils'; +import type { ParsedServerConfig } from '~/mcp/types'; describe('normalizeServerName', () => { it('should not modify server names that already match the pattern', () => { @@ -26,3 +27,201 @@ describe('normalizeServerName', () => { expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/); }); }); + +describe('redactServerSecrets', () => { + it('should strip apiKey.key from admin-sourced keys', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'super-secret-api-key', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.source).toBe('admin'); + expect(redacted.apiKey?.authorization_type).toBe('bearer'); + }); + + it('should strip oauth.client_secret', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + oauth: { + client_id: 'my-client', + client_secret: 'super-secret-oauth', + scope: 'read', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.oauth?.client_secret).toBeUndefined(); + expect(redacted.oauth?.client_id).toBe('my-client'); + expect(redacted.oauth?.scope).toBe('read'); + }); + + it('should strip both apiKey.key and oauth.client_secret simultaneously', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { + source: 'admin', + authorization_type: 'custom', + custom_header: 'X-API-Key', + key: 'secret-key', + }, + oauth: { + client_id: 'cid', + client_secret: 'csecret', + }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.custom_header).toBe('X-API-Key'); + expect(redacted.oauth?.client_secret).toBeUndefined(); + expect(redacted.oauth?.client_id).toBe('cid'); + }); + + it('should exclude headers from SSE configs', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'SSE Server', + }; + (config as ParsedServerConfig & { headers: Record }).headers = { + Authorization: 'Bearer admin-token-123', + 'X-Custom': 'safe-value', + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).headers).toBeUndefined(); + expect(redacted.title).toBe('SSE Server'); + }); + + it('should exclude env from stdio configs', () => { + const config: ParsedServerConfig = { + type: 'stdio', + command: 'node', + args: ['server.js'], + env: { DATABASE_URL: 'postgres://admin:password@localhost/db', PATH: '/usr/bin' }, + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).env).toBeUndefined(); + expect((redacted as Record).command).toBeUndefined(); + expect((redacted as Record).args).toBeUndefined(); + }); + + it('should exclude oauth_headers', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + oauth_headers: { Authorization: 'Bearer oauth-admin-token' }, + }; + const redacted = redactServerSecrets(config); + expect((redacted as Record).oauth_headers).toBeUndefined(); + }); + + it('should strip apiKey.key even for user-sourced keys', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { source: 'user', authorization_type: 'bearer', key: 'my-own-key' }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.apiKey?.key).toBeUndefined(); + expect(redacted.apiKey?.source).toBe('user'); + }); + + it('should not mutate the original config', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'secret' }, + oauth: { client_id: 'cid', client_secret: 'csecret' }, + }; + redactServerSecrets(config); + expect(config.apiKey?.key).toBe('secret'); + expect(config.oauth?.client_secret).toBe('csecret'); + }); + + it('should preserve all safe metadata fields', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'My Server', + description: 'A test server', + iconPath: '/icons/test.png', + chatMenu: true, + requiresOAuth: false, + capabilities: '{"tools":{}}', + tools: 'tool_a, tool_b', + dbId: 'abc123', + updatedAt: 1700000000000, + consumeOnly: false, + inspectionFailed: false, + customUserVars: { API_KEY: { title: 'API Key', description: 'Your key' } }, + }; + const redacted = redactServerSecrets(config); + expect(redacted.title).toBe('My Server'); + expect(redacted.description).toBe('A test server'); + expect(redacted.iconPath).toBe('/icons/test.png'); + expect(redacted.chatMenu).toBe(true); + expect(redacted.requiresOAuth).toBe(false); + expect(redacted.capabilities).toBe('{"tools":{}}'); + expect(redacted.tools).toBe('tool_a, tool_b'); + expect(redacted.dbId).toBe('abc123'); + expect(redacted.updatedAt).toBe(1700000000000); + expect(redacted.consumeOnly).toBe(false); + expect(redacted.inspectionFailed).toBe(false); + expect(redacted.customUserVars).toEqual(config.customUserVars); + }); + + it('should pass URLs through unchanged', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://mcp.example.com/sse?param=value', + }; + const redacted = redactServerSecrets(config); + expect(redacted.url).toBe('https://mcp.example.com/sse?param=value'); + }); + + it('should only include explicitly allowlisted fields', () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Test', + }; + (config as Record).someNewSensitiveField = 'leaked-value'; + const redacted = redactServerSecrets(config); + expect((redacted as Record).someNewSensitiveField).toBeUndefined(); + expect(redacted.title).toBe('Test'); + }); +}); + +describe('redactAllServerSecrets', () => { + it('should redact secrets from all configs in the map', () => { + const configs: Record = { + 'server-a': { + type: 'sse', + url: 'https://a.com/mcp', + apiKey: { source: 'admin', authorization_type: 'bearer', key: 'key-a' }, + }, + 'server-b': { + type: 'sse', + url: 'https://b.com/mcp', + oauth: { client_id: 'cid-b', client_secret: 'secret-b' }, + }, + 'server-c': { + type: 'stdio', + command: 'node', + args: ['c.js'], + }, + }; + const redacted = redactAllServerSecrets(configs); + expect(redacted['server-a'].apiKey?.key).toBeUndefined(); + expect(redacted['server-a'].apiKey?.source).toBe('admin'); + expect(redacted['server-b'].oauth?.client_secret).toBeUndefined(); + expect(redacted['server-b'].oauth?.client_id).toBe('cid-b'); + expect((redacted['server-c'] as Record).command).toBeUndefined(); + }); +}); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 366d0d2fde..8d863bfe79 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -24,6 +24,7 @@ import { selectRegistrationAuthMethod, inferClientAuthMethod, } from './methods'; +import { isSSRFTarget, resolveHostnameSSRF } from '~/auth'; import { sanitizeUrlForLogging } from '~/mcp/utils'; /** Type for the OAuth metadata from the SDK */ @@ -144,7 +145,9 @@ export class MCPOAuthHandler { resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn); if (resourceMetadata?.authorization_servers?.length) { - authServerUrl = new URL(resourceMetadata.authorization_servers[0]); + const discoveredAuthServer = resourceMetadata.authorization_servers[0]; + await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server'); + authServerUrl = new URL(discoveredAuthServer); logger.debug( `[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`, ); @@ -200,6 +203,19 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] OAuth metadata discovered successfully`); const metadata = await OAuthMetadataSchema.parseAsync(rawMetadata); + const endpointChecks: Promise[] = []; + if (metadata.registration_endpoint) { + endpointChecks.push( + this.validateOAuthUrl(metadata.registration_endpoint, 'registration_endpoint'), + ); + } + if (metadata.token_endpoint) { + endpointChecks.push(this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint')); + } + if (endpointChecks.length > 0) { + await Promise.all(endpointChecks); + } + logger.debug(`[MCPOAuth] OAuth metadata parsed successfully`); return { metadata: metadata as unknown as OAuthMetadata, @@ -355,10 +371,14 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] Generated flowId: ${flowId}, state: ${state}`); try { - // Check if we have pre-configured OAuth settings if (config?.authorization_url && config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`); + await Promise.all([ + this.validateOAuthUrl(config.authorization_url, 'authorization_url'), + this.validateOAuthUrl(config.token_url, 'token_url'), + ]); + const skipCodeChallengeCheck = config?.skip_code_challenge_check === true || process.env.MCP_SKIP_CODE_CHALLENGE_CHECK === 'true'; @@ -410,10 +430,11 @@ export class MCPOAuthHandler { code_challenge_methods_supported: codeChallengeMethodsSupported, }; logger.debug(`[MCPOAuth] metadata for "${serverName}": ${JSON.stringify(metadata)}`); + const redirectUri = this.getDefaultRedirectUri(serverName); const clientInfo: OAuthClientInformation = { client_id: config.client_id, client_secret: config.client_secret, - redirect_uris: [config.redirect_uri || this.getDefaultRedirectUri(serverName)], + redirect_uris: [redirectUri], scope: config.scope, token_endpoint_auth_method: tokenEndpointAuthMethod, }; @@ -422,7 +443,7 @@ export class MCPOAuthHandler { const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { metadata: metadata as unknown as SDKOAuthMetadata, clientInformation: clientInfo, - redirectUrl: clientInfo.redirect_uris?.[0] || this.getDefaultRedirectUri(serverName), + redirectUrl: redirectUri, scope: config.scope, }); @@ -462,8 +483,7 @@ export class MCPOAuthHandler { `[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`, ); - /** Dynamic client registration based on the discovered metadata */ - const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName); + const redirectUri = this.getDefaultRedirectUri(serverName); logger.debug(`[MCPOAuth] Registering OAuth client with redirect URI: ${redirectUri}`); const clientInfo = await this.registerOAuthClient( @@ -672,6 +692,24 @@ export class MCPOAuthHandler { return randomBytes(32).toString('base64url'); } + /** Validates an OAuth URL is not targeting a private/internal address */ + private static async validateOAuthUrl(url: string, fieldName: string): Promise { + let hostname: string; + try { + hostname = new URL(url).hostname; + } catch { + throw new Error(`Invalid OAuth ${fieldName}: ${sanitizeUrlForLogging(url)}`); + } + + if (isSSRFTarget(hostname)) { + throw new Error(`OAuth ${fieldName} targets a blocked address`); + } + + if (await resolveHostnameSSRF(hostname)) { + throw new Error(`OAuth ${fieldName} resolves to a private IP address`); + } + } + private static readonly STATE_MAP_TYPE = 'mcp_oauth_state'; /** @@ -783,10 +821,10 @@ export class MCPOAuthHandler { scope: metadata.clientInfo.scope, }); - /** Use the stored client information and metadata to determine the token URL */ let tokenUrl: string; let authMethods: string[] | undefined; if (config?.token_url) { + await this.validateOAuthUrl(config.token_url, 'token_url'); tokenUrl = config.token_url; authMethods = config.token_endpoint_auth_methods_supported; } else if (!metadata.serverUrl) { @@ -813,6 +851,7 @@ export class MCPOAuthHandler { tokenUrl = oauthMetadata.token_endpoint; authMethods = oauthMetadata.token_endpoint_auth_methods_supported; } + await this.validateOAuthUrl(tokenUrl, 'token_url'); } const body = new URLSearchParams({ @@ -886,10 +925,10 @@ export class MCPOAuthHandler { return this.processRefreshResponse(tokens, metadata.serverName, 'stored client info'); } - // Fallback: If we have pre-configured OAuth settings, use them if (config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`); + await this.validateOAuthUrl(config.token_url, 'token_url'); const tokenUrl = new URL(config.token_url); const body = new URLSearchParams({ @@ -987,6 +1026,7 @@ export class MCPOAuthHandler { } else { tokenUrl = new URL(oauthMetadata.token_endpoint); } + await this.validateOAuthUrl(tokenUrl.href, 'token_url'); const body = new URLSearchParams({ grant_type: 'refresh_token', @@ -1036,7 +1076,9 @@ export class MCPOAuthHandler { }, oauthHeaders: Record = {}, ): Promise { - // build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided + if (metadata.revocationEndpoint != null) { + await this.validateOAuthUrl(metadata.revocationEndpoint, 'revocation_endpoint'); + } const revokeUrl: URL = metadata.revocationEndpoint != null ? new URL(metadata.revocationEndpoint) diff --git a/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts b/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts index 1c755ae0f0..38ed51cd99 100644 --- a/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts +++ b/packages/api/src/mcp/registry/__tests__/ServerConfigsDB.test.ts @@ -1456,4 +1456,102 @@ describe('ServerConfigsDB', () => { expect(retrieved?.apiKey?.key).toBeUndefined(); }); }); + + describe('DB layer returns decrypted secrets (redaction is at controller layer)', () => { + it('should return decrypted apiKey.key to VIEW-only user via get()', async () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Secret API Key Server', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'admin-secret-api-key', + }, + }; + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.get(created.serverName, userId2); + expect(result).toBeDefined(); + expect(result?.apiKey?.key).toBe('admin-secret-api-key'); + }); + + it('should return decrypted oauth.client_secret to VIEW-only user via get()', async () => { + const config = createSSEConfig('Secret OAuth Server', 'Test', { + client_id: 'my-client-id', + client_secret: 'admin-oauth-secret', + }); + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.get(created.serverName, userId2); + expect(result).toBeDefined(); + expect(result?.oauth?.client_secret).toBe('admin-oauth-secret'); + }); + + it('should return decrypted secrets to VIEW-only user via getAll()', async () => { + const config: ParsedServerConfig = { + type: 'sse', + url: 'https://example.com/mcp', + title: 'Shared Secret Server', + apiKey: { + source: 'admin', + authorization_type: 'bearer', + key: 'shared-api-key', + }, + oauth: { + client_id: 'shared-client', + client_secret: 'shared-oauth-secret', + }, + }; + const created = await serverConfigsDB.add('temp-name', config, userId); + + const role = await mongoose.models.AccessRole.findOne({ + accessRoleId: AccessRoleIds.MCPSERVER_VIEWER, + }); + await mongoose.models.AclEntry.create({ + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: new mongoose.Types.ObjectId(userId2), + resourceType: ResourceType.MCPSERVER, + resourceId: new mongoose.Types.ObjectId(created.config.dbId!), + permBits: PermissionBits.VIEW, + roleId: role!._id, + grantedBy: new mongoose.Types.ObjectId(userId), + }); + + const result = await serverConfigsDB.getAll(userId2); + const serverConfig = result[created.serverName]; + expect(serverConfig).toBeDefined(); + expect(serverConfig?.apiKey?.key).toBe('shared-api-key'); + expect(serverConfig?.oauth?.client_secret).toBe('shared-oauth-secret'); + }); + }); }); diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index fddebb9db3..c517388a76 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -1,6 +1,66 @@ import { Constants } from 'librechat-data-provider'; +import type { ParsedServerConfig } from '~/mcp/types'; export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); + +/** + * Allowlist-based sanitization for API responses. Only explicitly listed fields are included; + * new fields added to ParsedServerConfig are excluded by default until allowlisted here. + * + * URLs are returned as-is: DB-stored configs reject ${VAR} patterns at validation time + * (MCPServerUserInputSchema), and YAML configs are admin-managed. Env variable resolution + * is handled at the schema/input boundary, not the output boundary. + */ +export function redactServerSecrets(config: ParsedServerConfig): Partial { + const safe: Partial = { + type: config.type, + url: config.url, + title: config.title, + description: config.description, + iconPath: config.iconPath, + chatMenu: config.chatMenu, + requiresOAuth: config.requiresOAuth, + capabilities: config.capabilities, + tools: config.tools, + toolFunctions: config.toolFunctions, + initDuration: config.initDuration, + updatedAt: config.updatedAt, + dbId: config.dbId, + consumeOnly: config.consumeOnly, + inspectionFailed: config.inspectionFailed, + customUserVars: config.customUserVars, + serverInstructions: config.serverInstructions, + }; + + if (config.apiKey) { + safe.apiKey = { + source: config.apiKey.source, + authorization_type: config.apiKey.authorization_type, + ...(config.apiKey.custom_header && { custom_header: config.apiKey.custom_header }), + }; + } + + if (config.oauth) { + const { client_secret: _secret, ...safeOAuth } = config.oauth; + safe.oauth = safeOAuth; + } + + return Object.fromEntries( + Object.entries(safe).filter(([, v]) => v !== undefined), + ) as Partial; +} + +/** Applies allowlist-based sanitization to a map of server configs. */ +export function redactAllServerSecrets( + configs: Record, +): Record> { + const result: Record> = {}; + for (const [key, config] of Object.entries(configs)) { + result[key] = redactServerSecrets(config); + } + return result; +} + /** * Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$ * This is required for Azure OpenAI models with Tool Calling From ca79a03135cde32eb112b4292ab02f80e83bcc69 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 13 Mar 2026 23:40:44 -0400 Subject: [PATCH 12/39] =?UTF-8?q?=F0=9F=9A=A6=20fix:=20Add=20Rate=20Limiti?= =?UTF-8?q?ng=20to=20Conversation=20Duplicate=20Endpoint=20(#12218)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: add rate limiting to conversation duplicate endpoint * chore: linter * fix: address review findings for conversation duplicate rate limiting * refactor: streamline test mocks for conversation routes - Consolidated mock implementations into a dedicated `convos-route-mocks.js` file to enhance maintainability and readability of test files. - Updated tests in `convos-duplicate-ratelimit.spec.js` and `convos.spec.js` to utilize the new mock structure, improving clarity and reducing redundancy. - Enhanced the `duplicateConversation` function to accept an optional title parameter for better flexibility in conversation duplication. * chore: rename files --- .../middleware/limiters/forkLimiters.js | 2 +- .../__test-utils__/convos-route-mocks.js | 92 ++++++++++++ .../convos-duplicate-ratelimit.spec.js | 135 ++++++++++++++++++ api/server/routes/__tests__/convos.spec.js | 119 +++------------ api/server/routes/convos.js | 3 +- api/server/utils/import/fork.js | 14 +- 6 files changed, 252 insertions(+), 113 deletions(-) create mode 100644 api/server/routes/__test-utils__/convos-route-mocks.js create mode 100644 api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js diff --git a/api/server/middleware/limiters/forkLimiters.js b/api/server/middleware/limiters/forkLimiters.js index e0aa65700c..f1e9b15f11 100644 --- a/api/server/middleware/limiters/forkLimiters.js +++ b/api/server/middleware/limiters/forkLimiters.js @@ -48,7 +48,7 @@ const createForkHandler = (ip = true) => { }; await logViolation(req, res, type, errorMessage, forkViolationScore); - res.status(429).json({ message: 'Too many conversation fork requests. Try again later' }); + res.status(429).json({ message: 'Too many requests. Try again later' }); }; }; diff --git a/api/server/routes/__test-utils__/convos-route-mocks.js b/api/server/routes/__test-utils__/convos-route-mocks.js new file mode 100644 index 0000000000..ca5bafeda9 --- /dev/null +++ b/api/server/routes/__test-utils__/convos-route-mocks.js @@ -0,0 +1,92 @@ +module.exports = { + agents: () => ({ sleep: jest.fn() }), + + api: (overrides = {}) => ({ + isEnabled: jest.fn(), + createAxiosInstance: jest.fn(() => ({ + get: jest.fn(), + post: jest.fn(), + put: jest.fn(), + delete: jest.fn(), + })), + logAxiosError: jest.fn(), + ...overrides, + }), + + dataSchemas: () => ({ + logger: { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }, + createModels: jest.fn(() => ({ + User: {}, + Conversation: {}, + Message: {}, + SharedLink: {}, + })), + }), + + dataProvider: (overrides = {}) => ({ + CacheKeys: { GEN_TITLE: 'GEN_TITLE' }, + EModelEndpoint: { + azureAssistants: 'azureAssistants', + assistants: 'assistants', + }, + ...overrides, + }), + + conversationModel: () => ({ + getConvosByCursor: jest.fn(), + getConvo: jest.fn(), + deleteConvos: jest.fn(), + saveConvo: jest.fn(), + }), + + toolCallModel: () => ({ deleteToolCalls: jest.fn() }), + + sharedModels: () => ({ + deleteAllSharedLinks: jest.fn(), + deleteConvoSharedLink: jest.fn(), + }), + + requireJwtAuth: () => (req, res, next) => next(), + + middlewarePassthrough: () => ({ + createImportLimiters: jest.fn(() => ({ + importIpLimiter: (req, res, next) => next(), + importUserLimiter: (req, res, next) => next(), + })), + createForkLimiters: jest.fn(() => ({ + forkIpLimiter: (req, res, next) => next(), + forkUserLimiter: (req, res, next) => next(), + })), + configMiddleware: (req, res, next) => next(), + validateConvoAccess: (req, res, next) => next(), + }), + + forkUtils: () => ({ + forkConversation: jest.fn(), + duplicateConversation: jest.fn(), + }), + + importUtils: () => ({ importConversations: jest.fn() }), + + logStores: () => jest.fn(), + + multerSetup: () => ({ + storage: {}, + importFileFilter: jest.fn(), + }), + + multerLib: () => + jest.fn(() => ({ + single: jest.fn(() => (req, res, next) => { + req.file = { path: '/tmp/test-file.json' }; + next(); + }), + })), + + assistantEndpoint: () => ({ initializeClient: jest.fn() }), +}; diff --git a/api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js b/api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js new file mode 100644 index 0000000000..788119a569 --- /dev/null +++ b/api/server/routes/__tests__/convos-duplicate-ratelimit.spec.js @@ -0,0 +1,135 @@ +const express = require('express'); +const request = require('supertest'); + +const MOCKS = '../__test-utils__/convos-route-mocks'; + +jest.mock('@librechat/agents', () => require(MOCKS).agents()); +jest.mock('@librechat/api', () => require(MOCKS).api({ limiterCache: jest.fn(() => undefined) })); +jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas()); +jest.mock('librechat-data-provider', () => + require(MOCKS).dataProvider({ ViolationTypes: { FILE_UPLOAD_LIMIT: 'file_upload_limit' } }), +); + +jest.mock('~/cache/logViolation', () => jest.fn().mockResolvedValue(undefined)); +jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores()); +jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel()); +jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel()); +jest.mock('~/models', () => require(MOCKS).sharedModels()); +jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth()); + +jest.mock('~/server/middleware', () => { + const { createForkLimiters } = jest.requireActual('~/server/middleware/limiters/forkLimiters'); + return { + createImportLimiters: jest.fn(() => ({ + importIpLimiter: (req, res, next) => next(), + importUserLimiter: (req, res, next) => next(), + })), + createForkLimiters, + configMiddleware: (req, res, next) => next(), + validateConvoAccess: (req, res, next) => next(), + }; +}); + +jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils()); +jest.mock('~/server/utils/import', () => require(MOCKS).importUtils()); +jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup()); +jest.mock('multer', () => require(MOCKS).multerLib()); +jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint()); +jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint()); + +describe('POST /api/convos/duplicate - Rate Limiting', () => { + let app; + let duplicateConversation; + const savedEnv = {}; + + beforeAll(() => { + savedEnv.FORK_USER_MAX = process.env.FORK_USER_MAX; + savedEnv.FORK_USER_WINDOW = process.env.FORK_USER_WINDOW; + savedEnv.FORK_IP_MAX = process.env.FORK_IP_MAX; + savedEnv.FORK_IP_WINDOW = process.env.FORK_IP_WINDOW; + }); + + afterAll(() => { + for (const key of Object.keys(savedEnv)) { + if (savedEnv[key] === undefined) { + delete process.env[key]; + } else { + process.env[key] = savedEnv[key]; + } + } + }); + + const setupApp = () => { + jest.clearAllMocks(); + jest.isolateModules(() => { + const convosRouter = require('../convos'); + ({ duplicateConversation } = require('~/server/utils/import/fork')); + + app = express(); + app.use(express.json()); + app.use((req, res, next) => { + req.user = { id: 'rate-limit-test-user' }; + next(); + }); + app.use('/api/convos', convosRouter); + }); + + duplicateConversation.mockResolvedValue({ + conversation: { conversationId: 'duplicated-conv' }, + }); + }; + + describe('user limit', () => { + beforeEach(() => { + process.env.FORK_USER_MAX = '2'; + process.env.FORK_USER_WINDOW = '1'; + process.env.FORK_IP_MAX = '100'; + process.env.FORK_IP_WINDOW = '1'; + setupApp(); + }); + + it('should return 429 after exceeding the user rate limit', async () => { + const userMax = parseInt(process.env.FORK_USER_MAX, 10); + + for (let i = 0; i < userMax; i++) { + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(201); + } + + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(429); + expect(res.body.message).toMatch(/too many/i); + }); + }); + + describe('IP limit', () => { + beforeEach(() => { + process.env.FORK_USER_MAX = '100'; + process.env.FORK_USER_WINDOW = '1'; + process.env.FORK_IP_MAX = '2'; + process.env.FORK_IP_WINDOW = '1'; + setupApp(); + }); + + it('should return 429 after exceeding the IP rate limit', async () => { + const ipMax = parseInt(process.env.FORK_IP_MAX, 10); + + for (let i = 0; i < ipMax; i++) { + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(201); + } + + const res = await request(app) + .post('/api/convos/duplicate') + .send({ conversationId: 'conv-123' }); + expect(res.status).toBe(429); + expect(res.body.message).toMatch(/too many/i); + }); + }); +}); diff --git a/api/server/routes/__tests__/convos.spec.js b/api/server/routes/__tests__/convos.spec.js index 931ef006d0..3bdeac32db 100644 --- a/api/server/routes/__tests__/convos.spec.js +++ b/api/server/routes/__tests__/convos.spec.js @@ -1,109 +1,24 @@ const express = require('express'); const request = require('supertest'); -jest.mock('@librechat/agents', () => ({ - sleep: jest.fn(), -})); +const MOCKS = '../__test-utils__/convos-route-mocks'; -jest.mock('@librechat/api', () => ({ - isEnabled: jest.fn(), - createAxiosInstance: jest.fn(() => ({ - get: jest.fn(), - post: jest.fn(), - put: jest.fn(), - delete: jest.fn(), - })), - logAxiosError: jest.fn(), -})); - -jest.mock('@librechat/data-schemas', () => ({ - logger: { - debug: jest.fn(), - info: jest.fn(), - warn: jest.fn(), - error: jest.fn(), - }, - createModels: jest.fn(() => ({ - User: {}, - Conversation: {}, - Message: {}, - SharedLink: {}, - })), -})); - -jest.mock('~/models/Conversation', () => ({ - getConvosByCursor: jest.fn(), - getConvo: jest.fn(), - deleteConvos: jest.fn(), - saveConvo: jest.fn(), -})); - -jest.mock('~/models/ToolCall', () => ({ - deleteToolCalls: jest.fn(), -})); - -jest.mock('~/models', () => ({ - deleteAllSharedLinks: jest.fn(), - deleteConvoSharedLink: jest.fn(), -})); - -jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next()); - -jest.mock('~/server/middleware', () => ({ - createImportLimiters: jest.fn(() => ({ - importIpLimiter: (req, res, next) => next(), - importUserLimiter: (req, res, next) => next(), - })), - createForkLimiters: jest.fn(() => ({ - forkIpLimiter: (req, res, next) => next(), - forkUserLimiter: (req, res, next) => next(), - })), - configMiddleware: (req, res, next) => next(), - validateConvoAccess: (req, res, next) => next(), -})); - -jest.mock('~/server/utils/import/fork', () => ({ - forkConversation: jest.fn(), - duplicateConversation: jest.fn(), -})); - -jest.mock('~/server/utils/import', () => ({ - importConversations: jest.fn(), -})); - -jest.mock('~/cache/getLogStores', () => jest.fn()); - -jest.mock('~/server/routes/files/multer', () => ({ - storage: {}, - importFileFilter: jest.fn(), -})); - -jest.mock('multer', () => { - return jest.fn(() => ({ - single: jest.fn(() => (req, res, next) => { - req.file = { path: '/tmp/test-file.json' }; - next(); - }), - })); -}); - -jest.mock('librechat-data-provider', () => ({ - CacheKeys: { - GEN_TITLE: 'GEN_TITLE', - }, - EModelEndpoint: { - azureAssistants: 'azureAssistants', - assistants: 'assistants', - }, -})); - -jest.mock('~/server/services/Endpoints/azureAssistants', () => ({ - initializeClient: jest.fn(), -})); - -jest.mock('~/server/services/Endpoints/assistants', () => ({ - initializeClient: jest.fn(), -})); +jest.mock('@librechat/agents', () => require(MOCKS).agents()); +jest.mock('@librechat/api', () => require(MOCKS).api()); +jest.mock('@librechat/data-schemas', () => require(MOCKS).dataSchemas()); +jest.mock('librechat-data-provider', () => require(MOCKS).dataProvider()); +jest.mock('~/models/Conversation', () => require(MOCKS).conversationModel()); +jest.mock('~/models/ToolCall', () => require(MOCKS).toolCallModel()); +jest.mock('~/models', () => require(MOCKS).sharedModels()); +jest.mock('~/server/middleware/requireJwtAuth', () => require(MOCKS).requireJwtAuth()); +jest.mock('~/server/middleware', () => require(MOCKS).middlewarePassthrough()); +jest.mock('~/server/utils/import/fork', () => require(MOCKS).forkUtils()); +jest.mock('~/server/utils/import', () => require(MOCKS).importUtils()); +jest.mock('~/cache/getLogStores', () => require(MOCKS).logStores()); +jest.mock('~/server/routes/files/multer', () => require(MOCKS).multerSetup()); +jest.mock('multer', () => require(MOCKS).multerLib()); +jest.mock('~/server/services/Endpoints/azureAssistants', () => require(MOCKS).assistantEndpoint()); +jest.mock('~/server/services/Endpoints/assistants', () => require(MOCKS).assistantEndpoint()); describe('Convos Routes', () => { let app; diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index bb9c4ebea9..5f0c35fa0a 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -224,6 +224,7 @@ router.post('/update', validateConvoAccess, async (req, res) => { }); const { importIpLimiter, importUserLimiter } = createImportLimiters(); +/** Fork and duplicate share one rate-limit budget (same "clone" operation class) */ const { forkIpLimiter, forkUserLimiter } = createForkLimiters(); const upload = multer({ storage: storage, fileFilter: importFileFilter }); @@ -280,7 +281,7 @@ router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => { } }); -router.post('/duplicate', async (req, res) => { +router.post('/duplicate', forkIpLimiter, forkUserLimiter, async (req, res) => { const { conversationId, title } = req.body; try { diff --git a/api/server/utils/import/fork.js b/api/server/utils/import/fork.js index c4ce8cb5d4..f896de378c 100644 --- a/api/server/utils/import/fork.js +++ b/api/server/utils/import/fork.js @@ -358,16 +358,15 @@ function splitAtTargetLevel(messages, targetMessageId) { * @param {object} params - The parameters for duplicating the conversation. * @param {string} params.userId - The ID of the user duplicating the conversation. * @param {string} params.conversationId - The ID of the conversation to duplicate. + * @param {string} [params.title] - Optional title override for the duplicate. * @returns {Promise<{ conversation: TConversation, messages: TMessage[] }>} The duplicated conversation and messages. */ -async function duplicateConversation({ userId, conversationId }) { - // Get original conversation +async function duplicateConversation({ userId, conversationId, title }) { const originalConvo = await getConvo(userId, conversationId); if (!originalConvo) { throw new Error('Conversation not found'); } - // Get original messages const originalMessages = await getMessages({ user: userId, conversationId, @@ -383,14 +382,11 @@ async function duplicateConversation({ userId, conversationId }) { cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder); - const result = importBatchBuilder.finishConversation( - originalConvo.title, - new Date(), - originalConvo, - ); + const duplicateTitle = title || originalConvo.title; + const result = importBatchBuilder.finishConversation(duplicateTitle, new Date(), originalConvo); await importBatchBuilder.saveBatch(); logger.debug( - `user: ${userId} | New conversation "${originalConvo.title}" duplicated from conversation ID ${conversationId}`, + `user: ${userId} | New conversation "${duplicateTitle}" duplicated from conversation ID ${conversationId}`, ); const conversation = await getConvo(userId, result.conversation.conversationId); From 189cdf581d1e5f4f565894f6a9fd3b2704f4fd20 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 13 Mar 2026 23:42:37 -0400 Subject: [PATCH 13/39] =?UTF-8?q?=F0=9F=94=90=20fix:=20Add=20User=20Filter?= =?UTF-8?q?=20to=20Message=20Deletion=20(#12220)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: add user filter to message deletion to prevent IDOR * refactor: streamline DELETE request syntax in messages-delete test - Simplified the DELETE request syntax in the messages-delete.spec.js test file by combining multiple lines into a single line for improved readability. This change enhances the clarity of the test code without altering its functionality. * fix: address review findings for message deletion IDOR fix * fix: add user filter to message deletion in conversation tests - Included a user filter in the message deletion test to ensure proper handling of user-specific deletions, enhancing the accuracy of the test case and preventing potential IDOR vulnerabilities. * chore: lint --- api/models/Conversation.js | 3 +- api/models/Conversation.spec.js | 1 + .../routes/__tests__/messages-delete.spec.js | 200 ++++++++++++++++++ api/server/routes/messages.js | 4 +- 4 files changed, 205 insertions(+), 3 deletions(-) create mode 100644 api/server/routes/__tests__/messages-delete.spec.js diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 32eac1a764..121eaa9696 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -228,7 +228,7 @@ module.exports = { }, ], }; - } catch (err) { + } catch (_err) { logger.warn('[getConvosByCursor] Invalid cursor format, starting from beginning'); } if (cursorFilter) { @@ -361,6 +361,7 @@ module.exports = { const deleteMessagesResult = await deleteMessages({ conversationId: { $in: conversationIds }, + user, }); return { ...deleteConvoResult, messages: deleteMessagesResult }; diff --git a/api/models/Conversation.spec.js b/api/models/Conversation.spec.js index bd415b4165..e9e4b5762d 100644 --- a/api/models/Conversation.spec.js +++ b/api/models/Conversation.spec.js @@ -549,6 +549,7 @@ describe('Conversation Operations', () => { expect(result.messages.deletedCount).toBe(5); expect(deleteMessages).toHaveBeenCalledWith({ conversationId: { $in: [mockConversationData.conversationId] }, + user: 'user123', }); // Verify conversation was deleted diff --git a/api/server/routes/__tests__/messages-delete.spec.js b/api/server/routes/__tests__/messages-delete.spec.js new file mode 100644 index 0000000000..e134eecfd0 --- /dev/null +++ b/api/server/routes/__tests__/messages-delete.spec.js @@ -0,0 +1,200 @@ +const mongoose = require('mongoose'); +const express = require('express'); +const request = require('supertest'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +jest.mock('@librechat/agents', () => ({ + sleep: jest.fn(), +})); + +jest.mock('@librechat/api', () => ({ + unescapeLaTeX: jest.fn((x) => x), + countTokens: jest.fn().mockResolvedValue(10), +})); + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }, +})); + +jest.mock('librechat-data-provider', () => ({ + ...jest.requireActual('librechat-data-provider'), +})); + +jest.mock('~/models', () => ({ + saveConvo: jest.fn(), + getMessage: jest.fn(), + saveMessage: jest.fn(), + getMessages: jest.fn(), + updateMessage: jest.fn(), + deleteMessages: jest.fn(), +})); + +jest.mock('~/server/services/Artifacts/update', () => ({ + findAllArtifacts: jest.fn(), + replaceArtifactContent: jest.fn(), +})); + +jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next()); + +jest.mock('~/server/middleware', () => ({ + requireJwtAuth: (req, res, next) => next(), + validateMessageReq: (req, res, next) => next(), +})); + +jest.mock('~/models/Conversation', () => ({ + getConvosQueried: jest.fn(), +})); + +jest.mock('~/db/models', () => ({ + Message: { + findOne: jest.fn(), + find: jest.fn(), + meiliSearch: jest.fn(), + }, +})); + +/* ─── Model-level tests: real MongoDB, proves cross-user deletion is prevented ─── */ + +const { messageSchema } = require('@librechat/data-schemas'); + +describe('deleteMessages – model-level IDOR prevention', () => { + let mongoServer; + let Message; + + const ownerUserId = 'user-owner-111'; + const attackerUserId = 'user-attacker-222'; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + Message = mongoose.models.Message || mongoose.model('Message', messageSchema); + await mongoose.connect(mongoServer.getUri()); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Message.deleteMany({}); + }); + + it("should NOT delete another user's message when attacker supplies victim messageId", async () => { + const conversationId = uuidv4(); + const victimMsgId = 'victim-msg-001'; + + await Message.create({ + messageId: victimMsgId, + conversationId, + user: ownerUserId, + text: 'Sensitive owner data', + }); + + await Message.deleteMany({ messageId: victimMsgId, user: attackerUserId }); + + const victimMsg = await Message.findOne({ messageId: victimMsgId }).lean(); + expect(victimMsg).not.toBeNull(); + expect(victimMsg.user).toBe(ownerUserId); + expect(victimMsg.text).toBe('Sensitive owner data'); + }); + + it("should delete the user's own message", async () => { + const conversationId = uuidv4(); + const ownMsgId = 'own-msg-001'; + + await Message.create({ + messageId: ownMsgId, + conversationId, + user: ownerUserId, + text: 'My message', + }); + + const result = await Message.deleteMany({ messageId: ownMsgId, user: ownerUserId }); + expect(result.deletedCount).toBe(1); + + const deleted = await Message.findOne({ messageId: ownMsgId }).lean(); + expect(deleted).toBeNull(); + }); + + it('should scope deletion by conversationId, messageId, and user together', async () => { + const convoA = uuidv4(); + const convoB = uuidv4(); + + await Message.create([ + { messageId: 'msg-a1', conversationId: convoA, user: ownerUserId, text: 'A1' }, + { messageId: 'msg-b1', conversationId: convoB, user: ownerUserId, text: 'B1' }, + ]); + + await Message.deleteMany({ messageId: 'msg-a1', conversationId: convoA, user: attackerUserId }); + + const remaining = await Message.find({ user: ownerUserId }).lean(); + expect(remaining).toHaveLength(2); + }); +}); + +/* ─── Route-level tests: supertest + mocked deleteMessages ─── */ + +describe('DELETE /:conversationId/:messageId – route handler', () => { + let app; + const { deleteMessages } = require('~/models'); + + const authenticatedUserId = 'user-owner-123'; + + beforeAll(() => { + const messagesRouter = require('../messages'); + + app = express(); + app.use(express.json()); + app.use((req, res, next) => { + req.user = { id: authenticatedUserId }; + next(); + }); + app.use('/api/messages', messagesRouter); + }); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should pass user and conversationId in the deleteMessages filter', async () => { + deleteMessages.mockResolvedValue({ deletedCount: 1 }); + + await request(app).delete('/api/messages/convo-1/msg-1'); + + expect(deleteMessages).toHaveBeenCalledTimes(1); + expect(deleteMessages).toHaveBeenCalledWith({ + messageId: 'msg-1', + conversationId: 'convo-1', + user: authenticatedUserId, + }); + }); + + it('should return 204 on successful deletion', async () => { + deleteMessages.mockResolvedValue({ deletedCount: 1 }); + + const response = await request(app).delete('/api/messages/convo-1/msg-owned'); + + expect(response.status).toBe(204); + expect(deleteMessages).toHaveBeenCalledWith({ + messageId: 'msg-owned', + conversationId: 'convo-1', + user: authenticatedUserId, + }); + }); + + it('should return 500 when deleteMessages throws', async () => { + deleteMessages.mockRejectedValue(new Error('DB failure')); + + const response = await request(app).delete('/api/messages/convo-1/msg-1'); + + expect(response.status).toBe(500); + expect(response.body).toEqual({ error: 'Internal server error' }); + }); +}); diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index c208e9c406..03286bc7f1 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -404,8 +404,8 @@ router.put('/:conversationId/:messageId/feedback', validateMessageReq, async (re router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => { try { - const { messageId } = req.params; - await deleteMessages({ messageId }); + const { conversationId, messageId } = req.params; + await deleteMessages({ messageId, conversationId, user: req.user.id }); res.status(204).send(); } catch (error) { logger.error('Error deleting message:', error); From 71a3b48504785362f9e705726990126014200095 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 01:51:31 -0400 Subject: [PATCH 14/39] =?UTF-8?q?=F0=9F=94=91=20fix:=20Require=20OTP=20Ver?= =?UTF-8?q?ification=20for=202FA=20Re-Enrollment=20and=20Backup=20Code=20R?= =?UTF-8?q?egeneration=20(#12223)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: require OTP verification for 2FA re-enrollment and backup code regeneration * fix: require OTP verification for account deletion when 2FA is enabled * refactor: Improve code formatting and readability in TwoFactorController and UserController - Reformatted code in TwoFactorController and UserController for better readability by aligning parameters and breaking long lines. - Updated test cases in deleteUser.spec.js and TwoFactorController.spec.js to enhance clarity by formatting object parameters consistently. * refactor: Consolidate OTP and backup code verification logic in TwoFactorController and UserController - Introduced a new `verifyOTPOrBackupCode` function to streamline the verification process for TOTP tokens and backup codes across multiple controllers. - Updated the `enable2FA`, `disable2FA`, and `deleteUserController` methods to utilize the new verification function, enhancing code reusability and readability. - Adjusted related tests to reflect the changes in verification logic, ensuring consistent behavior across different scenarios. - Improved error handling and response messages for verification failures, providing clearer feedback to users. * chore: linting * refactor: Update BackupCodesItem component to enhance OTP verification logic - Consolidated OTP input handling by moving the 2FA verification UI logic to a more consistent location within the component. - Improved the state management for OTP readiness, ensuring the regenerate button is only enabled when the OTP is ready. - Cleaned up imports by removing redundant type imports, enhancing code clarity and maintainability. * chore: lint * fix: stage 2FA re-enrollment in pending fields to prevent disarmament window enable2FA now writes to pendingTotpSecret/pendingBackupCodes instead of overwriting the live fields. confirm2FA performs the atomic swap only after the new TOTP code is verified. If the user abandons mid-flow, their existing 2FA remains active and intact. --- api/server/controllers/TwoFactorController.js | 107 +++++-- api/server/controllers/UserController.js | 18 ++ .../__tests__/TwoFactorController.spec.js | 264 +++++++++++++++ .../controllers/__tests__/deleteUser.spec.js | 302 ++++++++++++++++++ api/server/routes/auth.js | 2 +- api/server/services/twoFactorService.js | 54 +++- .../SettingsTabs/Account/BackupCodesItem.tsx | 96 +++++- .../SettingsTabs/Account/DeleteAccount.tsx | 92 +++++- client/src/data-provider/Auth/mutations.ts | 23 +- client/src/locales/en/translation.json | 1 + packages/data-provider/src/data-service.ts | 14 +- packages/data-provider/src/types.ts | 43 ++- packages/data-schemas/src/schema/user.ts | 9 + packages/data-schemas/src/types/user.ts | 6 + 14 files changed, 927 insertions(+), 104 deletions(-) create mode 100644 api/server/controllers/__tests__/TwoFactorController.spec.js create mode 100644 api/server/controllers/__tests__/deleteUser.spec.js diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js index fde5965261..18a0ee3f5a 100644 --- a/api/server/controllers/TwoFactorController.js +++ b/api/server/controllers/TwoFactorController.js @@ -1,5 +1,6 @@ const { encryptV3, logger } = require('@librechat/data-schemas'); const { + verifyOTPOrBackupCode, generateBackupCodes, generateTOTPSecret, verifyBackupCode, @@ -13,24 +14,42 @@ const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); /** * Enable 2FA for the user by generating a new TOTP secret and backup codes. * The secret is encrypted and stored, and 2FA is marked as disabled until confirmed. + * If 2FA is already enabled, requires OTP or backup code verification to re-enroll. */ const enable2FA = async (req, res) => { try { const userId = req.user.id; + const existingUser = await getUserById( + userId, + '+totpSecret +backupCodes _id twoFactorEnabled email', + ); + + if (existingUser && existingUser.twoFactorEnabled) { + const { token, backupCode } = req.body; + const result = await verifyOTPOrBackupCode({ + user: existingUser, + token, + backupCode, + persistBackupUse: false, + }); + + if (!result.verified) { + const msg = result.message ?? 'TOTP token or backup code is required to re-enroll 2FA'; + return res.status(result.status ?? 400).json({ message: msg }); + } + } + const secret = generateTOTPSecret(); const { plainCodes, codeObjects } = await generateBackupCodes(); - - // Encrypt the secret with v3 encryption before saving. const encryptedSecret = encryptV3(secret); - // Update the user record: store the secret & backup codes and set twoFactorEnabled to false. const user = await updateUser(userId, { - totpSecret: encryptedSecret, - backupCodes: codeObjects, - twoFactorEnabled: false, + pendingTotpSecret: encryptedSecret, + pendingBackupCodes: codeObjects, }); - const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`; + const email = user.email || (existingUser && existingUser.email) || ''; + const otpauthUrl = `otpauth://totp/${safeAppTitle}:${email}?secret=${secret}&issuer=${safeAppTitle}`; return res.status(200).json({ otpauthUrl, backupCodes: plainCodes }); } catch (err) { @@ -46,13 +65,14 @@ const verify2FA = async (req, res) => { try { const userId = req.user.id; const { token, backupCode } = req.body; - const user = await getUserById(userId, '_id totpSecret backupCodes'); + const user = await getUserById(userId, '+totpSecret +pendingTotpSecret +backupCodes _id'); + const secretSource = user?.pendingTotpSecret ?? user?.totpSecret; - if (!user || !user.totpSecret) { + if (!user || !secretSource) { return res.status(400).json({ message: '2FA not initiated' }); } - const secret = await getTOTPSecret(user.totpSecret); + const secret = await getTOTPSecret(secretSource); let isVerified = false; if (token) { @@ -78,15 +98,28 @@ const confirm2FA = async (req, res) => { try { const userId = req.user.id; const { token } = req.body; - const user = await getUserById(userId, '_id totpSecret'); + const user = await getUserById( + userId, + '+totpSecret +pendingTotpSecret +pendingBackupCodes _id', + ); + const secretSource = user?.pendingTotpSecret ?? user?.totpSecret; - if (!user || !user.totpSecret) { + if (!user || !secretSource) { return res.status(400).json({ message: '2FA not initiated' }); } - const secret = await getTOTPSecret(user.totpSecret); + const secret = await getTOTPSecret(secretSource); if (await verifyTOTP(secret, token)) { - await updateUser(userId, { twoFactorEnabled: true }); + const update = { + totpSecret: user.pendingTotpSecret ?? user.totpSecret, + twoFactorEnabled: true, + pendingTotpSecret: null, + pendingBackupCodes: [], + }; + if (user.pendingBackupCodes?.length) { + update.backupCodes = user.pendingBackupCodes; + } + await updateUser(userId, update); return res.status(200).json(); } return res.status(400).json({ message: 'Invalid token.' }); @@ -104,31 +137,27 @@ const disable2FA = async (req, res) => { try { const userId = req.user.id; const { token, backupCode } = req.body; - const user = await getUserById(userId, '_id totpSecret backupCodes'); + const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled'); if (!user || !user.totpSecret) { return res.status(400).json({ message: '2FA is not setup for this user' }); } if (user.twoFactorEnabled) { - const secret = await getTOTPSecret(user.totpSecret); - let isVerified = false; + const result = await verifyOTPOrBackupCode({ user, token, backupCode }); - if (token) { - isVerified = await verifyTOTP(secret, token); - } else if (backupCode) { - isVerified = await verifyBackupCode({ user, backupCode }); - } else { - return res - .status(400) - .json({ message: 'Either token or backup code is required to disable 2FA' }); - } - - if (!isVerified) { - return res.status(401).json({ message: 'Invalid token or backup code' }); + if (!result.verified) { + const msg = result.message ?? 'Either token or backup code is required to disable 2FA'; + return res.status(result.status ?? 400).json({ message: msg }); } } - await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false }); + await updateUser(userId, { + totpSecret: null, + backupCodes: [], + twoFactorEnabled: false, + pendingTotpSecret: null, + pendingBackupCodes: [], + }); return res.status(200).json(); } catch (err) { logger.error('[disable2FA]', err); @@ -138,10 +167,28 @@ const disable2FA = async (req, res) => { /** * Regenerate backup codes for the user. + * Requires OTP or backup code verification if 2FA is already enabled. */ const regenerateBackupCodes = async (req, res) => { try { const userId = req.user.id; + const user = await getUserById(userId, '+totpSecret +backupCodes _id twoFactorEnabled'); + + if (!user) { + return res.status(404).json({ message: 'User not found' }); + } + + if (user.twoFactorEnabled) { + const { token, backupCode } = req.body; + const result = await verifyOTPOrBackupCode({ user, token, backupCode }); + + if (!result.verified) { + const msg = + result.message ?? 'TOTP token or backup code is required to regenerate backup codes'; + return res.status(result.status ?? 400).json({ message: msg }); + } + } + const { plainCodes, codeObjects } = await generateBackupCodes(); await updateUser(userId, { backupCodes: codeObjects }); return res.status(200).json({ diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 7a9dd8125e..b3160bb3d3 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -14,6 +14,7 @@ const { deleteMessages, deletePresets, deleteUserKey, + getUserById, deleteConvos, deleteFiles, updateUser, @@ -34,6 +35,7 @@ const { User, } = require('~/db/models'); const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); +const { verifyOTPOrBackupCode } = require('~/server/services/twoFactorService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config'); const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools'); @@ -241,6 +243,22 @@ const deleteUserController = async (req, res) => { const { user } = req; try { + const existingUser = await getUserById( + user.id, + '+totpSecret +backupCodes _id twoFactorEnabled', + ); + if (existingUser && existingUser.twoFactorEnabled) { + const { token, backupCode } = req.body; + const result = await verifyOTPOrBackupCode({ user: existingUser, token, backupCode }); + + if (!result.verified) { + const msg = + result.message ?? + 'TOTP token or backup code is required to delete account with 2FA enabled'; + return res.status(result.status ?? 400).json({ message: msg }); + } + } + await deleteMessages({ user: user.id }); // delete user messages await deleteAllUserSessions({ userId: user.id }); // delete user sessions await Transaction.deleteMany({ user: user.id }); // delete user transactions diff --git a/api/server/controllers/__tests__/TwoFactorController.spec.js b/api/server/controllers/__tests__/TwoFactorController.spec.js new file mode 100644 index 0000000000..62531d94a1 --- /dev/null +++ b/api/server/controllers/__tests__/TwoFactorController.spec.js @@ -0,0 +1,264 @@ +const mockGetUserById = jest.fn(); +const mockUpdateUser = jest.fn(); +const mockVerifyOTPOrBackupCode = jest.fn(); +const mockGenerateTOTPSecret = jest.fn(); +const mockGenerateBackupCodes = jest.fn(); +const mockEncryptV3 = jest.fn(); + +jest.mock('@librechat/data-schemas', () => ({ + encryptV3: (...args) => mockEncryptV3(...args), + logger: { error: jest.fn() }, +})); + +jest.mock('~/server/services/twoFactorService', () => ({ + verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args), + generateBackupCodes: (...args) => mockGenerateBackupCodes(...args), + generateTOTPSecret: (...args) => mockGenerateTOTPSecret(...args), + verifyBackupCode: jest.fn(), + getTOTPSecret: jest.fn(), + verifyTOTP: jest.fn(), +})); + +jest.mock('~/models', () => ({ + getUserById: (...args) => mockGetUserById(...args), + updateUser: (...args) => mockUpdateUser(...args), +})); + +const { enable2FA, regenerateBackupCodes } = require('~/server/controllers/TwoFactorController'); + +function createRes() { + const res = {}; + res.status = jest.fn().mockReturnValue(res); + res.json = jest.fn().mockReturnValue(res); + return res; +} + +const PLAIN_CODES = ['code1', 'code2', 'code3']; +const CODE_OBJECTS = [ + { codeHash: 'h1', used: false, usedAt: null }, + { codeHash: 'h2', used: false, usedAt: null }, + { codeHash: 'h3', used: false, usedAt: null }, +]; + +beforeEach(() => { + jest.clearAllMocks(); + mockGenerateTOTPSecret.mockReturnValue('NEWSECRET'); + mockGenerateBackupCodes.mockResolvedValue({ plainCodes: PLAIN_CODES, codeObjects: CODE_OBJECTS }); + mockEncryptV3.mockReturnValue('encrypted-secret'); +}); + +describe('enable2FA', () => { + it('allows first-time setup without token — writes to pending fields', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false, email: 'a@b.com' }); + mockUpdateUser.mockResolvedValue({ email: 'a@b.com' }); + + await enable2FA(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ otpauthUrl: expect.any(String), backupCodes: PLAIN_CODES }), + ); + expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled(); + const updateCall = mockUpdateUser.mock.calls[0][1]; + expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret'); + expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS); + expect(updateCall).not.toHaveProperty('twoFactorEnabled'); + expect(updateCall).not.toHaveProperty('totpSecret'); + expect(updateCall).not.toHaveProperty('backupCodes'); + }); + + it('re-enrollment writes to pending fields, leaving live 2FA intact', async () => { + const req = { user: { id: 'user1' }, body: { token: '123456' } }; + const res = createRes(); + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + email: 'a@b.com', + }; + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({ email: 'a@b.com' }); + + await enable2FA(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: '123456', + backupCode: undefined, + persistBackupUse: false, + }); + expect(res.status).toHaveBeenCalledWith(200); + const updateCall = mockUpdateUser.mock.calls[0][1]; + expect(updateCall).toHaveProperty('pendingTotpSecret', 'encrypted-secret'); + expect(updateCall).toHaveProperty('pendingBackupCodes', CODE_OBJECTS); + expect(updateCall).not.toHaveProperty('twoFactorEnabled'); + expect(updateCall).not.toHaveProperty('totpSecret'); + }); + + it('allows re-enrollment with valid backup code (persistBackupUse: false)', async () => { + const req = { user: { id: 'user1' }, body: { backupCode: 'backup123' } }; + const res = createRes(); + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + email: 'a@b.com', + }; + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({ email: 'a@b.com' }); + + await enable2FA(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith( + expect.objectContaining({ persistBackupUse: false }), + ); + expect(res.status).toHaveBeenCalledWith(200); + }); + + it('returns error when no token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 }); + + await enable2FA(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(mockUpdateUser).not.toHaveBeenCalled(); + }); + + it('returns 401 when invalid token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: { token: 'wrong' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await enable2FA(req, res); + + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' }); + expect(mockUpdateUser).not.toHaveBeenCalled(); + }); +}); + +describe('regenerateBackupCodes', () => { + it('returns 404 when user not found', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue(null); + + await regenerateBackupCodes(req, res); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ message: 'User not found' }); + }); + + it('requires OTP when 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: { token: '123456' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({}); + + await regenerateBackupCodes(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.json).toHaveBeenCalledWith({ + backupCodes: PLAIN_CODES, + backupCodesHash: CODE_OBJECTS, + }); + }); + + it('returns error when no token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 }); + + await regenerateBackupCodes(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + }); + + it('returns 401 when invalid token provided and 2FA is enabled', async () => { + const req = { user: { id: 'user1' }, body: { token: 'wrong' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await regenerateBackupCodes(req, res); + + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' }); + }); + + it('includes backupCodesHash in response', async () => { + const req = { user: { id: 'user1' }, body: { token: '123456' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + mockUpdateUser.mockResolvedValue({}); + + await regenerateBackupCodes(req, res); + + const responseBody = res.json.mock.calls[0][0]; + expect(responseBody).toHaveProperty('backupCodesHash', CODE_OBJECTS); + expect(responseBody).toHaveProperty('backupCodes', PLAIN_CODES); + }); + + it('allows regeneration without token when 2FA is not enabled', async () => { + const req = { user: { id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: false, + }); + mockUpdateUser.mockResolvedValue({}); + + await regenerateBackupCodes(req, res); + + expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.json).toHaveBeenCalledWith({ + backupCodes: PLAIN_CODES, + backupCodesHash: CODE_OBJECTS, + }); + }); +}); diff --git a/api/server/controllers/__tests__/deleteUser.spec.js b/api/server/controllers/__tests__/deleteUser.spec.js new file mode 100644 index 0000000000..d0f54a046f --- /dev/null +++ b/api/server/controllers/__tests__/deleteUser.spec.js @@ -0,0 +1,302 @@ +const mockGetUserById = jest.fn(); +const mockDeleteMessages = jest.fn(); +const mockDeleteAllUserSessions = jest.fn(); +const mockDeleteUserById = jest.fn(); +const mockDeleteAllSharedLinks = jest.fn(); +const mockDeletePresets = jest.fn(); +const mockDeleteUserKey = jest.fn(); +const mockDeleteConvos = jest.fn(); +const mockDeleteFiles = jest.fn(); +const mockGetFiles = jest.fn(); +const mockUpdateUserPlugins = jest.fn(); +const mockUpdateUser = jest.fn(); +const mockFindToken = jest.fn(); +const mockVerifyOTPOrBackupCode = jest.fn(); +const mockDeleteUserPluginAuth = jest.fn(); +const mockProcessDeleteRequest = jest.fn(); +const mockDeleteToolCalls = jest.fn(); +const mockDeleteUserAgents = jest.fn(); +const mockDeleteUserPrompts = jest.fn(); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { error: jest.fn(), info: jest.fn() }, + webSearchKeys: [], +})); + +jest.mock('librechat-data-provider', () => ({ + Tools: {}, + CacheKeys: {}, + Constants: { mcp_delimiter: '::', mcp_prefix: 'mcp_' }, + FileSources: {}, +})); + +jest.mock('@librechat/api', () => ({ + MCPOAuthHandler: {}, + MCPTokenStorage: {}, + normalizeHttpError: jest.fn(), + extractWebSearchEnvVars: jest.fn(), +})); + +jest.mock('~/models', () => ({ + deleteAllUserSessions: (...args) => mockDeleteAllUserSessions(...args), + deleteAllSharedLinks: (...args) => mockDeleteAllSharedLinks(...args), + updateUserPlugins: (...args) => mockUpdateUserPlugins(...args), + deleteUserById: (...args) => mockDeleteUserById(...args), + deleteMessages: (...args) => mockDeleteMessages(...args), + deletePresets: (...args) => mockDeletePresets(...args), + deleteUserKey: (...args) => mockDeleteUserKey(...args), + getUserById: (...args) => mockGetUserById(...args), + deleteConvos: (...args) => mockDeleteConvos(...args), + deleteFiles: (...args) => mockDeleteFiles(...args), + updateUser: (...args) => mockUpdateUser(...args), + findToken: (...args) => mockFindToken(...args), + getFiles: (...args) => mockGetFiles(...args), +})); + +jest.mock('~/db/models', () => ({ + ConversationTag: { deleteMany: jest.fn() }, + AgentApiKey: { deleteMany: jest.fn() }, + Transaction: { deleteMany: jest.fn() }, + MemoryEntry: { deleteMany: jest.fn() }, + Assistant: { deleteMany: jest.fn() }, + AclEntry: { deleteMany: jest.fn() }, + Balance: { deleteMany: jest.fn() }, + Action: { deleteMany: jest.fn() }, + Group: { updateMany: jest.fn() }, + Token: { deleteMany: jest.fn() }, + User: {}, +})); + +jest.mock('~/server/services/PluginService', () => ({ + updateUserPluginAuth: jest.fn(), + deleteUserPluginAuth: (...args) => mockDeleteUserPluginAuth(...args), +})); + +jest.mock('~/server/services/twoFactorService', () => ({ + verifyOTPOrBackupCode: (...args) => mockVerifyOTPOrBackupCode(...args), +})); + +jest.mock('~/server/services/AuthService', () => ({ + verifyEmail: jest.fn(), + resendVerificationEmail: jest.fn(), +})); + +jest.mock('~/config', () => ({ + getMCPManager: jest.fn(), + getFlowStateManager: jest.fn(), + getMCPServersRegistry: jest.fn(), +})); + +jest.mock('~/server/services/Config/getCachedTools', () => ({ + invalidateCachedTools: jest.fn(), +})); + +jest.mock('~/server/services/Files/S3/crud', () => ({ + needsRefresh: jest.fn(), + getNewS3URL: jest.fn(), +})); + +jest.mock('~/server/services/Files/process', () => ({ + processDeleteRequest: (...args) => mockProcessDeleteRequest(...args), +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(), +})); + +jest.mock('~/models/ToolCall', () => ({ + deleteToolCalls: (...args) => mockDeleteToolCalls(...args), +})); + +jest.mock('~/models/Prompt', () => ({ + deleteUserPrompts: (...args) => mockDeleteUserPrompts(...args), +})); + +jest.mock('~/models/Agent', () => ({ + deleteUserAgents: (...args) => mockDeleteUserAgents(...args), +})); + +jest.mock('~/cache', () => ({ + getLogStores: jest.fn(), +})); + +const { deleteUserController } = require('~/server/controllers/UserController'); + +function createRes() { + const res = {}; + res.status = jest.fn().mockReturnValue(res); + res.json = jest.fn().mockReturnValue(res); + res.send = jest.fn().mockReturnValue(res); + return res; +} + +function stubDeletionMocks() { + mockDeleteMessages.mockResolvedValue(); + mockDeleteAllUserSessions.mockResolvedValue(); + mockDeleteUserKey.mockResolvedValue(); + mockDeletePresets.mockResolvedValue(); + mockDeleteConvos.mockResolvedValue(); + mockDeleteUserPluginAuth.mockResolvedValue(); + mockDeleteUserById.mockResolvedValue(); + mockDeleteAllSharedLinks.mockResolvedValue(); + mockGetFiles.mockResolvedValue([]); + mockProcessDeleteRequest.mockResolvedValue(); + mockDeleteFiles.mockResolvedValue(); + mockDeleteToolCalls.mockResolvedValue(); + mockDeleteUserAgents.mockResolvedValue(); + mockDeleteUserPrompts.mockResolvedValue(); +} + +beforeEach(() => { + jest.clearAllMocks(); + stubDeletionMocks(); +}); + +describe('deleteUserController - 2FA enforcement', () => { + it('proceeds with deletion when 2FA is not enabled', async () => { + const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ _id: 'user1', twoFactorEnabled: false }); + + await deleteUserController(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + expect(mockDeleteMessages).toHaveBeenCalled(); + expect(mockVerifyOTPOrBackupCode).not.toHaveBeenCalled(); + }); + + it('proceeds with deletion when user has no 2FA record', async () => { + const req = { user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue(null); + + await deleteUserController(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + }); + + it('returns error when 2FA is enabled and verification fails with 400', async () => { + const req = { user: { id: 'user1', _id: 'user1' }, body: {} }; + const res = createRes(); + mockGetUserById.mockResolvedValue({ + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: false, status: 400 }); + + await deleteUserController(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(mockDeleteMessages).not.toHaveBeenCalled(); + }); + + it('returns 401 when 2FA is enabled and invalid TOTP token provided', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }; + const req = { user: { id: 'user1', _id: 'user1' }, body: { token: 'wrong' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: 'wrong', + backupCode: undefined, + }); + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith({ message: 'Invalid token or backup code' }); + expect(mockDeleteMessages).not.toHaveBeenCalled(); + }); + + it('returns 401 when 2FA is enabled and invalid backup code provided', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + backupCodes: [], + }; + const req = { user: { id: 'user1', _id: 'user1' }, body: { backupCode: 'bad-code' } }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ + verified: false, + status: 401, + message: 'Invalid token or backup code', + }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: undefined, + backupCode: 'bad-code', + }); + expect(res.status).toHaveBeenCalledWith(401); + expect(mockDeleteMessages).not.toHaveBeenCalled(); + }); + + it('deletes account when valid TOTP token provided with 2FA enabled', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + }; + const req = { + user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, + body: { token: '123456' }, + }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: '123456', + backupCode: undefined, + }); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + expect(mockDeleteMessages).toHaveBeenCalled(); + }); + + it('deletes account when valid backup code provided with 2FA enabled', async () => { + const existingUser = { + _id: 'user1', + twoFactorEnabled: true, + totpSecret: 'enc-secret', + backupCodes: [{ codeHash: 'h1', used: false }], + }; + const req = { + user: { id: 'user1', _id: 'user1', email: 'a@b.com' }, + body: { backupCode: 'valid-code' }, + }; + const res = createRes(); + mockGetUserById.mockResolvedValue(existingUser); + mockVerifyOTPOrBackupCode.mockResolvedValue({ verified: true }); + + await deleteUserController(req, res); + + expect(mockVerifyOTPOrBackupCode).toHaveBeenCalledWith({ + user: existingUser, + token: undefined, + backupCode: 'valid-code', + }); + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalledWith({ message: 'User deleted' }); + expect(mockDeleteMessages).toHaveBeenCalled(); + }); +}); diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index e84442f65f..d55684f3de 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -63,7 +63,7 @@ router.post( resetPasswordController, ); -router.get('/2fa/enable', middleware.requireJwtAuth, enable2FA); +router.post('/2fa/enable', middleware.requireJwtAuth, enable2FA); router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA); router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken); router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA); diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js index cce24e2322..313c557133 100644 --- a/api/server/services/twoFactorService.js +++ b/api/server/services/twoFactorService.js @@ -153,9 +153,11 @@ const generateBackupCodes = async (count = 10) => { * @param {Object} params * @param {Object} params.user * @param {string} params.backupCode + * @param {boolean} [params.persist=true] - Whether to persist the used-mark to the database. + * Pass `false` when the caller will immediately overwrite `backupCodes` (e.g. re-enrollment). * @returns {Promise} */ -const verifyBackupCode = async ({ user, backupCode }) => { +const verifyBackupCode = async ({ user, backupCode, persist = true }) => { if (!backupCode || !user || !Array.isArray(user.backupCodes)) { return false; } @@ -165,17 +167,50 @@ const verifyBackupCode = async ({ user, backupCode }) => { (codeObj) => codeObj.codeHash === hashedInput && !codeObj.used, ); - if (matchingCode) { + if (!matchingCode) { + return false; + } + + if (persist) { const updatedBackupCodes = user.backupCodes.map((codeObj) => codeObj.codeHash === hashedInput && !codeObj.used ? { ...codeObj, used: true, usedAt: new Date() } : codeObj, ); - // Update the user record with the marked backup code. await updateUser(user._id, { backupCodes: updatedBackupCodes }); - return true; } - return false; + return true; +}; + +/** + * Verifies a user's identity via TOTP token or backup code. + * @param {Object} params + * @param {Object} params.user - The user document (must include totpSecret and backupCodes). + * @param {string} [params.token] - A 6-digit TOTP token. + * @param {string} [params.backupCode] - An 8-character backup code. + * @param {boolean} [params.persistBackupUse=true] - Whether to mark the backup code as used in the DB. + * @returns {Promise<{ verified: boolean, status?: number, message?: string }>} + */ +const verifyOTPOrBackupCode = async ({ user, token, backupCode, persistBackupUse = true }) => { + if (!token && !backupCode) { + return { verified: false, status: 400 }; + } + + if (token) { + const secret = await getTOTPSecret(user.totpSecret); + if (!secret) { + return { verified: false, status: 400, message: '2FA secret is missing or corrupted' }; + } + const ok = await verifyTOTP(secret, token); + return ok + ? { verified: true } + : { verified: false, status: 401, message: 'Invalid token or backup code' }; + } + + const ok = await verifyBackupCode({ user, backupCode, persist: persistBackupUse }); + return ok + ? { verified: true } + : { verified: false, status: 401, message: 'Invalid token or backup code' }; }; /** @@ -213,11 +248,12 @@ const generate2FATempToken = (userId) => { }; module.exports = { - generateTOTPSecret, - generateTOTP, - verifyTOTP, + verifyOTPOrBackupCode, + generate2FATempToken, generateBackupCodes, + generateTOTPSecret, verifyBackupCode, getTOTPSecret, - generate2FATempToken, + generateTOTP, + verifyTOTP, }; diff --git a/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx index c89ce61fff..e66cb7b08a 100644 --- a/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx +++ b/client/src/components/Nav/SettingsTabs/Account/BackupCodesItem.tsx @@ -1,12 +1,23 @@ import React, { useState } from 'react'; import { RefreshCcw } from 'lucide-react'; +import { useSetRecoilState } from 'recoil'; import { motion, AnimatePresence } from 'framer-motion'; -import { TBackupCode, TRegenerateBackupCodesResponse, type TUser } from 'librechat-data-provider'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; +import type { + TRegenerateBackupCodesResponse, + TRegenerateBackupCodesRequest, + TBackupCode, + TUser, +} from 'librechat-data-provider'; import { - OGDialog, + InputOTPSeparator, + InputOTPGroup, + InputOTPSlot, OGDialogContent, OGDialogTitle, OGDialogTrigger, + OGDialog, + InputOTP, Button, Label, Spinner, @@ -15,7 +26,6 @@ import { } from '@librechat/client'; import { useRegenerateBackupCodesMutation } from '~/data-provider'; import { useAuthContext, useLocalize } from '~/hooks'; -import { useSetRecoilState } from 'recoil'; import store from '~/store'; const BackupCodesItem: React.FC = () => { @@ -24,25 +34,30 @@ const BackupCodesItem: React.FC = () => { const { showToast } = useToastContext(); const setUser = useSetRecoilState(store.user); const [isDialogOpen, setDialogOpen] = useState(false); + const [otpToken, setOtpToken] = useState(''); + const [useBackup, setUseBackup] = useState(false); const { mutate: regenerateBackupCodes, isLoading } = useRegenerateBackupCodesMutation(); + const needs2FA = !!user?.twoFactorEnabled; + const fetchBackupCodes = (auto: boolean = false) => { - regenerateBackupCodes(undefined, { + let payload: TRegenerateBackupCodesRequest | undefined; + if (needs2FA && otpToken.trim()) { + payload = useBackup ? { backupCode: otpToken.trim() } : { token: otpToken.trim() }; + } + + regenerateBackupCodes(payload, { onSuccess: (data: TRegenerateBackupCodesResponse) => { - const newBackupCodes: TBackupCode[] = data.backupCodesHash.map((codeHash) => ({ - codeHash, - used: false, - usedAt: null, - })); + const newBackupCodes: TBackupCode[] = data.backupCodesHash; setUser((prev) => ({ ...prev, backupCodes: newBackupCodes }) as TUser); + setOtpToken(''); showToast({ message: localize('com_ui_backup_codes_regenerated'), status: 'success', }); - // Trigger file download only when user explicitly clicks the button. if (!auto && newBackupCodes.length) { const codesString = data.backupCodes.join('\n'); const blob = new Blob([codesString], { type: 'text/plain;charset=utf-8' }); @@ -66,6 +81,8 @@ const BackupCodesItem: React.FC = () => { fetchBackupCodes(false); }; + const otpReady = !needs2FA || otpToken.length === (useBackup ? 8 : 6); + return (
@@ -161,10 +178,10 @@ const BackupCodesItem: React.FC = () => { ); })}
-
+
)} + {needs2FA && ( +
+ +
+ + {useBackup ? ( + + + + + + + + + + + ) : ( + <> + + + + + + + + + + + + + )} + +
+ +
+ )} diff --git a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx index e879a0f2c6..d9c432c6a2 100644 --- a/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx +++ b/client/src/components/Nav/SettingsTabs/Account/DeleteAccount.tsx @@ -1,16 +1,22 @@ -import { LockIcon, Trash } from 'lucide-react'; import React, { useState, useCallback } from 'react'; +import { LockIcon, Trash } from 'lucide-react'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; import { - Label, - Input, - Button, - Spinner, - OGDialog, + InputOTPSeparator, OGDialogContent, OGDialogTrigger, OGDialogHeader, + InputOTPGroup, OGDialogTitle, + InputOTPSlot, + OGDialog, + InputOTP, + Spinner, + Button, + Label, + Input, } from '@librechat/client'; +import type { TDeleteUserRequest } from 'librechat-data-provider'; import { useDeleteUserMutation } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import { LocalizeFunction } from '~/common'; @@ -21,16 +27,27 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea const localize = useLocalize(); const { user, logout } = useAuthContext(); const { mutate: deleteUser, isLoading: isDeleting } = useDeleteUserMutation({ - onMutate: () => logout(), + onSuccess: () => logout(), }); const [isDialogOpen, setDialogOpen] = useState(false); const [isLocked, setIsLocked] = useState(true); + const [otpToken, setOtpToken] = useState(''); + const [useBackup, setUseBackup] = useState(false); + + const needs2FA = !!user?.twoFactorEnabled; const handleDeleteUser = () => { - if (!isLocked) { - deleteUser(undefined); + if (isLocked) { + return; } + + let payload: TDeleteUserRequest | undefined; + if (needs2FA && otpToken.trim()) { + payload = useBackup ? { backupCode: otpToken.trim() } : { token: otpToken.trim() }; + } + + deleteUser(payload); }; const handleInputChange = useCallback( @@ -42,6 +59,8 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea [user?.email], ); + const otpReady = !needs2FA || otpToken.length === (useBackup ? 8 : 6); + return ( <> @@ -79,7 +98,60 @@ const DeleteAccount = ({ disabled = false }: { title?: string; disabled?: boolea (e) => handleInputChange(e.target.value), )}
- {renderDeleteButton(handleDeleteUser, isDeleting, isLocked, localize)} + {needs2FA && ( +
+ +
+ + {useBackup ? ( + + + + + + + + + + + ) : ( + <> + + + + + + + + + + + + + )} + +
+ +
+ )} + {renderDeleteButton(handleDeleteUser, isDeleting, isLocked || !otpReady, localize)}
diff --git a/client/src/data-provider/Auth/mutations.ts b/client/src/data-provider/Auth/mutations.ts index 298ddd9b64..9930e42b4f 100644 --- a/client/src/data-provider/Auth/mutations.ts +++ b/client/src/data-provider/Auth/mutations.ts @@ -68,14 +68,14 @@ export const useRefreshTokenMutation = ( /* User */ export const useDeleteUserMutation = ( - options?: t.MutationOptions, -): UseMutationResult => { + options?: t.MutationOptions, +): UseMutationResult => { const queryClient = useQueryClient(); const clearStates = useClearStates(); const resetDefaultPreset = useResetRecoilState(store.defaultPreset); return useMutation([MutationKeys.deleteUser], { - mutationFn: () => dataService.deleteUser(), + mutationFn: (payload?: t.TDeleteUserRequest) => dataService.deleteUser(payload), ...(options || {}), onSuccess: (...args) => { resetDefaultPreset(); @@ -90,11 +90,11 @@ export const useDeleteUserMutation = ( export const useEnableTwoFactorMutation = (): UseMutationResult< t.TEnable2FAResponse, unknown, - void, + t.TEnable2FARequest | undefined, unknown > => { const queryClient = useQueryClient(); - return useMutation(() => dataService.enableTwoFactor(), { + return useMutation((payload?: t.TEnable2FARequest) => dataService.enableTwoFactor(payload), { onSuccess: (data) => { queryClient.setQueryData([QueryKeys.user, '2fa'], data); }, @@ -146,15 +146,18 @@ export const useDisableTwoFactorMutation = (): UseMutationResult< export const useRegenerateBackupCodesMutation = (): UseMutationResult< t.TRegenerateBackupCodesResponse, unknown, - void, + t.TRegenerateBackupCodesRequest | undefined, unknown > => { const queryClient = useQueryClient(); - return useMutation(() => dataService.regenerateBackupCodes(), { - onSuccess: (data) => { - queryClient.setQueryData([QueryKeys.user, '2fa', 'backup'], data); + return useMutation( + (payload?: t.TRegenerateBackupCodesRequest) => dataService.regenerateBackupCodes(payload), + { + onSuccess: (data) => { + queryClient.setQueryData([QueryKeys.user, '2fa', 'backup'], data); + }, }, - }); + ); }; export const useVerifyTwoFactorTempMutation = ( diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 35d8300489..196ea2ad4a 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -639,6 +639,7 @@ "com_ui_2fa_generate_error": "There was an error generating two-factor authentication settings", "com_ui_2fa_invalid": "Invalid two-factor authentication code", "com_ui_2fa_setup": "Setup 2FA", + "com_ui_2fa_verification_required": "Enter your 2FA code to continue", "com_ui_2fa_verified": "Successfully verified Two-Factor Authentication", "com_ui_accept": "I accept", "com_ui_action_button": "Action Button", diff --git a/packages/data-provider/src/data-service.ts b/packages/data-provider/src/data-service.ts index be5cccd43b..2c7a402d1f 100644 --- a/packages/data-provider/src/data-service.ts +++ b/packages/data-provider/src/data-service.ts @@ -21,8 +21,8 @@ export function revokeAllUserKeys(): Promise { return request.delete(endpoints.revokeAllUserKeys()); } -export function deleteUser(): Promise { - return request.delete(endpoints.deleteUser()); +export function deleteUser(payload?: t.TDeleteUserRequest): Promise { + return request.deleteWithOptions(endpoints.deleteUser(), { data: payload }); } export type FavoriteItem = { @@ -970,8 +970,8 @@ export function updateFeedback( } // 2FA -export function enableTwoFactor(): Promise { - return request.get(endpoints.enableTwoFactor()); +export function enableTwoFactor(payload?: t.TEnable2FARequest): Promise { + return request.post(endpoints.enableTwoFactor(), payload); } export function verifyTwoFactor(payload: t.TVerify2FARequest): Promise { @@ -986,8 +986,10 @@ export function disableTwoFactor(payload?: t.TDisable2FARequest): Promise { - return request.post(endpoints.regenerateBackupCodes()); +export function regenerateBackupCodes( + payload?: t.TRegenerateBackupCodesRequest, +): Promise { + return request.post(endpoints.regenerateBackupCodes(), payload); } export function verifyTwoFactorTemp( diff --git a/packages/data-provider/src/types.ts b/packages/data-provider/src/types.ts index 3b04c40f45..5895fba321 100644 --- a/packages/data-provider/src/types.ts +++ b/packages/data-provider/src/types.ts @@ -425,28 +425,29 @@ export type TLoginResponse = { tempToken?: string; }; +/** Shared payload for any operation that requires OTP or backup-code verification. */ +export type TOTPVerificationPayload = { + token?: string; + backupCode?: string; +}; + +export type TEnable2FARequest = TOTPVerificationPayload; + export type TEnable2FAResponse = { otpauthUrl: string; backupCodes: string[]; message?: string; }; -export type TVerify2FARequest = { - token?: string; - backupCode?: string; -}; +export type TVerify2FARequest = TOTPVerificationPayload; export type TVerify2FAResponse = { message: string; }; -/** - * For verifying 2FA during login with a temporary token. - */ -export type TVerify2FATempRequest = { +/** For verifying 2FA during login with a temporary token. */ +export type TVerify2FATempRequest = TOTPVerificationPayload & { tempToken: string; - token?: string; - backupCode?: string; }; export type TVerify2FATempResponse = { @@ -455,30 +456,22 @@ export type TVerify2FATempResponse = { message?: string; }; -/** - * Request for disabling 2FA. - */ -export type TDisable2FARequest = { - token?: string; - backupCode?: string; -}; +export type TDisable2FARequest = TOTPVerificationPayload; -/** - * Response from disabling 2FA. - */ export type TDisable2FAResponse = { message: string; }; -/** - * Response from regenerating backup codes. - */ +export type TRegenerateBackupCodesRequest = TOTPVerificationPayload; + export type TRegenerateBackupCodesResponse = { - message: string; + message?: string; backupCodes: string[]; - backupCodesHash: string[]; + backupCodesHash: TBackupCode[]; }; +export type TDeleteUserRequest = TOTPVerificationPayload; + export type TRequestPasswordReset = { email: string; }; diff --git a/packages/data-schemas/src/schema/user.ts b/packages/data-schemas/src/schema/user.ts index c2bdc6fd34..57c8f8574e 100644 --- a/packages/data-schemas/src/schema/user.ts +++ b/packages/data-schemas/src/schema/user.ts @@ -121,6 +121,15 @@ const userSchema = new Schema( type: [BackupCodeSchema], select: false, }, + pendingTotpSecret: { + type: String, + select: false, + }, + pendingBackupCodes: { + type: [BackupCodeSchema], + select: false, + default: undefined, + }, refreshToken: { type: [SessionSchema], }, diff --git a/packages/data-schemas/src/types/user.ts b/packages/data-schemas/src/types/user.ts index a78c4679f2..e1cecb7518 100644 --- a/packages/data-schemas/src/types/user.ts +++ b/packages/data-schemas/src/types/user.ts @@ -26,6 +26,12 @@ export interface IUser extends Document { used: boolean; usedAt?: Date | null; }>; + pendingTotpSecret?: string; + pendingBackupCodes?: Array<{ + codeHash: string; + used: boolean; + usedAt?: Date | null; + }>; refreshToken?: Array<{ refreshToken: string; }>; From c6982dc180c26d6c26e1802ee2c63b9dcf3046ee Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 02:57:56 -0400 Subject: [PATCH 15/39] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20fix:=20Agent=20Pe?= =?UTF-8?q?rmission=20Check=20on=20Image=20Upload=20Route=20(#12219)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: add agent permission check to image upload route * refactor: remove unused SystemRoles import and format test file for clarity * fix: address review findings for image upload agent permission check * refactor: move agent upload auth logic to TypeScript in packages/api Extract pure authorization logic from agentPermCheck.js into checkAgentUploadAuth() in packages/api/src/files/agentUploadAuth.ts. The function returns a structured result ({ allowed, status, error }) instead of writing HTTP responses directly, eliminating the dual responsibility and confusing sentinel return value. The JS wrapper in /api is now a thin adapter that translates the result to HTTP. * test: rewrite image upload permission tests as integration tests Replace mock-heavy images-agent-perm.spec.js with integration tests using MongoMemoryServer, real models, and real PermissionService. Follows the established pattern in files.agents.test.js. Moves test to sibling location (images.agents.test.js) matching backend convention. Adds temp file cleanup assertions on 403/404 responses and covers message_file exemption paths (boolean true, string "true", false). * fix: widen AgentUploadAuthDeps types to accept ObjectId from Mongoose The injected getAgent returns Mongoose documents where _id and author are Types.ObjectId at runtime, not string. Widen the DI interface to accept string | Types.ObjectId for _id, author, and resourceId so the contract accurately reflects real callers. * chore: move agent upload auth into files/agents/ subdirectory * refactor: delete agentPermCheck.js wrapper, move verifyAgentUploadPermission to packages/api The /api-only dependencies (getAgent, checkPermission) are now passed as object-field params from the route call sites. Both images.js and files.js import verifyAgentUploadPermission from @librechat/api and inject the deps directly, eliminating the intermediate JS wrapper. * style: fix import type ordering in agent upload auth * fix: prevent token TTL race in MCPTokenStorage.storeTokens When expires_in is provided, use it directly instead of round-tripping through Date arithmetic. The previous code computed accessTokenExpiry as a Date, then after an async encryptV2 call, recomputed expiresIn by subtracting Date.now(). On loaded CI runners the elapsed time caused Math.floor to truncate to 0, triggering the 1-year fallback and making the token appear permanently valid — so refresh never fired. --- api/server/routes/files/files.js | 53 +-- api/server/routes/files/images.agents.test.js | 376 ++++++++++++++++++ api/server/routes/files/images.js | 13 + packages/api/src/files/agents/auth.ts | 113 ++++++ packages/api/src/files/agents/index.ts | 1 + packages/api/src/files/index.ts | 1 + packages/api/src/mcp/oauth/tokens.ts | 28 +- 7 files changed, 525 insertions(+), 60 deletions(-) create mode 100644 api/server/routes/files/images.agents.test.js create mode 100644 packages/api/src/files/agents/auth.ts create mode 100644 packages/api/src/files/agents/index.ts diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index 5de2ddb379..9290d1a7ed 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -2,12 +2,12 @@ const fs = require('fs').promises; const express = require('express'); const { EnvVar } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); +const { verifyAgentUploadPermission } = require('@librechat/api'); const { Time, isUUID, CacheKeys, FileSources, - SystemRoles, ResourceType, EModelEndpoint, PermissionBits, @@ -381,48 +381,15 @@ router.post('/', async (req, res) => { return await processFileUpload({ req, res, metadata }); } - /** - * Check agent permissions for permanent agent file uploads (not message attachments). - * Message attachments (message_file=true) are temporary files for a single conversation - * and should be allowed for users who can chat with the agent. - * Permanent file uploads to tool_resources require EDIT permission. - */ - const isMessageAttachment = metadata.message_file === true || metadata.message_file === 'true'; - if (metadata.agent_id && metadata.tool_resource && !isMessageAttachment) { - const userId = req.user.id; - - /** Admin users bypass permission checks */ - if (req.user.role !== SystemRoles.ADMIN) { - const agent = await getAgent({ id: metadata.agent_id }); - - if (!agent) { - return res.status(404).json({ - error: 'Not Found', - message: 'Agent not found', - }); - } - - /** Check if user is the author or has edit permission */ - if (agent.author.toString() !== userId) { - const hasEditPermission = await checkPermission({ - userId, - role: req.user.role, - resourceType: ResourceType.AGENT, - resourceId: agent._id, - requiredPermission: PermissionBits.EDIT, - }); - - if (!hasEditPermission) { - logger.warn( - `[/files] User ${userId} denied upload to agent ${metadata.agent_id} (insufficient permissions)`, - ); - return res.status(403).json({ - error: 'Forbidden', - message: 'Insufficient permissions to upload files to this agent', - }); - } - } - } + const denied = await verifyAgentUploadPermission({ + req, + res, + metadata, + getAgent, + checkPermission, + }); + if (denied) { + return; } return await processAgentFileUpload({ req, res, metadata }); diff --git a/api/server/routes/files/images.agents.test.js b/api/server/routes/files/images.agents.test.js new file mode 100644 index 0000000000..862ab87d63 --- /dev/null +++ b/api/server/routes/files/images.agents.test.js @@ -0,0 +1,376 @@ +const express = require('express'); +const request = require('supertest'); +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { createMethods } = require('@librechat/data-schemas'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { + SystemRoles, + AccessRoleIds, + ResourceType, + PrincipalType, +} = require('librechat-data-provider'); +const { createAgent } = require('~/models/Agent'); + +jest.mock('~/server/services/Files/process', () => ({ + processAgentFileUpload: jest.fn().mockImplementation(async ({ res }) => { + return res.status(200).json({ message: 'Agent file uploaded', file_id: 'test-file-id' }); + }), + processImageFile: jest.fn().mockImplementation(async ({ res }) => { + return res.status(200).json({ message: 'Image processed' }); + }), + filterFile: jest.fn(), +})); + +jest.mock('fs', () => { + const actualFs = jest.requireActual('fs'); + return { + ...actualFs, + promises: { + ...actualFs.promises, + unlink: jest.fn().mockResolvedValue(undefined), + }, + }; +}); + +const fs = require('fs'); +const { processAgentFileUpload } = require('~/server/services/Files/process'); + +const router = require('~/server/routes/files/images'); + +describe('POST /images - Agent Upload Permission Check (Integration)', () => { + let mongoServer; + let authorId; + let otherUserId; + let agentCustomId; + let User; + let Agent; + let AclEntry; + let methods; + let modelsToCleanup = []; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + + const { createModels } = require('@librechat/data-schemas'); + const models = createModels(mongoose); + modelsToCleanup = Object.keys(models); + Object.assign(mongoose.models, models); + methods = createMethods(mongoose); + + User = models.User; + Agent = models.Agent; + AclEntry = models.AclEntry; + + await methods.seedDefaultRoles(); + }); + + afterAll(async () => { + const collections = mongoose.connection.collections; + for (const key in collections) { + await collections[key].deleteMany({}); + } + for (const modelName of modelsToCleanup) { + if (mongoose.models[modelName]) { + delete mongoose.models[modelName]; + } + } + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + await User.deleteMany({}); + await AclEntry.deleteMany({}); + + authorId = new mongoose.Types.ObjectId(); + otherUserId = new mongoose.Types.ObjectId(); + agentCustomId = `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`; + + await User.create({ _id: authorId, username: 'author', email: 'author@test.com' }); + await User.create({ _id: otherUserId, username: 'other', email: 'other@test.com' }); + + jest.clearAllMocks(); + }); + + const createAppWithUser = (userId, userRole = SystemRoles.USER) => { + const app = express(); + app.use(express.json()); + app.use((req, _res, next) => { + if (req.method === 'POST') { + req.file = { + originalname: 'test.png', + mimetype: 'image/png', + size: 100, + path: '/tmp/t.png', + filename: 'test.png', + }; + req.file_id = uuidv4(); + } + next(); + }); + app.use((req, _res, next) => { + req.user = { id: userId.toString(), role: userRole }; + req.app = { locals: {} }; + req.config = { fileStrategy: 'local', paths: { imageOutput: '/tmp/images' } }; + next(); + }); + app.use('/images', router); + return app; + }; + + it('should return 403 when user has no permission on agent', async () => { + await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(403); + expect(response.body.error).toBe('Forbidden'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); + + it('should allow upload for agent owner', async () => { + await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const app = createAppWithUser(authorId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should allow upload for admin regardless of ownership', async () => { + await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const app = createAppWithUser(otherUserId, SystemRoles.ADMIN); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should allow upload for user with EDIT permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_EDITOR, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should deny upload for user with only VIEW permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(403); + expect(response.body.error).toBe('Forbidden'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); + + it('should skip permission check for regular image uploads without agent_id/tool_resource', async () => { + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + }); + + it('should return 404 for non-existent agent', async () => { + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: 'agent_nonexistent123456789', + tool_resource: 'context', + file_id: uuidv4(), + }); + + expect(response.status).toBe(404); + expect(response.body.error).toBe('Not Found'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); + + it('should allow message_file attachment (boolean true) without EDIT permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + message_file: true, + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should allow message_file attachment (string "true") without EDIT permission', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + message_file: 'true', + file_id: uuidv4(), + }); + + expect(response.status).toBe(200); + expect(processAgentFileUpload).toHaveBeenCalled(); + }); + + it('should deny upload when message_file is false (not a message attachment)', async () => { + const agent = await createAgent({ + id: agentCustomId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const { grantPermission } = require('~/server/services/PermissionService'); + await grantPermission({ + principalType: PrincipalType.USER, + principalId: otherUserId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_VIEWER, + grantedBy: authorId, + }); + + const app = createAppWithUser(otherUserId); + const response = await request(app).post('/images').send({ + endpoint: 'agents', + agent_id: agentCustomId, + tool_resource: 'context', + message_file: false, + file_id: uuidv4(), + }); + + expect(response.status).toBe(403); + expect(response.body.error).toBe('Forbidden'); + expect(processAgentFileUpload).not.toHaveBeenCalled(); + expect(fs.promises.unlink).toHaveBeenCalledWith('/tmp/t.png'); + }); +}); diff --git a/api/server/routes/files/images.js b/api/server/routes/files/images.js index 8072612a69..185ec7a671 100644 --- a/api/server/routes/files/images.js +++ b/api/server/routes/files/images.js @@ -2,12 +2,15 @@ const path = require('path'); const fs = require('fs').promises; const express = require('express'); const { logger } = require('@librechat/data-schemas'); +const { verifyAgentUploadPermission } = require('@librechat/api'); const { isAssistantsEndpoint } = require('librechat-data-provider'); const { processAgentFileUpload, processImageFile, filterFile, } = require('~/server/services/Files/process'); +const { checkPermission } = require('~/server/services/PermissionService'); +const { getAgent } = require('~/models/Agent'); const router = express.Router(); @@ -22,6 +25,16 @@ router.post('/', async (req, res) => { metadata.file_id = req.file_id; if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) { + const denied = await verifyAgentUploadPermission({ + req, + res, + metadata, + getAgent, + checkPermission, + }); + if (denied) { + return; + } return await processAgentFileUpload({ req, res, metadata }); } diff --git a/packages/api/src/files/agents/auth.ts b/packages/api/src/files/agents/auth.ts new file mode 100644 index 0000000000..d9fb2b7423 --- /dev/null +++ b/packages/api/src/files/agents/auth.ts @@ -0,0 +1,113 @@ +import type { IUser } from '@librechat/data-schemas'; +import type { Response } from 'express'; +import type { Types } from 'mongoose'; +import { logger } from '@librechat/data-schemas'; +import { SystemRoles, ResourceType, PermissionBits } from 'librechat-data-provider'; +import type { ServerRequest } from '~/types'; + +export type AgentUploadAuthResult = + | { allowed: true } + | { allowed: false; status: number; error: string; message: string }; + +export interface AgentUploadAuthParams { + userId: string; + userRole: string; + agentId?: string; + toolResource?: string | null; + messageFile?: boolean | string; +} + +export interface AgentUploadAuthDeps { + getAgent: (params: { id: string }) => Promise<{ + _id: string | Types.ObjectId; + author?: string | Types.ObjectId | null; + } | null>; + checkPermission: (params: { + userId: string; + role: string; + resourceType: ResourceType; + resourceId: string | Types.ObjectId; + requiredPermission: number; + }) => Promise; +} + +export async function checkAgentUploadAuth( + params: AgentUploadAuthParams, + deps: AgentUploadAuthDeps, +): Promise { + const { userId, userRole, agentId, toolResource, messageFile } = params; + const { getAgent, checkPermission } = deps; + + const isMessageAttachment = messageFile === true || messageFile === 'true'; + if (!agentId || toolResource == null || isMessageAttachment) { + return { allowed: true }; + } + + if (userRole === SystemRoles.ADMIN) { + return { allowed: true }; + } + + const agent = await getAgent({ id: agentId }); + if (!agent) { + return { allowed: false, status: 404, error: 'Not Found', message: 'Agent not found' }; + } + + if (agent.author?.toString() === userId) { + return { allowed: true }; + } + + const hasEditPermission = await checkPermission({ + userId, + role: userRole, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + requiredPermission: PermissionBits.EDIT, + }); + + if (hasEditPermission) { + return { allowed: true }; + } + + logger.warn( + `[agentUploadAuth] User ${userId} denied upload to agent ${agentId} (insufficient permissions)`, + ); + return { + allowed: false, + status: 403, + error: 'Forbidden', + message: 'Insufficient permissions to upload files to this agent', + }; +} + +/** @returns true if denied (response already sent), false if allowed */ +export async function verifyAgentUploadPermission({ + req, + res, + metadata, + getAgent, + checkPermission, +}: { + req: ServerRequest; + res: Response; + metadata: { agent_id?: string; tool_resource?: string | null; message_file?: boolean | string }; + getAgent: AgentUploadAuthDeps['getAgent']; + checkPermission: AgentUploadAuthDeps['checkPermission']; +}): Promise { + const user = req.user as IUser; + const result = await checkAgentUploadAuth( + { + userId: user.id, + userRole: user.role ?? '', + agentId: metadata.agent_id, + toolResource: metadata.tool_resource, + messageFile: metadata.message_file, + }, + { getAgent, checkPermission }, + ); + + if (!result.allowed) { + res.status(result.status).json({ error: result.error, message: result.message }); + return true; + } + return false; +} diff --git a/packages/api/src/files/agents/index.ts b/packages/api/src/files/agents/index.ts new file mode 100644 index 0000000000..269586ee8b --- /dev/null +++ b/packages/api/src/files/agents/index.ts @@ -0,0 +1 @@ +export * from './auth'; diff --git a/packages/api/src/files/index.ts b/packages/api/src/files/index.ts index 707f2ef7fb..c3bdb49478 100644 --- a/packages/api/src/files/index.ts +++ b/packages/api/src/files/index.ts @@ -1,3 +1,4 @@ +export * from './agents'; export * from './audio'; export * from './context'; export * from './documents/crud'; diff --git a/packages/api/src/mcp/oauth/tokens.ts b/packages/api/src/mcp/oauth/tokens.ts index 7b1d189347..6094a05386 100644 --- a/packages/api/src/mcp/oauth/tokens.ts +++ b/packages/api/src/mcp/oauth/tokens.ts @@ -83,46 +83,40 @@ export class MCPTokenStorage { `${logPrefix} Token expires_in: ${'expires_in' in tokens ? tokens.expires_in : 'N/A'}, expires_at: ${'expires_at' in tokens ? tokens.expires_at : 'N/A'}`, ); - // Handle both expires_in and expires_at formats + const defaultTTL = 365 * 24 * 60 * 60; + let accessTokenExpiry: Date; + let expiresInSeconds: number; if ('expires_at' in tokens && tokens.expires_at) { /** MCPOAuthTokens format - already has calculated expiry */ logger.debug(`${logPrefix} Using expires_at: ${tokens.expires_at}`); accessTokenExpiry = new Date(tokens.expires_at); + expiresInSeconds = Math.floor((accessTokenExpiry.getTime() - Date.now()) / 1000); } else if (tokens.expires_in) { - /** Standard OAuthTokens format - calculate expiry */ + /** Standard OAuthTokens format - use expires_in directly to avoid lossy Date round-trip */ logger.debug(`${logPrefix} Using expires_in: ${tokens.expires_in}`); + expiresInSeconds = tokens.expires_in; accessTokenExpiry = new Date(Date.now() + tokens.expires_in * 1000); } else { - /** No expiry provided - default to 1 year */ logger.debug(`${logPrefix} No expiry provided, using default`); - accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000); + expiresInSeconds = defaultTTL; + accessTokenExpiry = new Date(Date.now() + defaultTTL * 1000); } logger.debug(`${logPrefix} Calculated expiry date: ${accessTokenExpiry.toISOString()}`); - logger.debug( - `${logPrefix} Date object: ${JSON.stringify({ - time: accessTokenExpiry.getTime(), - valid: !isNaN(accessTokenExpiry.getTime()), - iso: accessTokenExpiry.toISOString(), - })}`, - ); - // Ensure the date is valid before passing to createToken if (isNaN(accessTokenExpiry.getTime())) { logger.error(`${logPrefix} Invalid expiry date calculated, using default`); - accessTokenExpiry = new Date(Date.now() + 365 * 24 * 60 * 60 * 1000); + accessTokenExpiry = new Date(Date.now() + defaultTTL * 1000); + expiresInSeconds = defaultTTL; } - // Calculate expiresIn (seconds from now) - const expiresIn = Math.floor((accessTokenExpiry.getTime() - Date.now()) / 1000); - const accessTokenData = { userId, type: 'mcp_oauth', identifier, token: encryptedAccessToken, - expiresIn: expiresIn > 0 ? expiresIn : 365 * 24 * 60 * 60, // Default to 1 year if negative + expiresIn: expiresInSeconds > 0 ? expiresInSeconds : defaultTTL, }; // Check if token already exists and update if it does From 35a35dc2e9280e0bad29c5bb586bb0fd72d4104d Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 03:06:29 -0400 Subject: [PATCH 16/39] =?UTF-8?q?=F0=9F=93=8F=20refactor:=20Add=20File=20S?= =?UTF-8?q?ize=20Limits=20to=20Conversation=20Imports=20(#12221)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: add file size limits to conversation import multer instance * fix: address review findings for conversation import file size limits * fix: use local jest.mock for data-schemas instead of global moduleNameMapper The global @librechat/data-schemas mock in jest.config.js only provided logger, breaking all tests that depend on createModels from the same package. Replace with a virtual jest.mock scoped to the import spec file. * fix: move import to top of file, pre-compute upload middleware, assert logger.warn in tests * refactor: move resolveImportMaxFileSize to packages/api New backend logic belongs in packages/api as TypeScript. Delete the api/server/utils/import/limits.js wrapper and import directly from @librechat/api in convos.js and importConversations.js. Resolver unit tests move to packages/api; the api/ spec retains only multer behavior tests. * chore: rename importLimits to import * fix: stale type reference and mock isolation in import tests Update typeof import path from '../importLimits' to '../import' after the rename. Clear mockLogger.warn in beforeEach to prevent cross-test accumulation. * fix: add resolveImportMaxFileSize to @librechat/api mock in convos.spec.js * fix: resolve jest.mock hoisting issue in import tests jest.mock factories are hoisted above const declarations, so the mockLogger reference was undefined at factory evaluation time. Use a direct import of the mocked logger module instead. * fix: remove virtual flag from data-schemas mock for CI compatibility virtual: true prevents the mock from intercepting the real module in CI where @librechat/data-schemas is built, causing import.ts to use the real logger while the test asserts against the mock. --- api/jest.config.js | 2 +- .../__test-utils__/convos-route-mocks.js | 1 + .../routes/__tests__/convos-import.spec.js | 98 +++++++++++++++++++ api/server/routes/convos.js | 24 ++++- .../utils/import/importConversations.js | 8 +- .../api/src/utils/__tests__/import.test.ts | 76 ++++++++++++++ packages/api/src/utils/import.ts | 20 ++++ packages/api/src/utils/index.ts | 1 + 8 files changed, 223 insertions(+), 7 deletions(-) create mode 100644 api/server/routes/__tests__/convos-import.spec.js create mode 100644 packages/api/src/utils/__tests__/import.test.ts create mode 100644 packages/api/src/utils/import.ts diff --git a/api/jest.config.js b/api/jest.config.js index 3b752403c1..47f8b7287b 100644 --- a/api/jest.config.js +++ b/api/jest.config.js @@ -9,7 +9,7 @@ module.exports = { moduleNameMapper: { '~/(.*)': '/$1', '~/data/auth.json': '/__mocks__/auth.mock.json', - '^openid-client/passport$': '/test/__mocks__/openid-client-passport.js', // Mock for the passport strategy part + '^openid-client/passport$': '/test/__mocks__/openid-client-passport.js', '^openid-client$': '/test/__mocks__/openid-client.js', }, transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'], diff --git a/api/server/routes/__test-utils__/convos-route-mocks.js b/api/server/routes/__test-utils__/convos-route-mocks.js index ca5bafeda9..f89b77db3f 100644 --- a/api/server/routes/__test-utils__/convos-route-mocks.js +++ b/api/server/routes/__test-utils__/convos-route-mocks.js @@ -3,6 +3,7 @@ module.exports = { api: (overrides = {}) => ({ isEnabled: jest.fn(), + resolveImportMaxFileSize: jest.fn(() => 262144000), createAxiosInstance: jest.fn(() => ({ get: jest.fn(), post: jest.fn(), diff --git a/api/server/routes/__tests__/convos-import.spec.js b/api/server/routes/__tests__/convos-import.spec.js new file mode 100644 index 0000000000..c4ea139931 --- /dev/null +++ b/api/server/routes/__tests__/convos-import.spec.js @@ -0,0 +1,98 @@ +const express = require('express'); +const request = require('supertest'); +const multer = require('multer'); + +const importFileFilter = (req, file, cb) => { + if (file.mimetype === 'application/json') { + cb(null, true); + } else { + cb(new Error('Only JSON files are allowed'), false); + } +}; + +/** Proxy app that mirrors the production multer + error-handling pattern */ +function createImportApp(fileSize) { + const app = express(); + const upload = multer({ + storage: multer.memoryStorage(), + fileFilter: importFileFilter, + limits: { fileSize }, + }); + const uploadSingle = upload.single('file'); + + function handleUpload(req, res, next) { + uploadSingle(req, res, (err) => { + if (err && err.code === 'LIMIT_FILE_SIZE') { + return res.status(413).json({ message: 'File exceeds the maximum allowed size' }); + } + if (err) { + return next(err); + } + next(); + }); + } + + app.post('/import', handleUpload, (req, res) => { + res.status(201).json({ message: 'success', size: req.file.size }); + }); + + app.use((err, _req, res, _next) => { + res.status(400).json({ error: err.message }); + }); + + return app; +} + +describe('Conversation Import - Multer File Size Limits', () => { + describe('multer rejects files exceeding the configured limit', () => { + it('returns 413 for files larger than the limit', async () => { + const limit = 1024; + const app = createImportApp(limit); + const oversized = Buffer.alloc(limit + 512, 'x'); + + const res = await request(app) + .post('/import') + .attach('file', oversized, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(413); + expect(res.body.message).toBe('File exceeds the maximum allowed size'); + }); + + it('accepts files within the limit', async () => { + const limit = 4096; + const app = createImportApp(limit); + const valid = Buffer.from(JSON.stringify({ title: 'test' })); + + const res = await request(app) + .post('/import') + .attach('file', valid, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(201); + expect(res.body.message).toBe('success'); + }); + + it('rejects at the exact boundary (limit + 1 byte)', async () => { + const limit = 512; + const app = createImportApp(limit); + const boundary = Buffer.alloc(limit + 1, 'a'); + + const res = await request(app) + .post('/import') + .attach('file', boundary, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(413); + }); + + it('accepts a file just under the limit', async () => { + const limit = 512; + const app = createImportApp(limit); + const underLimit = Buffer.alloc(limit - 1, 'b'); + + const res = await request(app) + .post('/import') + .attach('file', underLimit, { filename: 'import.json', contentType: 'application/json' }); + + expect(res.status).toBe(201); + }); + }); +}); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 5f0c35fa0a..578796170a 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -1,7 +1,7 @@ const multer = require('multer'); const express = require('express'); const { sleep } = require('@librechat/agents'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, resolveImportMaxFileSize } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { @@ -226,7 +226,25 @@ router.post('/update', validateConvoAccess, async (req, res) => { const { importIpLimiter, importUserLimiter } = createImportLimiters(); /** Fork and duplicate share one rate-limit budget (same "clone" operation class) */ const { forkIpLimiter, forkUserLimiter } = createForkLimiters(); -const upload = multer({ storage: storage, fileFilter: importFileFilter }); +const importMaxFileSize = resolveImportMaxFileSize(); +const upload = multer({ + storage, + fileFilter: importFileFilter, + limits: { fileSize: importMaxFileSize }, +}); +const uploadSingle = upload.single('file'); + +function handleUpload(req, res, next) { + uploadSingle(req, res, (err) => { + if (err && err.code === 'LIMIT_FILE_SIZE') { + return res.status(413).json({ message: 'File exceeds the maximum allowed size' }); + } + if (err) { + return next(err); + } + next(); + }); +} /** * Imports a conversation from a JSON file and saves it to the database. @@ -239,7 +257,7 @@ router.post( importIpLimiter, importUserLimiter, configMiddleware, - upload.single('file'), + handleUpload, async (req, res) => { try { /* TODO: optimize to return imported conversations and add manually */ diff --git a/api/server/utils/import/importConversations.js b/api/server/utils/import/importConversations.js index d9e4d4332d..e56176c609 100644 --- a/api/server/utils/import/importConversations.js +++ b/api/server/utils/import/importConversations.js @@ -1,7 +1,10 @@ const fs = require('fs').promises; +const { resolveImportMaxFileSize } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { getImporter } = require('./importers'); +const maxFileSize = resolveImportMaxFileSize(); + /** * Job definition for importing a conversation. * @param {{ filepath, requestUserId }} job - The job object. @@ -11,11 +14,10 @@ const importConversations = async (job) => { try { logger.debug(`user: ${requestUserId} | Importing conversation(s) from file...`); - /* error if file is too large */ const fileInfo = await fs.stat(filepath); - if (fileInfo.size > process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES) { + if (fileInfo.size > maxFileSize) { throw new Error( - `File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES} bytes.`, + `File size is ${fileInfo.size} bytes. It exceeds the maximum limit of ${maxFileSize} bytes.`, ); } diff --git a/packages/api/src/utils/__tests__/import.test.ts b/packages/api/src/utils/__tests__/import.test.ts new file mode 100644 index 0000000000..08fa94669d --- /dev/null +++ b/packages/api/src/utils/__tests__/import.test.ts @@ -0,0 +1,76 @@ +jest.mock('@librechat/data-schemas', () => ({ + logger: { info: jest.fn(), warn: jest.fn(), error: jest.fn(), debug: jest.fn() }, +})); + +import { DEFAULT_IMPORT_MAX_FILE_SIZE, resolveImportMaxFileSize } from '../import'; +import { logger } from '@librechat/data-schemas'; + +describe('resolveImportMaxFileSize', () => { + let originalEnv: string | undefined; + + beforeEach(() => { + originalEnv = process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + jest.clearAllMocks(); + }); + + afterEach(() => { + if (originalEnv !== undefined) { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = originalEnv; + } else { + delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + } + }); + + it('returns 262144000 (250 MiB) when env var is not set', () => { + delete process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + expect(resolveImportMaxFileSize()).toBe(262144000); + expect(DEFAULT_IMPORT_MAX_FILE_SIZE).toBe(262144000); + }); + + it('returns default when env var is empty string', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = ''; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + }); + + it('respects a custom numeric value', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '5242880'; + expect(resolveImportMaxFileSize()).toBe(5242880); + }); + + it('parses string env var to number', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '1048576'; + expect(resolveImportMaxFileSize()).toBe(1048576); + }); + + it('falls back to default and warns for non-numeric string', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = 'abc'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); + + it('falls back to default and warns for negative values', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '-100'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); + + it('falls back to default and warns for zero', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = '0'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); + + it('falls back to default and warns for Infinity', () => { + process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES = 'Infinity'; + expect(resolveImportMaxFileSize()).toBe(DEFAULT_IMPORT_MAX_FILE_SIZE); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES'), + ); + }); +}); diff --git a/packages/api/src/utils/import.ts b/packages/api/src/utils/import.ts new file mode 100644 index 0000000000..94a2c8f818 --- /dev/null +++ b/packages/api/src/utils/import.ts @@ -0,0 +1,20 @@ +import { logger } from '@librechat/data-schemas'; + +/** 250 MiB — default max file size for conversation imports */ +export const DEFAULT_IMPORT_MAX_FILE_SIZE = 262144000; + +/** Resolves the import file-size limit from the env var, falling back to the 250 MiB default */ +export function resolveImportMaxFileSize(): number { + const raw = process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES; + if (!raw) { + return DEFAULT_IMPORT_MAX_FILE_SIZE; + } + const parsed = Number(raw); + if (!Number.isFinite(parsed) || parsed <= 0) { + logger.warn( + `[imports] Invalid CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES="${raw}"; using default ${DEFAULT_IMPORT_MAX_FILE_SIZE}`, + ); + return DEFAULT_IMPORT_MAX_FILE_SIZE; + } + return parsed; +} diff --git a/packages/api/src/utils/index.ts b/packages/api/src/utils/index.ts index 441c2e02d7..5b9315d8c7 100644 --- a/packages/api/src/utils/index.ts +++ b/packages/api/src/utils/index.ts @@ -6,6 +6,7 @@ export * from './email'; export * from './env'; export * from './events'; export * from './files'; +export * from './import'; export * from './generators'; export * from './graph'; export * from './path'; From f67bbb2bc5e1722c91a106f166629a026c6f0d6a Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 03:09:26 -0400 Subject: [PATCH 17/39] =?UTF-8?q?=F0=9F=A7=B9=20fix:=20Sanitize=20Artifact?= =?UTF-8?q?=20Filenames=20in=20Code=20Execution=20Output=20(#12222)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: sanitize artifact filenames to prevent path traversal in code output * test: Mock sanitizeFilename function in process.spec.js to return the original filename - Added a mock implementation for the `sanitizeFilename` function in the `process.spec.js` test file to return the original filename, ensuring that tests can run without altering the filename during the testing process. * fix: use path.relative for traversal check, sanitize all filenames, add security logging - Replace startsWith with path.relative pattern in saveLocalBuffer, consistent with deleteLocalFile and getLocalFileStream in the same file - Hoist sanitizeFilename call before the image/non-image branch so both code paths store the sanitized name in MongoDB - Log a warning when sanitizeFilename mutates a filename (potential traversal) - Log a specific warning when saveLocalBuffer throws a traversal error, so security events are distinguishable from generic network errors in the catch * test: improve traversal test coverage and remove mock reimplementation - Remove partial sanitizeFilename reimplementation from process-traversal tests; use controlled mock returns to verify processCodeOutput wiring instead - Add test for image branch sanitization - Use mkdtempSync for test isolation in crud-traversal to avoid parallel worker collisions - Add prefix-collision bypass test case (../user10/evil vs user1 directory) * fix: use path.relative in isValidPath to prevent prefix-collision bypass Pre-existing startsWith check without path separator had the same class of prefix-collision vulnerability fixed in saveLocalBuffer. --- .../Code/__tests__/process-traversal.spec.js | 124 ++++++++++++++++++ api/server/services/Files/Code/process.js | 20 ++- .../services/Files/Code/process.spec.js | 1 + .../Local/__tests__/crud-traversal.spec.js | 69 ++++++++++ api/server/services/Files/Local/crud.js | 16 ++- 5 files changed, 221 insertions(+), 9 deletions(-) create mode 100644 api/server/services/Files/Code/__tests__/process-traversal.spec.js create mode 100644 api/server/services/Files/Local/__tests__/crud-traversal.spec.js diff --git a/api/server/services/Files/Code/__tests__/process-traversal.spec.js b/api/server/services/Files/Code/__tests__/process-traversal.spec.js new file mode 100644 index 0000000000..2db366d06b --- /dev/null +++ b/api/server/services/Files/Code/__tests__/process-traversal.spec.js @@ -0,0 +1,124 @@ +jest.mock('uuid', () => ({ v4: jest.fn(() => 'mock-uuid') })); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { warn: jest.fn(), debug: jest.fn(), error: jest.fn() }, +})); + +jest.mock('@librechat/agents', () => ({ + getCodeBaseURL: jest.fn(() => 'http://localhost:8000'), +})); + +const mockSanitizeFilename = jest.fn(); + +jest.mock('@librechat/api', () => ({ + logAxiosError: jest.fn(), + getBasePath: jest.fn(() => ''), + sanitizeFilename: mockSanitizeFilename, +})); + +jest.mock('librechat-data-provider', () => ({ + ...jest.requireActual('librechat-data-provider'), + mergeFileConfig: jest.fn(() => ({ serverFileSizeLimit: 100 * 1024 * 1024 })), + getEndpointFileConfig: jest.fn(() => ({ + fileSizeLimit: 100 * 1024 * 1024, + supportedMimeTypes: ['*/*'], + })), + fileConfig: { checkType: jest.fn(() => true) }, +})); + +jest.mock('~/models', () => ({ + createFile: jest.fn().mockResolvedValue({}), + getFiles: jest.fn().mockResolvedValue([]), + updateFile: jest.fn(), + claimCodeFile: jest.fn().mockResolvedValue({ file_id: 'mock-uuid', usage: 0 }), +})); + +const mockSaveBuffer = jest.fn().mockResolvedValue('/uploads/user123/mock-uuid__output.csv'); + +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(() => ({ + saveBuffer: mockSaveBuffer, + })), +})); + +jest.mock('~/server/services/Files/permissions', () => ({ + filterFilesByAgentAccess: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/server/services/Files/images/convert', () => ({ + convertImage: jest.fn(), +})); + +jest.mock('~/server/utils', () => ({ + determineFileType: jest.fn().mockResolvedValue({ mime: 'text/csv' }), +})); + +jest.mock('axios', () => + jest.fn().mockResolvedValue({ + data: Buffer.from('file-content'), + }), +); + +const { createFile } = require('~/models'); +const { processCodeOutput } = require('../process'); + +const baseParams = { + req: { + user: { id: 'user123' }, + config: { + fileStrategy: 'local', + imageOutputType: 'webp', + fileConfig: {}, + }, + }, + id: 'code-file-id', + apiKey: 'test-key', + toolCallId: 'tool-1', + conversationId: 'conv-1', + messageId: 'msg-1', + session_id: 'session-1', +}; + +describe('processCodeOutput path traversal protection', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test('sanitizeFilename is called with the raw artifact name', async () => { + mockSanitizeFilename.mockReturnValueOnce('output.csv'); + await processCodeOutput({ ...baseParams, name: 'output.csv' }); + expect(mockSanitizeFilename).toHaveBeenCalledWith('output.csv'); + }); + + test('sanitized name is used in saveBuffer fileName', async () => { + mockSanitizeFilename.mockReturnValueOnce('sanitized-name.txt'); + await processCodeOutput({ ...baseParams, name: '../../../tmp/poc.txt' }); + + expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../tmp/poc.txt'); + const call = mockSaveBuffer.mock.calls[0][0]; + expect(call.fileName).toBe('mock-uuid__sanitized-name.txt'); + }); + + test('sanitized name is stored as filename in the file record', async () => { + mockSanitizeFilename.mockReturnValueOnce('safe-output.csv'); + await processCodeOutput({ ...baseParams, name: 'unsafe/../../output.csv' }); + + const fileArg = createFile.mock.calls[0][0]; + expect(fileArg.filename).toBe('safe-output.csv'); + }); + + test('sanitized name is used for image file records', async () => { + const { convertImage } = require('~/server/services/Files/images/convert'); + convertImage.mockResolvedValueOnce({ + filepath: '/images/user123/mock-uuid.webp', + bytes: 100, + }); + + mockSanitizeFilename.mockReturnValueOnce('safe-chart.png'); + await processCodeOutput({ ...baseParams, name: '../../../chart.png' }); + + expect(mockSanitizeFilename).toHaveBeenCalledWith('../../../chart.png'); + const fileArg = createFile.mock.calls[0][0]; + expect(fileArg.filename).toBe('safe-chart.png'); + }); +}); diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index 3f0bfcfc87..e878b00255 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -3,7 +3,7 @@ const { v4 } = require('uuid'); const axios = require('axios'); const { logger } = require('@librechat/data-schemas'); const { getCodeBaseURL } = require('@librechat/agents'); -const { logAxiosError, getBasePath } = require('@librechat/api'); +const { logAxiosError, getBasePath, sanitizeFilename } = require('@librechat/api'); const { Tools, megabyte, @@ -146,6 +146,13 @@ const processCodeOutput = async ({ ); } + const safeName = sanitizeFilename(name); + if (safeName !== name) { + logger.warn( + `[processCodeOutput] Filename sanitized: "${name}" -> "${safeName}" | conv=${conversationId}`, + ); + } + if (isImage) { const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1; const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); @@ -156,7 +163,7 @@ const processCodeOutput = async ({ file_id, messageId, usage, - filename: name, + filename: safeName, conversationId, user: req.user.id, type: `image/${appConfig.imageOutputType}`, @@ -200,7 +207,7 @@ const processCodeOutput = async ({ ); } - const fileName = `${file_id}__${name}`; + const fileName = `${file_id}__${safeName}`; const filepath = await saveBuffer({ userId: req.user.id, buffer, @@ -213,7 +220,7 @@ const processCodeOutput = async ({ filepath, messageId, object: 'file', - filename: name, + filename: safeName, type: mimeType, conversationId, user: req.user.id, @@ -229,6 +236,11 @@ const processCodeOutput = async ({ await createFile(file, true); return Object.assign(file, { messageId, toolCallId }); } catch (error) { + if (error?.message === 'Path traversal detected in filename') { + logger.warn( + `[processCodeOutput] Path traversal blocked for file "${name}" | conv=${conversationId}`, + ); + } logAxiosError({ message: 'Error downloading/processing code environment file', error, diff --git a/api/server/services/Files/Code/process.spec.js b/api/server/services/Files/Code/process.spec.js index f01a623f90..b89a6c6307 100644 --- a/api/server/services/Files/Code/process.spec.js +++ b/api/server/services/Files/Code/process.spec.js @@ -58,6 +58,7 @@ jest.mock('@librechat/agents', () => ({ jest.mock('@librechat/api', () => ({ logAxiosError: jest.fn(), getBasePath: jest.fn(() => ''), + sanitizeFilename: jest.fn((name) => name), })); // Mock models diff --git a/api/server/services/Files/Local/__tests__/crud-traversal.spec.js b/api/server/services/Files/Local/__tests__/crud-traversal.spec.js new file mode 100644 index 0000000000..57ba221d68 --- /dev/null +++ b/api/server/services/Files/Local/__tests__/crud-traversal.spec.js @@ -0,0 +1,69 @@ +jest.mock('@librechat/api', () => ({ deleteRagFile: jest.fn() })); +jest.mock('@librechat/data-schemas', () => ({ + logger: { warn: jest.fn(), error: jest.fn() }, +})); + +const mockTmpBase = require('fs').mkdtempSync( + require('path').join(require('os').tmpdir(), 'crud-traversal-'), +); + +jest.mock('~/config/paths', () => { + const path = require('path'); + return { + publicPath: path.join(mockTmpBase, 'public'), + uploads: path.join(mockTmpBase, 'uploads'), + }; +}); + +const fs = require('fs'); +const path = require('path'); +const { saveLocalBuffer } = require('../crud'); + +describe('saveLocalBuffer path containment', () => { + beforeAll(() => { + fs.mkdirSync(path.join(mockTmpBase, 'public', 'images'), { recursive: true }); + fs.mkdirSync(path.join(mockTmpBase, 'uploads'), { recursive: true }); + }); + + afterAll(() => { + fs.rmSync(mockTmpBase, { recursive: true, force: true }); + }); + + test('rejects filenames with path traversal sequences', async () => { + await expect( + saveLocalBuffer({ + userId: 'user1', + buffer: Buffer.from('malicious'), + fileName: '../../../etc/passwd', + basePath: 'uploads', + }), + ).rejects.toThrow('Path traversal detected in filename'); + }); + + test('rejects prefix-collision traversal (startsWith bypass)', async () => { + fs.mkdirSync(path.join(mockTmpBase, 'uploads', 'user10'), { recursive: true }); + await expect( + saveLocalBuffer({ + userId: 'user1', + buffer: Buffer.from('malicious'), + fileName: '../user10/evil', + basePath: 'uploads', + }), + ).rejects.toThrow('Path traversal detected in filename'); + }); + + test('allows normal filenames', async () => { + const result = await saveLocalBuffer({ + userId: 'user1', + buffer: Buffer.from('safe content'), + fileName: 'file-id__output.csv', + basePath: 'uploads', + }); + + expect(result).toBe('/uploads/user1/file-id__output.csv'); + + const filePath = path.join(mockTmpBase, 'uploads', 'user1', 'file-id__output.csv'); + expect(fs.existsSync(filePath)).toBe(true); + fs.unlinkSync(filePath); + }); +}); diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index 1f38a01f83..c86774d472 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -78,7 +78,13 @@ async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' } fs.mkdirSync(directoryPath, { recursive: true }); } - fs.writeFileSync(path.join(directoryPath, fileName), buffer); + const resolvedDir = path.resolve(directoryPath); + const resolvedPath = path.resolve(resolvedDir, fileName); + const rel = path.relative(resolvedDir, resolvedPath); + if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { + throw new Error('Path traversal detected in filename'); + } + fs.writeFileSync(resolvedPath, buffer); const filePath = path.posix.join('/', basePath, userId, fileName); @@ -165,9 +171,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) { } /** - * Validates if a given filepath is within a specified subdirectory under a base path. This function constructs - * the expected base path using the base, subfolder, and user id from the request, and then checks if the - * provided filepath starts with this constructed base path. + * Validates that a filepath is strictly contained within a subdirectory under a base path, + * using path.relative to prevent prefix-collision bypasses. * * @param {ServerRequest} req - The request object from Express. It should contain a `user` property with an `id`. * @param {string} base - The base directory path. @@ -180,7 +185,8 @@ async function getLocalFileURL({ fileName, basePath = 'images' }) { const isValidPath = (req, base, subfolder, filepath) => { const normalizedBase = path.resolve(base, subfolder, req.user.id); const normalizedFilepath = path.resolve(filepath); - return normalizedFilepath.startsWith(normalizedBase); + const rel = path.relative(normalizedBase, normalizedFilepath); + return !rel.startsWith('..') && !path.isAbsolute(rel) && !rel.includes(`..${path.sep}`); }; /** From cbdc6f606057296aa2a92eaf7636347757432849 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 03:36:03 -0400 Subject: [PATCH 18/39] =?UTF-8?q?=F0=9F=93=A6=20chore:=20Bump=20NPM=20Audi?= =?UTF-8?q?t=20Packages=20(#12227)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 chore: Update file-type dependency to version 21.3.2 in package-lock.json and package.json - Upgraded the "file-type" package from version 18.7.0 to 21.3.2 to ensure compatibility with the latest features and security updates. - Added new dependencies related to the updated "file-type" package, enhancing functionality and performance. * 🔧 chore: Upgrade undici dependency to version 7.24.1 in package-lock.json and package.json - Updated the "undici" package from version 7.18.2 to 7.24.1 across multiple package files to ensure compatibility with the latest features and security updates. * 🔧 chore: Upgrade yauzl dependency to version 3.2.1 in package-lock.json - Updated the "yauzl" package from version 3.2.0 to 3.2.1 to incorporate the latest features and security updates. * 🔧 chore: Upgrade hono dependency to version 4.12.7 in package-lock.json - Updated the "hono" package from version 4.12.5 to 4.12.7 to incorporate the latest features and security updates. --- api/package.json | 4 +- package-lock.json | 208 ++++++++++++++++++++++---------------- packages/api/package.json | 2 +- 3 files changed, 124 insertions(+), 90 deletions(-) diff --git a/api/package.json b/api/package.json index 1618481b58..0305446818 100644 --- a/api/package.json +++ b/api/package.json @@ -67,7 +67,7 @@ "express-rate-limit": "^8.3.0", "express-session": "^1.18.2", "express-static-gzip": "^2.2.0", - "file-type": "^18.7.0", + "file-type": "^21.3.2", "firebase": "^11.0.2", "form-data": "^4.0.4", "handlebars": "^4.7.7", @@ -109,7 +109,7 @@ "sharp": "^0.33.5", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", - "undici": "^7.18.2", + "undici": "^7.24.1", "winston": "^3.11.0", "winston-daily-rotate-file": "^5.0.0", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", diff --git a/package-lock.json b/package-lock.json index a2db2df389..502b3a8eed 100644 --- a/package-lock.json +++ b/package-lock.json @@ -82,7 +82,7 @@ "express-rate-limit": "^8.3.0", "express-session": "^1.18.2", "express-static-gzip": "^2.2.0", - "file-type": "^18.7.0", + "file-type": "^21.3.2", "firebase": "^11.0.2", "form-data": "^4.0.4", "handlebars": "^4.7.7", @@ -124,7 +124,7 @@ "sharp": "^0.33.5", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", - "undici": "^7.18.2", + "undici": "^7.24.1", "winston": "^3.11.0", "winston-daily-rotate-file": "^5.0.0", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", @@ -270,6 +270,24 @@ "node": ">= 0.8.0" } }, + "api/node_modules/file-type": { + "version": "21.3.2", + "resolved": "https://registry.npmjs.org/file-type/-/file-type-21.3.2.tgz", + "integrity": "sha512-DLkUvGwep3poOV2wpzbHCOnSKGk1LzyXTv+aHFgN2VFl96wnp8YA9YjO2qPzg5PuL8q/SW9Pdi6WTkYOIh995w==", + "license": "MIT", + "dependencies": { + "@tokenizer/inflate": "^0.4.1", + "strtok3": "^10.3.4", + "token-types": "^6.1.1", + "uint8array-extras": "^1.4.0" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sindresorhus/file-type?sponsor=1" + } + }, "api/node_modules/jose": { "version": "6.1.3", "resolved": "https://registry.npmjs.org/jose/-/jose-6.1.3.tgz", @@ -348,6 +366,40 @@ "@img/sharp-win32-x64": "0.33.5" } }, + "api/node_modules/strtok3": { + "version": "10.3.4", + "resolved": "https://registry.npmjs.org/strtok3/-/strtok3-10.3.4.tgz", + "integrity": "sha512-KIy5nylvC5le1OdaaoCJ07L+8iQzJHGH6pWDuzS+d07Cu7n1MZ2x26P8ZKIWfbK02+XIL8Mp4RkWeqdUCrDMfg==", + "license": "MIT", + "dependencies": { + "@tokenizer/token": "^0.3.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, + "api/node_modules/token-types": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/token-types/-/token-types-6.1.2.tgz", + "integrity": "sha512-dRXchy+C0IgK8WPC6xvCHFRIWYUbqqdEIKPaKo/AcTUNzwLTK6AH7RjdLWsEZcAN/TBdtfUw3PYEgPr5VPr6ww==", + "license": "MIT", + "dependencies": { + "@borewit/text-codec": "^0.2.1", + "@tokenizer/token": "^0.3.0", + "ieee754": "^1.2.1" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, "api/node_modules/winston-daily-rotate-file": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/winston-daily-rotate-file/-/winston-daily-rotate-file-5.0.0.tgz", @@ -7286,6 +7338,16 @@ "dev": true, "license": "MIT" }, + "node_modules/@borewit/text-codec": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/@borewit/text-codec/-/text-codec-0.2.2.tgz", + "integrity": "sha512-DDaRehssg1aNrH4+2hnj1B7vnUGEjU6OIlyRdkMd0aUdIUvKXrJfXsy8LVtXAy7DRvYVluWbMspsRhz2lcW0mQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, "node_modules/@braintree/sanitize-url": { "version": "7.1.1", "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-7.1.1.tgz", @@ -20799,6 +20861,41 @@ "@testing-library/dom": ">=7.21.4" } }, + "node_modules/@tokenizer/inflate": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/@tokenizer/inflate/-/inflate-0.4.1.tgz", + "integrity": "sha512-2mAv+8pkG6GIZiF1kNg1jAjh27IDxEPKwdGul3snfztFerfPGI1LjDezZp3i7BElXompqEtPmoPx6c2wgtWsOA==", + "license": "MIT", + "dependencies": { + "debug": "^4.4.3", + "token-types": "^6.1.1" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, + "node_modules/@tokenizer/inflate/node_modules/token-types": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/token-types/-/token-types-6.1.2.tgz", + "integrity": "sha512-dRXchy+C0IgK8WPC6xvCHFRIWYUbqqdEIKPaKo/AcTUNzwLTK6AH7RjdLWsEZcAN/TBdtfUw3PYEgPr5VPr6ww==", + "license": "MIT", + "dependencies": { + "@borewit/text-codec": "^0.2.1", + "@tokenizer/token": "^0.3.0", + "ieee754": "^1.2.1" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Borewit" + } + }, "node_modules/@tokenizer/token": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/@tokenizer/token/-/token-0.3.0.tgz", @@ -27513,22 +27610,6 @@ "moment": "^2.29.1" } }, - "node_modules/file-type": { - "version": "18.7.0", - "resolved": "https://registry.npmjs.org/file-type/-/file-type-18.7.0.tgz", - "integrity": "sha512-ihHtXRzXEziMrQ56VSgU7wkxh55iNchFkosu7Y9/S+tXHdKyrGjVK0ujbqNnsxzea+78MaLhN6PGmfYSAv1ACw==", - "dependencies": { - "readable-web-to-node-stream": "^3.0.2", - "strtok3": "^7.0.0", - "token-types": "^5.0.1" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sindresorhus/file-type?sponsor=1" - } - }, "node_modules/filelist": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.6.tgz", @@ -28817,9 +28898,9 @@ "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" }, "node_modules/hono": { - "version": "4.12.5", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz", - "integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==", + "version": "4.12.7", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz", + "integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==", "license": "MIT", "engines": { "node": ">=16.9.0" @@ -35702,18 +35783,6 @@ "node-readable-to-web-readable-stream": "^0.4.2" } }, - "node_modules/peek-readable": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/peek-readable/-/peek-readable-5.0.0.tgz", - "integrity": "sha512-YtCKvLUOvwtMGmrniQPdO7MwPjgkFBtFIrmfSbYmYuq3tKDV/mcfAhBth1+C3ru7uXIZasc/pHnb+YDYNkkj4A==", - "engines": { - "node": ">=14.16" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/pend": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/pend/-/pend-1.2.0.tgz", @@ -38519,21 +38588,6 @@ "node": ">= 6" } }, - "node_modules/readable-web-to-node-stream": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/readable-web-to-node-stream/-/readable-web-to-node-stream-3.0.2.tgz", - "integrity": "sha512-ePeK6cc1EcKLEhJFt/AebMCLL+GgSKhuygrZ/GLaKZYEecIgIECf4UaUuaByiGtzckwR4ain9VzUh95T1exYGw==", - "dependencies": { - "readable-stream": "^3.6.0" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/readdirp": { "version": "3.6.0", "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", @@ -40920,22 +40974,6 @@ ], "license": "MIT" }, - "node_modules/strtok3": { - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/strtok3/-/strtok3-7.0.0.tgz", - "integrity": "sha512-pQ+V+nYQdC5H3Q7qBZAz/MO6lwGhoC2gOAjuouGf/VO0m7vQRh8QNMl2Uf6SwAtzZ9bOw3UIeBukEGNJl5dtXQ==", - "dependencies": { - "@tokenizer/token": "^0.3.0", - "peek-readable": "^5.0.0" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/style-inject": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/style-inject/-/style-inject-0.3.0.tgz", @@ -41640,22 +41678,6 @@ "node": ">=0.6" } }, - "node_modules/token-types": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/token-types/-/token-types-5.0.1.tgz", - "integrity": "sha512-Y2fmSnZjQdDb9W4w4r1tswlMHylzWIeOKpx0aZH9BgGtACHhrk3OkT52AzwcuqTRBZtvvnTjDBh8eynMulu8Vg==", - "dependencies": { - "@tokenizer/token": "^0.3.0", - "ieee754": "^1.2.1" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/Borewit" - } - }, "node_modules/touch": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/touch/-/touch-3.1.0.tgz", @@ -42206,6 +42228,18 @@ "resolved": "https://registry.npmjs.org/uid2/-/uid2-0.0.4.tgz", "integrity": "sha512-IevTus0SbGwQzYh3+fRsAMTVVPOoIVufzacXcHPmdlle1jUpq7BRL+mw3dgeLanvGZdwwbWhRV6XrcFNdBmjWA==" }, + "node_modules/uint8array-extras": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/uint8array-extras/-/uint8array-extras-1.5.0.tgz", + "integrity": "sha512-rvKSBiC5zqCCiDZ9kAOszZcDvdAHwwIKJG33Ykj43OKcWsnmcBRL09YTU4nOeHZ8Y2a7l1MgTd08SBe9A8Qj6A==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/unbox-primitive": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.1.0.tgz", @@ -42238,9 +42272,9 @@ "license": "MIT" }, "node_modules/undici": { - "version": "7.20.0", - "resolved": "https://registry.npmjs.org/undici/-/undici-7.20.0.tgz", - "integrity": "sha512-MJZrkjyd7DeC+uPZh+5/YaMDxFiiEEaDgbUSVMXayofAkDWF1088CDo+2RPg7B1BuS1qf1vgNE7xqwPxE0DuSQ==", + "version": "7.24.1", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.24.1.tgz", + "integrity": "sha512-5xoBibbmnjlcR3jdqtY2Lnx7WbrD/tHlT01TmvqZUFVc9Q1w4+j5hbnapTqbcXITMH1ovjq/W7BkqBilHiVAaA==", "license": "MIT", "engines": { "node": ">=20.18.1" @@ -44097,9 +44131,9 @@ } }, "node_modules/yauzl": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-3.2.0.tgz", - "integrity": "sha512-Ow9nuGZE+qp1u4JIPvg+uCiUr7xGQWdff7JQSk5VGYTAZMDe2q8lxJ10ygv10qmSj031Ty/6FNJpLO4o1Sgc+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-3.2.1.tgz", + "integrity": "sha512-k1isifdbpNSFEHFJ1ZY4YDewv0IH9FR61lDetaRMD3j2ae3bIXGV+7c+LHCqtQGofSd8PIyV4X6+dHMAnSr60A==", "dev": true, "license": "MIT", "dependencies": { @@ -44232,7 +44266,7 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "undici": "^7.18.2", + "undici": "^7.24.1", "zod": "^3.22.4" } }, diff --git a/packages/api/package.json b/packages/api/package.json index 966447c51b..77258fc0b3 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -117,7 +117,7 @@ "node-fetch": "2.7.0", "pdfjs-dist": "^5.4.624", "rate-limit-redis": "^4.2.0", - "undici": "^7.18.2", + "undici": "^7.24.1", "zod": "^3.22.4" } } From 7bc793b18db312feab9ad15e765f2f2da34d0667 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 10:54:26 -0400 Subject: [PATCH 19/39] =?UTF-8?q?=F0=9F=8C=8A=20fix:=20Prevent=20Buffered?= =?UTF-8?q?=20Event=20Duplication=20on=20SSE=20Resume=20Connections=20(#12?= =?UTF-8?q?225)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: skipBufferReplay for job resume connections - Introduced a new option `skipBufferReplay` in the `subscribe` method of `GenerationJobManagerClass` to prevent duplication of events when resuming a connection. - Updated the logic to conditionally skip replaying buffered events if a sync event has already been sent, enhancing the efficiency of event handling during reconnections. - Added integration tests to verify the correct behavior of the new option, ensuring that no buffered events are replayed when `skipBufferReplay` is true, while still allowing for normal replay behavior when false. * refactor: Update GenerationJobManager to handle sync events more efficiently - Modified the `subscribe` method to utilize a new `skipBufferReplay` option, allowing for the prevention of duplicate events during resume connections. - Enhanced the logic in the `chat/stream` route to conditionally skip replaying buffered events if a sync event has already been sent, improving event handling efficiency. - Updated integration tests to verify the correct behavior of the new option, ensuring that no buffered events are replayed when `skipBufferReplay` is true, while maintaining normal replay behavior when false. * test: Enhance GenerationJobManager integration tests for Redis mode - Updated integration tests to conditionally run based on the USE_REDIS environment variable, allowing for better control over Redis-related tests. - Refactored test descriptions to utilize a dynamic `describeRedis` function, improving clarity and organization of tests related to Redis functionality. - Removed redundant checks for Redis availability within individual tests, streamlining the test logic and enhancing readability. * fix: sync handler state for new messages on resume The sync event's else branch (new response message) was missing resetContentHandler() and syncStepMessage() calls, leaving stale handler state that caused subsequent deltas to build on partial content instead of the synced aggregatedContent. * feat: atomic subscribeWithResume to close resume event gap Replaces separate getResumeState() + subscribe() calls with a single subscribeWithResume() that atomically drains earlyEventBuffer between the resume snapshot and the subscribe. In in-memory mode, drained events are returned as pendingEvents for the client to replay after sync. In Redis mode, pendingEvents is empty since chunks are already persisted. The route handler now uses the atomic method for resume connections and extracted shared SSE write helpers to reduce duplication. The client replays any pendingEvents through the existing step/content handlers after applying aggregatedContent from the sync payload. * fix: only capture gap events in subscribeWithResume, not pre-snapshot buffer The previous implementation drained the entire earlyEventBuffer into pendingEvents, but pre-snapshot events are already reflected in aggregatedContent. Replaying them re-introduced the duplication bug through a different vector. Now records buffer length before getResumeState() and slices from that index, so only events arriving during the async gap are returned as pendingEvents. Also: - Handle pendingEvents when resumeState is null (replay directly) - Hoist duplicate test helpers to shared scope - Remove redundant writableEnded guard in onDone --- api/server/routes/agents/index.js | 82 +- client/src/hooks/SSE/useResumableSSE.ts | 44 +- .../api/src/stream/GenerationJobManager.ts | 67 +- ...ationJobManager.stream_integration.spec.ts | 745 +++++++++++++----- packages/api/src/types/stream.ts | 21 + 5 files changed, 700 insertions(+), 259 deletions(-) diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index f8d39cb4d8..a99fdca592 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -76,52 +76,62 @@ router.get('/chat/stream/:streamId', async (req, res) => { logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`); - // Send sync event with resume state for ALL reconnecting clients - // This supports multi-tab scenarios where each tab needs run step data - if (isResume) { - const resumeState = await GenerationJobManager.getResumeState(streamId); - if (resumeState && !res.writableEnded) { - // Send sync event with run steps AND aggregatedContent - // Client will use aggregatedContent to initialize message state - res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`); + const writeEvent = (event) => { + if (!res.writableEnded) { + res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); if (typeof res.flush === 'function') { res.flush(); } - logger.debug( - `[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`, - ); } - } + }; - const result = await GenerationJobManager.subscribe( - streamId, - (event) => { - if (!res.writableEnded) { - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); + const onDone = (event) => { + writeEvent(event); + res.end(); + }; + + const onError = (error) => { + if (!res.writableEnded) { + res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`); + if (typeof res.flush === 'function') { + res.flush(); + } + res.end(); + } + }; + + let result; + + if (isResume) { + const { subscription, resumeState, pendingEvents } = + await GenerationJobManager.subscribeWithResume(streamId, writeEvent, onDone, onError); + + if (!res.writableEnded) { + if (resumeState) { + res.write( + `event: message\ndata: ${JSON.stringify({ sync: true, resumeState, pendingEvents })}\n\n`, + ); if (typeof res.flush === 'function') { res.flush(); } - } - }, - (event) => { - if (!res.writableEnded) { - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); - if (typeof res.flush === 'function') { - res.flush(); + GenerationJobManager.markSyncSent(streamId); + logger.debug( + `[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps, ${pendingEvents.length} pending events`, + ); + } else if (pendingEvents.length > 0) { + for (const event of pendingEvents) { + writeEvent(event); } - res.end(); + logger.warn( + `[AgentStream] Resume state null for ${streamId}, replayed ${pendingEvents.length} gap events directly`, + ); } - }, - (error) => { - if (!res.writableEnded) { - res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`); - if (typeof res.flush === 'function') { - res.flush(); - } - res.end(); - } - }, - ); + } + + result = subscription; + } else { + result = await GenerationJobManager.subscribe(streamId, writeEvent, onDone, onError); + } if (!result) { return res.status(404).json({ error: 'Failed to subscribe to stream' }); diff --git a/client/src/hooks/SSE/useResumableSSE.ts b/client/src/hooks/SSE/useResumableSSE.ts index 831bf042ad..4d4cb4841a 100644 --- a/client/src/hooks/SSE/useResumableSSE.ts +++ b/client/src/hooks/SSE/useResumableSSE.ts @@ -226,12 +226,12 @@ export default function useResumableSSE( if (data.sync != null) { console.log('[ResumableSSE] SYNC received', { runSteps: data.resumeState?.runSteps?.length ?? 0, + pendingEvents: data.pendingEvents?.length ?? 0, }); const runId = v4(); setActiveRunId(runId); - // Replay run steps if (data.resumeState?.runSteps) { for (const runStep of data.resumeState.runSteps) { stepHandler({ event: 'on_run_step', data: runStep }, { @@ -241,19 +241,15 @@ export default function useResumableSSE( } } - // Set message content from aggregatedContent if (data.resumeState?.aggregatedContent && userMessage?.messageId) { const messages = getMessages() ?? []; const userMsgId = userMessage.messageId; const serverResponseId = data.resumeState.responseMessageId; - // Find the EXACT response message - prioritize responseMessageId from server - // This is critical when there are multiple responses to the same user message let responseIdx = -1; if (serverResponseId) { responseIdx = messages.findIndex((m) => m.messageId === serverResponseId); } - // Fallback: find by parentMessageId pattern (for new messages) if (responseIdx < 0) { responseIdx = messages.findIndex( (m) => @@ -272,7 +268,6 @@ export default function useResumableSSE( }); if (responseIdx >= 0) { - // Update existing response message with aggregatedContent const updated = [...messages]; const oldContent = updated[responseIdx]?.content; updated[responseIdx] = { @@ -285,25 +280,34 @@ export default function useResumableSSE( newContentLength: data.resumeState.aggregatedContent?.length, }); setMessages(updated); - // Sync both content handler and step handler with the updated message - // so subsequent deltas build on synced content, not stale content resetContentHandler(); syncStepMessage(updated[responseIdx]); console.log('[ResumableSSE] SYNC complete, handlers synced'); } else { - // Add new response message const responseId = serverResponseId ?? `${userMsgId}_`; - setMessages([ - ...messages, - { - messageId: responseId, - parentMessageId: userMsgId, - conversationId: currentSubmission.conversation?.conversationId ?? '', - text: '', - content: data.resumeState.aggregatedContent, - isCreatedByUser: false, - } as TMessage, - ]); + const newMessage = { + messageId: responseId, + parentMessageId: userMsgId, + conversationId: currentSubmission.conversation?.conversationId ?? '', + text: '', + content: data.resumeState.aggregatedContent, + isCreatedByUser: false, + } as TMessage; + setMessages([...messages, newMessage]); + resetContentHandler(); + syncStepMessage(newMessage); + } + } + + if (data.pendingEvents?.length > 0) { + console.log(`[ResumableSSE] Replaying ${data.pendingEvents.length} pending events`); + const submission = { ...currentSubmission, userMessage } as EventSubmission; + for (const pendingEvent of data.pendingEvents) { + if (pendingEvent.event != null) { + stepHandler(pendingEvent, submission); + } else if (pendingEvent.type != null) { + contentHandler({ data: pendingEvent, submission }); + } } } diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index cd5ff04eb0..1b612dcb8f 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -707,6 +707,10 @@ class GenerationJobManagerClass { * @param onChunk - Handler for chunk events (streamed tokens, run steps, etc.) * @param onDone - Handler for completion event (includes final message) * @param onError - Handler for error events + * @param options - Subscription configuration + * @param options.skipBufferReplay - When true, skips replaying the earlyEventBuffer. + * Use this when a sync event was already sent (resume), since the sync's + * aggregatedContent already includes all buffered events. * @returns Subscription object with unsubscribe function, or null if job not found */ async subscribe( @@ -714,6 +718,7 @@ class GenerationJobManagerClass { onChunk: t.ChunkHandler, onDone?: t.DoneHandler, onError?: t.ErrorHandler, + options?: t.SubscribeOptions, ): Promise<{ unsubscribe: t.UnsubscribeFn } | null> { // Use lazy initialization to support cross-replica subscriptions const runtime = await this.getOrCreateRuntimeState(streamId); @@ -763,11 +768,17 @@ class GenerationJobManagerClass { runtime.hasSubscriber = true; if (runtime.earlyEventBuffer.length > 0) { - logger.debug( - `[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`, - ); - for (const bufferedEvent of runtime.earlyEventBuffer) { - onChunk(bufferedEvent); + if (options?.skipBufferReplay) { + logger.debug( + `[GenerationJobManager] Skipping ${runtime.earlyEventBuffer.length} buffered events for ${streamId} (skipBufferReplay)`, + ); + } else { + logger.debug( + `[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`, + ); + for (const bufferedEvent of runtime.earlyEventBuffer) { + onChunk(bufferedEvent); + } } runtime.earlyEventBuffer = []; } @@ -785,6 +796,52 @@ class GenerationJobManagerClass { return subscription; } + /** + * Atomic resume + subscribe: snapshots resume state and drains the early event buffer + * in one synchronous step, then subscribes with skipBufferReplay. + * + * Closes the timing gap between separate `getResumeState()` and `subscribe()` calls + * where events could arrive in earlyEventBuffer after the snapshot but before subscribe + * clears the buffer. + * + * In-memory mode: drained buffer events are returned as `pendingEvents` since + * they exist nowhere else. The caller must deliver them after the sync payload. + * Redis mode: `pendingEvents` is empty — chunks are persisted via appendChunk + * and will appear in aggregatedContent on the next resume. + */ + async subscribeWithResume( + streamId: string, + onChunk: t.ChunkHandler, + onDone?: t.DoneHandler, + onError?: t.ErrorHandler, + ): Promise { + const bufferLengthAtSnapshot = !this._isRedis + ? (this.runtimeState.get(streamId)?.earlyEventBuffer.length ?? 0) + : 0; + + const resumeState = await this.getResumeState(streamId); + + let pendingEvents: t.ServerSentEvent[] = []; + if (!this._isRedis) { + const runtime = this.runtimeState.get(streamId); + if (runtime) { + pendingEvents = runtime.earlyEventBuffer.slice(bufferLengthAtSnapshot); + runtime.earlyEventBuffer = []; + if (pendingEvents.length > 0) { + logger.debug( + `[GenerationJobManager] Captured ${pendingEvents.length} gap events for ${streamId}`, + ); + } + } + } + + const subscription = await this.subscribe(streamId, onChunk, onDone, onError, { + skipBufferReplay: true, + }); + + return { subscription, resumeState, pendingEvents }; + } + /** * Emit a chunk event to all subscribers. * Uses runtime state check for performance (avoids async job store lookup per token). diff --git a/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts b/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts index 59fe32e4e5..2f23510018 100644 --- a/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts +++ b/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts @@ -1,3 +1,4 @@ +/* eslint jest/no-standalone-expect: ["error", { "additionalTestBlockFunctions": ["testRedis"] }] */ import type { Redis, Cluster } from 'ioredis'; import type { ServerSentEvent } from '~/types/events'; import { InMemoryEventTransport } from '~/stream/implementations/InMemoryEventTransport'; @@ -27,6 +28,9 @@ describe('GenerationJobManager Integration Tests', () => { let dynamicKeyvClient: unknown = null; let dynamicKeyvReady: Promise | null = null; const testPrefix = 'JobManager-Integration-Test'; + const redisConfigured = process.env.USE_REDIS === 'true'; + const describeRedis = redisConfigured ? describe : describe.skip; + const testRedis = redisConfigured ? test : test.skip; beforeAll(async () => { originalEnv = { ...process.env }; @@ -82,6 +86,68 @@ describe('GenerationJobManager Integration Tests', () => { process.env = originalEnv; }); + function createInMemoryManager(): GenerationJobManagerClass { + const manager = new GenerationJobManagerClass(); + manager.configure({ + jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }), + eventTransport: new InMemoryEventTransport(), + isRedis: false, + }); + manager.initialize(); + return manager; + } + + function createRedisManager(): GenerationJobManagerClass { + const manager = new GenerationJobManagerClass(); + manager.configure( + createStreamServices({ + useRedis: true, + redisClient: ioredisClient!, + }), + ); + manager.initialize(); + return manager; + } + + async function setupDisconnectedStream( + manager: GenerationJobManagerClass, + streamId: string, + delay: number, + ): Promise { + const firstEvents: ServerSentEvent[] = []; + const sub = await manager.subscribe(streamId, (event) => firstEvents.push(event)); + + await new Promise((resolve) => setTimeout(resolve, delay)); + + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', runId: 'run-1', index: 0, stepDetails: { type: 'message_creation' } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: 'Hello' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, delay)); + expect(firstEvents.length).toBe(2); + + sub?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, delay)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: ' world' } } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: '!' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, delay)); + + return firstEvents; + } + describe('In-Memory Mode', () => { test('should create and manage jobs', async () => { // Configure with in-memory @@ -171,13 +237,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Redis Mode', () => { + describeRedis('Redis Mode', () => { test('should create and manage jobs via Redis', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // Create Redis services const services = createStreamServices({ useRedis: true, @@ -209,11 +270,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should persist chunks for cross-instance resume', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -264,11 +320,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should handle abort and return content', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -374,7 +425,7 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Cross-Replica Support (Redis)', () => { + describeRedis('Cross-Replica Support (Redis)', () => { /** * Problem: In k8s with Redis and multiple replicas, when a user sends a message: * 1. POST /api/agents/chat hits Replica A, creates job @@ -387,15 +438,10 @@ describe('GenerationJobManager Integration Tests', () => { * when the job exists in Redis but not in local memory. */ test('should NOT return 404 when stream endpoint hits different replica than job creator', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // === REPLICA A: Creates the job === // Simulate Replica A creating the job directly in Redis // (In real scenario, this happens via GenerationJobManager.createJob on Replica A) - const replicaAJobStore = new RedisJobStore(ioredisClient); + const replicaAJobStore = new RedisJobStore(ioredisClient!); await replicaAJobStore.initialize(); const streamId = `cross-replica-404-test-${Date.now()}`; @@ -452,13 +498,8 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should lazily create runtime state for jobs created on other replicas', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // Instance 1: Create the job directly in Redis (simulating another replica) - const jobStore = new RedisJobStore(ioredisClient); + const jobStore = new RedisJobStore(ioredisClient!); await jobStore.initialize(); const streamId = `cross-replica-${Date.now()}`; @@ -500,11 +541,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should persist syncSent to Redis for cross-replica consistency', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -539,11 +575,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should persist finalEvent to Redis for cross-replica access', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -581,11 +612,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should emit cross-replica abort signal via Redis pub/sub', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, @@ -620,16 +646,11 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should handle abort for lazily-initialized cross-replica jobs', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // This test validates that jobs created on Replica A and lazily-initialized // on Replica B can still receive and handle abort signals. // === Replica A: Create job directly in Redis === - const replicaAJobStore = new RedisJobStore(ioredisClient); + const replicaAJobStore = new RedisJobStore(ioredisClient!); await replicaAJobStore.initialize(); const streamId = `lazy-abort-${Date.now()}`; @@ -675,11 +696,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should abort generation when abort signal received from another replica', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // This test simulates: // 1. Replica A creates a job and starts generation // 2. Replica B receives abort request and emits abort signal @@ -729,13 +745,8 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should handle wasSyncSent for cross-replica scenarios', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - // Create job directly in Redis with syncSent: true - const jobStore = new RedisJobStore(ioredisClient); + const jobStore = new RedisJobStore(ioredisClient!); await jobStore.initialize(); const streamId = `cross-sync-${Date.now()}`; @@ -762,7 +773,7 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Sequential Event Ordering (Redis)', () => { + describeRedis('Sequential Event Ordering (Redis)', () => { /** * These tests verify that events are delivered in strict sequential order * when using Redis mode. This is critical because: @@ -773,11 +784,6 @@ describe('GenerationJobManager Integration Tests', () => { * The fix: emitChunk now awaits Redis publish to ensure ordered delivery. */ test('should maintain strict order for rapid sequential emits', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -823,11 +829,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should maintain order for tool call argument deltas', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -882,11 +883,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should maintain order: on_run_step before on_run_step_delta', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -945,11 +941,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should not block other streams when awaiting emitChunk', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - jest.resetModules(); const services = createStreamServices({ @@ -1069,12 +1060,7 @@ describe('GenerationJobManager Integration Tests', () => { await manager.destroy(); }); - test('should buffer and replay events emitted before subscribe (Redis)', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - + testRedis('should buffer and replay events emitted before subscribe (Redis)', async () => { const manager = new GenerationJobManagerClass(); const services = createStreamServices({ useRedis: true, @@ -1118,67 +1104,60 @@ describe('GenerationJobManager Integration Tests', () => { await manager.destroy(); }); - test('should not lose events when emitting before and after subscribe (Redis)', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - - const manager = new GenerationJobManagerClass(); - const services = createStreamServices({ - useRedis: true, - redisClient: ioredisClient, - }); - - manager.configure(services); - manager.initialize(); - - const streamId = `no-loss-${Date.now()}`; - await manager.createJob(streamId, 'user-1'); - - await manager.emitChunk(streamId, { - created: true, - message: { text: 'hello' }, - streamId, - } as unknown as ServerSentEvent); - await manager.emitChunk(streamId, { - event: 'on_run_step', - data: { id: 'step-1', type: 'message_creation', index: 0 }, - }); - - const receivedEvents: unknown[] = []; - const subscription = await manager.subscribe(streamId, (event: unknown) => - receivedEvents.push(event), - ); - - await new Promise((resolve) => setTimeout(resolve, 100)); - - for (let i = 0; i < 10; i++) { - await manager.emitChunk(streamId, { - event: 'on_message_delta', - data: { delta: { content: { type: 'text', text: `word${i} ` } }, index: i }, + testRedis( + 'should not lose events when emitting before and after subscribe (Redis)', + async () => { + const manager = new GenerationJobManagerClass(); + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, }); - } - await new Promise((resolve) => setTimeout(resolve, 300)); + manager.configure(services); + manager.initialize(); - expect(receivedEvents.length).toBe(12); - expect((receivedEvents[0] as Record).created).toBe(true); - expect((receivedEvents[1] as Record).event).toBe('on_run_step'); - for (let i = 0; i < 10; i++) { - expect((receivedEvents[i + 2] as Record).event).toBe('on_message_delta'); - } + const streamId = `no-loss-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); - subscription?.unsubscribe(); - await manager.destroy(); - }); + await manager.emitChunk(streamId, { + created: true, + message: { text: 'hello' }, + streamId, + } as unknown as ServerSentEvent); + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', type: 'message_creation', index: 0 }, + }); - test('RedisEventTransport.subscribe() should return a ready promise', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } + const receivedEvents: unknown[] = []; + const subscription = await manager.subscribe(streamId, (event: unknown) => + receivedEvents.push(event), + ); + await new Promise((resolve) => setTimeout(resolve, 100)); + + for (let i = 0; i < 10; i++) { + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: `word${i} ` } }, index: i }, + }); + } + + await new Promise((resolve) => setTimeout(resolve, 300)); + + expect(receivedEvents.length).toBe(12); + expect((receivedEvents[0] as Record).created).toBe(true); + expect((receivedEvents[1] as Record).event).toBe('on_run_step'); + for (let i = 0; i < 10; i++) { + expect((receivedEvents[i + 2] as Record).event).toBe('on_message_delta'); + } + + subscription?.unsubscribe(); + await manager.destroy(); + }, + ); + + testRedis('RedisEventTransport.subscribe() should return a ready promise', async () => { const subscriber = (ioredisClient as unknown as { duplicate: () => unknown }).duplicate(); const transport = new RedisEventTransport(ioredisClient as never, subscriber as never); @@ -1211,6 +1190,421 @@ describe('GenerationJobManager Integration Tests', () => { }); }); + describe('Resume: skipBufferReplay prevents duplication', () => { + /** + * Verifies the fix for duplicated content when navigating away from an + * in-progress conversation and back. Events accumulate in earlyEventBuffer + * while the subscriber is absent. On resume, the sync event delivers all + * accumulated content via aggregatedContent, so buffer replay must be + * skipped to prevent duplication. + */ + + test('should NOT replay buffer when skipBufferReplay is true (resume scenario)', async () => { + const manager = createInMemoryManager(); + const streamId = `skip-buf-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + await setupDisconnectedStream(manager, streamId, 10); + + const resumeState = await manager.getResumeState(streamId); + expect(resumeState).not.toBeNull(); + + const resumeEvents: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => resumeEvents.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(resumeEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: ' Live!' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(resumeEvents.length).toBe(1); + expect(resumeEvents[0].event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }); + + test('should replay buffer by default when no options are passed', async () => { + const manager = createInMemoryManager(); + const streamId = `replay-buf-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1Events: ServerSentEvent[] = []; + const sub1 = await manager.subscribe(streamId, (event) => sub1Events.push(event)); + await new Promise((resolve) => setTimeout(resolve, 10)); + + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', runId: 'run-1', index: 0, stepDetails: { type: 'message_creation' } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: 'buffered' } } }, + }); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe(streamId, (event) => sub2Events.push(event)); + await new Promise((resolve) => setTimeout(resolve, 20)); + + expect(sub2Events.length).toBe(1); + expect(sub2Events[0].event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }); + + test('should clear earlyEventBuffer even when skipping replay (no memory leak)', async () => { + const manager = createInMemoryManager(); + const streamId = `buf-clear-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buf1' } } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buf2' } } }, + }); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => sub2Events.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(sub2Events.length).toBe(0); + + sub2?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'new-event' } } }, + }); + + const sub3Events: ServerSentEvent[] = []; + const sub3 = await manager.subscribe(streamId, (event) => sub3Events.push(event)); + await new Promise((resolve) => setTimeout(resolve, 20)); + + expect(sub3Events.length).toBe(1); + const event = sub3Events[0] as { + event: string; + data: { delta: { content: { text: string } } }; + }; + expect(event.data.delta.content.text).toBe('new-event'); + + sub3?.unsubscribe(); + await manager.destroy(); + }); + + test('should handle multiple disconnect/reconnect cycles with skipBufferReplay', async () => { + const manager = createInMemoryManager(); + const streamId = `multi-reconnect-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'initial' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-1' } } }, + }); + + const resumeState1 = await manager.getResumeState(streamId); + expect(resumeState1).not.toBeNull(); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => sub2Events.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub2Events.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-1' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub2Events.length).toBe(1); + + sub2?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-2' } } }, + }); + + const resumeState2 = await manager.getResumeState(streamId); + expect(resumeState2).not.toBeNull(); + + const sub3Events: ServerSentEvent[] = []; + const sub3 = await manager.subscribe( + streamId, + (event) => sub3Events.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub3Events.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-2' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(sub3Events.length).toBe(1); + + sub3?.unsubscribe(); + await manager.destroy(); + }); + + testRedis('should NOT replay buffer when skipBufferReplay is true (Redis)', async () => { + const manager = createRedisManager(); + const streamId = `skip-buf-redis-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + await setupDisconnectedStream(manager, streamId, 100); + + const resumeState = await manager.getResumeState(streamId); + expect(resumeState).not.toBeNull(); + expect(resumeState!.aggregatedContent?.length).toBeGreaterThan(0); + + const resumeEvents: ServerSentEvent[] = []; + const sub2 = await manager.subscribe( + streamId, + (event) => resumeEvents.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + + await new Promise((resolve) => setTimeout(resolve, 200)); + expect(resumeEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: ' Live!' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 200)); + expect(resumeEvents.length).toBe(1); + expect(resumeEvents[0].event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }); + + testRedis( + 'should replay buffer without skipBufferReplay after disconnect (Redis)', + async () => { + const manager = createRedisManager(); + const streamId = `replay-buf-redis-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 100)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 100)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-redis' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const sub2Events: ServerSentEvent[] = []; + const sub2 = await manager.subscribe(streamId, (event) => sub2Events.push(event)); + + await new Promise((resolve) => setTimeout(resolve, 200)); + + expect(sub2Events.length).toBe(1); + expect(sub2Events[0].event).toBe('on_message_delta'); + + sub2?.unsubscribe(); + await manager.destroy(); + }, + ); + }); + + describe('Atomic subscribeWithResume', () => { + test('should return empty pendingEvents for pre-snapshot buffer events (in-memory)', async () => { + const manager = createInMemoryManager(); + const streamId = `atomic-drain-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_run_step', + data: { id: 'step-1', runId: 'run-1', index: 0, stepDetails: { type: 'message_creation' } }, + }); + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { id: 'step-1', delta: { content: { type: 'text', text: 'buffered' } } }, + }); + + const liveEvents: ServerSentEvent[] = []; + const { subscription, resumeState, pendingEvents } = await manager.subscribeWithResume( + streamId, + (event) => liveEvents.push(event), + ); + + expect(resumeState).not.toBeNull(); + expect(pendingEvents.length).toBe(0); + expect(liveEvents.length).toBe(0); + + subscription?.unsubscribe(); + await manager.destroy(); + }); + + test('should return empty pendingEvents when buffer is empty', async () => { + const manager = createInMemoryManager(); + const streamId = `atomic-empty-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'delivered' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + const { pendingEvents } = await manager.subscribeWithResume(streamId, () => {}); + + expect(pendingEvents.length).toBe(0); + + await manager.destroy(); + }); + + test('should deliver live events after subscribeWithResume', async () => { + const manager = createInMemoryManager(); + const streamId = `atomic-live-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 10)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 20)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-pre-snapshot' } } }, + }); + + const liveEvents: ServerSentEvent[] = []; + const { subscription, pendingEvents } = await manager.subscribeWithResume(streamId, (event) => + liveEvents.push(event), + ); + + expect(pendingEvents.length).toBe(0); + expect(liveEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-after' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(liveEvents.length).toBe(1); + const liveEvent = liveEvents[0] as { + event: string; + data: { delta: { content: { text: string } } }; + }; + expect(liveEvent.data.delta.content.text).toBe('live-after'); + + subscription?.unsubscribe(); + await manager.destroy(); + }); + + testRedis( + 'should return empty pendingEvents in Redis mode (chunks already persisted)', + async () => { + const manager = createRedisManager(); + const streamId = `atomic-redis-${Date.now()}`; + await manager.createJob(streamId, 'user-1'); + + const sub1 = await manager.subscribe(streamId, () => {}); + await new Promise((resolve) => setTimeout(resolve, 100)); + sub1?.unsubscribe(); + await new Promise((resolve) => setTimeout(resolve, 100)); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'buffered-redis' } } }, + }); + await new Promise((resolve) => setTimeout(resolve, 100)); + + const liveEvents: ServerSentEvent[] = []; + const { subscription, resumeState, pendingEvents } = await manager.subscribeWithResume( + streamId, + (event) => liveEvents.push(event), + ); + + expect(resumeState).not.toBeNull(); + expect(pendingEvents.length).toBe(0); + + await manager.emitChunk(streamId, { + event: 'on_message_delta', + data: { delta: { content: { type: 'text', text: 'live-redis' } } }, + }); + + await new Promise((resolve) => setTimeout(resolve, 200)); + expect(liveEvents.length).toBe(1); + + subscription?.unsubscribe(); + await manager.destroy(); + }, + ); + }); + describe('Error Preservation for Late Subscribers', () => { /** * These tests verify the fix for the race condition where errors @@ -1369,14 +1763,9 @@ describe('GenerationJobManager Integration Tests', () => { await GenerationJobManager.destroy(); }); - test('should handle error preservation in Redis mode (cross-replica)', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - + testRedis('should handle error preservation in Redis mode (cross-replica)', async () => { // === Replica A: Creates job and emits error === - const replicaAJobStore = new RedisJobStore(ioredisClient); + const replicaAJobStore = new RedisJobStore(ioredisClient!); await replicaAJobStore.initialize(); const streamId = `redis-error-${Date.now()}`; @@ -1463,13 +1852,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Cross-Replica Live Streaming (Redis)', () => { + describeRedis('Cross-Replica Live Streaming (Redis)', () => { test('should publish events to Redis even when no local subscriber exists', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const replicaA = new GenerationJobManagerClass(); const servicesA = createStreamServices({ useRedis: true, @@ -1489,7 +1873,7 @@ describe('GenerationJobManager Integration Tests', () => { const streamId = `cross-live-${Date.now()}`; await replicaA.createJob(streamId, 'user-1'); - const replicaBJobStore = new RedisJobStore(ioredisClient); + const replicaBJobStore = new RedisJobStore(ioredisClient!); await replicaBJobStore.initialize(); await replicaBJobStore.createJob(streamId, 'user-1'); @@ -1519,11 +1903,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should not cause data loss on cross-replica subscribers when local subscriber joins', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const replicaA = new GenerationJobManagerClass(); const servicesA = createStreamServices({ useRedis: true, @@ -1543,7 +1922,7 @@ describe('GenerationJobManager Integration Tests', () => { const streamId = `cross-seq-safe-${Date.now()}`; await replicaA.createJob(streamId, 'user-1'); - const replicaBJobStore = new RedisJobStore(ioredisClient); + const replicaBJobStore = new RedisJobStore(ioredisClient!); await replicaBJobStore.initialize(); await replicaBJobStore.createJob(streamId, 'user-1'); @@ -1603,11 +1982,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should deliver buffered events locally AND publish live events cross-replica', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const replicaA = new GenerationJobManagerClass(); const servicesA = createStreamServices({ useRedis: true, @@ -1641,7 +2015,7 @@ describe('GenerationJobManager Integration Tests', () => { replicaB.configure(servicesB); replicaB.initialize(); - const replicaBJobStore = new RedisJobStore(ioredisClient); + const replicaBJobStore = new RedisJobStore(ioredisClient!); await replicaBJobStore.initialize(); await replicaBJobStore.createJob(streamId, 'user-1'); @@ -1671,13 +2045,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Concurrent Subscriber Readiness (Redis)', () => { + describeRedis('Concurrent Subscriber Readiness (Redis)', () => { test('should return ready promise to all concurrent subscribers for same stream', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const subscriber = ( ioredisClient as unknown as { duplicate: () => typeof ioredisClient } ).duplicate()!; @@ -1706,13 +2075,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Sequence Reset Safety (Redis)', () => { + describeRedis('Sequence Reset Safety (Redis)', () => { test('should not receive stale pre-subscribe events via Redis after sequence reset', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const manager = new GenerationJobManagerClass(); const services = createStreamServices({ useRedis: true, @@ -1774,11 +2138,6 @@ describe('GenerationJobManager Integration Tests', () => { }); test('should not reset sequence when second subscriber joins mid-stream', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const manager = new GenerationJobManagerClass(); const services = createStreamServices({ useRedis: true, @@ -1837,13 +2196,8 @@ describe('GenerationJobManager Integration Tests', () => { }); }); - describe('Subscribe Error Recovery (Redis)', () => { + describeRedis('Subscribe Error Recovery (Redis)', () => { test('should allow resubscription after Redis subscribe failure', async () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - const subscriber = ( ioredisClient as unknown as { duplicate: () => typeof ioredisClient } ).duplicate()!; @@ -1892,12 +2246,7 @@ describe('GenerationJobManager Integration Tests', () => { }); describe('createStreamServices Auto-Detection', () => { - test('should use Redis when useRedis is true and client is available', () => { - if (!ioredisClient) { - console.warn('Redis not available, skipping test'); - return; - } - + testRedis('should use Redis when useRedis is true and client is available', () => { const services = createStreamServices({ useRedis: true, redisClient: ioredisClient, diff --git a/packages/api/src/types/stream.ts b/packages/api/src/types/stream.ts index 79b29d774f..068d9c8db8 100644 --- a/packages/api/src/types/stream.ts +++ b/packages/api/src/types/stream.ts @@ -47,3 +47,24 @@ export type ChunkHandler = (event: ServerSentEvent) => void; export type DoneHandler = (event: ServerSentEvent) => void; export type ErrorHandler = (error: string) => void; export type UnsubscribeFn = () => void; + +/** Options for subscribing to a job event stream */ +export interface SubscribeOptions { + /** + * When true, skips replaying the earlyEventBuffer. + * Use for resume connections after a sync event has been sent. + */ + skipBufferReplay?: boolean; +} + +/** Result of an atomic subscribe-with-resume operation */ +export interface SubscribeWithResumeResult { + subscription: { unsubscribe: UnsubscribeFn } | null; + resumeState: ResumeState | null; + /** + * Events that arrived between the resume snapshot and the subscribe call. + * In-memory mode: drained from earlyEventBuffer (only place they exist). + * Redis mode: empty — chunks are persisted to the store and appear in aggregatedContent on next resume. + */ + pendingEvents: ServerSentEvent[]; +} From 83184467043016d3e48b0d86c212d73b510eefa3 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 21:22:25 -0400 Subject: [PATCH 20/39] =?UTF-8?q?=F0=9F=92=81=20refactor:=20Better=20Confi?= =?UTF-8?q?g=20UX=20for=20MCP=20STDIO=20with=20`customUserVars`=20(#12226)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: Better UX for MCP stdio with Custom User Variables - Updated the ConnectionsRepository to prevent connections when customUserVars are defined, improving security and access control. - Modified the MCPServerInspector to skip capabilities fetch when customUserVars are present, streamlining server inspection. - Added tests to validate connection restrictions with customUserVars, ensuring robust handling of various server configurations. This change enhances the overall integrity of the connection management process by enforcing stricter rules around custom user variables. * fix: guard against empty customUserVars and add JSDoc context - Extract `hasCustomUserVars()` helper to guard against truthy `{}` (Zod's `.record().optional()` yields `{}` on empty input, not `undefined`) - Add JSDoc to `isAllowedToConnectToServer` explaining why customUserVars servers are excluded from app-level connections * test: improve customUserVars test coverage and fixture hygiene - Add no-connection-provided test for MCPServerInspector (production path) - Fix test descriptions to match actual fixture values - Replace real package name with fictional @test/mcp-stdio-server --- packages/api/src/mcp/ConnectionsRepository.ts | 18 +++++-- .../__tests__/ConnectionsRepository.test.ts | 44 +++++++++++++++++ .../src/mcp/registry/MCPServerInspector.ts | 7 ++- .../__tests__/MCPServerInspector.test.ts | 48 +++++++++++++++++++ packages/api/src/mcp/utils.ts | 5 ++ 5 files changed, 116 insertions(+), 6 deletions(-) diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index e629934dda..970e7ea4b9 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -1,8 +1,9 @@ import { logger } from '@librechat/data-schemas'; -import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { MCPConnection } from './connection'; -import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import type * as t from './types'; +import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { hasCustomUserVars } from './utils'; +import { MCPConnection } from './connection'; const CONNECT_CONCURRENCY = 3; @@ -139,12 +140,19 @@ export class ConnectionsRepository { return `[MCP][${serverName}]`; } + /** + * App-level (shared) connections cannot serve servers that need per-user context: + * env/header placeholders like `{{MY_KEY}}` are only resolved by `processMCPEnv()` + * when real `customUserVars` values exist — which requires a user-level connection. + */ private isAllowedToConnectToServer(config: t.ParsedServerConfig) { if (config.inspectionFailed) { return false; } - //the repository is not allowed to be connected in case the Connection repository is shared (ownerId is undefined/null) and the server requires Auth or startup false. - if (this.ownerId === undefined && (config.startup === false || config.requiresOAuth)) { + if ( + this.ownerId === undefined && + (config.startup === false || config.requiresOAuth || hasCustomUserVars(config)) + ) { return false; } return true; diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts index 3b827774d0..98e15eca18 100644 --- a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -392,6 +392,36 @@ describe('ConnectionsRepository', () => { expect(await repository.has('oauthDisabledServer')).toBe(false); }); + it('should NOT allow connection to servers with customUserVars', async () => { + mockServerConfigs.customVarServer = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + expect(await repository.has('customVarServer')).toBe(false); + }); + + it('should NOT allow connection when customUserVars is defined, even when startup is explicitly true', async () => { + mockServerConfigs.customVarStartupServer = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { TOKEN: '{{USER_TOKEN}}' }, + startup: true, + requiresOAuth: false, + customUserVars: { + USER_TOKEN: { title: 'Token', description: 'Your token' }, + }, + }; + + expect(await repository.has('customVarStartupServer')).toBe(false); + }); + it('should disconnect existing connection when server becomes not allowed', async () => { // Initially setup as regular server mockServerConfigs.changingServer = { @@ -471,6 +501,20 @@ describe('ConnectionsRepository', () => { expect(await repository.has('oauthDisabledServer')).toBe(true); }); + it('should allow connection to servers with customUserVars', async () => { + mockServerConfigs.customVarServer = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + expect(await repository.has('customVarServer')).toBe(true); + }); + it('should return null from get() when server config does not exist', async () => { const connection = await repository.get('nonexistent'); expect(connection).toBeNull(); diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index eea52bbf2e..a477d9b412 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -6,6 +6,7 @@ import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPDomainNotAllowedError } from '~/mcp/errors'; import { detectOAuthRequirement } from '~/mcp/oauth'; +import { hasCustomUserVars } from '~/mcp/utils'; import { isEnabled } from '~/utils'; /** @@ -54,7 +55,11 @@ export class MCPServerInspector { private async inspectServer(): Promise { await this.detectOAuth(); - if (this.config.startup !== false && !this.config.requiresOAuth) { + if ( + this.config.startup !== false && + !this.config.requiresOAuth && + !hasCustomUserVars(this.config) + ) { let tempConnection = false; if (!this.connection) { tempConnection = true; diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts index b79f2d044a..f0ab75c9b4 100644 --- a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -100,6 +100,54 @@ describe('MCPServerInspector', () => { }); }); + it('should skip capabilities fetch when customUserVars is defined', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + requiresOAuth: false, + initDuration: expect.any(Number), + }); + + expect(MCPConnectionFactory.create).not.toHaveBeenCalled(); + expect(mockConnection.disconnect).not.toHaveBeenCalled(); + }); + + it('should NOT create a temp connection when customUserVars is defined and no connection is provided', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'npx', + args: ['-y', '@test/mcp-stdio-server'], + env: { API_KEY: '{{MY_KEY}}' }, + customUserVars: { + MY_KEY: { title: 'API Key', description: 'Your API key' }, + }, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig); + + expect(MCPConnectionFactory.create).not.toHaveBeenCalled(); + expect(result.requiresOAuth).toBe(false); + expect(result.capabilities).toBeUndefined(); + expect(result.toolFunctions).toBeUndefined(); + }); + it('should keep custom serverInstructions string and not fetch from server', async () => { const rawConfig: t.MCPOptions = { type: 'stdio', diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index c517388a76..ff367725fc 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -3,6 +3,11 @@ import type { ParsedServerConfig } from '~/mcp/types'; export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); +/** Checks that `customUserVars` is present AND non-empty (guards against truthy `{}`) */ +export function hasCustomUserVars(config: Pick): boolean { + return !!config.customUserVars && Object.keys(config.customUserVars).length > 0; +} + /** * Allowlist-based sanitization for API responses. Only explicitly listed fields are included; * new fields added to ParsedServerConfig are excluded by default until allowlisted here. From 7c39a4594463442d07d09bba03c4e24639d59d70 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 14 Mar 2026 22:43:18 -0400 Subject: [PATCH 21/39] =?UTF-8?q?=F0=9F=90=8D=20refactor:=20Normalize=20No?= =?UTF-8?q?n-Standard=20Browser=20MIME=20Type=20Aliases=20in=20`inferMimeT?= =?UTF-8?q?ype`=20(#12240)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🐛 fix: Normalize non-standard browser MIME types in inferMimeType macOS Chrome/Firefox report .py files as text/x-python-script instead of text/x-python, causing client-side validation to reject Python file uploads. inferMimeType now normalizes known MIME type aliases before returning, so non-standard variants match the accepted regex patterns. * 🧪 test: Add tests for MIME type alias normalization in inferMimeType * 🐛 fix: Restore JSDoc params and make mimeTypeAliases immutable * 🧪 test: Add checkType integration tests, remove redundant DragDropModal tests --- .../data-provider/src/file-config.spec.ts | 41 ++++++++++++++++++- packages/data-provider/src/file-config.ts | 16 +++++--- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/packages/data-provider/src/file-config.spec.ts b/packages/data-provider/src/file-config.spec.ts index 018b4dbfcf..0ab9f23a3e 100644 --- a/packages/data-provider/src/file-config.spec.ts +++ b/packages/data-provider/src/file-config.spec.ts @@ -1,15 +1,52 @@ import type { FileConfig } from './types/files'; import { fileConfig as baseFileConfig, + documentParserMimeTypes, getEndpointFileConfig, - mergeFileConfig, applicationMimeTypes, defaultOCRMimeTypes, - documentParserMimeTypes, supportedMimeTypes, + mergeFileConfig, + inferMimeType, + textMimeTypes, } from './file-config'; import { EModelEndpoint } from './schemas'; +describe('inferMimeType', () => { + it('should normalize text/x-python-script to text/x-python', () => { + expect(inferMimeType('test.py', 'text/x-python-script')).toBe('text/x-python'); + }); + + it('should return a type that matches textMimeTypes after normalization', () => { + const normalized = inferMimeType('test.py', 'text/x-python-script'); + expect(textMimeTypes.test(normalized)).toBe(true); + }); + + it('should pass through standard browser types unchanged', () => { + expect(inferMimeType('test.py', 'text/x-python')).toBe('text/x-python'); + expect(inferMimeType('doc.pdf', 'application/pdf')).toBe('application/pdf'); + }); + + it('should infer from extension when browser type is empty', () => { + expect(inferMimeType('test.py', '')).toBe('text/x-python'); + expect(inferMimeType('code.js', '')).toBe('text/javascript'); + expect(inferMimeType('photo.heic', '')).toBe('image/heic'); + }); + + it('should return empty string for unknown extension with no browser type', () => { + expect(inferMimeType('file.xyz', '')).toBe(''); + }); + + it('should produce a type accepted by checkType after normalizing text/x-python-script', () => { + const normalized = inferMimeType('test.py', 'text/x-python-script'); + expect(baseFileConfig.checkType(normalized)).toBe(true); + }); + + it('should reject raw text/x-python-script without normalization', () => { + expect(baseFileConfig.checkType('text/x-python-script')).toBe(false); + }); +}); + describe('applicationMimeTypes', () => { const odfTypes = [ 'application/vnd.oasis.opendocument.text', diff --git a/packages/data-provider/src/file-config.ts b/packages/data-provider/src/file-config.ts index 033c868a80..67b4197958 100644 --- a/packages/data-provider/src/file-config.ts +++ b/packages/data-provider/src/file-config.ts @@ -357,15 +357,21 @@ export const imageTypeMapping: { [key: string]: string } = { heif: 'image/heif', }; +/** Normalizes non-standard MIME types that browsers may report to their canonical forms */ +export const mimeTypeAliases: Readonly> = { + 'text/x-python-script': 'text/x-python', +}; + /** - * Infers the MIME type from a file's extension when the browser doesn't recognize it - * @param fileName - The name of the file including extension - * @param currentType - The current MIME type reported by the browser (may be empty) - * @returns The inferred MIME type if browser didn't provide one, otherwise the original type + * Infers the MIME type from a file's extension when the browser doesn't recognize it, + * and normalizes known non-standard MIME type aliases to their canonical forms. + * @param fileName - The file name including its extension + * @param currentType - The MIME type reported by the browser (may be empty string) + * @returns The normalized or inferred MIME type; empty string if unresolvable */ export function inferMimeType(fileName: string, currentType: string): string { if (currentType) { - return currentType; + return mimeTypeAliases[currentType] ?? currentType; } const extension = fileName.split('.').pop()?.toLowerCase() ?? ''; From 0c27ad2d55e7229cc43f4c72111fa8833369f60f Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 10:19:29 -0400 Subject: [PATCH 22/39] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20refactor:=20Scope?= =?UTF-8?q?=20Action=20Mutations=20by=20Parent=20Resource=20Ownership=20(#?= =?UTF-8?q?12237)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Scope action mutations by parent resource ownership Prevent cross-tenant action overwrites by validating that an existing action's agent_id/assistant_id matches the URL parameter before allowing updates or deletes. Without this, a user with EDIT access on their own agent could reference a foreign action_id to hijack another agent's action record. * 🛡️ fix: Harden action ownership checks and scope write filters - Remove && short-circuit that bypassed the guard when agent_id or assistant_id was falsy (e.g. assistant-owned actions have no agent_id, so the check was skipped entirely on the agents route). - Include agent_id / assistant_id in the updateAction and deleteAction query filters so the DB write itself enforces ownership atomically. - Log a warning when deleteAction returns null (silent no-op from data-integrity mismatch). * 📝 docs: Update Action model JSDoc to reflect scoped query params * ✅ test: Add Action ownership scoping tests Cover update, delete, and cross-type protection scenarios using MongoMemoryServer to verify that scoped query filters (agent_id, assistant_id) prevent cross-tenant overwrites and deletions at the database level. * 🛡️ fix: Scope updateAction filter in agent duplication handler * 🐛 fix: Use action metadata domain instead of action_id when duplicating agent actions The duplicate handler was splitting `action.action_id` by `actionDelimiter` to extract the domain, but `action_id` is a bare nanoid that doesn't contain the delimiter. This produced malformed entries in the duplicated agent's actions array (nanoid_action_newNanoid instead of domain_action_newNanoid). The domain is available on `action.metadata.domain`. * ✅ test: Add integration tests for agent duplication action handling Uses MongoMemoryServer with real Agent and Action models to verify: - Duplicated actions use metadata.domain (not action_id) for the agent actions array entries - Sensitive metadata fields are stripped from duplicated actions - Original action documents are not modified --- api/models/Action.js | 10 +- api/models/Action.spec.js | 250 ++++++++++++++++++ .../__tests__/v1.duplicate-actions.spec.js | 159 +++++++++++ api/server/controllers/agents/v1.js | 4 +- api/server/routes/agents/actions.js | 13 +- api/server/routes/assistants/actions.js | 15 +- 6 files changed, 437 insertions(+), 14 deletions(-) create mode 100644 api/models/Action.spec.js create mode 100644 api/server/controllers/agents/__tests__/v1.duplicate-actions.spec.js diff --git a/api/models/Action.js b/api/models/Action.js index 20aa20a7e4..f14c415d5b 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -4,9 +4,7 @@ const { Action } = require('~/db/models'); * Update an action with new data without overwriting existing properties, * or create a new action if it doesn't exist. * - * @param {Object} searchParams - The search parameters to find the action to update. - * @param {string} searchParams.action_id - The ID of the action to update. - * @param {string} searchParams.user - The user ID of the action's author. + * @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams * @param {Object} updateData - An object containing the properties to update. * @returns {Promise} The updated or newly created action document as a plain object. */ @@ -47,10 +45,8 @@ const getActions = async (searchParams, includeSensitive = false) => { /** * Deletes an action by params. * - * @param {Object} searchParams - The search parameters to find the action to delete. - * @param {string} searchParams.action_id - The ID of the action to delete. - * @param {string} searchParams.user - The user ID of the action's author. - * @returns {Promise} A promise that resolves to the deleted action document as a plain object, or null if no document was found. + * @param {{ action_id: string, agent_id?: string, assistant_id?: string, user?: string }} searchParams + * @returns {Promise} The deleted action document as a plain object, or null if no match. */ const deleteAction = async (searchParams) => { return await Action.findOneAndDelete(searchParams).lean(); diff --git a/api/models/Action.spec.js b/api/models/Action.spec.js new file mode 100644 index 0000000000..61a3b10f0f --- /dev/null +++ b/api/models/Action.spec.js @@ -0,0 +1,250 @@ +const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { actionSchema } = require('@librechat/data-schemas'); +const { updateAction, getActions, deleteAction } = require('./Action'); + +let mongoServer; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + if (!mongoose.models.Action) { + mongoose.model('Action', actionSchema); + } + await mongoose.connect(mongoUri); +}, 20000); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.models.Action.deleteMany({}); +}); + +const userId = new mongoose.Types.ObjectId(); + +describe('Action ownership scoping', () => { + describe('updateAction', () => { + it('updates when action_id and agent_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_1', + agent_id: 'agent_A', + metadata: { domain: 'example.com' }, + }); + + const result = await updateAction( + { action_id: 'act_1', agent_id: 'agent_A' }, + { metadata: { domain: 'updated.com' } }, + ); + + expect(result).not.toBeNull(); + expect(result.metadata.domain).toBe('updated.com'); + expect(result.agent_id).toBe('agent_A'); + }); + + it('does not update when agent_id does not match (creates a new doc via upsert)', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_1', + agent_id: 'agent_B', + metadata: { domain: 'victim.com', api_key: 'secret' }, + }); + + const result = await updateAction( + { action_id: 'act_1', agent_id: 'agent_A' }, + { user: userId, metadata: { domain: 'attacker.com' } }, + ); + + expect(result.metadata.domain).toBe('attacker.com'); + + const original = await mongoose.models.Action.findOne({ + action_id: 'act_1', + agent_id: 'agent_B', + }).lean(); + expect(original).not.toBeNull(); + expect(original.metadata.domain).toBe('victim.com'); + expect(original.metadata.api_key).toBe('secret'); + }); + + it('updates when action_id and assistant_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_2', + assistant_id: 'asst_X', + metadata: { domain: 'example.com' }, + }); + + const result = await updateAction( + { action_id: 'act_2', assistant_id: 'asst_X' }, + { metadata: { domain: 'updated.com' } }, + ); + + expect(result).not.toBeNull(); + expect(result.metadata.domain).toBe('updated.com'); + }); + + it('does not overwrite when assistant_id does not match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_2', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com', api_key: 'secret' }, + }); + + await updateAction( + { action_id: 'act_2', assistant_id: 'asst_attacker' }, + { user: userId, metadata: { domain: 'attacker.com' } }, + ); + + const original = await mongoose.models.Action.findOne({ + action_id: 'act_2', + assistant_id: 'asst_victim', + }).lean(); + expect(original).not.toBeNull(); + expect(original.metadata.domain).toBe('victim.com'); + expect(original.metadata.api_key).toBe('secret'); + }); + }); + + describe('deleteAction', () => { + it('deletes when action_id and agent_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del', + agent_id: 'agent_A', + metadata: { domain: 'example.com' }, + }); + + const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' }); + expect(result).not.toBeNull(); + expect(result.action_id).toBe('act_del'); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(0); + }); + + it('returns null and preserves the document when agent_id does not match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del', + agent_id: 'agent_B', + metadata: { domain: 'victim.com' }, + }); + + const result = await deleteAction({ action_id: 'act_del', agent_id: 'agent_A' }); + expect(result).toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(1); + }); + + it('deletes when action_id and assistant_id both match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del_asst', + assistant_id: 'asst_X', + metadata: { domain: 'example.com' }, + }); + + const result = await deleteAction({ action_id: 'act_del_asst', assistant_id: 'asst_X' }); + expect(result).not.toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(0); + }); + + it('returns null and preserves the document when assistant_id does not match', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_del_asst', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com' }, + }); + + const result = await deleteAction({ + action_id: 'act_del_asst', + assistant_id: 'asst_attacker', + }); + expect(result).toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(1); + }); + }); + + describe('getActions (unscoped baseline)', () => { + it('returns actions by action_id regardless of agent_id', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_shared', + agent_id: 'agent_B', + metadata: { domain: 'example.com' }, + }); + + const results = await getActions({ action_id: 'act_shared' }, true); + expect(results).toHaveLength(1); + expect(results[0].agent_id).toBe('agent_B'); + }); + + it('returns actions scoped by agent_id when provided', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_scoped', + agent_id: 'agent_A', + metadata: { domain: 'a.com' }, + }); + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_other', + agent_id: 'agent_B', + metadata: { domain: 'b.com' }, + }); + + const results = await getActions({ agent_id: 'agent_A' }); + expect(results).toHaveLength(1); + expect(results[0].action_id).toBe('act_scoped'); + }); + }); + + describe('cross-type protection', () => { + it('updateAction with agent_id filter does not overwrite assistant-owned action', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_cross', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com', api_key: 'secret' }, + }); + + await updateAction( + { action_id: 'act_cross', agent_id: 'agent_attacker' }, + { user: userId, metadata: { domain: 'evil.com' } }, + ); + + const original = await mongoose.models.Action.findOne({ + action_id: 'act_cross', + assistant_id: 'asst_victim', + }).lean(); + expect(original).not.toBeNull(); + expect(original.metadata.domain).toBe('victim.com'); + expect(original.metadata.api_key).toBe('secret'); + }); + + it('deleteAction with agent_id filter does not delete assistant-owned action', async () => { + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_cross_del', + assistant_id: 'asst_victim', + metadata: { domain: 'victim.com' }, + }); + + const result = await deleteAction({ action_id: 'act_cross_del', agent_id: 'agent_attacker' }); + expect(result).toBeNull(); + + const remaining = await mongoose.models.Action.countDocuments(); + expect(remaining).toBe(1); + }); + }); +}); diff --git a/api/server/controllers/agents/__tests__/v1.duplicate-actions.spec.js b/api/server/controllers/agents/__tests__/v1.duplicate-actions.spec.js new file mode 100644 index 0000000000..cc298bd03a --- /dev/null +++ b/api/server/controllers/agents/__tests__/v1.duplicate-actions.spec.js @@ -0,0 +1,159 @@ +jest.mock('~/server/services/PermissionService', () => ({ + findPubliclyAccessibleResources: jest.fn(), + findAccessibleResources: jest.fn(), + hasPublicPermission: jest.fn(), + grantPermission: jest.fn().mockResolvedValue({}), +})); + +jest.mock('~/server/services/Config', () => ({ + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), +})); + +const mongoose = require('mongoose'); +const { actionDelimiter } = require('librechat-data-provider'); +const { agentSchema, actionSchema } = require('@librechat/data-schemas'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { duplicateAgent } = require('../v1'); + +let mongoServer; + +beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + if (!mongoose.models.Agent) { + mongoose.model('Agent', agentSchema); + } + if (!mongoose.models.Action) { + mongoose.model('Action', actionSchema); + } + await mongoose.connect(mongoUri); +}, 20000); + +afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); +}); + +beforeEach(async () => { + await mongoose.models.Agent.deleteMany({}); + await mongoose.models.Action.deleteMany({}); +}); + +describe('duplicateAgentHandler — action domain extraction', () => { + it('builds duplicated action entries using metadata.domain, not action_id', async () => { + const userId = new mongoose.Types.ObjectId(); + const originalAgentId = `agent_original`; + + const agent = await mongoose.models.Agent.create({ + id: originalAgentId, + name: 'Test Agent', + author: userId.toString(), + provider: 'openai', + model: 'gpt-4', + tools: [], + actions: [`api.example.com${actionDelimiter}act_original`], + versions: [{ name: 'Test Agent', createdAt: new Date(), updatedAt: new Date() }], + }); + + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_original', + agent_id: originalAgentId, + metadata: { domain: 'api.example.com' }, + }); + + const req = { + params: { id: agent.id }, + user: { id: userId.toString() }, + }; + const res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(201); + + const { agent: newAgent, actions: newActions } = res.json.mock.calls[0][0]; + + expect(newAgent.id).not.toBe(originalAgentId); + expect(String(newAgent.author)).toBe(userId.toString()); + expect(newActions).toHaveLength(1); + expect(newActions[0].metadata.domain).toBe('api.example.com'); + expect(newActions[0].agent_id).toBe(newAgent.id); + + for (const actionEntry of newAgent.actions) { + const [domain, actionId] = actionEntry.split(actionDelimiter); + expect(domain).toBe('api.example.com'); + expect(actionId).toBeTruthy(); + expect(actionId).not.toBe('act_original'); + } + + const allActions = await mongoose.models.Action.find({}).lean(); + expect(allActions).toHaveLength(2); + + const originalAction = allActions.find((a) => a.action_id === 'act_original'); + expect(originalAction.agent_id).toBe(originalAgentId); + + const duplicatedAction = allActions.find((a) => a.action_id !== 'act_original'); + expect(duplicatedAction.agent_id).toBe(newAgent.id); + expect(duplicatedAction.metadata.domain).toBe('api.example.com'); + }); + + it('strips sensitive metadata fields from duplicated actions', async () => { + const userId = new mongoose.Types.ObjectId(); + const originalAgentId = 'agent_sensitive'; + + await mongoose.models.Agent.create({ + id: originalAgentId, + name: 'Sensitive Agent', + author: userId.toString(), + provider: 'openai', + model: 'gpt-4', + tools: [], + actions: [`secure.api.com${actionDelimiter}act_secret`], + versions: [{ name: 'Sensitive Agent', createdAt: new Date(), updatedAt: new Date() }], + }); + + await mongoose.models.Action.create({ + user: userId, + action_id: 'act_secret', + agent_id: originalAgentId, + metadata: { + domain: 'secure.api.com', + api_key: 'sk-secret-key-12345', + oauth_client_id: 'client_id_xyz', + oauth_client_secret: 'client_secret_xyz', + }, + }); + + const req = { + params: { id: originalAgentId }, + user: { id: userId.toString() }, + }; + const res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(201); + + const duplicatedAction = await mongoose.models.Action.findOne({ + agent_id: { $ne: originalAgentId }, + }).lean(); + + expect(duplicatedAction.metadata.domain).toBe('secure.api.com'); + expect(duplicatedAction.metadata.api_key).toBeUndefined(); + expect(duplicatedAction.metadata.oauth_client_id).toBeUndefined(); + expect(duplicatedAction.metadata.oauth_client_secret).toBeUndefined(); + + const originalAction = await mongoose.models.Action.findOne({ + action_id: 'act_secret', + }).lean(); + expect(originalAction.metadata.api_key).toBe('sk-secret-key-12345'); + }); +}); diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index a2c0d55186..1abba8b2c8 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -371,7 +371,7 @@ const duplicateAgentHandler = async (req, res) => { */ const duplicateAction = async (action) => { const newActionId = nanoid(); - const [domain] = action.action_id.split(actionDelimiter); + const { domain } = action.metadata; const fullActionId = `${domain}${actionDelimiter}${newActionId}`; // Sanitize sensitive metadata before persisting @@ -381,7 +381,7 @@ const duplicateAgentHandler = async (req, res) => { } const newAction = await updateAction( - { action_id: newActionId }, + { action_id: newActionId, agent_id: newAgentId }, { metadata: filteredMetadata, agent_id: newAgentId, diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index 12168ba28a..4643f096aa 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -143,6 +143,9 @@ router.post( if (actions_result && actions_result.length) { const action = actions_result[0]; + if (action.agent_id !== agent_id) { + return res.status(403).json({ message: 'Action does not belong to this agent' }); + } metadata = { ...action.metadata, ...metadata }; } @@ -184,7 +187,7 @@ router.post( } /** @type {[Action]} */ - const updatedAction = await updateAction({ action_id }, actionUpdateData); + const updatedAction = await updateAction({ action_id, agent_id }, actionUpdateData); const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret']; for (let field of sensitiveFields) { @@ -251,7 +254,13 @@ router.delete( { tools: updatedTools, actions: updatedActions }, { updatingUserId: req.user.id, forceVersion: true }, ); - await deleteAction({ action_id }); + const deleted = await deleteAction({ action_id, agent_id }); + if (!deleted) { + logger.warn('[Agent Action Delete] No matching action document found', { + action_id, + agent_id, + }); + } res.status(200).json({ message: 'Action deleted successfully' }); } catch (error) { const message = 'Trouble deleting the Agent Action'; diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index 57975d32a7..b085fbd36a 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -60,6 +60,9 @@ router.post('/:assistant_id', async (req, res) => { if (actions_result && actions_result.length) { const action = actions_result[0]; + if (action.assistant_id !== assistant_id) { + return res.status(403).json({ message: 'Action does not belong to this assistant' }); + } metadata = { ...action.metadata, ...metadata }; } @@ -117,7 +120,7 @@ router.post('/:assistant_id', async (req, res) => { // For new actions, use the assistant owner's user ID actionUpdateData.user = assistant_user || req.user.id; } - promises.push(updateAction({ action_id }, actionUpdateData)); + promises.push(updateAction({ action_id, assistant_id }, actionUpdateData)); /** @type {[AssistantDocument, Action]} */ let [assistantDocument, updatedAction] = await Promise.all(promises); @@ -196,9 +199,15 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { assistantUpdateData.user = req.user.id; } promises.push(updateAssistantDoc({ assistant_id }, assistantUpdateData)); - promises.push(deleteAction({ action_id })); + promises.push(deleteAction({ action_id, assistant_id })); - await Promise.all(promises); + const [, deletedAction] = await Promise.all(promises); + if (!deletedAction) { + logger.warn('[Assistant Action Delete] No matching action document found', { + action_id, + assistant_id, + }); + } res.status(200).json({ message: 'Action deleted successfully' }); } catch (error) { const message = 'Trouble deleting the Assistant Action'; From 93a628d7a2c2ba07b1c92a118a7d0774da9d706e Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 10:35:44 -0400 Subject: [PATCH 23/39] =?UTF-8?q?=F0=9F=93=8E=20fix:=20Respect=20fileConfi?= =?UTF-8?q?g.disabled=20for=20Agents=20Endpoint=20Upload=20Button=20(#1223?= =?UTF-8?q?8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: respect fileConfig.disabled for agents endpoint upload button The isAgents check was OR'd without the !isUploadDisabled guard, bypassing the fileConfig.endpoints.agents.disabled setting and always rendering the attach file menu for agents. * test: add regression tests for fileConfig.disabled upload guard Cover the isUploadDisabled rendering gate for agents and assistants endpoints, preventing silent reintroduction of the bypass bug. * test: cover disabled fallback chain in useAgentFileConfig Verify agents-disabled propagates when no provider is set, when provider has no specific config (agents as fallback), and that provider-specific enabled overrides agents disabled. --- .../Chat/Input/Files/AttachFileChat.tsx | 2 +- .../Files/__tests__/AttachFileChat.spec.tsx | 59 ++++++++++++++++++- .../Agents/__tests__/AgentFileConfig.spec.tsx | 47 ++++++++++++++- 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/client/src/components/Chat/Input/Files/AttachFileChat.tsx b/client/src/components/Chat/Input/Files/AttachFileChat.tsx index 00a0b7aaa8..2f954d01d5 100644 --- a/client/src/components/Chat/Input/Files/AttachFileChat.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileChat.tsx @@ -91,7 +91,7 @@ function AttachFileChat({ if (isAssistants && endpointSupportsFiles && !isUploadDisabled) { return ; - } else if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) { + } else if ((isAgents || endpointSupportsFiles) && !isUploadDisabled) { return ( > = {}; let mockAgentQueryData: Partial | undefined; @@ -65,6 +67,7 @@ function renderComponent(conversation: Record | null, disableIn describe('AttachFileChat', () => { beforeEach(() => { + mockFileConfig = defaultFileConfig; mockAgentsMap = {}; mockAgentQueryData = undefined; mockAttachFileMenuProps = {}; @@ -148,6 +151,60 @@ describe('AttachFileChat', () => { }); }); + describe('upload disabled rendering', () => { + it('renders null for agents endpoint when fileConfig.agents.disabled is true', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.agents]: { disabled: true }, + }, + }); + const { container } = renderComponent({ + endpoint: EModelEndpoint.agents, + agent_id: 'agent-1', + }); + expect(container.innerHTML).toBe(''); + }); + + it('renders null for agents endpoint when disableInputs is true', () => { + const { container } = renderComponent( + { endpoint: EModelEndpoint.agents, agent_id: 'agent-1' }, + true, + ); + expect(container.innerHTML).toBe(''); + }); + + it('renders AttachFile for assistants endpoint when not disabled', () => { + renderComponent({ endpoint: EModelEndpoint.assistants }); + expect(screen.getByTestId('attach-file')).toBeInTheDocument(); + }); + + it('renders AttachFileMenu when provider-specific config overrides agents disabled', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + Moonshot: { disabled: false, fileLimit: 5 }, + [EModelEndpoint.agents]: { disabled: true }, + }, + }); + mockAgentsMap = { + 'agent-1': { provider: 'Moonshot', model_parameters: {} } as Partial, + }; + renderComponent({ endpoint: EModelEndpoint.agents, agent_id: 'agent-1' }); + expect(screen.getByTestId('attach-file-menu')).toBeInTheDocument(); + }); + + it('renders null for assistants endpoint when fileConfig.assistants.disabled is true', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.assistants]: { disabled: true }, + }, + }); + const { container } = renderComponent({ + endpoint: EModelEndpoint.assistants, + }); + expect(container.innerHTML).toBe(''); + }); + }); + describe('endpointFileConfig resolution', () => { it('passes Moonshot-specific file config for agent with Moonshot provider', () => { mockAgentsMap = { diff --git a/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx b/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx index aeb0dd3ff9..2bbd3fea22 100644 --- a/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx +++ b/client/src/components/SidePanel/Agents/__tests__/AgentFileConfig.spec.tsx @@ -18,7 +18,7 @@ const mockEndpointsConfig: TEndpointsConfig = { 'Some Endpoint': { type: EModelEndpoint.custom, userProvide: false, order: 9999 }, }; -let mockFileConfig = mergeFileConfig({ +const defaultFileConfig = mergeFileConfig({ endpoints: { Moonshot: { fileLimit: 5 }, [EModelEndpoint.agents]: { fileLimit: 20 }, @@ -26,6 +26,8 @@ let mockFileConfig = mergeFileConfig({ }, }); +let mockFileConfig = defaultFileConfig; + jest.mock('~/data-provider', () => ({ useGetEndpointsQuery: () => ({ data: mockEndpointsConfig }), useGetFileConfig: ({ select }: { select?: (data: unknown) => unknown }) => ({ @@ -118,13 +120,16 @@ describe('AgentPanel file config resolution (useAgentFileConfig)', () => { }); describe('disabled state', () => { + beforeEach(() => { + mockFileConfig = defaultFileConfig; + }); + it('reports not disabled for standard config', () => { render(); expect(screen.getByTestId('disabled').textContent).toBe('false'); }); it('reports disabled when provider-specific config is disabled', () => { - const original = mockFileConfig; mockFileConfig = mergeFileConfig({ endpoints: { Moonshot: { disabled: true }, @@ -135,8 +140,44 @@ describe('AgentPanel file config resolution (useAgentFileConfig)', () => { render(); expect(screen.getByTestId('disabled').textContent).toBe('true'); + }); - mockFileConfig = original; + it('reports disabled when agents config is disabled and no provider set', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.agents]: { disabled: true }, + default: { fileLimit: 10 }, + }, + }); + + render(); + expect(screen.getByTestId('disabled').textContent).toBe('true'); + }); + + it('reports disabled when agents is disabled and provider has no specific config', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + [EModelEndpoint.agents]: { disabled: true }, + default: { fileLimit: 10 }, + }, + }); + + render(); + expect(screen.getByTestId('disabled').textContent).toBe('true'); + }); + + it('provider-specific enabled overrides agents disabled', () => { + mockFileConfig = mergeFileConfig({ + endpoints: { + Moonshot: { disabled: false, fileLimit: 5 }, + [EModelEndpoint.agents]: { disabled: true }, + default: { fileLimit: 10 }, + }, + }); + + render(); + expect(screen.getByTestId('disabled').textContent).toBe('false'); + expect(screen.getByTestId('fileLimit').textContent).toBe('5'); }); }); From e079fc4900113711a704933e3059fdf215e50317 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 10:39:42 -0400 Subject: [PATCH 24/39] =?UTF-8?q?=F0=9F=93=8E=20fix:=20Enforce=20File=20Co?= =?UTF-8?q?unt=20and=20Size=20Limits=20Across=20All=20Attachment=20Paths?= =?UTF-8?q?=20(#12239)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🐛 fix: Enforce fileLimit and totalSizeLimit in Attached Files panel The Files side panel (PanelTable) was not checking fileLimit or totalSizeLimit from fileConfig when attaching previously uploaded files, allowing users to bypass per-endpoint file count and total size limits. * 🔧 fix: Address review findings on file limit enforcement - Fix totalSizeLimit double-counting size of already-attached files - Clarify fileLimit error message: "File limit reached: N files (endpoint)" - Replace Array.from(...).reduce with for...of loop to avoid intermediate allocation - Extract inline `type TFile` into standalone `import type` per project conventions * ✅ test: Add PanelTable handleFileClick file limit tests Cover fileLimit guard, totalSizeLimit guard, passing case, double-count prevention for re-attached files, and boundary case. * 🔧 test: Harden PanelTable test mock setup - Use explicit endpoint key matching mockConversation.endpoint instead of relying on default fallback behavior - Add supportedMimeTypes to mock config for explicit MIME coverage - Throw on missing filename cell in clickFilenameCell to prevent silent false-positive blocking assertions * ♻️ refactor: Align file validation ordering and messaging across upload paths - Reorder handleFileClick checks to match validateFiles: disabled → fileLimit → fileSizeLimit → checkType → totalSizeLimit - Change fileSizeLimit comparison from > to >= in handleFileClick to match validateFiles behavior - Align validateFiles error strings with localized key wording: "File limit reached:", "File size limit exceeded:", etc. - Remove stray console.log in validateFiles MIME-type check * ✅ test: Add validateFiles unit tests for both paths' consistency 13 tests covering disabled, empty, fileLimit (reject + boundary), fileSizeLimit (>= at limit + under limit), checkType, totalSizeLimit (reject + at limit), duplicate detection, and check ordering. Ensures both validateFiles and handleFileClick enforce the same validation rules in the same order. --- .../components/SidePanel/Files/PanelTable.tsx | 38 ++- .../Files/__tests__/PanelTable.spec.tsx | 239 ++++++++++++++++++ client/src/locales/en/translation.json | 2 + .../src/utils/__tests__/validateFiles.spec.ts | 172 +++++++++++++ client/src/utils/files.ts | 9 +- 5 files changed, 448 insertions(+), 12 deletions(-) create mode 100644 client/src/components/SidePanel/Files/__tests__/PanelTable.spec.tsx create mode 100644 client/src/utils/__tests__/validateFiles.spec.ts diff --git a/client/src/components/SidePanel/Files/PanelTable.tsx b/client/src/components/SidePanel/Files/PanelTable.tsx index 2fc8f7031b..e67e16abdd 100644 --- a/client/src/components/SidePanel/Files/PanelTable.tsx +++ b/client/src/components/SidePanel/Files/PanelTable.tsx @@ -24,14 +24,14 @@ import { type ColumnFiltersState, } from '@tanstack/react-table'; import { - fileConfig as defaultFileConfig, - checkOpenAIStorage, - mergeFileConfig, megabyte, + mergeFileConfig, + checkOpenAIStorage, isAssistantsEndpoint, getEndpointFileConfig, - type TFile, + fileConfig as defaultFileConfig, } from 'librechat-data-provider'; +import type { TFile } from 'librechat-data-provider'; import { MyFilesModal } from '~/components/Chat/Input/Files/MyFilesModal'; import { useFileMapContext, useChatContext } from '~/Providers'; import { useLocalize, useUpdateFiles } from '~/hooks'; @@ -86,7 +86,7 @@ export default function DataTable({ columns, data }: DataTablePro const fileMap = useFileMapContext(); const { showToast } = useToastContext(); - const { setFiles, conversation } = useChatContext(); + const { files, setFiles, conversation } = useChatContext(); const { data: fileConfig = null } = useGetFileConfig({ select: (data) => mergeFileConfig(data), }); @@ -142,7 +142,15 @@ export default function DataTable({ columns, data }: DataTablePro return; } - if (fileData.bytes > (endpointFileConfig.fileSizeLimit ?? Number.MAX_SAFE_INTEGER)) { + if (endpointFileConfig.fileLimit && files.size >= endpointFileConfig.fileLimit) { + showToast({ + message: `${localize('com_ui_attach_error_limit')} ${endpointFileConfig.fileLimit} files (${endpoint})`, + status: 'error', + }); + return; + } + + if (fileData.bytes >= (endpointFileConfig.fileSizeLimit ?? Number.MAX_SAFE_INTEGER)) { showToast({ message: `${localize('com_ui_attach_error_size')} ${ (endpointFileConfig.fileSizeLimit ?? 0) / megabyte @@ -160,6 +168,22 @@ export default function DataTable({ columns, data }: DataTablePro return; } + if (endpointFileConfig.totalSizeLimit) { + const existing = files.get(fileData.file_id); + let currentTotalSize = 0; + for (const f of files.values()) { + currentTotalSize += f.size; + } + currentTotalSize -= existing?.size ?? 0; + if (currentTotalSize + fileData.bytes > endpointFileConfig.totalSizeLimit) { + showToast({ + message: `${localize('com_ui_attach_error_total_size')} ${endpointFileConfig.totalSizeLimit / megabyte} MB (${endpoint})`, + status: 'error', + }); + return; + } + } + addFile({ progress: 1, attached: true, @@ -175,7 +199,7 @@ export default function DataTable({ columns, data }: DataTablePro metadata: fileData.metadata, }); }, - [addFile, fileMap, conversation, localize, showToast, fileConfig], + [addFile, files, fileMap, conversation, localize, showToast, fileConfig], ); const filenameFilter = table.getColumn('filename')?.getFilterValue() as string; diff --git a/client/src/components/SidePanel/Files/__tests__/PanelTable.spec.tsx b/client/src/components/SidePanel/Files/__tests__/PanelTable.spec.tsx new file mode 100644 index 0000000000..2639d3c100 --- /dev/null +++ b/client/src/components/SidePanel/Files/__tests__/PanelTable.spec.tsx @@ -0,0 +1,239 @@ +import React from 'react'; +import { render, screen, fireEvent } from '@testing-library/react'; +import { FileSources } from 'librechat-data-provider'; +import type { TFile } from 'librechat-data-provider'; +import type { ExtendedFile } from '~/common'; +import DataTable from '../PanelTable'; +import { columns } from '../PanelColumns'; + +const mockShowToast = jest.fn(); +const mockAddFile = jest.fn(); + +let mockFileMap: Record = {}; +let mockFiles: Map = new Map(); +let mockConversation: Record | null = { endpoint: 'openAI' }; +let mockRawFileConfig: Record | null = { + endpoints: { + openAI: { fileLimit: 10, supportedMimeTypes: ['application/pdf', 'text/plain'] }, + }, +}; + +jest.mock('@librechat/client', () => ({ + Table: ({ children, ...props }: { children: React.ReactNode }) => ( + {children}
+ ), + Button: ({ + children, + ...props + }: { children: React.ReactNode } & React.ButtonHTMLAttributes) => ( + + ), + TableRow: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + TableHead: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + TableBody: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + TableCell: ({ + children, + ...props + }: { children: React.ReactNode } & React.TdHTMLAttributes) => ( + {children} + ), + FilterInput: () => , + TableHeader: ({ children, ...props }: { children: React.ReactNode }) => ( + {children} + ), + useToastContext: () => ({ showToast: mockShowToast }), +})); + +jest.mock('~/Providers', () => ({ + useFileMapContext: () => mockFileMap, + useChatContext: () => ({ + files: mockFiles, + setFiles: jest.fn(), + conversation: mockConversation, + }), +})); + +jest.mock('~/hooks', () => ({ + useLocalize: () => (key: string) => key, + useUpdateFiles: () => ({ addFile: mockAddFile }), +})); + +jest.mock('~/data-provider', () => ({ + useGetFileConfig: ({ select }: { select?: (d: unknown) => unknown }) => ({ + data: select != null ? select(mockRawFileConfig) : mockRawFileConfig, + }), +})); + +jest.mock('~/components/Chat/Input/Files/MyFilesModal', () => ({ + MyFilesModal: () => null, +})); + +jest.mock('../PanelFileCell', () => ({ row }: { row: { original: TFile } }) => ( + {row.original?.filename} +)); + +function makeFile(overrides: Partial = {}): TFile { + return { + user: 'user-1', + file_id: 'file-1', + bytes: 1024, + embedded: false, + filename: 'test.pdf', + filepath: '/files/test.pdf', + object: 'file', + type: 'application/pdf', + usage: 0, + source: FileSources.local, + ...overrides, + }; +} + +function makeExtendedFile(overrides: Partial = {}): ExtendedFile { + return { + file_id: 'ext-1', + size: 1024, + progress: 1, + source: FileSources.local, + ...overrides, + }; +} + +function renderTable(data: TFile[]) { + return render(); +} + +function clickFilenameCell() { + const cells = screen.getAllByRole('button'); + const filenameCell = cells.find( + (cell) => cell.tagName === 'TD' && cell.textContent && !cell.textContent.includes('com_ui_'), + ); + if (!filenameCell) { + throw new Error('Could not find filename cell with role="button" — check mock setup'); + } + fireEvent.click(filenameCell); + return filenameCell; +} + +describe('PanelTable handleFileClick', () => { + beforeEach(() => { + mockShowToast.mockClear(); + mockAddFile.mockClear(); + mockFiles = new Map(); + mockConversation = { endpoint: 'openAI' }; + mockRawFileConfig = { + endpoints: { + openAI: { + fileLimit: 5, + totalSizeLimit: 10, + supportedMimeTypes: ['application/pdf', 'text/plain'], + }, + }, + }; + }); + + it('calls addFile when within file limits', () => { + const file = makeFile(); + mockFileMap = { [file.file_id]: file }; + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).toHaveBeenCalledTimes(1); + expect(mockAddFile).toHaveBeenCalledWith( + expect.objectContaining({ + file_id: file.file_id, + attached: true, + progress: 1, + }), + ); + expect(mockShowToast).not.toHaveBeenCalledWith(expect.objectContaining({ status: 'error' })); + }); + + it('blocks attachment when fileLimit is reached', () => { + const file = makeFile({ file_id: 'new-file', filename: 'new.pdf' }); + mockFileMap = { [file.file_id]: file }; + + mockFiles = new Map( + Array.from({ length: 5 }, (_, i) => [ + `existing-${i}`, + makeExtendedFile({ file_id: `existing-${i}` }), + ]), + ); + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).not.toHaveBeenCalled(); + expect(mockShowToast).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('com_ui_attach_error_limit'), + status: 'error', + }), + ); + }); + + it('blocks attachment when totalSizeLimit would be exceeded', () => { + const MB = 1024 * 1024; + const largeFile = makeFile({ file_id: 'large-file', bytes: 6 * MB }); + mockFileMap = { [largeFile.file_id]: largeFile }; + + mockFiles = new Map([ + ['existing-1', makeExtendedFile({ file_id: 'existing-1', size: 5 * MB })], + ]); + + renderTable([largeFile]); + clickFilenameCell(); + + expect(mockAddFile).not.toHaveBeenCalled(); + expect(mockShowToast).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('com_ui_attach_error_total_size'), + status: 'error', + }), + ); + }); + + it('does not double-count size of already-attached file', () => { + const MB = 1024 * 1024; + const file = makeFile({ file_id: 'reattach', bytes: 5 * MB }); + mockFileMap = { [file.file_id]: file }; + + mockFiles = new Map([ + ['reattach', makeExtendedFile({ file_id: 'reattach', size: 5 * MB })], + ['other', makeExtendedFile({ file_id: 'other', size: 4 * MB })], + ]); + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).toHaveBeenCalledTimes(1); + expect(mockShowToast).not.toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('com_ui_attach_error_total_size'), + }), + ); + }); + + it('allows attachment when just under fileLimit', () => { + const file = makeFile({ file_id: 'under-limit' }); + mockFileMap = { [file.file_id]: file }; + + mockFiles = new Map( + Array.from({ length: 4 }, (_, i) => [ + `existing-${i}`, + makeExtendedFile({ file_id: `existing-${i}` }), + ]), + ); + + renderTable([file]); + clickFilenameCell(); + + expect(mockAddFile).toHaveBeenCalledTimes(1); + }); +}); diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 196ea2ad4a..f45cdd5f8c 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -748,7 +748,9 @@ "com_ui_attach_error": "Cannot attach file. Create or select a conversation, or try refreshing the page.", "com_ui_attach_error_disabled": "File uploads are disabled for this endpoint", "com_ui_attach_error_openai": "Cannot attach Assistant files to other endpoints", + "com_ui_attach_error_limit": "File limit reached:", "com_ui_attach_error_size": "File size limit exceeded for endpoint:", + "com_ui_attach_error_total_size": "Total file size limit exceeded for endpoint:", "com_ui_attach_error_type": "Unsupported file type for endpoint:", "com_ui_attach_remove": "Remove file", "com_ui_attach_warn_endpoint": "Non-Assistant files may be ignored without a compatible tool", diff --git a/client/src/utils/__tests__/validateFiles.spec.ts b/client/src/utils/__tests__/validateFiles.spec.ts new file mode 100644 index 0000000000..6d690bf62a --- /dev/null +++ b/client/src/utils/__tests__/validateFiles.spec.ts @@ -0,0 +1,172 @@ +import { megabyte, fileConfig as defaultFileConfig } from 'librechat-data-provider'; +import type { EndpointFileConfig, FileConfig } from 'librechat-data-provider'; +import type { ExtendedFile } from '~/common'; +import { validateFiles } from '../files'; + +const supportedMimeTypes = defaultFileConfig.endpoints.default.supportedMimeTypes; + +function makeEndpointConfig(overrides: Partial = {}): EndpointFileConfig { + return { + fileLimit: 10, + fileSizeLimit: 25 * megabyte, + totalSizeLimit: 100 * megabyte, + supportedMimeTypes, + disabled: false, + ...overrides, + }; +} + +function makeFile(name: string, type: string, size: number): File { + const content = new ArrayBuffer(size); + return new File([content], name, { type }); +} + +function makeExtendedFile(overrides: Partial = {}): ExtendedFile { + return { + file_id: 'ext-1', + size: 1024, + progress: 1, + type: 'application/pdf', + ...overrides, + }; +} + +describe('validateFiles', () => { + let setError: jest.Mock; + let files: Map; + let endpointFileConfig: EndpointFileConfig; + const fileConfig: FileConfig | null = null; + + beforeEach(() => { + setError = jest.fn(); + files = new Map(); + endpointFileConfig = makeEndpointConfig(); + }); + + it('returns true when all checks pass', () => { + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + expect(setError).not.toHaveBeenCalled(); + }); + + it('rejects when endpoint is disabled', () => { + endpointFileConfig = makeEndpointConfig({ disabled: true }); + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('com_ui_attach_error_disabled'); + }); + + it('rejects empty files (zero bytes)', () => { + const fileList = [makeFile('empty.pdf', 'application/pdf', 0)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('com_error_files_empty'); + }); + + it('rejects when fileLimit would be exceeded', () => { + endpointFileConfig = makeEndpointConfig({ fileLimit: 3 }); + files = new Map([ + ['f1', makeExtendedFile({ file_id: 'f1', filename: 'one.pdf', size: 2048 })], + ['f2', makeExtendedFile({ file_id: 'f2', filename: 'two.pdf', size: 3072 })], + ]); + const fileList = [ + makeFile('a.pdf', 'application/pdf', 1024), + makeFile('b.pdf', 'application/pdf', 2048), + ]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('File limit reached: 3 files'); + }); + + it('allows upload when exactly at fileLimit boundary', () => { + endpointFileConfig = makeEndpointConfig({ fileLimit: 3 }); + files = new Map([ + ['f1', makeExtendedFile({ file_id: 'f1', filename: 'one.pdf', size: 2048 })], + ['f2', makeExtendedFile({ file_id: 'f2', filename: 'two.pdf', size: 3072 })], + ]); + const fileList = [makeFile('a.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + }); + + it('rejects unsupported MIME type', () => { + const fileList = [makeFile('data.xyz', 'application/x-unknown', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('Unsupported file type: application/x-unknown'); + }); + + it('rejects when file size equals fileSizeLimit (>= comparison)', () => { + const limit = 5 * megabyte; + endpointFileConfig = makeEndpointConfig({ fileSizeLimit: limit }); + const fileList = [makeFile('exact.pdf', 'application/pdf', limit)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith(`File size limit exceeded: ${limit / megabyte} MB`); + }); + + it('allows file just under fileSizeLimit', () => { + const limit = 5 * megabyte; + endpointFileConfig = makeEndpointConfig({ fileSizeLimit: limit }); + const fileList = [makeFile('under.pdf', 'application/pdf', limit - 1)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + }); + + it('rejects when totalSizeLimit would be exceeded', () => { + const limit = 10 * megabyte; + endpointFileConfig = makeEndpointConfig({ totalSizeLimit: limit }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', size: 6 * megabyte })]]); + const fileList = [makeFile('big.pdf', 'application/pdf', 5 * megabyte)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith(`Total file size limit exceeded: ${limit / megabyte} MB`); + }); + + it('allows when totalSizeLimit is exactly met', () => { + const limit = 10 * megabyte; + endpointFileConfig = makeEndpointConfig({ totalSizeLimit: limit }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', size: 5 * megabyte })]]); + const fileList = [makeFile('fits.pdf', 'application/pdf', 5 * megabyte)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(true); + }); + + it('rejects duplicate files', () => { + files = new Map([ + [ + 'f1', + makeExtendedFile({ + file_id: 'f1', + file: makeFile('doc.pdf', 'application/pdf', 1024), + filename: 'doc.pdf', + size: 1024, + type: 'application/pdf', + }), + ], + ]); + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + const result = validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(result).toBe(false); + expect(setError).toHaveBeenCalledWith('com_error_files_dupe'); + }); + + it('enforces check ordering: disabled before fileLimit', () => { + endpointFileConfig = makeEndpointConfig({ disabled: true, fileLimit: 1 }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', filename: 'existing.pdf' })]]); + const fileList = [makeFile('doc.pdf', 'application/pdf', 1024)]; + validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(setError).toHaveBeenCalledWith('com_ui_attach_error_disabled'); + }); + + it('enforces check ordering: fileLimit before fileSizeLimit', () => { + const limit = 1; + endpointFileConfig = makeEndpointConfig({ fileLimit: 1, fileSizeLimit: limit }); + files = new Map([['f1', makeExtendedFile({ file_id: 'f1', filename: 'existing.pdf' })]]); + const fileList = [makeFile('huge.pdf', 'application/pdf', limit)]; + validateFiles({ files, fileList, setError, endpointFileConfig, fileConfig }); + expect(setError).toHaveBeenCalledWith('File limit reached: 1 files'); + }); +}); diff --git a/client/src/utils/files.ts b/client/src/utils/files.ts index b4d362d456..be81a31b79 100644 --- a/client/src/utils/files.ts +++ b/client/src/utils/files.ts @@ -251,7 +251,7 @@ export const validateFiles = ({ const currentTotalSize = existingFiles.reduce((total, file) => total + file.size, 0); if (fileLimit && fileList.length + files.size > fileLimit) { - setError(`You can only upload up to ${fileLimit} files at a time.`); + setError(`File limit reached: ${fileLimit} files`); return false; } @@ -282,19 +282,18 @@ export const validateFiles = ({ } if (!checkType(originalFile.type, mimeTypesToCheck)) { - console.log(originalFile); - setError('Currently, unsupported file type: ' + originalFile.type); + setError(`Unsupported file type: ${originalFile.type}`); return false; } if (fileSizeLimit && originalFile.size >= fileSizeLimit) { - setError(`File size exceeds ${fileSizeLimit / megabyte} MB.`); + setError(`File size limit exceeded: ${fileSizeLimit / megabyte} MB`); return false; } } if (totalSizeLimit && currentTotalSize + incomingTotalSize > totalSizeLimit) { - setError(`The total size of the files cannot exceed ${totalSizeLimit / megabyte} MB.`); + setError(`Total file size limit exceeded: ${totalSizeLimit / megabyte} MB`); return false; } From a01959b3d2eddb0961c611d429a703187d2e347b Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 11:11:10 -0400 Subject: [PATCH 25/39] =?UTF-8?q?=F0=9F=9B=B0=EF=B8=8F=20fix:=20Cross-Repl?= =?UTF-8?q?ica=20Created=20Event=20Delivery=20(#12231)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: emit created event from metadata on cross-replica subscribe In multi-instance Redis deployments, the created event (which triggers sidebar conversation creation) was lost when the SSE subscriber connected to a different instance than the one generating. The event was only in the generating instance's local earlyEventBuffer and the Redis pub/sub message was already gone by the time the subscriber's channel was active. When subscribing cross-replica (empty buffer, Redis mode, userMessage already in job metadata), reconstruct and emit the created event directly from stored metadata. * test: add skipBufferReplay regression guard for cross-replica created event Add test asserting the resume path (skipBufferReplay: true) does NOT emit a created event on cross-replica subscribe — prevents the duplication fix from PR #12225 from regressing. Add explanatory JSDoc on the cross-replica fallback branch documenting which fields are preserved from trackUserMessage() and why sender/isCreatedByUser are hardcoded. * refactor: replace as-unknown-as casts with discriminated ServerSentEvent union Split ServerSentEvent into StreamEvent | CreatedEvent | FinalEvent so event shapes are statically typed. Removes all as-unknown-as casts in GenerationJobManager and test file; narrows with proper union members where properties are accessed. * fix: await trackUserMessage before PUBLISH for structural ordering trackUserMessage was fire-and-forget — the HSET for userMessage could theoretically race with the PUBLISH. Await it so the write commits before the pub/sub fires, guaranteeing any cross-replica getJob() after the pub/sub window always finds userMessage in Redis. No-op for non-created events (early return before any async work). * refactor: type CreatedEvent.message explicitly, fix JSDoc and import Give CreatedEvent.message its full known shape instead of Record. Update sendEvent JSDoc to reflect the discriminated union. Use barrel import in test file. * refactor: type FinalEvent fields with explicit message and conversation shapes Replace Record on requestMessage, responseMessage, conversation, and runMessages with FinalMessageFields and a typed conversation shape. Captures the known field set used by all final event constructors (abort handler in GenerationJobManager and normal completion in request.js) while allowing extension via index signature for fields contributed by the full TMessage/TConversation schemas. * refactor: narrow trackUserMessage with discriminated union, disambiguate error fields Use 'created' in event to narrow ServerSentEvent to CreatedEvent, eliminating all Record casts and manual field assertions. Add JSDoc to the two distinct error fields on FinalMessageFields and FinalEvent to prevent confusion. * fix: update cross-replica test to expect created event from metadata The cross-replica subscribe fallback now correctly emits a created event reconstructed from persisted metadata when userMessage exists in the Redis job hash. Replica B receives 4 events (created + 3 deltas) instead of 3. --- .../api/src/stream/GenerationJobManager.ts | 50 ++++-- ...ationJobManager.stream_integration.spec.ts | 142 ++++++++++++++++-- packages/api/src/types/events.ts | 49 +++++- packages/api/src/utils/events.ts | 9 +- 4 files changed, 218 insertions(+), 32 deletions(-) diff --git a/packages/api/src/stream/GenerationJobManager.ts b/packages/api/src/stream/GenerationJobManager.ts index 1b612dcb8f..3e04ab734b 100644 --- a/packages/api/src/stream/GenerationJobManager.ts +++ b/packages/api/src/stream/GenerationJobManager.ts @@ -656,7 +656,7 @@ class GenerationJobManagerClass { aborted: true, // Flag for early abort - no messages saved, frontend should go to new chat earlyAbort: isEarlyAbort, - } as unknown as t.ServerSentEvent; + } satisfies t.FinalEvent as t.ServerSentEvent; if (runtime) { runtime.finalEvent = abortFinalEvent; @@ -781,6 +781,27 @@ class GenerationJobManagerClass { } } runtime.earlyEventBuffer = []; + } else if (this._isRedis && !options?.skipBufferReplay && jobData?.userMessage) { + /** + * Cross-replica fallback: the created event was buffered on the generating + * instance and published via Redis pub/sub before this subscriber was active. + * Reconstruct from persisted metadata. Only fields stored by trackUserMessage() + * are available (messageId, parentMessageId, conversationId, text); + * sender/isCreatedByUser are invariant for user messages and added back here. + */ + logger.debug( + `[GenerationJobManager] Cross-replica subscribe: emitting created event from metadata for ${streamId}`, + ); + const createdEvent: t.CreatedEvent = { + created: true, + message: { + ...jobData.userMessage, + sender: 'User', + isCreatedByUser: true, + }, + streamId, + }; + onChunk(createdEvent); } this.eventTransport.syncReorderBuffer?.(streamId); @@ -858,8 +879,7 @@ class GenerationJobManagerClass { return; } - // Track user message from created event - this.trackUserMessage(streamId, event); + await this.trackUserMessage(streamId, event); // For Redis mode, persist chunk for later reconstruction (fire-and-forget for resumability) if (this._isRedis) { @@ -943,29 +963,31 @@ class GenerationJobManagerClass { } /** - * Track user message from created event. + * Persist user message metadata from the created event. + * Awaited in emitChunk so the HSET commits before the PUBLISH, + * guaranteeing any cross-replica getJob() after the pub/sub window + * finds userMessage in Redis. */ - private trackUserMessage(streamId: string, event: t.ServerSentEvent): void { - const data = event as Record; - if (!data.created || !data.message) { + private async trackUserMessage(streamId: string, event: t.ServerSentEvent): Promise { + if (!('created' in event)) { return; } - const message = data.message as Record; + const { message } = event; const updates: Partial = { userMessage: { - messageId: message.messageId as string, - parentMessageId: message.parentMessageId as string | undefined, - conversationId: message.conversationId as string | undefined, - text: message.text as string | undefined, + messageId: message.messageId, + parentMessageId: message.parentMessageId, + conversationId: message.conversationId, + text: message.text, }, }; if (message.conversationId) { - updates.conversationId = message.conversationId as string; + updates.conversationId = message.conversationId; } - this.jobStore.updateJob(streamId, updates); + await this.jobStore.updateJob(streamId, updates); } /** diff --git a/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts b/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts index 2f23510018..3e85ace56d 100644 --- a/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts +++ b/packages/api/src/stream/__tests__/GenerationJobManager.stream_integration.spec.ts @@ -1,6 +1,6 @@ /* eslint jest/no-standalone-expect: ["error", { "additionalTestBlockFunctions": ["testRedis"] }] */ import type { Redis, Cluster } from 'ioredis'; -import type { ServerSentEvent } from '~/types/events'; +import type { ServerSentEvent, StreamEvent, CreatedEvent } from '~/types'; import { InMemoryEventTransport } from '~/stream/implementations/InMemoryEventTransport'; import { RedisEventTransport } from '~/stream/implementations/RedisEventTransport'; import { InMemoryJobStore } from '~/stream/implementations/InMemoryJobStore'; @@ -771,6 +771,127 @@ describe('GenerationJobManager Integration Tests', () => { await GenerationJobManager.destroy(); await jobStore.destroy(); }); + + test('should emit created event from metadata on cross-replica subscribe', async () => { + const replicaAJobStore = new RedisJobStore(ioredisClient!); + await replicaAJobStore.initialize(); + + const streamId = `cross-created-${Date.now()}`; + const userId = 'test-user'; + + await replicaAJobStore.createJob(streamId, userId); + await replicaAJobStore.updateJob(streamId, { + userMessage: { + messageId: 'msg-123', + parentMessageId: '00000000-0000-0000-0000-000000000000', + conversationId: streamId, + text: 'hello world', + }, + }); + + jest.resetModules(); + + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, + }); + + GenerationJobManager.configure(services); + GenerationJobManager.initialize(); + + const received: unknown[] = []; + const subscription = await GenerationJobManager.subscribe( + streamId, + (event) => received.push(event), + ); + + expect(subscription).not.toBeNull(); + expect(received.length).toBe(1); + + const created = received[0] as CreatedEvent; + expect(created.created).toBe(true); + expect(created.streamId).toBe(streamId); + expect(created.message.messageId).toBe('msg-123'); + expect(created.message.conversationId).toBe(streamId); + expect(created.message.sender).toBe('User'); + expect(created.message.isCreatedByUser).toBe(true); + + subscription?.unsubscribe(); + await GenerationJobManager.destroy(); + await replicaAJobStore.destroy(); + }); + + test('should NOT emit created event from metadata when userMessage is not set', async () => { + const replicaAJobStore = new RedisJobStore(ioredisClient!); + await replicaAJobStore.initialize(); + + const streamId = `cross-no-created-${Date.now()}`; + await replicaAJobStore.createJob(streamId, 'test-user'); + + jest.resetModules(); + + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, + }); + + GenerationJobManager.configure(services); + GenerationJobManager.initialize(); + + const received: unknown[] = []; + const subscription = await GenerationJobManager.subscribe( + streamId, + (event) => received.push(event), + ); + + expect(subscription).not.toBeNull(); + expect(received.length).toBe(0); + + subscription?.unsubscribe(); + await GenerationJobManager.destroy(); + await replicaAJobStore.destroy(); + }); + + test('should NOT emit created event when skipBufferReplay is true (resume path)', async () => { + const replicaAJobStore = new RedisJobStore(ioredisClient!); + await replicaAJobStore.initialize(); + + const streamId = `cross-no-replay-${Date.now()}`; + await replicaAJobStore.createJob(streamId, 'test-user'); + await replicaAJobStore.updateJob(streamId, { + userMessage: { + messageId: 'msg-456', + conversationId: streamId, + text: 'hi', + }, + }); + + jest.resetModules(); + + const services = createStreamServices({ + useRedis: true, + redisClient: ioredisClient, + }); + + GenerationJobManager.configure(services); + GenerationJobManager.initialize(); + + const received: unknown[] = []; + const subscription = await GenerationJobManager.subscribe( + streamId, + (event) => received.push(event), + undefined, + undefined, + { skipBufferReplay: true }, + ); + + expect(subscription).not.toBeNull(); + expect(received.length).toBe(0); + + subscription?.unsubscribe(); + await GenerationJobManager.destroy(); + await replicaAJobStore.destroy(); + }); }); describeRedis('Sequential Event Ordering (Redis)', () => { @@ -1040,7 +1161,7 @@ describe('GenerationJobManager Integration Tests', () => { created: true, message: { text: 'hello' }, streamId, - } as unknown as ServerSentEvent); + } as CreatedEvent); await manager.emitChunk(streamId, { event: 'on_message_delta', data: { delta: { content: { type: 'text', text: 'First chunk' } } }, @@ -1077,7 +1198,7 @@ describe('GenerationJobManager Integration Tests', () => { created: true, message: { text: 'hello' }, streamId, - } as unknown as ServerSentEvent); + } as CreatedEvent); await manager.emitChunk(streamId, { event: 'on_message_delta', data: { delta: { content: { type: 'text', text: 'First' } } }, @@ -1123,7 +1244,7 @@ describe('GenerationJobManager Integration Tests', () => { created: true, message: { text: 'hello' }, streamId, - } as unknown as ServerSentEvent); + } as CreatedEvent); await manager.emitChunk(streamId, { event: 'on_run_step', data: { id: 'step-1', type: 'message_creation', index: 0 }, @@ -1228,7 +1349,7 @@ describe('GenerationJobManager Integration Tests', () => { await new Promise((resolve) => setTimeout(resolve, 20)); expect(resumeEvents.length).toBe(1); - expect(resumeEvents[0].event).toBe('on_message_delta'); + expect((resumeEvents[0] as StreamEvent).event).toBe('on_message_delta'); sub2?.unsubscribe(); await manager.destroy(); @@ -1262,7 +1383,7 @@ describe('GenerationJobManager Integration Tests', () => { await new Promise((resolve) => setTimeout(resolve, 20)); expect(sub2Events.length).toBe(1); - expect(sub2Events[0].event).toBe('on_message_delta'); + expect((sub2Events[0] as StreamEvent).event).toBe('on_message_delta'); sub2?.unsubscribe(); await manager.destroy(); @@ -1427,7 +1548,7 @@ describe('GenerationJobManager Integration Tests', () => { await new Promise((resolve) => setTimeout(resolve, 200)); expect(resumeEvents.length).toBe(1); - expect(resumeEvents[0].event).toBe('on_message_delta'); + expect((resumeEvents[0] as StreamEvent).event).toBe('on_message_delta'); sub2?.unsubscribe(); await manager.destroy(); @@ -1458,7 +1579,7 @@ describe('GenerationJobManager Integration Tests', () => { await new Promise((resolve) => setTimeout(resolve, 200)); expect(sub2Events.length).toBe(1); - expect(sub2Events[0].event).toBe('on_message_delta'); + expect((sub2Events[0] as StreamEvent).event).toBe('on_message_delta'); sub2?.unsubscribe(); await manager.destroy(); @@ -1997,7 +2118,7 @@ describe('GenerationJobManager Integration Tests', () => { created: true, message: { text: 'hello' }, streamId, - } as unknown as ServerSentEvent); + } as CreatedEvent); const receivedOnA: unknown[] = []; const subA = await replicaA.subscribe(streamId, (event: unknown) => receivedOnA.push(event)); @@ -2035,7 +2156,8 @@ describe('GenerationJobManager Integration Tests', () => { await new Promise((resolve) => setTimeout(resolve, 700)); expect(receivedOnA.length).toBe(4); - expect(receivedOnB.length).toBe(3); + expect(receivedOnB.length).toBe(4); + expect((receivedOnB[0] as CreatedEvent).created).toBe(true); subA?.unsubscribe(); subB?.unsubscribe(); diff --git a/packages/api/src/types/events.ts b/packages/api/src/types/events.ts index 1e866fa840..d068888b17 100644 --- a/packages/api/src/types/events.ts +++ b/packages/api/src/types/events.ts @@ -1,4 +1,49 @@ -export type ServerSentEvent = { +/** SSE streaming event (on_run_step, on_message_delta, etc.) */ +export type StreamEvent = { + event: string; data: string | Record; - event?: string; }; + +/** Control event emitted when user message is created and generation starts */ +export type CreatedEvent = { + created: true; + message: { + messageId: string; + parentMessageId?: string; + conversationId?: string; + text?: string; + sender: string; + isCreatedByUser: boolean; + }; + streamId: string; +}; + +export type FinalMessageFields = { + messageId?: string; + parentMessageId?: string; + conversationId?: string; + text?: string; + content?: unknown[]; + sender?: string; + isCreatedByUser?: boolean; + unfinished?: boolean; + /** Per-message error flag — matches TMessage.error (boolean or error text) */ + error?: boolean | string; + [key: string]: unknown; +}; + +/** Terminal event emitted when generation completes or is aborted */ +export type FinalEvent = { + final: true; + requestMessage?: FinalMessageFields | null; + responseMessage?: FinalMessageFields | null; + conversation?: { conversationId?: string; [key: string]: unknown } | null; + title?: string; + aborted?: boolean; + earlyAbort?: boolean; + runMessages?: FinalMessageFields[]; + /** Top-level event error (abort-during-completion edge case) */ + error?: { message: string }; +}; + +export type ServerSentEvent = StreamEvent | CreatedEvent | FinalEvent; diff --git a/packages/api/src/utils/events.ts b/packages/api/src/utils/events.ts index 20c9583993..e084e631f5 100644 --- a/packages/api/src/utils/events.ts +++ b/packages/api/src/utils/events.ts @@ -2,14 +2,11 @@ import type { Response as ServerResponse } from 'express'; import type { ServerSentEvent } from '~/types'; /** - * Sends message data in Server Sent Events format. - * @param res - The server response. - * @param event - The message event. - * @param event.event - The type of event. - * @param event.data - The message to be sent. + * Sends a Server-Sent Event to the client. + * Empty-string StreamEvent data is silently dropped. */ export function sendEvent(res: ServerResponse, event: ServerSentEvent): void { - if (typeof event.data === 'string' && event.data.length === 0) { + if ('data' in event && typeof event.data === 'string' && event.data.length === 0) { return; } res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); From a0b4949a059732ee4b457be8a52d7d4015cc950b Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 17:07:55 -0400 Subject: [PATCH 26/39] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20fix:=20Cover=20fu?= =?UTF-8?q?ll=20fe80::/10=20link-local=20range=20in=20IPv6=20check=20(#122?= =?UTF-8?q?44)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Cover full fe80::/10 link-local range in SSRF IPv6 check The `isPrivateIP` check used `startsWith('fe80')` which only matched fe80:: but missed fe90::–febf:: (the rest of the RFC 4291 fe80::/10 link-local block). Replace with a proper bitwise hextet check. * 🛡️ fix: Guard isIPv6LinkLocal against parseInt partial-parse on hostnames parseInt('fe90.example.com', 16) stops at the dot and returns 0xfe90, which passes the bitmask check and false-positives legitimate domains. Add colon-presence guard (IPv6 literals always contain ':') and a hex regex validation on the first hextet before parseInt. Also document why fc/fd use startsWith while fe80::/10 needs bitwise. * ✅ test: Harden IPv6 link-local SSRF tests with false-positive guards - Assert fe90/fea0/febf hostnames are NOT blocked (regression guard) - Add feb0::1 and bracket form [fe90::1] to isPrivateIP coverage - Extend resolveHostnameSSRF tests for fe90::1 and febf::1 --- packages/api/src/auth/domain.spec.ts | 25 ++++++++++++++++++++++++- packages/api/src/auth/domain.ts | 18 ++++++++++++++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index 9812960cd9..8ba72d82a2 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -177,6 +177,20 @@ describe('isSSRFTarget', () => { expect(isSSRFTarget('fd00::1')).toBe(true); expect(isSSRFTarget('fe80::1')).toBe(true); }); + + it('should block full fe80::/10 link-local range (fe80–febf)', () => { + expect(isSSRFTarget('fe90::1')).toBe(true); + expect(isSSRFTarget('fea0::1')).toBe(true); + expect(isSSRFTarget('feb0::1')).toBe(true); + expect(isSSRFTarget('febf::1')).toBe(true); + expect(isSSRFTarget('fec0::1')).toBe(false); + }); + + it('should NOT false-positive on hostnames whose first label resembles a link-local prefix', () => { + expect(isSSRFTarget('fe90.example.com')).toBe(false); + expect(isSSRFTarget('fea0.api.io')).toBe(false); + expect(isSSRFTarget('febf.service.net')).toBe(false); + }); }); describe('internal hostnames', () => { @@ -277,10 +291,17 @@ describe('isPrivateIP', () => { expect(isPrivateIP('[::1]')).toBe(true); }); - it('should detect unique local (fc/fd) and link-local (fe80)', () => { + it('should detect unique local (fc/fd) and link-local (fe80::/10)', () => { expect(isPrivateIP('fc00::1')).toBe(true); expect(isPrivateIP('fd00::1')).toBe(true); expect(isPrivateIP('fe80::1')).toBe(true); + expect(isPrivateIP('fe90::1')).toBe(true); + expect(isPrivateIP('fea0::1')).toBe(true); + expect(isPrivateIP('feb0::1')).toBe(true); + expect(isPrivateIP('febf::1')).toBe(true); + expect(isPrivateIP('[fe90::1]')).toBe(true); + expect(isPrivateIP('fec0::1')).toBe(false); + expect(isPrivateIP('fe90.example.com')).toBe(false); }); }); @@ -482,6 +503,8 @@ describe('resolveHostnameSSRF', () => { expect(await resolveHostnameSSRF('::1')).toBe(true); expect(await resolveHostnameSSRF('fc00::1')).toBe(true); expect(await resolveHostnameSSRF('fe80::1')).toBe(true); + expect(await resolveHostnameSSRF('fe90::1')).toBe(true); + expect(await resolveHostnameSSRF('febf::1')).toBe(true); expect(mockedLookup).not.toHaveBeenCalled(); }); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index 2761a80b55..37510f5e9b 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -59,6 +59,20 @@ function isPrivateIPv4(a: number, b: number, c: number): boolean { return false; } +/** Checks if a pre-normalized (lowercase, bracket-stripped) IPv6 address falls within fe80::/10 */ +function isIPv6LinkLocal(ipv6: string): boolean { + if (!ipv6.includes(':')) { + return false; + } + const firstHextet = ipv6.split(':', 1)[0]; + if (!firstHextet || !/^[0-9a-f]{1,4}$/.test(firstHextet)) { + return false; + } + const hextet = parseInt(firstHextet, 16); + // /10 mask (0xffc0) preserves top 10 bits: fe80 = 1111_1110_10xx_xxxx + return (hextet & 0xffc0) === 0xfe80; +} + /** Checks if an IPv6 address embeds a private IPv4 via 6to4, NAT64, or Teredo */ function hasPrivateEmbeddedIPv4(ipv6: string): boolean { if (!ipv6.startsWith('2002:') && !ipv6.startsWith('64:ff9b::') && !ipv6.startsWith('2001::')) { @@ -132,9 +146,9 @@ export function isPrivateIP(ip: string): boolean { if ( normalized === '::1' || normalized === '::' || - normalized.startsWith('fc') || + normalized.startsWith('fc') || // fc00::/7 — exactly prefixes 'fc' and 'fd' normalized.startsWith('fd') || - normalized.startsWith('fe80') + isIPv6LinkLocal(normalized) // fe80::/10 — spans 0xfe80–0xfebf; bitwise check required ) { return true; } From 07d0ce4ce9885281633800d7a14d7b2abab5c76f Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 17:08:43 -0400 Subject: [PATCH 27/39] =?UTF-8?q?=F0=9F=AA=A4=20fix:=20Fail-Closed=20MCP?= =?UTF-8?q?=20Domain=20Validation=20for=20Unparseable=20URLs=20(#12245)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Fail-closed MCP domain validation for unparseable URLs `isMCPDomainAllowed` returned true (allow) when `extractMCPServerDomain` could not parse the URL, treating it identically to a stdio transport. A URL containing template placeholders or invalid syntax bypassed the domain allowlist, then `processMCPEnv` resolved it to a valid—and potentially disallowed—host at connection time. Distinguish "no URL" (stdio, allowed) from "has URL but unparseable" (rejected when an allowlist is active) by checking whether `config.url` is an explicit non-empty string before falling through to the stdio path. When no allowlist is configured the guard does not fire—unparseable URLs fall through to connection-level SSRF protection via `createSSRFSafeUndiciConnect`, preserving legitimate `customUserVars` template-URL configs. * test: Expand MCP domain validation coverage for invalid/templated URLs Cover all branches of the fail-closed guard: - Invalid/templated URLs rejected when allowlist is configured - Invalid/templated URLs allowed when no allowlist (null/undefined/[]) - Whitespace-only and empty-string URLs treated as absent across all allowedDomains configurations - Stdio configs (no url property) remain allowed --- packages/api/src/auth/domain.spec.ts | 31 +++++++++++++++++++++++++++- packages/api/src/auth/domain.ts | 20 ++++++++++++++++-- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index 8ba72d82a2..76f50213db 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -1046,8 +1046,37 @@ describe('isMCPDomainAllowed', () => { }); describe('invalid URL handling', () => { - it('should allow config with invalid URL (treated as stdio)', async () => { + it('should reject invalid URL when allowlist is configured', async () => { const config = { url: 'not-a-valid-url' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(false); + }); + + it('should reject templated URL when allowlist is configured', async () => { + const config = { url: 'http://{{CUSTOM_HOST}}/mcp' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(false); + }); + + it('should allow invalid URL when no allowlist is configured (defers to connection-level SSRF)', async () => { + const config = { url: 'http://{{CUSTOM_HOST}}/mcp' }; + expect(await isMCPDomainAllowed(config, null)).toBe(true); + expect(await isMCPDomainAllowed(config, undefined)).toBe(true); + expect(await isMCPDomainAllowed(config, [])).toBe(true); + }); + + it('should allow config with whitespace-only URL (treated as absent)', async () => { + const config = { url: ' ' }; + expect(await isMCPDomainAllowed(config, [])).toBe(true); + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + expect(await isMCPDomainAllowed(config, null)).toBe(true); + }); + + it('should allow config with empty string URL (treated as absent)', async () => { + const config = { url: '' }; + expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); + }); + + it('should allow config with no url property (stdio)', async () => { + const config = { command: 'node', args: ['server.js'] }; expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true); }); }); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index 37510f5e9b..3babb09aa6 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -442,7 +442,10 @@ export async function isActionDomainAllowed( /** * Extracts full domain spec (protocol://hostname:port) from MCP server config URL. * Returns the full origin for proper protocol/port matching against allowedDomains. - * Returns null for stdio transports (no URL) or invalid URLs. + * @returns The full origin string, or null when: + * - No `url` property, non-string, or empty (stdio transport — always allowed upstream) + * - URL string present but cannot be parsed (rejected fail-closed upstream when allowlist active) + * Callers must distinguish these two null cases; see {@link isMCPDomainAllowed}. * @param config - MCP server configuration (accepts any config with optional url field) */ export function extractMCPServerDomain(config: Record): string | null { @@ -466,6 +469,11 @@ export function extractMCPServerDomain(config: Record): string * Validates MCP server domain against allowedDomains. * Supports HTTP, HTTPS, WS, and WSS protocols (per MCP specification). * Stdio transports (no URL) are always allowed. + * Configs with a non-empty URL that cannot be parsed are rejected fail-closed when an + * allowlist is active, preventing template placeholders (e.g. `{{HOST}}`) from bypassing + * domain validation after `processMCPEnv` resolves them at connection time. + * When no allowlist is configured, unparseable URLs fall through to connection-level + * SSRF protection (`createSSRFSafeUndiciConnect`). * @param config - MCP server configuration with optional url field * @param allowedDomains - List of allowed domains (with wildcard support) */ @@ -474,8 +482,16 @@ export async function isMCPDomainAllowed( allowedDomains?: string[] | null, ): Promise { const domain = extractMCPServerDomain(config); + const hasAllowlist = Array.isArray(allowedDomains) && allowedDomains.length > 0; - // Stdio transports don't have domains - always allowed + const hasExplicitUrl = + Object.hasOwn(config, 'url') && typeof config.url === 'string' && config.url.trim().length > 0; + + if (!domain && hasExplicitUrl && hasAllowlist) { + return false; + } + + // Stdio transports (no URL) are always allowed if (!domain) { return true; } From 8dc6d60750df434682a9f519ec944529c470dd7e Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 17:12:45 -0400 Subject: [PATCH 28/39] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20fix:=20Enforce=20?= =?UTF-8?q?MULTI=5FCONVO=20and=20agent=20ACL=20checks=20on=20addedConvo=20?= =?UTF-8?q?(#12243)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Enforce MULTI_CONVO and agent ACL checks on addedConvo addedConvo.agent_id was passed through to loadAddedAgent without any permission check, enabling an authenticated user to load and execute another user's private agent via the parallel multi-convo feature. The middleware now chains a checkAddedConvoAccess gate after the primary agent check: when req.body.addedConvo is present it verifies the user has MULTI_CONVO:USE role permission, and when the addedConvo agent_id is a real (non-ephemeral) agent it runs the same canAccessResource ACL check used for the primary agent. * refactor: Harden addedConvo middleware and avoid duplicate agent fetch - Convert checkAddedConvoAccess to curried factory matching Express middleware signature: (requiredPermission) => (req, res, next) - Call checkPermission directly for the addedConvo agent instead of routing through canAccessResource's tempReq pattern; this avoids orphaning the resolved agent document and enables caching it on req.resolvedAddedAgent for downstream loadAddedAgent - Update loadAddedAgent to use req.resolvedAddedAgent when available, eliminating a duplicate getAgent DB call per chat request - Validate addedConvo is a plain object and agent_id is a string before passing to isEphemeralAgentId (prevents TypeError on object injection, returns 400-equivalent early exit instead of 500) - Fix JSDoc: "VIEW access" → "same permission as primary agent", add @param/@returns to helpers, restore @example on factory - Fix redundant return await in resolveAgentIdFromBody * test: Add canAccessAgentFromBody spec covering IDOR fix 26 integration tests using MongoMemoryServer with real models, ACL entries, and PermissionService — no mocks for core logic. Covered paths: - Factory validation (requiredPermission type check) - Primary agent: missing agent_id, ephemeral, non-agents endpoint - addedConvo absent / invalid shape (string, array, object injection) - MULTI_CONVO:USE gate: denied, missing role, ADMIN bypass - Agent resource ACL: no ACL → 403, insufficient bits → 403, nonexistent agent → 404, valid ACL → next + cached on req - End-to-end: both real agents, primary denied short-circuits, ephemeral primary + real addedConvo --- api/models/loadAddedAgent.js | 12 +- .../accessResources/canAccessAgentFromBody.js | 156 ++++-- .../canAccessAgentFromBody.spec.js | 509 ++++++++++++++++++ 3 files changed, 637 insertions(+), 40 deletions(-) create mode 100644 api/server/middleware/accessResources/canAccessAgentFromBody.spec.js diff --git a/api/models/loadAddedAgent.js b/api/models/loadAddedAgent.js index aa83375eae..101ee96685 100644 --- a/api/models/loadAddedAgent.js +++ b/api/models/loadAddedAgent.js @@ -48,14 +48,14 @@ const loadAddedAgent = async ({ req, conversation, primaryAgent }) => { return null; } - // If there's an agent_id, load the existing agent if (conversation.agent_id && !isEphemeralAgentId(conversation.agent_id)) { - if (!getAgent) { - throw new Error('getAgent not initialized - call setGetAgent first'); + let agent = req.resolvedAddedAgent; + if (!agent) { + if (!getAgent) { + throw new Error('getAgent not initialized - call setGetAgent first'); + } + agent = await getAgent({ id: conversation.agent_id }); } - const agent = await getAgent({ - id: conversation.agent_id, - }); if (!agent) { logger.warn(`[loadAddedAgent] Agent ${conversation.agent_id} not found`); diff --git a/api/server/middleware/accessResources/canAccessAgentFromBody.js b/api/server/middleware/accessResources/canAccessAgentFromBody.js index f8112af14d..572a86f5e5 100644 --- a/api/server/middleware/accessResources/canAccessAgentFromBody.js +++ b/api/server/middleware/accessResources/canAccessAgentFromBody.js @@ -1,42 +1,144 @@ const { logger } = require('@librechat/data-schemas'); const { Constants, + Permissions, ResourceType, + SystemRoles, + PermissionTypes, isAgentsEndpoint, isEphemeralAgentId, } = require('librechat-data-provider'); +const { checkPermission } = require('~/server/services/PermissionService'); const { canAccessResource } = require('./canAccessResource'); +const { getRoleByName } = require('~/models/Role'); const { getAgent } = require('~/models/Agent'); /** - * Agent ID resolver function for agent_id from request body - * Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId - * This is used specifically for chat routes where agent_id comes from request body - * + * Resolves custom agent ID (e.g., "agent_abc123") to a MongoDB document. * @param {string} agentCustomId - Custom agent ID from request body - * @returns {Promise} Agent document with _id field, or null if not found + * @returns {Promise} Agent document with _id field, or null if ephemeral/not found */ const resolveAgentIdFromBody = async (agentCustomId) => { - // Handle ephemeral agents - they don't need permission checks - // Real agent IDs always start with "agent_", so anything else is ephemeral if (isEphemeralAgentId(agentCustomId)) { - return null; // No permission check needed for ephemeral agents + return null; } - - return await getAgent({ id: agentCustomId }); + return getAgent({ id: agentCustomId }); }; /** - * Middleware factory that creates middleware to check agent access permissions from request body. - * This middleware is specifically designed for chat routes where the agent_id comes from req.body - * instead of route parameters. + * Creates a `canAccessResource` middleware for the given agent ID + * and chains to the provided continuation on success. + * + * @param {string} agentId - The agent's custom string ID (e.g., "agent_abc123") + * @param {number} requiredPermission - Permission bit(s) required + * @param {import('express').Request} req + * @param {import('express').Response} res - Written on deny; continuation called on allow + * @param {Function} continuation - Called when the permission check passes + * @returns {Promise} + */ +const checkAgentResourceAccess = (agentId, requiredPermission, req, res, continuation) => { + const middleware = canAccessResource({ + resourceType: ResourceType.AGENT, + requiredPermission, + resourceIdParam: 'agent_id', + idResolver: () => resolveAgentIdFromBody(agentId), + }); + + const tempReq = { + ...req, + params: { ...req.params, agent_id: agentId }, + }; + + return middleware(tempReq, res, continuation); +}; + +/** + * Middleware factory that validates MULTI_CONVO:USE role permission and, when + * addedConvo.agent_id is a non-ephemeral agent, the same resource-level permission + * required for the primary agent (`requiredPermission`). Caches the resolved agent + * document on `req.resolvedAddedAgent` to avoid a duplicate DB fetch in `loadAddedAgent`. + * + * @param {number} requiredPermission - Permission bit(s) to check on the added agent resource + * @returns {(req: import('express').Request, res: import('express').Response, next: Function) => Promise} + */ +const checkAddedConvoAccess = (requiredPermission) => async (req, res, next) => { + const addedConvo = req.body?.addedConvo; + if (!addedConvo || typeof addedConvo !== 'object' || Array.isArray(addedConvo)) { + return next(); + } + + try { + if (!req.user?.role) { + return res.status(403).json({ + error: 'Forbidden', + message: 'Insufficient permissions for multi-conversation', + }); + } + + if (req.user.role !== SystemRoles.ADMIN) { + const role = await getRoleByName(req.user.role); + const hasMultiConvo = role?.permissions?.[PermissionTypes.MULTI_CONVO]?.[Permissions.USE]; + if (!hasMultiConvo) { + return res.status(403).json({ + error: 'Forbidden', + message: 'Multi-conversation feature is not enabled', + }); + } + } + + const addedAgentId = addedConvo.agent_id; + if (!addedAgentId || typeof addedAgentId !== 'string' || isEphemeralAgentId(addedAgentId)) { + return next(); + } + + if (req.user.role === SystemRoles.ADMIN) { + return next(); + } + + const agent = await resolveAgentIdFromBody(addedAgentId); + if (!agent) { + return res.status(404).json({ + error: 'Not Found', + message: `${ResourceType.AGENT} not found`, + }); + } + + const hasPermission = await checkPermission({ + userId: req.user.id, + role: req.user.role, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + requiredPermission, + }); + + if (!hasPermission) { + return res.status(403).json({ + error: 'Forbidden', + message: `Insufficient permissions to access this ${ResourceType.AGENT}`, + }); + } + + req.resolvedAddedAgent = agent; + return next(); + } catch (error) { + logger.error('Failed to validate addedConvo access permissions', error); + return res.status(500).json({ + error: 'Internal Server Error', + message: 'Failed to validate addedConvo access permissions', + }); + } +}; + +/** + * Middleware factory that checks agent access permissions from request body. + * Validates both the primary agent_id and, when present, addedConvo.agent_id + * (which also requires MULTI_CONVO:USE role permission). * * @param {Object} options - Configuration options * @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share) * @returns {Function} Express middleware function * * @example - * // Basic usage for agent chat (requires VIEW permission) * router.post('/chat', * canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }), * buildEndpointOption, @@ -46,11 +148,12 @@ const resolveAgentIdFromBody = async (agentCustomId) => { const canAccessAgentFromBody = (options) => { const { requiredPermission } = options; - // Validate required options if (!requiredPermission || typeof requiredPermission !== 'number') { throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number'); } + const addedConvoMiddleware = checkAddedConvoAccess(requiredPermission); + return async (req, res, next) => { try { const { endpoint, agent_id } = req.body; @@ -67,28 +170,13 @@ const canAccessAgentFromBody = (options) => { }); } - // Skip permission checks for ephemeral agents - // Real agent IDs always start with "agent_", so anything else is ephemeral + const afterPrimaryCheck = () => addedConvoMiddleware(req, res, next); + if (isEphemeralAgentId(agentId)) { - return next(); + return afterPrimaryCheck(); } - const agentAccessMiddleware = canAccessResource({ - resourceType: ResourceType.AGENT, - requiredPermission, - resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver - idResolver: () => resolveAgentIdFromBody(agentId), - }); - - const tempReq = { - ...req, - params: { - ...req.params, - agent_id: agentId, - }, - }; - - return agentAccessMiddleware(tempReq, res, next); + return checkAgentResourceAccess(agentId, requiredPermission, req, res, afterPrimaryCheck); } catch (error) { logger.error('Failed to validate agent access permissions', error); return res.status(500).json({ diff --git a/api/server/middleware/accessResources/canAccessAgentFromBody.spec.js b/api/server/middleware/accessResources/canAccessAgentFromBody.spec.js new file mode 100644 index 0000000000..47f1130d13 --- /dev/null +++ b/api/server/middleware/accessResources/canAccessAgentFromBody.spec.js @@ -0,0 +1,509 @@ +const mongoose = require('mongoose'); +const { + ResourceType, + SystemRoles, + PrincipalType, + PrincipalModel, +} = require('librechat-data-provider'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { canAccessAgentFromBody } = require('./canAccessAgentFromBody'); +const { User, Role, AclEntry } = require('~/db/models'); +const { createAgent } = require('~/models/Agent'); + +describe('canAccessAgentFromBody middleware', () => { + let mongoServer; + let req, res, next; + let testUser, otherUser; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await mongoose.connection.dropDatabase(); + + await Role.create({ + name: 'test-role', + permissions: { + AGENTS: { USE: true, CREATE: true, SHARE: true }, + MULTI_CONVO: { USE: true }, + }, + }); + + await Role.create({ + name: 'no-multi-convo', + permissions: { + AGENTS: { USE: true, CREATE: true, SHARE: true }, + MULTI_CONVO: { USE: false }, + }, + }); + + await Role.create({ + name: SystemRoles.ADMIN, + permissions: { + AGENTS: { USE: true, CREATE: true, SHARE: true }, + MULTI_CONVO: { USE: true }, + }, + }); + + testUser = await User.create({ + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + role: 'test-role', + }); + + otherUser = await User.create({ + email: 'other@example.com', + name: 'Other User', + username: 'otheruser', + role: 'test-role', + }); + + req = { + user: { id: testUser._id, role: testUser.role }, + params: {}, + body: { + endpoint: 'agents', + agent_id: 'ephemeral_primary', + }, + }; + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + next = jest.fn(); + + jest.clearAllMocks(); + }); + + describe('middleware factory', () => { + test('throws if requiredPermission is missing', () => { + expect(() => canAccessAgentFromBody({})).toThrow( + 'canAccessAgentFromBody: requiredPermission is required and must be a number', + ); + }); + + test('throws if requiredPermission is not a number', () => { + expect(() => canAccessAgentFromBody({ requiredPermission: '1' })).toThrow( + 'canAccessAgentFromBody: requiredPermission is required and must be a number', + ); + }); + + test('returns a middleware function', () => { + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + expect(typeof middleware).toBe('function'); + expect(middleware.length).toBe(3); + }); + }); + + describe('primary agent checks', () => { + test('returns 400 when agent_id is missing on agents endpoint', async () => { + req.body.agent_id = undefined; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(400); + }); + + test('proceeds for ephemeral primary agent without addedConvo', async () => { + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + + test('proceeds for non-agents endpoint (ephemeral fallback)', async () => { + req.body.endpoint = 'openAI'; + req.body.agent_id = undefined; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + }); + + describe('addedConvo — absent or invalid shape', () => { + test('calls next when addedConvo is absent', async () => { + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when addedConvo is a string', async () => { + req.body.addedConvo = 'not-an-object'; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when addedConvo is an array', async () => { + req.body.addedConvo = [{ agent_id: 'agent_something' }]; + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + }); + + describe('addedConvo — MULTI_CONVO permission gate', () => { + test('returns 403 when user lacks MULTI_CONVO:USE', async () => { + req.user.role = 'no-multi-convo'; + req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ message: 'Multi-conversation feature is not enabled' }), + ); + }); + + test('returns 403 when user.role is missing', async () => { + req.user = { id: testUser._id }; + req.body.addedConvo = { agent_id: 'agent_x', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('ADMIN bypasses MULTI_CONVO check', async () => { + req.user.role = SystemRoles.ADMIN; + req.body.addedConvo = { agent_id: 'ephemeral_x', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + }); + + describe('addedConvo — agent_id shape validation', () => { + test('calls next when agent_id is ephemeral', async () => { + req.body.addedConvo = { agent_id: 'ephemeral_xyz', endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when agent_id is absent', async () => { + req.body.addedConvo = { endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + + test('calls next when agent_id is not a string (object injection)', async () => { + req.body.addedConvo = { agent_id: { $gt: '' }, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + }); + }); + + describe('addedConvo — agent resource ACL (IDOR prevention)', () => { + let addedAgent; + + beforeEach(async () => { + addedAgent = await createAgent({ + id: `agent_added_${Date.now()}`, + name: 'Private Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + }); + + test('returns 403 when requester has no ACL for the added agent', async () => { + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + message: 'Insufficient permissions to access this agent', + }), + ); + }); + + test('returns 404 when added agent does not exist', async () => { + req.body.addedConvo = { + agent_id: 'agent_nonexistent_999', + endpoint: 'agents', + model: 'gpt-4', + }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(404); + }); + + test('proceeds when requester has ACL for the added agent', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + + test('denies when ACL permission bits are insufficient', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 2 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('caches resolved agent on req.resolvedAddedAgent', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(req.resolvedAddedAgent).toBeDefined(); + expect(req.resolvedAddedAgent._id.toString()).toBe(addedAgent._id.toString()); + }); + + test('ADMIN bypasses agent resource ACL for addedConvo', async () => { + req.user.role = SystemRoles.ADMIN; + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + expect(req.resolvedAddedAgent).toBeUndefined(); + }); + }); + + describe('end-to-end: primary real agent + addedConvo real agent', () => { + let primaryAgent, addedAgent; + + beforeEach(async () => { + primaryAgent = await createAgent({ + id: `agent_primary_${Date.now()}`, + name: 'Primary Agent', + provider: 'openai', + model: 'gpt-4', + author: testUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: primaryAgent._id, + permBits: 15, + grantedBy: testUser._id, + }); + + addedAgent = await createAgent({ + id: `agent_added_${Date.now()}`, + name: 'Added Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + + req.body.agent_id = primaryAgent.id; + }); + + test('both checks pass when user has ACL for both agents', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + expect(req.resolvedAddedAgent).toBeDefined(); + }); + + test('primary passes but addedConvo denied → 403', async () => { + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('primary denied → 403 without reaching addedConvo check', async () => { + const foreignAgent = await createAgent({ + id: `agent_foreign_${Date.now()}`, + name: 'Foreign Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: foreignAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + + req.body.agent_id = foreignAgent.id; + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + }); + + describe('ephemeral primary + real addedConvo agent', () => { + let addedAgent; + + beforeEach(async () => { + addedAgent = await createAgent({ + id: `agent_added_${Date.now()}`, + name: 'Added Agent', + provider: 'openai', + model: 'gpt-4', + author: otherUser._id, + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: otherUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 15, + grantedBy: otherUser._id, + }); + }); + + test('runs full addedConvo ACL check even when primary is ephemeral', async () => { + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(403); + }); + + test('proceeds when user has ACL for added agent (ephemeral primary)', async () => { + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: addedAgent._id, + permBits: 1, + grantedBy: otherUser._id, + }); + + req.body.addedConvo = { agent_id: addedAgent.id, endpoint: 'agents', model: 'gpt-4' }; + + const middleware = canAccessAgentFromBody({ requiredPermission: 1 }); + await middleware(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(res.status).not.toHaveBeenCalled(); + }); + }); +}); From 1312cd757c3e2634d59a73fb96f30dcd4c0d17e3 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 18:05:08 -0400 Subject: [PATCH 29/39] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20fix:=20Validate?= =?UTF-8?q?=20User-provided=20URLs=20for=20Web=20Search=20(#12247)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: SSRF-validate user-provided URLs in web search auth User-controlled URL fields (jinaApiUrl, firecrawlApiUrl, searxngInstanceUrl) flow from plugin auth into outbound HTTP requests without validation. Reuse existing isSSRFTarget/resolveHostnameSSRF to block private/internal targets while preserving admin-configured (env var) internal URLs. * 🛡️ fix: Harden web search SSRF validation - Reject non-HTTP(S) schemes (file://, ftp://, etc.) in isSSRFUrl - Conditional write: only assign to authResult after SSRF check passes - Move isUserProvided tracking after SSRF gate to avoid false positives - Add authenticated assertions for optional-field SSRF blocks in tests - Add file:// scheme rejection test - Wrap process.env mutation in try/finally guard - Add JSDoc + sync-obligation comment on WEB_SEARCH_URL_KEYS * 🛡️ fix: Correct auth-type reporting for SSRF-stripped optional URLs SSRF-stripped optional URL fields no longer pollute isUserProvided. Track whether the field actually contributed to authResult before crediting it as user-provided, so categories report SYSTEM_DEFINED when all surviving values match env vars. --- packages/api/src/web/web.spec.ts | 360 +++++++++++++++++++++++++++++++ packages/api/src/web/web.ts | 50 ++++- 2 files changed, 408 insertions(+), 2 deletions(-) diff --git a/packages/api/src/web/web.spec.ts b/packages/api/src/web/web.spec.ts index c7bb3f4962..74e02b20ef 100644 --- a/packages/api/src/web/web.spec.ts +++ b/packages/api/src/web/web.spec.ts @@ -18,6 +18,14 @@ jest.mock('../utils', () => ({ }, })); +const mockIsSSRFTarget = jest.fn().mockReturnValue(false); +const mockResolveHostnameSSRF = jest.fn().mockResolvedValue(false); + +jest.mock('../auth', () => ({ + isSSRFTarget: (...args: unknown[]) => mockIsSSRFTarget(...args), + resolveHostnameSSRF: (...args: unknown[]) => mockResolveHostnameSSRF(...args), +})); + describe('web.ts', () => { describe('extractWebSearchEnvVars', () => { it('should return empty array if config is undefined', () => { @@ -1227,4 +1235,356 @@ describe('web.ts', () => { expect(result.authResult.firecrawlOptions).toBeUndefined(); // Should be undefined }); }); + + describe('SSRF protection for user-provided URLs', () => { + const userId = 'test-user-id'; + let mockLoadAuthValues: jest.Mock; + + beforeEach(() => { + jest.clearAllMocks(); + mockLoadAuthValues = jest.fn(); + mockIsSSRFTarget.mockReturnValue(false); + mockResolveHostnameSSRF.mockResolvedValue(false); + }); + + it('should block user-provided jinaApiUrl targeting localhost', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === 'localhost'); + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'JINA_API_URL') { + result[field] = 'http://localhost:8080/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.jinaApiUrl).toBeUndefined(); + expect(mockIsSSRFTarget).toHaveBeenCalledWith('localhost'); + }); + + it('should block user-provided firecrawlApiUrl resolving to private IP', async () => { + mockResolveHostnameSSRF.mockImplementation((hostname: string) => + Promise.resolve(hostname === 'evil.internal-service.com'), + ); + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'https://evil.internal-service.com/scrape'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + const scrapersAuth = result.authTypes.find(([c]) => c === 'scrapers')?.[1]; + expect(scrapersAuth).toBe(AuthType.USER_PROVIDED); + }); + + it('should block user-provided searxngInstanceUrl targeting metadata endpoint', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === '169.254.169.254'); + + const webSearchConfig: TCustomConfig['webSearch'] = { + searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + searchProvider: 'searxng' as SearchProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'SEARXNG_INSTANCE_URL') { + result[field] = 'http://169.254.169.254/latest/meta-data'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.searxngInstanceUrl).toBeUndefined(); + expect(result.authenticated).toBe(false); + }); + + it('should allow system-defined URLs even if they match SSRF patterns', async () => { + mockIsSSRFTarget.mockReturnValue(true); + + const originalEnv = process.env; + try { + process.env = { + ...originalEnv, + JINA_API_KEY: 'system-jina-key', + JINA_API_URL: 'http://jina-internal:8080/rerank', + }; + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'JINA_API_KEY') { + result[field] = 'system-jina-key'; + } else if (field === 'JINA_API_URL') { + result[field] = 'http://jina-internal:8080/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.jinaApiUrl).toBe('http://jina-internal:8080/rerank'); + expect(result.authenticated).toBe(true); + } finally { + process.env = originalEnv; + } + }); + + it('should reject URLs with invalid format', async () => { + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'not-a-valid-url'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + const scrapersAuth = result.authTypes.find(([c]) => c === 'scrapers')?.[1]; + expect(scrapersAuth).toBe(AuthType.USER_PROVIDED); + }); + + it('should reject non-HTTP schemes like file://', async () => { + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'file:///etc/passwd'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + }); + + it('should allow legitimate external URLs', async () => { + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + firecrawlApiUrl: '${FIRECRAWL_API_URL}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + scraperProvider: 'firecrawl' as ScraperProviders, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'FIRECRAWL_API_URL') { + result[field] = 'https://api.firecrawl.dev'; + } else if (field === 'JINA_API_URL') { + result[field] = 'https://api.jina.ai/v1/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.firecrawlApiUrl).toBe('https://api.firecrawl.dev'); + expect(result.authResult.jinaApiUrl).toBe('https://api.jina.ai/v1/rerank'); + expect(result.authenticated).toBe(true); + }); + + it('should fail required URL field and mark category unauthenticated', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === '127.0.0.1'); + + const webSearchConfig: TCustomConfig['webSearch'] = { + searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}', + searxngApiKey: '${SEARXNG_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + safeSearch: SafeSearchTypes.MODERATE, + searchProvider: 'searxng' as SearchProviders, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'SEARXNG_INSTANCE_URL') { + result[field] = 'http://127.0.0.1:8888/search'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authenticated).toBe(false); + const providersAuthType = result.authTypes.find( + ([category]) => category === 'providers', + )?.[1]; + expect(providersAuthType).toBe(AuthType.USER_PROVIDED); + }); + + it('should report SYSTEM_DEFINED when only user-provided field is a stripped SSRF URL', async () => { + mockIsSSRFTarget.mockImplementation((hostname: string) => hostname === 'localhost'); + + const originalEnv = process.env; + try { + process.env = { + ...originalEnv, + JINA_API_KEY: 'system-jina-key', + }; + + const webSearchConfig: TCustomConfig['webSearch'] = { + serperApiKey: '${SERPER_API_KEY}', + firecrawlApiKey: '${FIRECRAWL_API_KEY}', + jinaApiKey: '${JINA_API_KEY}', + jinaApiUrl: '${JINA_API_URL}', + safeSearch: SafeSearchTypes.MODERATE, + rerankerType: 'jina' as RerankerTypes, + }; + + mockLoadAuthValues.mockImplementation(({ authFields }) => { + const result: Record = {}; + authFields.forEach((field: string) => { + if (field === 'JINA_API_KEY') { + result[field] = 'system-jina-key'; + } else if (field === 'JINA_API_URL') { + result[field] = 'http://localhost:9999/rerank'; + } else { + result[field] = 'test-api-key'; + } + }); + return Promise.resolve(result); + }); + + const result = await loadWebSearchAuth({ + userId, + webSearchConfig, + loadAuthValues: mockLoadAuthValues, + }); + + expect(result.authResult.jinaApiUrl).toBeUndefined(); + expect(result.authenticated).toBe(true); + const rerankersAuth = result.authTypes.find(([c]) => c === 'rerankers')?.[1]; + expect(rerankersAuth).toBe(AuthType.SYSTEM_DEFINED); + } finally { + process.env = originalEnv; + } + }); + }); }); diff --git a/packages/api/src/web/web.ts b/packages/api/src/web/web.ts index ad172e187f..cc0d8688ca 100644 --- a/packages/api/src/web/web.ts +++ b/packages/api/src/web/web.ts @@ -13,6 +13,37 @@ import type { TWebSearchConfig, } from 'librechat-data-provider'; import type { TWebSearchKeys, TWebSearchCategories } from '@librechat/data-schemas'; +import { isSSRFTarget, resolveHostnameSSRF } from '../auth'; + +/** + * URL-type keys in TWebSearchKeys (not API keys or version strings). + * Must stay in sync with URL-typed fields in webSearchAuth (packages/data-schemas). + */ +const WEB_SEARCH_URL_KEYS = new Set([ + 'searxngInstanceUrl', + 'firecrawlApiUrl', + 'jinaApiUrl', +]); + +/** + * Returns true if the URL should be blocked for SSRF risk. + * Fail-closed: unparseable URLs and non-HTTP(S) schemes return true. + */ +async function isSSRFUrl(url: string): Promise { + let parsed: URL; + try { + parsed = new URL(url); + } catch { + return true; + } + if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') { + return true; + } + if (isSSRFTarget(parsed.hostname)) { + return true; + } + return resolveHostnameSSRF(parsed.hostname); +} export function extractWebSearchEnvVars({ keys, @@ -149,12 +180,27 @@ export async function loadWebSearchAuth({ const field = allAuthFields[j]; const value = authValues[field]; const originalKey = allKeys[j]; - if (originalKey) authResult[originalKey] = value; + if (!optionalSet.has(field) && !value) { allFieldsAuthenticated = false; break; } - if (!isUserProvided && process.env[field] !== value) { + + const isFieldUserProvided = value != null && process.env[field] !== value; + const isUrlKey = originalKey != null && WEB_SEARCH_URL_KEYS.has(originalKey); + let contributed = false; + + if (isUrlKey && isFieldUserProvided && (await isSSRFUrl(value))) { + if (!optionalSet.has(field)) { + allFieldsAuthenticated = false; + break; + } + } else if (originalKey) { + authResult[originalKey] = value; + contributed = true; + } + + if (!isUserProvided && isFieldUserProvided && contributed) { isUserProvided = true; } } From bcf45519bd5c10740ec38cf37f74ee7d1bac287a Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 18:08:57 -0400 Subject: [PATCH 30/39] =?UTF-8?q?=F0=9F=AA=AA=20fix:=20Enforce=20VIEW=20AC?= =?UTF-8?q?L=20on=20Agent=20Edge=20References=20at=20Write=20and=20Runtime?= =?UTF-8?q?=20(#12246)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Enforce ACL checks on handoff edge and added-convo agent loading Edge-linked agents and added-convo agents were fetched by ID via getAgent without verifying the requesting user's access permissions. This allowed an authenticated user to reference another user's private agent in edges or addedConvo and have it initialized at runtime. Add checkPermission(VIEW) gate in processAgent before initializing any handoff agent, and in processAddedConvo for non-ephemeral added agents. Unauthorized agents are logged and added to skippedAgentIds so orphaned-edge filtering removes them cleanly. * 🛡️ fix: Validate edge agent access at agent create/update time Reject agent create/update requests that reference agents in edges the requesting user cannot VIEW. This provides early feedback and prevents storing unauthorized agent references as defense-in-depth alongside the runtime ACL gate in processAgent. Add collectEdgeAgentIds utility to extract all unique agent IDs from an edge array, and validateEdgeAgentAccess helper in the v1 handler. * 🧪 test: Improve ACL gate test coverage and correctness - Add processAgent ACL gate tests for initializeClient (skip/allow handoff agents) - Fix addedConvo.spec.js to mock loadAddedAgent directly instead of getAgent - Seed permMap with ownedAgent VIEW bits in v1.spec.js update-403 test * 🧹 chore: Remove redundant addedConvo ACL gate (now in middleware) PR #12243 moved the addedConvo agent ACL check upstream into canAccessAgentFromBody middleware, making the runtime check in processAddedConvo and its spec redundant. * 🧪 test: Rewrite processAgent ACL test with real DB and minimal mocking Replace heavy mock-based test (12 mocks, Providers.XAI crash) with MongoMemoryServer-backed integration test that exercises real getAgent, checkPermission, and AclEntry — only external I/O (initializeAgent, ToolService, AgentClient) remains mocked. Load edge utilities directly from packages/api/src/agents/edges to sidestep the config.ts barrel. * 🧪 fix: Use requireActual spread for @librechat/agents and @librechat/api mocks The Providers.XAI crash was caused by mocking @librechat/agents with a minimal replacement object, breaking the @librechat/api initialization chain. Match the established pattern from client.test.js and recordCollectedUsage.spec.js: spread jest.requireActual for both packages, overriding only the functions under test. --- api/server/controllers/agents/v1.js | 63 +++++- api/server/controllers/agents/v1.spec.js | 113 +++++++++- .../services/Endpoints/agents/initialize.js | 19 ++ .../Endpoints/agents/initialize.spec.js | 201 ++++++++++++++++++ packages/api/src/agents/edges.spec.ts | 51 ++++- packages/api/src/agents/edges.ts | 14 ++ 6 files changed, 457 insertions(+), 4 deletions(-) create mode 100644 api/server/services/Endpoints/agents/initialize.spec.js diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 1abba8b2c8..dbb97df24b 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -6,6 +6,7 @@ const { agentCreateSchema, agentUpdateSchema, refreshListAvatars, + collectEdgeAgentIds, mergeAgentOcrConversion, MAX_AVATAR_REFRESH_AGENTS, convertOcrToContextInPlace, @@ -35,6 +36,7 @@ const { } = require('~/models/Agent'); const { findPubliclyAccessibleResources, + getResourcePermissionsMap, findAccessibleResources, hasPublicPermission, grantPermission, @@ -58,6 +60,44 @@ const systemTools = { const MAX_SEARCH_LEN = 100; const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); +/** + * Validates that the requesting user has VIEW access to every agent referenced in edges. + * Agents that do not exist in the database are skipped — at create time, the `from` field + * often references the agent being built, which has no DB record yet. + * @param {import('librechat-data-provider').GraphEdge[]} edges + * @param {string} userId + * @param {string} userRole - Used for group/role principal resolution + * @returns {Promise} Agent IDs the user cannot VIEW (empty if all accessible) + */ +const validateEdgeAgentAccess = async (edges, userId, userRole) => { + const edgeAgentIds = collectEdgeAgentIds(edges); + if (edgeAgentIds.size === 0) { + return []; + } + + const agents = (await Promise.all([...edgeAgentIds].map((id) => getAgent({ id })))).filter( + Boolean, + ); + + if (agents.length === 0) { + return []; + } + + const permissionsMap = await getResourcePermissionsMap({ + userId, + role: userRole, + resourceType: ResourceType.AGENT, + resourceIds: agents.map((a) => a._id), + }); + + return agents + .filter((a) => { + const bits = permissionsMap.get(a._id.toString()) ?? 0; + return (bits & PermissionBits.VIEW) === 0; + }) + .map((a) => a.id); +}; + /** * Creates an Agent. * @route POST /Agents @@ -75,7 +115,17 @@ const createAgentHandler = async (req, res) => { agentData.model_parameters = removeNullishValues(agentData.model_parameters, true); } - const { id: userId } = req.user; + const { id: userId, role: userRole } = req.user; + + if (agentData.edges?.length) { + const unauthorized = await validateEdgeAgentAccess(agentData.edges, userId, userRole); + if (unauthorized.length > 0) { + return res.status(403).json({ + error: 'You do not have access to one or more agents referenced in edges', + agent_ids: unauthorized, + }); + } + } agentData.id = `agent_${nanoid()}`; agentData.author = userId; @@ -243,6 +293,17 @@ const updateAgentHandler = async (req, res) => { updateData.avatar = avatarField; } + if (updateData.edges?.length) { + const { id: userId, role: userRole } = req.user; + const unauthorized = await validateEdgeAgentAccess(updateData.edges, userId, userRole); + if (unauthorized.length > 0) { + return res.status(403).json({ + error: 'You do not have access to one or more agents referenced in edges', + agent_ids: unauthorized, + }); + } + } + // Convert OCR to context in incoming updateData convertOcrToContextInPlace(updateData); diff --git a/api/server/controllers/agents/v1.spec.js b/api/server/controllers/agents/v1.spec.js index ce68cc241f..ede4ea416a 100644 --- a/api/server/controllers/agents/v1.spec.js +++ b/api/server/controllers/agents/v1.spec.js @@ -2,7 +2,7 @@ const mongoose = require('mongoose'); const { nanoid } = require('nanoid'); const { v4: uuidv4 } = require('uuid'); const { agentSchema } = require('@librechat/data-schemas'); -const { FileSources } = require('librechat-data-provider'); +const { FileSources, PermissionBits } = require('librechat-data-provider'); const { MongoMemoryServer } = require('mongodb-memory-server'); // Only mock the dependencies that are not database-related @@ -46,9 +46,9 @@ jest.mock('~/models/File', () => ({ jest.mock('~/server/services/PermissionService', () => ({ findAccessibleResources: jest.fn().mockResolvedValue([]), findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]), + getResourcePermissionsMap: jest.fn().mockResolvedValue(new Map()), grantPermission: jest.fn(), hasPublicPermission: jest.fn().mockResolvedValue(false), - checkPermission: jest.fn().mockResolvedValue(true), })); jest.mock('~/models', () => ({ @@ -74,6 +74,7 @@ const { const { findAccessibleResources, findPubliclyAccessibleResources, + getResourcePermissionsMap, } = require('~/server/services/PermissionService'); const { refreshS3Url } = require('~/server/services/Files/S3/crud'); @@ -1647,4 +1648,112 @@ describe('Agent Controllers - Mass Assignment Protection', () => { expect(agent.avatar.filepath).toBe('old-s3-path.jpg'); }); }); + + describe('Edge ACL validation', () => { + let targetAgent; + + beforeEach(async () => { + targetAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: new mongoose.Types.ObjectId().toString(), + name: 'Target Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + }); + + test('createAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => { + const permMap = new Map(); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.body = { + name: 'Attacker Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.agent_ids).toContain(targetAgent.id); + }); + + test('createAgentHandler should succeed when user has VIEW on all edge-referenced agents', async () => { + const permMap = new Map([[targetAgent._id.toString(), 1]]); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.body = { + name: 'Legit Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + }); + + test('createAgentHandler should allow edges referencing non-existent agents (self-reference at create time)', async () => { + mockReq.body = { + name: 'Self-Ref Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'agent_does_not_exist_yet', to: 'agent_also_new', edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + }); + + test('updateAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => { + const ownedAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: mockReq.user.id, + name: 'Owned Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + + const permMap = new Map([[ownedAgent._id.toString(), PermissionBits.VIEW]]); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.params = { id: ownedAgent.id }; + mockReq.body = { + edges: [{ from: ownedAgent.id, to: targetAgent.id, edgeType: 'handoff' }], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.agent_ids).toContain(targetAgent.id); + expect(response.agent_ids).not.toContain(ownedAgent.id); + }); + + test('updateAgentHandler should succeed when edges field is absent from payload', async () => { + const ownedAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: mockReq.user.id, + name: 'Owned Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + + mockReq.params = { id: ownedAgent.id }; + mockReq.body = { name: 'Renamed Agent' }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.name).toBe('Renamed Agent'); + }); + }); }); diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index e71270ef85..44583e6dbc 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -10,6 +10,8 @@ const { createSequentialChainEdges, } = require('@librechat/api'); const { + ResourceType, + PermissionBits, EModelEndpoint, isAgentsEndpoint, getResponseSender, @@ -21,6 +23,7 @@ const { } = require('~/server/controllers/agents/callbacks'); const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { getModelsConfig } = require('~/server/controllers/ModelController'); +const { checkPermission } = require('~/server/services/PermissionService'); const AgentClient = require('~/server/controllers/agents/client'); const { getConvoFiles } = require('~/models/Conversation'); const { processAddedConvo } = require('./addedConvo'); @@ -229,6 +232,22 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { return null; } + const hasAccess = await checkPermission({ + userId: req.user.id, + role: req.user.role, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + requiredPermission: PermissionBits.VIEW, + }); + + if (!hasAccess) { + logger.warn( + `[processAgent] User ${req.user.id} lacks VIEW access to handoff agent ${agentId}, skipping`, + ); + skippedAgentIds.add(agentId); + return null; + } + const validationResult = await validateAgentModel({ req, res, diff --git a/api/server/services/Endpoints/agents/initialize.spec.js b/api/server/services/Endpoints/agents/initialize.spec.js new file mode 100644 index 0000000000..16b41aca65 --- /dev/null +++ b/api/server/services/Endpoints/agents/initialize.spec.js @@ -0,0 +1,201 @@ +const mongoose = require('mongoose'); +const { + ResourceType, + PermissionBits, + PrincipalType, + PrincipalModel, +} = require('librechat-data-provider'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +const mockInitializeAgent = jest.fn(); +const mockValidateAgentModel = jest.fn(); + +jest.mock('@librechat/agents', () => ({ + ...jest.requireActual('@librechat/agents'), + createContentAggregator: jest.fn(() => ({ + contentParts: [], + aggregateContent: jest.fn(), + })), +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + initializeAgent: (...args) => mockInitializeAgent(...args), + validateAgentModel: (...args) => mockValidateAgentModel(...args), + GenerationJobManager: { setCollectedUsage: jest.fn() }, + getCustomEndpointConfig: jest.fn(), + createSequentialChainEdges: jest.fn(), +})); + +jest.mock('~/server/controllers/agents/callbacks', () => ({ + createToolEndCallback: jest.fn(() => jest.fn()), + getDefaultHandlers: jest.fn(() => ({})), +})); + +jest.mock('~/server/services/ToolService', () => ({ + loadAgentTools: jest.fn(), + loadToolsForExecution: jest.fn(), +})); + +jest.mock('~/server/controllers/ModelController', () => ({ + getModelsConfig: jest.fn().mockResolvedValue({}), +})); + +let agentClientArgs; +jest.mock('~/server/controllers/agents/client', () => { + return jest.fn().mockImplementation((args) => { + agentClientArgs = args; + return {}; + }); +}); + +jest.mock('./addedConvo', () => ({ + processAddedConvo: jest.fn().mockResolvedValue({ userMCPAuthMap: undefined }), +})); + +jest.mock('~/cache', () => ({ + logViolation: jest.fn(), +})); + +const { initializeClient } = require('./initialize'); +const { createAgent } = require('~/models/Agent'); +const { User, AclEntry } = require('~/db/models'); + +const PRIMARY_ID = 'agent_primary'; +const TARGET_ID = 'agent_target'; +const AUTHORIZED_ID = 'agent_authorized'; + +describe('initializeClient — processAgent ACL gate', () => { + let mongoServer; + let testUser; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await mongoose.connection.dropDatabase(); + jest.clearAllMocks(); + agentClientArgs = undefined; + + testUser = await User.create({ + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + role: 'USER', + }); + + mockValidateAgentModel.mockResolvedValue({ isValid: true }); + }); + + const makeReq = () => ({ + user: { id: testUser._id.toString(), role: 'USER' }, + body: { conversationId: 'conv_1', files: [] }, + config: { endpoints: {} }, + _resumableStreamId: null, + }); + + const makeEndpointOption = () => ({ + agent: Promise.resolve({ + id: PRIMARY_ID, + name: 'Primary', + provider: 'openai', + model: 'gpt-4', + tools: [], + }), + model_parameters: { model: 'gpt-4' }, + endpoint: 'agents', + }); + + const makePrimaryConfig = (edges) => ({ + id: PRIMARY_ID, + endpoint: 'agents', + edges, + toolDefinitions: [], + toolRegistry: new Map(), + userMCPAuthMap: null, + tool_resources: {}, + resendFiles: true, + maxContextTokens: 4096, + }); + + it('should skip handoff agent and filter its edge when user lacks VIEW access', async () => { + await createAgent({ + id: TARGET_ID, + name: 'Target Agent', + provider: 'openai', + model: 'gpt-4', + author: new mongoose.Types.ObjectId(), + tools: [], + }); + + const edges = [{ from: PRIMARY_ID, to: TARGET_ID, edgeType: 'handoff' }]; + mockInitializeAgent.mockResolvedValue(makePrimaryConfig(edges)); + + await initializeClient({ + req: makeReq(), + res: {}, + signal: new AbortController().signal, + endpointOption: makeEndpointOption(), + }); + + expect(mockInitializeAgent).toHaveBeenCalledTimes(1); + expect(agentClientArgs.agent.edges).toEqual([]); + }); + + it('should initialize handoff agent and keep its edge when user has VIEW access', async () => { + const authorizedAgent = await createAgent({ + id: AUTHORIZED_ID, + name: 'Authorized Agent', + provider: 'openai', + model: 'gpt-4', + author: new mongoose.Types.ObjectId(), + tools: [], + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: authorizedAgent._id, + permBits: PermissionBits.VIEW, + grantedBy: testUser._id, + }); + + const edges = [{ from: PRIMARY_ID, to: AUTHORIZED_ID, edgeType: 'handoff' }]; + const handoffConfig = { + id: AUTHORIZED_ID, + edges: [], + toolDefinitions: [], + toolRegistry: new Map(), + userMCPAuthMap: null, + tool_resources: {}, + }; + + let callCount = 0; + mockInitializeAgent.mockImplementation(() => { + callCount++; + return callCount === 1 + ? Promise.resolve(makePrimaryConfig(edges)) + : Promise.resolve(handoffConfig); + }); + + await initializeClient({ + req: makeReq(), + res: {}, + signal: new AbortController().signal, + endpointOption: makeEndpointOption(), + }); + + expect(mockInitializeAgent).toHaveBeenCalledTimes(2); + expect(agentClientArgs.agent.edges).toHaveLength(1); + expect(agentClientArgs.agent.edges[0].to).toBe(AUTHORIZED_ID); + }); +}); diff --git a/packages/api/src/agents/edges.spec.ts b/packages/api/src/agents/edges.spec.ts index 1b30a202d0..b23f00f63f 100644 --- a/packages/api/src/agents/edges.spec.ts +++ b/packages/api/src/agents/edges.spec.ts @@ -1,5 +1,11 @@ import type { GraphEdge } from 'librechat-data-provider'; -import { getEdgeKey, getEdgeParticipants, filterOrphanedEdges, createEdgeCollector } from './edges'; +import { + getEdgeKey, + getEdgeParticipants, + collectEdgeAgentIds, + filterOrphanedEdges, + createEdgeCollector, +} from './edges'; describe('edges utilities', () => { describe('getEdgeKey', () => { @@ -70,6 +76,49 @@ describe('edges utilities', () => { }); }); + describe('collectEdgeAgentIds', () => { + it('should return empty set for undefined input', () => { + expect(collectEdgeAgentIds(undefined)).toEqual(new Set()); + }); + + it('should return empty set for empty array', () => { + expect(collectEdgeAgentIds([])).toEqual(new Set()); + }); + + it('should collect IDs from simple string from/to', () => { + const edges: GraphEdge[] = [{ from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }]; + expect(collectEdgeAgentIds(edges)).toEqual(new Set(['agent_a', 'agent_b'])); + }); + + it('should collect IDs from array from/to values', () => { + const edges: GraphEdge[] = [ + { from: ['agent_a', 'agent_b'], to: ['agent_c', 'agent_d'], edgeType: 'handoff' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual( + new Set(['agent_a', 'agent_b', 'agent_c', 'agent_d']), + ); + }); + + it('should deduplicate IDs across edges', () => { + const edges: GraphEdge[] = [ + { from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }, + { from: 'agent_b', to: 'agent_c', edgeType: 'handoff' }, + { from: 'agent_a', to: 'agent_c', edgeType: 'direct' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual(new Set(['agent_a', 'agent_b', 'agent_c'])); + }); + + it('should handle mixed scalar and array edges', () => { + const edges: GraphEdge[] = [ + { from: 'agent_a', to: ['agent_b', 'agent_c'], edgeType: 'handoff' }, + { from: ['agent_c', 'agent_d'], to: 'agent_e', edgeType: 'direct' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual( + new Set(['agent_a', 'agent_b', 'agent_c', 'agent_d', 'agent_e']), + ); + }); + }); + describe('filterOrphanedEdges', () => { const edges: GraphEdge[] = [ { from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }, diff --git a/packages/api/src/agents/edges.ts b/packages/api/src/agents/edges.ts index 4d2883d165..9a36105b74 100644 --- a/packages/api/src/agents/edges.ts +++ b/packages/api/src/agents/edges.ts @@ -43,6 +43,20 @@ export function filterOrphanedEdges(edges: GraphEdge[], skippedAgentIds: Set { + const ids = new Set(); + if (!edges || edges.length === 0) { + return ids; + } + for (const edge of edges) { + for (const id of getEdgeParticipants(edge)) { + ids.add(id); + } + } + return ids; +} + /** * Result of discovering and aggregating edges from connected agents. */ From f9927f01687f40eb0b77b9c98308a9dc1b05898c Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 18:40:42 -0400 Subject: [PATCH 31/39] =?UTF-8?q?=F0=9F=93=91=20fix:=20Sanitize=20Markdown?= =?UTF-8?q?=20Artifacts=20(#12249)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Sanitize markdown artifact rendering to prevent stored XSS Replace marked-react with react-markdown + remark-gfm for artifact markdown preview. react-markdown's skipHtml strips raw HTML tags, and a urlTransform guard blocks javascript: and data: protocol links. * fix: Update useArtifactProps test to expect react-markdown dependencies * fix: Harden markdown artifact sanitization - Convert isSafeUrl from denylist to allowlist (http, https, mailto, tel plus relative/anchor URLs); unknown protocols are now fail-closed - Add remark-breaks to restore single-newline-to-
behavior that was silently dropped when replacing marked-react - Export isSafeUrl from the host module and add 16 direct unit tests covering allowed protocols, blocked schemes (javascript, data, blob, vbscript, file, custom), edge cases (empty, whitespace, mixed case) - Hoist remarkPlugins to a module-level constant to avoid per-render array allocation in the generated Sandpack component - Fix import order in generated template (shortest to longest per AGENTS.md) and remove pre-existing trailing whitespace * fix: Return null for blocked URLs, add sync-guard comments and test - urlTransform returns null (not '') for blocked URLs so react-markdown omits the href/src attribute entirely instead of producing - Hoist urlTransform to module-level constant alongside remarkPlugins - Add JSDoc sync-guard comments tying the exported isSafeUrl to its template-string mirror, so future maintainers know to update both - Add synchronization test asserting the embedded isSafeUrl contains the same allowlist set, URL parsing, and relative-path checks as the export --- .../__tests__/useArtifactProps.test.ts | 6 +- client/src/utils/__tests__/markdown.test.ts | 96 +++++++++++++++++-- client/src/utils/artifacts.ts | 4 +- client/src/utils/markdown.ts | 55 ++++++++++- 4 files changed, 148 insertions(+), 13 deletions(-) diff --git a/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts b/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts index f9f29e0c56..e46a285c50 100644 --- a/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts +++ b/client/src/hooks/Artifacts/__tests__/useArtifactProps.test.ts @@ -112,7 +112,7 @@ describe('useArtifactProps', () => { expect(result.current.files['content.md']).toBe('# No content provided'); }); - it('should provide marked-react dependency', () => { + it('should provide react-markdown dependency', () => { const artifact = createArtifact({ type: 'text/markdown', content: '# Test', @@ -120,7 +120,9 @@ describe('useArtifactProps', () => { const { result } = renderHook(() => useArtifactProps({ artifact })); - expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('marked-react'); + expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('react-markdown'); + expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('remark-gfm'); + expect(result.current.sharedProps.customSetup?.dependencies).toHaveProperty('remark-breaks'); }); it('should update files when content changes', () => { diff --git a/client/src/utils/__tests__/markdown.test.ts b/client/src/utils/__tests__/markdown.test.ts index fcc0f169e6..9734e0e18a 100644 --- a/client/src/utils/__tests__/markdown.test.ts +++ b/client/src/utils/__tests__/markdown.test.ts @@ -1,4 +1,72 @@ -import { getMarkdownFiles } from '../markdown'; +import { isSafeUrl, getMarkdownFiles } from '../markdown'; + +describe('isSafeUrl', () => { + it('allows https URLs', () => { + expect(isSafeUrl('https://example.com')).toBe(true); + }); + + it('allows http URLs', () => { + expect(isSafeUrl('http://example.com/path')).toBe(true); + }); + + it('allows mailto links', () => { + expect(isSafeUrl('mailto:user@example.com')).toBe(true); + }); + + it('allows tel links', () => { + expect(isSafeUrl('tel:+1234567890')).toBe(true); + }); + + it('allows relative paths', () => { + expect(isSafeUrl('/path/to/page')).toBe(true); + expect(isSafeUrl('./relative')).toBe(true); + expect(isSafeUrl('../parent')).toBe(true); + }); + + it('allows anchor links', () => { + expect(isSafeUrl('#section')).toBe(true); + }); + + it('blocks javascript: protocol', () => { + expect(isSafeUrl('javascript:alert(1)')).toBe(false); + }); + + it('blocks javascript: with leading whitespace', () => { + expect(isSafeUrl(' javascript:alert(1)')).toBe(false); + }); + + it('blocks javascript: with mixed case', () => { + expect(isSafeUrl('JavaScript:alert(1)')).toBe(false); + }); + + it('blocks data: protocol', () => { + expect(isSafeUrl('data:text/html,x')).toBe(false); + }); + + it('blocks blob: protocol', () => { + expect(isSafeUrl('blob:http://example.com/uuid')).toBe(false); + }); + + it('blocks vbscript: protocol', () => { + expect(isSafeUrl('vbscript:MsgBox("xss")')).toBe(false); + }); + + it('blocks file: protocol', () => { + expect(isSafeUrl('file:///etc/passwd')).toBe(false); + }); + + it('blocks empty strings', () => { + expect(isSafeUrl('')).toBe(false); + }); + + it('blocks whitespace-only strings', () => { + expect(isSafeUrl(' ')).toBe(false); + }); + + it('blocks unknown/custom protocols', () => { + expect(isSafeUrl('custom:payload')).toBe(false); + }); +}); describe('markdown artifacts', () => { describe('getMarkdownFiles', () => { @@ -41,7 +109,7 @@ describe('markdown artifacts', () => { const markdown = '# Test'; const files = getMarkdownFiles(markdown); - expect(files['/components/ui/MarkdownRenderer.tsx']).toContain('import Markdown from'); + expect(files['/components/ui/MarkdownRenderer.tsx']).toContain('import ReactMarkdown from'); expect(files['/components/ui/MarkdownRenderer.tsx']).toContain('MarkdownRendererProps'); expect(files['/components/ui/MarkdownRenderer.tsx']).toContain( 'export default MarkdownRenderer', @@ -162,13 +230,29 @@ describe('markdown artifacts', () => { }); describe('markdown component structure', () => { - it('should generate a MarkdownRenderer component that uses marked-react', () => { + it('should generate a MarkdownRenderer component with safe markdown rendering', () => { const files = getMarkdownFiles('# Test'); const rendererCode = files['/components/ui/MarkdownRenderer.tsx']; - // Verify the component imports and uses Markdown from marked-react - expect(rendererCode).toContain("import Markdown from 'marked-react'"); - expect(rendererCode).toContain('{content}'); + expect(rendererCode).toContain("import ReactMarkdown from 'react-markdown'"); + expect(rendererCode).toContain("import remarkBreaks from 'remark-breaks'"); + expect(rendererCode).toContain('skipHtml={true}'); + expect(rendererCode).toContain('SAFE_PROTOCOLS'); + expect(rendererCode).toContain('isSafeUrl'); + expect(rendererCode).toContain('urlTransform={urlTransform}'); + expect(rendererCode).toContain('remarkPlugins={remarkPlugins}'); + expect(rendererCode).toContain('isSafeUrl(url) ? url : null'); + }); + + it('should embed isSafeUrl logic matching the exported version', () => { + const files = getMarkdownFiles('# Test'); + const rendererCode = files['/components/ui/MarkdownRenderer.tsx']; + + expect(rendererCode).toContain("new Set(['http:', 'https:', 'mailto:', 'tel:'])"); + expect(rendererCode).toContain('new URL(trimmed).protocol'); + expect(rendererCode).toContain("trimmed.startsWith('/')"); + expect(rendererCode).toContain("trimmed.startsWith('#')"); + expect(rendererCode).toContain("trimmed.startsWith('.')"); }); it('should pass markdown content to the Markdown component', () => { diff --git a/client/src/utils/artifacts.ts b/client/src/utils/artifacts.ts index 13f3a23b47..e862d18a40 100644 --- a/client/src/utils/artifacts.ts +++ b/client/src/utils/artifacts.ts @@ -108,7 +108,9 @@ const mermaidDependencies = { }; const markdownDependencies = { - 'marked-react': '^2.0.0', + 'remark-gfm': '^4.0.0', + 'remark-breaks': '^4.0.0', + 'react-markdown': '^9.0.1', }; const dependenciesMap: Record< diff --git a/client/src/utils/markdown.ts b/client/src/utils/markdown.ts index 12556c1a24..24d5105863 100644 --- a/client/src/utils/markdown.ts +++ b/client/src/utils/markdown.ts @@ -1,23 +1,70 @@ import dedent from 'dedent'; -const markdownRenderer = dedent(`import React, { useEffect, useState } from 'react'; -import Markdown from 'marked-react'; +const SAFE_PROTOCOLS = new Set(['http:', 'https:', 'mailto:', 'tel:']); + +/** + * Allowlist-based URL validator for markdown artifact rendering. + * Mirrored verbatim in the markdownRenderer template string below — + * any logic change MUST be applied to both copies. + */ +export const isSafeUrl = (url: string): boolean => { + const trimmed = url.trim(); + if (!trimmed) { + return false; + } + if (trimmed.startsWith('/') || trimmed.startsWith('#') || trimmed.startsWith('.')) { + return true; + } + try { + return SAFE_PROTOCOLS.has(new URL(trimmed).protocol); + } catch { + return false; + } +}; + +const markdownRenderer = dedent(`import React from 'react'; +import remarkGfm from 'remark-gfm'; +import remarkBreaks from 'remark-breaks'; +import ReactMarkdown from 'react-markdown'; interface MarkdownRendererProps { content: string; } +/** Mirror of the exported isSafeUrl in markdown.ts — keep in sync. */ +const SAFE_PROTOCOLS = new Set(['http:', 'https:', 'mailto:', 'tel:']); + +const isSafeUrl = (url: string): boolean => { + const trimmed = url.trim(); + if (!trimmed) return false; + if (trimmed.startsWith('/') || trimmed.startsWith('#') || trimmed.startsWith('.')) return true; + try { + return SAFE_PROTOCOLS.has(new URL(trimmed).protocol); + } catch { + return false; + } +}; + +const remarkPlugins = [remarkGfm, remarkBreaks]; +const urlTransform = (url: string) => (isSafeUrl(url) ? url : null); + const MarkdownRenderer: React.FC = ({ content }) => { return (
- {content} + + {content} +
); }; From f7ab5e645ad25ff36463a820d791d277afec985a Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 18:41:59 -0400 Subject: [PATCH 32/39] =?UTF-8?q?=F0=9F=AB=B7=20fix:=20Validate=20User-Pro?= =?UTF-8?q?vided=20Base=20URL=20in=20Endpoint=20Init=20(#12248)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Block SSRF via user-provided baseURL in endpoint initialization User-provided baseURL values (when endpoint is configured with `user_provided`) were passed through to the OpenAI SDK without validation. Combined with `directEndpoint`, this allowed arbitrary server-side requests to internal/metadata URLs. Adds `validateEndpointURL` that checks against known SSRF targets and DNS-resolves hostnames to block private IPs. Applied in both custom and OpenAI endpoint initialization paths. * 🧪 test: Add validateEndpointURL SSRF tests Covers unparseable URLs, localhost, private IPs, link-local/metadata, internal Docker/K8s hostnames, DNS resolution to private IPs, and legitimate public URLs. * 🛡️ fix: Add protocol enforcement and import order fix - Reject non-HTTP/HTTPS schemes (ftp://, file://, data:, etc.) in validateEndpointURL before SSRF hostname checks - Document DNS rebinding limitation and fail-open semantics in JSDoc - Fix import order in custom/initialize.ts per project conventions * 🧪 test: Expand SSRF validation coverage and add initializer integration tests Unit tests for validateEndpointURL: - Non-HTTP/HTTPS schemes (ftp, file, data) - IPv6 loopback, link-local, and unique-local addresses - .local and .internal TLD hostnames - DNS fail-open path (lookup failure allows request) Integration tests for initializeCustom and initializeOpenAI: - Guard fires when userProvidesURL is true - Guard skipped when URL is system-defined or falsy - SSRF rejection propagates and prevents getOpenAIConfig call * 🐛 fix: Correct broken env restore in OpenAI initialize spec process.env was captured by reference, not by value, making the restore closure a no-op. Snapshot individual env keys before mutation so they can be properly restored after each test. * 🛡️ fix: Throw structured ErrorTypes for SSRF base URL validation Replace plain-string Error throws in validateEndpointURL with JSON-structured errors using type 'invalid_base_url' (matching new ErrorTypes.INVALID_BASE_URL enum value). This ensures the client-side Error component can look up a localized message instead of falling through to the raw-text default. Changes across workspaces: - data-provider: add INVALID_BASE_URL to ErrorTypes enum - packages/api: throwInvalidBaseURL helper emits structured JSON - client: add errorMessages entry and localization key - tests: add structured JSON format assertion * 🧹 refactor: Use ErrorTypes enum key in Error.tsx for consistency Replace bare string literal 'invalid_base_url' with computed property [ErrorTypes.INVALID_BASE_URL] to match every other entry in the errorMessages map. --- .../src/components/Messages/Content/Error.tsx | 1 + client/src/locales/en/translation.json | 1 + packages/api/src/auth/domain.spec.ts | 133 +++++++++++++++++ packages/api/src/auth/domain.ts | 42 ++++++ .../src/endpoints/custom/initialize.spec.ts | 119 +++++++++++++++ .../api/src/endpoints/custom/initialize.ts | 7 +- .../src/endpoints/openai/initialize.spec.ts | 135 ++++++++++++++++++ .../api/src/endpoints/openai/initialize.ts | 5 + packages/data-provider/src/config.ts | 4 + 9 files changed, 446 insertions(+), 1 deletion(-) create mode 100644 packages/api/src/endpoints/custom/initialize.spec.ts create mode 100644 packages/api/src/endpoints/openai/initialize.spec.ts diff --git a/client/src/components/Messages/Content/Error.tsx b/client/src/components/Messages/Content/Error.tsx index 469e29fe32..ff2f2d7e90 100644 --- a/client/src/components/Messages/Content/Error.tsx +++ b/client/src/components/Messages/Content/Error.tsx @@ -41,6 +41,7 @@ const errorMessages = { [ErrorTypes.NO_USER_KEY]: 'com_error_no_user_key', [ErrorTypes.INVALID_USER_KEY]: 'com_error_invalid_user_key', [ErrorTypes.NO_BASE_URL]: 'com_error_no_base_url', + [ErrorTypes.INVALID_BASE_URL]: 'com_error_invalid_base_url', [ErrorTypes.INVALID_ACTION]: `com_error_${ErrorTypes.INVALID_ACTION}`, [ErrorTypes.INVALID_REQUEST]: `com_error_${ErrorTypes.INVALID_REQUEST}`, [ErrorTypes.REFUSAL]: 'com_error_refusal', diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index f45cdd5f8c..36d882c6a2 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -372,6 +372,7 @@ "com_error_missing_model": "No model selected for {{0}}. Please select a model and try again.", "com_error_models_not_loaded": "Models configuration could not be loaded. Please refresh the page and try again.", "com_error_moderation": "It appears that the content submitted has been flagged by our moderation system for not aligning with our community guidelines. We're unable to proceed with this specific topic. If you have any other questions or topics you'd like to explore, please edit your message, or create a new conversation.", + "com_error_invalid_base_url": "The base URL you provided targets a restricted address. Please use a valid external URL and try again.", "com_error_no_base_url": "No base URL found. Please provide one and try again.", "com_error_no_user_key": "No key found. Please provide a key and try again.", "com_error_refusal": "Response refused by safety filters. Rewrite your message and try again. If you encounter this frequently while using Claude Sonnet 4.5 or Opus 4.1, you can try Sonnet 4, which has different usage restrictions.", diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index 76f50213db..a7140528a9 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -12,6 +12,7 @@ import { isPrivateIP, isSSRFTarget, resolveHostnameSSRF, + validateEndpointURL, } from './domain'; const mockedLookup = lookup as jest.MockedFunction; @@ -1209,3 +1210,135 @@ describe('isMCPDomainAllowed', () => { }); }); }); + +describe('validateEndpointURL', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should throw for unparseable URLs', async () => { + await expect(validateEndpointURL('not-a-url', 'test-ep')).rejects.toThrow( + 'Invalid base URL for test-ep', + ); + }); + + it('should throw for localhost URLs', async () => { + await expect(validateEndpointURL('http://localhost:8080/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for private IP URLs', async () => { + await expect(validateEndpointURL('http://192.168.1.1/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + await expect(validateEndpointURL('http://10.0.0.1/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + await expect(validateEndpointURL('http://172.16.0.1/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for link-local / metadata IP', async () => { + await expect( + validateEndpointURL('http://169.254.169.254/latest/meta-data/', 'test-ep'), + ).rejects.toThrow('targets a restricted address'); + }); + + it('should throw for loopback IP', async () => { + await expect(validateEndpointURL('http://127.0.0.1:11434/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for internal Docker/Kubernetes hostnames', async () => { + await expect(validateEndpointURL('http://redis:6379/', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + await expect(validateEndpointURL('http://mongodb:27017/', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw when hostname DNS-resolves to a private IP', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '10.0.0.5', family: 4 }] as never); + await expect(validateEndpointURL('https://evil.example.com/v1', 'test-ep')).rejects.toThrow( + 'resolves to a restricted address', + ); + }); + + it('should allow public URLs', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '104.18.7.192', family: 4 }] as never); + await expect( + validateEndpointURL('https://api.openai.com/v1', 'test-ep'), + ).resolves.toBeUndefined(); + }); + + it('should allow public URLs that resolve to public IPs', async () => { + mockedLookup.mockResolvedValueOnce([{ address: '8.8.8.8', family: 4 }] as never); + await expect( + validateEndpointURL('https://api.example.com/v1/chat', 'test-ep'), + ).resolves.toBeUndefined(); + }); + + it('should throw for non-HTTP/HTTPS schemes', async () => { + await expect(validateEndpointURL('ftp://example.com/v1', 'test-ep')).rejects.toThrow( + 'only HTTP and HTTPS are permitted', + ); + await expect(validateEndpointURL('file:///etc/passwd', 'test-ep')).rejects.toThrow( + 'only HTTP and HTTPS are permitted', + ); + await expect(validateEndpointURL('data:text/plain,hello', 'test-ep')).rejects.toThrow( + 'only HTTP and HTTPS are permitted', + ); + }); + + it('should throw for IPv6 loopback URL', async () => { + await expect(validateEndpointURL('http://[::1]:8080/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for IPv6 link-local URL', async () => { + await expect(validateEndpointURL('http://[fe80::1]/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for IPv6 unique-local URL', async () => { + await expect(validateEndpointURL('http://[fc00::1]/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for .local TLD hostname', async () => { + await expect(validateEndpointURL('http://myservice.local/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should throw for .internal TLD hostname', async () => { + await expect(validateEndpointURL('http://api.internal/v1', 'test-ep')).rejects.toThrow( + 'targets a restricted address', + ); + }); + + it('should pass when DNS lookup fails (fail-open)', async () => { + mockedLookup.mockRejectedValueOnce(new Error('ENOTFOUND')); + await expect( + validateEndpointURL('https://nonexistent.example.com/v1', 'test-ep'), + ).resolves.toBeUndefined(); + }); + + it('should throw structured JSON with type invalid_base_url', async () => { + const error = await validateEndpointURL('http://169.254.169.254/latest/', 'my-ep').catch( + (err: Error) => err, + ); + expect(error).toBeInstanceOf(Error); + const parsed = JSON.parse((error as Error).message); + expect(parsed.type).toBe('invalid_base_url'); + expect(parsed.message).toContain('my-ep'); + expect(parsed.message).toContain('targets a restricted address'); + }); +}); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index 3babb09aa6..fabe2502ff 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -499,3 +499,45 @@ export async function isMCPDomainAllowed( // Use MCP_PROTOCOLS (HTTP/HTTPS/WS/WSS) for MCP server validation return isDomainAllowedCore(domain, allowedDomains, MCP_PROTOCOLS); } + +/** Matches ErrorTypes.INVALID_BASE_URL — string literal avoids build-time dependency on data-provider */ +const INVALID_BASE_URL_TYPE = 'invalid_base_url'; + +function throwInvalidBaseURL(message: string): never { + throw new Error(JSON.stringify({ type: INVALID_BASE_URL_TYPE, message })); +} + +/** + * Validates that a user-provided endpoint URL does not target private/internal addresses. + * Throws if the URL is unparseable, uses a non-HTTP(S) scheme, targets a known SSRF hostname, + * or DNS-resolves to a private IP. + * + * @note DNS rebinding: validation performs a single DNS lookup. An adversary controlling + * DNS with TTL=0 could respond with a public IP at validation time and a private IP + * at request time. This is an accepted limitation of point-in-time DNS checks. + * @note Fail-open on DNS errors: a resolution failure here implies a failure at request + * time as well, matching {@link resolveHostnameSSRF} semantics. + */ +export async function validateEndpointURL(url: string, endpoint: string): Promise { + let hostname: string; + let protocol: string; + try { + const parsed = new URL(url); + hostname = parsed.hostname; + protocol = parsed.protocol; + } catch { + throwInvalidBaseURL(`Invalid base URL for ${endpoint}: unable to parse URL.`); + } + + if (protocol !== 'http:' && protocol !== 'https:') { + throwInvalidBaseURL(`Invalid base URL for ${endpoint}: only HTTP and HTTPS are permitted.`); + } + + if (isSSRFTarget(hostname)) { + throwInvalidBaseURL(`Base URL for ${endpoint} targets a restricted address.`); + } + + if (await resolveHostnameSSRF(hostname)) { + throwInvalidBaseURL(`Base URL for ${endpoint} resolves to a restricted address.`); + } +} diff --git a/packages/api/src/endpoints/custom/initialize.spec.ts b/packages/api/src/endpoints/custom/initialize.spec.ts new file mode 100644 index 0000000000..911e17c446 --- /dev/null +++ b/packages/api/src/endpoints/custom/initialize.spec.ts @@ -0,0 +1,119 @@ +import { AuthType } from 'librechat-data-provider'; +import type { BaseInitializeParams } from '~/types'; + +const mockValidateEndpointURL = jest.fn(); +jest.mock('~/auth', () => ({ + validateEndpointURL: (...args: unknown[]) => mockValidateEndpointURL(...args), +})); + +const mockGetOpenAIConfig = jest.fn().mockReturnValue({ + llmConfig: { model: 'test-model' }, + configOptions: {}, +}); +jest.mock('~/endpoints/openai/config', () => ({ + getOpenAIConfig: (...args: unknown[]) => mockGetOpenAIConfig(...args), +})); + +jest.mock('~/endpoints/models', () => ({ + fetchModels: jest.fn(), +})); + +jest.mock('~/cache', () => ({ + standardCache: jest.fn(() => ({ get: jest.fn().mockResolvedValue(null) })), +})); + +jest.mock('~/utils', () => ({ + isUserProvided: (val: string) => val === 'user_provided', + checkUserKeyExpiry: jest.fn(), +})); + +const mockGetCustomEndpointConfig = jest.fn(); +jest.mock('~/app/config', () => ({ + getCustomEndpointConfig: (...args: unknown[]) => mockGetCustomEndpointConfig(...args), +})); + +import { initializeCustom } from './initialize'; + +function createParams(overrides: { + apiKey?: string; + baseURL?: string; + userBaseURL?: string; + userApiKey?: string; + expiresAt?: string; +}): BaseInitializeParams { + const { apiKey = 'sk-test-key', baseURL = 'https://api.example.com/v1' } = overrides; + + mockGetCustomEndpointConfig.mockReturnValue({ + apiKey, + baseURL, + models: {}, + }); + + const db = { + getUserKeyValues: jest.fn().mockResolvedValue({ + apiKey: overrides.userApiKey ?? 'sk-user-key', + baseURL: overrides.userBaseURL ?? 'https://user-api.example.com/v1', + }), + } as unknown as BaseInitializeParams['db']; + + return { + req: { + user: { id: 'user-1' }, + body: { key: overrides.expiresAt ?? '2099-01-01' }, + config: {}, + } as unknown as BaseInitializeParams['req'], + endpoint: 'test-custom', + model_parameters: { model: 'gpt-4' }, + db, + }; +} + +describe('initializeCustom – SSRF guard wiring', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should call validateEndpointURL when baseURL is user_provided', async () => { + const params = createParams({ + apiKey: 'sk-test-key', + baseURL: AuthType.USER_PROVIDED, + userBaseURL: 'https://user-api.example.com/v1', + expiresAt: '2099-01-01', + }); + + await initializeCustom(params); + + expect(mockValidateEndpointURL).toHaveBeenCalledTimes(1); + expect(mockValidateEndpointURL).toHaveBeenCalledWith( + 'https://user-api.example.com/v1', + 'test-custom', + ); + }); + + it('should NOT call validateEndpointURL when baseURL is system-defined', async () => { + const params = createParams({ + apiKey: 'sk-test-key', + baseURL: 'https://api.provider.com/v1', + }); + + await initializeCustom(params); + + expect(mockValidateEndpointURL).not.toHaveBeenCalled(); + }); + + it('should propagate SSRF rejection from validateEndpointURL', async () => { + mockValidateEndpointURL.mockRejectedValueOnce( + new Error('Base URL for test-custom targets a restricted address.'), + ); + + const params = createParams({ + apiKey: 'sk-test-key', + baseURL: AuthType.USER_PROVIDED, + userBaseURL: 'http://169.254.169.254/latest/meta-data/', + expiresAt: '2099-01-01', + }); + + await expect(initializeCustom(params)).rejects.toThrow('targets a restricted address'); + expect(mockGetOpenAIConfig).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/api/src/endpoints/custom/initialize.ts b/packages/api/src/endpoints/custom/initialize.ts index 7930b1c12f..15b6b873c7 100644 --- a/packages/api/src/endpoints/custom/initialize.ts +++ b/packages/api/src/endpoints/custom/initialize.ts @@ -9,9 +9,10 @@ import type { TEndpoint } from 'librechat-data-provider'; import type { AppConfig } from '@librechat/data-schemas'; import type { BaseInitializeParams, InitializeResultBase, EndpointTokenConfig } from '~/types'; import { getOpenAIConfig } from '~/endpoints/openai/config'; +import { isUserProvided, checkUserKeyExpiry } from '~/utils'; import { getCustomEndpointConfig } from '~/app/config'; import { fetchModels } from '~/endpoints/models'; -import { isUserProvided, checkUserKeyExpiry } from '~/utils'; +import { validateEndpointURL } from '~/auth'; import { standardCache } from '~/cache'; const { PROXY } = process.env; @@ -123,6 +124,10 @@ export async function initializeCustom({ throw new Error(`${endpoint} Base URL not provided.`); } + if (userProvidesURL) { + await validateEndpointURL(baseURL, endpoint); + } + let endpointTokenConfig: EndpointTokenConfig | undefined; const userId = req.user?.id ?? ''; diff --git a/packages/api/src/endpoints/openai/initialize.spec.ts b/packages/api/src/endpoints/openai/initialize.spec.ts new file mode 100644 index 0000000000..ae91571fb3 --- /dev/null +++ b/packages/api/src/endpoints/openai/initialize.spec.ts @@ -0,0 +1,135 @@ +import { AuthType, EModelEndpoint } from 'librechat-data-provider'; +import type { BaseInitializeParams } from '~/types'; + +const mockValidateEndpointURL = jest.fn(); +jest.mock('~/auth', () => ({ + validateEndpointURL: (...args: unknown[]) => mockValidateEndpointURL(...args), +})); + +const mockGetOpenAIConfig = jest.fn().mockReturnValue({ + llmConfig: { model: 'gpt-4' }, + configOptions: {}, +}); +jest.mock('./config', () => ({ + getOpenAIConfig: (...args: unknown[]) => mockGetOpenAIConfig(...args), +})); + +jest.mock('~/utils', () => ({ + getAzureCredentials: jest.fn(), + resolveHeaders: jest.fn(() => ({})), + isUserProvided: (val: string) => val === 'user_provided', + checkUserKeyExpiry: jest.fn(), +})); + +import { initializeOpenAI } from './initialize'; + +function createParams(env: Record): BaseInitializeParams { + const savedEnv: Record = {}; + for (const key of Object.keys(env)) { + savedEnv[key] = process.env[key]; + } + Object.assign(process.env, env); + + const db = { + getUserKeyValues: jest.fn().mockResolvedValue({ + apiKey: 'sk-user-key', + baseURL: 'https://user-proxy.example.com/v1', + }), + } as unknown as BaseInitializeParams['db']; + + const params: BaseInitializeParams = { + req: { + user: { id: 'user-1' }, + body: { key: '2099-01-01' }, + config: { endpoints: {} }, + } as unknown as BaseInitializeParams['req'], + endpoint: EModelEndpoint.openAI, + model_parameters: { model: 'gpt-4' }, + db, + }; + + const restore = () => { + for (const key of Object.keys(env)) { + if (savedEnv[key] === undefined) { + delete process.env[key]; + } else { + process.env[key] = savedEnv[key]; + } + } + }; + + return Object.assign(params, { _restore: restore }); +} + +describe('initializeOpenAI – SSRF guard wiring', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should call validateEndpointURL when OPENAI_REVERSE_PROXY is user_provided', async () => { + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + OPENAI_REVERSE_PROXY: AuthType.USER_PROVIDED, + }); + + try { + await initializeOpenAI(params); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockValidateEndpointURL).toHaveBeenCalledTimes(1); + expect(mockValidateEndpointURL).toHaveBeenCalledWith( + 'https://user-proxy.example.com/v1', + EModelEndpoint.openAI, + ); + }); + + it('should NOT call validateEndpointURL when OPENAI_REVERSE_PROXY is a system URL', async () => { + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + OPENAI_REVERSE_PROXY: 'https://api.openai.com/v1', + }); + + try { + await initializeOpenAI(params); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockValidateEndpointURL).not.toHaveBeenCalled(); + }); + + it('should NOT call validateEndpointURL when baseURL is falsy', async () => { + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + }); + + try { + await initializeOpenAI(params); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockValidateEndpointURL).not.toHaveBeenCalled(); + }); + + it('should propagate SSRF rejection from validateEndpointURL', async () => { + mockValidateEndpointURL.mockRejectedValueOnce( + new Error('Base URL for openAI targets a restricted address.'), + ); + + const params = createParams({ + OPENAI_API_KEY: 'sk-test', + OPENAI_REVERSE_PROXY: AuthType.USER_PROVIDED, + }); + + try { + await expect(initializeOpenAI(params)).rejects.toThrow('targets a restricted address'); + } finally { + (params as unknown as { _restore: () => void })._restore(); + } + + expect(mockGetOpenAIConfig).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/api/src/endpoints/openai/initialize.ts b/packages/api/src/endpoints/openai/initialize.ts index 33ce233d34..a6ad6df895 100644 --- a/packages/api/src/endpoints/openai/initialize.ts +++ b/packages/api/src/endpoints/openai/initialize.ts @@ -6,6 +6,7 @@ import type { UserKeyValues, } from '~/types'; import { getAzureCredentials, resolveHeaders, isUserProvided, checkUserKeyExpiry } from '~/utils'; +import { validateEndpointURL } from '~/auth'; import { getOpenAIConfig } from './config'; /** @@ -55,6 +56,10 @@ export async function initializeOpenAI({ ? userValues?.baseURL : baseURLOptions[endpoint as keyof typeof baseURLOptions]; + if (userProvidesURL && baseURL) { + await validateEndpointURL(baseURL, endpoint); + } + const clientOptions: OpenAIConfigOptions = { proxy: PROXY ?? undefined, reverseProxyUrl: baseURL || undefined, diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index e13521c019..bb0c180209 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1560,6 +1560,10 @@ export enum ErrorTypes { * No Base URL Provided. */ NO_BASE_URL = 'no_base_url', + /** + * Base URL targets a restricted or invalid address (SSRF protection). + */ + INVALID_BASE_URL = 'invalid_base_url', /** * Moderation error */ From ad08df4db682b2865f54fd0e77d4706ba9eaf843 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 18:54:34 -0400 Subject: [PATCH 33/39] =?UTF-8?q?=F0=9F=94=8F=20fix:=20Scope=20Agent-Autho?= =?UTF-8?q?r=20File=20Access=20to=20Attached=20Files=20Only=20(#12251)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Scope agent-author file access to attached files only The hasAccessToFilesViaAgent helper short-circuited for agent authors, granting access to all requested file IDs without verifying they were attached to the agent's tool_resources. This enabled an IDOR where any agent author could delete arbitrary files by supplying their agent_id alongside unrelated file IDs. Now both the author and non-author paths check file IDs against the agent's tool_resources before granting access. * chore: Use Object.values/for...of and add JSDoc in getAttachedFileIds * test: Add boundary cases for agent file access authorization - Agent with no tool_resources denies all access (fail-closed) - Files across multiple resource types are all reachable - Author + isDelete: true still scopes to attached files only --- api/models/File.spec.js | 121 +++++++++++++++++++++-- api/server/services/Files/permissions.js | 45 +++++---- 2 files changed, 141 insertions(+), 25 deletions(-) diff --git a/api/models/File.spec.js b/api/models/File.spec.js index 2d4282cff7..ecb2e21b08 100644 --- a/api/models/File.spec.js +++ b/api/models/File.spec.js @@ -152,12 +152,11 @@ describe('File Access Control', () => { expect(accessMap.get(fileIds[3])).toBe(false); }); - it('should grant access to all files when user is the agent author', async () => { + it('should only grant author access to files attached to the agent', async () => { const authorId = new mongoose.Types.ObjectId(); const agentId = uuidv4(); const fileIds = [uuidv4(), uuidv4(), uuidv4()]; - // Create author user await User.create({ _id: authorId, email: 'author@example.com', @@ -165,7 +164,6 @@ describe('File Access Control', () => { provider: 'local', }); - // Create agent await createAgent({ id: agentId, name: 'Test Agent', @@ -174,12 +172,83 @@ describe('File Access Control', () => { provider: 'openai', tool_resources: { file_search: { - file_ids: [fileIds[0]], // Only one file attached + file_ids: [fileIds[0]], + }, + }, + }); + + const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); + const accessMap = await hasAccessToFilesViaAgent({ + userId: authorId, + role: SystemRoles.USER, + fileIds, + agentId, + }); + + expect(accessMap.get(fileIds[0])).toBe(true); + expect(accessMap.get(fileIds[1])).toBe(false); + expect(accessMap.get(fileIds[2])).toBe(false); + }); + + it('should deny all access when agent has no tool_resources', async () => { + const authorId = new mongoose.Types.ObjectId(); + const agentId = uuidv4(); + const fileId = uuidv4(); + + await User.create({ + _id: authorId, + email: 'author-no-resources@example.com', + emailVerified: true, + provider: 'local', + }); + + await createAgent({ + id: agentId, + name: 'Bare Agent', + author: authorId, + model: 'gpt-4', + provider: 'openai', + }); + + const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); + const accessMap = await hasAccessToFilesViaAgent({ + userId: authorId, + role: SystemRoles.USER, + fileIds: [fileId], + agentId, + }); + + expect(accessMap.get(fileId)).toBe(false); + }); + + it('should grant access to files across multiple resource types', async () => { + const authorId = new mongoose.Types.ObjectId(); + const agentId = uuidv4(); + const fileIds = [uuidv4(), uuidv4(), uuidv4()]; + + await User.create({ + _id: authorId, + email: 'author-multi@example.com', + emailVerified: true, + provider: 'local', + }); + + await createAgent({ + id: agentId, + name: 'Multi Resource Agent', + author: authorId, + model: 'gpt-4', + provider: 'openai', + tool_resources: { + file_search: { + file_ids: [fileIds[0]], + }, + execute_code: { + file_ids: [fileIds[1]], }, }, }); - // Check access as the author const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); const accessMap = await hasAccessToFilesViaAgent({ userId: authorId, @@ -188,10 +257,48 @@ describe('File Access Control', () => { agentId, }); - // Author should have access to all files expect(accessMap.get(fileIds[0])).toBe(true); expect(accessMap.get(fileIds[1])).toBe(true); - expect(accessMap.get(fileIds[2])).toBe(true); + expect(accessMap.get(fileIds[2])).toBe(false); + }); + + it('should grant author access to attached files when isDelete is true', async () => { + const authorId = new mongoose.Types.ObjectId(); + const agentId = uuidv4(); + const attachedFileId = uuidv4(); + const unattachedFileId = uuidv4(); + + await User.create({ + _id: authorId, + email: 'author-delete@example.com', + emailVerified: true, + provider: 'local', + }); + + await createAgent({ + id: agentId, + name: 'Delete Test Agent', + author: authorId, + model: 'gpt-4', + provider: 'openai', + tool_resources: { + file_search: { + file_ids: [attachedFileId], + }, + }, + }); + + const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions'); + const accessMap = await hasAccessToFilesViaAgent({ + userId: authorId, + role: SystemRoles.USER, + fileIds: [attachedFileId, unattachedFileId], + agentId, + isDelete: true, + }); + + expect(accessMap.get(attachedFileId)).toBe(true); + expect(accessMap.get(unattachedFileId)).toBe(false); }); it('should handle non-existent agent gracefully', async () => { diff --git a/api/server/services/Files/permissions.js b/api/server/services/Files/permissions.js index d909afe25a..df484f7c29 100644 --- a/api/server/services/Files/permissions.js +++ b/api/server/services/Files/permissions.js @@ -4,7 +4,26 @@ const { checkPermission } = require('~/server/services/PermissionService'); const { getAgent } = require('~/models/Agent'); /** - * Checks if a user has access to multiple files through a shared agent (batch operation) + * @param {Object} agent - The agent document (lean) + * @returns {Set} All file IDs attached across all resource types + */ +function getAttachedFileIds(agent) { + const attachedFileIds = new Set(); + if (agent.tool_resources) { + for (const resource of Object.values(agent.tool_resources)) { + if (resource?.file_ids && Array.isArray(resource.file_ids)) { + for (const fileId of resource.file_ids) { + attachedFileIds.add(fileId); + } + } + } + } + return attachedFileIds; +} + +/** + * Checks if a user has access to multiple files through a shared agent (batch operation). + * Access is always scoped to files actually attached to the agent's tool_resources. * @param {Object} params - Parameters object * @param {string} params.userId - The user ID to check access for * @param {string} [params.role] - Optional user role to avoid DB query @@ -16,7 +35,6 @@ const { getAgent } = require('~/models/Agent'); const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => { const accessMap = new Map(); - // Initialize all files as no access fileIds.forEach((fileId) => accessMap.set(fileId, false)); try { @@ -26,13 +44,17 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele return accessMap; } - // Check if user is the author - if so, grant access to all files + const attachedFileIds = getAttachedFileIds(agent); + if (agent.author.toString() === userId.toString()) { - fileIds.forEach((fileId) => accessMap.set(fileId, true)); + fileIds.forEach((fileId) => { + if (attachedFileIds.has(fileId)) { + accessMap.set(fileId, true); + } + }); return accessMap; } - // Check if user has at least VIEW permission on the agent const hasViewPermission = await checkPermission({ userId, role, @@ -46,7 +68,6 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele } if (isDelete) { - // Check if user has EDIT permission (which would indicate collaborative access) const hasEditPermission = await checkPermission({ userId, role, @@ -55,23 +76,11 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele requiredPermission: PermissionBits.EDIT, }); - // If user only has VIEW permission, they can't access files - // Only users with EDIT permission or higher can access agent files if (!hasEditPermission) { return accessMap; } } - const attachedFileIds = new Set(); - if (agent.tool_resources) { - for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) { - if (resource?.file_ids && Array.isArray(resource.file_ids)) { - resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId)); - } - } - } - - // Grant access only to files that are attached to this agent fileIds.forEach((fileId) => { if (attachedFileIds.has(fileId)) { accessMap.set(fileId, true); From aee1ced81713d06246646709d89dffe56d5f9d5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Airam=20Hern=C3=A1ndez=20Hern=C3=A1ndez?= <100208966+Airamhh@users.noreply.github.com> Date: Sun, 15 Mar 2026 23:09:53 +0000 Subject: [PATCH 34/39] =?UTF-8?q?=F0=9F=AA=99=20fix:=20Resolve=20Azure=20A?= =?UTF-8?q?D=20Group=20Overage=20via=20OBO=20Token=20Exchange=20for=20Open?= =?UTF-8?q?ID=20(#12187)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When Azure AD users belong to 200+ groups, group claims are moved out of the ID token (overage). The existing resolveGroupsFromOverage() called Microsoft Graph directly with the app-audience access token, which Graph rejected (401/403). Changes: - Add exchangeTokenForOverage() dedicated OBO exchange with User.Read scope - Update resolveGroupsFromOverage() to exchange token before Graph call - Add overage handling to OPENID_ADMIN_ROLE block (was silently failing) - Share resolved overage groups between required role and admin role checks - Always resolve via Graph when overage detected (even with partial groups) - Remove debug-only bypass that forced Graph resolution - Add tests for OBO exchange, caching, and admin role overage scenarios Co-authored-by: Airam Hernández Hernández --- api/strategies/openidStrategy.js | 104 ++++++++- api/strategies/openidStrategy.spec.js | 313 +++++++++++++++++++++++++- 2 files changed, 406 insertions(+), 11 deletions(-) diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index 0ebdcb04e1..7c43358297 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -315,24 +315,85 @@ function convertToUsername(input, defaultValue = '') { return defaultValue; } +/** + * Exchange the access token for a Graph-scoped token using the On-Behalf-Of (OBO) flow. + * + * The original access token has the app's own audience (api://), which Microsoft Graph + * rejects. This exchange produces a token with audience https://graph.microsoft.com and the + * minimum delegated scope (User.Read) required by /me/getMemberObjects. + * + * Uses a dedicated cache key (`${sub}:overage`) to avoid collisions with other OBO exchanges + * in the codebase (userinfo, Graph principal search). + * + * @param {string} accessToken - The original access token from the OpenID tokenset + * @param {string} sub - The subject identifier for cache keying + * @returns {Promise} A Graph-scoped access token + * @see https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow + */ +async function exchangeTokenForOverage(accessToken, sub) { + if (!openidConfig) { + throw new Error('[openidStrategy] OpenID config not initialized; cannot exchange OBO token'); + } + + const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS); + const cacheKey = `${sub}:overage`; + + const cached = await tokensCache.get(cacheKey); + if (cached?.access_token) { + logger.debug('[openidStrategy] Using cached Graph token for overage resolution'); + return cached.access_token; + } + + const grantResponse = await client.genericGrantRequest( + openidConfig, + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + { + scope: 'https://graph.microsoft.com/User.Read', + assertion: accessToken, + requested_token_use: 'on_behalf_of', + }, + ); + + if (!grantResponse.access_token) { + throw new Error( + '[openidStrategy] OBO exchange succeeded but returned no access_token; cannot call Graph API', + ); + } + + const ttlMs = + Number.isFinite(grantResponse.expires_in) && grantResponse.expires_in > 0 + ? grantResponse.expires_in * 1000 + : 3600 * 1000; + + await tokensCache.set(cacheKey, { access_token: grantResponse.access_token }, ttlMs); + + return grantResponse.access_token; +} + /** * Resolve Azure AD groups when group overage is in effect (groups moved to _claim_names/_claim_sources). * * NOTE: Microsoft recommends treating _claim_names/_claim_sources as a signal only and using Microsoft Graph * to resolve group membership instead of calling the endpoint in _claim_sources directly. * - * @param {string} accessToken - Access token with Microsoft Graph permissions + * Before calling Graph, the access token is exchanged via the OBO flow to obtain a token with the + * correct audience (https://graph.microsoft.com) and User.Read scope. + * + * @param {string} accessToken - Access token from the OpenID tokenset (app audience) + * @param {string} sub - The subject identifier of the user (for OBO exchange and cache keying) * @returns {Promise} Resolved group IDs or null on failure * @see https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#groups-overage-claim * @see https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects */ -async function resolveGroupsFromOverage(accessToken) { +async function resolveGroupsFromOverage(accessToken, sub) { try { if (!accessToken) { logger.error('[openidStrategy] Access token missing; cannot resolve group overage'); return null; } + const graphToken = await exchangeTokenForOverage(accessToken, sub); + // Use /me/getMemberObjects so least-privileged delegated permission User.Read is sufficient // when resolving the signed-in user's group membership. const url = 'https://graph.microsoft.com/v1.0/me/getMemberObjects'; @@ -344,7 +405,7 @@ async function resolveGroupsFromOverage(accessToken) { const fetchOptions = { method: 'POST', headers: { - Authorization: `Bearer ${accessToken}`, + Authorization: `Bearer ${graphToken}`, 'Content-Type': 'application/json', }, body: JSON.stringify({ securityEnabledOnly: false }), @@ -364,6 +425,7 @@ async function resolveGroupsFromOverage(accessToken) { } const data = await response.json(); + const values = Array.isArray(data?.value) ? data.value : null; if (!values) { logger.error( @@ -432,6 +494,8 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { const fullName = getFullName(userinfo); const requiredRole = process.env.OPENID_REQUIRED_ROLE; + let resolvedOverageGroups = null; + if (requiredRole) { const requiredRoles = requiredRole .split(',') @@ -451,19 +515,21 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { // Handle Azure AD group overage for ID token groups: when hasgroups or _claim_* indicate overage, // resolve groups via Microsoft Graph instead of relying on token group values. + const hasOverage = + decodedToken?.hasgroups || + (decodedToken?._claim_names?.groups && + decodedToken?._claim_sources?.[decodedToken._claim_names.groups]); + if ( - !Array.isArray(roles) && - typeof roles !== 'string' && requiredRoleTokenKind === 'id' && requiredRoleParameterPath === 'groups' && decodedToken && - (decodedToken.hasgroups || - (decodedToken._claim_names?.groups && - decodedToken._claim_sources?.[decodedToken._claim_names.groups])) + hasOverage ) { - const overageGroups = await resolveGroupsFromOverage(tokenset.access_token); + const overageGroups = await resolveGroupsFromOverage(tokenset.access_token, claims.sub); if (overageGroups) { roles = overageGroups; + resolvedOverageGroups = overageGroups; } } @@ -550,7 +616,25 @@ async function processOpenIDAuth(tokenset, existingUsersOnly = false) { throw new Error('Invalid admin role token kind'); } - const adminRoles = get(adminRoleObject, adminRoleParameterPath); + let adminRoles = get(adminRoleObject, adminRoleParameterPath); + + // Handle Azure AD group overage for admin role when using ID token groups + if (adminRoleTokenKind === 'id' && adminRoleParameterPath === 'groups' && adminRoleObject) { + const hasAdminOverage = + adminRoleObject.hasgroups || + (adminRoleObject._claim_names?.groups && + adminRoleObject._claim_sources?.[adminRoleObject._claim_names.groups]); + + if (hasAdminOverage) { + const overageGroups = + resolvedOverageGroups || + (await resolveGroupsFromOverage(tokenset.access_token, claims.sub)); + if (overageGroups) { + adminRoles = overageGroups; + } + } + } + let adminRoleValues = []; if (Array.isArray(adminRoles)) { adminRoleValues = adminRoles; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 485b77829e..16fa548a59 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -64,6 +64,10 @@ jest.mock('openid-client', () => { // Only return additional properties, but don't override any claims return Promise.resolve({}); }), + genericGrantRequest: jest.fn().mockResolvedValue({ + access_token: 'exchanged_graph_token', + expires_in: 3600, + }), customFetch: Symbol('customFetch'), }; }); @@ -730,7 +734,7 @@ describe('setupOpenId', () => { expect.objectContaining({ method: 'POST', headers: expect.objectContaining({ - Authorization: `Bearer ${tokenset.access_token}`, + Authorization: 'Bearer exchanged_graph_token', }), }), ); @@ -745,6 +749,313 @@ describe('setupOpenId', () => { ); }); + describe('OBO token exchange for overage', () => { + it('exchanges access token via OBO before calling Graph API', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + await validate(tokenset); + + expect(openidClient.genericGrantRequest).toHaveBeenCalledWith( + expect.anything(), + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + expect.objectContaining({ + scope: 'https://graph.microsoft.com/User.Read', + assertion: tokenset.access_token, + requested_token_use: 'on_behalf_of', + }), + ); + + expect(undici.fetch).toHaveBeenCalledWith( + 'https://graph.microsoft.com/v1.0/me/getMemberObjects', + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer exchanged_graph_token', + }), + }), + ); + }); + + it('caches the exchanged token and reuses it on subsequent calls', async () => { + const openidClient = require('openid-client'); + const getLogStores = require('~/cache/getLogStores'); + const mockSet = jest.fn(); + const mockGet = jest + .fn() + .mockResolvedValueOnce(undefined) + .mockResolvedValueOnce({ access_token: 'exchanged_graph_token' }); + getLogStores.mockReturnValue({ get: mockGet, set: mockSet }); + + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + // First call: cache miss → OBO exchange → cache set + await validate(tokenset); + expect(mockSet).toHaveBeenCalledWith( + '1234:overage', + { access_token: 'exchanged_graph_token' }, + 3600000, + ); + expect(openidClient.genericGrantRequest).toHaveBeenCalledTimes(1); + + // Second call: cache hit → no new OBO exchange + openidClient.genericGrantRequest.mockClear(); + await validate(tokenset); + expect(openidClient.genericGrantRequest).not.toHaveBeenCalled(); + }); + }); + + describe('admin role group overage', () => { + it('resolves admin groups via Graph when overage is detected for admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('ADMIN'); + }); + + it('does not grant admin when overage groups do not contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'other-group'] }), + }); + + const { user } = await validate(tokenset); + + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('reuses already-resolved overage groups for admin role check (no duplicate Graph call)', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required', 'admin-group-id'] }), + }); + + await validate(tokenset); + + // Graph API should be called only once (for required role), admin role reuses the result + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('demotes existing admin when overage groups no longer contain admin role', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + const existingAdminUser = { + _id: 'existingAdminId', + provider: 'openid', + email: tokenset.claims().email, + openidId: tokenset.claims().sub, + username: 'adminuser', + name: 'Admin User', + role: 'ADMIN', + }; + + findUser.mockImplementation(async (query) => { + if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) { + return existingAdminUser; + } + return null; + }); + + jwtDecode.mockReturnValue({ hasgroups: true }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['group-required'] }), + }); + + const { user } = await validate(tokenset); + + expect(user.role).toBe('USER'); + }); + + it('does not attempt overage for admin role when token kind is not id', async () => { + process.env.OPENID_REQUIRED_ROLE = 'requiredRole'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + process.env.OPENID_ADMIN_ROLE = 'admin'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'access'; + + jwtDecode.mockReturnValue({ + roles: ['requiredRole'], + hasgroups: true, + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user } = await validate(tokenset); + + // No Graph call since admin uses access token (not id) + expect(undici.fetch).not.toHaveBeenCalled(); + expect(user.role).toBeUndefined(); + }); + + it('resolves admin via Graph independently when OPENID_REQUIRED_ROLE is not configured', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['admin-group-id'] }), + }); + + const { user } = await validate(tokenset); + expect(user.role).toBe('ADMIN'); + expect(undici.fetch).toHaveBeenCalledTimes(1); + }); + + it('denies admin when OPENID_REQUIRED_ROLE is absent and Graph does not contain admin group', async () => { + delete process.env.OPENID_REQUIRED_ROLE; + process.env.OPENID_ADMIN_ROLE = 'admin-group-id'; + process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_ADMIN_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: ['other-group'] }), + }); + + const { user } = await validate(tokenset); + expect(user).toBeTruthy(); + expect(user.role).toBeUndefined(); + }); + + it('denies login and logs error when OBO exchange throws', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockRejectedValueOnce(new Error('OBO exchange rejected')); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + + it('denies login when OBO exchange returns no access_token', async () => { + const openidClient = require('openid-client'); + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ hasgroups: true }); + openidClient.genericGrantRequest.mockResolvedValueOnce({ expires_in: 3600 }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + expect(undici.fetch).not.toHaveBeenCalled(); + }); + }); + it('should attempt to download and save the avatar if picture is provided', async () => { // Act const { user } = await validate(tokenset); From a26eeea59281889f3d05885bc765c732f1028147 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 20:08:34 -0400 Subject: [PATCH 35/39] =?UTF-8?q?=F0=9F=94=8F=20fix:=20Enforce=20MCP=20Ser?= =?UTF-8?q?ver=20Authorization=20on=20Agent=20Tool=20Persistence=20(#12250?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🛡️ fix: Validate MCP tool authorization on agent create/update Agent creation and update accepted arbitrary MCP tool strings without verifying the user has access to the referenced MCP servers. This allowed a user to embed unauthorized server names in tool identifiers (e.g. "anything_mcp_"), causing mcpServerNames to be stored on the agent and granting consumeOnly access via hasAccessViaAgent(). Adds filterAuthorizedTools() that checks MCP tool strings against the user's accessible server configs (via getAllServerConfigs) before persisting. Applied to create, update, and duplicate agent paths. * 🛡️ fix: Harden MCP tool authorization and add test coverage Addresses review findings on the MCP agent tool authorization fix: - Wrap getMCPServersRegistry() in try/catch so uninitialized registry gracefully filters all MCP tools instead of causing a 500 (DoS risk) - Guard revertAgentVersionHandler: filter unauthorized MCP tools after reverting to a previous version snapshot - Preserve existing MCP tools on collaborative updates: only validate newly added tools, preventing silent stripping of tools the editing user lacks direct access to - Add audit logging (logger.warn) when MCP tools are rejected - Refactor to single-pass lazy-fetch (registry queried only on first MCP tool encountered) - Export filterAuthorizedTools for direct unit testing - Add 18 tests covering: authorized/unauthorized/mixed tools, registry unavailable fallback, create/update/duplicate/revert handler paths, collaborative update preservation, and mcpServerNames persistence * test: Add duplicate handler test, use Constants.mcp_delimiter, DB assertions - N1: Add duplicateAgentHandler integration test verifying unauthorized MCP tools are stripped from the cloned agent and mcpServerNames are correctly persisted in the database - N2: Replace all hardcoded '_mcp_' delimiter literals with Constants.mcp_delimiter to prevent silent false-positive tests if the delimiter value ever changes - N3: Add DB state assertion to the revert-with-strip test confirming persisted tools match the response after unauthorized tools are removed * fix: Enforce exact 2-segment format for MCP tool keys Reject MCP tool keys with multiple delimiters to prevent authorization/execution mismatch when `.pop()` vs `split[1]` extract different server names from the same key. * fix: Preserve existing MCP tools when registry is unavailable When the MCP registry is uninitialized (e.g. server restart), existing tools already persisted on the agent are preserved instead of silently stripped. New MCP tools are still rejected when the registry cannot verify them. Applies to duplicate and revert handlers via existingTools param; update handler already preserves existing tools via its diff logic. --- .../agents/filterAuthorizedTools.spec.js | 677 ++++++++++++++++++ api/server/controllers/agents/v1.js | 134 +++- 2 files changed, 801 insertions(+), 10 deletions(-) create mode 100644 api/server/controllers/agents/filterAuthorizedTools.spec.js diff --git a/api/server/controllers/agents/filterAuthorizedTools.spec.js b/api/server/controllers/agents/filterAuthorizedTools.spec.js new file mode 100644 index 0000000000..259e41fb0d --- /dev/null +++ b/api/server/controllers/agents/filterAuthorizedTools.spec.js @@ -0,0 +1,677 @@ +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { Constants } = require('librechat-data-provider'); +const { agentSchema } = require('@librechat/data-schemas'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +const d = Constants.mcp_delimiter; + +const mockGetAllServerConfigs = jest.fn(); + +jest.mock('~/server/services/Config', () => ({ + getCachedTools: jest.fn().mockResolvedValue({ + web_search: true, + execute_code: true, + file_search: true, + }), +})); + +jest.mock('~/config', () => ({ + getMCPServersRegistry: jest.fn(() => ({ + getAllServerConfigs: mockGetAllServerConfigs, + })), +})); + +jest.mock('~/models/Project', () => ({ + getProjectByName: jest.fn().mockResolvedValue(null), +})); + +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +jest.mock('~/server/services/Files/images/avatar', () => ({ + resizeAvatar: jest.fn(), +})); + +jest.mock('~/server/services/Files/S3/crud', () => ({ + refreshS3Url: jest.fn(), +})); + +jest.mock('~/server/services/Files/process', () => ({ + filterFile: jest.fn(), +})); + +jest.mock('~/models/Action', () => ({ + updateAction: jest.fn(), + getActions: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/File', () => ({ + deleteFileByFilter: jest.fn(), +})); + +jest.mock('~/server/services/PermissionService', () => ({ + findAccessibleResources: jest.fn().mockResolvedValue([]), + findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]), + grantPermission: jest.fn(), + hasPublicPermission: jest.fn().mockResolvedValue(false), + checkPermission: jest.fn().mockResolvedValue(true), +})); + +jest.mock('~/models', () => ({ + getCategoriesWithCounts: jest.fn(), +})); + +jest.mock('~/cache', () => ({ + getLogStores: jest.fn(() => ({ + get: jest.fn(), + set: jest.fn(), + delete: jest.fn(), + })), +})); + +const { + filterAuthorizedTools, + createAgent: createAgentHandler, + updateAgent: updateAgentHandler, + duplicateAgent: duplicateAgentHandler, + revertAgentVersion: revertAgentVersionHandler, +} = require('./v1'); + +const { getMCPServersRegistry } = require('~/config'); + +let Agent; + +describe('MCP Tool Authorization', () => { + let mongoServer; + let mockReq; + let mockRes; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + }, 20000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + jest.clearAllMocks(); + + getMCPServersRegistry.mockImplementation(() => ({ + getAllServerConfigs: mockGetAllServerConfigs, + })); + mockGetAllServerConfigs.mockResolvedValue({ + authorizedServer: { type: 'sse', url: 'https://authorized.example.com' }, + anotherServer: { type: 'sse', url: 'https://another.example.com' }, + }); + + mockReq = { + user: { + id: new mongoose.Types.ObjectId().toString(), + role: 'USER', + }, + body: {}, + params: {}, + query: {}, + app: { locals: { fileStrategy: 'local' } }, + }; + + mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + }; + }); + + describe('filterAuthorizedTools', () => { + const availableTools = { web_search: true, custom_tool: true }; + const userId = 'test-user-123'; + + test('should keep authorized MCP tools and strip unauthorized ones', async () => { + const result = await filterAuthorizedTools({ + tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`, 'web_search'], + userId, + availableTools, + }); + + expect(result).toContain(`toolA${d}authorizedServer`); + expect(result).toContain('web_search'); + expect(result).not.toContain(`toolB${d}forbiddenServer`); + }); + + test('should keep system tools without querying MCP registry', async () => { + const result = await filterAuthorizedTools({ + tools: ['execute_code', 'file_search', 'web_search'], + userId, + availableTools: {}, + }); + + expect(result).toEqual(['execute_code', 'file_search', 'web_search']); + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should not query MCP registry when no MCP tools are present', async () => { + const result = await filterAuthorizedTools({ + tools: ['web_search', 'custom_tool'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search', 'custom_tool']); + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should filter all MCP tools when registry is uninitialized', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const result = await filterAuthorizedTools({ + tools: [`toolA${d}someServer`, 'web_search'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search']); + expect(result).not.toContain(`toolA${d}someServer`); + }); + + test('should handle mixed authorized and unauthorized MCP tools', async () => { + const result = await filterAuthorizedTools({ + tools: [ + 'web_search', + `search${d}authorizedServer`, + `attack${d}victimServer`, + 'execute_code', + `list${d}anotherServer`, + `steal${d}nonexistent`, + ], + userId, + availableTools, + }); + + expect(result).toEqual([ + 'web_search', + `search${d}authorizedServer`, + 'execute_code', + `list${d}anotherServer`, + ]); + }); + + test('should handle empty tools array', async () => { + const result = await filterAuthorizedTools({ + tools: [], + userId, + availableTools, + }); + + expect(result).toEqual([]); + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should handle null/undefined tool entries gracefully', async () => { + const result = await filterAuthorizedTools({ + tools: [null, undefined, '', 'web_search'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search']); + }); + + test('should call getAllServerConfigs with the correct userId', async () => { + await filterAuthorizedTools({ + tools: [`tool${d}authorizedServer`], + userId: 'specific-user-id', + availableTools, + }); + + expect(mockGetAllServerConfigs).toHaveBeenCalledWith('specific-user-id'); + }); + + test('should only call getAllServerConfigs once even with multiple MCP tools', async () => { + await filterAuthorizedTools({ + tools: [`tool1${d}authorizedServer`, `tool2${d}anotherServer`, `tool3${d}unknownServer`], + userId, + availableTools, + }); + + expect(mockGetAllServerConfigs).toHaveBeenCalledTimes(1); + }); + + test('should preserve existing MCP tools when registry is unavailable', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const existingTools = [`toolA${d}serverA`, `toolB${d}serverB`]; + + const result = await filterAuthorizedTools({ + tools: [...existingTools, `newTool${d}unknownServer`, 'web_search'], + userId, + availableTools, + existingTools, + }); + + expect(result).toContain(`toolA${d}serverA`); + expect(result).toContain(`toolB${d}serverB`); + expect(result).toContain('web_search'); + expect(result).not.toContain(`newTool${d}unknownServer`); + }); + + test('should still reject all MCP tools when registry is unavailable and no existingTools', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const result = await filterAuthorizedTools({ + tools: [`toolA${d}serverA`, 'web_search'], + userId, + availableTools, + }); + + expect(result).toEqual(['web_search']); + }); + + test('should not preserve malformed existing tools when registry is unavailable', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + const malformedTool = `a${d}b${d}c`; + const result = await filterAuthorizedTools({ + tools: [malformedTool, `legit${d}serverA`, 'web_search'], + userId, + availableTools, + existingTools: [malformedTool, `legit${d}serverA`], + }); + + expect(result).toContain(`legit${d}serverA`); + expect(result).toContain('web_search'); + expect(result).not.toContain(malformedTool); + }); + + test('should reject malformed MCP tool keys with multiple delimiters', async () => { + const result = await filterAuthorizedTools({ + tools: [ + `attack${d}victimServer${d}authorizedServer`, + `legit${d}authorizedServer`, + `a${d}b${d}c${d}d`, + 'web_search', + ], + userId, + availableTools, + }); + + expect(result).toEqual([`legit${d}authorizedServer`, 'web_search']); + expect(result).not.toContainEqual(expect.stringContaining('victimServer')); + expect(result).not.toContainEqual(expect.stringContaining(`a${d}b`)); + }); + }); + + describe('createAgentHandler - MCP tool authorization', () => { + test('should strip unauthorized MCP tools on create', async () => { + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'MCP Test Agent', + tools: ['web_search', `validTool${d}authorizedServer`, `attack${d}forbiddenServer`], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const agent = mockRes.json.mock.calls[0][0]; + expect(agent.tools).toContain('web_search'); + expect(agent.tools).toContain(`validTool${d}authorizedServer`); + expect(agent.tools).not.toContain(`attack${d}forbiddenServer`); + }); + + test('should not 500 when MCP registry is uninitialized', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'MCP Uninitialized Test', + tools: [`tool${d}someServer`, 'web_search'], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const agent = mockRes.json.mock.calls[0][0]; + expect(agent.tools).toEqual(['web_search']); + }); + + test('should store mcpServerNames only for authorized servers', async () => { + mockReq.body = { + provider: 'openai', + model: 'gpt-4', + name: 'MCP Names Test', + tools: [`toolA${d}authorizedServer`, `toolB${d}forbiddenServer`], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const agent = mockRes.json.mock.calls[0][0]; + const agentInDb = await Agent.findOne({ id: agent.id }); + expect(agentInDb.mcpServerNames).toContain('authorizedServer'); + expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer'); + }); + }); + + describe('updateAgentHandler - MCP tool authorization', () => { + let existingAgentId; + let existingAgentAuthorId; + + beforeEach(async () => { + existingAgentAuthorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Original Agent', + provider: 'openai', + model: 'gpt-4', + author: existingAgentAuthorId, + tools: ['web_search', `existingTool${d}authorizedServer`], + mcpServerNames: ['authorizedServer'], + versions: [ + { + name: 'Original Agent', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search', `existingTool${d}authorizedServer`], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + existingAgentId = agent.id; + }); + + test('should preserve existing MCP tools even if editor lacks access', async () => { + mockGetAllServerConfigs.mockResolvedValue({}); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.tools).toContain('web_search'); + }); + + test('should reject newly added unauthorized MCP tools', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`, `attack${d}forbiddenServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain('web_search'); + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.tools).not.toContain(`attack${d}forbiddenServer`); + }); + + test('should allow adding authorized MCP tools', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`, `newTool${d}anotherServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`newTool${d}anotherServer`); + }); + + test('should not query MCP registry when no new MCP tools added', async () => { + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockGetAllServerConfigs).not.toHaveBeenCalled(); + }); + + test('should preserve existing MCP tools when registry unavailable and user edits agent', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Renamed After Restart', + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.tools).toContain('web_search'); + expect(updatedAgent.name).toBe('Renamed After Restart'); + }); + + test('should preserve existing MCP tools when server not in configs (disconnected)', async () => { + mockGetAllServerConfigs.mockResolvedValue({}); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { + name: 'Edited While Disconnected', + tools: ['web_search', `existingTool${d}authorizedServer`], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const updatedAgent = mockRes.json.mock.calls[0][0]; + expect(updatedAgent.tools).toContain(`existingTool${d}authorizedServer`); + expect(updatedAgent.name).toBe('Edited While Disconnected'); + }); + }); + + describe('duplicateAgentHandler - MCP tool authorization', () => { + let sourceAgentId; + let sourceAgentAuthorId; + + beforeEach(async () => { + sourceAgentAuthorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Source Agent', + provider: 'openai', + model: 'gpt-4', + author: sourceAgentAuthorId, + tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`], + mcpServerNames: ['authorizedServer', 'forbiddenServer'], + versions: [ + { + name: 'Source Agent', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search', `tool${d}authorizedServer`, `tool${d}forbiddenServer`], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + sourceAgentId = agent.id; + }); + + test('should strip unauthorized MCP tools from duplicated agent', async () => { + mockGetAllServerConfigs.mockResolvedValue({ + authorizedServer: { type: 'sse' }, + }); + + mockReq.user.id = sourceAgentAuthorId.toString(); + mockReq.params.id = sourceAgentId; + + await duplicateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const { agent: newAgent } = mockRes.json.mock.calls[0][0]; + expect(newAgent.id).not.toBe(sourceAgentId); + expect(newAgent.tools).toContain('web_search'); + expect(newAgent.tools).toContain(`tool${d}authorizedServer`); + expect(newAgent.tools).not.toContain(`tool${d}forbiddenServer`); + + const agentInDb = await Agent.findOne({ id: newAgent.id }); + expect(agentInDb.mcpServerNames).toContain('authorizedServer'); + expect(agentInDb.mcpServerNames).not.toContain('forbiddenServer'); + }); + + test('should preserve source agent MCP tools when registry is unavailable', async () => { + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.user.id = sourceAgentAuthorId.toString(); + mockReq.params.id = sourceAgentId; + + await duplicateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + const { agent: newAgent } = mockRes.json.mock.calls[0][0]; + expect(newAgent.tools).toContain('web_search'); + expect(newAgent.tools).toContain(`tool${d}authorizedServer`); + expect(newAgent.tools).toContain(`tool${d}forbiddenServer`); + }); + }); + + describe('revertAgentVersionHandler - MCP tool authorization', () => { + let existingAgentId; + let existingAgentAuthorId; + + beforeEach(async () => { + existingAgentAuthorId = new mongoose.Types.ObjectId(); + const agent = await Agent.create({ + id: `agent_${uuidv4()}`, + name: 'Reverted Agent V2', + provider: 'openai', + model: 'gpt-4', + author: existingAgentAuthorId, + tools: ['web_search'], + versions: [ + { + name: 'Reverted Agent V1', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search', `oldTool${d}revokedServer`], + createdAt: new Date(Date.now() - 10000), + updatedAt: new Date(Date.now() - 10000), + }, + { + name: 'Reverted Agent V2', + provider: 'openai', + model: 'gpt-4', + tools: ['web_search'], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }); + existingAgentId = agent.id; + }); + + test('should strip unauthorized MCP tools after reverting to a previous version', async () => { + mockGetAllServerConfigs.mockResolvedValue({ + authorizedServer: { type: 'sse' }, + }); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { version_index: 0 }; + + await revertAgentVersionHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const result = mockRes.json.mock.calls[0][0]; + expect(result.tools).toContain('web_search'); + expect(result.tools).not.toContain(`oldTool${d}revokedServer`); + + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.tools).toContain('web_search'); + expect(agentInDb.tools).not.toContain(`oldTool${d}revokedServer`); + }); + + test('should keep authorized MCP tools after revert', async () => { + await Agent.updateOne( + { id: existingAgentId }, + { $set: { 'versions.0.tools': ['web_search', `tool${d}authorizedServer`] } }, + ); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { version_index: 0 }; + + await revertAgentVersionHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const result = mockRes.json.mock.calls[0][0]; + expect(result.tools).toContain('web_search'); + expect(result.tools).toContain(`tool${d}authorizedServer`); + }); + + test('should preserve version MCP tools when registry is unavailable on revert', async () => { + await Agent.updateOne( + { id: existingAgentId }, + { + $set: { + 'versions.0.tools': [ + 'web_search', + `validTool${d}authorizedServer`, + `otherTool${d}anotherServer`, + ], + }, + }, + ); + + getMCPServersRegistry.mockImplementation(() => { + throw new Error('MCPServersRegistry has not been initialized.'); + }); + + mockReq.user.id = existingAgentAuthorId.toString(); + mockReq.params.id = existingAgentId; + mockReq.body = { version_index: 0 }; + + await revertAgentVersionHandler(mockReq, mockRes); + + expect(mockRes.json).toHaveBeenCalled(); + const result = mockRes.json.mock.calls[0][0]; + expect(result.tools).toContain('web_search'); + expect(result.tools).toContain(`validTool${d}authorizedServer`); + expect(result.tools).toContain(`otherTool${d}anotherServer`); + + const agentInDb = await Agent.findOne({ id: existingAgentId }); + expect(agentInDb.tools).toContain(`validTool${d}authorizedServer`); + expect(agentInDb.tools).toContain(`otherTool${d}anotherServer`); + }); + }); +}); diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index dbb97df24b..309873e56c 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -49,6 +49,7 @@ const { refreshS3Url } = require('~/server/services/Files/S3/crud'); const { filterFile } = require('~/server/services/Files/process'); const { updateAction, getActions } = require('~/models/Action'); const { getCachedTools } = require('~/server/services/Config'); +const { getMCPServersRegistry } = require('~/config'); const { getLogStores } = require('~/cache'); const systemTools = { @@ -98,6 +99,78 @@ const validateEdgeAgentAccess = async (edges, userId, userRole) => { .map((a) => a.id); }; +/** + * Filters tools to only include those the user is authorized to use. + * MCP tools must match the exact format `{toolName}_mcp_{serverName}` (exactly 2 segments). + * Multi-delimiter keys are rejected to prevent authorization/execution mismatch. + * Non-MCP tools must appear in availableTools (global tool cache) or systemTools. + * + * When `existingTools` is provided and the MCP registry is unavailable (e.g. server restart), + * tools already present on the agent are preserved rather than stripped — they were validated + * when originally added, and we cannot re-verify them without the registry. + * @param {object} params + * @param {string[]} params.tools - Raw tool strings from the request + * @param {string} params.userId - Requesting user ID for MCP server access check + * @param {Record} params.availableTools - Global non-MCP tool cache + * @param {string[]} [params.existingTools] - Tools already persisted on the agent document + * @returns {Promise} Only the authorized subset of tools + */ +const filterAuthorizedTools = async ({ tools, userId, availableTools, existingTools }) => { + const filteredTools = []; + let mcpServerConfigs; + let registryUnavailable = false; + const existingToolSet = existingTools?.length ? new Set(existingTools) : null; + + for (const tool of tools) { + if (availableTools[tool] || systemTools[tool]) { + filteredTools.push(tool); + continue; + } + + if (!tool?.includes(Constants.mcp_delimiter)) { + continue; + } + + if (mcpServerConfigs === undefined) { + try { + mcpServerConfigs = (await getMCPServersRegistry().getAllServerConfigs(userId)) ?? {}; + } catch (e) { + logger.warn( + '[filterAuthorizedTools] MCP registry unavailable, filtering all MCP tools', + e.message, + ); + mcpServerConfigs = {}; + registryUnavailable = true; + } + } + + const parts = tool.split(Constants.mcp_delimiter); + if (parts.length !== 2) { + logger.warn( + `[filterAuthorizedTools] Rejected malformed MCP tool key "${tool}" for user ${userId}`, + ); + continue; + } + + if (registryUnavailable && existingToolSet?.has(tool)) { + filteredTools.push(tool); + continue; + } + + const [, serverName] = parts; + if (!serverName || !Object.hasOwn(mcpServerConfigs, serverName)) { + logger.warn( + `[filterAuthorizedTools] Rejected MCP tool "${tool}" — server "${serverName}" not accessible to user ${userId}`, + ); + continue; + } + + filteredTools.push(tool); + } + + return filteredTools; +}; + /** * Creates an Agent. * @route POST /Agents @@ -132,15 +205,7 @@ const createAgentHandler = async (req, res) => { agentData.tools = []; const availableTools = (await getCachedTools()) ?? {}; - for (const tool of tools) { - if (availableTools[tool]) { - agentData.tools.push(tool); - } else if (systemTools[tool]) { - agentData.tools.push(tool); - } else if (tool.includes(Constants.mcp_delimiter)) { - agentData.tools.push(tool); - } - } + agentData.tools = await filterAuthorizedTools({ tools, userId, availableTools }); const agent = await createAgent(agentData); @@ -322,6 +387,26 @@ const updateAgentHandler = async (req, res) => { updateData.tools = ocrConversion.tools; } + if (updateData.tools) { + const existingToolSet = new Set(existingAgent.tools ?? []); + const newMCPTools = updateData.tools.filter( + (t) => !existingToolSet.has(t) && t?.includes(Constants.mcp_delimiter), + ); + + if (newMCPTools.length > 0) { + const availableTools = (await getCachedTools()) ?? {}; + const approvedNew = await filterAuthorizedTools({ + tools: newMCPTools, + userId: req.user.id, + availableTools, + }); + const rejectedSet = new Set(newMCPTools.filter((t) => !approvedNew.includes(t))); + if (rejectedSet.size > 0) { + updateData.tools = updateData.tools.filter((t) => !rejectedSet.has(t)); + } + } + } + let updatedAgent = Object.keys(updateData).length > 0 ? await updateAgent({ id }, updateData, { @@ -464,6 +549,17 @@ const duplicateAgentHandler = async (req, res) => { const agentActions = await Promise.all(promises); newAgentData.actions = agentActions; + + if (newAgentData.tools?.length) { + const availableTools = (await getCachedTools()) ?? {}; + newAgentData.tools = await filterAuthorizedTools({ + tools: newAgentData.tools, + userId, + availableTools, + existingTools: newAgentData.tools, + }); + } + const newAgent = await createAgent(newAgentData); try { @@ -792,7 +888,24 @@ const revertAgentVersionHandler = async (req, res) => { // Permissions are enforced via route middleware (ACL EDIT) - const updatedAgent = await revertAgentVersion({ id }, version_index); + let updatedAgent = await revertAgentVersion({ id }, version_index); + + if (updatedAgent.tools?.length) { + const availableTools = (await getCachedTools()) ?? {}; + const filteredTools = await filterAuthorizedTools({ + tools: updatedAgent.tools, + userId: req.user.id, + availableTools, + existingTools: updatedAgent.tools, + }); + if (filteredTools.length !== updatedAgent.tools.length) { + updatedAgent = await updateAgent( + { id }, + { tools: filteredTools }, + { updatingUserId: req.user.id }, + ); + } + } if (updatedAgent.author) { updatedAgent.author = updatedAgent.author.toString(); @@ -860,4 +973,5 @@ module.exports = { uploadAgentAvatar: uploadAgentAvatarHandler, revertAgentVersion: revertAgentVersionHandler, getAgentCategories, + filterAuthorizedTools, }; From 6f87b49df8fa2696dc54d52a7a0ab117da8dbd60 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 23:01:36 -0400 Subject: [PATCH 36/39] =?UTF-8?q?=F0=9F=9B=82=20fix:=20Enforce=20Actions?= =?UTF-8?q?=20Capability=20Gate=20Across=20All=20Event-Driven=20Tool=20Loa?= =?UTF-8?q?ding=20Paths=20(#12252)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: gate action tools by actions capability in all code paths Extract resolveAgentCapabilities helper to eliminate 3x-duplicated capability resolution. Apply early action-tool filtering in both loadToolDefinitionsWrapper and loadAgentTools non-definitions path. Gate loadActionToolsForExecution in loadToolsForExecution behind an actionsEnabled parameter with a cache-based fallback. Replace the late capability guard in loadAgentTools with a hasActionTools check to avoid unnecessary loadActionSets DB calls and duplicate warnings. * fix: thread actionsEnabled through InitializedAgent type Add actionsEnabled to the loadTools callback return type, InitializedAgent, and the initializeAgent destructuring/return so callers can forward the resolved value to loadToolsForExecution without redundant getEndpointsConfig cache lookups. * fix: pass actionsEnabled from callers to loadToolsForExecution Thread actionsEnabled through the agentToolContexts map in initialize.js (primary and handoff agents) and through primaryConfig in the openai.js and responses.js controllers, avoiding per-tool-call capability re-resolution on the hot path. * test: add regression tests for action capability gating Test the real exported functions (resolveAgentCapabilities, loadAgentTools, loadToolsForExecution) with mocked dependencies instead of shadow re-implementations. Covers definition filtering, execution gating, actionsEnabled param forwarding, and fallback capability resolution. * test: use Constants.EPHEMERAL_AGENT_ID in ephemeral fallback test Replaces a string guess with the canonical constant to avoid fragility if the ephemeral detection heuristic changes. * fix: populate agentToolContexts for addedConvo parallel agents After processAddedConvo returns, backfill agentToolContexts for any agents in agentConfigs not already present, so ON_TOOL_EXECUTE for added-convo agents receives actionsEnabled instead of falling back to a per-call cache lookup. --- api/server/controllers/agents/openai.js | 1 + api/server/controllers/agents/responses.js | 2 + .../services/Endpoints/agents/initialize.js | 16 + api/server/services/ToolService.js | 73 ++-- .../services/__tests__/ToolService.spec.js | 312 +++++++++++++++++- packages/api/src/agents/initialize.ts | 6 + 6 files changed, 372 insertions(+), 38 deletions(-) diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js index e8561f15fe..bab81f1535 100644 --- a/api/server/controllers/agents/openai.js +++ b/api/server/controllers/agents/openai.js @@ -265,6 +265,7 @@ const OpenAIChatCompletionController = async (req, res) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); }, toolEndCallback, diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js index 83e6ad6efd..bbf02580dd 100644 --- a/api/server/controllers/agents/responses.js +++ b/api/server/controllers/agents/responses.js @@ -429,6 +429,7 @@ const createResponse = async (req, res) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); }, toolEndCallback, @@ -586,6 +587,7 @@ const createResponse = async (req, res) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); }, toolEndCallback, diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 44583e6dbc..762236ea19 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -128,6 +128,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { toolRegistry: ctx.toolRegistry, userMCPAuthMap: ctx.userMCPAuthMap, tool_resources: ctx.tool_resources, + actionsEnabled: ctx.actionsEnabled, }); logger.debug(`[ON_TOOL_EXECUTE] loaded ${result.loadedTools?.length ?? 0} tools`); @@ -214,6 +215,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { toolRegistry: primaryConfig.toolRegistry, userMCPAuthMap: primaryConfig.userMCPAuthMap, tool_resources: primaryConfig.tool_resources, + actionsEnabled: primaryConfig.actionsEnabled, }); const agent_ids = primaryConfig.agent_ids; @@ -297,6 +299,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { toolRegistry: config.toolRegistry, userMCPAuthMap: config.userMCPAuthMap, tool_resources: config.tool_resources, + actionsEnabled: config.actionsEnabled, }); agentConfigs.set(agentId, config); @@ -370,6 +373,19 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { userMCPAuthMap = updatedMCPAuthMap; } + for (const [agentId, config] of agentConfigs) { + if (agentToolContexts.has(agentId)) { + continue; + } + agentToolContexts.set(agentId, { + agent: config, + toolRegistry: config.toolRegistry, + userMCPAuthMap: config.userMCPAuthMap, + tool_resources: config.tool_resources, + actionsEnabled: config.actionsEnabled, + }); + } + // Ensure edges is an array when we have multiple agents (multi-agent mode) // MultiAgentGraph.categorizeEdges requires edges to be iterable if (agentConfigs.size > 0 && !edges) { diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 62499348e6..5fc95e748d 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -64,6 +64,26 @@ const { redactMessage } = require('~/config/parsers'); const { findPluginAuthsByKeys } = require('~/models'); const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); + +/** + * Resolves the set of enabled agent capabilities from endpoints config, + * falling back to app-level or default capabilities for ephemeral agents. + * @param {ServerRequest} req + * @param {Object} appConfig + * @param {string} agentId + * @returns {Promise>} + */ +async function resolveAgentCapabilities(req, appConfig, agentId) { + const endpointsConfig = await getEndpointsConfig(req); + let capabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); + if (capabilities.size === 0 && isEphemeralAgentId(agentId)) { + capabilities = new Set( + appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities, + ); + } + return capabilities; +} + /** * Processes the required actions by calling the appropriate tools and returning the outputs. * @param {OpenAIClient} client - OpenAI or StreamRunManager Client. @@ -445,17 +465,11 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to } const appConfig = req.config; - const endpointsConfig = await getEndpointsConfig(req); - let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); - - if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) { - enabledCapabilities = new Set( - appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities, - ); - } + const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id); const checkCapability = (capability) => enabledCapabilities.has(capability); const areToolsEnabled = checkCapability(AgentCapabilities.tools); + const actionsEnabled = checkCapability(AgentCapabilities.actions); const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools); const filteredTools = agent.tools?.filter((tool) => { @@ -468,7 +482,10 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to if (tool === Tools.web_search) { return checkCapability(AgentCapabilities.web_search); } - if (!areToolsEnabled && !tool.includes(actionDelimiter)) { + if (tool.includes(actionDelimiter)) { + return actionsEnabled; + } + if (!areToolsEnabled) { return false; } return true; @@ -765,6 +782,7 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to toolContextMap, toolDefinitions, hasDeferredTools, + actionsEnabled, }; } @@ -808,14 +826,7 @@ async function loadAgentTools({ } const appConfig = req.config; - const endpointsConfig = await getEndpointsConfig(req); - let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); - /** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */ - if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) { - enabledCapabilities = new Set( - appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities, - ); - } + const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent.id); const checkCapability = (capability) => { const enabled = enabledCapabilities.has(capability); if (!enabled) { @@ -832,6 +843,7 @@ async function loadAgentTools({ return enabled; }; const areToolsEnabled = checkCapability(AgentCapabilities.tools); + const actionsEnabled = checkCapability(AgentCapabilities.actions); let includesWebSearch = false; const _agentTools = agent.tools?.filter((tool) => { @@ -842,7 +854,9 @@ async function loadAgentTools({ } else if (tool === Tools.web_search) { includesWebSearch = checkCapability(AgentCapabilities.web_search); return includesWebSearch; - } else if (!areToolsEnabled && !tool.includes(actionDelimiter)) { + } else if (tool.includes(actionDelimiter)) { + return actionsEnabled; + } else if (!areToolsEnabled) { return false; } return true; @@ -947,13 +961,15 @@ async function loadAgentTools({ agentTools.push(...additionalTools); - if (!checkCapability(AgentCapabilities.actions)) { + const hasActionTools = _agentTools.some((t) => t.includes(actionDelimiter)); + if (!hasActionTools) { return { toolRegistry, userMCPAuthMap, toolContextMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: agentTools, }; } @@ -969,6 +985,7 @@ async function loadAgentTools({ toolContextMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: agentTools, }; } @@ -1101,6 +1118,7 @@ async function loadAgentTools({ userMCPAuthMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: agentTools, }; } @@ -1118,9 +1136,11 @@ async function loadAgentTools({ * @param {AbortSignal} [params.signal] - Abort signal * @param {Object} params.agent - The agent object * @param {string[]} params.toolNames - Names of tools to load + * @param {Map} [params.toolRegistry] - Tool registry * @param {Record>} [params.userMCPAuthMap] - User MCP auth map * @param {Object} [params.tool_resources] - Tool resources * @param {string|null} [params.streamId] - Stream ID for web search callbacks + * @param {boolean} [params.actionsEnabled] - Whether the actions capability is enabled * @returns {Promise<{ loadedTools: Array, configurable: Object }>} */ async function loadToolsForExecution({ @@ -1133,11 +1153,17 @@ async function loadToolsForExecution({ userMCPAuthMap, tool_resources, streamId = null, + actionsEnabled, }) { const appConfig = req.config; const allLoadedTools = []; const configurable = { userMCPAuthMap }; + if (actionsEnabled === undefined) { + const enabledCapabilities = await resolveAgentCapabilities(req, appConfig, agent?.id); + actionsEnabled = enabledCapabilities.has(AgentCapabilities.actions); + } + const isToolSearch = toolNames.includes(AgentConstants.TOOL_SEARCH); const isPTC = toolNames.includes(AgentConstants.PROGRAMMATIC_TOOL_CALLING); @@ -1194,7 +1220,6 @@ async function loadToolsForExecution({ const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter)); const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter)); - /** @type {Record} */ if (regularToolNames.length > 0) { const includesWebSearch = regularToolNames.includes(Tools.web_search); const webSearchCallbacks = includesWebSearch ? createOnSearchResults(res, streamId) : undefined; @@ -1225,7 +1250,7 @@ async function loadToolsForExecution({ } } - if (actionToolNames.length > 0 && agent) { + if (actionToolNames.length > 0 && agent && actionsEnabled) { const actionTools = await loadActionToolsForExecution({ req, res, @@ -1235,6 +1260,11 @@ async function loadToolsForExecution({ actionToolNames, }); allLoadedTools.push(...actionTools); + } else if (actionToolNames.length > 0 && agent && !actionsEnabled) { + logger.warn( + `[loadToolsForExecution] Capability "${AgentCapabilities.actions}" disabled. ` + + `Skipping action tool execution. User: ${req.user.id} | Agent: ${agent.id} | Tools: ${actionToolNames.join(', ')}`, + ); } if (isPTC && allLoadedTools.length > 0) { @@ -1395,4 +1425,5 @@ module.exports = { loadAgentTools, loadToolsForExecution, processRequiredActions, + resolveAgentCapabilities, }; diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js index c44298b09c..a468a88eb3 100644 --- a/api/server/services/__tests__/ToolService.spec.js +++ b/api/server/services/__tests__/ToolService.spec.js @@ -1,19 +1,304 @@ const { + Tools, Constants, + EModelEndpoint, + actionDelimiter, AgentCapabilities, defaultAgentCapabilities, } = require('librechat-data-provider'); -/** - * Tests for ToolService capability checking logic. - * The actual loadAgentTools function has many dependencies, so we test - * the capability checking logic in isolation. - */ -describe('ToolService - Capability Checking', () => { +const mockGetEndpointsConfig = jest.fn(); +const mockGetMCPServerTools = jest.fn(); +const mockGetCachedTools = jest.fn(); +jest.mock('~/server/services/Config', () => ({ + getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args), + getMCPServerTools: (...args) => mockGetMCPServerTools(...args), + getCachedTools: (...args) => mockGetCachedTools(...args), +})); + +const mockLoadToolDefinitions = jest.fn(); +const mockGetUserMCPAuthMap = jest.fn(); +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + loadToolDefinitions: (...args) => mockLoadToolDefinitions(...args), + getUserMCPAuthMap: (...args) => mockGetUserMCPAuthMap(...args), +})); + +const mockLoadToolsUtil = jest.fn(); +jest.mock('~/app/clients/tools/util', () => ({ + loadTools: (...args) => mockLoadToolsUtil(...args), +})); + +const mockLoadActionSets = jest.fn(); +jest.mock('~/server/services/Tools/credentials', () => ({ + loadAuthValues: jest.fn().mockResolvedValue({}), +})); +jest.mock('~/server/services/Tools/search', () => ({ + createOnSearchResults: jest.fn(), +})); +jest.mock('~/server/services/Tools/mcp', () => ({ + reinitMCPServer: jest.fn(), +})); +jest.mock('~/server/services/Files/process', () => ({ + processFileURL: jest.fn(), + uploadImageBuffer: jest.fn(), +})); +jest.mock('~/app/clients/tools/util/fileSearch', () => ({ + primeFiles: jest.fn().mockResolvedValue({}), +})); +jest.mock('~/server/services/Files/Code/process', () => ({ + primeFiles: jest.fn().mockResolvedValue({}), +})); +jest.mock('../ActionService', () => ({ + loadActionSets: (...args) => mockLoadActionSets(...args), + decryptMetadata: jest.fn(), + createActionTool: jest.fn(), + domainParser: jest.fn(), +})); +jest.mock('~/server/services/Threads', () => ({ + recordUsage: jest.fn(), +})); +jest.mock('~/models', () => ({ + findPluginAuthsByKeys: jest.fn(), +})); +jest.mock('~/config', () => ({ + getFlowStateManager: jest.fn(() => ({})), +})); +jest.mock('~/cache', () => ({ + getLogStores: jest.fn(() => ({})), +})); + +const { + loadAgentTools, + loadToolsForExecution, + resolveAgentCapabilities, +} = require('../ToolService'); + +function createMockReq(capabilities) { + return { + user: { id: 'user_123' }, + config: { + endpoints: { + [EModelEndpoint.agents]: { + capabilities, + }, + }, + }, + }; +} + +function createEndpointsConfig(capabilities) { + return { + [EModelEndpoint.agents]: { capabilities }, + }; +} + +describe('ToolService - Action Capability Gating', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockLoadToolDefinitions.mockResolvedValue({ + toolDefinitions: [], + toolRegistry: new Map(), + hasDeferredTools: false, + }); + mockLoadToolsUtil.mockResolvedValue({ loadedTools: [], toolContextMap: {} }); + mockLoadActionSets.mockResolvedValue([]); + }); + + describe('resolveAgentCapabilities', () => { + it('should return capabilities from endpoints config', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + const result = await resolveAgentCapabilities(req, req.config, 'agent_123'); + + expect(result).toBeInstanceOf(Set); + expect(result.has(AgentCapabilities.tools)).toBe(true); + expect(result.has(AgentCapabilities.actions)).toBe(true); + expect(result.has(AgentCapabilities.web_search)).toBe(false); + }); + + it('should fall back to default capabilities for ephemeral agents with empty config', async () => { + const req = createMockReq(defaultAgentCapabilities); + mockGetEndpointsConfig.mockResolvedValue({}); + + const result = await resolveAgentCapabilities(req, req.config, Constants.EPHEMERAL_AGENT_ID); + + for (const cap of defaultAgentCapabilities) { + expect(result.has(cap)).toBe(true); + } + }); + + it('should return empty set when no capabilities and not ephemeral', async () => { + const req = createMockReq([]); + mockGetEndpointsConfig.mockResolvedValue({}); + + const result = await resolveAgentCapabilities(req, req.config, 'agent_123'); + + expect(result.size).toBe(0); + }); + }); + + describe('loadAgentTools (definitionsOnly=true) — action tool filtering', () => { + const actionToolName = `get_weather${actionDelimiter}api_example_com`; + const regularTool = 'calculator'; + + it('should exclude action tools from definitions when actions capability is disabled', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.web_search]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: true, + }); + + expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1); + const [callArgs] = mockLoadToolDefinitions.mock.calls[0]; + expect(callArgs.tools).toContain(regularTool); + expect(callArgs.tools).not.toContain(actionToolName); + }); + + it('should include action tools in definitions when actions capability is enabled', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: true, + }); + + expect(mockLoadToolDefinitions).toHaveBeenCalledTimes(1); + const [callArgs] = mockLoadToolDefinitions.mock.calls[0]; + expect(callArgs.tools).toContain(regularTool); + expect(callArgs.tools).toContain(actionToolName); + }); + + it('should return actionsEnabled in the result', async () => { + const capabilities = [AgentCapabilities.tools]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + const result = await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool] }, + definitionsOnly: true, + }); + + expect(result.actionsEnabled).toBe(false); + }); + }); + + describe('loadAgentTools (definitionsOnly=false) — action tool filtering', () => { + const actionToolName = `get_weather${actionDelimiter}api_example_com`; + const regularTool = 'calculator'; + + it('should not load action sets when actions capability is disabled', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.web_search]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: false, + }); + + expect(mockLoadActionSets).not.toHaveBeenCalled(); + }); + + it('should load action sets when actions capability is enabled and action tools present', async () => { + const capabilities = [AgentCapabilities.tools, AgentCapabilities.actions]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadAgentTools({ + req, + res: {}, + agent: { id: 'agent_123', tools: [regularTool, actionToolName] }, + definitionsOnly: false, + }); + + expect(mockLoadActionSets).toHaveBeenCalledWith({ agent_id: 'agent_123' }); + }); + }); + + describe('loadToolsForExecution — action tool gating', () => { + const actionToolName = `get_weather${actionDelimiter}api_example_com`; + const regularTool = Tools.web_search; + + it('should skip action tool loading when actionsEnabled=false', async () => { + const req = createMockReq([]); + req.config = {}; + + const result = await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [regularTool, actionToolName], + actionsEnabled: false, + }); + + expect(mockLoadActionSets).not.toHaveBeenCalled(); + expect(result.loadedTools).toBeDefined(); + }); + + it('should load action tools when actionsEnabled=true', async () => { + const req = createMockReq([AgentCapabilities.actions]); + req.config = {}; + + await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [actionToolName], + actionsEnabled: true, + }); + + expect(mockLoadActionSets).toHaveBeenCalledWith({ agent_id: 'agent_123' }); + }); + + it('should resolve actionsEnabled from capabilities when not explicitly provided', async () => { + const capabilities = [AgentCapabilities.tools]; + const req = createMockReq(capabilities); + mockGetEndpointsConfig.mockResolvedValue(createEndpointsConfig(capabilities)); + + await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [actionToolName], + }); + + expect(mockGetEndpointsConfig).toHaveBeenCalled(); + expect(mockLoadActionSets).not.toHaveBeenCalled(); + }); + + it('should not call loadActionSets when there are no action tools', async () => { + const req = createMockReq([AgentCapabilities.actions]); + req.config = {}; + + await loadToolsForExecution({ + req, + res: {}, + agent: { id: 'agent_123' }, + toolNames: [regularTool], + actionsEnabled: true, + }); + + expect(mockLoadActionSets).not.toHaveBeenCalled(); + }); + }); + describe('checkCapability logic', () => { - /** - * Simulates the checkCapability function from loadAgentTools - */ const createCheckCapability = (enabledCapabilities, logger = { warn: jest.fn() }) => { return (capability) => { const enabled = enabledCapabilities.has(capability); @@ -124,10 +409,6 @@ describe('ToolService - Capability Checking', () => { }); describe('userMCPAuthMap gating', () => { - /** - * Simulates the guard condition used in both loadToolDefinitionsWrapper - * and loadAgentTools to decide whether getUserMCPAuthMap should be called. - */ const shouldFetchMCPAuth = (tools) => tools?.some((t) => t.includes(Constants.mcp_delimiter)) ?? false; @@ -178,20 +459,17 @@ describe('ToolService - Capability Checking', () => { return (capability) => enabledCapabilities.has(capability); }; - // When deferred_tools is in capabilities const withDeferred = new Set([AgentCapabilities.deferred_tools, AgentCapabilities.tools]); const checkWithDeferred = createCheckCapability(withDeferred); expect(checkWithDeferred(AgentCapabilities.deferred_tools)).toBe(true); - // When deferred_tools is NOT in capabilities const withoutDeferred = new Set([AgentCapabilities.tools, AgentCapabilities.actions]); const checkWithoutDeferred = createCheckCapability(withoutDeferred); expect(checkWithoutDeferred(AgentCapabilities.deferred_tools)).toBe(false); }); it('should use defaultAgentCapabilities when no capabilities configured', () => { - // Simulates the fallback behavior in loadAgentTools - const endpointsConfig = {}; // No capabilities configured + const endpointsConfig = {}; const enabledCapabilities = new Set( endpointsConfig?.capabilities ?? defaultAgentCapabilities, ); diff --git a/packages/api/src/agents/initialize.ts b/packages/api/src/agents/initialize.ts index af604beb81..913835a007 100644 --- a/packages/api/src/agents/initialize.ts +++ b/packages/api/src/agents/initialize.ts @@ -52,6 +52,8 @@ export type InitializedAgent = Agent & { toolDefinitions?: LCTool[]; /** Precomputed flag indicating if any tools have defer_loading enabled (for efficient runtime checks) */ hasDeferredTools?: boolean; + /** Whether the actions capability is enabled (resolved during tool loading) */ + actionsEnabled?: boolean; }; /** @@ -90,6 +92,7 @@ export interface InitializeAgentParams { /** Serializable tool definitions for event-driven mode */ toolDefinitions?: LCTool[]; hasDeferredTools?: boolean; + actionsEnabled?: boolean; } | null>; /** Endpoint option (contains model_parameters and endpoint info) */ endpointOption?: Partial; @@ -283,6 +286,7 @@ export async function initializeAgent( userMCPAuthMap, toolDefinitions, hasDeferredTools, + actionsEnabled, tools: structuredTools, } = (await loadTools?.({ req, @@ -300,6 +304,7 @@ export async function initializeAgent( toolRegistry: undefined, toolDefinitions: [], hasDeferredTools: false, + actionsEnabled: undefined, }; const { getOptions, overrideProvider } = getProviderConfig({ @@ -409,6 +414,7 @@ export async function initializeAgent( userMCPAuthMap, toolDefinitions, hasDeferredTools, + actionsEnabled, attachments: finalAttachments, toolContextMap: toolContextMap ?? {}, useLegacyContent: !!options.useLegacyContent, From 8e8fb01d18b7471607b4e3bd4a894ae135d3cfaa Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 23:02:36 -0400 Subject: [PATCH 37/39] =?UTF-8?q?=F0=9F=A7=B1=20fix:=20Enforce=20Agent=20A?= =?UTF-8?q?ccess=20Control=20on=20Context=20and=20OCR=20File=20Loading=20(?= =?UTF-8?q?#12253)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔏 fix: Apply agent access control filtering to context/OCR resource loading The context/OCR file path in primeResources fetched files by file_id without applying filterFilesByAgentAccess, unlike the file_search and execute_code paths. Add filterFiles dependency injection to primeResources and invoke it after getFiles to enforce consistent access control. * fix: Wire filterFilesByAgentAccess into all agent initialization callers Pass the filterFilesByAgentAccess function from the JS layer into the TS initializeAgent → primeResources chain via dependency injection, covering primary, handoff, added-convo, and memory agent init paths. * test: Add access control filtering tests for primeResources Cover filterFiles invocation with context/OCR files, verify filtering rejects inaccessible files, and confirm graceful fallback when filterFiles, userId, or agentId are absent. * fix: Guard filterFilesByAgentAccess against ephemeral agent IDs Ephemeral agents have no DB document, so getAgent returns null and the access map defaults to all-false, silently blocking all non-owned files. Short-circuit with isEphemeralAgentId to preserve the pass-through behavior for inline-built agents (memory, tool agents). * fix: Clean up resources.ts and JS caller import order Remove redundant optional chain on req.user.role inside user-guarded block, update primeResources JSDoc with filterFiles and agentId params, and reorder JS imports to longest-to-shortest per project conventions. * test: Strengthen OCR assertion and add filterFiles error-path test Use toHaveBeenCalledWith for the OCR filtering test to verify exact arguments after the OCR→context merge step. Add test for filterFiles rejection to verify graceful degradation (logs error, returns original tool_resources). * fix: Correct import order in addedConvo.js and initialize.js Sort by total line length descending: loadAddedAgent (91) before filterFilesByAgentAccess (84), loadAgentTools (91) before filterFilesByAgentAccess (84). * test: Add unit tests for filterFilesByAgentAccess and hasAccessToFilesViaAgent Cover every branch in permissions.js: ephemeral agent guard, missing userId/agentId/files early returns, all-owned short-circuit, mixed owned + non-owned with VIEW/no-VIEW, agent-not-found fail-closed, author path scoped to attached files, EDIT gate on delete, DB error fail-closed, and agent with no tool_resources. * test: Cover file.user undefined/null in permissions spec Files with no user field fall into the non-owned path and get run through hasAccessToFilesViaAgent. Add two cases: attached file with no user field is returned, unattached file with no user field is excluded. --- api/server/controllers/agents/client.js | 2 + .../services/Endpoints/agents/addedConvo.js | 2 + .../services/Endpoints/agents/initialize.js | 3 + api/server/services/Files/permissions.js | 4 +- api/server/services/Files/permissions.spec.js | 409 ++++++++++++++++++ packages/api/src/agents/initialize.ts | 6 +- packages/api/src/agents/resources.test.ts | 275 +++++++++++- packages/api/src/agents/resources.ts | 32 +- 8 files changed, 708 insertions(+), 25 deletions(-) create mode 100644 api/server/services/Files/permissions.spec.js diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 0ecd62b819..c454bd65cf 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -44,6 +44,7 @@ const { isEphemeralAgentId, removeNullishValues, } = require('librechat-data-provider'); +const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { updateBalance, bulkInsertTransactions } = require('~/models'); @@ -479,6 +480,7 @@ class AgentClient extends BaseClient { getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); diff --git a/api/server/services/Endpoints/agents/addedConvo.js b/api/server/services/Endpoints/agents/addedConvo.js index 25b1327991..11b87e450e 100644 --- a/api/server/services/Endpoints/agents/addedConvo.js +++ b/api/server/services/Endpoints/agents/addedConvo.js @@ -1,6 +1,7 @@ const { logger } = require('@librechat/data-schemas'); const { initializeAgent, validateAgentModel } = require('@librechat/api'); const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent'); +const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { getConvoFiles } = require('~/models/Conversation'); const { getAgent } = require('~/models/Agent'); const db = require('~/models'); @@ -108,6 +109,7 @@ const processAddedConvo = async ({ getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 762236ea19..08f631c3d2 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -22,6 +22,7 @@ const { getDefaultHandlers, } = require('~/server/controllers/agents/callbacks'); const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); +const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { getModelsConfig } = require('~/server/controllers/ModelController'); const { checkPermission } = require('~/server/services/PermissionService'); const AgentClient = require('~/server/controllers/agents/client'); @@ -204,6 +205,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); @@ -284,6 +286,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, getCodeGeneratedFiles: db.getCodeGeneratedFiles, + filterFilesByAgentAccess, }, ); diff --git a/api/server/services/Files/permissions.js b/api/server/services/Files/permissions.js index df484f7c29..b9a5d6656f 100644 --- a/api/server/services/Files/permissions.js +++ b/api/server/services/Files/permissions.js @@ -1,5 +1,5 @@ const { logger } = require('@librechat/data-schemas'); -const { PermissionBits, ResourceType } = require('librechat-data-provider'); +const { PermissionBits, ResourceType, isEphemeralAgentId } = require('librechat-data-provider'); const { checkPermission } = require('~/server/services/PermissionService'); const { getAgent } = require('~/models/Agent'); @@ -104,7 +104,7 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele * @returns {Promise>} Filtered array of accessible files */ const filterFilesByAgentAccess = async ({ files, userId, role, agentId }) => { - if (!userId || !agentId || !files || files.length === 0) { + if (!userId || !agentId || !files || files.length === 0 || isEphemeralAgentId(agentId)) { return files; } diff --git a/api/server/services/Files/permissions.spec.js b/api/server/services/Files/permissions.spec.js new file mode 100644 index 0000000000..85e7b2dc5b --- /dev/null +++ b/api/server/services/Files/permissions.spec.js @@ -0,0 +1,409 @@ +jest.mock('@librechat/data-schemas', () => ({ + logger: { error: jest.fn() }, +})); + +jest.mock('~/server/services/PermissionService', () => ({ + checkPermission: jest.fn(), +})); + +jest.mock('~/models/Agent', () => ({ + getAgent: jest.fn(), +})); + +const { logger } = require('@librechat/data-schemas'); +const { Constants, PermissionBits, ResourceType } = require('librechat-data-provider'); +const { checkPermission } = require('~/server/services/PermissionService'); +const { getAgent } = require('~/models/Agent'); +const { filterFilesByAgentAccess, hasAccessToFilesViaAgent } = require('./permissions'); + +const AUTHOR_ID = 'author-user-id'; +const USER_ID = 'viewer-user-id'; +const AGENT_ID = 'agent_test-abc123'; +const AGENT_MONGO_ID = 'mongo-agent-id'; + +function makeFile(file_id, user) { + return { file_id, user, filename: `${file_id}.txt` }; +} + +function makeAgent(overrides = {}) { + return { + _id: AGENT_MONGO_ID, + id: AGENT_ID, + author: AUTHOR_ID, + tool_resources: { + file_search: { file_ids: ['attached-1', 'attached-2'] }, + execute_code: { file_ids: ['attached-3'] }, + }, + ...overrides, + }; +} + +beforeEach(() => { + jest.clearAllMocks(); +}); + +describe('filterFilesByAgentAccess', () => { + describe('early returns (no DB calls)', () => { + it('should return files unfiltered for ephemeral agentId', async () => { + const files = [makeFile('f1', 'other-user')]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: Constants.EPHEMERAL_AGENT_ID, + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return files unfiltered for non-agent_ prefixed agentId', async () => { + const files = [makeFile('f1', 'other-user')]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: 'custom-memory-id', + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return files when userId is missing', async () => { + const files = [makeFile('f1', 'someone')]; + const result = await filterFilesByAgentAccess({ + files, + userId: undefined, + agentId: AGENT_ID, + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return files when agentId is missing', async () => { + const files = [makeFile('f1', 'someone')]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: undefined, + }); + + expect(result).toBe(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return empty array when files is empty', async () => { + const result = await filterFilesByAgentAccess({ + files: [], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + expect(getAgent).not.toHaveBeenCalled(); + }); + + it('should return undefined when files is nullish', async () => { + const result = await filterFilesByAgentAccess({ + files: null, + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toBeNull(); + expect(getAgent).not.toHaveBeenCalled(); + }); + }); + + describe('all files owned by userId', () => { + it('should return all files without calling getAgent', async () => { + const files = [makeFile('f1', USER_ID), makeFile('f2', USER_ID)]; + const result = await filterFilesByAgentAccess({ + files, + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual(files); + expect(getAgent).not.toHaveBeenCalled(); + }); + }); + + describe('mixed owned and non-owned files', () => { + const ownedFile = makeFile('owned-1', USER_ID); + const sharedFile = makeFile('attached-1', AUTHOR_ID); + const unattachedFile = makeFile('not-attached', AUTHOR_ID); + + it('should return owned + accessible non-owned files when user has VIEW', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile, unattachedFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toHaveLength(2); + expect(result.map((f) => f.file_id)).toContain('owned-1'); + expect(result.map((f) => f.file_id)).toContain('attached-1'); + expect(result.map((f) => f.file_id)).not.toContain('not-attached'); + }); + + it('should return only owned files when user lacks VIEW permission', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(false); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toEqual([ownedFile]); + }); + + it('should return only owned files when agent is not found', async () => { + getAgent.mockResolvedValue(null); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([ownedFile]); + }); + + it('should return only owned files on DB error (fail-closed)', async () => { + getAgent.mockRejectedValue(new Error('DB connection lost')); + + const result = await filterFilesByAgentAccess({ + files: [ownedFile, sharedFile], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([ownedFile]); + expect(logger.error).toHaveBeenCalled(); + }); + }); + + describe('file with no user field', () => { + it('should treat file as non-owned and run through access check', async () => { + const noUserFile = makeFile('attached-1', undefined); + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [noUserFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(getAgent).toHaveBeenCalled(); + expect(result).toEqual([noUserFile]); + }); + + it('should exclude file with no user field when not attached to agent', async () => { + const noUserFile = makeFile('not-attached', null); + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [noUserFile], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + }); + }); + + describe('no owned files (all non-owned)', () => { + const file1 = makeFile('attached-1', AUTHOR_ID); + const file2 = makeFile('not-attached', AUTHOR_ID); + + it('should return only attached files when user has VIEW', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await filterFilesByAgentAccess({ + files: [file1, file2], + userId: USER_ID, + role: 'USER', + agentId: AGENT_ID, + }); + + expect(result).toEqual([file1]); + }); + + it('should return empty array when no VIEW permission', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(false); + + const result = await filterFilesByAgentAccess({ + files: [file1, file2], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + }); + + it('should return empty array when agent not found', async () => { + getAgent.mockResolvedValue(null); + + const result = await filterFilesByAgentAccess({ + files: [file1], + userId: USER_ID, + agentId: AGENT_ID, + }); + + expect(result).toEqual([]); + }); + }); +}); + +describe('hasAccessToFilesViaAgent', () => { + describe('agent not found', () => { + it('should return all-false map', async () => { + getAgent.mockResolvedValue(null); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['f1', 'f2'], + agentId: AGENT_ID, + }); + + expect(result.get('f1')).toBe(false); + expect(result.get('f2')).toBe(false); + }); + }); + + describe('author path', () => { + it('should grant access to attached files for the agent author', async () => { + getAgent.mockResolvedValue(makeAgent()); + + const result = await hasAccessToFilesViaAgent({ + userId: AUTHOR_ID, + fileIds: ['attached-1', 'not-attached'], + agentId: AGENT_ID, + }); + + expect(result.get('attached-1')).toBe(true); + expect(result.get('not-attached')).toBe(false); + expect(checkPermission).not.toHaveBeenCalled(); + }); + }); + + describe('VIEW permission path', () => { + it('should grant access to attached files for viewer with VIEW permission', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(true); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + role: 'USER', + fileIds: ['attached-1', 'attached-3', 'not-attached'], + agentId: AGENT_ID, + }); + + expect(result.get('attached-1')).toBe(true); + expect(result.get('attached-3')).toBe(true); + expect(result.get('not-attached')).toBe(false); + + expect(checkPermission).toHaveBeenCalledWith({ + userId: USER_ID, + role: 'USER', + resourceType: ResourceType.AGENT, + resourceId: AGENT_MONGO_ID, + requiredPermission: PermissionBits.VIEW, + }); + }); + + it('should deny all when VIEW permission is missing', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValue(false); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['attached-1'], + agentId: AGENT_ID, + }); + + expect(result.get('attached-1')).toBe(false); + }); + }); + + describe('delete path (EDIT permission required)', () => { + it('should grant access when both VIEW and EDIT pass', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(true); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['attached-1'], + agentId: AGENT_ID, + isDelete: true, + }); + + expect(result.get('attached-1')).toBe(true); + expect(checkPermission).toHaveBeenCalledTimes(2); + expect(checkPermission).toHaveBeenLastCalledWith( + expect.objectContaining({ requiredPermission: PermissionBits.EDIT }), + ); + }); + + it('should deny all when VIEW passes but EDIT fails', async () => { + getAgent.mockResolvedValue(makeAgent()); + checkPermission.mockResolvedValueOnce(true).mockResolvedValueOnce(false); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['attached-1'], + agentId: AGENT_ID, + isDelete: true, + }); + + expect(result.get('attached-1')).toBe(false); + }); + }); + + describe('error handling', () => { + it('should return all-false map on DB error (fail-closed)', async () => { + getAgent.mockRejectedValue(new Error('connection refused')); + + const result = await hasAccessToFilesViaAgent({ + userId: USER_ID, + fileIds: ['f1', 'f2'], + agentId: AGENT_ID, + }); + + expect(result.get('f1')).toBe(false); + expect(result.get('f2')).toBe(false); + expect(logger.error).toHaveBeenCalledWith( + '[hasAccessToFilesViaAgent] Error checking file access:', + expect.any(Error), + ); + }); + }); + + describe('agent with no tool_resources', () => { + it('should deny all files even for the author', async () => { + getAgent.mockResolvedValue(makeAgent({ tool_resources: undefined })); + + const result = await hasAccessToFilesViaAgent({ + userId: AUTHOR_ID, + fileIds: ['f1'], + agentId: AGENT_ID, + }); + + expect(result.get('f1')).toBe(false); + }); + }); +}); diff --git a/packages/api/src/agents/initialize.ts b/packages/api/src/agents/initialize.ts index 913835a007..d5bfca5aba 100644 --- a/packages/api/src/agents/initialize.ts +++ b/packages/api/src/agents/initialize.ts @@ -31,6 +31,7 @@ import { filterFilesByEndpointConfig } from '~/files'; import { generateArtifactsPrompt } from '~/prompts'; import { getProviderConfig } from '~/endpoints'; import { primeResources } from './resources'; +import type { TFilterFilesByAgentAccess } from './resources'; /** * Extended agent type with additional fields needed after initialization @@ -111,7 +112,9 @@ export interface InitializeAgentDbMethods extends EndpointDbMethods { /** Update usage tracking for multiple files */ updateFilesUsage: (files: Array<{ file_id: string }>, fileIds?: string[]) => Promise; /** Get files from database */ - getFiles: (filter: unknown, sort: unknown, select: unknown, opts?: unknown) => Promise; + getFiles: (filter: unknown, sort: unknown, select: unknown) => Promise; + /** Filter files by agent access permissions (ownership or agent attachment) */ + filterFilesByAgentAccess?: TFilterFilesByAgentAccess; /** Get tool files by IDs (user-uploaded files only, code files handled separately) */ getToolFilesByIds: (fileIds: string[], toolSet: Set) => Promise; /** Get conversation file IDs */ @@ -271,6 +274,7 @@ export async function initializeAgent( const { attachments: primedAttachments, tool_resources } = await primeResources({ req: req as never, getFiles: db.getFiles as never, + filterFiles: db.filterFilesByAgentAccess, appConfig: req.config, agentId: agent.id, attachments: currentFiles diff --git a/packages/api/src/agents/resources.test.ts b/packages/api/src/agents/resources.test.ts index bfd2327764..641fb9284c 100644 --- a/packages/api/src/agents/resources.test.ts +++ b/packages/api/src/agents/resources.test.ts @@ -4,7 +4,7 @@ import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-dat import type { TAgentsEndpoint, TFile } from 'librechat-data-provider'; import type { IUser, AppConfig } from '@librechat/data-schemas'; import type { Request as ServerRequest } from 'express'; -import type { TGetFiles } from './resources'; +import type { TGetFiles, TFilterFilesByAgentAccess } from './resources'; // Mock logger jest.mock('@librechat/data-schemas', () => ({ @@ -17,16 +17,16 @@ describe('primeResources', () => { let mockReq: ServerRequest & { user?: IUser }; let mockAppConfig: AppConfig; let mockGetFiles: jest.MockedFunction; + let mockFilterFiles: jest.MockedFunction; let requestFileSet: Set; beforeEach(() => { - // Reset mocks jest.clearAllMocks(); - // Setup mock request - mockReq = {} as unknown as ServerRequest & { user?: IUser }; + mockReq = { + user: { id: 'user1', role: 'USER' }, + } as unknown as ServerRequest & { user?: IUser }; - // Setup mock appConfig mockAppConfig = { endpoints: { [EModelEndpoint.agents]: { @@ -35,10 +35,9 @@ describe('primeResources', () => { }, } as AppConfig; - // Setup mock getFiles function mockGetFiles = jest.fn(); + mockFilterFiles = jest.fn().mockImplementation(({ files }) => Promise.resolve(files)); - // Setup request file set requestFileSet = new Set(['file1', 'file2', 'file3']); }); @@ -70,20 +69,21 @@ describe('primeResources', () => { req: mockReq, appConfig: mockAppConfig, getFiles: mockGetFiles, + filterFiles: mockFilterFiles, requestFileSet, attachments: undefined, tool_resources, + agentId: 'agent_test', }); - expect(mockGetFiles).toHaveBeenCalledWith( - { file_id: { $in: ['ocr-file-1'] } }, - {}, - {}, - { userId: undefined, agentId: undefined }, - ); + expect(mockGetFiles).toHaveBeenCalledWith({ file_id: { $in: ['ocr-file-1'] } }, {}, {}); + expect(mockFilterFiles).toHaveBeenCalledWith({ + files: mockOcrFiles, + userId: 'user1', + role: 'USER', + agentId: 'agent_test', + }); expect(result.attachments).toEqual(mockOcrFiles); - // Context field is deleted after files are fetched and re-categorized - // Since the file is not embedded and has no special properties, it won't be categorized expect(result.tool_resources).toEqual({}); }); }); @@ -1108,12 +1108,10 @@ describe('primeResources', () => { 'ocr-file-1', ); - // Verify getFiles was called with merged file_ids expect(mockGetFiles).toHaveBeenCalledWith( { file_id: { $in: ['context-file-1', 'ocr-file-1'] } }, {}, {}, - { userId: undefined, agentId: undefined }, ); }); @@ -1241,6 +1239,249 @@ describe('primeResources', () => { }); }); + describe('access control filtering', () => { + it('should filter context files through filterFiles when provided', async () => { + const ownedFile: TFile = { + user: 'user1', + file_id: 'owned-file', + filename: 'owned.pdf', + filepath: '/uploads/owned.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + const inaccessibleFile: TFile = { + user: 'other-user', + file_id: 'inaccessible-file', + filename: 'secret.pdf', + filepath: '/uploads/secret.pdf', + object: 'file', + type: 'application/pdf', + bytes: 2048, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([ownedFile, inaccessibleFile]); + mockFilterFiles.mockResolvedValue([ownedFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['owned-file', 'inaccessible-file'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_shared', + }); + + expect(mockFilterFiles).toHaveBeenCalledWith({ + files: [ownedFile, inaccessibleFile], + userId: 'user1', + role: 'USER', + agentId: 'agent_shared', + }); + expect(result.attachments).toEqual([ownedFile]); + expect(result.attachments).not.toContainEqual(inaccessibleFile); + }); + + it('should filter OCR files merged into context through filterFiles', async () => { + const ocrFile: TFile = { + user: 'other-user', + file_id: 'ocr-restricted', + filename: 'scan.pdf', + filepath: '/uploads/scan.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([ocrFile]); + mockFilterFiles.mockResolvedValue([]); + + const tool_resources = { + [EToolResources.ocr]: { + file_ids: ['ocr-restricted'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_shared', + }); + + expect(mockFilterFiles).toHaveBeenCalledWith({ + files: [ocrFile], + userId: 'user1', + role: 'USER', + agentId: 'agent_shared', + }); + expect(result.attachments).toBeUndefined(); + }); + + it('should skip filtering when filterFiles is not provided', async () => { + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_test', + }); + + expect(mockFilterFiles).not.toHaveBeenCalled(); + expect(result.attachments).toEqual([mockFile]); + }); + + it('should skip filtering when user ID is missing', async () => { + const reqNoUser = {} as unknown as ServerRequest & { user?: IUser }; + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: reqNoUser, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_test', + }); + + expect(mockFilterFiles).not.toHaveBeenCalled(); + expect(result.attachments).toEqual([mockFile]); + }); + + it('should gracefully handle filterFiles rejection', async () => { + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + mockFilterFiles.mockRejectedValue(new Error('DB failure')); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + agentId: 'agent_test', + }); + + expect(logger.error).toHaveBeenCalledWith('Error priming resources', expect.any(Error)); + expect(result.tool_resources).toEqual(tool_resources); + }); + + it('should skip filtering when agentId is missing', async () => { + const mockFile: TFile = { + user: 'user1', + file_id: 'file-1', + filename: 'doc.pdf', + filepath: '/uploads/doc.pdf', + object: 'file', + type: 'application/pdf', + bytes: 1024, + embedded: false, + usage: 0, + }; + + mockGetFiles.mockResolvedValue([mockFile]); + + const tool_resources = { + [EToolResources.context]: { + file_ids: ['file-1'], + }, + }; + + const result = await primeResources({ + req: mockReq, + appConfig: mockAppConfig, + getFiles: mockGetFiles, + filterFiles: mockFilterFiles, + requestFileSet, + attachments: undefined, + tool_resources, + }); + + expect(mockFilterFiles).not.toHaveBeenCalled(); + expect(result.attachments).toEqual([mockFile]); + }); + }); + describe('edge cases', () => { it('should handle missing appConfig agents endpoint gracefully', async () => { const reqWithoutLocals = {} as ServerRequest & { user?: IUser }; diff --git a/packages/api/src/agents/resources.ts b/packages/api/src/agents/resources.ts index 4655453847..e147c743cf 100644 --- a/packages/api/src/agents/resources.ts +++ b/packages/api/src/agents/resources.ts @@ -10,16 +10,26 @@ import type { Request as ServerRequest } from 'express'; * @param filter - MongoDB filter query for files * @param _sortOptions - Sorting options (currently unused) * @param selectFields - Field selection options - * @param options - Additional options including userId and agentId for access control * @returns Promise resolving to array of files */ export type TGetFiles = ( filter: FilterQuery, _sortOptions: ProjectionType | null | undefined, selectFields: QueryOptions | null | undefined, - options?: { userId?: string; agentId?: string }, ) => Promise>; +/** + * Function type for filtering files by agent access permissions. + * Used to enforce that only files the user has access to (via ownership or agent attachment) + * are returned after a raw DB query. + */ +export type TFilterFilesByAgentAccess = (params: { + files: Array; + userId: string; + role?: string; + agentId: string; +}) => Promise>; + /** * Helper function to add a file to a specific tool resource category * Prevents duplicate files within the same resource category @@ -128,7 +138,7 @@ const categorizeFileForToolResources = ({ /** * Primes resources for agent execution by processing attachments and tool resources * This function: - * 1. Fetches OCR files if OCR is enabled + * 1. Fetches context/OCR files (filtered by agent access control when available) * 2. Processes attachment files * 3. Categorizes files into appropriate tool resources * 4. Prevents duplicate files across all sources @@ -137,15 +147,18 @@ const categorizeFileForToolResources = ({ * @param params.req - Express request object * @param params.appConfig - Application configuration object * @param params.getFiles - Function to retrieve files from database + * @param params.filterFiles - Optional function to enforce agent-based file access control * @param params.requestFileSet - Set of file IDs from the current request * @param params.attachments - Promise resolving to array of attachment files * @param params.tool_resources - Existing tool resources for the agent + * @param params.agentId - Agent ID used for access control filtering * @returns Promise resolving to processed attachments and updated tool resources */ export const primeResources = async ({ req, appConfig, getFiles, + filterFiles, requestFileSet, attachments: _attachments, tool_resources: _tool_resources, @@ -157,6 +170,7 @@ export const primeResources = async ({ attachments: Promise> | undefined; tool_resources: AgentToolResources | undefined; getFiles: TGetFiles; + filterFiles?: TFilterFilesByAgentAccess; agentId?: string; }): Promise<{ attachments: Array | undefined; @@ -228,15 +242,23 @@ export const primeResources = async ({ if (fileIds.length > 0 && isContextEnabled) { delete tool_resources[EToolResources.context]; - const context = await getFiles( + let context = await getFiles( { file_id: { $in: fileIds }, }, {}, {}, - { userId: req.user?.id, agentId }, ); + if (filterFiles && req.user?.id && agentId) { + context = await filterFiles({ + files: context, + userId: req.user.id, + role: req.user.role, + agentId, + }); + } + for (const file of context) { if (!file?.file_id) { continue; From acd07e80852f6b931a4459372981b5d3db8082da Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 23:03:12 -0400 Subject: [PATCH 38/39] =?UTF-8?q?=F0=9F=97=9D=EF=B8=8F=20fix:=20Exempt=20A?= =?UTF-8?q?dmin-Trusted=20Domains=20from=20MCP=20OAuth=20Validation=20(#12?= =?UTF-8?q?255)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: exempt allowedDomains from MCP OAuth SSRF checks (#12254) The SSRF guard in validateOAuthUrl was context-blind — it blocked private/internal OAuth endpoints even for admin-trusted MCP servers listed in mcpSettings.allowedDomains. Add isHostnameAllowed() to domain.ts and skip SSRF checks in validateOAuthUrl when the OAuth endpoint hostname matches an allowed domain. * refactor: thread allowedDomains through MCP connection stack Pass allowedDomains from MCPServersRegistry through BasicConnectionOptions, MCPConnectionFactory, and into MCPOAuthHandler method calls so the OAuth layer can exempt admin-trusted domains from SSRF validation. * test: add allowedDomains bypass tests and fix registry mocks Add isHostnameAllowed unit tests (exact, wildcard, case-insensitive, private IPs). Add MCPOAuthSecurity tests covering the allowedDomains bypass for initiateOAuthFlow, refreshOAuthTokens, and revokeOAuthToken. Update registry mocks to include getAllowedDomains. * fix: enforce protocol/port constraints in OAuth allowedDomains bypass Replace isHostnameAllowed (hostname-only check) with isOAuthUrlAllowed which parses the full OAuth URL and matches against allowedDomains entries including protocol and explicit port constraints — mirroring isDomainAllowedCore's allowlist logic. Prevents a port-scoped entry like 'https://auth.internal:8443' from also exempting other ports. * test: cover auto-discovery and branch-3 refresh paths with allowedDomains Add three new integration tests using a real OAuth test server: - auto-discovered OAuth endpoints allowed when server IP is in allowedDomains - auto-discovered endpoints rejected when allowedDomains doesn't match - refreshOAuthTokens branch 3 (no clientInfo/config) with allowedDomains bypass Also rename describe block from ephemeral issue number to durable name. * docs: explain intentional absence of allowedDomains in completeOAuthFlow Prevents future contributors from assuming a missing parameter during security audits — URLs are pre-validated during initiateOAuthFlow. * test: update initiateOAuthFlow assertion for allowedDomains parameter * perf: avoid redundant URL parse for admin-trusted OAuth endpoints Move isOAuthUrlAllowed check before the hostname extraction so admin-trusted URLs short-circuit with a single URL parse instead of two. The hostname extraction (new URL) is now deferred to the SSRF-check path where it's actually needed. --- api/server/controllers/UserController.js | 3 + packages/api/src/auth/domain.spec.ts | 91 ++++++++ packages/api/src/auth/domain.ts | 46 ++++ packages/api/src/mcp/ConnectionsRepository.ts | 4 +- packages/api/src/mcp/MCPConnectionFactory.ts | 5 + packages/api/src/mcp/MCPManager.ts | 5 +- packages/api/src/mcp/UserConnectionManager.ts | 4 +- .../__tests__/ConnectionsRepository.test.ts | 4 + .../__tests__/MCPConnectionFactory.test.ts | 1 + .../api/src/mcp/__tests__/MCPManager.test.ts | 1 + .../__tests__/MCPOAuthRaceCondition.test.ts | 2 + .../mcp/__tests__/MCPOAuthSecurity.test.ts | 214 ++++++++++++++++++ packages/api/src/mcp/oauth/handler.ts | 59 +++-- .../src/mcp/registry/MCPServerInspector.ts | 10 +- packages/api/src/mcp/types/index.ts | 1 + 15 files changed, 432 insertions(+), 18 deletions(-) diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index b3160bb3d3..6d5df0ac8d 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -370,6 +370,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { serverConfig.oauth?.revocation_endpoint_auth_methods_supported ?? clientMetadata.revocation_endpoint_auth_methods_supported; const oauthHeaders = serverConfig.oauth_headers ?? {}; + const allowedDomains = getMCPServersRegistry().getAllowedDomains(); if (tokens?.access_token) { try { @@ -385,6 +386,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { revocationEndpointAuthMethodsSupported, }, oauthHeaders, + allowedDomains, ); } catch (error) { logger.error(`Error revoking OAuth access token for ${serverName}:`, error); @@ -405,6 +407,7 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { revocationEndpointAuthMethodsSupported, }, oauthHeaders, + allowedDomains, ); } catch (error) { logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error); diff --git a/packages/api/src/auth/domain.spec.ts b/packages/api/src/auth/domain.spec.ts index a7140528a9..88a7c98160 100644 --- a/packages/api/src/auth/domain.spec.ts +++ b/packages/api/src/auth/domain.spec.ts @@ -8,6 +8,7 @@ import { extractMCPServerDomain, isActionDomainAllowed, isEmailDomainAllowed, + isOAuthUrlAllowed, isMCPDomainAllowed, isPrivateIP, isSSRFTarget, @@ -1211,6 +1212,96 @@ describe('isMCPDomainAllowed', () => { }); }); +describe('isOAuthUrlAllowed', () => { + it('should return false when allowedDomains is null/undefined/empty', () => { + expect(isOAuthUrlAllowed('https://example.com/token', null)).toBe(false); + expect(isOAuthUrlAllowed('https://example.com/token', undefined)).toBe(false); + expect(isOAuthUrlAllowed('https://example.com/token', [])).toBe(false); + }); + + it('should return false for unparseable URLs', () => { + expect(isOAuthUrlAllowed('not-a-url', ['example.com'])).toBe(false); + }); + + it('should match exact hostnames', () => { + expect(isOAuthUrlAllowed('https://example.com/token', ['example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://other.com/token', ['example.com'])).toBe(false); + }); + + it('should match wildcard subdomains', () => { + expect(isOAuthUrlAllowed('https://api.example.com/token', ['*.example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://deep.nested.example.com/token', ['*.example.com'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('https://example.com/token', ['*.example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://other.com/token', ['*.example.com'])).toBe(false); + }); + + it('should be case-insensitive', () => { + expect(isOAuthUrlAllowed('https://EXAMPLE.COM/token', ['example.com'])).toBe(true); + expect(isOAuthUrlAllowed('https://example.com/token', ['EXAMPLE.COM'])).toBe(true); + }); + + it('should match private/internal URLs when hostname is in allowedDomains', () => { + expect(isOAuthUrlAllowed('http://localhost:8080/token', ['localhost'])).toBe(true); + expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['10.0.0.1'])).toBe(true); + expect( + isOAuthUrlAllowed('http://host.docker.internal:8044/token', ['host.docker.internal']), + ).toBe(true); + expect(isOAuthUrlAllowed('http://myserver.local/token', ['*.local'])).toBe(true); + }); + + it('should match internal URLs with wildcard patterns', () => { + expect(isOAuthUrlAllowed('https://auth.company.internal/token', ['*.company.internal'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('https://company.internal/token', ['*.company.internal'])).toBe(true); + }); + + it('should not match when hostname is absent from allowedDomains', () => { + expect(isOAuthUrlAllowed('http://10.0.0.1/token', ['192.168.1.1'])).toBe(false); + expect(isOAuthUrlAllowed('http://localhost/token', ['host.docker.internal'])).toBe(false); + }); + + describe('protocol and port constraint enforcement', () => { + it('should enforce protocol when allowedDomains specifies one', () => { + expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal'])).toBe( + true, + ); + expect(isOAuthUrlAllowed('http://auth.internal/token', ['https://auth.internal'])).toBe( + false, + ); + }); + + it('should allow any protocol when allowedDomains has bare hostname', () => { + expect(isOAuthUrlAllowed('http://auth.internal/token', ['auth.internal'])).toBe(true); + expect(isOAuthUrlAllowed('https://auth.internal/token', ['auth.internal'])).toBe(true); + }); + + it('should enforce port when allowedDomains specifies one', () => { + expect( + isOAuthUrlAllowed('https://auth.internal:8443/token', ['https://auth.internal:8443']), + ).toBe(true); + expect( + isOAuthUrlAllowed('https://auth.internal:6379/token', ['https://auth.internal:8443']), + ).toBe(false); + expect(isOAuthUrlAllowed('https://auth.internal/token', ['https://auth.internal:8443'])).toBe( + false, + ); + }); + + it('should allow any port when allowedDomains has no explicit port', () => { + expect(isOAuthUrlAllowed('https://auth.internal:8443/token', ['auth.internal'])).toBe(true); + expect(isOAuthUrlAllowed('https://auth.internal:22/token', ['auth.internal'])).toBe(true); + }); + + it('should reject wrong port even when hostname matches (prevents port-scanning)', () => { + expect(isOAuthUrlAllowed('http://10.0.0.1:6379/token', ['http://10.0.0.1:8080'])).toBe(false); + expect(isOAuthUrlAllowed('http://10.0.0.1:25/token', ['http://10.0.0.1:8080'])).toBe(false); + }); + }); +}); + describe('validateEndpointURL', () => { afterEach(() => { jest.clearAllMocks(); diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index fabe2502ff..f4f9f5f04e 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -500,6 +500,52 @@ export async function isMCPDomainAllowed( return isDomainAllowedCore(domain, allowedDomains, MCP_PROTOCOLS); } +/** + * Checks whether an OAuth URL matches any entry in the MCP allowedDomains list, + * honoring protocol and port constraints when specified by the admin. + * + * Mirrors the allowlist-matching logic of {@link isDomainAllowedCore} (hostname, + * protocol, and explicit-port checks) but is synchronous — no DNS resolution is + * needed because the caller is deciding whether to *skip* the subsequent + * SSRF/DNS checks, not replace them. + * + * @remarks `parseDomainSpec` normalizes `www.` prefixes, so both the input URL + * and allowedDomains entries starting with `www.` are matched without that prefix. + */ +export function isOAuthUrlAllowed(url: string, allowedDomains?: string[] | null): boolean { + if (!Array.isArray(allowedDomains) || allowedDomains.length === 0) { + return false; + } + + const inputSpec = parseDomainSpec(url); + if (!inputSpec) { + return false; + } + + for (const allowedDomain of allowedDomains) { + const allowedSpec = parseDomainSpec(allowedDomain); + if (!allowedSpec) { + continue; + } + if (!hostnameMatches(inputSpec.hostname, allowedSpec)) { + continue; + } + if (allowedSpec.protocol !== null) { + if (inputSpec.protocol === null || inputSpec.protocol !== allowedSpec.protocol) { + continue; + } + } + if (allowedSpec.explicitPort) { + if (!inputSpec.explicitPort || inputSpec.port !== allowedSpec.port) { + continue; + } + } + return true; + } + + return false; +} + /** Matches ErrorTypes.INVALID_BASE_URL — string literal avoids build-time dependency on data-provider */ const INVALID_BASE_URL_TYPE = 'invalid_base_url'; diff --git a/packages/api/src/mcp/ConnectionsRepository.ts b/packages/api/src/mcp/ConnectionsRepository.ts index 970e7ea4b9..6313faa8d4 100644 --- a/packages/api/src/mcp/ConnectionsRepository.ts +++ b/packages/api/src/mcp/ConnectionsRepository.ts @@ -77,12 +77,14 @@ export class ConnectionsRepository { await this.disconnect(serverName); } } + const registry = MCPServersRegistry.getInstance(); const connection = await MCPConnectionFactory.create( { serverName, serverConfig, dbSourced: !!(serverConfig as t.ParsedServerConfig).dbId, - useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), + useSSRFProtection: registry.shouldEnableSSRFProtection(), + allowedDomains: registry.getAllowedDomains(), }, this.oauthOpts, ); diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 0fc86e0315..b5b3d61bf0 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -30,6 +30,7 @@ export class MCPConnectionFactory { protected readonly logPrefix: string; protected readonly useOAuth: boolean; protected readonly useSSRFProtection: boolean; + protected readonly allowedDomains?: string[] | null; // OAuth-related properties (only set when useOAuth is true) protected readonly userId?: string; @@ -197,6 +198,7 @@ export class MCPConnectionFactory { this.serverName = basic.serverName; this.useOAuth = !!oauth?.useOAuth; this.useSSRFProtection = basic.useSSRFProtection === true; + this.allowedDomains = basic.allowedDomains; this.connectionTimeout = oauth?.connectionTimeout; this.logPrefix = oauth?.user ? `[MCP][${basic.serverName}][${oauth.user.id}]` @@ -297,6 +299,7 @@ export class MCPConnectionFactory { }, this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, + this.allowedDomains, ); }; } @@ -340,6 +343,7 @@ export class MCPConnectionFactory { this.userId!, config?.oauth_headers ?? {}, config?.oauth, + this.allowedDomains, ); if (existingFlow) { @@ -603,6 +607,7 @@ export class MCPConnectionFactory { this.userId!, this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, + this.allowedDomains, ); // Store flow state BEFORE redirecting so the callback can find it diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 6fdf45c27a..afb6c68796 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -100,13 +100,16 @@ export class MCPManager extends UserConnectionManager { const useOAuth = Boolean(serverConfig.requiresOAuth || serverConfig.oauthMetadata); - const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection(); + const registry = MCPServersRegistry.getInstance(); + const useSSRFProtection = registry.shouldEnableSSRFProtection(); + const allowedDomains = registry.getAllowedDomains(); const dbSourced = !!serverConfig.dbId; const basic: t.BasicConnectionOptions = { dbSourced, serverName, serverConfig, useSSRFProtection, + allowedDomains, }; if (!useOAuth) { diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 76523fc0fc..2e9d5be467 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -153,12 +153,14 @@ export abstract class UserConnectionManager { logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`); try { + const registry = MCPServersRegistry.getInstance(); connection = await MCPConnectionFactory.create( { serverConfig: config, serverName: serverName, dbSourced: !!config.dbId, - useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(), + useSSRFProtection: registry.shouldEnableSSRFProtection(), + allowedDomains: registry.getAllowedDomains(), }, { useOAuth: true, diff --git a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts index 98e15eca18..7a93960765 100644 --- a/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts +++ b/packages/api/src/mcp/__tests__/ConnectionsRepository.test.ts @@ -25,6 +25,7 @@ const mockRegistryInstance = { getServerConfig: jest.fn(), getAllServerConfigs: jest.fn(), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }; jest.mock('../registry/MCPServersRegistry', () => ({ @@ -110,6 +111,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: mockServerConfigs.server1, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, @@ -133,6 +135,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: mockServerConfigs.server1, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, @@ -173,6 +176,7 @@ describe('ConnectionsRepository', () => { serverName: 'server1', serverConfig: configWithCachedAt, useSSRFProtection: false, + allowedDomains: null, dbSourced: false, }, undefined, diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index bceb23b246..23bfa89d56 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -269,6 +269,7 @@ describe('MCPConnectionFactory', () => { 'user123', {}, undefined, + undefined, ); // initFlow must be awaited BEFORE the redirect to guarantee state is stored diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index bf63a6af3c..dd1ead0dd9 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -34,6 +34,7 @@ const mockRegistryInstance = { getAllServerConfigs: jest.fn(), getOAuthServers: jest.fn(), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }; jest.mock('~/mcp/registry/MCPServersRegistry', () => ({ diff --git a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts index 85febb3ece..cb6187ab45 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts @@ -82,6 +82,7 @@ describe('MCP OAuth Race Condition Fixes', () => { .mockReturnValue({ getServerConfig: jest.fn().mockResolvedValue(mockConfig), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }); const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); @@ -147,6 +148,7 @@ describe('MCP OAuth Race Condition Fixes', () => { .mockReturnValue({ getServerConfig: jest.fn().mockResolvedValue(mockConfig), shouldEnableSSRFProtection: jest.fn().mockReturnValue(false), + getAllowedDomains: jest.fn().mockReturnValue(null), }); const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory'); diff --git a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts index a5188e24b0..a2d0440d42 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthSecurity.test.ts @@ -7,6 +7,9 @@ * * 2. redirect_uri manipulation — validates that user-supplied redirect_uri * is ignored in favor of the server-controlled default. + * + * 3. allowedDomains SSRF exemption — validates that admin-configured allowedDomains + * exempts trusted domains from SSRF checks, including auto-discovery paths. */ import * as http from 'http'; @@ -226,3 +229,214 @@ describe('MCP OAuth redirect_uri enforcement', () => { expect(authUrl.searchParams.get('redirect_uri')).not.toBe(attackerRedirectUri); }); }); + +describe('MCP OAuth allowedDomains SSRF exemption for admin-trusted hosts', () => { + it('should allow private authorization_url when hostname is in allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'internal-server', + 'https://speedy-mcp.company.com/', + 'user-1', + {}, + { + authorization_url: 'http://10.0.0.1/authorize', + token_url: 'http://10.0.0.1/token', + client_id: 'client', + client_secret: 'secret', + }, + ['10.0.0.1'], + ); + + expect(result.authorizationUrl).toContain('10.0.0.1/authorize'); + }); + + it('should allow private token_url when hostname matches wildcard allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'internal-server', + 'https://speedy-mcp.company.com/', + 'user-1', + {}, + { + authorization_url: 'https://auth.company.internal/authorize', + token_url: 'https://auth.company.internal/token', + client_id: 'client', + client_secret: 'secret', + }, + ['*.company.internal'], + ); + + expect(result.authorizationUrl).toContain('auth.company.internal/authorize'); + }); + + it('should still reject private URLs when allowedDomains does not match', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://169.254.169.254/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + ['safe.example.com'], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should still reject when allowedDomains is empty', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'https://mcp.example.com/', + 'user-1', + {}, + { + authorization_url: 'http://10.0.0.1/authorize', + token_url: 'https://auth.example.com/token', + client_id: 'client', + client_secret: 'secret', + }, + [], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should allow private revocationEndpoint when hostname is in allowedDomains', async () => { + const mockFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + } as Response); + const originalFetch = global.fetch; + global.fetch = mockFetch; + + try { + await MCPOAuthHandler.revokeOAuthToken( + 'internal-server', + 'some-token', + 'access', + { + serverUrl: 'https://internal.corp.net/', + clientId: 'client', + clientSecret: 'secret', + revocationEndpoint: 'http://10.0.0.1/revoke', + }, + {}, + ['10.0.0.1'], + ); + + expect(mockFetch).toHaveBeenCalled(); + } finally { + global.fetch = originalFetch; + } + }); + + it('should allow localhost token_url in refreshOAuthTokens when localhost is in allowedDomains', async () => { + const mockFetch = jest.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + }), + } as Response); + const originalFetch = global.fetch; + global.fetch = mockFetch; + + try { + const tokens = await MCPOAuthHandler.refreshOAuthTokens( + 'old-refresh-token', + { + serverName: 'local-server', + serverUrl: 'http://localhost:8080/', + clientInfo: { + client_id: 'client-id', + client_secret: 'client-secret', + redirect_uris: ['http://localhost:3080/callback'], + }, + }, + {}, + { + token_url: 'http://localhost:8080/token', + client_id: 'client-id', + client_secret: 'client-secret', + }, + ['localhost'], + ); + + expect(tokens.access_token).toBe('new-access-token'); + expect(mockFetch).toHaveBeenCalled(); + } finally { + global.fetch = originalFetch; + } + }); + + describe('auto-discovery path with allowedDomains', () => { + let discoveryServer: OAuthTestServer; + + beforeEach(async () => { + discoveryServer = await createOAuthMCPServer({ + tokenTTLMs: 60000, + issueRefreshTokens: true, + }); + }); + + afterEach(async () => { + await discoveryServer.close(); + }); + + it('should allow auto-discovered OAuth endpoints when server IP is in allowedDomains', async () => { + const result = await MCPOAuthHandler.initiateOAuthFlow( + 'discovery-server', + discoveryServer.url, + 'user-1', + {}, + undefined, + ['127.0.0.1'], + ); + + expect(result.authorizationUrl).toContain('127.0.0.1'); + expect(result.flowId).toBeTruthy(); + }); + + it('should reject auto-discovered endpoints when allowedDomains does not cover server IP', async () => { + await expect( + MCPOAuthHandler.initiateOAuthFlow( + 'discovery-server', + discoveryServer.url, + 'user-1', + {}, + undefined, + ['safe.example.com'], + ), + ).rejects.toThrow(/targets a blocked address/); + }); + + it('should allow auto-discovered token_url in refreshOAuthTokens branch 3 (no clientInfo/config)', async () => { + const code = await discoveryServer.getAuthCode(); + const tokenRes = await fetch(`${discoveryServer.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}`, + }); + const initial = (await tokenRes.json()) as { + access_token: string; + refresh_token: string; + }; + + const tokens = await MCPOAuthHandler.refreshOAuthTokens( + initial.refresh_token, + { + serverName: 'discovery-refresh-server', + serverUrl: discoveryServer.url, + }, + {}, + undefined, + ['127.0.0.1'], + ); + + expect(tokens.access_token).toBeTruthy(); + }); + }); +}); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index 8d863bfe79..0a9154ff35 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -24,7 +24,7 @@ import { selectRegistrationAuthMethod, inferClientAuthMethod, } from './methods'; -import { isSSRFTarget, resolveHostnameSSRF } from '~/auth'; +import { isSSRFTarget, resolveHostnameSSRF, isOAuthUrlAllowed } from '~/auth'; import { sanitizeUrlForLogging } from '~/mcp/utils'; /** Type for the OAuth metadata from the SDK */ @@ -123,6 +123,7 @@ export class MCPOAuthHandler { private static async discoverMetadata( serverUrl: string, oauthHeaders: Record, + allowedDomains?: string[] | null, ): Promise<{ metadata: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; @@ -146,7 +147,7 @@ export class MCPOAuthHandler { if (resourceMetadata?.authorization_servers?.length) { const discoveredAuthServer = resourceMetadata.authorization_servers[0]; - await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server'); + await this.validateOAuthUrl(discoveredAuthServer, 'authorization_server', allowedDomains); authServerUrl = new URL(discoveredAuthServer); logger.debug( `[MCPOAuth] Found authorization server from resource metadata: ${authServerUrl}`, @@ -206,11 +207,17 @@ export class MCPOAuthHandler { const endpointChecks: Promise[] = []; if (metadata.registration_endpoint) { endpointChecks.push( - this.validateOAuthUrl(metadata.registration_endpoint, 'registration_endpoint'), + this.validateOAuthUrl( + metadata.registration_endpoint, + 'registration_endpoint', + allowedDomains, + ), ); } if (metadata.token_endpoint) { - endpointChecks.push(this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint')); + endpointChecks.push( + this.validateOAuthUrl(metadata.token_endpoint, 'token_endpoint', allowedDomains), + ); } if (endpointChecks.length > 0) { await Promise.all(endpointChecks); @@ -360,6 +367,7 @@ export class MCPOAuthHandler { userId: string, oauthHeaders: Record, config?: MCPOptions['oauth'], + allowedDomains?: string[] | null, ): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> { logger.debug( `[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`, @@ -375,8 +383,8 @@ export class MCPOAuthHandler { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for ${serverName}`); await Promise.all([ - this.validateOAuthUrl(config.authorization_url, 'authorization_url'), - this.validateOAuthUrl(config.token_url, 'token_url'), + this.validateOAuthUrl(config.authorization_url, 'authorization_url', allowedDomains), + this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains), ]); const skipCodeChallengeCheck = @@ -477,6 +485,7 @@ export class MCPOAuthHandler { const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata( serverUrl, oauthHeaders, + allowedDomains, ); logger.debug( @@ -588,7 +597,11 @@ export class MCPOAuthHandler { } /** - * Completes the OAuth flow by exchanging the authorization code for tokens + * Completes the OAuth flow by exchanging the authorization code for tokens. + * + * `allowedDomains` is intentionally absent: all URLs used here (serverUrl, + * token_endpoint) originate from {@link MCPOAuthFlowMetadata} that was + * SSRF-validated during {@link initiateOAuthFlow}. No new URL resolution occurs. */ static async completeOAuthFlow( flowId: string, @@ -692,8 +705,20 @@ export class MCPOAuthHandler { return randomBytes(32).toString('base64url'); } - /** Validates an OAuth URL is not targeting a private/internal address */ - private static async validateOAuthUrl(url: string, fieldName: string): Promise { + /** + * Validates an OAuth URL is not targeting a private/internal address. + * Skipped when the full URL (hostname + protocol + port) matches an admin-trusted + * allowedDomains entry, honoring protocol/port constraints when the admin specifies them. + */ + private static async validateOAuthUrl( + url: string, + fieldName: string, + allowedDomains?: string[] | null, + ): Promise { + if (isOAuthUrlAllowed(url, allowedDomains)) { + return; + } + let hostname: string; try { hostname = new URL(url).hostname; @@ -799,6 +824,7 @@ export class MCPOAuthHandler { metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation }, oauthHeaders: Record, config?: MCPOptions['oauth'], + allowedDomains?: string[] | null, ): Promise { logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`); @@ -824,7 +850,7 @@ export class MCPOAuthHandler { let tokenUrl: string; let authMethods: string[] | undefined; if (config?.token_url) { - await this.validateOAuthUrl(config.token_url, 'token_url'); + await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains); tokenUrl = config.token_url; authMethods = config.token_endpoint_auth_methods_supported; } else if (!metadata.serverUrl) { @@ -851,7 +877,7 @@ export class MCPOAuthHandler { tokenUrl = oauthMetadata.token_endpoint; authMethods = oauthMetadata.token_endpoint_auth_methods_supported; } - await this.validateOAuthUrl(tokenUrl, 'token_url'); + await this.validateOAuthUrl(tokenUrl, 'token_url', allowedDomains); } const body = new URLSearchParams({ @@ -928,7 +954,7 @@ export class MCPOAuthHandler { if (config?.token_url && config?.client_id) { logger.debug(`[MCPOAuth] Using pre-configured OAuth settings for token refresh`); - await this.validateOAuthUrl(config.token_url, 'token_url'); + await this.validateOAuthUrl(config.token_url, 'token_url', allowedDomains); const tokenUrl = new URL(config.token_url); const body = new URLSearchParams({ @@ -1026,7 +1052,7 @@ export class MCPOAuthHandler { } else { tokenUrl = new URL(oauthMetadata.token_endpoint); } - await this.validateOAuthUrl(tokenUrl.href, 'token_url'); + await this.validateOAuthUrl(tokenUrl.href, 'token_url', allowedDomains); const body = new URLSearchParams({ grant_type: 'refresh_token', @@ -1075,9 +1101,14 @@ export class MCPOAuthHandler { revocationEndpointAuthMethodsSupported?: string[]; }, oauthHeaders: Record = {}, + allowedDomains?: string[] | null, ): Promise { if (metadata.revocationEndpoint != null) { - await this.validateOAuthUrl(metadata.revocationEndpoint, 'revocation_endpoint'); + await this.validateOAuthUrl( + metadata.revocationEndpoint, + 'revocation_endpoint', + allowedDomains, + ); } const revokeUrl: URL = metadata.revocationEndpoint != null diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts index a477d9b412..7f31211680 100644 --- a/packages/api/src/mcp/registry/MCPServerInspector.ts +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -20,6 +20,7 @@ export class MCPServerInspector { private readonly config: t.ParsedServerConfig, private connection: MCPConnection | undefined, private readonly useSSRFProtection: boolean = false, + private readonly allowedDomains?: string[] | null, ) {} /** @@ -46,7 +47,13 @@ export class MCPServerInspector { const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0; const start = Date.now(); - const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection); + const inspector = new MCPServerInspector( + serverName, + rawConfig, + connection, + useSSRFProtection, + allowedDomains, + ); await inspector.inspectServer(); inspector.config.initDuration = Date.now() - start; return inspector.config; @@ -68,6 +75,7 @@ export class MCPServerInspector { serverName: this.serverName, dbSourced: !!this.config.dbId, useSSRFProtection: this.useSSRFProtection, + allowedDomains: this.allowedDomains, }); } diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index bbdabb4428..0af10c7399 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -169,6 +169,7 @@ export interface BasicConnectionOptions { serverName: string; serverConfig: MCPOptions; useSSRFProtection?: boolean; + allowedDomains?: string[] | null; /** When true, only resolve customUserVars in processMCPEnv (for DB-stored servers) */ dbSourced?: boolean; } From 8271055c2da8c0eb18d6cb7703525e11289bef59 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sun, 15 Mar 2026 23:51:41 -0400 Subject: [PATCH 39/39] =?UTF-8?q?=F0=9F=93=A6=20chore:=20Bump=20`@librecha?= =?UTF-8?q?t/agents`=20to=20v3.1.56=20(#12258)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 📦 chore: Bump `@librechat/agents` to v3.1.56 * chore: resolve type error, URL property check in isMCPDomainAllowed function --- api/package.json | 2 +- package-lock.json | 11 ++++++----- packages/api/package.json | 2 +- packages/api/src/auth/domain.ts | 4 +++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/api/package.json b/api/package.json index 0305446818..89a5183ddd 100644 --- a/api/package.json +++ b/api/package.json @@ -44,7 +44,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", diff --git a/package-lock.json b/package-lock.json index 502b3a8eed..45f737ad8f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -59,7 +59,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", @@ -12324,9 +12324,9 @@ } }, "node_modules/@librechat/agents": { - "version": "3.1.55", - "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.55.tgz", - "integrity": "sha512-impxeKpCDlPkAVQFWnA6u6xkxDSBR/+H8uYq7rZomBeu0rUh/OhJLiI1fAwPhKXP33udNtHA8GyDi0QJj78R9w==", + "version": "3.1.56", + "resolved": "https://registry.npmjs.org/@librechat/agents/-/agents-3.1.56.tgz", + "integrity": "sha512-HJJwRnLM4XKpTWB4/wPDJR+iegyKBVUwqj7A8QHqzEcHzjKJDTr3wBPxZVH1tagGr6/mbbnErOJ14cH1OSNmpA==", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "^0.73.0", @@ -12347,6 +12347,7 @@ "@langfuse/tracing": "^4.3.0", "@opentelemetry/sdk-node": "^0.207.0", "@scarf/scarf": "^1.4.0", + "ai-tokenizer": "^1.0.6", "axios": "^1.13.5", "cheerio": "^1.0.0", "dotenv": "^16.4.7", @@ -44239,7 +44240,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/package.json b/packages/api/package.json index 77258fc0b3..b3b40c79a2 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -90,7 +90,7 @@ "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.1.55", + "@librechat/agents": "^3.1.56", "@librechat/data-schemas": "*", "@modelcontextprotocol/sdk": "^1.27.1", "@smithy/node-http-handler": "^4.4.5", diff --git a/packages/api/src/auth/domain.ts b/packages/api/src/auth/domain.ts index f4f9f5f04e..f5719829d5 100644 --- a/packages/api/src/auth/domain.ts +++ b/packages/api/src/auth/domain.ts @@ -485,7 +485,9 @@ export async function isMCPDomainAllowed( const hasAllowlist = Array.isArray(allowedDomains) && allowedDomains.length > 0; const hasExplicitUrl = - Object.hasOwn(config, 'url') && typeof config.url === 'string' && config.url.trim().length > 0; + Object.prototype.hasOwnProperty.call(config, 'url') && + typeof config.url === 'string' && + config.url.trim().length > 0; if (!domain && hasExplicitUrl && hasAllowlist) { return false;