Merge branch 'main' into claude/fix-mcp-accent-support-UBEjT

This commit is contained in:
Lionel Ringenbach 2026-03-11 11:27:32 -07:00 committed by GitHub
commit e3671a6835
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 5749 additions and 437 deletions

View file

@ -850,3 +850,24 @@ OPENWEATHER_API_KEY=
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
# MCP_SKIP_CODE_CHALLENGE_CHECK=false
# Circuit breaker: max connect/disconnect cycles before tripping (per server)
# MCP_CB_MAX_CYCLES=7
# Circuit breaker: sliding window (ms) for counting cycles
# MCP_CB_CYCLE_WINDOW_MS=45000
# Circuit breaker: cooldown (ms) after the cycle breaker trips
# MCP_CB_CYCLE_COOLDOWN_MS=15000
# Circuit breaker: max consecutive failed connection rounds before backoff
# MCP_CB_MAX_FAILED_ROUNDS=3
# Circuit breaker: sliding window (ms) for counting failed rounds
# MCP_CB_FAILED_WINDOW_MS=120000
# Circuit breaker: base backoff (ms) after failed round threshold is reached
# MCP_CB_BASE_BACKOFF_MS=30000
# Circuit breaker: max backoff cap (ms) for exponential backoff
# MCP_CB_MAX_BACKOFF_MS=300000

View file

@ -149,7 +149,15 @@ Multi-line imports count total character length across all lines. Consolidate va
- Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc.
- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering.
- Cover loading, success, and error states for UI/data flows.
- Mock data-provider hooks and external dependencies.
### Philosophy
- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort.
- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic.
- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls.
- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals.
- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls.
- Heavy mocking is a code smell, not a testing strategy.
---

View file

@ -1,7 +1,6 @@
const DALLE3 = require('../DALLE3');
const { ProxyAgent } = require('undici');
jest.mock('tiktoken');
const processFileURL = jest.fn();
describe('DALLE3 Proxy Configuration', () => {

View file

@ -14,15 +14,6 @@ jest.mock('@librechat/data-schemas', () => {
};
});
jest.mock('tiktoken', () => {
return {
encoding_for_model: jest.fn().mockReturnValue({
encode: jest.fn(),
decode: jest.fn(),
}),
};
});
const processFileURL = jest.fn();
const generate = jest.fn();

View file

@ -51,6 +51,7 @@
"@modelcontextprotocol/sdk": "^1.27.1",
"@node-saml/passport-saml": "^5.1.0",
"@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5",
"bcryptjs": "^2.4.3",
"compression": "^1.8.1",
@ -106,7 +107,6 @@
"pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0",
"sharp": "^0.33.5",
"tiktoken": "^1.0.15",
"traverse": "^0.6.7",
"ua-parser-js": "^1.0.36",
"undici": "^7.18.2",

View file

@ -1172,7 +1172,11 @@ class AgentClient extends BaseClient {
}
}
/** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */
getEncoding() {
if (this.model && this.model.toLowerCase().includes('claude')) {
return 'claude';
}
return 'o200k_base';
}

View file

@ -32,6 +32,9 @@ jest.mock('@librechat/api', () => {
getFlowState: jest.fn(),
completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(),
resolveStateToFlowId: jest.fn(async (state) => state),
storeStateMapping: jest.fn(),
deleteStateMapping: jest.fn(),
},
MCPTokenStorage: {
storeTokens: jest.fn(),
@ -180,7 +183,10 @@ describe('MCP Routes', () => {
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
authorizationUrl: 'https://oauth.example.com/auth',
flowId: 'test-user-id:test-server',
flowMetadata: { state: 'random-state-value' },
});
MCPOAuthHandler.storeStateMapping.mockResolvedValue();
mockFlowManager.initFlow = jest.fn().mockResolvedValue();
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
userId: 'test-user-id',
@ -367,6 +373,121 @@ describe('MCP Routes', () => {
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
});
describe('CSRF fallback via active PENDING flow', () => {
it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'PENDING',
createdAt: Date.now(),
}),
completeFlow: jest.fn().mockResolvedValue(true),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: {},
clientInfo: {},
codeVerifier: 'test-verifier',
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({
access_token: 'test-token',
});
MCPTokenStorage.storeTokens.mockResolvedValue();
mockRegistryInstance.getServerConfig.mockResolvedValue({});
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue({
fetchTools: jest.fn().mockResolvedValue([]),
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getOAuthReconnectionManager.mockReturnValue({
clearReconnection: jest.fn(),
});
require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue();
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
});
it('should reject when no PENDING flow exists and no cookies are present', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue(null),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'COMPLETED',
createdAt: Date.now(),
}),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
const flowId = 'test-user-id:test-server';
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
status: 'PENDING',
createdAt: Date.now() - 3 * 60 * 1000,
}),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const response = await request(app)
.get('/api/mcp/test-server/oauth/callback')
.query({ code: 'test-code', state: flowId });
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe(
`${basePath}/oauth/error?error=csrf_validation_failed`,
);
});
});
it('should handle OAuth callback successfully', async () => {
// mockRegistryInstance is defined at the top of the file
const mockFlowManager = {

View file

@ -13,6 +13,7 @@ const {
MCPOAuthHandler,
MCPTokenStorage,
setOAuthSession,
PENDING_STALE_MS,
getUserMCPAuthMap,
validateOAuthCsrf,
OAUTH_CSRF_COOKIE,
@ -91,7 +92,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
}
const oauthHeaders = await getOAuthHeaders(serverName, userId);
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
const {
authorizationUrl,
flowId: oauthFlowId,
flowMetadata,
} = await MCPOAuthHandler.initiateOAuthFlow(
serverName,
serverUrl,
userId,
@ -101,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager);
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
res.redirect(authorizationUrl);
} catch (error) {
@ -143,30 +149,52 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
}
const flowId = state;
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager);
if (!flowId) {
logger.error('[MCP OAuth] Could not resolve state to flow ID', { state });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId });
const flowParts = flowId.split(':');
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
logger.error('[MCP OAuth] Invalid flow ID format', { flowId });
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
const [flowUserId] = flowParts;
if (
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
!validateOAuthSession(req, flowUserId)
) {
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
flowId,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
});
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH);
const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId);
let hasActiveFlow = false;
if (!hasCsrf && !hasSession) {
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS;
if (hasActiveFlow) {
logger.debug(
'[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow',
{
flowId,
},
);
}
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
if (!hasCsrf && !hasSession && !hasActiveFlow) {
logger.error(
'[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow',
{
flowId,
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
},
);
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
}
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
@ -281,7 +309,13 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const toolFlowId = flowState.metadata?.toolFlowId;
if (toolFlowId) {
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
if (!completed) {
logger.warn(
'[MCP OAuth] Tool flow state not found during completion — waiter will time out',
{ toolFlowId },
);
}
}
/** Redirect to success page with flowId and serverName */

View file

@ -34,6 +34,55 @@ const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache');
const MAX_CACHE_SIZE = 1000;
const lastReconnectAttempts = new Map();
const RECONNECT_THROTTLE_MS = 10_000;
const missingToolCache = new Map();
const MISSING_TOOL_TTL_MS = 10_000;
function evictStale(map, ttl) {
if (map.size <= MAX_CACHE_SIZE) {
return;
}
const now = Date.now();
for (const [key, timestamp] of map) {
if (now - timestamp >= ttl) {
map.delete(key);
}
if (map.size <= MAX_CACHE_SIZE) {
return;
}
}
}
const unavailableMsg =
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
/**
* @param {string} toolName
* @param {string} serverName
*/
function createUnavailableToolStub(toolName, serverName) {
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
const _call = async () => [unavailableMsg, null];
const toolInstance = tool(_call, {
schema: {
type: 'object',
properties: {
input: { type: 'string', description: 'Input for the tool' },
},
required: [],
},
name: normalizedToolKey,
description: unavailableMsg,
responseFormat: AgentConstants.CONTENT_AND_ARTIFACT,
});
toolInstance.mcp = true;
toolInstance.mcpRawServerName = serverName;
return toolInstance;
}
function isEmptyObjectSchema(jsonSchema) {
return (
jsonSchema != null &&
@ -211,6 +260,17 @@ async function reconnectServer({
logger.debug(
`[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`,
);
const throttleKey = `${user.id}:${serverName}`;
const now = Date.now();
const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0;
if (now - lastAttempt < RECONNECT_THROTTLE_MS) {
logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`);
return null;
}
lastReconnectAttempts.set(throttleKey, now);
evictStale(lastReconnectAttempts, RECONNECT_THROTTLE_MS);
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${user.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
@ -267,7 +327,7 @@ async function reconnectServer({
userMCPAuthMap,
forceNew: true,
returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES,
connectionTimeout: Time.THIRTY_SECONDS,
});
} finally {
// Clean up abort handler to prevent memory leaks
@ -330,9 +390,13 @@ async function createMCPTools({
userMCPAuthMap,
streamId,
});
if (result === null) {
logger.debug(`[MCP][${serverName}] Reconnect throttled, skipping tool creation.`);
return [];
}
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
return [];
}
const serverTools = [];
@ -402,6 +466,14 @@ async function createMCPTool({
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {
const cachedAt = missingToolCache.get(toolKey);
if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) {
logger.debug(
`[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`,
);
return createUnavailableToolStub(toolName, serverName);
}
logger.warn(
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
);
@ -415,11 +487,18 @@ async function createMCPTool({
streamId,
});
toolDefinition = result?.availableTools?.[toolKey]?.function;
if (!toolDefinition) {
missingToolCache.set(toolKey, Date.now());
evictStale(missingToolCache, MISSING_TOOL_TTL_MS);
}
}
if (!toolDefinition) {
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
return;
logger.warn(
`[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`,
);
return createUnavailableToolStub(toolName, serverName);
}
return createToolInstance({
@ -720,4 +799,5 @@ module.exports = {
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
createUnavailableToolStub,
};

View file

@ -45,6 +45,7 @@ const {
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
createUnavailableToolStub,
} = require('./MCP');
jest.mock('./Config', () => ({
@ -1098,6 +1099,188 @@ describe('User parameter passing tests', () => {
});
});
describe('createUnavailableToolStub', () => {
it('should return a tool whose _call returns a valid CONTENT_AND_ARTIFACT two-tuple', async () => {
const stub = createUnavailableToolStub('myTool', 'myServer');
// invoke() goes through langchain's base tool, which checks responseFormat.
// CONTENT_AND_ARTIFACT requires [content, artifact] — a bare string would throw:
// "Tool response format is "content_and_artifact" but the output was not a two-tuple"
const result = await stub.invoke({});
// If we reach here without throwing, the two-tuple format is correct.
// invoke() returns the content portion of [content, artifact] as a string.
expect(result).toContain('temporarily unavailable');
});
});
describe('negative tool cache and throttle interaction', () => {
it('should cache tool as missing even when throttled (cross-user dedup)', async () => {
const mockUser = { id: 'throttle-test-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// First call: reconnect succeeds but tool not found
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: `missing-tool${D}cache-dedup-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
// Second call within 10s for DIFFERENT tool on same server:
// reconnect is throttled (returns null), tool is still cached as missing.
// This is intentional: the cache acts as cross-user dedup since the
// throttle is per-user-per-server and can't prevent N different users
// from each triggering their own reconnect.
const result2 = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: `other-tool${D}cache-dedup-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result2).toBeDefined();
expect(result2.name).toContain('other-tool');
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
it('should prevent user B from triggering reconnect when user A already cached the tool', async () => {
const userA = { id: 'cache-user-A' };
const userB = { id: 'cache-user-B' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// User A: real reconnect, tool not found → cached
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `shared-tool${D}cross-user-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// User B requests the SAME tool within 10s.
// The negative cache is keyed by toolKey (no user prefix), so user B
// gets a cache hit and no reconnect fires. This is the cross-user
// storm protection: without this, user B's unthrottled first request
// would trigger a second reconnect to the same server.
const result = await createMCPTool({
res: mockRes,
user: userB,
toolKey: `shared-tool${D}cross-user-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result).toBeDefined();
expect(result.name).toContain('shared-tool');
// reinitMCPServer still called only once — user B hit the cache
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
it('should prevent user B from triggering reconnect for throttle-cached tools', async () => {
const userA = { id: 'storm-user-A' };
const userB = { id: 'storm-user-B' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// User A: real reconnect for tool-1, tool not found → cached
mockReinitMCPServer.mockResolvedValueOnce({
availableTools: {},
});
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `tool-1${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
// User A: tool-2 on same server within 10s → throttled → cached from throttle
await createMCPTool({
res: mockRes,
user: userA,
toolKey: `tool-2${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// User B requests tool-2 — gets cache hit from the throttle-cached entry.
// Without this caching, user B would trigger a real reconnect since
// user B has their own throttle key and hasn't reconnected yet.
const result = await createMCPTool({
res: mockRes,
user: userB,
toolKey: `tool-2${D}storm-server`,
provider: 'openai',
userMCPAuthMap: {},
availableTools: undefined,
});
expect(result).toBeDefined();
expect(result.name).toContain('tool-2');
// Still only 1 real reconnect — user B was protected by the cache
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
});
});
describe('createMCPTools throttle handling', () => {
it('should return empty array with debug log when reconnect is throttled', async () => {
const mockUser = { id: 'throttle-tools-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// First call: real reconnect
mockReinitMCPServer.mockResolvedValueOnce({
tools: [{ name: 'tool1' }],
availableTools: {
[`tool1${D}throttle-tools-server`]: {
function: { description: 'Tool 1', parameters: {} },
},
},
});
await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'throttle-tools-server',
provider: 'openai',
userMCPAuthMap: {},
});
// Second call within 10s — throttled
const result = await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'throttle-tools-server',
provider: 'openai',
userMCPAuthMap: {},
});
expect(result).toEqual([]);
// reinitMCPServer called only once — second was throttled
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
// Should log at debug level (not warn) for throttled case
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Reconnect throttled'));
});
});
describe('User parameter integrity', () => {
it('should preserve user object properties through the call chain', async () => {
const complexUser = {

View file

@ -1,5 +1,4 @@
// --- Mocks ---
jest.mock('tiktoken');
jest.mock('fs');
jest.mock('path');
jest.mock('node-fetch');

23
package-lock.json generated
View file

@ -66,6 +66,7 @@
"@modelcontextprotocol/sdk": "^1.27.1",
"@node-saml/passport-saml": "^5.1.0",
"@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5",
"bcryptjs": "^2.4.3",
"compression": "^1.8.1",
@ -121,7 +122,6 @@
"pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0",
"sharp": "^0.33.5",
"tiktoken": "^1.0.15",
"traverse": "^0.6.7",
"ua-parser-js": "^1.0.36",
"undici": "^7.18.2",
@ -22230,6 +22230,20 @@
"node": ">= 14"
}
},
"node_modules/ai-tokenizer": {
"version": "1.0.6",
"resolved": "https://registry.npmjs.org/ai-tokenizer/-/ai-tokenizer-1.0.6.tgz",
"integrity": "sha512-GaakQFxen0pRH/HIA4v68ZM40llCH27HUYUSBLK+gVuZ57e53pYJe1xFvSTj4sJJjbWU92m1X6NjPWyeWkFDow==",
"license": "MIT",
"peerDependencies": {
"ai": "^5.0.0"
},
"peerDependenciesMeta": {
"ai": {
"optional": true
}
}
},
"node_modules/ajv": {
"version": "8.18.0",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
@ -41485,11 +41499,6 @@
"node": ">=0.8"
}
},
"node_modules/tiktoken": {
"version": "1.0.15",
"resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.15.tgz",
"integrity": "sha512-sCsrq/vMWUSEW29CJLNmPvWxlVp7yh2tlkAjpJltIKqp5CKf98ZNpdeHRmAlPVFlGEbswDc6SmI8vz64W/qErw=="
},
"node_modules/timers-browserify": {
"version": "2.0.12",
"resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz",
@ -44200,6 +44209,7 @@
"@librechat/data-schemas": "*",
"@modelcontextprotocol/sdk": "^1.27.1",
"@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5",
"connect-redis": "^8.1.0",
"eventsource": "^3.0.2",
@ -44222,7 +44232,6 @@
"node-fetch": "2.7.0",
"pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0",
"tiktoken": "^1.0.15",
"undici": "^7.18.2",
"zod": "^3.22.4"
}

View file

@ -7,6 +7,7 @@ export default {
'\\.dev\\.ts$',
'\\.helper\\.ts$',
'\\.helper\\.d\\.ts$',
'/__tests__/helpers/',
],
coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit',

View file

@ -18,8 +18,8 @@
"build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs",
"build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs",
"build:watch:prod": "rollup -c -w --bundleConfigAsCjs",
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
"test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
"test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
@ -94,6 +94,7 @@
"@librechat/data-schemas": "*",
"@modelcontextprotocol/sdk": "^1.27.1",
"@smithy/node-http-handler": "^4.4.5",
"ai-tokenizer": "^1.0.6",
"axios": "^1.13.5",
"connect-redis": "^8.1.0",
"eventsource": "^3.0.2",
@ -116,7 +117,6 @@
"node-fetch": "2.7.0",
"pdfjs-dist": "^5.4.624",
"rate-limit-redis": "^4.2.0",
"tiktoken": "^1.0.15",
"undici": "^7.18.2",
"zod": "^3.22.4"
}

View file

@ -22,8 +22,9 @@ jest.mock('winston', () => ({
}));
// Mock the Tokenizer
jest.mock('~/utils', () => ({
Tokenizer: {
jest.mock('~/utils/tokenizer', () => ({
__esModule: true,
default: {
getTokenCount: jest.fn((text: string) => text.length), // Simple mock: 1 char = 1 token
},
}));

View file

@ -19,7 +19,8 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
import type { Response as ServerResponse } from 'express';
import { GenerationJobManager } from '~/stream/GenerationJobManager';
import { Tokenizer, resolveHeaders, createSafeUser } from '~/utils';
import { resolveHeaders, createSafeUser } from '~/utils';
import Tokenizer from '~/utils/tokenizer';
type RequiredMemoryMethods = Pick<
MemoryMethods,

View file

@ -3,6 +3,18 @@ import { logger } from '@librechat/data-schemas';
import type { StoredDataNoRaw } from 'keyv';
import type { FlowState, FlowMetadata, FlowManagerOptions } from './types';
export const PENDING_STALE_MS = 2 * 60 * 1000;
const SECONDS_THRESHOLD = 1e10;
/**
* Normalizes an expiration timestamp to milliseconds.
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
*/
export function normalizeExpiresAt(timestamp: number): number {
return timestamp < SECONDS_THRESHOLD ? timestamp * 1000 : timestamp;
}
export class FlowStateManager<T = unknown> {
private keyv: Keyv;
private ttl: number;
@ -45,32 +57,8 @@ export class FlowStateManager<T = unknown> {
return `${type}:${flowId}`;
}
/**
* Normalizes an expiration timestamp to milliseconds.
* Detects whether the input is in seconds or milliseconds based on magnitude.
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
* @param timestamp - The expiration timestamp (in seconds or milliseconds)
* @returns The timestamp normalized to milliseconds
*/
private normalizeExpirationTimestamp(timestamp: number): number {
const SECONDS_THRESHOLD = 1e10;
if (timestamp < SECONDS_THRESHOLD) {
return timestamp * 1000;
}
return timestamp;
}
/**
* Checks if a flow's token has expired based on its expires_at field
* @param flowState - The flow state to check
* @returns true if the token has expired, false otherwise (including if no expires_at exists)
*/
private isTokenExpired(flowState: FlowState<T> | undefined): boolean {
if (!flowState?.result) {
return false;
}
if (typeof flowState.result !== 'object') {
if (!flowState?.result || typeof flowState.result !== 'object') {
return false;
}
@ -79,13 +67,11 @@ export class FlowStateManager<T = unknown> {
}
const expiresAt = (flowState.result as { expires_at: unknown }).expires_at;
if (typeof expiresAt !== 'number' || !Number.isFinite(expiresAt)) {
return false;
}
const normalizedExpiresAt = this.normalizeExpirationTimestamp(expiresAt);
return normalizedExpiresAt < Date.now();
return normalizeExpiresAt(expiresAt) < Date.now();
}
/**
@ -149,6 +135,8 @@ export class FlowStateManager<T = unknown> {
let elapsedTime = 0;
let isCleanedUp = false;
let intervalId: NodeJS.Timeout | null = null;
let missingStateRetried = false;
let isRetrying = false;
// Cleanup function to avoid duplicate cleanup
const cleanup = () => {
@ -188,16 +176,29 @@ export class FlowStateManager<T = unknown> {
}
intervalId = setInterval(async () => {
if (isCleanedUp) return;
if (isCleanedUp || isRetrying) return;
try {
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
let flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) {
cleanup();
logger.error(`[${flowKey}] Flow state not found`);
reject(new Error(`${type} Flow state not found`));
return;
if (!missingStateRetried) {
missingStateRetried = true;
isRetrying = true;
logger.warn(
`[${flowKey}] Flow state not found, retrying once after 500ms (race recovery)`,
);
await new Promise((r) => setTimeout(r, 500));
flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
isRetrying = false;
}
if (!flowState) {
cleanup();
logger.error(`[${flowKey}] Flow state not found after retry`);
reject(new Error(`${type} Flow state not found`));
return;
}
}
if (signal?.aborted) {
@ -251,10 +252,10 @@ export class FlowStateManager<T = unknown> {
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) {
logger.warn('[FlowStateManager] Cannot complete flow - flow state not found', {
flowId,
type,
});
logger.warn(
'[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping',
{ flowId, type },
);
return false;
}
@ -297,7 +298,7 @@ export class FlowStateManager<T = unknown> {
async isFlowStale(
flowId: string,
type: string,
staleThresholdMs: number = 2 * 60 * 1000,
staleThresholdMs: number = PENDING_STALE_MS,
): Promise<{ isStale: boolean; age: number; status?: string }> {
const flowKey = this.getFlowKey(flowId, type);
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;

View file

@ -15,6 +15,8 @@ export * from './mcp/errors';
/* Utilities */
export * from './mcp/utils';
export * from './utils';
export { default as Tokenizer, countTokens } from './utils/tokenizer';
export type { EncodingName } from './utils/tokenizer';
export * from './db/utils';
/* OAuth */
export * from './oauth';

View file

@ -2,11 +2,11 @@ import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { Tool } from '@modelcontextprotocol/sdk/types.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth';
import type { MCPOAuthTokens, OAuthMetadata, MCPOAuthFlowMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import type * as t from './types';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
import { PENDING_STALE_MS, normalizeExpiresAt } from '~/flow/manager';
import { sanitizeUrlForLogging } from './utils';
import { withTimeout } from '~/utils/promise';
import { MCPConnection } from './connection';
@ -104,6 +104,7 @@ export class MCPConnectionFactory {
return { tools, connection, oauthRequired: false, oauthUrl: null };
}
} catch {
MCPConnection.decrementCycleCount(this.serverName);
logger.debug(
`${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`,
);
@ -125,7 +126,9 @@ export class MCPConnectionFactory {
}
return { tools, connection: null, oauthRequired, oauthUrl };
}
MCPConnection.decrementCycleCount(this.serverName);
} catch (listError) {
MCPConnection.decrementCycleCount(this.serverName);
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError);
}
@ -265,6 +268,10 @@ export class MCPConnectionFactory {
if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`);
return tokens;
} catch (error) {
if (error instanceof ReauthenticationRequiredError) {
logger.info(`${this.logPrefix} ${error.message}, will trigger OAuth flow`);
return null;
}
logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error);
return null;
}
@ -306,11 +313,21 @@ export class MCPConnectionFactory {
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
if (existingFlow?.status === 'PENDING') {
const pendingAge = existingFlow.createdAt
? Date.now() - existingFlow.createdAt
: Infinity;
if (pendingAge < PENDING_STALE_MS) {
logger.debug(
`${this.logPrefix} Recent PENDING OAuth flow exists (${Math.round(pendingAge / 1000)}s old), skipping new initiation`,
);
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
}
logger.debug(
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will replace`,
);
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
}
const {
@ -326,11 +343,17 @@ export class MCPConnectionFactory {
);
if (existingFlow) {
const oldState = (existingFlow.metadata as MCPOAuthFlowMetadata)?.state;
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
if (oldState) {
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager!);
}
}
// Store flow state BEFORE redirecting so the callback can find it
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', flowMetadata);
const metadataWithUrl = { ...flowMetadata, authorizationUrl };
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager!);
// Start monitoring in background — createFlow will find the existing PENDING state
// written by initFlow above, so metadata arg is unused (pass {} to make that explicit)
@ -495,11 +518,75 @@ export class MCPConnectionFactory {
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow) {
const flowMeta = existingFlow.metadata as MCPOAuthFlowMetadata | undefined;
if (existingFlow.status === 'PENDING') {
const pendingAge = existingFlow.createdAt
? Date.now() - existingFlow.createdAt
: Infinity;
if (pendingAge < PENDING_STALE_MS) {
logger.debug(
`${this.logPrefix} Found recent PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), joining instead of creating new one`,
);
const storedAuthUrl = flowMeta?.authorizationUrl;
if (storedAuthUrl && typeof this.oauthStart === 'function') {
logger.info(
`${this.logPrefix} Re-issuing stored authorization URL to caller while joining PENDING flow`,
);
await this.oauthStart(storedAuthUrl);
}
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth', {}, this.signal);
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(
`${this.logPrefix} Joined existing OAuth flow completed for ${this.serverName}`,
);
return {
tokens,
clientInfo: flowMeta?.clientInfo,
metadata: flowMeta?.metadata,
};
}
logger.debug(
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will delete and start fresh`,
);
}
if (existingFlow.status === 'COMPLETED') {
const completedAge = existingFlow.completedAt
? Date.now() - existingFlow.completedAt
: Infinity;
const cachedTokens = existingFlow.result as MCPOAuthTokens | null | undefined;
const isTokenExpired =
cachedTokens?.expires_at != null &&
normalizeExpiresAt(cachedTokens.expires_at) < Date.now();
if (completedAge <= PENDING_STALE_MS && cachedTokens !== undefined && !isTokenExpired) {
logger.debug(
`${this.logPrefix} Found non-stale COMPLETED OAuth flow, reusing cached tokens`,
);
return {
tokens: cachedTokens,
clientInfo: flowMeta?.clientInfo,
metadata: flowMeta?.metadata,
};
}
}
logger.debug(
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
);
try {
const oldState = flowMeta?.state;
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
if (oldState) {
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager);
}
} catch (error) {
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
}
@ -519,7 +606,9 @@ export class MCPConnectionFactory {
);
// Store flow state BEFORE redirecting so the callback can find it
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata);
const metadataWithUrl = { ...flowMetadata, authorizationUrl };
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager);
if (typeof this.oauthStart === 'function') {
logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`);

View file

@ -1,10 +1,10 @@
import { logger } from '@librechat/data-schemas';
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { MCPConnection } from './connection';
import type * as t from './types';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPConnection } from './connection';
import { mcpConfig } from './mcpConfig';
/**
@ -21,6 +21,8 @@ export abstract class UserConnectionManager {
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
/** Last activity timestamp for users (not per server) */
protected userLastActivity: Map<string, number> = new Map();
/** In-flight connection promises keyed by `userId:serverName` — coalesces concurrent attempts */
protected pendingConnections: Map<string, Promise<MCPConnection>> = new Map();
/** Updates the last activity timestamp for a user */
protected updateUserLastActivity(userId: string): void {
@ -31,29 +33,64 @@ export abstract class UserConnectionManager {
);
}
/** Gets or creates a connection for a specific user */
public async getUserConnection({
serverName,
forceNew,
user,
flowManager,
customUserVars,
requestBody,
tokenMethods,
oauthStart,
oauthEnd,
signal,
returnOnOAuth = false,
connectionTimeout,
}: {
serverName: string;
forceNew?: boolean;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
/** Gets or creates a connection for a specific user, coalescing concurrent attempts */
public async getUserConnection(
opts: {
serverName: string;
forceNew?: boolean;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
): Promise<MCPConnection> {
const { serverName, forceNew, user } = opts;
const userId = user?.id;
if (!userId) {
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
}
const lockKey = `${userId}:${serverName}`;
if (!forceNew) {
const pending = this.pendingConnections.get(lockKey);
if (pending) {
logger.debug(`[MCP][User: ${userId}][${serverName}] Joining in-flight connection attempt`);
return pending;
}
}
const connectionPromise = this.createUserConnectionInternal(opts, userId);
if (!forceNew) {
this.pendingConnections.set(lockKey, connectionPromise);
}
try {
return await connectionPromise;
} finally {
if (!forceNew && this.pendingConnections.get(lockKey) === connectionPromise) {
this.pendingConnections.delete(lockKey);
}
}
}
private async createUserConnectionInternal(
{
serverName,
forceNew,
user,
flowManager,
customUserVars,
requestBody,
tokenMethods,
oauthStart,
oauthEnd,
signal,
returnOnOAuth = false,
connectionTimeout,
}: {
serverName: string;
forceNew?: boolean;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
userId: string,
): Promise<MCPConnection> {
if (await this.appConnections!.has(serverName)) {
throw new McpError(
ErrorCode.InvalidRequest,
@ -65,6 +102,9 @@ export abstract class UserConnectionManager {
const userServerMap = this.userConnections.get(userId);
let connection = forceNew ? undefined : userServerMap?.get(serverName);
if (forceNew) {
MCPConnection.clearCooldown(serverName);
}
const now = Date.now();
// Check if user is idle
@ -185,6 +225,7 @@ export abstract class UserConnectionManager {
/** Disconnects and removes a specific user connection */
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
this.pendingConnections.delete(`${userId}:${serverName}`);
const userMap = this.userConnections.get(userId);
const connection = userMap?.get(serverName);
if (connection) {
@ -212,6 +253,12 @@ export abstract class UserConnectionManager {
);
}
await Promise.allSettled(disconnectPromises);
// Clean up any pending connection promises for this user
for (const key of this.pendingConnections.keys()) {
if (key.startsWith(`${userId}:`)) {
this.pendingConnections.delete(key);
}
}
// Ensure user activity timestamp is removed
this.userLastActivity.delete(userId);
logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);

View file

@ -559,3 +559,242 @@ describe('extractSSEErrorMessage', () => {
});
});
});
/**
* Tests for circuit breaker logic.
*
* Uses standalone implementations that mirror the static/private circuit breaker
* methods in MCPConnection. Same approach as the error detection tests above.
*/
describe('MCPConnection Circuit Breaker', () => {
/** 5 cycles within 60s triggers a 30s cooldown */
const CB_MAX_CYCLES = 5;
const CB_CYCLE_WINDOW_MS = 60_000;
const CB_CYCLE_COOLDOWN_MS = 30_000;
/** 3 failed rounds within 120s triggers exponential backoff (30s - 300s) */
const CB_MAX_FAILED_ROUNDS = 3;
const CB_FAILED_WINDOW_MS = 120_000;
const CB_BASE_BACKOFF_MS = 30_000;
const CB_MAX_BACKOFF_MS = 300_000;
interface CircuitBreakerState {
cycleCount: number;
cycleWindowStart: number;
cooldownUntil: number;
failedRounds: number;
failedWindowStart: number;
failedBackoffUntil: number;
}
function createCB(): CircuitBreakerState {
return {
cycleCount: 0,
cycleWindowStart: Date.now(),
cooldownUntil: 0,
failedRounds: 0,
failedWindowStart: Date.now(),
failedBackoffUntil: 0,
};
}
function isCircuitOpen(cb: CircuitBreakerState): boolean {
const now = Date.now();
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
}
function recordCycle(cb: CircuitBreakerState): void {
const now = Date.now();
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
cb.cycleCount++;
if (cb.cycleCount >= CB_MAX_CYCLES) {
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
}
function recordFailedRound(cb: CircuitBreakerState): void {
const now = Date.now();
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
cb.failedRounds = 0;
cb.failedWindowStart = now;
}
cb.failedRounds++;
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
const backoff = Math.min(
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
CB_MAX_BACKOFF_MS,
);
cb.failedBackoffUntil = now + backoff;
}
}
function resetFailedRounds(cb: CircuitBreakerState): void {
cb.failedRounds = 0;
cb.failedWindowStart = Date.now();
cb.failedBackoffUntil = 0;
}
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
describe('cycle tracking', () => {
it('should not trigger cooldown for fewer than 5 cycles', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(false);
});
it('should trigger 30s cooldown after 5 cycles within 60s', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(29_000);
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(1_000);
expect(isCircuitOpen(cb)).toBe(false);
});
it('should reset cycle count when window expires', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
recordCycle(cb);
}
jest.advanceTimersByTime(CB_CYCLE_WINDOW_MS + 1);
recordCycle(cb);
expect(isCircuitOpen(cb)).toBe(false);
});
});
describe('failed round tracking', () => {
it('should not trigger backoff for fewer than 3 failures', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS - 1; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(false);
});
it('should trigger 30s backoff after 3 failures within 120s', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
jest.advanceTimersByTime(CB_BASE_BACKOFF_MS);
expect(isCircuitOpen(cb)).toBe(false);
});
it('should use exponential backoff based on failure count', () => {
jest.setSystemTime(Date.now());
const cb = createCB();
for (let i = 0; i < 3; i++) {
recordFailedRound(cb);
}
expect(cb.failedBackoffUntil - Date.now()).toBe(30_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(60_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(120_000);
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(240_000);
// capped at 300s
recordFailedRound(cb);
expect(cb.failedBackoffUntil - Date.now()).toBe(300_000);
});
it('should reset failed window when window expires', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
recordFailedRound(cb);
recordFailedRound(cb);
jest.advanceTimersByTime(CB_FAILED_WINDOW_MS + 1);
recordFailedRound(cb);
expect(isCircuitOpen(cb)).toBe(false);
});
});
describe('resetFailedRounds', () => {
it('should clear failed round state on successful connection', () => {
const now = Date.now();
jest.setSystemTime(now);
const cb = createCB();
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
recordFailedRound(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
resetFailedRounds(cb);
expect(isCircuitOpen(cb)).toBe(false);
expect(cb.failedRounds).toBe(0);
expect(cb.failedBackoffUntil).toBe(0);
});
});
describe('clearCooldown (registry deletion)', () => {
it('should allow connections after clearing circuit breaker state', () => {
const now = Date.now();
jest.setSystemTime(now);
const registry = new Map<string, CircuitBreakerState>();
const serverName = 'test-server';
const cb = createCB();
registry.set(serverName, cb);
for (let i = 0; i < CB_MAX_CYCLES; i++) {
recordCycle(cb);
}
expect(isCircuitOpen(cb)).toBe(true);
registry.delete(serverName);
const newCb = createCB();
expect(isCircuitOpen(newCb)).toBe(false);
});
});
});

View file

@ -207,6 +207,7 @@ describe('MCPConnection Agent lifecycle streamable-http', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();
@ -366,6 +367,7 @@ describe('MCPConnection Agent lifecycle SSE', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test-sse');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();
@ -453,6 +455,7 @@ describe('Regression: old per-request Agent pattern leaks agents', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test-regression');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();
@ -675,6 +678,7 @@ describe('MCPConnection SSE GET stream recovery integration', () => {
});
afterEach(async () => {
MCPConnection.clearCooldown('test-sse-recovery');
await safeDisconnect(conn);
conn = null;
jest.restoreAllMocks();

View file

@ -275,7 +275,7 @@ describe('MCPConnectionFactory', () => {
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
'flow123',
'mcp_oauth',
mockFlowData.flowMetadata,
expect.objectContaining(mockFlowData.flowMetadata),
);
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
@ -550,7 +550,7 @@ describe('MCPConnectionFactory', () => {
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
'flow123',
'mcp_oauth',
mockFlowData.flowMetadata,
expect.objectContaining(mockFlowData.flowMetadata),
);
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock

View file

@ -0,0 +1,232 @@
/**
* Tests for the OAuth callback CSRF fallback logic.
*
* The callback route validates requests via three mechanisms (in order):
* 1. CSRF cookie (HMAC-based, set during initiate)
* 2. Session cookie (bound to authenticated userId)
* 3. Active PENDING flow in FlowStateManager (fallback for SSE/chat flows)
*
* This suite tests mechanism 3 the PENDING flow fallback including
* staleness enforcement and rejection of non-PENDING flows.
*
* These tests exercise the validation functions directly for fast,
* focused coverage. Route-level integration tests using supertest
* are in api/server/routes/__tests__/mcp.spec.js ("CSRF fallback
* via active PENDING flow" describe block).
*/
import { Keyv } from 'keyv';
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
import type { Request, Response } from 'express';
import {
generateOAuthCsrfToken,
OAUTH_SESSION_COOKIE,
validateOAuthSession,
OAUTH_CSRF_COOKIE,
validateOAuthCsrf,
} from '~/oauth/csrf';
import { MockKeyv } from './helpers/oauthTestServer';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
const CSRF_COOKIE_PATH = '/api/mcp';
function makeReq(cookies: Record<string, string> = {}): Request {
return { cookies } as unknown as Request;
}
function makeRes(): Response {
const res = {
clearCookie: jest.fn(),
} as unknown as Response;
return res;
}
/**
* Replicate the callback route's three-tier validation logic.
* Returns which mechanism (if any) authorized the request.
*/
async function validateCallback(
req: Request,
res: Response,
flowId: string,
flowManager: FlowStateManager,
): Promise<'csrf' | 'session' | 'pendingFlow' | false> {
const flowUserId = flowId.split(':')[0];
const hasCsrf = validateOAuthCsrf(req, res, flowId, CSRF_COOKIE_PATH);
if (hasCsrf) {
return 'csrf';
}
const hasSession = validateOAuthSession(req, flowUserId);
if (hasSession) {
return 'session';
}
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
if (pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS) {
return 'pendingFlow';
}
return false;
}
describe('OAuth Callback CSRF Fallback', () => {
let flowManager: FlowStateManager;
beforeEach(() => {
process.env.JWT_SECRET = 'test-secret-for-csrf';
const store = new MockKeyv();
flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 300000, ci: true });
});
afterEach(() => {
delete process.env.JWT_SECRET;
jest.clearAllMocks();
});
describe('CSRF cookie validation (mechanism 1)', () => {
it('should accept valid CSRF cookie', async () => {
const flowId = 'user1:test-server';
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('csrf');
});
it('should reject invalid CSRF cookie', async () => {
const flowId = 'user1:test-server';
const req = makeReq({ [OAUTH_CSRF_COOKIE]: 'wrong-token-value' });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
});
describe('Session cookie validation (mechanism 2)', () => {
it('should accept valid session cookie when CSRF is absent', async () => {
const flowId = 'user1:test-server';
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('session');
});
});
describe('PENDING flow fallback (mechanism 3)', () => {
it('should accept when a fresh PENDING flow exists and no cookies are present', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('pendingFlow');
});
it('should reject when no PENDING flow, no CSRF cookie, and no session cookie', async () => {
const flowId = 'user1:test-server';
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
await flowManager.completeFlow(flowId, 'mcp_oauth', { access_token: 'tok' } as never);
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should reject when only a FAILED flow exists', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
await flowManager.failFlow(flowId, 'mcp_oauth', 'some error');
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
// Artificially age the flow past the staleness threshold
const store = (flowManager as unknown as { keyv: { get: (k: string) => Promise<unknown> } })
.keyv;
const flowState = (await store.get(`mcp_oauth:${flowId}`)) as { createdAt: number };
flowState.createdAt = Date.now() - PENDING_STALE_MS - 1000;
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe(false);
});
it('should accept PENDING flow that is just under the staleness threshold', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
// Flow was just created, well under threshold
const req = makeReq();
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('pendingFlow');
});
});
describe('Priority ordering', () => {
it('should prefer CSRF cookie over PENDING flow', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('csrf');
});
it('should prefer session cookie over PENDING flow when CSRF is absent', async () => {
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
const res = makeRes();
const result = await validateCallback(req, res, flowId, flowManager);
expect(result).toBe('session');
});
});
});

View file

@ -0,0 +1,268 @@
/**
* Tests for MCPConnection OAuth event cycle against a real OAuth-gated MCP server.
*
* Verifies: oauthRequired emission on 401, oauthHandled reconnection,
* oauthFailed rejection, timeout behavior, and token expiry mid-session.
*/
import { MCPConnection } from '~/mcp/connection';
import { createOAuthMCPServer } from './helpers/oauthTestServer';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
jest.mock('~/auth', () => ({
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
resolveHostnameSSRF: jest.fn(async () => false),
}));
jest.mock('~/mcp/mcpConfig', () => ({
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
}));
async function safeDisconnect(conn: MCPConnection | null): Promise<void> {
if (!conn) {
return;
}
try {
await conn.disconnect();
} catch {
// Ignore disconnect errors during cleanup
}
}
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
redirect: 'manual',
});
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${serverUrl}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const data = (await tokenRes.json()) as { access_token: string };
return data.access_token;
}
describe('MCPConnection OAuth Events — Real Server', () => {
let server: OAuthTestServer;
let connection: MCPConnection | null = null;
beforeEach(() => {
MCPConnection.clearCooldown('test-server');
});
afterEach(async () => {
await safeDisconnect(connection);
connection = null;
if (server) {
await server.close();
}
jest.clearAllMocks();
});
describe('oauthRequired event', () => {
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
it('should emit oauthRequired when connecting without a token', async () => {
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: { type: 'streamable-http', url: server.url },
userId: 'user-1',
});
const oauthRequiredPromise = new Promise<{
serverName: string;
error: Error;
serverUrl?: string;
userId?: string;
}>((resolve) => {
connection!.on('oauthRequired', (data) => {
resolve(
data as {
serverName: string;
error: Error;
serverUrl?: string;
userId?: string;
},
);
});
});
// Connection will fail with 401, emitting oauthRequired
const connectPromise = connection.connect().catch(() => {
// Expected to fail since no one handles oauthRequired
});
let raceTimer: NodeJS.Timeout | undefined;
const eventData = await Promise.race([
oauthRequiredPromise,
new Promise<never>((_, reject) => {
raceTimer = setTimeout(
() => reject(new Error('Timed out waiting for oauthRequired')),
10000,
);
}),
]).finally(() => clearTimeout(raceTimer));
expect(eventData.serverName).toBe('test-server');
expect(eventData.error).toBeDefined();
// Emit oauthFailed to unblock connect()
connection.emit('oauthFailed', new Error('test cleanup'));
await connectPromise.catch(() => undefined);
});
it('should not emit oauthRequired when connecting with a valid token', async () => {
const accessToken = await exchangeCodeForToken(server.url);
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: { type: 'streamable-http', url: server.url },
userId: 'user-1',
oauthTokens: {
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens,
});
let oauthFired = false;
connection.on('oauthRequired', () => {
oauthFired = true;
});
await connection.connect();
expect(await connection.isConnected()).toBe(true);
expect(oauthFired).toBe(false);
});
});
describe('oauthHandled reconnection', () => {
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
it('should succeed on retry after oauthHandled provides valid tokens', async () => {
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: {
type: 'streamable-http',
url: server.url,
initTimeout: 15000,
},
userId: 'user-1',
});
// First connect fails with 401 → oauthRequired fires
let oauthFired = false;
connection.on('oauthRequired', () => {
oauthFired = true;
connection!.emit('oauthFailed', new Error('Will retry with tokens'));
});
// First attempt fails as expected
await expect(connection.connect()).rejects.toThrow();
expect(oauthFired).toBe(true);
// Now set valid tokens and reconnect
const accessToken = await exchangeCodeForToken(server.url);
connection.setOAuthTokens({
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
await connection.connect();
expect(await connection.isConnected()).toBe(true);
});
});
describe('oauthFailed rejection', () => {
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
it('should reject connect() when oauthFailed is emitted', async () => {
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: {
type: 'streamable-http',
url: server.url,
initTimeout: 15000,
},
userId: 'user-1',
});
connection.on('oauthRequired', () => {
connection!.emit('oauthFailed', new Error('User denied OAuth'));
});
await expect(connection.connect()).rejects.toThrow();
});
});
describe('Token expiry during session', () => {
it('should detect expired token on reconnect and emit oauthRequired', async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 1000 });
const accessToken = await exchangeCodeForToken(server.url);
connection = new MCPConnection({
serverName: 'test-server',
serverConfig: {
type: 'streamable-http',
url: server.url,
initTimeout: 15000,
},
userId: 'user-1',
oauthTokens: {
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens,
});
// Initial connect should succeed
await connection.connect();
expect(await connection.isConnected()).toBe(true);
await connection.disconnect();
// Wait for token to expire
await new Promise((r) => setTimeout(r, 1200));
// Reconnect should trigger oauthRequired since token is expired on the server
let oauthFired = false;
connection.on('oauthRequired', () => {
oauthFired = true;
connection!.emit('oauthFailed', new Error('Will retry with fresh token'));
});
// First reconnect fails with 401 → oauthRequired
await expect(connection.connect()).rejects.toThrow();
expect(oauthFired).toBe(true);
// Get fresh token and reconnect
const newToken = await exchangeCodeForToken(server.url);
connection.setOAuthTokens({
access_token: newToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
await connection.connect();
expect(await connection.isConnected()).toBe(true);
});
});
});

View file

@ -0,0 +1,538 @@
/**
* OAuth flow tests against a real HTTP server.
*
* Tests MCPOAuthHandler.refreshOAuthTokens and MCPTokenStorage lifecycle
* using a real test OAuth server (not mocked SDK functions).
*/
import { createHash } from 'crypto';
import { Keyv } from 'keyv';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { FlowStateManager } from '~/flow/manager';
import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
describe('MCP OAuth Flow — Real HTTP Server', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('Token refresh with real server', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
});
});
afterEach(async () => {
await server.close();
});
it('should refresh tokens with stored client info via real /token endpoint', async () => {
// First get initial tokens
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
refresh_token: string;
};
expect(initial.refresh_token).toBeDefined();
// Register a client so we have clientInfo
const regRes = await fetch(`${server.url}register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
});
const clientInfo = (await regRes.json()) as {
client_id: string;
client_secret: string;
};
// Refresh tokens using the real endpoint
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
initial.refresh_token,
{
serverName: 'test-server',
serverUrl: server.url,
clientInfo: {
...clientInfo,
redirect_uris: ['http://localhost/callback'],
},
},
{},
{
token_url: `${server.url}token`,
client_id: clientInfo.client_id,
client_secret: clientInfo.client_secret,
token_exchange_method: 'DefaultPost',
},
);
expect(refreshed.access_token).toBeDefined();
expect(refreshed.access_token).not.toBe(initial.access_token);
expect(refreshed.token_type).toBe('Bearer');
expect(refreshed.obtained_at).toBeDefined();
});
it('should get new refresh token when server rotates', async () => {
const rotatingServer = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
rotateRefreshTokens: true,
});
try {
const code = await rotatingServer.getAuthCode();
const tokenRes = await fetch(`${rotatingServer.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
refresh_token: string;
};
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
initial.refresh_token,
{
serverName: 'test-server',
serverUrl: rotatingServer.url,
},
{},
{
token_url: `${rotatingServer.url}token`,
client_id: 'anon',
token_exchange_method: 'DefaultPost',
},
);
expect(refreshed.access_token).not.toBe(initial.access_token);
expect(refreshed.refresh_token).toBeDefined();
expect(refreshed.refresh_token).not.toBe(initial.refresh_token);
} finally {
await rotatingServer.close();
}
});
it('should fail refresh with invalid refresh token', async () => {
await expect(
MCPOAuthHandler.refreshOAuthTokens(
'invalid-refresh-token',
{
serverName: 'test-server',
serverUrl: server.url,
},
{},
{
token_url: `${server.url}token`,
client_id: 'anon',
token_exchange_method: 'DefaultPost',
},
),
).rejects.toThrow();
});
});
describe('OAuth server metadata discovery', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ issueRefreshTokens: true });
});
afterEach(async () => {
await server.close();
});
it('should expose /.well-known/oauth-authorization-server', async () => {
const res = await fetch(`${server.url}.well-known/oauth-authorization-server`);
expect(res.status).toBe(200);
const metadata = (await res.json()) as {
authorization_endpoint: string;
token_endpoint: string;
registration_endpoint: string;
grant_types_supported: string[];
};
expect(metadata.authorization_endpoint).toContain('/authorize');
expect(metadata.token_endpoint).toContain('/token');
expect(metadata.registration_endpoint).toContain('/register');
expect(metadata.grant_types_supported).toContain('authorization_code');
expect(metadata.grant_types_supported).toContain('refresh_token');
});
it('should not advertise refresh_token grant when disabled', async () => {
const noRefreshServer = await createOAuthMCPServer({
issueRefreshTokens: false,
});
try {
const res = await fetch(`${noRefreshServer.url}.well-known/oauth-authorization-server`);
const metadata = (await res.json()) as { grant_types_supported: string[] };
expect(metadata.grant_types_supported).not.toContain('refresh_token');
} finally {
await noRefreshServer.close();
}
});
});
describe('Dynamic client registration', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer();
});
afterEach(async () => {
await server.close();
});
it('should register a client via /register endpoint', async () => {
const res = await fetch(`${server.url}register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
redirect_uris: ['http://localhost/callback'],
}),
});
expect(res.status).toBe(200);
const client = (await res.json()) as {
client_id: string;
client_secret: string;
redirect_uris: string[];
};
expect(client.client_id).toBeDefined();
expect(client.client_secret).toBeDefined();
expect(client.redirect_uris).toEqual(['http://localhost/callback']);
expect(server.registeredClients.has(client.client_id)).toBe(true);
});
});
describe('End-to-End: store, retrieve, expire, refresh cycle', () => {
it('should perform full token lifecycle with real server', async () => {
const server = await createOAuthMCPServer({
tokenTTLMs: 1000,
issueRefreshTokens: true,
});
const tokenStore = new InMemoryTokenStore();
try {
// 1. Get initial tokens via auth code exchange
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// 2. Store tokens
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'test-srv',
tokens: initial,
createToken: tokenStore.createToken,
});
// 3. Retrieve — should succeed
const valid = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
});
expect(valid).not.toBeNull();
expect(valid!.access_token).toBe(initial.access_token);
expect(valid!.refresh_token).toBe(initial.refresh_token);
// 4. Wait for expiry
await new Promise((r) => setTimeout(r, 1200));
// 5. Retrieve again — should trigger refresh via callback
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const refreshRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!refreshRes.ok) {
throw new Error(`Refresh failed: ${refreshRes.status}`);
}
const data = (await refreshRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string;
};
return {
...data,
obtained_at: Date.now(),
expires_at: Date.now() + data.expires_in * 1000,
};
};
const refreshed = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(refreshed).not.toBeNull();
expect(refreshed!.access_token).not.toBe(initial.access_token);
// 6. Verify the refreshed token works against the server
const mcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${refreshed!.access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(mcpRes.status).toBe(200);
} finally {
await server.close();
}
});
});
describe('completeOAuthFlow via FlowStateManager', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ issueRefreshTokens: true });
});
afterEach(async () => {
await server.close();
});
it('should exchange auth code and complete flow in FlowStateManager', async () => {
const store = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'test-user:test-server';
const code = await server.getAuthCode();
// Initialize the flow with metadata the handler needs
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverUrl: server.url,
clientInfo: {
client_id: 'test-client',
redirect_uris: ['http://localhost/callback'],
},
codeVerifier: 'test-verifier',
metadata: {
token_endpoint: `${server.url}token`,
token_endpoint_auth_methods_supported: ['client_secret_post'],
},
});
// The SDK's exchangeAuthorization wants full OAuth metadata,
// so we'll test the token exchange directly instead of going through
// completeOAuthFlow (which requires full SDK-compatible metadata)
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const tokens = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string;
};
const mcpTokens: MCPOAuthTokens = {
...tokens,
obtained_at: Date.now(),
expires_at: Date.now() + tokens.expires_in * 1000,
};
// Complete the flow
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
expect(completed).toBe(true);
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('COMPLETED');
expect(state?.result?.access_token).toBe(tokens.access_token);
});
it('should fail flow when authorization code is invalid', async () => {
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: 'grant_type=authorization_code&code=invalid-code',
});
expect(tokenRes.status).toBe(400);
const body = (await tokenRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
it('should fail when authorization code is reused', async () => {
const code = await server.getAuthCode();
// First exchange succeeds
const firstRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(firstRes.status).toBe(200);
// Second exchange fails
const secondRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(secondRes.status).toBe(400);
const body = (await secondRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
});
describe('PKCE verification', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
afterEach(async () => {
await server.close();
});
function generatePKCE(): { verifier: string; challenge: string } {
const verifier = 'dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk';
const challenge = createHash('sha256').update(verifier).digest('base64url');
return { verifier, challenge };
}
it('should accept valid code_verifier matching code_challenge', async () => {
const { verifier, challenge } = generatePKCE();
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}&code_verifier=${verifier}`,
});
expect(tokenRes.status).toBe(200);
const data = (await tokenRes.json()) as { access_token: string };
expect(data.access_token).toBeDefined();
});
it('should reject wrong code_verifier', async () => {
const { challenge } = generatePKCE();
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}&code_verifier=wrong-verifier`,
});
expect(tokenRes.status).toBe(400);
const body = (await tokenRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
it('should reject missing code_verifier when code_challenge was provided', async () => {
const { challenge } = generatePKCE();
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(tokenRes.status).toBe(400);
const body = (await tokenRes.json()) as { error: string };
expect(body.error).toBe('invalid_grant');
});
it('should still accept codes without PKCE when no code_challenge was provided', async () => {
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(tokenRes.status).toBe(200);
});
});
});

View file

@ -0,0 +1,516 @@
/**
* Tests for MCP OAuth race condition fixes:
*
* 1. Connection mutex coalesces concurrent getUserConnection() calls
* 2. PENDING OAuth flows are reused, not deleted
* 3. No-refresh-token expiry throws ReauthenticationRequiredError
* 4. completeFlow recovers when flow state was deleted by a race
* 5. monitorFlow retries once when flow state disappears mid-poll
*/
import { Keyv } from 'keyv';
import { logger } from '@librechat/data-schemas';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
import { MockKeyv, createOAuthMCPServer } from './helpers/oauthTestServer';
import { FlowStateManager } from '~/flow/manager';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
jest.mock('~/auth', () => ({
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
resolveHostnameSSRF: jest.fn(async () => false),
}));
jest.mock('~/mcp/mcpConfig', () => ({
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
}));
const mockLogger = logger as jest.Mocked<typeof logger>;
describe('MCP OAuth Race Condition Fixes', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('Fix 1: Connection mutex coalesces concurrent attempts', () => {
it('should return the same pending promise for concurrent getUserConnection calls', async () => {
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
class TestManager extends UserConnectionManager {
public createCallCount = 0;
getPendingConnections() {
return this.pendingConnections;
}
}
const manager = new TestManager();
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn().mockResolvedValue(undefined),
isStale: jest.fn().mockReturnValue(false),
};
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
manager.appConnections = mockAppConnections as never;
const mockConfig = {
type: 'streamable-http',
url: 'http://localhost:9999/',
updatedAt: undefined,
dbId: undefined,
};
jest
.spyOn(
// eslint-disable-next-line @typescript-eslint/no-require-imports
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
'getInstance',
)
.mockReturnValue({
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
});
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
const createSpy = jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
manager.createCallCount++;
await new Promise((r) => setTimeout(r, 100));
return mockConnection as never;
});
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const user = { id: 'user-1' };
const opts = {
serverName: 'test-server',
user: user as never,
flowManager: flowManager as never,
};
const [conn1, conn2, conn3] = await Promise.all([
manager.getUserConnection(opts),
manager.getUserConnection(opts),
manager.getUserConnection(opts),
]);
expect(conn1).toBe(conn2);
expect(conn2).toBe(conn3);
expect(createSpy).toHaveBeenCalledTimes(1);
expect(manager.createCallCount).toBe(1);
createSpy.mockRestore();
});
it('should not coalesce when forceNew is true', async () => {
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
class TestManager extends UserConnectionManager {}
const manager = new TestManager();
let callCount = 0;
const makeConnection = () => ({
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn().mockResolvedValue(undefined),
isStale: jest.fn().mockReturnValue(false),
});
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
manager.appConnections = mockAppConnections as never;
const mockConfig = {
type: 'streamable-http',
url: 'http://localhost:9999/',
updatedAt: undefined,
dbId: undefined,
};
jest
.spyOn(
// eslint-disable-next-line @typescript-eslint/no-require-imports
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
'getInstance',
)
.mockReturnValue({
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
});
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
callCount++;
await new Promise((r) => setTimeout(r, 50));
return makeConnection() as never;
});
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const user = { id: 'user-2' };
const [conn1, conn2] = await Promise.all([
manager.getUserConnection({
serverName: 'test-server',
forceNew: true,
user: user as never,
flowManager: flowManager as never,
}),
manager.getUserConnection({
serverName: 'test-server',
forceNew: true,
user: user as never,
flowManager: flowManager as never,
}),
]);
expect(callCount).toBe(2);
expect(conn1).not.toBe(conn2);
});
});
describe('Fix 2: PENDING flow is reused, not deleted', () => {
it('should join an existing PENDING flow via createFlow instead of deleting it', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'test-flow-pending';
await flowManager.initFlow(flowId, 'mcp_oauth', {
clientInfo: { client_id: 'test-client' },
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('PENDING');
const deleteSpy = jest.spyOn(flowManager, 'deleteFlow');
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
await new Promise((r) => setTimeout(r, 500));
await flowManager.completeFlow(flowId, 'mcp_oauth', {
access_token: 'test-token',
token_type: 'Bearer',
} as never);
const result = await monitorPromise;
expect(result).toEqual(
expect.objectContaining({ access_token: 'test-token', token_type: 'Bearer' }),
);
expect(deleteSpy).not.toHaveBeenCalled();
deleteSpy.mockRestore();
});
it('should delete and recreate FAILED flows', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'test-flow-failed';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
await flowManager.failFlow(flowId, 'mcp_oauth', 'previous error');
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('FAILED');
await flowManager.deleteFlow(flowId, 'mcp_oauth');
const afterDelete = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(afterDelete).toBeUndefined();
});
});
describe('Fix 3: completeFlow handles deleted state gracefully', () => {
it('should return false when state was deleted by race', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'race-deleted-flow';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
await flowManager.deleteFlow(flowId, 'mcp_oauth');
const stateBeforeComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(stateBeforeComplete).toBeUndefined();
const result = await flowManager.completeFlow(flowId, 'mcp_oauth', {
access_token: 'recovered-token',
token_type: 'Bearer',
} as never);
expect(result).toBe(false);
const stateAfterComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(stateAfterComplete).toBeUndefined();
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('cannot recover metadata'),
expect.any(Object),
);
});
it('should reject monitorFlow when state is deleted and not recoverable', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
const flowId = 'monitor-retry-flow';
await flowManager.initFlow(flowId, 'mcp_oauth', {});
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
await new Promise((r) => setTimeout(r, 500));
await flowManager.deleteFlow(flowId, 'mcp_oauth');
await expect(monitorPromise).rejects.toThrow('Flow state not found');
});
});
describe('State mapping cleanup on flow replacement', () => {
it('should delete old state mapping when a flow is replaced', async () => {
const store = new MockKeyv();
const flowManager = new FlowStateManager<MCPOAuthTokens | null>(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'user1:test-server';
const oldState = 'old-random-state-abc123';
const newState = 'new-random-state-xyz789';
// Simulate initial flow with state mapping
await flowManager.initFlow(flowId, 'mcp_oauth', { state: oldState });
await MCPOAuthHandler.storeStateMapping(oldState, flowId, flowManager);
// Old state should resolve
const resolvedBefore = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
expect(resolvedBefore).toBe(flowId);
// Replace the flow: delete old, create new, clean up old state mapping
await flowManager.deleteFlow(flowId, 'mcp_oauth');
await MCPOAuthHandler.deleteStateMapping(oldState, flowManager);
await flowManager.initFlow(flowId, 'mcp_oauth', { state: newState });
await MCPOAuthHandler.storeStateMapping(newState, flowId, flowManager);
// Old state should no longer resolve
const resolvedOld = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
expect(resolvedOld).toBeNull();
// New state should resolve
const resolvedNew = await MCPOAuthHandler.resolveStateToFlowId(newState, flowManager);
expect(resolvedNew).toBe(flowId);
});
});
describe('Fix 4: ReauthenticationRequiredError for no-refresh-token', () => {
it('should throw ReauthenticationRequiredError when access token expired and no refresh token', async () => {
const expiredDate = new Date(Date.now() - 60000);
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
if (filter.type === 'mcp_oauth') {
return {
token: 'enc:expired-access-token',
expiresAt: expiredDate,
createdAt: new Date(Date.now() - 120000),
};
}
if (filter.type === 'mcp_oauth_refresh') {
return null;
}
return null;
});
await expect(
MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
await expect(
MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
}),
).rejects.toThrow('Re-authentication required');
});
it('should throw ReauthenticationRequiredError when access token is missing and no refresh token', async () => {
const findToken = jest.fn().mockResolvedValue(null);
await expect(
MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should not throw when access token is valid', async () => {
const futureDate = new Date(Date.now() + 3600000);
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
if (filter.type === 'mcp_oauth') {
return {
token: 'enc:valid-access-token',
expiresAt: futureDate,
createdAt: new Date(),
};
}
if (filter.type === 'mcp_oauth_refresh') {
return null;
}
return null;
});
const result = await MCPTokenStorage.getTokens({
userId: 'user-1',
serverName: 'test-server',
findToken,
});
expect(result).not.toBeNull();
expect(result?.access_token).toBe('valid-access-token');
});
});
describe('E2E: OAuth-gated MCP server with no refresh tokens', () => {
let server: OAuthTestServer;
beforeEach(async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
afterEach(async () => {
await server.close();
});
it('should start OAuth-gated MCP server that validates Bearer tokens', async () => {
const res = await fetch(server.url, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ jsonrpc: '2.0', method: 'initialize', id: 1 }),
});
expect(res.status).toBe(401);
const body = (await res.json()) as { error: string };
expect(body.error).toBe('invalid_token');
});
it('should issue tokens via authorization code exchange with no refresh token', async () => {
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
redirect: 'manual',
});
expect(authRes.status).toBe(302);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
expect(code).toBeTruthy();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
expect(tokenRes.status).toBe(200);
const tokenBody = (await tokenRes.json()) as {
access_token: string;
token_type: string;
refresh_token?: string;
};
expect(tokenBody.access_token).toBeTruthy();
expect(tokenBody.token_type).toBe('Bearer');
expect(tokenBody.refresh_token).toBeUndefined();
});
it('should allow MCP requests with valid Bearer token', async () => {
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
redirect: 'manual',
});
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const { access_token } = (await tokenRes.json()) as { access_token: string };
const mcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(mcpRes.status).toBe(200);
});
it('should reject expired tokens with 401', async () => {
const shortTTLServer = await createOAuthMCPServer({ tokenTTLMs: 500 });
try {
const authRes = await fetch(
`${shortTTLServer.url}authorize?redirect_uri=http://localhost&state=s1`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
const tokenRes = await fetch(`${shortTTLServer.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const { access_token } = (await tokenRes.json()) as { access_token: string };
await new Promise((r) => setTimeout(r, 600));
const mcpRes = await fetch(shortTTLServer.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${access_token}`,
},
body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 2 }),
});
expect(mcpRes.status).toBe(401);
} finally {
await shortTTLServer.close();
}
});
});
});

View file

@ -0,0 +1,654 @@
/**
* Tests for MCP OAuth token expiry re-authentication scenarios.
*
* Reproduces the edge case where:
* 1. Tokens are stored (access + refresh)
* 2. Access token expires
* 3. Refresh attempt fails (server rejects/revokes refresh token)
* 4. System must fall back to full OAuth re-auth via handleOAuthRequired
* 5. The CSRF cookie may be absent (chat/SSE flow), so the PENDING flow fallback is needed
*
* Also tests the happy path: access token expired but refresh succeeds.
*/
import { Keyv } from 'keyv';
import { logger } from '@librechat/data-schemas';
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
import { MockKeyv, InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
describe('MCP OAuth Token Expiry Scenarios', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('Access token expired + refresh token available + refresh succeeds', () => {
let server: OAuthTestServer;
let tokenStore: InMemoryTokenStore;
beforeEach(async () => {
server = await createOAuthMCPServer({
tokenTTLMs: 500,
issueRefreshTokens: true,
});
tokenStore = new InMemoryTokenStore();
});
afterEach(async () => {
await server.close();
});
it('should refresh expired access token via real /token endpoint', async () => {
// Get initial tokens from real server
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// Store expired access token directly (bypassing storeTokens' expiresIn clamping)
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: `enc:${initial.access_token}`,
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: `enc:${initial.refresh_token}`,
expiresIn: 86400,
});
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const res = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!res.ok) {
throw new Error(`Refresh failed: ${res.status}`);
}
const data = (await res.json()) as {
access_token: string;
token_type: string;
expires_in: number;
};
return {
...data,
obtained_at: Date.now(),
expires_at: Date.now() + data.expires_in * 1000,
};
};
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(result).not.toBeNull();
expect(result!.access_token).not.toBe(initial.access_token);
// Verify the refreshed token works against the server
const mcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${result!.access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(mcpRes.status).toBe(200);
});
});
describe('Access token expired + refresh token rejected by server', () => {
let tokenStore: InMemoryTokenStore;
beforeEach(() => {
tokenStore = new InMemoryTokenStore();
});
it('should return null when refresh token is rejected (invalid_grant)', async () => {
const server = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
});
try {
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// Store expired access token directly
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: `enc:${initial.access_token}`,
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: `enc:${initial.refresh_token}`,
expiresIn: 86400,
});
// Simulate server revoking the refresh token
server.issuedRefreshTokens.clear();
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const res = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!res.ok) {
const body = (await res.json()) as { error: string };
throw new Error(`Token refresh failed: ${body.error}`);
}
const data = (await res.json()) as MCPOAuthTokens;
return data;
};
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(result).toBeNull();
expect(logger.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to refresh tokens'),
expect.any(Error),
);
} finally {
await server.close();
}
});
it('should return null when refresh endpoint returns unauthorized_client', async () => {
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: 'enc:some-refresh-token',
expiresIn: 86400,
});
const refreshCallback = jest
.fn()
.mockRejectedValue(new Error('unauthorized_client: client not authorized for refresh'));
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
expect(result).toBeNull();
expect(logger.info).toHaveBeenCalledWith(
expect.stringContaining('does not support refresh tokens'),
);
});
});
describe('Access token expired + NO refresh token → ReauthenticationRequiredError', () => {
let tokenStore: InMemoryTokenStore;
beforeEach(() => {
tokenStore = new InMemoryTokenStore();
});
it('should throw ReauthenticationRequiredError when no refresh token stored', async () => {
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should throw ReauthenticationRequiredError with correct reason for expired token', async () => {
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
}),
).rejects.toThrow('access token expired');
});
it('should throw ReauthenticationRequiredError with correct reason for missing token', async () => {
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
}),
).rejects.toThrow('access token missing');
});
});
describe('PENDING flow fallback for CSRF-less OAuth callbacks', () => {
it('should allow OAuth completion when PENDING flow exists (simulating chat/SSE path)', async () => {
const store = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
userId: 'user1',
serverUrl: 'https://example.com',
state: 'test-state',
authorizationUrl: 'https://example.com/authorize?state=user1:test-server',
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(state?.status).toBe('PENDING');
const tokens: MCPOAuthTokens = {
access_token: 'new-access-token',
token_type: 'Bearer',
refresh_token: 'new-refresh-token',
obtained_at: Date.now(),
expires_at: Date.now() + 3600000,
};
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', tokens);
expect(completed).toBe(true);
const completedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(completedState?.status).toBe('COMPLETED');
expect((completedState?.result as MCPOAuthTokens | undefined)?.access_token).toBe(
'new-access-token',
);
});
it('should store authorizationUrl in flow metadata for re-issuance', async () => {
const store = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(store as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'user1:test-server';
const authUrl = 'https://auth.example.com/authorize?client_id=abc&state=user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
userId: 'user1',
serverUrl: 'https://example.com',
state: 'test-state',
authorizationUrl: authUrl,
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect((state?.metadata as Record<string, unknown>)?.authorizationUrl).toBe(authUrl);
});
});
describe('Full token expiry → refresh failure → re-auth flow', () => {
let server: OAuthTestServer;
let tokenStore: InMemoryTokenStore;
beforeEach(async () => {
server = await createOAuthMCPServer({
tokenTTLMs: 60000,
issueRefreshTokens: true,
});
tokenStore = new InMemoryTokenStore();
});
afterEach(async () => {
await server.close();
});
it('should go through full cycle: get tokens → expire → refresh fails → re-auth needed', async () => {
// Step 1: Get initial tokens
const code = await server.getAuthCode();
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const initial = (await tokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
};
// Step 2: Store tokens with valid expiry first
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'test-srv',
tokens: initial,
createToken: tokenStore.createToken,
});
// Step 3: Verify tokens work
const validResult = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
});
expect(validResult).not.toBeNull();
expect(validResult!.access_token).toBe(initial.access_token);
// Step 4: Simulate token expiry by directly updating the stored token's expiresAt
await tokenStore.updateToken({ userId: 'u1', identifier: 'mcp:test-srv' }, { expiresIn: -1 });
// Step 5: Revoke refresh token on server side (simulating server-side revocation)
server.issuedRefreshTokens.clear();
// Step 6: Try to get tokens — refresh should fail, return null
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
const res = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
});
if (!res.ok) {
const body = (await res.json()) as { error: string };
throw new Error(`Refresh failed: ${body.error}`);
}
const data = (await res.json()) as MCPOAuthTokens;
return data;
};
const expiredResult = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
// Refresh failed → returns null → triggers OAuth re-auth flow
expect(expiredResult).toBeNull();
// Step 7: Simulate the re-auth flow via FlowStateManager
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const flowId = 'u1:test-srv';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-srv',
userId: 'u1',
serverUrl: server.url,
state: 'test-state',
authorizationUrl: `${server.url}authorize?state=${flowId}`,
});
// Step 8: Get a new auth code and exchange for tokens (simulating user re-auth)
const newCode = await server.getAuthCode();
const newTokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${newCode}`,
});
const newTokens = (await newTokenRes.json()) as {
access_token: string;
token_type: string;
expires_in: number;
refresh_token?: string;
};
// Step 9: Complete the flow
const mcpTokens: MCPOAuthTokens = {
...newTokens,
obtained_at: Date.now(),
expires_at: Date.now() + newTokens.expires_in * 1000,
};
await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
// Step 10: Store the new tokens
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'test-srv',
tokens: mcpTokens,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
findToken: tokenStore.findToken,
});
// Step 11: Verify new tokens work
const newResult = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
});
expect(newResult).not.toBeNull();
expect(newResult!.access_token).toBe(newTokens.access_token);
// Step 12: Verify new token works against server
const finalMcpRes = await fetch(server.url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Accept: 'application/json, text/event-stream',
Authorization: `Bearer ${newResult!.access_token}`,
},
body: JSON.stringify({
jsonrpc: '2.0',
method: 'initialize',
id: 1,
params: {
protocolVersion: '2025-03-26',
capabilities: {},
clientInfo: { name: 'test', version: '0.0.1' },
},
}),
});
expect(finalMcpRes.status).toBe(200);
});
});
describe('Concurrent token expiry with connection mutex', () => {
it('should handle multiple concurrent getTokens calls when token is expired', async () => {
const tokenStore = new InMemoryTokenStore();
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:test-srv',
token: 'enc:expired-token',
expiresIn: -1,
});
await tokenStore.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:test-srv:refresh',
token: 'enc:valid-refresh',
expiresIn: 86400,
});
let refreshCallCount = 0;
const refreshCallback = jest.fn().mockImplementation(async () => {
refreshCallCount++;
await new Promise((r) => setTimeout(r, 100));
return {
access_token: `refreshed-token-${refreshCallCount}`,
token_type: 'Bearer',
expires_in: 3600,
obtained_at: Date.now(),
expires_at: Date.now() + 3600000,
};
});
// Fire 3 concurrent getTokens calls via FlowStateManager (like the connection mutex does)
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 30000,
ci: true,
});
const getTokensViaFlow = () =>
flowManager.createFlowWithHandler('u1:test-srv', 'mcp_get_tokens', async () => {
return await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'test-srv',
findToken: tokenStore.findToken,
createToken: tokenStore.createToken,
updateToken: tokenStore.updateToken,
refreshTokens: refreshCallback,
});
});
const [r1, r2, r3] = await Promise.all([
getTokensViaFlow(),
getTokensViaFlow(),
getTokensViaFlow(),
]);
// All should get tokens (either directly or via flow coalescing)
expect(r1).not.toBeNull();
expect(r2).not.toBeNull();
expect(r3).not.toBeNull();
// The refresh callback should only be called once due to flow coalescing
expect(refreshCallback).toHaveBeenCalledTimes(1);
});
});
describe('Stale PENDING flow detection', () => {
it('should treat PENDING flows older than 2 minutes as stale', async () => {
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 300000,
ci: true,
});
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
authorizationUrl: 'https://example.com/auth',
});
// Manually age the flow to 3 minutes
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (state) {
state.createdAt = Date.now() - 3 * 60 * 1000;
await (flowStore as unknown as { set: (k: string, v: unknown) => Promise<void> }).set(
`mcp_oauth:${flowId}`,
state,
);
}
const agedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
expect(agedState?.status).toBe('PENDING');
const age = agedState?.createdAt ? Date.now() - agedState.createdAt : 0;
expect(age).toBeGreaterThan(2 * 60 * 1000);
// A new flow should be created (the stale one would be deleted + recreated)
// This verifies our staleness check threshold
expect(age > PENDING_STALE_MS).toBe(true);
});
it('should not treat recent PENDING flows as stale', async () => {
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
ttl: 300000,
ci: true,
});
const flowId = 'user1:test-server';
await flowManager.initFlow(flowId, 'mcp_oauth', {
serverName: 'test-server',
authorizationUrl: 'https://example.com/auth',
});
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
const age = state?.createdAt ? Date.now() - state.createdAt : Infinity;
expect(age < PENDING_STALE_MS).toBe(true);
});
});
});

View file

@ -0,0 +1,544 @@
/**
* Integration tests for MCPTokenStorage.storeTokens() and MCPTokenStorage.getTokens().
*
* Uses InMemoryTokenStore to exercise encrypt/decrypt round-trips, expiry calculation,
* refresh callback wiring, and ReauthenticationRequiredError paths.
*/
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
import { InMemoryTokenStore } from './helpers/oauthTestServer';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
}));
describe('MCPTokenStorage', () => {
let store: InMemoryTokenStore;
beforeEach(() => {
store = new InMemoryTokenStore();
jest.clearAllMocks();
});
describe('storeTokens', () => {
it('should create new access token with expires_in', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved).not.toBeNull();
expect(saved!.token).toBe('enc:at1');
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
expect(expiresInMs).toBeGreaterThan(3500 * 1000);
expect(expiresInMs).toBeLessThanOrEqual(3600 * 1000);
});
it('should create new access token with expires_at (MCPOAuthTokens format)', async () => {
const expiresAt = Date.now() + 7200 * 1000;
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'at1',
token_type: 'Bearer',
expires_at: expiresAt,
obtained_at: Date.now(),
},
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved).not.toBeNull();
const diff = Math.abs(saved!.expiresAt.getTime() - expiresAt);
expect(diff).toBeLessThan(2000);
});
it('should default to 1-year expiry when none provided', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer' },
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
});
it('should update existing access token', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:old-token',
expiresIn: 3600,
});
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'new-token', token_type: 'Bearer', expires_in: 7200 },
createToken: store.createToken,
updateToken: store.updateToken,
findToken: store.findToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved!.token).toBe('enc:new-token');
});
it('should store refresh token alongside access token', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'at1',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'rt1',
},
createToken: store.createToken,
});
const refreshSaved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
});
expect(refreshSaved).not.toBeNull();
expect(refreshSaved!.token).toBe('enc:rt1');
});
it('should skip refresh token when not in response', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
});
const refreshSaved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
});
expect(refreshSaved).toBeNull();
});
it('should store client info when provided', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] },
});
const clientSaved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth_client',
identifier: 'mcp:srv1:client',
});
expect(clientSaved).not.toBeNull();
expect(clientSaved!.token).toContain('enc:');
expect(clientSaved!.token).toContain('cid');
});
it('should use existingTokens to skip DB lookups', async () => {
const findSpy = jest.fn();
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
createToken: store.createToken,
updateToken: store.updateToken,
findToken: findSpy,
existingTokens: {
accessToken: null,
refreshToken: null,
clientInfoToken: null,
},
});
expect(findSpy).not.toHaveBeenCalled();
});
it('should handle invalid NaN expiry date', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'at1',
token_type: 'Bearer',
expires_at: NaN,
obtained_at: Date.now(),
},
createToken: store.createToken,
});
const saved = await store.findToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
});
expect(saved).not.toBeNull();
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
});
});
describe('getTokens', () => {
it('should return valid non-expired tokens', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:valid-token',
expiresIn: 3600,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result).not.toBeNull();
expect(result!.access_token).toBe('valid-token');
expect(result!.token_type).toBe('Bearer');
});
it('should return tokens with refresh token when available', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:at',
expiresIn: 3600,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result!.refresh_token).toBe('rt');
});
it('should return tokens without refresh token field when none stored', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:at',
expiresIn: 3600,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result!.refresh_token).toBeUndefined();
});
it('should throw ReauthenticationRequiredError when expired and no refresh', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should throw ReauthenticationRequiredError when missing and no refresh', async () => {
await expect(
MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
}),
).rejects.toThrow(ReauthenticationRequiredError);
});
it('should refresh expired access token when refresh token and callback are available', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockResolvedValue({
access_token: 'refreshed-at',
token_type: 'Bearer',
expires_in: 3600,
obtained_at: Date.now(),
expires_at: Date.now() + 3600000,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
updateToken: store.updateToken,
refreshTokens,
});
expect(result).not.toBeNull();
expect(result!.access_token).toBe('refreshed-at');
expect(refreshTokens).toHaveBeenCalledWith(
'rt',
expect.objectContaining({ userId: 'u1', serverName: 'srv1' }),
);
});
it('should return null when refresh fails', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockRejectedValue(new Error('refresh failed'));
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
updateToken: store.updateToken,
refreshTokens,
});
expect(result).toBeNull();
});
it('should return null when no refreshTokens callback provided', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result).toBeNull();
});
it('should return null when no createToken callback provided', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
refreshTokens: jest.fn(),
});
expect(result).toBeNull();
});
it('should pass client info to refreshTokens metadata', async () => {
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_client',
identifier: 'mcp:srv1:client',
token: 'enc:{"client_id":"cid","client_secret":"csec"}',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockResolvedValue({
access_token: 'new-at',
token_type: 'Bearer',
expires_in: 3600,
});
await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
updateToken: store.updateToken,
refreshTokens,
});
expect(refreshTokens).toHaveBeenCalledWith(
'rt',
expect.objectContaining({
clientInfo: expect.objectContaining({ client_id: 'cid' }),
}),
);
});
it('should handle unauthorized_client refresh error', async () => {
const { logger } = await import('@librechat/data-schemas');
await store.createToken({
userId: 'u1',
type: 'mcp_oauth',
identifier: 'mcp:srv1',
token: 'enc:expired-token',
expiresIn: -1,
});
await store.createToken({
userId: 'u1',
type: 'mcp_oauth_refresh',
identifier: 'mcp:srv1:refresh',
token: 'enc:rt',
expiresIn: 86400,
});
const refreshTokens = jest.fn().mockRejectedValue(new Error('unauthorized_client'));
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
createToken: store.createToken,
refreshTokens,
});
expect(result).toBeNull();
expect(logger.info).toHaveBeenCalledWith(
expect.stringContaining('does not support refresh tokens'),
);
});
});
describe('storeTokens + getTokens round-trip', () => {
it('should store and retrieve tokens with full encrypt/decrypt cycle', async () => {
await MCPTokenStorage.storeTokens({
userId: 'u1',
serverName: 'srv1',
tokens: {
access_token: 'my-access-token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'my-refresh-token',
},
createToken: store.createToken,
clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] },
});
const result = await MCPTokenStorage.getTokens({
userId: 'u1',
serverName: 'srv1',
findToken: store.findToken,
});
expect(result!.access_token).toBe('my-access-token');
expect(result!.refresh_token).toBe('my-refresh-token');
expect(result!.token_type).toBe('Bearer');
expect(result!.obtained_at).toBeDefined();
expect(result!.expires_at).toBeDefined();
});
});
});

View file

@ -1439,5 +1439,292 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
}),
);
});
describe('path-based URL origin fallback', () => {
it('retries with origin URL when path-based discovery fails (stored clientInfo path)', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
grant_types: ['authorization_code', 'refresh_token'],
},
};
const originMetadata = {
issuer: 'https://mcp.sentry.dev/',
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
token_endpoint_auth_methods_supported: ['client_secret_post'],
response_types_supported: ['code'],
jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json',
subject_types_supported: ['public'],
id_token_signing_alg_values_supported: ['RS256'],
} as AuthorizationServerMetadata;
// First call (path-based URL) fails, second call (origin URL) succeeds
mockDiscoverAuthorizationServerMetadata
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce(originMetadata);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
// Discovery attempted twice: once with path URL, once with origin URL
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
1,
expect.any(URL),
expect.any(Object),
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
2,
expect.any(URL),
expect.any(Object),
);
const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL;
const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL;
expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp');
expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/');
// Token endpoint from origin discovery metadata is used (string in stored-clientInfo branch)
expect(mockFetch).toHaveBeenCalled();
const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0];
expect(typeof fetchUrl).toBe('string');
expect(fetchUrl).toBe('https://mcp.sentry.dev/oauth/token');
expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' }));
expect(result.access_token).toBe('new-access-token');
});
it('retries with origin URL when path-based discovery fails (no stored clientInfo)', async () => {
// No clientInfo — uses the auto-discovered branch
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
};
const originMetadata = {
issuer: 'https://mcp.sentry.dev/',
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
response_types_supported: ['code'],
jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json',
subject_types_supported: ['public'],
id_token_signing_alg_values_supported: ['RS256'],
} as AuthorizationServerMetadata;
// First call (path-based URL) fails, second call (origin URL) succeeds
mockDiscoverAuthorizationServerMetadata
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce(originMetadata);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
// Discovery attempted twice: once with path URL, once with origin URL
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
1,
expect.any(URL),
expect.any(Object),
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
2,
expect.any(URL),
expect.any(Object),
);
const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL;
const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL;
expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp');
expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/');
// Token endpoint from origin discovery metadata is used (URL object in auto-discovered branch)
expect(mockFetch).toHaveBeenCalled();
const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0];
expect(fetchUrl).toBeInstanceOf(URL);
expect(fetchUrl.toString()).toBe('https://mcp.sentry.dev/oauth/token');
expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' }));
expect(result.access_token).toBe('new-access-token');
});
it('falls back to /token when both path and origin discovery fail', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
grant_types: ['authorization_code', 'refresh_token'],
},
};
// Both path AND origin discovery return undefined
mockDiscoverAuthorizationServerMetadata
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce(undefined);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
// Falls back to /token relative to server URL origin
const [fetchUrl] = mockFetch.mock.calls[0];
expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/token');
expect(result.access_token).toBe('new-access-token');
});
it('does not retry with origin when server URL has no path (root URL)', async () => {
const metadata = {
serverName: 'test-server',
serverUrl: 'https://auth.example.com/',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
},
};
// Root URL discovery fails — no retry
mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({ access_token: 'new-token', expires_in: 3600 }),
} as Response);
await MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {});
// Only one discovery attempt for a root URL
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1);
});
it('retries with origin when path-based discovery throws', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
grant_types: ['authorization_code', 'refresh_token'],
},
};
const originMetadata = {
issuer: 'https://mcp.sentry.dev/',
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
token_endpoint_auth_methods_supported: ['client_secret_post'],
response_types_supported: ['code'],
} as AuthorizationServerMetadata;
// First call throws, second call succeeds
mockDiscoverAuthorizationServerMetadata
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce(originMetadata);
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
expires_in: 3600,
}),
} as Response);
const result = await MCPOAuthHandler.refreshOAuthTokens(
'test-refresh-token',
metadata,
{},
{},
);
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
const [fetchUrl] = mockFetch.mock.calls[0];
expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/oauth/token');
expect(result.access_token).toBe('new-access-token');
});
it('propagates the throw when root URL discovery throws', async () => {
const metadata = {
serverName: 'test-server',
serverUrl: 'https://auth.example.com/',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
},
};
mockDiscoverAuthorizationServerMetadata.mockRejectedValueOnce(
new Error('Discovery failed'),
);
await expect(
MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}),
).rejects.toThrow('Discovery failed');
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1);
});
it('propagates the throw when both path and origin discovery throw', async () => {
const metadata = {
serverName: 'sentry',
serverUrl: 'https://mcp.sentry.dev/mcp',
clientInfo: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
},
};
mockDiscoverAuthorizationServerMetadata
.mockRejectedValueOnce(new Error('Network error'))
.mockRejectedValueOnce(new Error('Origin also failed'));
await expect(
MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}),
).rejects.toThrow('Origin also failed');
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
});
});
});
});

View file

@ -0,0 +1,449 @@
import * as http from 'http';
import * as net from 'net';
import { randomUUID, createHash } from 'crypto';
import { z } from 'zod';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import type { FlowState } from '~/flow/types';
import type { Socket } from 'net';
export class MockKeyv<T = unknown> {
private store: Map<string, FlowState<T>>;
constructor() {
this.store = new Map();
}
async get(key: string): Promise<FlowState<T> | undefined> {
return this.store.get(key);
}
async set(key: string, value: FlowState<T>, _ttl?: number): Promise<true> {
this.store.set(key, value);
return true;
}
async delete(key: string): Promise<boolean> {
return this.store.delete(key);
}
}
export function getFreePort(): Promise<number> {
return new Promise((resolve, reject) => {
const srv = net.createServer();
srv.listen(0, '127.0.0.1', () => {
const addr = srv.address() as net.AddressInfo;
srv.close((err) => (err ? reject(err) : resolve(addr.port)));
});
});
}
export function trackSockets(httpServer: http.Server): () => Promise<void> {
const sockets = new Set<Socket>();
httpServer.on('connection', (socket: Socket) => {
sockets.add(socket);
socket.once('close', () => sockets.delete(socket));
});
return () =>
new Promise<void>((resolve) => {
for (const socket of sockets) {
socket.destroy();
}
sockets.clear();
httpServer.close(() => resolve());
});
}
export interface OAuthTestServerOptions {
tokenTTLMs?: number;
issueRefreshTokens?: boolean;
refreshTokenTTLMs?: number;
rotateRefreshTokens?: boolean;
}
export interface OAuthTestServer {
url: string;
port: number;
close: () => Promise<void>;
issuedTokens: Set<string>;
tokenTTL: number;
tokenIssueTimes: Map<string, number>;
issuedRefreshTokens: Map<string, string>;
registeredClients: Map<string, { client_id: string; client_secret: string }>;
getAuthCode: () => Promise<string>;
}
async function readRequestBody(req: http.IncomingMessage): Promise<string> {
const chunks: Uint8Array[] = [];
for await (const chunk of req) {
chunks.push(chunk as Uint8Array);
}
return Buffer.concat(chunks).toString();
}
function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null {
if (contentType?.includes('application/x-www-form-urlencoded')) {
return new URLSearchParams(body);
}
if (contentType?.includes('application/json')) {
const json = JSON.parse(body) as Record<string, string>;
return new URLSearchParams(json);
}
return new URLSearchParams(body);
}
export async function createOAuthMCPServer(
options: OAuthTestServerOptions = {},
): Promise<OAuthTestServer> {
const {
tokenTTLMs = 60000,
issueRefreshTokens = false,
refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000,
rotateRefreshTokens = false,
} = options;
const sessions = new Map<string, StreamableHTTPServerTransport>();
const issuedTokens = new Set<string>();
const tokenIssueTimes = new Map<string, number>();
const issuedRefreshTokens = new Map<string, string>();
const refreshTokenIssueTimes = new Map<string, number>();
const authCodes = new Map<string, { codeChallenge?: string; codeChallengeMethod?: string }>();
const registeredClients = new Map<string, { client_id: string; client_secret: string }>();
let port = 0;
const httpServer = http.createServer(async (req, res) => {
const url = new URL(req.url ?? '/', `http://${req.headers.host}`);
if (url.pathname === '/.well-known/oauth-authorization-server' && req.method === 'GET') {
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(
JSON.stringify({
issuer: `http://127.0.0.1:${port}`,
authorization_endpoint: `http://127.0.0.1:${port}/authorize`,
token_endpoint: `http://127.0.0.1:${port}/token`,
registration_endpoint: `http://127.0.0.1:${port}/register`,
response_types_supported: ['code'],
grant_types_supported: issueRefreshTokens
? ['authorization_code', 'refresh_token']
: ['authorization_code'],
token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post'],
code_challenge_methods_supported: ['S256'],
}),
);
return;
}
if (url.pathname === '/register' && req.method === 'POST') {
const body = await readRequestBody(req);
const data = JSON.parse(body) as { redirect_uris?: string[] };
const clientId = `client-${randomUUID().slice(0, 8)}`;
const clientSecret = `secret-${randomUUID()}`;
registeredClients.set(clientId, { client_id: clientId, client_secret: clientSecret });
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(
JSON.stringify({
client_id: clientId,
client_secret: clientSecret,
redirect_uris: data.redirect_uris ?? [],
}),
);
return;
}
if (url.pathname === '/authorize') {
const code = randomUUID();
const codeChallenge = url.searchParams.get('code_challenge') ?? undefined;
const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined;
authCodes.set(code, { codeChallenge, codeChallengeMethod });
const redirectUri = url.searchParams.get('redirect_uri') ?? '';
const state = url.searchParams.get('state') ?? '';
res.writeHead(302, {
Location: `${redirectUri}?code=${code}&state=${state}`,
});
res.end();
return;
}
if (url.pathname === '/token' && req.method === 'POST') {
const body = await readRequestBody(req);
const params = parseTokenRequest(body, req.headers['content-type']);
if (!params) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_request' }));
return;
}
const grantType = params.get('grant_type');
if (grantType === 'authorization_code') {
const code = params.get('code');
const codeData = code ? authCodes.get(code) : undefined;
if (!code || !codeData) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
if (codeData.codeChallenge) {
const codeVerifier = params.get('code_verifier');
if (!codeVerifier) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
if (codeData.codeChallengeMethod === 'S256') {
const expected = createHash('sha256').update(codeVerifier).digest('base64url');
if (expected !== codeData.codeChallenge) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
}
}
authCodes.delete(code);
const accessToken = randomUUID();
issuedTokens.add(accessToken);
tokenIssueTimes.set(accessToken, Date.now());
const tokenResponse: Record<string, string | number> = {
access_token: accessToken,
token_type: 'Bearer',
expires_in: Math.ceil(tokenTTLMs / 1000),
};
if (issueRefreshTokens) {
const refreshToken = randomUUID();
issuedRefreshTokens.set(refreshToken, accessToken);
refreshTokenIssueTimes.set(refreshToken, Date.now());
tokenResponse.refresh_token = refreshToken;
}
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(tokenResponse));
return;
}
if (grantType === 'refresh_token' && issueRefreshTokens) {
const refreshToken = params.get('refresh_token');
if (!refreshToken || !issuedRefreshTokens.has(refreshToken)) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
const issueTime = refreshTokenIssueTimes.get(refreshToken) ?? 0;
if (Date.now() - issueTime > refreshTokenTTLMs) {
issuedRefreshTokens.delete(refreshToken);
refreshTokenIssueTimes.delete(refreshToken);
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_grant' }));
return;
}
const newAccessToken = randomUUID();
issuedTokens.add(newAccessToken);
tokenIssueTimes.set(newAccessToken, Date.now());
const tokenResponse: Record<string, string | number> = {
access_token: newAccessToken,
token_type: 'Bearer',
expires_in: Math.ceil(tokenTTLMs / 1000),
};
if (rotateRefreshTokens) {
issuedRefreshTokens.delete(refreshToken);
refreshTokenIssueTimes.delete(refreshToken);
const newRefreshToken = randomUUID();
issuedRefreshTokens.set(newRefreshToken, newAccessToken);
refreshTokenIssueTimes.set(newRefreshToken, Date.now());
tokenResponse.refresh_token = newRefreshToken;
}
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(tokenResponse));
return;
}
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'unsupported_grant_type' }));
return;
}
// All other paths require Bearer token auth
const authHeader = req.headers.authorization;
if (!authHeader || !authHeader.startsWith('Bearer ')) {
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_token' }));
return;
}
const token = authHeader.slice(7);
if (!issuedTokens.has(token)) {
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_token' }));
return;
}
const issueTime = tokenIssueTimes.get(token) ?? 0;
if (Date.now() - issueTime > tokenTTLMs) {
issuedTokens.delete(token);
tokenIssueTimes.delete(token);
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_token' }));
return;
}
// Authenticated MCP request — route to transport
const sid = req.headers['mcp-session-id'] as string | undefined;
let transport = sid ? sessions.get(sid) : undefined;
if (!transport) {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
});
const mcp = new McpServer({ name: 'oauth-test-server', version: '0.0.1' });
mcp.tool('echo', { message: z.string() }, async (args) => ({
content: [{ type: 'text' as const, text: `echo: ${args.message}` }],
}));
await mcp.connect(transport);
}
await transport.handleRequest(req, res);
if (transport.sessionId && !sessions.has(transport.sessionId)) {
sessions.set(transport.sessionId, transport);
transport.onclose = () => sessions.delete(transport!.sessionId!);
}
});
const destroySockets = trackSockets(httpServer);
port = await getFreePort();
await new Promise<void>((resolve) => httpServer.listen(port, '127.0.0.1', resolve));
return {
url: `http://127.0.0.1:${port}/`,
port,
issuedTokens,
tokenTTL: tokenTTLMs,
tokenIssueTimes,
issuedRefreshTokens,
registeredClients,
getAuthCode: async () => {
const authRes = await fetch(
`http://127.0.0.1:${port}/authorize?redirect_uri=http://localhost&state=test`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
return new URL(location).searchParams.get('code') ?? '';
},
close: async () => {
const closing = [...sessions.values()].map((t) => t.close().catch(() => undefined));
sessions.clear();
await Promise.all(closing);
await destroySockets();
},
};
}
export interface InMemoryToken {
userId: string;
type: string;
identifier: string;
token: string;
expiresAt: Date;
createdAt: Date;
metadata?: Map<string, unknown> | Record<string, unknown>;
}
export class InMemoryTokenStore {
private tokens: Map<string, InMemoryToken> = new Map();
private key(filter: { userId?: string; type?: string; identifier?: string }): string {
return `${filter.userId}:${filter.type}:${filter.identifier}`;
}
findToken = async (filter: {
userId?: string;
type?: string;
identifier?: string;
}): Promise<InMemoryToken | null> => {
for (const token of this.tokens.values()) {
const matchUserId = !filter.userId || token.userId === filter.userId;
const matchType = !filter.type || token.type === filter.type;
const matchIdentifier = !filter.identifier || token.identifier === filter.identifier;
if (matchUserId && matchType && matchIdentifier) {
return token;
}
}
return null;
};
createToken = async (data: {
userId: string;
type: string;
identifier: string;
token: string;
expiresIn?: number;
metadata?: Record<string, unknown>;
}): Promise<InMemoryToken> => {
const expiresIn = data.expiresIn ?? 365 * 24 * 60 * 60;
const token: InMemoryToken = {
userId: data.userId,
type: data.type,
identifier: data.identifier,
token: data.token,
expiresAt: new Date(Date.now() + expiresIn * 1000),
createdAt: new Date(),
metadata: data.metadata,
};
this.tokens.set(this.key(data), token);
return token;
};
updateToken = async (
filter: { userId?: string; type?: string; identifier?: string },
data: {
userId?: string;
type?: string;
identifier?: string;
token?: string;
expiresIn?: number;
metadata?: Record<string, unknown>;
},
): Promise<InMemoryToken> => {
const existing = await this.findToken(filter);
if (!existing) {
throw new Error(`Token not found for filter: ${JSON.stringify(filter)}`);
}
const existingKey = this.key(existing);
const expiresIn =
data.expiresIn ?? Math.floor((existing.expiresAt.getTime() - Date.now()) / 1000);
const updated: InMemoryToken = {
...existing,
token: data.token ?? existing.token,
expiresAt: data.expiresIn ? new Date(Date.now() + expiresIn * 1000) : existing.expiresAt,
metadata: data.metadata ?? existing.metadata,
};
this.tokens.set(existingKey, updated);
return updated;
};
deleteToken = async (filter: {
userId: string;
type: string;
identifier: string;
}): Promise<void> => {
this.tokens.delete(this.key(filter));
};
getAll(): InMemoryToken[] {
return [...this.tokens.values()];
}
clear(): void {
this.tokens.clear();
}
}

View file

@ -0,0 +1,668 @@
/**
* Reconnection storm regression tests for PR #12162.
*
* Validates circuit breaker, throttling, cooldown, and timeout fixes using real
* MCP SDK transports (no mocked stubs). A real StreamableHTTP server is spun up
* per test suite and MCPConnection talks to it through a genuine HTTP stack.
*/
import http from 'http';
import { randomUUID } from 'crypto';
import express from 'express';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import type { Socket } from 'net';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import { OAuthReconnectionTracker } from '~/mcp/oauth/OAuthReconnectionTracker';
import { createOAuthMCPServer } from './helpers/oauthTestServer';
import { MCPConnection } from '~/mcp/connection';
import { mcpConfig } from '~/mcp/mcpConfig';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
/* ------------------------------------------------------------------ */
/* Helpers */
/* ------------------------------------------------------------------ */
interface TestServer {
url: string;
httpServer: http.Server;
close: () => Promise<void>;
}
function trackSockets(httpServer: http.Server): () => Promise<void> {
const sockets = new Set<Socket>();
httpServer.on('connection', (socket: Socket) => {
sockets.add(socket);
socket.once('close', () => sockets.delete(socket));
});
return () =>
new Promise<void>((resolve) => {
for (const socket of sockets) {
socket.destroy();
}
sockets.clear();
httpServer.close(() => resolve());
});
}
function startMCPServer(): Promise<TestServer> {
const app = express();
app.use(express.json());
const transports: Record<string, StreamableHTTPServerTransport> = {};
function createServer(): McpServer {
const server = new McpServer({ name: 'test-server', version: '1.0.0' });
server.tool('echo', 'echoes input', { message: { type: 'string' } as never }, async (args) => {
const msg = (args as Record<string, string>).message ?? '';
return { content: [{ type: 'text', text: msg }] };
});
return server;
}
app.all('/mcp', async (req, res) => {
const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId && transports[sessionId]) {
await transports[sessionId].handleRequest(req, res, req.body);
return;
}
if (!sessionId && isInitializeRequest(req.body)) {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (sid) => {
transports[sid] = transport;
},
});
transport.onclose = () => {
const sid = transport.sessionId;
if (sid) {
delete transports[sid];
}
};
const server = createServer();
await server.connect(transport);
await transport.handleRequest(req, res, req.body);
return;
}
if (req.method === 'GET') {
res.status(404).send('Not Found');
return;
}
res.status(400).json({
jsonrpc: '2.0',
error: { code: -32000, message: 'Bad Request: No valid session ID provided' },
id: null,
});
});
return new Promise((resolve) => {
const httpServer = app.listen(0, '127.0.0.1', () => {
const destroySockets = trackSockets(httpServer);
const addr = httpServer.address() as { port: number };
resolve({
url: `http://127.0.0.1:${addr.port}/mcp`,
httpServer,
close: async () => {
for (const t of Object.values(transports)) {
t.close().catch(() => {});
}
await destroySockets();
},
});
});
});
}
function createConnection(serverName: string, url: string, initTimeout = 5000): MCPConnection {
return new MCPConnection({
serverName,
serverConfig: { url, type: 'streamable-http', initTimeout } as never,
});
}
async function teardownConnection(conn: MCPConnection): Promise<void> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = true;
conn.removeAllListeners();
await conn.disconnect();
}
afterEach(() => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(MCPConnection as any).circuitBreakers.clear();
});
/* ------------------------------------------------------------------ */
/* Fix #2 — Circuit breaker trips after rapid connect/disconnect */
/* cycles (CB_MAX_CYCLES within window -> cooldown) */
/* ------------------------------------------------------------------ */
describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => {
it('blocks connection after CB_MAX_CYCLES rapid cycles via static circuit breaker', async () => {
const srv = await startMCPServer();
const conn = createConnection('cycling-server', srv.url);
let completedCycles = 0;
let breakerMessage = '';
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
for (let cycle = 0; cycle < maxAttempts; cycle++) {
try {
await conn.connect();
await teardownConnection(conn);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = false;
completedCycles++;
} catch (e) {
breakerMessage = (e as Error).message;
break;
}
}
expect(breakerMessage).toContain('Circuit breaker is open');
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
await srv.close();
});
});
/* ------------------------------------------------------------------ */
/* Fix #3 — SSE 400/405 handled in same branch as 404 */
/* ------------------------------------------------------------------ */
describe('Fix #3: SSE 400/405 handled in same branch as 404', () => {
it('400 with active session triggers reconnection (session lost)', async () => {
const srv = await startMCPServer();
const conn = createConnection('sse-400', srv.url);
await conn.connect();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = true;
const changes: string[] = [];
conn.on('connectionChange', (s: string) => changes.push(s));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const transport = (conn as any).transport;
transport.onerror({ message: 'Failed to open SSE stream', code: 400 });
expect(changes).toContain('error');
await teardownConnection(conn);
await srv.close();
});
it('405 with active session triggers reconnection (session lost)', async () => {
const srv = await startMCPServer();
const conn = createConnection('sse-405', srv.url);
await conn.connect();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = true;
const changes: string[] = [];
conn.on('connectionChange', (s: string) => changes.push(s));
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const transport = (conn as any).transport;
transport.onerror({ message: 'Method Not Allowed', code: 405 });
expect(changes).toContain('error');
await teardownConnection(conn);
await srv.close();
});
});
/* ------------------------------------------------------------------ */
/* Fix #4 — Circuit breaker state persists in static Map across */
/* instance replacements */
/* ------------------------------------------------------------------ */
describe('Fix #4: Circuit breaker state persists across instance replacement', () => {
it('new MCPConnection for same serverName inherits breaker state from static Map', async () => {
const srv = await startMCPServer();
const conn1 = createConnection('replace', srv.url);
await conn1.connect();
await teardownConnection(conn1);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cbAfterConn1 = (MCPConnection as any).circuitBreakers.get('replace');
expect(cbAfterConn1).toBeDefined();
const cyclesAfterConn1 = cbAfterConn1.cycleCount;
expect(cyclesAfterConn1).toBeGreaterThan(0);
const conn2 = createConnection('replace', srv.url);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cbFromConn2 = (conn2 as any).getCircuitBreaker();
expect(cbFromConn2.cycleCount).toBe(cyclesAfterConn1);
await teardownConnection(conn2);
await srv.close();
});
it('clearCooldown resets static state so explicit retry proceeds', () => {
const conn = createConnection('replace', 'http://127.0.0.1:1/mcp');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (conn as any).getCircuitBreaker();
cb.cooldownUntil = Date.now() + 999_999;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((conn as any).isCircuitOpen()).toBe(true);
MCPConnection.clearCooldown('replace');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((conn as any).isCircuitOpen()).toBe(false);
});
});
/* ------------------------------------------------------------------ */
/* Fix #5 — Dead servers now trigger circuit breaker via */
/* recordFailedRound() in the catch path */
/* ------------------------------------------------------------------ */
describe('Fix #5: Dead server triggers circuit breaker', () => {
it('failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => {
const conn = createConnection('dead', 'http://127.0.0.1:1/mcp', 1000);
const spy = jest.spyOn(conn.client, 'connect');
const totalAttempts = mcpConfig.CB_MAX_FAILED_ROUNDS + 2;
const errors: string[] = [];
for (let i = 0; i < totalAttempts; i++) {
try {
await conn.connect();
} catch (e) {
errors.push((e as Error).message);
}
}
expect(spy.mock.calls.length).toBe(mcpConfig.CB_MAX_FAILED_ROUNDS);
expect(errors).toHaveLength(totalAttempts);
expect(errors.filter((m) => m.includes('Circuit breaker is open'))).toHaveLength(2);
await conn.disconnect();
});
it('user B is immediately blocked when user A already tripped the breaker for the same server', async () => {
const deadUrl = 'http://127.0.0.1:1/mcp';
const userA = new MCPConnection({
serverName: 'shared-dead',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-A',
});
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
try {
await userA.connect();
} catch {
// expected
}
}
const userB = new MCPConnection({
serverName: 'shared-dead',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-B',
});
const spyB = jest.spyOn(userB.client, 'connect');
let blockedMessage = '';
try {
await userB.connect();
} catch (e) {
blockedMessage = (e as Error).message;
}
expect(blockedMessage).toContain('Circuit breaker is open');
expect(spyB).toHaveBeenCalledTimes(0);
await userA.disconnect();
await userB.disconnect();
});
it('clearCooldown after user retry unblocks all users', async () => {
const deadUrl = 'http://127.0.0.1:1/mcp';
const userA = new MCPConnection({
serverName: 'shared-dead-clear',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-A',
});
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
try {
await userA.connect();
} catch {
// expected
}
}
const userB = new MCPConnection({
serverName: 'shared-dead-clear',
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
userId: 'user-B',
});
try {
await userB.connect();
} catch (e) {
expect((e as Error).message).toContain('Circuit breaker is open');
}
MCPConnection.clearCooldown('shared-dead-clear');
const spyB = jest.spyOn(userB.client, 'connect');
try {
await userB.connect();
} catch {
// expected — server is still dead
}
expect(spyB).toHaveBeenCalledTimes(1);
await userA.disconnect();
await userB.disconnect();
});
});
/* ------------------------------------------------------------------ */
/* Fix #5b — disconnect(false) preserves cycle tracking */
/* ------------------------------------------------------------------ */
describe('Fix #5b: disconnect(false) preserves cycle tracking', () => {
it('connect() passes false to disconnect, which calls recordCycle()', async () => {
const srv = await startMCPServer();
const conn = createConnection('wipe', srv.url);
const spy = jest.spyOn(conn, 'disconnect');
await conn.connect();
expect(spy).toHaveBeenCalledWith(false);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (MCPConnection as any).circuitBreakers.get('wipe');
expect(cb).toBeDefined();
expect(cb.cycleCount).toBeGreaterThan(0);
await teardownConnection(conn);
await srv.close();
});
});
/* ------------------------------------------------------------------ */
/* Fix #6 — OAuth failure uses cooldown-based retry */
/* ------------------------------------------------------------------ */
describe('Fix #6: OAuth failure uses cooldown-based retry', () => {
beforeEach(() => jest.useFakeTimers());
afterEach(() => jest.useRealTimers());
it('isFailed expires after first cooldown of 5 min', () => {
jest.setSystemTime(Date.now());
const tracker = new OAuthReconnectionTracker();
tracker.setFailed('u1', 'srv');
expect(tracker.isFailed('u1', 'srv')).toBe(true);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
});
it('progressive cooldown: 5m, 10m, 20m, 30m (capped)', () => {
jest.setSystemTime(Date.now());
const tracker = new OAuthReconnectionTracker();
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(10 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(20 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
});
it('removeFailed resets attempt count so next failure starts at 5m', () => {
jest.setSystemTime(Date.now());
const tracker = new OAuthReconnectionTracker();
tracker.setFailed('u1', 'srv');
tracker.setFailed('u1', 'srv');
tracker.setFailed('u1', 'srv');
tracker.removeFailed('u1', 'srv');
tracker.setFailed('u1', 'srv');
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed('u1', 'srv')).toBe(false);
});
});
/* ------------------------------------------------------------------ */
/* Integration: Circuit breaker caps rapid cycling with real transport */
/* ------------------------------------------------------------------ */
describe('Cascade: Circuit breaker caps rapid cycling', () => {
it('breaker trips before double CB_MAX_CYCLES complete against a live server', async () => {
const srv = await startMCPServer();
const conn = createConnection('cascade', srv.url);
const spy = jest.spyOn(conn.client, 'connect');
let completedCycles = 0;
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
for (let i = 0; i < maxAttempts; i++) {
try {
await conn.connect();
await teardownConnection(conn);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = false;
completedCycles++;
} catch (e) {
if ((e as Error).message.includes('Circuit breaker is open')) {
break;
}
throw e;
}
}
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
expect(spy.mock.calls.length).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
await srv.close();
});
it('breaker bounds failures against a killed server', async () => {
const srv = await startMCPServer();
const conn = createConnection('cascade-die', srv.url, 2000);
await conn.connect();
await teardownConnection(conn);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(conn as any).shouldStopReconnecting = false;
await srv.close();
let breakerTripped = false;
for (let i = 0; i < 10; i++) {
try {
await conn.connect();
} catch (e) {
if ((e as Error).message.includes('Circuit breaker is open')) {
breakerTripped = true;
break;
}
}
}
expect(breakerTripped).toBe(true);
}, 30_000);
});
/* ------------------------------------------------------------------ */
/* OAuth: cycle recovery after successful OAuth reconnect */
/* ------------------------------------------------------------------ */
describe('OAuth: cycle budget recovery after successful OAuth', () => {
let oauthServer: OAuthTestServer;
beforeEach(async () => {
oauthServer = await createOAuthMCPServer({ tokenTTLMs: 60000 });
});
afterEach(async () => {
await oauthServer.close();
});
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
redirect: 'manual',
});
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code') ?? '';
const tokenRes = await fetch(`${serverUrl}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}`,
});
const data = (await tokenRes.json()) as { access_token: string };
return data.access_token;
}
it('should decrement cycle count after successful OAuth recovery', async () => {
const serverName = 'oauth-cycle-test';
MCPConnection.clearCooldown(serverName);
const conn = new MCPConnection({
serverName,
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
userId: 'user-1',
});
// When oauthRequired fires, get a token and emit oauthHandled
// This triggers the oauthRecovery path inside connectClient
conn.on('oauthRequired', async () => {
const accessToken = await exchangeCodeForToken(oauthServer.url);
conn.setOAuthTokens({
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
conn.emit('oauthHandled');
});
// connect() → 401 → oauthRequired → oauthHandled → connectClient returns
// connect() sees not connected → throws "Connection not established"
await expect(conn.connect()).rejects.toThrow('Connection not established');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
const cyclesBeforeRetry = cb.cycleCount;
// Retry — should succeed and decrement cycle count via oauthRecovery
await conn.connect();
expect(await conn.isConnected()).toBe(true);
const cyclesAfterSuccess = cb.cycleCount;
// The retry adds +1 cycle (disconnect(false)) then -1 (oauthRecovery decrement)
// So cyclesAfterSuccess should equal cyclesBeforeRetry, not cyclesBeforeRetry + 1
expect(cyclesAfterSuccess).toBe(cyclesBeforeRetry);
await teardownConnection(conn);
});
it('should allow more OAuth reconnects than non-OAuth before breaker trips', async () => {
const serverName = 'oauth-budget';
MCPConnection.clearCooldown(serverName);
// Each OAuth flow: connect (+1) → 401 → oauthHandled → retry connect (+1) → success (-1) = net 1
// Without the decrement it would be net 2 per flow, tripping the breaker after ~2 users
let successfulFlows = 0;
for (let i = 0; i < 10; i++) {
const conn = new MCPConnection({
serverName,
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
userId: `user-${i}`,
});
conn.on('oauthRequired', async () => {
const accessToken = await exchangeCodeForToken(oauthServer.url);
conn.setOAuthTokens({
access_token: accessToken,
token_type: 'Bearer',
} as MCPOAuthTokens);
conn.emit('oauthHandled');
});
try {
// First connect: 401 → oauthHandled → returns without connection
await conn.connect().catch(() => {});
// Retry: succeeds with token, decrements cycle
await conn.connect();
successfulFlows++;
await teardownConnection(conn);
} catch (e) {
conn.removeAllListeners();
if ((e as Error).message.includes('Circuit breaker is open')) {
break;
}
}
}
// With the oauthRecovery decrement, each flow is net ~1 cycle instead of ~2,
// so we should get more successful flows before the breaker trips
expect(successfulFlows).toBeGreaterThanOrEqual(3);
});
it('should not decrement cycle count when OAuth fails', async () => {
const serverName = 'oauth-failed-no-decrement';
MCPConnection.clearCooldown(serverName);
const conn = new MCPConnection({
serverName,
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
userId: 'user-1',
});
conn.on('oauthRequired', () => {
conn.emit('oauthFailed', new Error('user denied'));
});
await expect(conn.connect()).rejects.toThrow();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
const cyclesAfterFailure = cb.cycleCount;
// connect() recorded +1 cycle, oauthFailed should NOT decrement
expect(cyclesAfterFailure).toBeGreaterThanOrEqual(1);
conn.removeAllListeners();
});
});
/* ------------------------------------------------------------------ */
/* Sanity: Real transport works end-to-end */
/* ------------------------------------------------------------------ */
describe('Sanity: Real MCP SDK transport works correctly', () => {
it('connects, lists tools, and disconnects cleanly', async () => {
const srv = await startMCPServer();
const conn = createConnection('sanity', srv.url);
await conn.connect();
expect(await conn.isConnected()).toBe(true);
const tools = await conn.fetchTools();
expect(tools).toEqual(expect.arrayContaining([expect.objectContaining({ name: 'echo' })]));
await teardownConnection(conn);
await srv.close();
});
});

View file

@ -71,6 +71,17 @@ const FIVE_MINUTES = 5 * 60 * 1000;
const DEFAULT_TIMEOUT = 60000;
/** SSE connections through proxies may need longer initial handshake time */
const SSE_CONNECT_TIMEOUT = 120000;
const DEFAULT_INIT_TIMEOUT = 30000;
interface CircuitBreakerState {
cycleCount: number;
cycleWindowStart: number;
cooldownUntil: number;
failedRounds: number;
failedWindowStart: number;
failedBackoffUntil: number;
}
/** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */
const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES;
@ -262,6 +273,7 @@ export class MCPConnection extends EventEmitter {
private oauthTokens?: MCPOAuthTokens | null;
private requestHeaders?: Record<string, string> | null;
private oauthRequired = false;
private oauthRecovery = false;
private readonly useSSRFProtection: boolean;
iconPath?: string;
timeout?: number;
@ -274,6 +286,88 @@ export class MCPConnection extends EventEmitter {
*/
public readonly createdAt: number;
private static circuitBreakers: Map<string, CircuitBreakerState> = new Map();
public static clearCooldown(serverName: string): void {
MCPConnection.circuitBreakers.delete(serverName);
logger.debug(`[MCP][${serverName}] Circuit breaker state cleared`);
}
private getCircuitBreaker(): CircuitBreakerState {
let cb = MCPConnection.circuitBreakers.get(this.serverName);
if (!cb) {
cb = {
cycleCount: 0,
cycleWindowStart: Date.now(),
cooldownUntil: 0,
failedRounds: 0,
failedWindowStart: Date.now(),
failedBackoffUntil: 0,
};
MCPConnection.circuitBreakers.set(this.serverName, cb);
}
return cb;
}
private isCircuitOpen(): boolean {
const cb = this.getCircuitBreaker();
const now = Date.now();
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
}
private recordCycle(): void {
const cb = this.getCircuitBreaker();
const now = Date.now();
if (now - cb.cycleWindowStart > mcpConfig.CB_CYCLE_WINDOW_MS) {
cb.cycleCount = 0;
cb.cycleWindowStart = now;
}
cb.cycleCount++;
if (cb.cycleCount >= mcpConfig.CB_MAX_CYCLES) {
cb.cooldownUntil = now + mcpConfig.CB_CYCLE_COOLDOWN_MS;
cb.cycleCount = 0;
cb.cycleWindowStart = now;
logger.warn(
`${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${mcpConfig.CB_CYCLE_COOLDOWN_MS}ms`,
);
}
}
private recordFailedRound(): void {
const cb = this.getCircuitBreaker();
const now = Date.now();
if (now - cb.failedWindowStart > mcpConfig.CB_FAILED_WINDOW_MS) {
cb.failedRounds = 0;
cb.failedWindowStart = now;
}
cb.failedRounds++;
if (cb.failedRounds >= mcpConfig.CB_MAX_FAILED_ROUNDS) {
const backoff = Math.min(
mcpConfig.CB_BASE_BACKOFF_MS *
Math.pow(2, cb.failedRounds - mcpConfig.CB_MAX_FAILED_ROUNDS),
mcpConfig.CB_MAX_BACKOFF_MS,
);
cb.failedBackoffUntil = now + backoff;
logger.warn(
`${this.getLogPrefix()} Circuit breaker: too many failures, backing off for ${backoff}ms`,
);
}
}
private resetFailedRounds(): void {
const cb = this.getCircuitBreaker();
cb.failedRounds = 0;
cb.failedWindowStart = Date.now();
cb.failedBackoffUntil = 0;
}
public static decrementCycleCount(serverName: string): void {
const cb = MCPConnection.circuitBreakers.get(serverName);
if (cb && cb.cycleCount > 0) {
cb.cycleCount--;
}
}
setRequestHeaders(headers: Record<string, string> | null): void {
if (!headers) {
return;
@ -686,6 +780,12 @@ export class MCPConnection extends EventEmitter {
return;
}
if (this.isCircuitOpen()) {
this.connectionState = 'error';
this.emit('connectionChange', 'error');
throw new Error(`${this.getLogPrefix()} Circuit breaker is open, connection attempt blocked`);
}
this.emit('connectionChange', 'connecting');
this.connectPromise = (async () => {
@ -703,7 +803,7 @@ export class MCPConnection extends EventEmitter {
this.transport = await runOutsideTracing(() => this.constructTransport(this.options));
this.patchTransportSend();
const connectTimeout = this.options.initTimeout ?? 120000;
const connectTimeout = this.options.initTimeout ?? DEFAULT_INIT_TIMEOUT;
await runOutsideTracing(() =>
withTimeout(
this.client.connect(this.transport!),
@ -716,6 +816,14 @@ export class MCPConnection extends EventEmitter {
this.connectionState = 'connected';
this.emit('connectionChange', 'connected');
this.reconnectAttempts = 0;
this.resetFailedRounds();
if (this.oauthRecovery) {
MCPConnection.decrementCycleCount(this.serverName);
this.oauthRecovery = false;
logger.debug(
`${this.getLogPrefix()} OAuth recovery: decremented cycle count after successful reconnect`,
);
}
} catch (error) {
// Check if it's a rate limit error - stop immediately to avoid making it worse
if (this.isRateLimitError(error)) {
@ -799,9 +907,8 @@ export class MCPConnection extends EventEmitter {
try {
// Wait for OAuth to be handled
await oauthHandledPromise;
// Reset the oauthRequired flag
this.oauthRequired = false;
// Don't throw the error - just return so connection can be retried
this.oauthRecovery = true;
logger.info(
`${this.getLogPrefix()} OAuth handled successfully, connection will be retried`,
);
@ -817,6 +924,7 @@ export class MCPConnection extends EventEmitter {
this.connectionState = 'error';
this.emit('connectionChange', 'error');
this.recordFailedRound();
throw error;
} finally {
this.connectPromise = null;
@ -866,7 +974,8 @@ export class MCPConnection extends EventEmitter {
async connect(): Promise<void> {
try {
await this.disconnect();
// preserve cycle tracking across reconnects so the circuit breaker can detect rapid cycling
await this.disconnect(false);
await this.connectClient();
if (!(await this.isConnected())) {
throw new Error('Connection not established');
@ -906,7 +1015,7 @@ export class MCPConnection extends EventEmitter {
isTransient,
} = extractSSEErrorMessage(error);
if (errorCode === 404) {
if (errorCode === 400 || errorCode === 404 || errorCode === 405) {
const hasSession =
'sessionId' in transport &&
(transport as { sessionId?: string }).sessionId != null &&
@ -914,14 +1023,14 @@ export class MCPConnection extends EventEmitter {
if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) {
logger.warn(
`${this.getLogPrefix()} SSE stream not available (404), no session. Ignoring.`,
`${this.getLogPrefix()} SSE stream not available (${errorCode}), no session. Ignoring.`,
);
return;
}
if (hasSession) {
logger.warn(
`${this.getLogPrefix()} 404 with active session — session lost, triggering reconnection.`,
`${this.getLogPrefix()} ${errorCode} with active session — session lost, triggering reconnection.`,
);
}
}
@ -992,7 +1101,7 @@ export class MCPConnection extends EventEmitter {
await Promise.all(closing);
}
public async disconnect(): Promise<void> {
public async disconnect(resetCycleTracking = true): Promise<void> {
try {
if (this.transport) {
await this.client.close();
@ -1006,6 +1115,9 @@ export class MCPConnection extends EventEmitter {
this.emit('connectionChange', 'disconnected');
} finally {
this.connectPromise = null;
if (!resetCycleTracking) {
this.recordCycle();
}
}
}

View file

@ -12,4 +12,18 @@ export const mcpConfig = {
USER_CONNECTION_IDLE_TIMEOUT: math(
process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000,
),
/** Max connect/disconnect cycles before the circuit breaker trips. Default: 7 */
CB_MAX_CYCLES: math(process.env.MCP_CB_MAX_CYCLES ?? 7),
/** Sliding window (ms) for counting cycles. Default: 45s */
CB_CYCLE_WINDOW_MS: math(process.env.MCP_CB_CYCLE_WINDOW_MS ?? 45_000),
/** Cooldown (ms) after the cycle breaker trips. Default: 15s */
CB_CYCLE_COOLDOWN_MS: math(process.env.MCP_CB_CYCLE_COOLDOWN_MS ?? 15_000),
/** Max consecutive failed connection rounds before backoff. Default: 3 */
CB_MAX_FAILED_ROUNDS: math(process.env.MCP_CB_MAX_FAILED_ROUNDS ?? 3),
/** Sliding window (ms) for counting failed rounds. Default: 120s */
CB_FAILED_WINDOW_MS: math(process.env.MCP_CB_FAILED_WINDOW_MS ?? 120_000),
/** Base backoff (ms) after failed round threshold is reached. Default: 30s */
CB_BASE_BACKOFF_MS: math(process.env.MCP_CB_BASE_BACKOFF_MS ?? 30_000),
/** Max backoff cap (ms) for exponential backoff. Default: 300s */
CB_MAX_BACKOFF_MS: math(process.env.MCP_CB_MAX_BACKOFF_MS ?? 300_000),
};

View file

@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
it('should not reconnect servers with expired tokens', async () => {
it('should not reconnect servers with expired tokens and no refresh token', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// server1: has expired token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
} as unknown as MCPOAuthTokens);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens;
}
return null;
});
await reconnectionManager.reconnectServers(userId);
@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => {
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
});
it('should reconnect servers with expired access token but valid refresh token', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() - 3600000),
} as unknown as MCPOAuthTokens;
}
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should reconnect when access token is TTL-deleted but refresh token exists', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server1:refresh') {
return {
userId,
identifier,
} as unknown as MCPOAuthTokens;
}
return null;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
});
it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => {
});
});
describe('reconnectServer', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should return true on successful reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(true);
});
it('should return false on failed reconnection', async () => {
const userId = 'user-123';
const serverName = 'server1';
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
const result = await reconnectionManager.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
it('should return false when MCPManager is not available', async () => {
const userId = 'user-123';
const serverName = 'server1';
(OAuthReconnectionManager as unknown as { instance: null }).instance = null;
(MCPManager.getInstance as jest.Mock).mockImplementation(() => {
throw new Error('MCPManager has not been initialized.');
});
const managerWithoutMCP = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
const result = await managerWithoutMCP.reconnectServer(userId, serverName);
expect(result).toBe(false);
});
});
describe('reconnection staggering', () => {
let reconnectionTracker: OAuthReconnectionTracker;

View file

@ -96,6 +96,24 @@ export class OAuthReconnectionManager {
}
}
/**
* Attempts to reconnect a single OAuth MCP server.
* @returns true if reconnection succeeded, false otherwise.
*/
public async reconnectServer(userId: string, serverName: string): Promise<boolean> {
if (this.mcpManager == null) {
return false;
}
this.reconnectionsTracker.setActive(userId, serverName);
try {
await this.tryReconnect(userId, serverName);
return !this.reconnectionsTracker.isFailed(userId, serverName);
} catch {
return false;
}
}
public clearReconnection(userId: string, serverName: string) {
this.reconnectionsTracker.removeFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
@ -174,23 +192,31 @@ export class OAuthReconnectionManager {
}
}
// if the server has no tokens for the user, don't attempt to reconnect
// if the server has a valid (non-expired) access token, allow reconnect
const accessToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}`,
});
if (accessToken == null) {
if (accessToken != null) {
const now = new Date();
if (!accessToken.expiresAt || accessToken.expiresAt >= now) {
return true;
}
}
// if the access token is expired or TTL-deleted, fall back to refresh token
const refreshToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}:refresh`,
});
if (refreshToken == null) {
return false;
}
// if the token has expired, don't attempt to reconnect
const now = new Date();
if (accessToken.expiresAt && accessToken.expiresAt < now) {
return false;
}
// …otherwise, we're good to go with the reconnect attempt
return true;
}
}

View file

@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => {
});
});
describe('cooldown-based retry', () => {
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
it('should return true from isFailed within first cooldown period (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(4 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should return false from isFailed after first cooldown elapses (5 min)', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => {
const now = Date.now();
jest.setSystemTime(now);
// First failure: 5 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Second failure: 10 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(9 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Third failure: 20 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(19 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Fourth failure: 30 min cooldown
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should cap cooldown at 30 min for attempts beyond 4', () => {
const now = Date.now();
jest.setSystemTime(now);
for (let i = 0; i < 5; i++) {
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(30 * 60 * 1000);
}
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(29 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(true);
jest.advanceTimersByTime(1 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should fully reset metadata on removeFailed', () => {
const now = Date.now();
jest.setSystemTime(now);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, serverName);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
tracker.setFailed(userId, serverName);
jest.advanceTimersByTime(5 * 60 * 1000);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
});
describe('timestamp tracking edge cases', () => {
beforeEach(() => {
jest.useFakeTimers();

View file

@ -1,6 +1,12 @@
interface FailedMeta {
attempts: number;
lastFailedAt: number;
}
const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000];
export class OAuthReconnectionTracker {
/** Map of userId -> Set of serverNames that have failed reconnection */
private failed: Map<string, Set<string>> = new Map();
private failedMeta: Map<string, Map<string, FailedMeta>> = new Map();
/** Map of userId -> Set of serverNames that are actively reconnecting */
private active: Map<string, Set<string>> = new Map();
/** Map of userId:serverName -> timestamp when reconnection started */
@ -9,7 +15,17 @@ export class OAuthReconnectionTracker {
private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
public isFailed(userId: string, serverName: string): boolean {
return this.failed.get(userId)?.has(serverName) ?? false;
const meta = this.failedMeta.get(userId)?.get(serverName);
if (!meta) {
return false;
}
const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1);
const cooldown = COOLDOWN_SCHEDULE_MS[idx];
const elapsed = Date.now() - meta.lastFailedAt;
if (elapsed >= cooldown) {
return false;
}
return true;
}
/** Check if server is in the active set (original simple check) */
@ -48,11 +64,15 @@ export class OAuthReconnectionTracker {
}
public setFailed(userId: string, serverName: string): void {
if (!this.failed.has(userId)) {
this.failed.set(userId, new Set());
if (!this.failedMeta.has(userId)) {
this.failedMeta.set(userId, new Map());
}
this.failed.get(userId)?.add(serverName);
const userMap = this.failedMeta.get(userId)!;
const existing = userMap.get(serverName);
userMap.set(serverName, {
attempts: (existing?.attempts ?? 0) + 1,
lastFailedAt: Date.now(),
});
}
public setActive(userId: string, serverName: string): void {
@ -68,10 +88,10 @@ export class OAuthReconnectionTracker {
}
public removeFailed(userId: string, serverName: string): void {
const userServers = this.failed.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.failed.delete(userId);
const userMap = this.failedMeta.get(userId);
userMap?.delete(serverName);
if (userMap?.size === 0) {
this.failedMeta.delete(userId);
}
}
@ -94,7 +114,7 @@ export class OAuthReconnectionTracker {
activeTimestamps: number;
} {
return {
usersWithFailedServers: this.failed.size,
usersWithFailedServers: this.failedMeta.size,
usersWithActiveReconnections: this.active.size,
activeTimestamps: this.activeTimestamps.size,
};

View file

@ -161,20 +161,7 @@ export class MCPOAuthHandler {
logger.debug(
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
);
let rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl, {
fetchFn,
});
// If discovery failed and we're using a path-based URL, try the base URL
if (!rawMetadata && authServerUrl.pathname !== '/') {
const baseUrl = new URL(authServerUrl.origin);
logger.debug(
`[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
);
rawMetadata = await discoverAuthorizationServerMetadata(baseUrl, {
fetchFn,
});
}
const rawMetadata = await this.discoverWithOriginFallback(authServerUrl, fetchFn);
if (!rawMetadata) {
/**
@ -221,6 +208,39 @@ export class MCPOAuthHandler {
};
}
/**
* Discovers OAuth authorization server metadata, retrying with just the origin
* when discovery fails for a path-based URL. Shared implementation used by
* `discoverMetadata` and both `refreshOAuthTokens` branches.
*/
private static async discoverWithOriginFallback(
serverUrl: URL,
fetchFn: FetchLike,
): ReturnType<typeof discoverAuthorizationServerMetadata> {
let metadata: Awaited<ReturnType<typeof discoverAuthorizationServerMetadata>>;
try {
metadata = await discoverAuthorizationServerMetadata(serverUrl, { fetchFn });
} catch (err) {
if (serverUrl.pathname === '/') {
throw err;
}
const baseUrl = new URL(serverUrl.origin);
logger.debug(
`[MCPOAuth] Discovery threw for path URL, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
{ error: err },
);
return discoverAuthorizationServerMetadata(baseUrl, { fetchFn });
}
if (!metadata && serverUrl.pathname !== '/') {
const baseUrl = new URL(serverUrl.origin);
logger.debug(
`[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
);
return discoverAuthorizationServerMetadata(baseUrl, { fetchFn });
}
return metadata;
}
/**
* Registers an OAuth client dynamically
*/
@ -406,8 +426,8 @@ export class MCPOAuthHandler {
scope: config.scope,
});
/** Add state parameter with flowId to the authorization URL */
authorizationUrl.searchParams.set('state', flowId);
/** Add cryptographic state parameter to the authorization URL */
authorizationUrl.searchParams.set('state', state);
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
const flowMetadata: MCPOAuthFlowMetadata = {
@ -485,8 +505,8 @@ export class MCPOAuthHandler {
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
);
/** Add state parameter with flowId to the authorization URL */
authorizationUrl.searchParams.set('state', flowId);
/** Add cryptographic state parameter to the authorization URL */
authorizationUrl.searchParams.set('state', state);
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
if (resourceMetadata?.resource != null && resourceMetadata.resource) {
@ -652,6 +672,44 @@ export class MCPOAuthHandler {
return randomBytes(32).toString('base64url');
}
private static readonly STATE_MAP_TYPE = 'mcp_oauth_state';
/**
* Stores a mapping from the opaque OAuth state parameter to the flowId.
* This allows the callback to resolve the flowId from an unguessable state
* value, preventing attackers from forging callback requests.
*/
static async storeStateMapping(
state: string,
flowId: string,
flowManager: FlowStateManager<MCPOAuthTokens | null>,
): Promise<void> {
await flowManager.initFlow(state, this.STATE_MAP_TYPE, { flowId });
}
/**
* Resolves an opaque OAuth state parameter back to the original flowId.
* Returns null if the state is not found (expired or never stored).
*/
static async resolveStateToFlowId(
state: string,
flowManager: FlowStateManager<MCPOAuthTokens | null>,
): Promise<string | null> {
const mapping = await flowManager.getFlowState(state, this.STATE_MAP_TYPE);
return (mapping?.metadata?.flowId as string) ?? null;
}
/**
* Deletes an orphaned state mapping when a flow is replaced.
* Prevents old authorization URLs from resolving after a flow restart.
*/
static async deleteStateMapping(
state: string,
flowManager: FlowStateManager<MCPOAuthTokens | null>,
): Promise<void> {
await flowManager.deleteFlow(state, this.STATE_MAP_TYPE);
}
/**
* Gets the default redirect URI for a server
*/
@ -735,9 +793,10 @@ export class MCPOAuthHandler {
throw new Error('No token URL available for refresh');
} else {
/** Auto-discover OAuth configuration for refresh */
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
fetchFn: this.createOAuthFetch(oauthHeaders),
});
const serverUrl = new URL(metadata.serverUrl);
const fetchFn = this.createOAuthFetch(oauthHeaders);
const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn);
if (!oauthMetadata) {
/**
* No metadata discovered - use fallback /token endpoint.
@ -911,9 +970,9 @@ export class MCPOAuthHandler {
}
/** Auto-discover OAuth configuration for refresh */
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
fetchFn: this.createOAuthFetch(oauthHeaders),
});
const serverUrl = new URL(metadata.serverUrl);
const fetchFn = this.createOAuthFetch(oauthHeaders);
const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn);
let tokenUrl: URL;
if (!oauthMetadata?.token_endpoint) {

View file

@ -4,6 +4,15 @@ import type { TokenMethods, IToken } from '@librechat/data-schemas';
import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types';
import { isSystemUserId } from '~/mcp/enum';
export class ReauthenticationRequiredError extends Error {
constructor(serverName: string, reason: 'expired' | 'missing') {
super(
`Re-authentication required for "${serverName}": access token ${reason} and no refresh token available`,
);
this.name = 'ReauthenticationRequiredError';
}
}
interface StoreTokensParams {
userId: string;
serverName: string;
@ -27,7 +36,12 @@ interface GetTokensParams {
findToken: TokenMethods['findToken'];
refreshTokens?: (
refreshToken: string,
metadata: { userId: string; serverName: string; identifier: string },
metadata: {
userId: string;
serverName: string;
identifier: string;
clientInfo?: OAuthClientInformation;
},
) => Promise<MCPOAuthTokens>;
createToken?: TokenMethods['createToken'];
updateToken?: TokenMethods['updateToken'];
@ -273,10 +287,11 @@ export class MCPTokenStorage {
});
if (!refreshTokenData) {
const reason = isMissing ? 'missing' : 'expired';
logger.info(
`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'} and no refresh token available`,
`${logPrefix} Access token ${reason} and no refresh token available — re-authentication required`,
);
return null;
throw new ReauthenticationRequiredError(serverName, reason);
}
if (!refreshTokens) {
@ -395,6 +410,9 @@ export class MCPTokenStorage {
logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`);
return tokens;
} catch (error) {
if (error instanceof ReauthenticationRequiredError) {
throw error;
}
logger.error(`${logPrefix} Failed to retrieve tokens`, error);
return null;
}

View file

@ -88,6 +88,7 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata {
clientInfo?: OAuthClientInformation;
metadata?: OAuthMetadata;
resourceMetadata?: OAuthProtectedResourceMetadata;
authorizationUrl?: string;
}
export interface MCPOAuthTokens extends OAuthTokens {

View file

@ -17,20 +17,20 @@
import * as net from 'net';
import * as http from 'http';
import { Keyv } from 'keyv';
import { Agent } from 'undici';
import { Types } from 'mongoose';
import { randomUUID } from 'crypto';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { Keyv } from 'keyv';
import { Types } from 'mongoose';
import type { IUser } from '@librechat/data-schemas';
import type { Socket } from 'net';
import type * as t from '~/mcp/types';
import { MCPInspectionFailedError } from '~/mcp/errors';
import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache';
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPInspectionFailedError } from '~/mcp/errors';
import { FlowStateManager } from '~/flow/manager';
import { MCPConnection } from '~/mcp/connection';
import { MCPManager } from '~/mcp/MCPManager';

View file

@ -19,7 +19,6 @@ export * from './promise';
export * from './sanitizeTitle';
export * from './tempChatRetention';
export * from './text';
export { default as Tokenizer, countTokens } from './tokenizer';
export * from './yaml';
export * from './http';
export * from './tokens';

View file

@ -65,7 +65,7 @@ const createRealTokenCounter = () => {
let callCount = 0;
const tokenCountFn = (text: string): number => {
callCount++;
return Tokenizer.getTokenCount(text, 'cl100k_base');
return Tokenizer.getTokenCount(text, 'o200k_base');
};
return {
tokenCountFn,
@ -590,9 +590,9 @@ describe('processTextWithTokenLimit', () => {
});
});
describe('direct comparison with REAL tiktoken tokenizer', () => {
beforeEach(() => {
Tokenizer.freeAndResetAllEncoders();
describe('direct comparison with REAL ai-tokenizer', () => {
beforeAll(async () => {
await Tokenizer.initEncoding('o200k_base');
});
it('should produce valid truncation with real tokenizer', async () => {
@ -611,7 +611,7 @@ describe('processTextWithTokenLimit', () => {
expect(result.text.length).toBeLessThan(text.length);
});
it('should use fewer tiktoken calls than old implementation (realistic text)', async () => {
it('should use fewer tokenizer calls than old implementation (realistic text)', async () => {
const oldCounter = createRealTokenCounter();
const newCounter = createRealTokenCounter();
const text = createRealisticText(15000);
@ -623,8 +623,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn,
});
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({
text,
tokenLimit,
@ -634,17 +632,17 @@ describe('processTextWithTokenLimit', () => {
const oldCalls = oldCounter.getCallCount();
const newCalls = newCounter.getCallCount();
console.log(`[Real tiktoken ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`);
console.log(`[Real tiktoken] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
console.log(`[Real tokenizer ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`);
console.log(`[Real tokenizer] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
expect(newCalls).toBeLessThan(oldCalls);
});
it('should handle the reported user scenario with real tokenizer (~120k tokens)', async () => {
it('should handle large text with real tokenizer (~20k tokens)', async () => {
const oldCounter = createRealTokenCounter();
const newCounter = createRealTokenCounter();
const text = createRealisticText(120000);
const tokenLimit = 100000;
const text = createRealisticText(20000);
const tokenLimit = 15000;
const startOld = performance.now();
await processTextWithTokenLimitOLD({
@ -654,8 +652,6 @@ describe('processTextWithTokenLimit', () => {
});
const timeOld = performance.now() - startOld;
Tokenizer.freeAndResetAllEncoders();
const startNew = performance.now();
const result = await processTextWithTokenLimit({
text,
@ -667,9 +663,9 @@ describe('processTextWithTokenLimit', () => {
const oldCalls = oldCounter.getCallCount();
const newCalls = newCounter.getCallCount();
console.log(`\n[REAL TIKTOKEN - User reported scenario: ~120k tokens]`);
console.log(`OLD implementation: ${oldCalls} tiktoken calls, ${timeOld.toFixed(0)}ms`);
console.log(`NEW implementation: ${newCalls} tiktoken calls, ${timeNew.toFixed(0)}ms`);
console.log(`\n[REAL TOKENIZER - ~20k tokens]`);
console.log(`OLD implementation: ${oldCalls} tokenizer calls, ${timeOld.toFixed(0)}ms`);
console.log(`NEW implementation: ${newCalls} tokenizer calls, ${timeNew.toFixed(0)}ms`);
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`);
console.log(
@ -684,8 +680,8 @@ describe('processTextWithTokenLimit', () => {
it('should achieve at least 70% reduction with real tokenizer', async () => {
const oldCounter = createRealTokenCounter();
const newCounter = createRealTokenCounter();
const text = createRealisticText(50000);
const tokenLimit = 10000;
const text = createRealisticText(15000);
const tokenLimit = 5000;
await processTextWithTokenLimitOLD({
text,
@ -693,8 +689,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn,
});
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({
text,
tokenLimit,
@ -706,7 +700,7 @@ describe('processTextWithTokenLimit', () => {
const reduction = 1 - newCalls / oldCalls;
console.log(
`[Real tiktoken 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
`[Real tokenizer 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
);
expect(reduction).toBeGreaterThanOrEqual(0.7);
@ -714,10 +708,6 @@ describe('processTextWithTokenLimit', () => {
});
describe('using countTokens async function from @librechat/api', () => {
beforeEach(() => {
Tokenizer.freeAndResetAllEncoders();
});
it('countTokens should return correct token count', async () => {
const text = 'Hello, world!';
const count = await countTokens(text);
@ -759,8 +749,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn,
});
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({
text,
tokenLimit,
@ -776,11 +764,11 @@ describe('processTextWithTokenLimit', () => {
expect(newCalls).toBeLessThan(oldCalls);
});
it('should handle user reported scenario with countTokens (~120k tokens)', async () => {
it('should handle large text with countTokens (~20k tokens)', async () => {
const oldCounter = createCountTokensCounter();
const newCounter = createCountTokensCounter();
const text = createRealisticText(120000);
const tokenLimit = 100000;
const text = createRealisticText(20000);
const tokenLimit = 15000;
const startOld = performance.now();
await processTextWithTokenLimitOLD({
@ -790,8 +778,6 @@ describe('processTextWithTokenLimit', () => {
});
const timeOld = performance.now() - startOld;
Tokenizer.freeAndResetAllEncoders();
const startNew = performance.now();
const result = await processTextWithTokenLimit({
text,
@ -803,7 +789,7 @@ describe('processTextWithTokenLimit', () => {
const oldCalls = oldCounter.getCallCount();
const newCalls = newCounter.getCallCount();
console.log(`\n[countTokens - User reported scenario: ~120k tokens]`);
console.log(`\n[countTokens - ~20k tokens]`);
console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`);
console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`);
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
@ -820,8 +806,8 @@ describe('processTextWithTokenLimit', () => {
it('should achieve at least 70% reduction with countTokens', async () => {
const oldCounter = createCountTokensCounter();
const newCounter = createCountTokensCounter();
const text = createRealisticText(50000);
const tokenLimit = 10000;
const text = createRealisticText(15000);
const tokenLimit = 5000;
await processTextWithTokenLimitOLD({
text,
@ -829,8 +815,6 @@ describe('processTextWithTokenLimit', () => {
tokenCountFn: oldCounter.tokenCountFn,
});
Tokenizer.freeAndResetAllEncoders();
await processTextWithTokenLimit({
text,
tokenLimit,
@ -842,7 +826,7 @@ describe('processTextWithTokenLimit', () => {
const reduction = 1 - newCalls / oldCalls;
console.log(
`[countTokens 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
`[countTokens 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
);
expect(reduction).toBeGreaterThanOrEqual(0.7);

View file

@ -1,12 +1,3 @@
/**
* @file Tokenizer.spec.cjs
*
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
* Make sure to install `tiktoken` and have it configured properly.
*/
import { logger } from '@librechat/data-schemas';
import type { Tiktoken } from 'tiktoken';
import Tokenizer from './tokenizer';
jest.mock('@librechat/data-schemas', () => ({
@ -17,127 +8,49 @@ jest.mock('@librechat/data-schemas', () => ({
describe('Tokenizer', () => {
it('should be a singleton (same instance)', async () => {
const AnotherTokenizer = await import('./tokenizer'); // same path
const AnotherTokenizer = await import('./tokenizer');
expect(Tokenizer).toBe(AnotherTokenizer.default);
});
describe('getTokenizer', () => {
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
// The real `encoding_for_model` will be called internally
// as soon as we pass isModelName = true.
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
// Basic sanity checks
expect(tokenizer).toBeDefined();
// You can optionally check certain properties from `tiktoken` if they exist
// e.g., expect(typeof tokenizer.encode).toBe('function');
describe('initEncoding', () => {
it('should load o200k_base encoding', async () => {
await Tokenizer.initEncoding('o200k_base');
const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base');
expect(count).toBeGreaterThan(0);
});
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
// The real `get_encoding` will be called internally
// as soon as we pass isModelName = false.
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
expect(tokenizer).toBeDefined();
// e.g., expect(typeof tokenizer.encode).toBe('function');
it('should load claude encoding', async () => {
await Tokenizer.initEncoding('claude');
const count = Tokenizer.getTokenCount('Hello, world!', 'claude');
expect(count).toBeGreaterThan(0);
});
it('should return cached tokenizer if previously fetched', () => {
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
// Should be the exact same instance from the cache
expect(tokenizer1).toBe(tokenizer2);
});
});
describe('freeAndResetAllEncoders', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
// By creating two different encodings, we populate the cache
Tokenizer.getTokenizer('cl100k_base', false);
Tokenizer.getTokenizer('r50k_base', false);
// Now free them
Tokenizer.freeAndResetAllEncoders();
// The internal cache is cleared
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
// tokenizerCallsCount is reset to 1
expect(Tokenizer.tokenizerCallsCount).toBe(1);
});
it('should catch and log errors if freeing fails', () => {
// Mock logger.error before the test
const mockLoggerError = jest.spyOn(logger, 'error');
// Set up a problematic tokenizer in the cache
Tokenizer.tokenizersCache['cl100k_base'] = {
free() {
throw new Error('Intentional free error');
},
} as unknown as Tiktoken;
// Should not throw uncaught errors
Tokenizer.freeAndResetAllEncoders();
// Verify logger.error was called with correct arguments
expect(mockLoggerError).toHaveBeenCalledWith(
'[Tokenizer] Free and reset encoders error',
expect.any(Error),
);
// Clean up
mockLoggerError.mockRestore();
Tokenizer.tokenizersCache = {};
it('should deduplicate concurrent init calls', async () => {
const [, , count] = await Promise.all([
Tokenizer.initEncoding('o200k_base'),
Tokenizer.initEncoding('o200k_base'),
Tokenizer.initEncoding('o200k_base').then(() =>
Tokenizer.getTokenCount('test', 'o200k_base'),
),
]);
expect(count).toBeGreaterThan(0);
});
});
describe('getTokenCount', () => {
beforeEach(() => {
jest.clearAllMocks();
Tokenizer.freeAndResetAllEncoders();
beforeAll(async () => {
await Tokenizer.initEncoding('o200k_base');
await Tokenizer.initEncoding('claude');
});
it('should return the number of tokens in the given text', () => {
const text = 'Hello, world!';
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base');
expect(count).toBeGreaterThan(0);
});
it('should reset encoders if an error is thrown', () => {
// We can simulate an error by temporarily overriding the selected tokenizer's `encode` method.
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
const originalEncode = tokenizer.encode;
tokenizer.encode = () => {
throw new Error('Forced error');
};
// Despite the forced error, the code should catch and reset, then re-encode
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
it('should count tokens using claude encoding', () => {
const count = Tokenizer.getTokenCount('Hello, world!', 'claude');
expect(count).toBeGreaterThan(0);
// Restore the original encode
tokenizer.encode = originalEncode;
});
it('should reset tokenizers after 25 calls', () => {
// Spy on freeAndResetAllEncoders
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
// Make 24 calls; should NOT reset yet
for (let i = 0; i < 24; i++) {
Tokenizer.getTokenCount('test text', 'cl100k_base');
}
expect(resetSpy).not.toHaveBeenCalled();
// 25th call triggers the reset
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
expect(resetSpy).toHaveBeenCalledTimes(1);
});
});
});

View file

@ -1,74 +1,46 @@
import { logger } from '@librechat/data-schemas';
import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken';
import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken';
import { Tokenizer as AiTokenizer } from 'ai-tokenizer';
interface TokenizerOptions {
debug?: boolean;
}
export type EncodingName = 'o200k_base' | 'claude';
type EncodingData = ConstructorParameters<typeof AiTokenizer>[0];
class Tokenizer {
tokenizersCache: Record<string, Tiktoken>;
tokenizerCallsCount: number;
private options?: TokenizerOptions;
private tokenizersCache: Partial<Record<EncodingName, AiTokenizer>> = {};
private loadingPromises: Partial<Record<EncodingName, Promise<void>>> = {};
constructor() {
this.tokenizersCache = {};
this.tokenizerCallsCount = 0;
}
getTokenizer(
encoding: TiktokenModel | TiktokenEncoding,
isModelName = false,
extendSpecialTokens: Record<string, number> = {},
): Tiktoken {
let tokenizer: Tiktoken;
/** Pre-loads an encoding so that subsequent getTokenCount calls are accurate. */
async initEncoding(encoding: EncodingName): Promise<void> {
if (this.tokenizersCache[encoding]) {
tokenizer = this.tokenizersCache[encoding];
} else {
if (isModelName) {
tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens);
}
this.tokenizersCache[encoding] = tokenizer;
return;
}
return tokenizer;
if (this.loadingPromises[encoding]) {
return this.loadingPromises[encoding];
}
this.loadingPromises[encoding] = (async () => {
const data: EncodingData =
encoding === 'claude'
? await import('ai-tokenizer/encoding/claude')
: await import('ai-tokenizer/encoding/o200k_base');
this.tokenizersCache[encoding] = new AiTokenizer(data);
})();
return this.loadingPromises[encoding];
}
freeAndResetAllEncoders(): void {
getTokenCount(text: string, encoding: EncodingName = 'o200k_base'): number {
const tokenizer = this.tokenizersCache[encoding];
if (!tokenizer) {
this.initEncoding(encoding);
return Math.ceil(text.length / 4);
}
try {
Object.keys(this.tokenizersCache).forEach((key) => {
if (this.tokenizersCache[key]) {
this.tokenizersCache[key].free();
delete this.tokenizersCache[key];
}
});
this.tokenizerCallsCount = 1;
} catch (error) {
logger.error('[Tokenizer] Free and reset encoders error', error);
}
}
resetTokenizersIfNecessary(): void {
if (this.tokenizerCallsCount >= 25) {
if (this.options?.debug) {
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
}
this.freeAndResetAllEncoders();
}
this.tokenizerCallsCount++;
}
getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number {
this.resetTokenizersIfNecessary();
try {
const tokenizer = this.getTokenizer(encoding);
return tokenizer.encode(text, 'all').length;
return tokenizer.count(text);
} catch (error) {
logger.error('[Tokenizer] Error getting token count:', error);
this.freeAndResetAllEncoders();
const tokenizer = this.getTokenizer(encoding);
return tokenizer.encode(text, 'all').length;
delete this.tokenizersCache[encoding];
delete this.loadingPromises[encoding];
this.initEncoding(encoding);
return Math.ceil(text.length / 4);
}
}
}
@ -76,13 +48,13 @@ class Tokenizer {
const TokenizerSingleton = new Tokenizer();
/**
* Counts the number of tokens in a given text using tiktoken.
* This is an async wrapper around Tokenizer.getTokenCount for compatibility.
* @param text - The text to be tokenized. Defaults to an empty string if not provided.
* Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding.
* @param text - The text to count tokens in. Defaults to an empty string.
* @returns The number of tokens in the provided text.
*/
export async function countTokens(text = ''): Promise<number> {
return TokenizerSingleton.getTokenCount(text, 'cl100k_base');
await TokenizerSingleton.initEncoding('o200k_base');
return TokenizerSingleton.getTokenCount(text, 'o200k_base');
}
export default TokenizerSingleton;

View file

@ -593,42 +593,3 @@ export function processModelData(input: z.infer<typeof inputSchema>): EndpointTo
return tokenConfig;
}
export const tiktokenModels = new Set([
'text-davinci-003',
'text-davinci-002',
'text-davinci-001',
'text-curie-001',
'text-babbage-001',
'text-ada-001',
'davinci',
'curie',
'babbage',
'ada',
'code-davinci-002',
'code-davinci-001',
'code-cushman-002',
'code-cushman-001',
'davinci-codex',
'cushman-codex',
'text-davinci-edit-001',
'code-davinci-edit-001',
'text-embedding-ada-002',
'text-similarity-davinci-001',
'text-similarity-curie-001',
'text-similarity-babbage-001',
'text-similarity-ada-001',
'text-search-davinci-doc-001',
'text-search-curie-doc-001',
'text-search-babbage-doc-001',
'text-search-ada-doc-001',
'code-search-babbage-code-001',
'code-search-ada-code-001',
'gpt2',
'gpt-4',
'gpt-4-0314',
'gpt-4-32k',
'gpt-4-32k-0314',
'gpt-3.5-turbo',
'gpt-3.5-turbo-0301',
]);