mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 08:25:23 +02:00
Merge branch 'main' into claude/fix-mcp-accent-support-UBEjT
This commit is contained in:
commit
e3671a6835
47 changed files with 5749 additions and 437 deletions
21
.env.example
21
.env.example
|
|
@ -850,3 +850,24 @@ OPENWEATHER_API_KEY=
|
|||
# Skip code challenge method validation (e.g., for AWS Cognito that supports S256 but doesn't advertise it)
|
||||
# When set to true, forces S256 code challenge even if not advertised in .well-known/openid-configuration
|
||||
# MCP_SKIP_CODE_CHALLENGE_CHECK=false
|
||||
|
||||
# Circuit breaker: max connect/disconnect cycles before tripping (per server)
|
||||
# MCP_CB_MAX_CYCLES=7
|
||||
|
||||
# Circuit breaker: sliding window (ms) for counting cycles
|
||||
# MCP_CB_CYCLE_WINDOW_MS=45000
|
||||
|
||||
# Circuit breaker: cooldown (ms) after the cycle breaker trips
|
||||
# MCP_CB_CYCLE_COOLDOWN_MS=15000
|
||||
|
||||
# Circuit breaker: max consecutive failed connection rounds before backoff
|
||||
# MCP_CB_MAX_FAILED_ROUNDS=3
|
||||
|
||||
# Circuit breaker: sliding window (ms) for counting failed rounds
|
||||
# MCP_CB_FAILED_WINDOW_MS=120000
|
||||
|
||||
# Circuit breaker: base backoff (ms) after failed round threshold is reached
|
||||
# MCP_CB_BASE_BACKOFF_MS=30000
|
||||
|
||||
# Circuit breaker: max backoff cap (ms) for exponential backoff
|
||||
# MCP_CB_MAX_BACKOFF_MS=300000
|
||||
|
|
|
|||
10
AGENTS.md
10
AGENTS.md
|
|
@ -149,7 +149,15 @@ Multi-line imports count total character length across all lines. Consolidate va
|
|||
- Run tests from their workspace directory: `cd api && npx jest <pattern>`, `cd packages/api && npx jest <pattern>`, etc.
|
||||
- Frontend tests: `__tests__` directories alongside components; use `test/layout-test-utils` for rendering.
|
||||
- Cover loading, success, and error states for UI/data flows.
|
||||
- Mock data-provider hooks and external dependencies.
|
||||
|
||||
### Philosophy
|
||||
|
||||
- **Real logic over mocks.** Exercise actual code paths with real dependencies. Mocking is a last resort.
|
||||
- **Spies over mocks.** Assert that real functions are called with expected arguments and frequency without replacing underlying logic.
|
||||
- **MongoDB**: use `mongodb-memory-server` for a real in-memory MongoDB instance. Test actual queries and schema validation, not mocked DB calls.
|
||||
- **MCP**: use real `@modelcontextprotocol/sdk` exports for servers, transports, and tool definitions. Mirror real scenarios, don't stub SDK internals.
|
||||
- Only mock what you cannot control: external HTTP APIs, rate-limited services, non-deterministic system calls.
|
||||
- Heavy mocking is a code smell, not a testing strategy.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
const DALLE3 = require('../DALLE3');
|
||||
const { ProxyAgent } = require('undici');
|
||||
|
||||
jest.mock('tiktoken');
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
describe('DALLE3 Proxy Configuration', () => {
|
||||
|
|
|
|||
|
|
@ -14,15 +14,6 @@ jest.mock('@librechat/data-schemas', () => {
|
|||
};
|
||||
});
|
||||
|
||||
jest.mock('tiktoken', () => {
|
||||
return {
|
||||
encoding_for_model: jest.fn().mockReturnValue({
|
||||
encode: jest.fn(),
|
||||
decode: jest.fn(),
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
const generate = jest.fn();
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@
|
|||
"@modelcontextprotocol/sdk": "^1.27.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"ai-tokenizer": "^1.0.6",
|
||||
"axios": "^1.13.5",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"compression": "^1.8.1",
|
||||
|
|
@ -106,7 +107,6 @@
|
|||
"pdfjs-dist": "^5.4.624",
|
||||
"rate-limit-redis": "^4.2.0",
|
||||
"sharp": "^0.33.5",
|
||||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"undici": "^7.18.2",
|
||||
|
|
|
|||
|
|
@ -1172,7 +1172,11 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
}
|
||||
|
||||
/** Anthropic Claude models use a distinct BPE tokenizer; all others default to o200k_base. */
|
||||
getEncoding() {
|
||||
if (this.model && this.model.toLowerCase().includes('claude')) {
|
||||
return 'claude';
|
||||
}
|
||||
return 'o200k_base';
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ jest.mock('@librechat/api', () => {
|
|||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
resolveStateToFlowId: jest.fn(async (state) => state),
|
||||
storeStateMapping: jest.fn(),
|
||||
deleteStateMapping: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
|
|
@ -180,7 +183,10 @@ describe('MCP Routes', () => {
|
|||
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
||||
authorizationUrl: 'https://oauth.example.com/auth',
|
||||
flowId: 'test-user-id:test-server',
|
||||
flowMetadata: { state: 'random-state-value' },
|
||||
});
|
||||
MCPOAuthHandler.storeStateMapping.mockResolvedValue();
|
||||
mockFlowManager.initFlow = jest.fn().mockResolvedValue();
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({
|
||||
userId: 'test-user-id',
|
||||
|
|
@ -367,6 +373,121 @@ describe('MCP Routes', () => {
|
|||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
||||
});
|
||||
|
||||
describe('CSRF fallback via active PENDING flow', () => {
|
||||
it('should proceed when a fresh PENDING flow exists and no cookies are present', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now(),
|
||||
}),
|
||||
completeFlow: jest.fn().mockResolvedValue(true),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: {},
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue({
|
||||
access_token: 'test-token',
|
||||
});
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({});
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getOAuthReconnectionManager.mockReturnValue({
|
||||
clearReconnection: jest.fn(),
|
||||
});
|
||||
require('~/server/services/Config/mcp').updateMCPServerTools.mockResolvedValue();
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
|
||||
it('should reject when no PENDING flow exists and no cookies are present', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue(null),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
status: 'COMPLETED',
|
||||
createdAt: Date.now(),
|
||||
}),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now() - 3 * 60 * 1000,
|
||||
}),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.query({ code: 'test-code', state: flowId });
|
||||
|
||||
const basePath = getBasePath();
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(
|
||||
`${basePath}/oauth/error?error=csrf_validation_failed`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle OAuth callback successfully', async () => {
|
||||
// mockRegistryInstance is defined at the top of the file
|
||||
const mockFlowManager = {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ const {
|
|||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
setOAuthSession,
|
||||
PENDING_STALE_MS,
|
||||
getUserMCPAuthMap,
|
||||
validateOAuthCsrf,
|
||||
OAUTH_CSRF_COOKIE,
|
||||
|
|
@ -91,7 +92,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
|||
}
|
||||
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, userId);
|
||||
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
const {
|
||||
authorizationUrl,
|
||||
flowId: oauthFlowId,
|
||||
flowMetadata,
|
||||
} = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
serverName,
|
||||
serverUrl,
|
||||
userId,
|
||||
|
|
@ -101,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async
|
|||
|
||||
logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl });
|
||||
|
||||
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, oauthFlowId, flowManager);
|
||||
setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
res.redirect(authorizationUrl);
|
||||
} catch (error) {
|
||||
|
|
@ -143,30 +149,52 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
|
||||
}
|
||||
|
||||
const flowId = state;
|
||||
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
const flowId = await MCPOAuthHandler.resolveStateToFlowId(state, flowManager);
|
||||
if (!flowId) {
|
||||
logger.error('[MCP OAuth] Could not resolve state to flow ID', { state });
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
logger.debug('[MCP OAuth] Resolved flow ID from state', { flowId });
|
||||
|
||||
const flowParts = flowId.split(':');
|
||||
if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) {
|
||||
logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId });
|
||||
logger.error('[MCP OAuth] Invalid flow ID format', { flowId });
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
const [flowUserId] = flowParts;
|
||||
if (
|
||||
!validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) &&
|
||||
!validateOAuthSession(req, flowUserId)
|
||||
) {
|
||||
logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', {
|
||||
flowId,
|
||||
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
||||
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
||||
});
|
||||
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
||||
|
||||
const hasCsrf = validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH);
|
||||
const hasSession = !hasCsrf && validateOAuthSession(req, flowUserId);
|
||||
let hasActiveFlow = false;
|
||||
if (!hasCsrf && !hasSession) {
|
||||
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
|
||||
hasActiveFlow = pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS;
|
||||
if (hasActiveFlow) {
|
||||
logger.debug(
|
||||
'[MCP OAuth] CSRF/session cookies absent, validating via active PENDING flow',
|
||||
{
|
||||
flowId,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
if (!hasCsrf && !hasSession && !hasActiveFlow) {
|
||||
logger.error(
|
||||
'[MCP OAuth] CSRF validation failed: no valid CSRF cookie, session cookie, or active flow',
|
||||
{
|
||||
flowId,
|
||||
hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE],
|
||||
hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE],
|
||||
},
|
||||
);
|
||||
return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId);
|
||||
const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager);
|
||||
|
|
@ -281,7 +309,13 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
const toolFlowId = flowState.metadata?.toolFlowId;
|
||||
if (toolFlowId) {
|
||||
logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId });
|
||||
await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
|
||||
const completed = await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens);
|
||||
if (!completed) {
|
||||
logger.warn(
|
||||
'[MCP OAuth] Tool flow state not found during completion — waiter will time out',
|
||||
{ toolFlowId },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Redirect to success page with flowId and serverName */
|
||||
|
|
|
|||
|
|
@ -34,6 +34,55 @@ const { reinitMCPServer } = require('./Tools/mcp');
|
|||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const MAX_CACHE_SIZE = 1000;
|
||||
const lastReconnectAttempts = new Map();
|
||||
const RECONNECT_THROTTLE_MS = 10_000;
|
||||
|
||||
const missingToolCache = new Map();
|
||||
const MISSING_TOOL_TTL_MS = 10_000;
|
||||
|
||||
function evictStale(map, ttl) {
|
||||
if (map.size <= MAX_CACHE_SIZE) {
|
||||
return;
|
||||
}
|
||||
const now = Date.now();
|
||||
for (const [key, timestamp] of map) {
|
||||
if (now - timestamp >= ttl) {
|
||||
map.delete(key);
|
||||
}
|
||||
if (map.size <= MAX_CACHE_SIZE) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const unavailableMsg =
|
||||
"This tool's MCP server is temporarily unavailable. Please try again shortly.";
|
||||
|
||||
/**
|
||||
* @param {string} toolName
|
||||
* @param {string} serverName
|
||||
*/
|
||||
function createUnavailableToolStub(toolName, serverName) {
|
||||
const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`;
|
||||
const _call = async () => [unavailableMsg, null];
|
||||
const toolInstance = tool(_call, {
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string', description: 'Input for the tool' },
|
||||
},
|
||||
required: [],
|
||||
},
|
||||
name: normalizedToolKey,
|
||||
description: unavailableMsg,
|
||||
responseFormat: AgentConstants.CONTENT_AND_ARTIFACT,
|
||||
});
|
||||
toolInstance.mcp = true;
|
||||
toolInstance.mcpRawServerName = serverName;
|
||||
return toolInstance;
|
||||
}
|
||||
|
||||
function isEmptyObjectSchema(jsonSchema) {
|
||||
return (
|
||||
jsonSchema != null &&
|
||||
|
|
@ -211,6 +260,17 @@ async function reconnectServer({
|
|||
logger.debug(
|
||||
`[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`,
|
||||
);
|
||||
|
||||
const throttleKey = `${user.id}:${serverName}`;
|
||||
const now = Date.now();
|
||||
const lastAttempt = lastReconnectAttempts.get(throttleKey) ?? 0;
|
||||
if (now - lastAttempt < RECONNECT_THROTTLE_MS) {
|
||||
logger.debug(`[MCP][reconnectServer] Throttled reconnect for ${serverName}`);
|
||||
return null;
|
||||
}
|
||||
lastReconnectAttempts.set(throttleKey, now);
|
||||
evictStale(lastReconnectAttempts, RECONNECT_THROTTLE_MS);
|
||||
|
||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
|
|
@ -267,7 +327,7 @@ async function reconnectServer({
|
|||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
connectionTimeout: Time.THIRTY_SECONDS,
|
||||
});
|
||||
} finally {
|
||||
// Clean up abort handler to prevent memory leaks
|
||||
|
|
@ -330,9 +390,13 @@ async function createMCPTools({
|
|||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
if (result === null) {
|
||||
logger.debug(`[MCP][${serverName}] Reconnect throttled, skipping tool creation.`);
|
||||
return [];
|
||||
}
|
||||
if (!result || !result.tools) {
|
||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||
return;
|
||||
return [];
|
||||
}
|
||||
|
||||
const serverTools = [];
|
||||
|
|
@ -402,6 +466,14 @@ async function createMCPTool({
|
|||
/** @type {LCTool | undefined} */
|
||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
const cachedAt = missingToolCache.get(toolKey);
|
||||
if (cachedAt && Date.now() - cachedAt < MISSING_TOOL_TTL_MS) {
|
||||
logger.debug(
|
||||
`[MCP][${serverName}][${toolName}] Tool in negative cache, returning unavailable stub.`,
|
||||
);
|
||||
return createUnavailableToolStub(toolName, serverName);
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
|
||||
);
|
||||
|
|
@ -415,11 +487,18 @@ async function createMCPTool({
|
|||
streamId,
|
||||
});
|
||||
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||
|
||||
if (!toolDefinition) {
|
||||
missingToolCache.set(toolKey, Date.now());
|
||||
evictStale(missingToolCache, MISSING_TOOL_TTL_MS);
|
||||
}
|
||||
}
|
||||
|
||||
if (!toolDefinition) {
|
||||
logger.warn(`[MCP][${serverName}][${toolName}] Tool definition not found, cannot create tool.`);
|
||||
return;
|
||||
logger.warn(
|
||||
`[MCP][${serverName}][${toolName}] Tool definition not found, returning unavailable stub.`,
|
||||
);
|
||||
return createUnavailableToolStub(toolName, serverName);
|
||||
}
|
||||
|
||||
return createToolInstance({
|
||||
|
|
@ -720,4 +799,5 @@ module.exports = {
|
|||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
createUnavailableToolStub,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ const {
|
|||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
createUnavailableToolStub,
|
||||
} = require('./MCP');
|
||||
|
||||
jest.mock('./Config', () => ({
|
||||
|
|
@ -1098,6 +1099,188 @@ describe('User parameter passing tests', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('createUnavailableToolStub', () => {
|
||||
it('should return a tool whose _call returns a valid CONTENT_AND_ARTIFACT two-tuple', async () => {
|
||||
const stub = createUnavailableToolStub('myTool', 'myServer');
|
||||
// invoke() goes through langchain's base tool, which checks responseFormat.
|
||||
// CONTENT_AND_ARTIFACT requires [content, artifact] — a bare string would throw:
|
||||
// "Tool response format is "content_and_artifact" but the output was not a two-tuple"
|
||||
const result = await stub.invoke({});
|
||||
// If we reach here without throwing, the two-tuple format is correct.
|
||||
// invoke() returns the content portion of [content, artifact] as a string.
|
||||
expect(result).toContain('temporarily unavailable');
|
||||
});
|
||||
});
|
||||
|
||||
describe('negative tool cache and throttle interaction', () => {
|
||||
it('should cache tool as missing even when throttled (cross-user dedup)', async () => {
|
||||
const mockUser = { id: 'throttle-test-user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// First call: reconnect succeeds but tool not found
|
||||
mockReinitMCPServer.mockResolvedValueOnce({
|
||||
availableTools: {},
|
||||
});
|
||||
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: `missing-tool${D}cache-dedup-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: undefined,
|
||||
});
|
||||
|
||||
// Second call within 10s for DIFFERENT tool on same server:
|
||||
// reconnect is throttled (returns null), tool is still cached as missing.
|
||||
// This is intentional: the cache acts as cross-user dedup since the
|
||||
// throttle is per-user-per-server and can't prevent N different users
|
||||
// from each triggering their own reconnect.
|
||||
const result2 = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: `other-tool${D}cache-dedup-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: undefined,
|
||||
});
|
||||
|
||||
expect(result2).toBeDefined();
|
||||
expect(result2.name).toContain('other-tool');
|
||||
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should prevent user B from triggering reconnect when user A already cached the tool', async () => {
|
||||
const userA = { id: 'cache-user-A' };
|
||||
const userB = { id: 'cache-user-B' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// User A: real reconnect, tool not found → cached
|
||||
mockReinitMCPServer.mockResolvedValueOnce({
|
||||
availableTools: {},
|
||||
});
|
||||
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: userA,
|
||||
toolKey: `shared-tool${D}cross-user-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: undefined,
|
||||
});
|
||||
|
||||
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||
|
||||
// User B requests the SAME tool within 10s.
|
||||
// The negative cache is keyed by toolKey (no user prefix), so user B
|
||||
// gets a cache hit and no reconnect fires. This is the cross-user
|
||||
// storm protection: without this, user B's unthrottled first request
|
||||
// would trigger a second reconnect to the same server.
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: userB,
|
||||
toolKey: `shared-tool${D}cross-user-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: undefined,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.name).toContain('shared-tool');
|
||||
// reinitMCPServer still called only once — user B hit the cache
|
||||
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should prevent user B from triggering reconnect for throttle-cached tools', async () => {
|
||||
const userA = { id: 'storm-user-A' };
|
||||
const userB = { id: 'storm-user-B' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// User A: real reconnect for tool-1, tool not found → cached
|
||||
mockReinitMCPServer.mockResolvedValueOnce({
|
||||
availableTools: {},
|
||||
});
|
||||
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: userA,
|
||||
toolKey: `tool-1${D}storm-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: undefined,
|
||||
});
|
||||
|
||||
// User A: tool-2 on same server within 10s → throttled → cached from throttle
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: userA,
|
||||
toolKey: `tool-2${D}storm-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: undefined,
|
||||
});
|
||||
|
||||
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||
|
||||
// User B requests tool-2 — gets cache hit from the throttle-cached entry.
|
||||
// Without this caching, user B would trigger a real reconnect since
|
||||
// user B has their own throttle key and hasn't reconnected yet.
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: userB,
|
||||
toolKey: `tool-2${D}storm-server`,
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: undefined,
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.name).toContain('tool-2');
|
||||
// Still only 1 real reconnect — user B was protected by the cache
|
||||
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMCPTools throttle handling', () => {
|
||||
it('should return empty array with debug log when reconnect is throttled', async () => {
|
||||
const mockUser = { id: 'throttle-tools-user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// First call: real reconnect
|
||||
mockReinitMCPServer.mockResolvedValueOnce({
|
||||
tools: [{ name: 'tool1' }],
|
||||
availableTools: {
|
||||
[`tool1${D}throttle-tools-server`]: {
|
||||
function: { description: 'Tool 1', parameters: {} },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await createMCPTools({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
serverName: 'throttle-tools-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
});
|
||||
|
||||
// Second call within 10s — throttled
|
||||
const result = await createMCPTools({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
serverName: 'throttle-tools-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
});
|
||||
|
||||
expect(result).toEqual([]);
|
||||
// reinitMCPServer called only once — second was throttled
|
||||
expect(mockReinitMCPServer).toHaveBeenCalledTimes(1);
|
||||
// Should log at debug level (not warn) for throttled case
|
||||
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Reconnect throttled'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('User parameter integrity', () => {
|
||||
it('should preserve user object properties through the call chain', async () => {
|
||||
const complexUser = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
// --- Mocks ---
|
||||
jest.mock('tiktoken');
|
||||
jest.mock('fs');
|
||||
jest.mock('path');
|
||||
jest.mock('node-fetch');
|
||||
|
|
|
|||
23
package-lock.json
generated
23
package-lock.json
generated
|
|
@ -66,6 +66,7 @@
|
|||
"@modelcontextprotocol/sdk": "^1.27.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"ai-tokenizer": "^1.0.6",
|
||||
"axios": "^1.13.5",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"compression": "^1.8.1",
|
||||
|
|
@ -121,7 +122,6 @@
|
|||
"pdfjs-dist": "^5.4.624",
|
||||
"rate-limit-redis": "^4.2.0",
|
||||
"sharp": "^0.33.5",
|
||||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"undici": "^7.18.2",
|
||||
|
|
@ -22230,6 +22230,20 @@
|
|||
"node": ">= 14"
|
||||
}
|
||||
},
|
||||
"node_modules/ai-tokenizer": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/ai-tokenizer/-/ai-tokenizer-1.0.6.tgz",
|
||||
"integrity": "sha512-GaakQFxen0pRH/HIA4v68ZM40llCH27HUYUSBLK+gVuZ57e53pYJe1xFvSTj4sJJjbWU92m1X6NjPWyeWkFDow==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"ai": "^5.0.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"ai": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/ajv": {
|
||||
"version": "8.18.0",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
|
|
@ -41485,11 +41499,6 @@
|
|||
"node": ">=0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/tiktoken": {
|
||||
"version": "1.0.15",
|
||||
"resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.15.tgz",
|
||||
"integrity": "sha512-sCsrq/vMWUSEW29CJLNmPvWxlVp7yh2tlkAjpJltIKqp5CKf98ZNpdeHRmAlPVFlGEbswDc6SmI8vz64W/qErw=="
|
||||
},
|
||||
"node_modules/timers-browserify": {
|
||||
"version": "2.0.12",
|
||||
"resolved": "https://registry.npmjs.org/timers-browserify/-/timers-browserify-2.0.12.tgz",
|
||||
|
|
@ -44200,6 +44209,7 @@
|
|||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.27.1",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"ai-tokenizer": "^1.0.6",
|
||||
"axios": "^1.13.5",
|
||||
"connect-redis": "^8.1.0",
|
||||
"eventsource": "^3.0.2",
|
||||
|
|
@ -44222,7 +44232,6 @@
|
|||
"node-fetch": "2.7.0",
|
||||
"pdfjs-dist": "^5.4.624",
|
||||
"rate-limit-redis": "^4.2.0",
|
||||
"tiktoken": "^1.0.15",
|
||||
"undici": "^7.18.2",
|
||||
"zod": "^3.22.4"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ export default {
|
|||
'\\.dev\\.ts$',
|
||||
'\\.helper\\.ts$',
|
||||
'\\.helper\\.d\\.ts$',
|
||||
'/__tests__/helpers/',
|
||||
],
|
||||
coverageReporters: ['text', 'cobertura'],
|
||||
testResultsProcessor: 'jest-junit',
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@
|
|||
"build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs",
|
||||
"build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs",
|
||||
"build:watch:prod": "rollup -c -w --bundleConfigAsCjs",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.|__tests__/helpers/\"",
|
||||
"test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
"test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
|
||||
"test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
|
|
@ -94,6 +94,7 @@
|
|||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.27.1",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"ai-tokenizer": "^1.0.6",
|
||||
"axios": "^1.13.5",
|
||||
"connect-redis": "^8.1.0",
|
||||
"eventsource": "^3.0.2",
|
||||
|
|
@ -116,7 +117,6 @@
|
|||
"node-fetch": "2.7.0",
|
||||
"pdfjs-dist": "^5.4.624",
|
||||
"rate-limit-redis": "^4.2.0",
|
||||
"tiktoken": "^1.0.15",
|
||||
"undici": "^7.18.2",
|
||||
"zod": "^3.22.4"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@ jest.mock('winston', () => ({
|
|||
}));
|
||||
|
||||
// Mock the Tokenizer
|
||||
jest.mock('~/utils', () => ({
|
||||
Tokenizer: {
|
||||
jest.mock('~/utils/tokenizer', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getTokenCount: jest.fn((text: string) => text.length), // Simple mock: 1 char = 1 token
|
||||
},
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
|
|||
import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import { GenerationJobManager } from '~/stream/GenerationJobManager';
|
||||
import { Tokenizer, resolveHeaders, createSafeUser } from '~/utils';
|
||||
import { resolveHeaders, createSafeUser } from '~/utils';
|
||||
import Tokenizer from '~/utils/tokenizer';
|
||||
|
||||
type RequiredMemoryMethods = Pick<
|
||||
MemoryMethods,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,18 @@ import { logger } from '@librechat/data-schemas';
|
|||
import type { StoredDataNoRaw } from 'keyv';
|
||||
import type { FlowState, FlowMetadata, FlowManagerOptions } from './types';
|
||||
|
||||
export const PENDING_STALE_MS = 2 * 60 * 1000;
|
||||
|
||||
const SECONDS_THRESHOLD = 1e10;
|
||||
|
||||
/**
|
||||
* Normalizes an expiration timestamp to milliseconds.
|
||||
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
|
||||
*/
|
||||
export function normalizeExpiresAt(timestamp: number): number {
|
||||
return timestamp < SECONDS_THRESHOLD ? timestamp * 1000 : timestamp;
|
||||
}
|
||||
|
||||
export class FlowStateManager<T = unknown> {
|
||||
private keyv: Keyv;
|
||||
private ttl: number;
|
||||
|
|
@ -45,32 +57,8 @@ export class FlowStateManager<T = unknown> {
|
|||
return `${type}:${flowId}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes an expiration timestamp to milliseconds.
|
||||
* Detects whether the input is in seconds or milliseconds based on magnitude.
|
||||
* Timestamps below 10 billion are assumed to be in seconds (valid until ~2286).
|
||||
* @param timestamp - The expiration timestamp (in seconds or milliseconds)
|
||||
* @returns The timestamp normalized to milliseconds
|
||||
*/
|
||||
private normalizeExpirationTimestamp(timestamp: number): number {
|
||||
const SECONDS_THRESHOLD = 1e10;
|
||||
if (timestamp < SECONDS_THRESHOLD) {
|
||||
return timestamp * 1000;
|
||||
}
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a flow's token has expired based on its expires_at field
|
||||
* @param flowState - The flow state to check
|
||||
* @returns true if the token has expired, false otherwise (including if no expires_at exists)
|
||||
*/
|
||||
private isTokenExpired(flowState: FlowState<T> | undefined): boolean {
|
||||
if (!flowState?.result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (typeof flowState.result !== 'object') {
|
||||
if (!flowState?.result || typeof flowState.result !== 'object') {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -79,13 +67,11 @@ export class FlowStateManager<T = unknown> {
|
|||
}
|
||||
|
||||
const expiresAt = (flowState.result as { expires_at: unknown }).expires_at;
|
||||
|
||||
if (typeof expiresAt !== 'number' || !Number.isFinite(expiresAt)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const normalizedExpiresAt = this.normalizeExpirationTimestamp(expiresAt);
|
||||
return normalizedExpiresAt < Date.now();
|
||||
return normalizeExpiresAt(expiresAt) < Date.now();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -149,6 +135,8 @@ export class FlowStateManager<T = unknown> {
|
|||
let elapsedTime = 0;
|
||||
let isCleanedUp = false;
|
||||
let intervalId: NodeJS.Timeout | null = null;
|
||||
let missingStateRetried = false;
|
||||
let isRetrying = false;
|
||||
|
||||
// Cleanup function to avoid duplicate cleanup
|
||||
const cleanup = () => {
|
||||
|
|
@ -188,16 +176,29 @@ export class FlowStateManager<T = unknown> {
|
|||
}
|
||||
|
||||
intervalId = setInterval(async () => {
|
||||
if (isCleanedUp) return;
|
||||
if (isCleanedUp || isRetrying) return;
|
||||
|
||||
try {
|
||||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
let flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
||||
if (!flowState) {
|
||||
cleanup();
|
||||
logger.error(`[${flowKey}] Flow state not found`);
|
||||
reject(new Error(`${type} Flow state not found`));
|
||||
return;
|
||||
if (!missingStateRetried) {
|
||||
missingStateRetried = true;
|
||||
isRetrying = true;
|
||||
logger.warn(
|
||||
`[${flowKey}] Flow state not found, retrying once after 500ms (race recovery)`,
|
||||
);
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
isRetrying = false;
|
||||
}
|
||||
|
||||
if (!flowState) {
|
||||
cleanup();
|
||||
logger.error(`[${flowKey}] Flow state not found after retry`);
|
||||
reject(new Error(`${type} Flow state not found`));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (signal?.aborted) {
|
||||
|
|
@ -251,10 +252,10 @@ export class FlowStateManager<T = unknown> {
|
|||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
||||
if (!flowState) {
|
||||
logger.warn('[FlowStateManager] Cannot complete flow - flow state not found', {
|
||||
flowId,
|
||||
type,
|
||||
});
|
||||
logger.warn(
|
||||
'[FlowStateManager] Flow state not found during completion — cannot recover metadata, skipping',
|
||||
{ flowId, type },
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -297,7 +298,7 @@ export class FlowStateManager<T = unknown> {
|
|||
async isFlowStale(
|
||||
flowId: string,
|
||||
type: string,
|
||||
staleThresholdMs: number = 2 * 60 * 1000,
|
||||
staleThresholdMs: number = PENDING_STALE_MS,
|
||||
): Promise<{ isStale: boolean; age: number; status?: string }> {
|
||||
const flowKey = this.getFlowKey(flowId, type);
|
||||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ export * from './mcp/errors';
|
|||
/* Utilities */
|
||||
export * from './mcp/utils';
|
||||
export * from './utils';
|
||||
export { default as Tokenizer, countTokens } from './utils/tokenizer';
|
||||
export type { EncodingName } from './utils/tokenizer';
|
||||
export * from './db/utils';
|
||||
/* OAuth */
|
||||
export * from './oauth';
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import { logger } from '@librechat/data-schemas';
|
|||
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { Tool } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth';
|
||||
import type { MCPOAuthTokens, OAuthMetadata, MCPOAuthFlowMetadata } from '~/mcp/oauth';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { FlowMetadata } from '~/flow/types';
|
||||
import type * as t from './types';
|
||||
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { PENDING_STALE_MS, normalizeExpiresAt } from '~/flow/manager';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
import { MCPConnection } from './connection';
|
||||
|
|
@ -104,6 +104,7 @@ export class MCPConnectionFactory {
|
|||
return { tools, connection, oauthRequired: false, oauthUrl: null };
|
||||
}
|
||||
} catch {
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
logger.debug(
|
||||
`${this.logPrefix} [Discovery] Connection failed, attempting unauthenticated tool listing`,
|
||||
);
|
||||
|
|
@ -125,7 +126,9 @@ export class MCPConnectionFactory {
|
|||
}
|
||||
return { tools, connection: null, oauthRequired, oauthUrl };
|
||||
}
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
} catch (listError) {
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
logger.debug(`${this.logPrefix} [Discovery] Unauthenticated tool listing failed:`, listError);
|
||||
}
|
||||
|
||||
|
|
@ -265,6 +268,10 @@ export class MCPConnectionFactory {
|
|||
if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`);
|
||||
return tokens;
|
||||
} catch (error) {
|
||||
if (error instanceof ReauthenticationRequiredError) {
|
||||
logger.info(`${this.logPrefix} ${error.message}, will trigger OAuth flow`);
|
||||
return null;
|
||||
}
|
||||
logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error);
|
||||
return null;
|
||||
}
|
||||
|
|
@ -306,11 +313,21 @@ export class MCPConnectionFactory {
|
|||
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow?.status === 'PENDING') {
|
||||
const pendingAge = existingFlow.createdAt
|
||||
? Date.now() - existingFlow.createdAt
|
||||
: Infinity;
|
||||
|
||||
if (pendingAge < PENDING_STALE_MS) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Recent PENDING OAuth flow exists (${Math.round(pendingAge / 1000)}s old), skipping new initiation`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
|
||||
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will replace`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
|
|
@ -326,11 +343,17 @@ export class MCPConnectionFactory {
|
|||
);
|
||||
|
||||
if (existingFlow) {
|
||||
const oldState = (existingFlow.metadata as MCPOAuthFlowMetadata)?.state;
|
||||
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
|
||||
if (oldState) {
|
||||
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager!);
|
||||
}
|
||||
}
|
||||
|
||||
// Store flow state BEFORE redirecting so the callback can find it
|
||||
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', flowMetadata);
|
||||
const metadataWithUrl = { ...flowMetadata, authorizationUrl };
|
||||
await this.flowManager!.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
|
||||
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager!);
|
||||
|
||||
// Start monitoring in background — createFlow will find the existing PENDING state
|
||||
// written by initFlow above, so metadata arg is unused (pass {} to make that explicit)
|
||||
|
|
@ -495,11 +518,75 @@ export class MCPConnectionFactory {
|
|||
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow) {
|
||||
const flowMeta = existingFlow.metadata as MCPOAuthFlowMetadata | undefined;
|
||||
|
||||
if (existingFlow.status === 'PENDING') {
|
||||
const pendingAge = existingFlow.createdAt
|
||||
? Date.now() - existingFlow.createdAt
|
||||
: Infinity;
|
||||
|
||||
if (pendingAge < PENDING_STALE_MS) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found recent PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), joining instead of creating new one`,
|
||||
);
|
||||
|
||||
const storedAuthUrl = flowMeta?.authorizationUrl;
|
||||
if (storedAuthUrl && typeof this.oauthStart === 'function') {
|
||||
logger.info(
|
||||
`${this.logPrefix} Re-issuing stored authorization URL to caller while joining PENDING flow`,
|
||||
);
|
||||
await this.oauthStart(storedAuthUrl);
|
||||
}
|
||||
|
||||
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth', {}, this.signal);
|
||||
if (typeof this.oauthEnd === 'function') {
|
||||
await this.oauthEnd();
|
||||
}
|
||||
logger.info(
|
||||
`${this.logPrefix} Joined existing OAuth flow completed for ${this.serverName}`,
|
||||
);
|
||||
return {
|
||||
tokens,
|
||||
clientInfo: flowMeta?.clientInfo,
|
||||
metadata: flowMeta?.metadata,
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found stale PENDING OAuth flow (${Math.round(pendingAge / 1000)}s old), will delete and start fresh`,
|
||||
);
|
||||
}
|
||||
|
||||
if (existingFlow.status === 'COMPLETED') {
|
||||
const completedAge = existingFlow.completedAt
|
||||
? Date.now() - existingFlow.completedAt
|
||||
: Infinity;
|
||||
const cachedTokens = existingFlow.result as MCPOAuthTokens | null | undefined;
|
||||
const isTokenExpired =
|
||||
cachedTokens?.expires_at != null &&
|
||||
normalizeExpiresAt(cachedTokens.expires_at) < Date.now();
|
||||
|
||||
if (completedAge <= PENDING_STALE_MS && cachedTokens !== undefined && !isTokenExpired) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found non-stale COMPLETED OAuth flow, reusing cached tokens`,
|
||||
);
|
||||
return {
|
||||
tokens: cachedTokens,
|
||||
clientInfo: flowMeta?.clientInfo,
|
||||
metadata: flowMeta?.metadata,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
|
||||
);
|
||||
try {
|
||||
const oldState = flowMeta?.state;
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
if (oldState) {
|
||||
await MCPOAuthHandler.deleteStateMapping(oldState, this.flowManager);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
|
||||
}
|
||||
|
|
@ -519,7 +606,9 @@ export class MCPConnectionFactory {
|
|||
);
|
||||
|
||||
// Store flow state BEFORE redirecting so the callback can find it
|
||||
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', flowMetadata as FlowMetadata);
|
||||
const metadataWithUrl = { ...flowMetadata, authorizationUrl };
|
||||
await this.flowManager.initFlow(newFlowId, 'mcp_oauth', metadataWithUrl);
|
||||
await MCPOAuthHandler.storeStateMapping(flowMetadata.state, newFlowId, this.flowManager);
|
||||
|
||||
if (typeof this.oauthStart === 'function') {
|
||||
logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`);
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { MCPConnection } from './connection';
|
||||
import type * as t from './types';
|
||||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPConnection } from './connection';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
||||
/**
|
||||
|
|
@ -21,6 +21,8 @@ export abstract class UserConnectionManager {
|
|||
protected userConnections: Map<string, Map<string, MCPConnection>> = new Map();
|
||||
/** Last activity timestamp for users (not per server) */
|
||||
protected userLastActivity: Map<string, number> = new Map();
|
||||
/** In-flight connection promises keyed by `userId:serverName` — coalesces concurrent attempts */
|
||||
protected pendingConnections: Map<string, Promise<MCPConnection>> = new Map();
|
||||
|
||||
/** Updates the last activity timestamp for a user */
|
||||
protected updateUserLastActivity(userId: string): void {
|
||||
|
|
@ -31,29 +33,64 @@ export abstract class UserConnectionManager {
|
|||
);
|
||||
}
|
||||
|
||||
/** Gets or creates a connection for a specific user */
|
||||
public async getUserConnection({
|
||||
serverName,
|
||||
forceNew,
|
||||
user,
|
||||
flowManager,
|
||||
customUserVars,
|
||||
requestBody,
|
||||
tokenMethods,
|
||||
oauthStart,
|
||||
oauthEnd,
|
||||
signal,
|
||||
returnOnOAuth = false,
|
||||
connectionTimeout,
|
||||
}: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>): Promise<MCPConnection> {
|
||||
/** Gets or creates a connection for a specific user, coalescing concurrent attempts */
|
||||
public async getUserConnection(
|
||||
opts: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
|
||||
): Promise<MCPConnection> {
|
||||
const { serverName, forceNew, user } = opts;
|
||||
const userId = user?.id;
|
||||
if (!userId) {
|
||||
throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`);
|
||||
}
|
||||
|
||||
const lockKey = `${userId}:${serverName}`;
|
||||
|
||||
if (!forceNew) {
|
||||
const pending = this.pendingConnections.get(lockKey);
|
||||
if (pending) {
|
||||
logger.debug(`[MCP][User: ${userId}][${serverName}] Joining in-flight connection attempt`);
|
||||
return pending;
|
||||
}
|
||||
}
|
||||
|
||||
const connectionPromise = this.createUserConnectionInternal(opts, userId);
|
||||
|
||||
if (!forceNew) {
|
||||
this.pendingConnections.set(lockKey, connectionPromise);
|
||||
}
|
||||
|
||||
try {
|
||||
return await connectionPromise;
|
||||
} finally {
|
||||
if (!forceNew && this.pendingConnections.get(lockKey) === connectionPromise) {
|
||||
this.pendingConnections.delete(lockKey);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async createUserConnectionInternal(
|
||||
{
|
||||
serverName,
|
||||
forceNew,
|
||||
user,
|
||||
flowManager,
|
||||
customUserVars,
|
||||
requestBody,
|
||||
tokenMethods,
|
||||
oauthStart,
|
||||
oauthEnd,
|
||||
signal,
|
||||
returnOnOAuth = false,
|
||||
connectionTimeout,
|
||||
}: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
|
||||
userId: string,
|
||||
): Promise<MCPConnection> {
|
||||
if (await this.appConnections!.has(serverName)) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
|
|
@ -65,6 +102,9 @@ export abstract class UserConnectionManager {
|
|||
|
||||
const userServerMap = this.userConnections.get(userId);
|
||||
let connection = forceNew ? undefined : userServerMap?.get(serverName);
|
||||
if (forceNew) {
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
}
|
||||
const now = Date.now();
|
||||
|
||||
// Check if user is idle
|
||||
|
|
@ -185,6 +225,7 @@ export abstract class UserConnectionManager {
|
|||
|
||||
/** Disconnects and removes a specific user connection */
|
||||
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
|
||||
this.pendingConnections.delete(`${userId}:${serverName}`);
|
||||
const userMap = this.userConnections.get(userId);
|
||||
const connection = userMap?.get(serverName);
|
||||
if (connection) {
|
||||
|
|
@ -212,6 +253,12 @@ export abstract class UserConnectionManager {
|
|||
);
|
||||
}
|
||||
await Promise.allSettled(disconnectPromises);
|
||||
// Clean up any pending connection promises for this user
|
||||
for (const key of this.pendingConnections.keys()) {
|
||||
if (key.startsWith(`${userId}:`)) {
|
||||
this.pendingConnections.delete(key);
|
||||
}
|
||||
}
|
||||
// Ensure user activity timestamp is removed
|
||||
this.userLastActivity.delete(userId);
|
||||
logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);
|
||||
|
|
|
|||
|
|
@ -559,3 +559,242 @@ describe('extractSSEErrorMessage', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Tests for circuit breaker logic.
|
||||
*
|
||||
* Uses standalone implementations that mirror the static/private circuit breaker
|
||||
* methods in MCPConnection. Same approach as the error detection tests above.
|
||||
*/
|
||||
describe('MCPConnection Circuit Breaker', () => {
|
||||
/** 5 cycles within 60s triggers a 30s cooldown */
|
||||
const CB_MAX_CYCLES = 5;
|
||||
const CB_CYCLE_WINDOW_MS = 60_000;
|
||||
const CB_CYCLE_COOLDOWN_MS = 30_000;
|
||||
|
||||
/** 3 failed rounds within 120s triggers exponential backoff (30s - 300s) */
|
||||
const CB_MAX_FAILED_ROUNDS = 3;
|
||||
const CB_FAILED_WINDOW_MS = 120_000;
|
||||
const CB_BASE_BACKOFF_MS = 30_000;
|
||||
const CB_MAX_BACKOFF_MS = 300_000;
|
||||
|
||||
interface CircuitBreakerState {
|
||||
cycleCount: number;
|
||||
cycleWindowStart: number;
|
||||
cooldownUntil: number;
|
||||
failedRounds: number;
|
||||
failedWindowStart: number;
|
||||
failedBackoffUntil: number;
|
||||
}
|
||||
|
||||
function createCB(): CircuitBreakerState {
|
||||
return {
|
||||
cycleCount: 0,
|
||||
cycleWindowStart: Date.now(),
|
||||
cooldownUntil: 0,
|
||||
failedRounds: 0,
|
||||
failedWindowStart: Date.now(),
|
||||
failedBackoffUntil: 0,
|
||||
};
|
||||
}
|
||||
|
||||
function isCircuitOpen(cb: CircuitBreakerState): boolean {
|
||||
const now = Date.now();
|
||||
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
|
||||
}
|
||||
|
||||
function recordCycle(cb: CircuitBreakerState): void {
|
||||
const now = Date.now();
|
||||
if (now - cb.cycleWindowStart > CB_CYCLE_WINDOW_MS) {
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
}
|
||||
cb.cycleCount++;
|
||||
if (cb.cycleCount >= CB_MAX_CYCLES) {
|
||||
cb.cooldownUntil = now + CB_CYCLE_COOLDOWN_MS;
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
}
|
||||
}
|
||||
|
||||
function recordFailedRound(cb: CircuitBreakerState): void {
|
||||
const now = Date.now();
|
||||
if (now - cb.failedWindowStart > CB_FAILED_WINDOW_MS) {
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = now;
|
||||
}
|
||||
cb.failedRounds++;
|
||||
if (cb.failedRounds >= CB_MAX_FAILED_ROUNDS) {
|
||||
const backoff = Math.min(
|
||||
CB_BASE_BACKOFF_MS * Math.pow(2, cb.failedRounds - CB_MAX_FAILED_ROUNDS),
|
||||
CB_MAX_BACKOFF_MS,
|
||||
);
|
||||
cb.failedBackoffUntil = now + backoff;
|
||||
}
|
||||
}
|
||||
|
||||
function resetFailedRounds(cb: CircuitBreakerState): void {
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = Date.now();
|
||||
cb.failedBackoffUntil = 0;
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
describe('cycle tracking', () => {
|
||||
it('should not trigger cooldown for fewer than 5 cycles', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should trigger 30s cooldown after 5 cycles within 60s', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_CYCLES; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(29_000);
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(1_000);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should reset cycle count when window expires', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_CYCLES - 1; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
|
||||
jest.advanceTimersByTime(CB_CYCLE_WINDOW_MS + 1);
|
||||
|
||||
recordCycle(cb);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('failed round tracking', () => {
|
||||
it('should not trigger backoff for fewer than 3 failures', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_FAILED_ROUNDS - 1; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should trigger 30s backoff after 3 failures within 120s', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(CB_BASE_BACKOFF_MS);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
|
||||
it('should use exponential backoff based on failure count', () => {
|
||||
jest.setSystemTime(Date.now());
|
||||
|
||||
const cb = createCB();
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(30_000);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(60_000);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(120_000);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(240_000);
|
||||
|
||||
// capped at 300s
|
||||
recordFailedRound(cb);
|
||||
expect(cb.failedBackoffUntil - Date.now()).toBe(300_000);
|
||||
});
|
||||
|
||||
it('should reset failed window when window expires', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
recordFailedRound(cb);
|
||||
recordFailedRound(cb);
|
||||
|
||||
jest.advanceTimersByTime(CB_FAILED_WINDOW_MS + 1);
|
||||
|
||||
recordFailedRound(cb);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('resetFailedRounds', () => {
|
||||
it('should clear failed round state on successful connection', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const cb = createCB();
|
||||
for (let i = 0; i < CB_MAX_FAILED_ROUNDS; i++) {
|
||||
recordFailedRound(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
resetFailedRounds(cb);
|
||||
expect(isCircuitOpen(cb)).toBe(false);
|
||||
expect(cb.failedRounds).toBe(0);
|
||||
expect(cb.failedBackoffUntil).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearCooldown (registry deletion)', () => {
|
||||
it('should allow connections after clearing circuit breaker state', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
const registry = new Map<string, CircuitBreakerState>();
|
||||
const serverName = 'test-server';
|
||||
|
||||
const cb = createCB();
|
||||
registry.set(serverName, cb);
|
||||
|
||||
for (let i = 0; i < CB_MAX_CYCLES; i++) {
|
||||
recordCycle(cb);
|
||||
}
|
||||
expect(isCircuitOpen(cb)).toBe(true);
|
||||
|
||||
registry.delete(serverName);
|
||||
|
||||
const newCb = createCB();
|
||||
expect(isCircuitOpen(newCb)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ describe('MCPConnection Agent lifecycle – streamable-http', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
@ -366,6 +367,7 @@ describe('MCPConnection Agent lifecycle – SSE', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test-sse');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
@ -453,6 +455,7 @@ describe('Regression: old per-request Agent pattern leaks agents', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test-regression');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
@ -675,6 +678,7 @@ describe('MCPConnection SSE GET stream recovery – integration', () => {
|
|||
});
|
||||
|
||||
afterEach(async () => {
|
||||
MCPConnection.clearCooldown('test-sse-recovery');
|
||||
await safeDisconnect(conn);
|
||||
conn = null;
|
||||
jest.restoreAllMocks();
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ describe('MCPConnectionFactory', () => {
|
|||
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
|
||||
'flow123',
|
||||
'mcp_oauth',
|
||||
mockFlowData.flowMetadata,
|
||||
expect.objectContaining(mockFlowData.flowMetadata),
|
||||
);
|
||||
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
|
||||
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
|
||||
|
|
@ -550,7 +550,7 @@ describe('MCPConnectionFactory', () => {
|
|||
expect(mockFlowManager.initFlow).toHaveBeenCalledWith(
|
||||
'flow123',
|
||||
'mcp_oauth',
|
||||
mockFlowData.flowMetadata,
|
||||
expect.objectContaining(mockFlowData.flowMetadata),
|
||||
);
|
||||
const initCallOrder = mockFlowManager.initFlow.mock.invocationCallOrder[0];
|
||||
const oauthStartCallOrder = (oauthOptions.oauthStart as jest.Mock).mock
|
||||
|
|
|
|||
232
packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts
Normal file
232
packages/api/src/mcp/__tests__/MCPOAuthCSRFFallback.test.ts
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
/**
|
||||
* Tests for the OAuth callback CSRF fallback logic.
|
||||
*
|
||||
* The callback route validates requests via three mechanisms (in order):
|
||||
* 1. CSRF cookie (HMAC-based, set during initiate)
|
||||
* 2. Session cookie (bound to authenticated userId)
|
||||
* 3. Active PENDING flow in FlowStateManager (fallback for SSE/chat flows)
|
||||
*
|
||||
* This suite tests mechanism 3 — the PENDING flow fallback — including
|
||||
* staleness enforcement and rejection of non-PENDING flows.
|
||||
*
|
||||
* These tests exercise the validation functions directly for fast,
|
||||
* focused coverage. Route-level integration tests using supertest
|
||||
* are in api/server/routes/__tests__/mcp.spec.js ("CSRF fallback
|
||||
* via active PENDING flow" describe block).
|
||||
*/
|
||||
|
||||
import { Keyv } from 'keyv';
|
||||
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
|
||||
import type { Request, Response } from 'express';
|
||||
import {
|
||||
generateOAuthCsrfToken,
|
||||
OAUTH_SESSION_COOKIE,
|
||||
validateOAuthSession,
|
||||
OAUTH_CSRF_COOKIE,
|
||||
validateOAuthCsrf,
|
||||
} from '~/oauth/csrf';
|
||||
import { MockKeyv } from './helpers/oauthTestServer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
const CSRF_COOKIE_PATH = '/api/mcp';
|
||||
|
||||
function makeReq(cookies: Record<string, string> = {}): Request {
|
||||
return { cookies } as unknown as Request;
|
||||
}
|
||||
|
||||
function makeRes(): Response {
|
||||
const res = {
|
||||
clearCookie: jest.fn(),
|
||||
} as unknown as Response;
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Replicate the callback route's three-tier validation logic.
|
||||
* Returns which mechanism (if any) authorized the request.
|
||||
*/
|
||||
async function validateCallback(
|
||||
req: Request,
|
||||
res: Response,
|
||||
flowId: string,
|
||||
flowManager: FlowStateManager,
|
||||
): Promise<'csrf' | 'session' | 'pendingFlow' | false> {
|
||||
const flowUserId = flowId.split(':')[0];
|
||||
|
||||
const hasCsrf = validateOAuthCsrf(req, res, flowId, CSRF_COOKIE_PATH);
|
||||
if (hasCsrf) {
|
||||
return 'csrf';
|
||||
}
|
||||
|
||||
const hasSession = validateOAuthSession(req, flowUserId);
|
||||
if (hasSession) {
|
||||
return 'session';
|
||||
}
|
||||
|
||||
const pendingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
const pendingAge = pendingFlow?.createdAt ? Date.now() - pendingFlow.createdAt : Infinity;
|
||||
if (pendingFlow?.status === 'PENDING' && pendingAge < PENDING_STALE_MS) {
|
||||
return 'pendingFlow';
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
describe('OAuth Callback CSRF Fallback', () => {
|
||||
let flowManager: FlowStateManager;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.JWT_SECRET = 'test-secret-for-csrf';
|
||||
const store = new MockKeyv();
|
||||
flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 300000, ci: true });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
delete process.env.JWT_SECRET;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('CSRF cookie validation (mechanism 1)', () => {
|
||||
it('should accept valid CSRF cookie', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('csrf');
|
||||
});
|
||||
|
||||
it('should reject invalid CSRF cookie', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const req = makeReq({ [OAUTH_CSRF_COOKIE]: 'wrong-token-value' });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Session cookie validation (mechanism 2)', () => {
|
||||
it('should accept valid session cookie when CSRF is absent', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('session');
|
||||
});
|
||||
});
|
||||
|
||||
describe('PENDING flow fallback (mechanism 3)', () => {
|
||||
it('should accept when a fresh PENDING flow exists and no cookies are present', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('pendingFlow');
|
||||
});
|
||||
|
||||
it('should reject when no PENDING flow, no CSRF cookie, and no session cookie', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject when only a COMPLETED flow exists (not PENDING)', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', { access_token: 'tok' } as never);
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject when only a FAILED flow exists', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
await flowManager.failFlow(flowId, 'mcp_oauth', 'some error');
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject when PENDING flow is stale (older than PENDING_STALE_MS)', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
// Artificially age the flow past the staleness threshold
|
||||
const store = (flowManager as unknown as { keyv: { get: (k: string) => Promise<unknown> } })
|
||||
.keyv;
|
||||
const flowState = (await store.get(`mcp_oauth:${flowId}`)) as { createdAt: number };
|
||||
flowState.createdAt = Date.now() - PENDING_STALE_MS - 1000;
|
||||
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should accept PENDING flow that is just under the staleness threshold', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
// Flow was just created, well under threshold
|
||||
const req = makeReq();
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('pendingFlow');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Priority ordering', () => {
|
||||
it('should prefer CSRF cookie over PENDING flow', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
const csrfToken = generateOAuthCsrfToken(flowId, 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_CSRF_COOKIE]: csrfToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('csrf');
|
||||
});
|
||||
|
||||
it('should prefer session cookie over PENDING flow when CSRF is absent', async () => {
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { serverName: 'test-server' });
|
||||
|
||||
const sessionToken = generateOAuthCsrfToken('user1', 'test-secret-for-csrf');
|
||||
const req = makeReq({ [OAUTH_SESSION_COOKIE]: sessionToken });
|
||||
const res = makeRes();
|
||||
|
||||
const result = await validateCallback(req, res, flowId, flowManager);
|
||||
expect(result).toBe('session');
|
||||
});
|
||||
});
|
||||
});
|
||||
268
packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts
Normal file
268
packages/api/src/mcp/__tests__/MCPOAuthConnectionEvents.test.ts
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
/**
|
||||
* Tests for MCPConnection OAuth event cycle against a real OAuth-gated MCP server.
|
||||
*
|
||||
* Verifies: oauthRequired emission on 401, oauthHandled reconnection,
|
||||
* oauthFailed rejection, timeout behavior, and token expiry mid-session.
|
||||
*/
|
||||
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
jest.mock('~/auth', () => ({
|
||||
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
|
||||
resolveHostnameSSRF: jest.fn(async () => false),
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/mcpConfig', () => ({
|
||||
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
|
||||
}));
|
||||
|
||||
async function safeDisconnect(conn: MCPConnection | null): Promise<void> {
|
||||
if (!conn) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await conn.disconnect();
|
||||
} catch {
|
||||
// Ignore disconnect errors during cleanup
|
||||
}
|
||||
}
|
||||
|
||||
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
|
||||
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${serverUrl}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const data = (await tokenRes.json()) as { access_token: string };
|
||||
return data.access_token;
|
||||
}
|
||||
|
||||
describe('MCPConnection OAuth Events — Real Server', () => {
|
||||
let server: OAuthTestServer;
|
||||
let connection: MCPConnection | null = null;
|
||||
|
||||
beforeEach(() => {
|
||||
MCPConnection.clearCooldown('test-server');
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await safeDisconnect(connection);
|
||||
connection = null;
|
||||
if (server) {
|
||||
await server.close();
|
||||
}
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('oauthRequired event', () => {
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
it('should emit oauthRequired when connecting without a token', async () => {
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: { type: 'streamable-http', url: server.url },
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
const oauthRequiredPromise = new Promise<{
|
||||
serverName: string;
|
||||
error: Error;
|
||||
serverUrl?: string;
|
||||
userId?: string;
|
||||
}>((resolve) => {
|
||||
connection!.on('oauthRequired', (data) => {
|
||||
resolve(
|
||||
data as {
|
||||
serverName: string;
|
||||
error: Error;
|
||||
serverUrl?: string;
|
||||
userId?: string;
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// Connection will fail with 401, emitting oauthRequired
|
||||
const connectPromise = connection.connect().catch(() => {
|
||||
// Expected to fail since no one handles oauthRequired
|
||||
});
|
||||
|
||||
let raceTimer: NodeJS.Timeout | undefined;
|
||||
const eventData = await Promise.race([
|
||||
oauthRequiredPromise,
|
||||
new Promise<never>((_, reject) => {
|
||||
raceTimer = setTimeout(
|
||||
() => reject(new Error('Timed out waiting for oauthRequired')),
|
||||
10000,
|
||||
);
|
||||
}),
|
||||
]).finally(() => clearTimeout(raceTimer));
|
||||
|
||||
expect(eventData.serverName).toBe('test-server');
|
||||
expect(eventData.error).toBeDefined();
|
||||
|
||||
// Emit oauthFailed to unblock connect()
|
||||
connection.emit('oauthFailed', new Error('test cleanup'));
|
||||
await connectPromise.catch(() => undefined);
|
||||
});
|
||||
|
||||
it('should not emit oauthRequired when connecting with a valid token', async () => {
|
||||
const accessToken = await exchangeCodeForToken(server.url);
|
||||
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: { type: 'streamable-http', url: server.url },
|
||||
userId: 'user-1',
|
||||
oauthTokens: {
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens,
|
||||
});
|
||||
|
||||
let oauthFired = false;
|
||||
connection.on('oauthRequired', () => {
|
||||
oauthFired = true;
|
||||
});
|
||||
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
expect(oauthFired).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('oauthHandled reconnection', () => {
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
it('should succeed on retry after oauthHandled provides valid tokens', async () => {
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
type: 'streamable-http',
|
||||
url: server.url,
|
||||
initTimeout: 15000,
|
||||
},
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
// First connect fails with 401 → oauthRequired fires
|
||||
let oauthFired = false;
|
||||
connection.on('oauthRequired', () => {
|
||||
oauthFired = true;
|
||||
connection!.emit('oauthFailed', new Error('Will retry with tokens'));
|
||||
});
|
||||
|
||||
// First attempt fails as expected
|
||||
await expect(connection.connect()).rejects.toThrow();
|
||||
expect(oauthFired).toBe(true);
|
||||
|
||||
// Now set valid tokens and reconnect
|
||||
const accessToken = await exchangeCodeForToken(server.url);
|
||||
connection.setOAuthTokens({
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('oauthFailed rejection', () => {
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
it('should reject connect() when oauthFailed is emitted', async () => {
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
type: 'streamable-http',
|
||||
url: server.url,
|
||||
initTimeout: 15000,
|
||||
},
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
connection.on('oauthRequired', () => {
|
||||
connection!.emit('oauthFailed', new Error('User denied OAuth'));
|
||||
});
|
||||
|
||||
await expect(connection.connect()).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Token expiry during session', () => {
|
||||
it('should detect expired token on reconnect and emit oauthRequired', async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 1000 });
|
||||
|
||||
const accessToken = await exchangeCodeForToken(server.url);
|
||||
|
||||
connection = new MCPConnection({
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
type: 'streamable-http',
|
||||
url: server.url,
|
||||
initTimeout: 15000,
|
||||
},
|
||||
userId: 'user-1',
|
||||
oauthTokens: {
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens,
|
||||
});
|
||||
|
||||
// Initial connect should succeed
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
await connection.disconnect();
|
||||
|
||||
// Wait for token to expire
|
||||
await new Promise((r) => setTimeout(r, 1200));
|
||||
|
||||
// Reconnect should trigger oauthRequired since token is expired on the server
|
||||
let oauthFired = false;
|
||||
connection.on('oauthRequired', () => {
|
||||
oauthFired = true;
|
||||
connection!.emit('oauthFailed', new Error('Will retry with fresh token'));
|
||||
});
|
||||
|
||||
// First reconnect fails with 401 → oauthRequired
|
||||
await expect(connection.connect()).rejects.toThrow();
|
||||
expect(oauthFired).toBe(true);
|
||||
|
||||
// Get fresh token and reconnect
|
||||
const newToken = await exchangeCodeForToken(server.url);
|
||||
connection.setOAuthTokens({
|
||||
access_token: newToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
|
||||
await connection.connect();
|
||||
expect(await connection.isConnected()).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
538
packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts
Normal file
538
packages/api/src/mcp/__tests__/MCPOAuthFlow.test.ts
Normal file
|
|
@ -0,0 +1,538 @@
|
|||
/**
|
||||
* OAuth flow tests against a real HTTP server.
|
||||
*
|
||||
* Tests MCPOAuthHandler.refreshOAuthTokens and MCPTokenStorage lifecycle
|
||||
* using a real test OAuth server (not mocked SDK functions).
|
||||
*/
|
||||
|
||||
import { createHash } from 'crypto';
|
||||
import { Keyv } from 'keyv';
|
||||
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
import { createOAuthMCPServer, MockKeyv, InMemoryTokenStore } from './helpers/oauthTestServer';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
describe('MCP OAuth Flow — Real HTTP Server', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Token refresh with real server', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should refresh tokens with stored client info via real /token endpoint', async () => {
|
||||
// First get initial tokens
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
expect(initial.refresh_token).toBeDefined();
|
||||
|
||||
// Register a client so we have clientInfo
|
||||
const regRes = await fetch(`${server.url}register`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
|
||||
});
|
||||
const clientInfo = (await regRes.json()) as {
|
||||
client_id: string;
|
||||
client_secret: string;
|
||||
};
|
||||
|
||||
// Refresh tokens using the real endpoint
|
||||
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
initial.refresh_token,
|
||||
{
|
||||
serverName: 'test-server',
|
||||
serverUrl: server.url,
|
||||
clientInfo: {
|
||||
...clientInfo,
|
||||
redirect_uris: ['http://localhost/callback'],
|
||||
},
|
||||
},
|
||||
{},
|
||||
{
|
||||
token_url: `${server.url}token`,
|
||||
client_id: clientInfo.client_id,
|
||||
client_secret: clientInfo.client_secret,
|
||||
token_exchange_method: 'DefaultPost',
|
||||
},
|
||||
);
|
||||
|
||||
expect(refreshed.access_token).toBeDefined();
|
||||
expect(refreshed.access_token).not.toBe(initial.access_token);
|
||||
expect(refreshed.token_type).toBe('Bearer');
|
||||
expect(refreshed.obtained_at).toBeDefined();
|
||||
});
|
||||
|
||||
it('should get new refresh token when server rotates', async () => {
|
||||
const rotatingServer = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
rotateRefreshTokens: true,
|
||||
});
|
||||
|
||||
try {
|
||||
const code = await rotatingServer.getAuthCode();
|
||||
const tokenRes = await fetch(`${rotatingServer.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
const refreshed = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
initial.refresh_token,
|
||||
{
|
||||
serverName: 'test-server',
|
||||
serverUrl: rotatingServer.url,
|
||||
},
|
||||
{},
|
||||
{
|
||||
token_url: `${rotatingServer.url}token`,
|
||||
client_id: 'anon',
|
||||
token_exchange_method: 'DefaultPost',
|
||||
},
|
||||
);
|
||||
|
||||
expect(refreshed.access_token).not.toBe(initial.access_token);
|
||||
expect(refreshed.refresh_token).toBeDefined();
|
||||
expect(refreshed.refresh_token).not.toBe(initial.refresh_token);
|
||||
} finally {
|
||||
await rotatingServer.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should fail refresh with invalid refresh token', async () => {
|
||||
await expect(
|
||||
MCPOAuthHandler.refreshOAuthTokens(
|
||||
'invalid-refresh-token',
|
||||
{
|
||||
serverName: 'test-server',
|
||||
serverUrl: server.url,
|
||||
},
|
||||
{},
|
||||
{
|
||||
token_url: `${server.url}token`,
|
||||
client_id: 'anon',
|
||||
token_exchange_method: 'DefaultPost',
|
||||
},
|
||||
),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('OAuth server metadata discovery', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ issueRefreshTokens: true });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should expose /.well-known/oauth-authorization-server', async () => {
|
||||
const res = await fetch(`${server.url}.well-known/oauth-authorization-server`);
|
||||
expect(res.status).toBe(200);
|
||||
|
||||
const metadata = (await res.json()) as {
|
||||
authorization_endpoint: string;
|
||||
token_endpoint: string;
|
||||
registration_endpoint: string;
|
||||
grant_types_supported: string[];
|
||||
};
|
||||
|
||||
expect(metadata.authorization_endpoint).toContain('/authorize');
|
||||
expect(metadata.token_endpoint).toContain('/token');
|
||||
expect(metadata.registration_endpoint).toContain('/register');
|
||||
expect(metadata.grant_types_supported).toContain('authorization_code');
|
||||
expect(metadata.grant_types_supported).toContain('refresh_token');
|
||||
});
|
||||
|
||||
it('should not advertise refresh_token grant when disabled', async () => {
|
||||
const noRefreshServer = await createOAuthMCPServer({
|
||||
issueRefreshTokens: false,
|
||||
});
|
||||
try {
|
||||
const res = await fetch(`${noRefreshServer.url}.well-known/oauth-authorization-server`);
|
||||
const metadata = (await res.json()) as { grant_types_supported: string[] };
|
||||
expect(metadata.grant_types_supported).not.toContain('refresh_token');
|
||||
} finally {
|
||||
await noRefreshServer.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Dynamic client registration', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should register a client via /register endpoint', async () => {
|
||||
const res = await fetch(`${server.url}register`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
redirect_uris: ['http://localhost/callback'],
|
||||
}),
|
||||
});
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
const client = (await res.json()) as {
|
||||
client_id: string;
|
||||
client_secret: string;
|
||||
redirect_uris: string[];
|
||||
};
|
||||
|
||||
expect(client.client_id).toBeDefined();
|
||||
expect(client.client_secret).toBeDefined();
|
||||
expect(client.redirect_uris).toEqual(['http://localhost/callback']);
|
||||
expect(server.registeredClients.has(client.client_id)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('End-to-End: store, retrieve, expire, refresh cycle', () => {
|
||||
it('should perform full token lifecycle with real server', async () => {
|
||||
const server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 1000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
const tokenStore = new InMemoryTokenStore();
|
||||
|
||||
try {
|
||||
// 1. Get initial tokens via auth code exchange
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// 2. Store tokens
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
tokens: initial,
|
||||
createToken: tokenStore.createToken,
|
||||
});
|
||||
|
||||
// 3. Retrieve — should succeed
|
||||
const valid = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
expect(valid).not.toBeNull();
|
||||
expect(valid!.access_token).toBe(initial.access_token);
|
||||
expect(valid!.refresh_token).toBe(initial.refresh_token);
|
||||
|
||||
// 4. Wait for expiry
|
||||
await new Promise((r) => setTimeout(r, 1200));
|
||||
|
||||
// 5. Retrieve again — should trigger refresh via callback
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const refreshRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
|
||||
if (!refreshRes.ok) {
|
||||
throw new Error(`Refresh failed: ${refreshRes.status}`);
|
||||
}
|
||||
|
||||
const data = (await refreshRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
...data,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + data.expires_in * 1000,
|
||||
};
|
||||
};
|
||||
|
||||
const refreshed = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(refreshed).not.toBeNull();
|
||||
expect(refreshed!.access_token).not.toBe(initial.access_token);
|
||||
|
||||
// 6. Verify the refreshed token works against the server
|
||||
const mcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${refreshed!.access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
expect(mcpRes.status).toBe(200);
|
||||
} finally {
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('completeOAuthFlow via FlowStateManager', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ issueRefreshTokens: true });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should exchange auth code and complete flow in FlowStateManager', async () => {
|
||||
const store = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'test-user:test-server';
|
||||
const code = await server.getAuthCode();
|
||||
|
||||
// Initialize the flow with metadata the handler needs
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverUrl: server.url,
|
||||
clientInfo: {
|
||||
client_id: 'test-client',
|
||||
redirect_uris: ['http://localhost/callback'],
|
||||
},
|
||||
codeVerifier: 'test-verifier',
|
||||
metadata: {
|
||||
token_endpoint: `${server.url}token`,
|
||||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||
},
|
||||
});
|
||||
|
||||
// The SDK's exchangeAuthorization wants full OAuth metadata,
|
||||
// so we'll test the token exchange directly instead of going through
|
||||
// completeOAuthFlow (which requires full SDK-compatible metadata)
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
const tokens = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
const mcpTokens: MCPOAuthTokens = {
|
||||
...tokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + tokens.expires_in * 1000,
|
||||
};
|
||||
|
||||
// Complete the flow
|
||||
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
|
||||
expect(completed).toBe(true);
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('COMPLETED');
|
||||
expect(state?.result?.access_token).toBe(tokens.access_token);
|
||||
});
|
||||
|
||||
it('should fail flow when authorization code is invalid', async () => {
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: 'grant_type=authorization_code&code=invalid-code',
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(400);
|
||||
const body = (await tokenRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
|
||||
it('should fail when authorization code is reused', async () => {
|
||||
const code = await server.getAuthCode();
|
||||
|
||||
// First exchange succeeds
|
||||
const firstRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
expect(firstRes.status).toBe(200);
|
||||
|
||||
// Second exchange fails
|
||||
const secondRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
expect(secondRes.status).toBe(400);
|
||||
const body = (await secondRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
});
|
||||
|
||||
describe('PKCE verification', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
function generatePKCE(): { verifier: string; challenge: string } {
|
||||
const verifier = 'dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk';
|
||||
const challenge = createHash('sha256').update(verifier).digest('base64url');
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
||||
it('should accept valid code_verifier matching code_challenge', async () => {
|
||||
const { verifier, challenge } = generatePKCE();
|
||||
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}&code_verifier=${verifier}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(200);
|
||||
const data = (await tokenRes.json()) as { access_token: string };
|
||||
expect(data.access_token).toBeDefined();
|
||||
});
|
||||
|
||||
it('should reject wrong code_verifier', async () => {
|
||||
const { challenge } = generatePKCE();
|
||||
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}&code_verifier=wrong-verifier`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(400);
|
||||
const body = (await tokenRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
|
||||
it('should reject missing code_verifier when code_challenge was provided', async () => {
|
||||
const { challenge } = generatePKCE();
|
||||
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost&state=test&code_challenge=${challenge}&code_challenge_method=S256`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(400);
|
||||
const body = (await tokenRes.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_grant');
|
||||
});
|
||||
|
||||
it('should still accept codes without PKCE when no code_challenge was provided', async () => {
|
||||
const code = await server.getAuthCode();
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(200);
|
||||
});
|
||||
});
|
||||
});
|
||||
516
packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts
Normal file
516
packages/api/src/mcp/__tests__/MCPOAuthRaceCondition.test.ts
Normal file
|
|
@ -0,0 +1,516 @@
|
|||
/**
|
||||
* Tests for MCP OAuth race condition fixes:
|
||||
*
|
||||
* 1. Connection mutex coalesces concurrent getUserConnection() calls
|
||||
* 2. PENDING OAuth flows are reused, not deleted
|
||||
* 3. No-refresh-token expiry throws ReauthenticationRequiredError
|
||||
* 4. completeFlow recovers when flow state was deleted by a race
|
||||
* 5. monitorFlow retries once when flow state disappears mid-poll
|
||||
*/
|
||||
|
||||
import { Keyv } from 'keyv';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
import { MCPTokenStorage, MCPOAuthHandler, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { MockKeyv, createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
jest.mock('~/auth', () => ({
|
||||
createSSRFSafeUndiciConnect: jest.fn(() => undefined),
|
||||
resolveHostnameSSRF: jest.fn(async () => false),
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/mcpConfig', () => ({
|
||||
mcpConfig: { CONNECTION_CHECK_TTL: 0, USER_CONNECTION_IDLE_TIMEOUT: 30 * 60 * 1000 },
|
||||
}));
|
||||
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
|
||||
describe('MCP OAuth Race Condition Fixes', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Fix 1: Connection mutex coalesces concurrent attempts', () => {
|
||||
it('should return the same pending promise for concurrent getUserConnection calls', async () => {
|
||||
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
|
||||
|
||||
class TestManager extends UserConnectionManager {
|
||||
public createCallCount = 0;
|
||||
|
||||
getPendingConnections() {
|
||||
return this.pendingConnections;
|
||||
}
|
||||
}
|
||||
|
||||
const manager = new TestManager();
|
||||
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
isStale: jest.fn().mockReturnValue(false),
|
||||
};
|
||||
|
||||
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
|
||||
manager.appConnections = mockAppConnections as never;
|
||||
|
||||
const mockConfig = {
|
||||
type: 'streamable-http',
|
||||
url: 'http://localhost:9999/',
|
||||
updatedAt: undefined,
|
||||
dbId: undefined,
|
||||
};
|
||||
|
||||
jest
|
||||
.spyOn(
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
|
||||
'getInstance',
|
||||
)
|
||||
.mockReturnValue({
|
||||
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
});
|
||||
|
||||
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
|
||||
const createSpy = jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
|
||||
manager.createCallCount++;
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
return mockConnection as never;
|
||||
});
|
||||
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
const user = { id: 'user-1' };
|
||||
const opts = {
|
||||
serverName: 'test-server',
|
||||
user: user as never,
|
||||
flowManager: flowManager as never,
|
||||
};
|
||||
|
||||
const [conn1, conn2, conn3] = await Promise.all([
|
||||
manager.getUserConnection(opts),
|
||||
manager.getUserConnection(opts),
|
||||
manager.getUserConnection(opts),
|
||||
]);
|
||||
|
||||
expect(conn1).toBe(conn2);
|
||||
expect(conn2).toBe(conn3);
|
||||
expect(createSpy).toHaveBeenCalledTimes(1);
|
||||
expect(manager.createCallCount).toBe(1);
|
||||
|
||||
createSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should not coalesce when forceNew is true', async () => {
|
||||
const { UserConnectionManager } = await import('~/mcp/UserConnectionManager');
|
||||
|
||||
class TestManager extends UserConnectionManager {}
|
||||
|
||||
const manager = new TestManager();
|
||||
|
||||
let callCount = 0;
|
||||
const makeConnection = () => ({
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
isStale: jest.fn().mockReturnValue(false),
|
||||
});
|
||||
|
||||
const mockAppConnections = { has: jest.fn().mockResolvedValue(false) };
|
||||
manager.appConnections = mockAppConnections as never;
|
||||
|
||||
const mockConfig = {
|
||||
type: 'streamable-http',
|
||||
url: 'http://localhost:9999/',
|
||||
updatedAt: undefined,
|
||||
dbId: undefined,
|
||||
};
|
||||
|
||||
jest
|
||||
.spyOn(
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
require('~/mcp/registry/MCPServersRegistry').MCPServersRegistry,
|
||||
'getInstance',
|
||||
)
|
||||
.mockReturnValue({
|
||||
getServerConfig: jest.fn().mockResolvedValue(mockConfig),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
});
|
||||
|
||||
const { MCPConnectionFactory } = await import('~/mcp/MCPConnectionFactory');
|
||||
jest.spyOn(MCPConnectionFactory, 'create').mockImplementation(async () => {
|
||||
callCount++;
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
return makeConnection() as never;
|
||||
});
|
||||
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
const user = { id: 'user-2' };
|
||||
|
||||
const [conn1, conn2] = await Promise.all([
|
||||
manager.getUserConnection({
|
||||
serverName: 'test-server',
|
||||
forceNew: true,
|
||||
user: user as never,
|
||||
flowManager: flowManager as never,
|
||||
}),
|
||||
manager.getUserConnection({
|
||||
serverName: 'test-server',
|
||||
forceNew: true,
|
||||
user: user as never,
|
||||
flowManager: flowManager as never,
|
||||
}),
|
||||
]);
|
||||
|
||||
expect(callCount).toBe(2);
|
||||
expect(conn1).not.toBe(conn2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fix 2: PENDING flow is reused, not deleted', () => {
|
||||
it('should join an existing PENDING flow via createFlow instead of deleting it', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'test-flow-pending';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
clientInfo: { client_id: 'test-client' },
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('PENDING');
|
||||
|
||||
const deleteSpy = jest.spyOn(flowManager, 'deleteFlow');
|
||||
|
||||
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
|
||||
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', {
|
||||
access_token: 'test-token',
|
||||
token_type: 'Bearer',
|
||||
} as never);
|
||||
|
||||
const result = await monitorPromise;
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({ access_token: 'test-token', token_type: 'Bearer' }),
|
||||
);
|
||||
expect(deleteSpy).not.toHaveBeenCalled();
|
||||
|
||||
deleteSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should delete and recreate FAILED flows', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'test-flow-failed';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
await flowManager.failFlow(flowId, 'mcp_oauth', 'previous error');
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('FAILED');
|
||||
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
const afterDelete = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(afterDelete).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fix 3: completeFlow handles deleted state gracefully', () => {
|
||||
it('should return false when state was deleted by race', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'race-deleted-flow';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
const stateBeforeComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(stateBeforeComplete).toBeUndefined();
|
||||
|
||||
const result = await flowManager.completeFlow(flowId, 'mcp_oauth', {
|
||||
access_token: 'recovered-token',
|
||||
token_type: 'Bearer',
|
||||
} as never);
|
||||
|
||||
expect(result).toBe(false);
|
||||
|
||||
const stateAfterComplete = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(stateAfterComplete).toBeUndefined();
|
||||
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('cannot recover metadata'),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject monitorFlow when state is deleted and not recoverable', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
|
||||
|
||||
const flowId = 'monitor-retry-flow';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {});
|
||||
|
||||
const monitorPromise = flowManager.createFlow(flowId, 'mcp_oauth', {});
|
||||
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
await expect(monitorPromise).rejects.toThrow('Flow state not found');
|
||||
});
|
||||
});
|
||||
|
||||
describe('State mapping cleanup on flow replacement', () => {
|
||||
it('should delete old state mapping when a flow is replaced', async () => {
|
||||
const store = new MockKeyv();
|
||||
const flowManager = new FlowStateManager<MCPOAuthTokens | null>(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
const oldState = 'old-random-state-abc123';
|
||||
const newState = 'new-random-state-xyz789';
|
||||
|
||||
// Simulate initial flow with state mapping
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { state: oldState });
|
||||
await MCPOAuthHandler.storeStateMapping(oldState, flowId, flowManager);
|
||||
|
||||
// Old state should resolve
|
||||
const resolvedBefore = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
|
||||
expect(resolvedBefore).toBe(flowId);
|
||||
|
||||
// Replace the flow: delete old, create new, clean up old state mapping
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
await MCPOAuthHandler.deleteStateMapping(oldState, flowManager);
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', { state: newState });
|
||||
await MCPOAuthHandler.storeStateMapping(newState, flowId, flowManager);
|
||||
|
||||
// Old state should no longer resolve
|
||||
const resolvedOld = await MCPOAuthHandler.resolveStateToFlowId(oldState, flowManager);
|
||||
expect(resolvedOld).toBeNull();
|
||||
|
||||
// New state should resolve
|
||||
const resolvedNew = await MCPOAuthHandler.resolveStateToFlowId(newState, flowManager);
|
||||
expect(resolvedNew).toBe(flowId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fix 4: ReauthenticationRequiredError for no-refresh-token', () => {
|
||||
it('should throw ReauthenticationRequiredError when access token expired and no refresh token', async () => {
|
||||
const expiredDate = new Date(Date.now() - 60000);
|
||||
|
||||
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
|
||||
if (filter.type === 'mcp_oauth') {
|
||||
return {
|
||||
token: 'enc:expired-access-token',
|
||||
expiresAt: expiredDate,
|
||||
createdAt: new Date(Date.now() - 120000),
|
||||
};
|
||||
}
|
||||
if (filter.type === 'mcp_oauth_refresh') {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
}),
|
||||
).rejects.toThrow('Re-authentication required');
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when access token is missing and no refresh token', async () => {
|
||||
const findToken = jest.fn().mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should not throw when access token is valid', async () => {
|
||||
const futureDate = new Date(Date.now() + 3600000);
|
||||
|
||||
const findToken = jest.fn().mockImplementation(async (filter: { type?: string }) => {
|
||||
if (filter.type === 'mcp_oauth') {
|
||||
return {
|
||||
token: 'enc:valid-access-token',
|
||||
expiresAt: futureDate,
|
||||
createdAt: new Date(),
|
||||
};
|
||||
}
|
||||
if (filter.type === 'mcp_oauth_refresh') {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
findToken,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.access_token).toBe('valid-access-token');
|
||||
});
|
||||
});
|
||||
|
||||
describe('E2E: OAuth-gated MCP server with no refresh tokens', () => {
|
||||
let server: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should start OAuth-gated MCP server that validates Bearer tokens', async () => {
|
||||
const res = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'initialize', id: 1 }),
|
||||
});
|
||||
|
||||
expect(res.status).toBe(401);
|
||||
const body = (await res.json()) as { error: string };
|
||||
expect(body.error).toBe('invalid_token');
|
||||
});
|
||||
|
||||
it('should issue tokens via authorization code exchange with no refresh token', async () => {
|
||||
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
|
||||
expect(authRes.status).toBe(302);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
expect(code).toBeTruthy();
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(200);
|
||||
const tokenBody = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
refresh_token?: string;
|
||||
};
|
||||
expect(tokenBody.access_token).toBeTruthy();
|
||||
expect(tokenBody.token_type).toBe('Bearer');
|
||||
expect(tokenBody.refresh_token).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should allow MCP requests with valid Bearer token', async () => {
|
||||
const authRes = await fetch(`${server.url}authorize?redirect_uri=http://localhost&state=s1`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
const { access_token } = (await tokenRes.json()) as { access_token: string };
|
||||
|
||||
const mcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
expect(mcpRes.status).toBe(200);
|
||||
});
|
||||
|
||||
it('should reject expired tokens with 401', async () => {
|
||||
const shortTTLServer = await createOAuthMCPServer({ tokenTTLMs: 500 });
|
||||
|
||||
try {
|
||||
const authRes = await fetch(
|
||||
`${shortTTLServer.url}authorize?redirect_uri=http://localhost&state=s1`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
|
||||
const tokenRes = await fetch(`${shortTTLServer.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
|
||||
const { access_token } = (await tokenRes.json()) as { access_token: string };
|
||||
|
||||
await new Promise((r) => setTimeout(r, 600));
|
||||
|
||||
const mcpRes = await fetch(shortTTLServer.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${access_token}`,
|
||||
},
|
||||
body: JSON.stringify({ jsonrpc: '2.0', method: 'ping', id: 2 }),
|
||||
});
|
||||
|
||||
expect(mcpRes.status).toBe(401);
|
||||
} finally {
|
||||
await shortTTLServer.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
654
packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts
Normal file
654
packages/api/src/mcp/__tests__/MCPOAuthTokenExpiry.test.ts
Normal file
|
|
@ -0,0 +1,654 @@
|
|||
/**
|
||||
* Tests for MCP OAuth token expiry → re-authentication scenarios.
|
||||
*
|
||||
* Reproduces the edge case where:
|
||||
* 1. Tokens are stored (access + refresh)
|
||||
* 2. Access token expires
|
||||
* 3. Refresh attempt fails (server rejects/revokes refresh token)
|
||||
* 4. System must fall back to full OAuth re-auth via handleOAuthRequired
|
||||
* 5. The CSRF cookie may be absent (chat/SSE flow), so the PENDING flow fallback is needed
|
||||
*
|
||||
* Also tests the happy path: access token expired but refresh succeeds.
|
||||
*/
|
||||
|
||||
import { Keyv } from 'keyv';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { FlowStateManager, PENDING_STALE_MS } from '~/flow/manager';
|
||||
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { MockKeyv, InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
describe('MCP OAuth Token Expiry Scenarios', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Access token expired + refresh token available + refresh succeeds', () => {
|
||||
let server: OAuthTestServer;
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 500,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should refresh expired access token via real /token endpoint', async () => {
|
||||
// Get initial tokens from real server
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// Store expired access token directly (bypassing storeTokens' expiresIn clamping)
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: `enc:${initial.access_token}`,
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: `enc:${initial.refresh_token}`,
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const res = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(`Refresh failed: ${res.status}`);
|
||||
}
|
||||
const data = (await res.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
};
|
||||
return {
|
||||
...data,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + data.expires_in * 1000,
|
||||
};
|
||||
};
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.access_token).not.toBe(initial.access_token);
|
||||
|
||||
// Verify the refreshed token works against the server
|
||||
const mcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${result!.access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
expect(mcpRes.status).toBe(200);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Access token expired + refresh token rejected by server', () => {
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(() => {
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
it('should return null when refresh token is rejected (invalid_grant)', async () => {
|
||||
const server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
|
||||
try {
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// Store expired access token directly
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: `enc:${initial.access_token}`,
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: `enc:${initial.refresh_token}`,
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
// Simulate server revoking the refresh token
|
||||
server.issuedRefreshTokens.clear();
|
||||
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const res = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const body = (await res.json()) as { error: string };
|
||||
throw new Error(`Token refresh failed: ${body.error}`);
|
||||
}
|
||||
const data = (await res.json()) as MCPOAuthTokens;
|
||||
return data;
|
||||
};
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to refresh tokens'),
|
||||
expect.any(Error),
|
||||
);
|
||||
} finally {
|
||||
await server.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should return null when refresh endpoint returns unauthorized_client', async () => {
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: 'enc:some-refresh-token',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshCallback = jest
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('unauthorized_client: client not authorized for refresh'));
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('does not support refresh tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Access token expired + NO refresh token → ReauthenticationRequiredError', () => {
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(() => {
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when no refresh token stored', async () => {
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError with correct reason for expired token', async () => {
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
}),
|
||||
).rejects.toThrow('access token expired');
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError with correct reason for missing token', async () => {
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
}),
|
||||
).rejects.toThrow('access token missing');
|
||||
});
|
||||
});
|
||||
|
||||
describe('PENDING flow fallback for CSRF-less OAuth callbacks', () => {
|
||||
it('should allow OAuth completion when PENDING flow exists (simulating chat/SSE path)', async () => {
|
||||
const store = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
userId: 'user1',
|
||||
serverUrl: 'https://example.com',
|
||||
state: 'test-state',
|
||||
authorizationUrl: 'https://example.com/authorize?state=user1:test-server',
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(state?.status).toBe('PENDING');
|
||||
|
||||
const tokens: MCPOAuthTokens = {
|
||||
access_token: 'new-access-token',
|
||||
token_type: 'Bearer',
|
||||
refresh_token: 'new-refresh-token',
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + 3600000,
|
||||
};
|
||||
|
||||
const completed = await flowManager.completeFlow(flowId, 'mcp_oauth', tokens);
|
||||
expect(completed).toBe(true);
|
||||
|
||||
const completedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(completedState?.status).toBe('COMPLETED');
|
||||
expect((completedState?.result as MCPOAuthTokens | undefined)?.access_token).toBe(
|
||||
'new-access-token',
|
||||
);
|
||||
});
|
||||
|
||||
it('should store authorizationUrl in flow metadata for re-issuance', async () => {
|
||||
const store = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(store as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
const authUrl = 'https://auth.example.com/authorize?client_id=abc&state=user1:test-server';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
userId: 'user1',
|
||||
serverUrl: 'https://example.com',
|
||||
state: 'test-state',
|
||||
authorizationUrl: authUrl,
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect((state?.metadata as Record<string, unknown>)?.authorizationUrl).toBe(authUrl);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Full token expiry → refresh failure → re-auth flow', () => {
|
||||
let server: OAuthTestServer;
|
||||
let tokenStore: InMemoryTokenStore;
|
||||
|
||||
beforeEach(async () => {
|
||||
server = await createOAuthMCPServer({
|
||||
tokenTTLMs: 60000,
|
||||
issueRefreshTokens: true,
|
||||
});
|
||||
tokenStore = new InMemoryTokenStore();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await server.close();
|
||||
});
|
||||
|
||||
it('should go through full cycle: get tokens → expire → refresh fails → re-auth needed', async () => {
|
||||
// Step 1: Get initial tokens
|
||||
const code = await server.getAuthCode();
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const initial = (await tokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
// Step 2: Store tokens with valid expiry first
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
tokens: initial,
|
||||
createToken: tokenStore.createToken,
|
||||
});
|
||||
|
||||
// Step 3: Verify tokens work
|
||||
const validResult = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
expect(validResult).not.toBeNull();
|
||||
expect(validResult!.access_token).toBe(initial.access_token);
|
||||
|
||||
// Step 4: Simulate token expiry by directly updating the stored token's expiresAt
|
||||
await tokenStore.updateToken({ userId: 'u1', identifier: 'mcp:test-srv' }, { expiresIn: -1 });
|
||||
|
||||
// Step 5: Revoke refresh token on server side (simulating server-side revocation)
|
||||
server.issuedRefreshTokens.clear();
|
||||
|
||||
// Step 6: Try to get tokens — refresh should fail, return null
|
||||
const refreshCallback = async (refreshToken: string): Promise<MCPOAuthTokens> => {
|
||||
const res = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=refresh_token&refresh_token=${refreshToken}`,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const body = (await res.json()) as { error: string };
|
||||
throw new Error(`Refresh failed: ${body.error}`);
|
||||
}
|
||||
const data = (await res.json()) as MCPOAuthTokens;
|
||||
return data;
|
||||
};
|
||||
|
||||
const expiredResult = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
|
||||
// Refresh failed → returns null → triggers OAuth re-auth flow
|
||||
expect(expiredResult).toBeNull();
|
||||
|
||||
// Step 7: Simulate the re-auth flow via FlowStateManager
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
const flowId = 'u1:test-srv';
|
||||
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-srv',
|
||||
userId: 'u1',
|
||||
serverUrl: server.url,
|
||||
state: 'test-state',
|
||||
authorizationUrl: `${server.url}authorize?state=${flowId}`,
|
||||
});
|
||||
|
||||
// Step 8: Get a new auth code and exchange for tokens (simulating user re-auth)
|
||||
const newCode = await server.getAuthCode();
|
||||
const newTokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${newCode}`,
|
||||
});
|
||||
const newTokens = (await newTokenRes.json()) as {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
// Step 9: Complete the flow
|
||||
const mcpTokens: MCPOAuthTokens = {
|
||||
...newTokens,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + newTokens.expires_in * 1000,
|
||||
};
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', mcpTokens);
|
||||
|
||||
// Step 10: Store the new tokens
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
tokens: mcpTokens,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
|
||||
// Step 11: Verify new tokens work
|
||||
const newResult = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
});
|
||||
expect(newResult).not.toBeNull();
|
||||
expect(newResult!.access_token).toBe(newTokens.access_token);
|
||||
|
||||
// Step 12: Verify new token works against server
|
||||
const finalMcpRes = await fetch(server.url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'application/json, text/event-stream',
|
||||
Authorization: `Bearer ${newResult!.access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'initialize',
|
||||
id: 1,
|
||||
params: {
|
||||
protocolVersion: '2025-03-26',
|
||||
capabilities: {},
|
||||
clientInfo: { name: 'test', version: '0.0.1' },
|
||||
},
|
||||
}),
|
||||
});
|
||||
expect(finalMcpRes.status).toBe(200);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Concurrent token expiry with connection mutex', () => {
|
||||
it('should handle multiple concurrent getTokens calls when token is expired', async () => {
|
||||
const tokenStore = new InMemoryTokenStore();
|
||||
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:test-srv',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await tokenStore.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:test-srv:refresh',
|
||||
token: 'enc:valid-refresh',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
let refreshCallCount = 0;
|
||||
const refreshCallback = jest.fn().mockImplementation(async () => {
|
||||
refreshCallCount++;
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
return {
|
||||
access_token: `refreshed-token-${refreshCallCount}`,
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + 3600000,
|
||||
};
|
||||
});
|
||||
|
||||
// Fire 3 concurrent getTokens calls via FlowStateManager (like the connection mutex does)
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 30000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const getTokensViaFlow = () =>
|
||||
flowManager.createFlowWithHandler('u1:test-srv', 'mcp_get_tokens', async () => {
|
||||
return await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'test-srv',
|
||||
findToken: tokenStore.findToken,
|
||||
createToken: tokenStore.createToken,
|
||||
updateToken: tokenStore.updateToken,
|
||||
refreshTokens: refreshCallback,
|
||||
});
|
||||
});
|
||||
|
||||
const [r1, r2, r3] = await Promise.all([
|
||||
getTokensViaFlow(),
|
||||
getTokensViaFlow(),
|
||||
getTokensViaFlow(),
|
||||
]);
|
||||
|
||||
// All should get tokens (either directly or via flow coalescing)
|
||||
expect(r1).not.toBeNull();
|
||||
expect(r2).not.toBeNull();
|
||||
expect(r3).not.toBeNull();
|
||||
|
||||
// The refresh callback should only be called once due to flow coalescing
|
||||
expect(refreshCallback).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Stale PENDING flow detection', () => {
|
||||
it('should treat PENDING flows older than 2 minutes as stale', async () => {
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 300000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
authorizationUrl: 'https://example.com/auth',
|
||||
});
|
||||
|
||||
// Manually age the flow to 3 minutes
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (state) {
|
||||
state.createdAt = Date.now() - 3 * 60 * 1000;
|
||||
await (flowStore as unknown as { set: (k: string, v: unknown) => Promise<void> }).set(
|
||||
`mcp_oauth:${flowId}`,
|
||||
state,
|
||||
);
|
||||
}
|
||||
|
||||
const agedState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
expect(agedState?.status).toBe('PENDING');
|
||||
|
||||
const age = agedState?.createdAt ? Date.now() - agedState.createdAt : 0;
|
||||
expect(age).toBeGreaterThan(2 * 60 * 1000);
|
||||
|
||||
// A new flow should be created (the stale one would be deleted + recreated)
|
||||
// This verifies our staleness check threshold
|
||||
expect(age > PENDING_STALE_MS).toBe(true);
|
||||
});
|
||||
|
||||
it('should not treat recent PENDING flows as stale', async () => {
|
||||
const flowStore = new MockKeyv<MCPOAuthTokens | null>();
|
||||
const flowManager = new FlowStateManager(flowStore as unknown as Keyv, {
|
||||
ttl: 300000,
|
||||
ci: true,
|
||||
});
|
||||
|
||||
const flowId = 'user1:test-server';
|
||||
await flowManager.initFlow(flowId, 'mcp_oauth', {
|
||||
serverName: 'test-server',
|
||||
authorizationUrl: 'https://example.com/auth',
|
||||
});
|
||||
|
||||
const state = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
const age = state?.createdAt ? Date.now() - state.createdAt : Infinity;
|
||||
|
||||
expect(age < PENDING_STALE_MS).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
544
packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts
Normal file
544
packages/api/src/mcp/__tests__/MCPOAuthTokenStorage.test.ts
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
/**
|
||||
* Integration tests for MCPTokenStorage.storeTokens() and MCPTokenStorage.getTokens().
|
||||
*
|
||||
* Uses InMemoryTokenStore to exercise encrypt/decrypt round-trips, expiry calculation,
|
||||
* refresh callback wiring, and ReauthenticationRequiredError paths.
|
||||
*/
|
||||
|
||||
import { MCPTokenStorage, ReauthenticationRequiredError } from '~/mcp/oauth';
|
||||
import { InMemoryTokenStore } from './helpers/oauthTestServer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
encryptV2: jest.fn(async (val: string) => `enc:${val}`),
|
||||
decryptV2: jest.fn(async (val: string) => val.replace(/^enc:/, '')),
|
||||
}));
|
||||
|
||||
describe('MCPTokenStorage', () => {
|
||||
let store: InMemoryTokenStore;
|
||||
|
||||
beforeEach(() => {
|
||||
store = new InMemoryTokenStore();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('storeTokens', () => {
|
||||
it('should create new access token with expires_in', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved).not.toBeNull();
|
||||
expect(saved!.token).toBe('enc:at1');
|
||||
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
|
||||
expect(expiresInMs).toBeGreaterThan(3500 * 1000);
|
||||
expect(expiresInMs).toBeLessThanOrEqual(3600 * 1000);
|
||||
});
|
||||
|
||||
it('should create new access token with expires_at (MCPOAuthTokens format)', async () => {
|
||||
const expiresAt = Date.now() + 7200 * 1000;
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'at1',
|
||||
token_type: 'Bearer',
|
||||
expires_at: expiresAt,
|
||||
obtained_at: Date.now(),
|
||||
},
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved).not.toBeNull();
|
||||
const diff = Math.abs(saved!.expiresAt.getTime() - expiresAt);
|
||||
expect(diff).toBeLessThan(2000);
|
||||
});
|
||||
|
||||
it('should default to 1-year expiry when none provided', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer' },
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
|
||||
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
|
||||
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
|
||||
});
|
||||
|
||||
it('should update existing access token', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:old-token',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'new-token', token_type: 'Bearer', expires_in: 7200 },
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved!.token).toBe('enc:new-token');
|
||||
});
|
||||
|
||||
it('should store refresh token alongside access token', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'at1',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'rt1',
|
||||
},
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const refreshSaved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
});
|
||||
expect(refreshSaved).not.toBeNull();
|
||||
expect(refreshSaved!.token).toBe('enc:rt1');
|
||||
});
|
||||
|
||||
it('should skip refresh token when not in response', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const refreshSaved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
});
|
||||
expect(refreshSaved).toBeNull();
|
||||
});
|
||||
|
||||
it('should store client info when provided', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
clientInfo: { client_id: 'cid', client_secret: 'csec', redirect_uris: [] },
|
||||
});
|
||||
|
||||
const clientSaved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: 'mcp:srv1:client',
|
||||
});
|
||||
expect(clientSaved).not.toBeNull();
|
||||
expect(clientSaved!.token).toContain('enc:');
|
||||
expect(clientSaved!.token).toContain('cid');
|
||||
});
|
||||
|
||||
it('should use existingTokens to skip DB lookups', async () => {
|
||||
const findSpy = jest.fn();
|
||||
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: { access_token: 'at1', token_type: 'Bearer', expires_in: 3600 },
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
findToken: findSpy,
|
||||
existingTokens: {
|
||||
accessToken: null,
|
||||
refreshToken: null,
|
||||
clientInfoToken: null,
|
||||
},
|
||||
});
|
||||
|
||||
expect(findSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle invalid NaN expiry date', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'at1',
|
||||
token_type: 'Bearer',
|
||||
expires_at: NaN,
|
||||
obtained_at: Date.now(),
|
||||
},
|
||||
createToken: store.createToken,
|
||||
});
|
||||
|
||||
const saved = await store.findToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
});
|
||||
expect(saved).not.toBeNull();
|
||||
const oneYearMs = 365 * 24 * 60 * 60 * 1000;
|
||||
const expiresInMs = saved!.expiresAt.getTime() - Date.now();
|
||||
expect(expiresInMs).toBeGreaterThan(oneYearMs - 5000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokens', () => {
|
||||
it('should return valid non-expired tokens', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:valid-token',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.access_token).toBe('valid-token');
|
||||
expect(result!.token_type).toBe('Bearer');
|
||||
});
|
||||
|
||||
it('should return tokens with refresh token when available', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:at',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result!.refresh_token).toBe('rt');
|
||||
});
|
||||
|
||||
it('should return tokens without refresh token field when none stored', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:at',
|
||||
expiresIn: 3600,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result!.refresh_token).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when expired and no refresh', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should throw ReauthenticationRequiredError when missing and no refresh', async () => {
|
||||
await expect(
|
||||
MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
}),
|
||||
).rejects.toThrow(ReauthenticationRequiredError);
|
||||
});
|
||||
|
||||
it('should refresh expired access token when refresh token and callback are available', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockResolvedValue({
|
||||
access_token: 'refreshed-at',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
obtained_at: Date.now(),
|
||||
expires_at: Date.now() + 3600000,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.access_token).toBe('refreshed-at');
|
||||
expect(refreshTokens).toHaveBeenCalledWith(
|
||||
'rt',
|
||||
expect.objectContaining({ userId: 'u1', serverName: 'srv1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null when refresh fails', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockRejectedValue(new Error('refresh failed'));
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when no refreshTokens callback provided', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when no createToken callback provided', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
refreshTokens: jest.fn(),
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should pass client info to refreshTokens metadata', async () => {
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_client',
|
||||
identifier: 'mcp:srv1:client',
|
||||
token: 'enc:{"client_id":"cid","client_secret":"csec"}',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockResolvedValue({
|
||||
access_token: 'new-at',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
});
|
||||
|
||||
await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
updateToken: store.updateToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(refreshTokens).toHaveBeenCalledWith(
|
||||
'rt',
|
||||
expect.objectContaining({
|
||||
clientInfo: expect.objectContaining({ client_id: 'cid' }),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle unauthorized_client refresh error', async () => {
|
||||
const { logger } = await import('@librechat/data-schemas');
|
||||
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth',
|
||||
identifier: 'mcp:srv1',
|
||||
token: 'enc:expired-token',
|
||||
expiresIn: -1,
|
||||
});
|
||||
await store.createToken({
|
||||
userId: 'u1',
|
||||
type: 'mcp_oauth_refresh',
|
||||
identifier: 'mcp:srv1:refresh',
|
||||
token: 'enc:rt',
|
||||
expiresIn: 86400,
|
||||
});
|
||||
|
||||
const refreshTokens = jest.fn().mockRejectedValue(new Error('unauthorized_client'));
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
createToken: store.createToken,
|
||||
refreshTokens,
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('does not support refresh tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('storeTokens + getTokens round-trip', () => {
|
||||
it('should store and retrieve tokens with full encrypt/decrypt cycle', async () => {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
tokens: {
|
||||
access_token: 'my-access-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'my-refresh-token',
|
||||
},
|
||||
createToken: store.createToken,
|
||||
clientInfo: { client_id: 'cid', client_secret: 'sec', redirect_uris: [] },
|
||||
});
|
||||
|
||||
const result = await MCPTokenStorage.getTokens({
|
||||
userId: 'u1',
|
||||
serverName: 'srv1',
|
||||
findToken: store.findToken,
|
||||
});
|
||||
|
||||
expect(result!.access_token).toBe('my-access-token');
|
||||
expect(result!.refresh_token).toBe('my-refresh-token');
|
||||
expect(result!.token_type).toBe('Bearer');
|
||||
expect(result!.obtained_at).toBeDefined();
|
||||
expect(result!.expires_at).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1439,5 +1439,292 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => {
|
|||
}),
|
||||
);
|
||||
});
|
||||
|
||||
describe('path-based URL origin fallback', () => {
|
||||
it('retries with origin URL when path-based discovery fails (stored clientInfo path)', async () => {
|
||||
const metadata = {
|
||||
serverName: 'sentry',
|
||||
serverUrl: 'https://mcp.sentry.dev/mcp',
|
||||
clientInfo: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
},
|
||||
};
|
||||
|
||||
const originMetadata = {
|
||||
issuer: 'https://mcp.sentry.dev/',
|
||||
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
|
||||
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
|
||||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||
response_types_supported: ['code'],
|
||||
jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json',
|
||||
subject_types_supported: ['public'],
|
||||
id_token_signing_alg_values_supported: ['RS256'],
|
||||
} as AuthorizationServerMetadata;
|
||||
|
||||
// First call (path-based URL) fails, second call (origin URL) succeeds
|
||||
mockDiscoverAuthorizationServerMetadata
|
||||
.mockResolvedValueOnce(undefined)
|
||||
.mockResolvedValueOnce(originMetadata);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
access_token: 'new-access-token',
|
||||
refresh_token: 'new-refresh-token',
|
||||
expires_in: 3600,
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const result = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
'test-refresh-token',
|
||||
metadata,
|
||||
{},
|
||||
{},
|
||||
);
|
||||
|
||||
// Discovery attempted twice: once with path URL, once with origin URL
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.any(URL),
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.any(URL),
|
||||
expect.any(Object),
|
||||
);
|
||||
const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL;
|
||||
const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL;
|
||||
expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp');
|
||||
expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/');
|
||||
|
||||
// Token endpoint from origin discovery metadata is used (string in stored-clientInfo branch)
|
||||
expect(mockFetch).toHaveBeenCalled();
|
||||
const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0];
|
||||
expect(typeof fetchUrl).toBe('string');
|
||||
expect(fetchUrl).toBe('https://mcp.sentry.dev/oauth/token');
|
||||
expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' }));
|
||||
expect(result.access_token).toBe('new-access-token');
|
||||
});
|
||||
|
||||
it('retries with origin URL when path-based discovery fails (no stored clientInfo)', async () => {
|
||||
// No clientInfo — uses the auto-discovered branch
|
||||
const metadata = {
|
||||
serverName: 'sentry',
|
||||
serverUrl: 'https://mcp.sentry.dev/mcp',
|
||||
};
|
||||
|
||||
const originMetadata = {
|
||||
issuer: 'https://mcp.sentry.dev/',
|
||||
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
|
||||
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
|
||||
response_types_supported: ['code'],
|
||||
jwks_uri: 'https://mcp.sentry.dev/.well-known/jwks.json',
|
||||
subject_types_supported: ['public'],
|
||||
id_token_signing_alg_values_supported: ['RS256'],
|
||||
} as AuthorizationServerMetadata;
|
||||
|
||||
// First call (path-based URL) fails, second call (origin URL) succeeds
|
||||
mockDiscoverAuthorizationServerMetadata
|
||||
.mockResolvedValueOnce(undefined)
|
||||
.mockResolvedValueOnce(originMetadata);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
access_token: 'new-access-token',
|
||||
refresh_token: 'new-refresh-token',
|
||||
expires_in: 3600,
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const result = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
'test-refresh-token',
|
||||
metadata,
|
||||
{},
|
||||
{},
|
||||
);
|
||||
|
||||
// Discovery attempted twice: once with path URL, once with origin URL
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.any(URL),
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.any(URL),
|
||||
expect.any(Object),
|
||||
);
|
||||
const firstDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[0][0] as URL;
|
||||
const secondDiscoveryUrl = mockDiscoverAuthorizationServerMetadata.mock.calls[1][0] as URL;
|
||||
expect(firstDiscoveryUrl.href).toBe('https://mcp.sentry.dev/mcp');
|
||||
expect(secondDiscoveryUrl.href).toBe('https://mcp.sentry.dev/');
|
||||
|
||||
// Token endpoint from origin discovery metadata is used (URL object in auto-discovered branch)
|
||||
expect(mockFetch).toHaveBeenCalled();
|
||||
const [fetchUrl, fetchOptions] = mockFetch.mock.calls[0];
|
||||
expect(fetchUrl).toBeInstanceOf(URL);
|
||||
expect(fetchUrl.toString()).toBe('https://mcp.sentry.dev/oauth/token');
|
||||
expect(fetchOptions).toEqual(expect.objectContaining({ method: 'POST' }));
|
||||
expect(result.access_token).toBe('new-access-token');
|
||||
});
|
||||
|
||||
it('falls back to /token when both path and origin discovery fail', async () => {
|
||||
const metadata = {
|
||||
serverName: 'sentry',
|
||||
serverUrl: 'https://mcp.sentry.dev/mcp',
|
||||
clientInfo: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
},
|
||||
};
|
||||
|
||||
// Both path AND origin discovery return undefined
|
||||
mockDiscoverAuthorizationServerMetadata
|
||||
.mockResolvedValueOnce(undefined)
|
||||
.mockResolvedValueOnce(undefined);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
access_token: 'new-access-token',
|
||||
refresh_token: 'new-refresh-token',
|
||||
expires_in: 3600,
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const result = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
'test-refresh-token',
|
||||
metadata,
|
||||
{},
|
||||
{},
|
||||
);
|
||||
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Falls back to /token relative to server URL origin
|
||||
const [fetchUrl] = mockFetch.mock.calls[0];
|
||||
expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/token');
|
||||
expect(result.access_token).toBe('new-access-token');
|
||||
});
|
||||
|
||||
it('does not retry with origin when server URL has no path (root URL)', async () => {
|
||||
const metadata = {
|
||||
serverName: 'test-server',
|
||||
serverUrl: 'https://auth.example.com/',
|
||||
clientInfo: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
},
|
||||
};
|
||||
|
||||
// Root URL discovery fails — no retry
|
||||
mockDiscoverAuthorizationServerMetadata.mockResolvedValueOnce(undefined);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ access_token: 'new-token', expires_in: 3600 }),
|
||||
} as Response);
|
||||
|
||||
await MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {});
|
||||
|
||||
// Only one discovery attempt for a root URL
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('retries with origin when path-based discovery throws', async () => {
|
||||
const metadata = {
|
||||
serverName: 'sentry',
|
||||
serverUrl: 'https://mcp.sentry.dev/mcp',
|
||||
clientInfo: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
},
|
||||
};
|
||||
|
||||
const originMetadata = {
|
||||
issuer: 'https://mcp.sentry.dev/',
|
||||
authorization_endpoint: 'https://mcp.sentry.dev/oauth/authorize',
|
||||
token_endpoint: 'https://mcp.sentry.dev/oauth/token',
|
||||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||
response_types_supported: ['code'],
|
||||
} as AuthorizationServerMetadata;
|
||||
|
||||
// First call throws, second call succeeds
|
||||
mockDiscoverAuthorizationServerMetadata
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockResolvedValueOnce(originMetadata);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
access_token: 'new-access-token',
|
||||
refresh_token: 'new-refresh-token',
|
||||
expires_in: 3600,
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
const result = await MCPOAuthHandler.refreshOAuthTokens(
|
||||
'test-refresh-token',
|
||||
metadata,
|
||||
{},
|
||||
{},
|
||||
);
|
||||
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
|
||||
const [fetchUrl] = mockFetch.mock.calls[0];
|
||||
expect(String(fetchUrl)).toBe('https://mcp.sentry.dev/oauth/token');
|
||||
expect(result.access_token).toBe('new-access-token');
|
||||
});
|
||||
|
||||
it('propagates the throw when root URL discovery throws', async () => {
|
||||
const metadata = {
|
||||
serverName: 'test-server',
|
||||
serverUrl: 'https://auth.example.com/',
|
||||
clientInfo: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
},
|
||||
};
|
||||
|
||||
mockDiscoverAuthorizationServerMetadata.mockRejectedValueOnce(
|
||||
new Error('Discovery failed'),
|
||||
);
|
||||
|
||||
await expect(
|
||||
MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}),
|
||||
).rejects.toThrow('Discovery failed');
|
||||
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('propagates the throw when both path and origin discovery throw', async () => {
|
||||
const metadata = {
|
||||
serverName: 'sentry',
|
||||
serverUrl: 'https://mcp.sentry.dev/mcp',
|
||||
clientInfo: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
},
|
||||
};
|
||||
|
||||
mockDiscoverAuthorizationServerMetadata
|
||||
.mockRejectedValueOnce(new Error('Network error'))
|
||||
.mockRejectedValueOnce(new Error('Origin also failed'));
|
||||
|
||||
await expect(
|
||||
MCPOAuthHandler.refreshOAuthTokens('test-refresh-token', metadata, {}, {}),
|
||||
).rejects.toThrow('Origin also failed');
|
||||
|
||||
expect(mockDiscoverAuthorizationServerMetadata).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
449
packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts
Normal file
449
packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import { randomUUID, createHash } from 'crypto';
|
||||
import { z } from 'zod';
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import type { FlowState } from '~/flow/types';
|
||||
import type { Socket } from 'net';
|
||||
|
||||
export class MockKeyv<T = unknown> {
|
||||
private store: Map<string, FlowState<T>>;
|
||||
|
||||
constructor() {
|
||||
this.store = new Map();
|
||||
}
|
||||
|
||||
async get(key: string): Promise<FlowState<T> | undefined> {
|
||||
return this.store.get(key);
|
||||
}
|
||||
|
||||
async set(key: string, value: FlowState<T>, _ttl?: number): Promise<true> {
|
||||
this.store.set(key, value);
|
||||
return true;
|
||||
}
|
||||
|
||||
async delete(key: string): Promise<boolean> {
|
||||
return this.store.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
export function getFreePort(): Promise<number> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const srv = net.createServer();
|
||||
srv.listen(0, '127.0.0.1', () => {
|
||||
const addr = srv.address() as net.AddressInfo;
|
||||
srv.close((err) => (err ? reject(err) : resolve(addr.port)));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export function trackSockets(httpServer: http.Server): () => Promise<void> {
|
||||
const sockets = new Set<Socket>();
|
||||
httpServer.on('connection', (socket: Socket) => {
|
||||
sockets.add(socket);
|
||||
socket.once('close', () => sockets.delete(socket));
|
||||
});
|
||||
return () =>
|
||||
new Promise<void>((resolve) => {
|
||||
for (const socket of sockets) {
|
||||
socket.destroy();
|
||||
}
|
||||
sockets.clear();
|
||||
httpServer.close(() => resolve());
|
||||
});
|
||||
}
|
||||
|
||||
export interface OAuthTestServerOptions {
|
||||
tokenTTLMs?: number;
|
||||
issueRefreshTokens?: boolean;
|
||||
refreshTokenTTLMs?: number;
|
||||
rotateRefreshTokens?: boolean;
|
||||
}
|
||||
|
||||
export interface OAuthTestServer {
|
||||
url: string;
|
||||
port: number;
|
||||
close: () => Promise<void>;
|
||||
issuedTokens: Set<string>;
|
||||
tokenTTL: number;
|
||||
tokenIssueTimes: Map<string, number>;
|
||||
issuedRefreshTokens: Map<string, string>;
|
||||
registeredClients: Map<string, { client_id: string; client_secret: string }>;
|
||||
getAuthCode: () => Promise<string>;
|
||||
}
|
||||
|
||||
async function readRequestBody(req: http.IncomingMessage): Promise<string> {
|
||||
const chunks: Uint8Array[] = [];
|
||||
for await (const chunk of req) {
|
||||
chunks.push(chunk as Uint8Array);
|
||||
}
|
||||
return Buffer.concat(chunks).toString();
|
||||
}
|
||||
|
||||
function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null {
|
||||
if (contentType?.includes('application/x-www-form-urlencoded')) {
|
||||
return new URLSearchParams(body);
|
||||
}
|
||||
if (contentType?.includes('application/json')) {
|
||||
const json = JSON.parse(body) as Record<string, string>;
|
||||
return new URLSearchParams(json);
|
||||
}
|
||||
return new URLSearchParams(body);
|
||||
}
|
||||
|
||||
export async function createOAuthMCPServer(
|
||||
options: OAuthTestServerOptions = {},
|
||||
): Promise<OAuthTestServer> {
|
||||
const {
|
||||
tokenTTLMs = 60000,
|
||||
issueRefreshTokens = false,
|
||||
refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000,
|
||||
rotateRefreshTokens = false,
|
||||
} = options;
|
||||
|
||||
const sessions = new Map<string, StreamableHTTPServerTransport>();
|
||||
const issuedTokens = new Set<string>();
|
||||
const tokenIssueTimes = new Map<string, number>();
|
||||
const issuedRefreshTokens = new Map<string, string>();
|
||||
const refreshTokenIssueTimes = new Map<string, number>();
|
||||
const authCodes = new Map<string, { codeChallenge?: string; codeChallengeMethod?: string }>();
|
||||
const registeredClients = new Map<string, { client_id: string; client_secret: string }>();
|
||||
|
||||
let port = 0;
|
||||
|
||||
const httpServer = http.createServer(async (req, res) => {
|
||||
const url = new URL(req.url ?? '/', `http://${req.headers.host}`);
|
||||
|
||||
if (url.pathname === '/.well-known/oauth-authorization-server' && req.method === 'GET') {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(
|
||||
JSON.stringify({
|
||||
issuer: `http://127.0.0.1:${port}`,
|
||||
authorization_endpoint: `http://127.0.0.1:${port}/authorize`,
|
||||
token_endpoint: `http://127.0.0.1:${port}/token`,
|
||||
registration_endpoint: `http://127.0.0.1:${port}/register`,
|
||||
response_types_supported: ['code'],
|
||||
grant_types_supported: issueRefreshTokens
|
||||
? ['authorization_code', 'refresh_token']
|
||||
: ['authorization_code'],
|
||||
token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post'],
|
||||
code_challenge_methods_supported: ['S256'],
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (url.pathname === '/register' && req.method === 'POST') {
|
||||
const body = await readRequestBody(req);
|
||||
const data = JSON.parse(body) as { redirect_uris?: string[] };
|
||||
const clientId = `client-${randomUUID().slice(0, 8)}`;
|
||||
const clientSecret = `secret-${randomUUID()}`;
|
||||
registeredClients.set(clientId, { client_id: clientId, client_secret: clientSecret });
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(
|
||||
JSON.stringify({
|
||||
client_id: clientId,
|
||||
client_secret: clientSecret,
|
||||
redirect_uris: data.redirect_uris ?? [],
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (url.pathname === '/authorize') {
|
||||
const code = randomUUID();
|
||||
const codeChallenge = url.searchParams.get('code_challenge') ?? undefined;
|
||||
const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined;
|
||||
authCodes.set(code, { codeChallenge, codeChallengeMethod });
|
||||
const redirectUri = url.searchParams.get('redirect_uri') ?? '';
|
||||
const state = url.searchParams.get('state') ?? '';
|
||||
res.writeHead(302, {
|
||||
Location: `${redirectUri}?code=${code}&state=${state}`,
|
||||
});
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
|
||||
if (url.pathname === '/token' && req.method === 'POST') {
|
||||
const body = await readRequestBody(req);
|
||||
const params = parseTokenRequest(body, req.headers['content-type']);
|
||||
if (!params) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_request' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const grantType = params.get('grant_type');
|
||||
|
||||
if (grantType === 'authorization_code') {
|
||||
const code = params.get('code');
|
||||
const codeData = code ? authCodes.get(code) : undefined;
|
||||
if (!code || !codeData) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
|
||||
if (codeData.codeChallenge) {
|
||||
const codeVerifier = params.get('code_verifier');
|
||||
if (!codeVerifier) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
if (codeData.codeChallengeMethod === 'S256') {
|
||||
const expected = createHash('sha256').update(codeVerifier).digest('base64url');
|
||||
if (expected !== codeData.codeChallenge) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
authCodes.delete(code);
|
||||
|
||||
const accessToken = randomUUID();
|
||||
issuedTokens.add(accessToken);
|
||||
tokenIssueTimes.set(accessToken, Date.now());
|
||||
|
||||
const tokenResponse: Record<string, string | number> = {
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
expires_in: Math.ceil(tokenTTLMs / 1000),
|
||||
};
|
||||
|
||||
if (issueRefreshTokens) {
|
||||
const refreshToken = randomUUID();
|
||||
issuedRefreshTokens.set(refreshToken, accessToken);
|
||||
refreshTokenIssueTimes.set(refreshToken, Date.now());
|
||||
tokenResponse.refresh_token = refreshToken;
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(tokenResponse));
|
||||
return;
|
||||
}
|
||||
|
||||
if (grantType === 'refresh_token' && issueRefreshTokens) {
|
||||
const refreshToken = params.get('refresh_token');
|
||||
if (!refreshToken || !issuedRefreshTokens.has(refreshToken)) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const issueTime = refreshTokenIssueTimes.get(refreshToken) ?? 0;
|
||||
if (Date.now() - issueTime > refreshTokenTTLMs) {
|
||||
issuedRefreshTokens.delete(refreshToken);
|
||||
refreshTokenIssueTimes.delete(refreshToken);
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_grant' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const newAccessToken = randomUUID();
|
||||
issuedTokens.add(newAccessToken);
|
||||
tokenIssueTimes.set(newAccessToken, Date.now());
|
||||
|
||||
const tokenResponse: Record<string, string | number> = {
|
||||
access_token: newAccessToken,
|
||||
token_type: 'Bearer',
|
||||
expires_in: Math.ceil(tokenTTLMs / 1000),
|
||||
};
|
||||
|
||||
if (rotateRefreshTokens) {
|
||||
issuedRefreshTokens.delete(refreshToken);
|
||||
refreshTokenIssueTimes.delete(refreshToken);
|
||||
const newRefreshToken = randomUUID();
|
||||
issuedRefreshTokens.set(newRefreshToken, newAccessToken);
|
||||
refreshTokenIssueTimes.set(newRefreshToken, Date.now());
|
||||
tokenResponse.refresh_token = newRefreshToken;
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(tokenResponse));
|
||||
return;
|
||||
}
|
||||
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'unsupported_grant_type' }));
|
||||
return;
|
||||
}
|
||||
|
||||
// All other paths require Bearer token auth
|
||||
const authHeader = req.headers.authorization;
|
||||
if (!authHeader || !authHeader.startsWith('Bearer ')) {
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_token' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const token = authHeader.slice(7);
|
||||
if (!issuedTokens.has(token)) {
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_token' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const issueTime = tokenIssueTimes.get(token) ?? 0;
|
||||
if (Date.now() - issueTime > tokenTTLMs) {
|
||||
issuedTokens.delete(token);
|
||||
tokenIssueTimes.delete(token);
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_token' }));
|
||||
return;
|
||||
}
|
||||
|
||||
// Authenticated MCP request — route to transport
|
||||
const sid = req.headers['mcp-session-id'] as string | undefined;
|
||||
let transport = sid ? sessions.get(sid) : undefined;
|
||||
|
||||
if (!transport) {
|
||||
transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => randomUUID(),
|
||||
});
|
||||
const mcp = new McpServer({ name: 'oauth-test-server', version: '0.0.1' });
|
||||
mcp.tool('echo', { message: z.string() }, async (args) => ({
|
||||
content: [{ type: 'text' as const, text: `echo: ${args.message}` }],
|
||||
}));
|
||||
await mcp.connect(transport);
|
||||
}
|
||||
|
||||
await transport.handleRequest(req, res);
|
||||
|
||||
if (transport.sessionId && !sessions.has(transport.sessionId)) {
|
||||
sessions.set(transport.sessionId, transport);
|
||||
transport.onclose = () => sessions.delete(transport!.sessionId!);
|
||||
}
|
||||
});
|
||||
|
||||
const destroySockets = trackSockets(httpServer);
|
||||
port = await getFreePort();
|
||||
await new Promise<void>((resolve) => httpServer.listen(port, '127.0.0.1', resolve));
|
||||
|
||||
return {
|
||||
url: `http://127.0.0.1:${port}/`,
|
||||
port,
|
||||
issuedTokens,
|
||||
tokenTTL: tokenTTLMs,
|
||||
tokenIssueTimes,
|
||||
issuedRefreshTokens,
|
||||
registeredClients,
|
||||
getAuthCode: async () => {
|
||||
const authRes = await fetch(
|
||||
`http://127.0.0.1:${port}/authorize?redirect_uri=http://localhost&state=test`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
return new URL(location).searchParams.get('code') ?? '';
|
||||
},
|
||||
close: async () => {
|
||||
const closing = [...sessions.values()].map((t) => t.close().catch(() => undefined));
|
||||
sessions.clear();
|
||||
await Promise.all(closing);
|
||||
await destroySockets();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export interface InMemoryToken {
|
||||
userId: string;
|
||||
type: string;
|
||||
identifier: string;
|
||||
token: string;
|
||||
expiresAt: Date;
|
||||
createdAt: Date;
|
||||
metadata?: Map<string, unknown> | Record<string, unknown>;
|
||||
}
|
||||
|
||||
export class InMemoryTokenStore {
|
||||
private tokens: Map<string, InMemoryToken> = new Map();
|
||||
|
||||
private key(filter: { userId?: string; type?: string; identifier?: string }): string {
|
||||
return `${filter.userId}:${filter.type}:${filter.identifier}`;
|
||||
}
|
||||
|
||||
findToken = async (filter: {
|
||||
userId?: string;
|
||||
type?: string;
|
||||
identifier?: string;
|
||||
}): Promise<InMemoryToken | null> => {
|
||||
for (const token of this.tokens.values()) {
|
||||
const matchUserId = !filter.userId || token.userId === filter.userId;
|
||||
const matchType = !filter.type || token.type === filter.type;
|
||||
const matchIdentifier = !filter.identifier || token.identifier === filter.identifier;
|
||||
if (matchUserId && matchType && matchIdentifier) {
|
||||
return token;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
createToken = async (data: {
|
||||
userId: string;
|
||||
type: string;
|
||||
identifier: string;
|
||||
token: string;
|
||||
expiresIn?: number;
|
||||
metadata?: Record<string, unknown>;
|
||||
}): Promise<InMemoryToken> => {
|
||||
const expiresIn = data.expiresIn ?? 365 * 24 * 60 * 60;
|
||||
const token: InMemoryToken = {
|
||||
userId: data.userId,
|
||||
type: data.type,
|
||||
identifier: data.identifier,
|
||||
token: data.token,
|
||||
expiresAt: new Date(Date.now() + expiresIn * 1000),
|
||||
createdAt: new Date(),
|
||||
metadata: data.metadata,
|
||||
};
|
||||
this.tokens.set(this.key(data), token);
|
||||
return token;
|
||||
};
|
||||
|
||||
updateToken = async (
|
||||
filter: { userId?: string; type?: string; identifier?: string },
|
||||
data: {
|
||||
userId?: string;
|
||||
type?: string;
|
||||
identifier?: string;
|
||||
token?: string;
|
||||
expiresIn?: number;
|
||||
metadata?: Record<string, unknown>;
|
||||
},
|
||||
): Promise<InMemoryToken> => {
|
||||
const existing = await this.findToken(filter);
|
||||
if (!existing) {
|
||||
throw new Error(`Token not found for filter: ${JSON.stringify(filter)}`);
|
||||
}
|
||||
const existingKey = this.key(existing);
|
||||
const expiresIn =
|
||||
data.expiresIn ?? Math.floor((existing.expiresAt.getTime() - Date.now()) / 1000);
|
||||
const updated: InMemoryToken = {
|
||||
...existing,
|
||||
token: data.token ?? existing.token,
|
||||
expiresAt: data.expiresIn ? new Date(Date.now() + expiresIn * 1000) : existing.expiresAt,
|
||||
metadata: data.metadata ?? existing.metadata,
|
||||
};
|
||||
this.tokens.set(existingKey, updated);
|
||||
return updated;
|
||||
};
|
||||
|
||||
deleteToken = async (filter: {
|
||||
userId: string;
|
||||
type: string;
|
||||
identifier: string;
|
||||
}): Promise<void> => {
|
||||
this.tokens.delete(this.key(filter));
|
||||
};
|
||||
|
||||
getAll(): InMemoryToken[] {
|
||||
return [...this.tokens.values()];
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.tokens.clear();
|
||||
}
|
||||
}
|
||||
668
packages/api/src/mcp/__tests__/reconnection-storm.test.ts
Normal file
668
packages/api/src/mcp/__tests__/reconnection-storm.test.ts
Normal file
|
|
@ -0,0 +1,668 @@
|
|||
/**
|
||||
* Reconnection storm regression tests for PR #12162.
|
||||
*
|
||||
* Validates circuit breaker, throttling, cooldown, and timeout fixes using real
|
||||
* MCP SDK transports (no mocked stubs). A real StreamableHTTP server is spun up
|
||||
* per test suite and MCPConnection talks to it through a genuine HTTP stack.
|
||||
*/
|
||||
import http from 'http';
|
||||
import { randomUUID } from 'crypto';
|
||||
import express from 'express';
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import type { Socket } from 'net';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
import { OAuthReconnectionTracker } from '~/mcp/oauth/OAuthReconnectionTracker';
|
||||
import { createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { mcpConfig } from '~/mcp/mcpConfig';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
interface TestServer {
|
||||
url: string;
|
||||
httpServer: http.Server;
|
||||
close: () => Promise<void>;
|
||||
}
|
||||
|
||||
function trackSockets(httpServer: http.Server): () => Promise<void> {
|
||||
const sockets = new Set<Socket>();
|
||||
httpServer.on('connection', (socket: Socket) => {
|
||||
sockets.add(socket);
|
||||
socket.once('close', () => sockets.delete(socket));
|
||||
});
|
||||
return () =>
|
||||
new Promise<void>((resolve) => {
|
||||
for (const socket of sockets) {
|
||||
socket.destroy();
|
||||
}
|
||||
sockets.clear();
|
||||
httpServer.close(() => resolve());
|
||||
});
|
||||
}
|
||||
|
||||
function startMCPServer(): Promise<TestServer> {
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
const transports: Record<string, StreamableHTTPServerTransport> = {};
|
||||
|
||||
function createServer(): McpServer {
|
||||
const server = new McpServer({ name: 'test-server', version: '1.0.0' });
|
||||
server.tool('echo', 'echoes input', { message: { type: 'string' } as never }, async (args) => {
|
||||
const msg = (args as Record<string, string>).message ?? '';
|
||||
return { content: [{ type: 'text', text: msg }] };
|
||||
});
|
||||
return server;
|
||||
}
|
||||
|
||||
app.all('/mcp', async (req, res) => {
|
||||
const sessionId = req.headers['mcp-session-id'] as string | undefined;
|
||||
|
||||
if (sessionId && transports[sessionId]) {
|
||||
await transports[sessionId].handleRequest(req, res, req.body);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sessionId && isInitializeRequest(req.body)) {
|
||||
const transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => randomUUID(),
|
||||
onsessioninitialized: (sid) => {
|
||||
transports[sid] = transport;
|
||||
},
|
||||
});
|
||||
transport.onclose = () => {
|
||||
const sid = transport.sessionId;
|
||||
if (sid) {
|
||||
delete transports[sid];
|
||||
}
|
||||
};
|
||||
const server = createServer();
|
||||
await server.connect(transport);
|
||||
await transport.handleRequest(req, res, req.body);
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.method === 'GET') {
|
||||
res.status(404).send('Not Found');
|
||||
return;
|
||||
}
|
||||
|
||||
res.status(400).json({
|
||||
jsonrpc: '2.0',
|
||||
error: { code: -32000, message: 'Bad Request: No valid session ID provided' },
|
||||
id: null,
|
||||
});
|
||||
});
|
||||
|
||||
return new Promise((resolve) => {
|
||||
const httpServer = app.listen(0, '127.0.0.1', () => {
|
||||
const destroySockets = trackSockets(httpServer);
|
||||
const addr = httpServer.address() as { port: number };
|
||||
resolve({
|
||||
url: `http://127.0.0.1:${addr.port}/mcp`,
|
||||
httpServer,
|
||||
close: async () => {
|
||||
for (const t of Object.values(transports)) {
|
||||
t.close().catch(() => {});
|
||||
}
|
||||
await destroySockets();
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function createConnection(serverName: string, url: string, initTimeout = 5000): MCPConnection {
|
||||
return new MCPConnection({
|
||||
serverName,
|
||||
serverConfig: { url, type: 'streamable-http', initTimeout } as never,
|
||||
});
|
||||
}
|
||||
|
||||
async function teardownConnection(conn: MCPConnection): Promise<void> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(conn as any).shouldStopReconnecting = true;
|
||||
conn.removeAllListeners();
|
||||
await conn.disconnect();
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(MCPConnection as any).circuitBreakers.clear();
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Fix #2 — Circuit breaker trips after rapid connect/disconnect */
|
||||
/* cycles (CB_MAX_CYCLES within window -> cooldown) */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #2: Circuit breaker stops rapid reconnect cycling', () => {
|
||||
it('blocks connection after CB_MAX_CYCLES rapid cycles via static circuit breaker', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('cycling-server', srv.url);
|
||||
|
||||
let completedCycles = 0;
|
||||
let breakerMessage = '';
|
||||
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
|
||||
for (let cycle = 0; cycle < maxAttempts; cycle++) {
|
||||
try {
|
||||
await conn.connect();
|
||||
await teardownConnection(conn);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(conn as any).shouldStopReconnecting = false;
|
||||
completedCycles++;
|
||||
} catch (e) {
|
||||
breakerMessage = (e as Error).message;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
expect(breakerMessage).toContain('Circuit breaker is open');
|
||||
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
|
||||
|
||||
await srv.close();
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Fix #3 — SSE 400/405 handled in same branch as 404 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #3: SSE 400/405 handled in same branch as 404', () => {
|
||||
it('400 with active session triggers reconnection (session lost)', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('sse-400', srv.url);
|
||||
await conn.connect();
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(conn as any).shouldStopReconnecting = true;
|
||||
|
||||
const changes: string[] = [];
|
||||
conn.on('connectionChange', (s: string) => changes.push(s));
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const transport = (conn as any).transport;
|
||||
transport.onerror({ message: 'Failed to open SSE stream', code: 400 });
|
||||
|
||||
expect(changes).toContain('error');
|
||||
|
||||
await teardownConnection(conn);
|
||||
await srv.close();
|
||||
});
|
||||
|
||||
it('405 with active session triggers reconnection (session lost)', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('sse-405', srv.url);
|
||||
await conn.connect();
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(conn as any).shouldStopReconnecting = true;
|
||||
|
||||
const changes: string[] = [];
|
||||
conn.on('connectionChange', (s: string) => changes.push(s));
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const transport = (conn as any).transport;
|
||||
transport.onerror({ message: 'Method Not Allowed', code: 405 });
|
||||
|
||||
expect(changes).toContain('error');
|
||||
|
||||
await teardownConnection(conn);
|
||||
await srv.close();
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Fix #4 — Circuit breaker state persists in static Map across */
|
||||
/* instance replacements */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #4: Circuit breaker state persists across instance replacement', () => {
|
||||
it('new MCPConnection for same serverName inherits breaker state from static Map', async () => {
|
||||
const srv = await startMCPServer();
|
||||
|
||||
const conn1 = createConnection('replace', srv.url);
|
||||
await conn1.connect();
|
||||
await teardownConnection(conn1);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cbAfterConn1 = (MCPConnection as any).circuitBreakers.get('replace');
|
||||
expect(cbAfterConn1).toBeDefined();
|
||||
const cyclesAfterConn1 = cbAfterConn1.cycleCount;
|
||||
expect(cyclesAfterConn1).toBeGreaterThan(0);
|
||||
|
||||
const conn2 = createConnection('replace', srv.url);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cbFromConn2 = (conn2 as any).getCircuitBreaker();
|
||||
expect(cbFromConn2.cycleCount).toBe(cyclesAfterConn1);
|
||||
|
||||
await teardownConnection(conn2);
|
||||
await srv.close();
|
||||
});
|
||||
|
||||
it('clearCooldown resets static state so explicit retry proceeds', () => {
|
||||
const conn = createConnection('replace', 'http://127.0.0.1:1/mcp');
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cb = (conn as any).getCircuitBreaker();
|
||||
cb.cooldownUntil = Date.now() + 999_999;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((conn as any).isCircuitOpen()).toBe(true);
|
||||
|
||||
MCPConnection.clearCooldown('replace');
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((conn as any).isCircuitOpen()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Fix #5 — Dead servers now trigger circuit breaker via */
|
||||
/* recordFailedRound() in the catch path */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #5: Dead server triggers circuit breaker', () => {
|
||||
it('failures trigger backoff, blocking subsequent attempts before they reach the SDK', async () => {
|
||||
const conn = createConnection('dead', 'http://127.0.0.1:1/mcp', 1000);
|
||||
const spy = jest.spyOn(conn.client, 'connect');
|
||||
|
||||
const totalAttempts = mcpConfig.CB_MAX_FAILED_ROUNDS + 2;
|
||||
const errors: string[] = [];
|
||||
for (let i = 0; i < totalAttempts; i++) {
|
||||
try {
|
||||
await conn.connect();
|
||||
} catch (e) {
|
||||
errors.push((e as Error).message);
|
||||
}
|
||||
}
|
||||
|
||||
expect(spy.mock.calls.length).toBe(mcpConfig.CB_MAX_FAILED_ROUNDS);
|
||||
expect(errors).toHaveLength(totalAttempts);
|
||||
expect(errors.filter((m) => m.includes('Circuit breaker is open'))).toHaveLength(2);
|
||||
|
||||
await conn.disconnect();
|
||||
});
|
||||
|
||||
it('user B is immediately blocked when user A already tripped the breaker for the same server', async () => {
|
||||
const deadUrl = 'http://127.0.0.1:1/mcp';
|
||||
|
||||
const userA = new MCPConnection({
|
||||
serverName: 'shared-dead',
|
||||
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
|
||||
userId: 'user-A',
|
||||
});
|
||||
|
||||
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
|
||||
try {
|
||||
await userA.connect();
|
||||
} catch {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
const userB = new MCPConnection({
|
||||
serverName: 'shared-dead',
|
||||
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
|
||||
userId: 'user-B',
|
||||
});
|
||||
const spyB = jest.spyOn(userB.client, 'connect');
|
||||
|
||||
let blockedMessage = '';
|
||||
try {
|
||||
await userB.connect();
|
||||
} catch (e) {
|
||||
blockedMessage = (e as Error).message;
|
||||
}
|
||||
|
||||
expect(blockedMessage).toContain('Circuit breaker is open');
|
||||
expect(spyB).toHaveBeenCalledTimes(0);
|
||||
|
||||
await userA.disconnect();
|
||||
await userB.disconnect();
|
||||
});
|
||||
|
||||
it('clearCooldown after user retry unblocks all users', async () => {
|
||||
const deadUrl = 'http://127.0.0.1:1/mcp';
|
||||
|
||||
const userA = new MCPConnection({
|
||||
serverName: 'shared-dead-clear',
|
||||
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
|
||||
userId: 'user-A',
|
||||
});
|
||||
for (let i = 0; i < mcpConfig.CB_MAX_FAILED_ROUNDS; i++) {
|
||||
try {
|
||||
await userA.connect();
|
||||
} catch {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
const userB = new MCPConnection({
|
||||
serverName: 'shared-dead-clear',
|
||||
serverConfig: { url: deadUrl, type: 'streamable-http', initTimeout: 1000 } as never,
|
||||
userId: 'user-B',
|
||||
});
|
||||
try {
|
||||
await userB.connect();
|
||||
} catch (e) {
|
||||
expect((e as Error).message).toContain('Circuit breaker is open');
|
||||
}
|
||||
|
||||
MCPConnection.clearCooldown('shared-dead-clear');
|
||||
|
||||
const spyB = jest.spyOn(userB.client, 'connect');
|
||||
try {
|
||||
await userB.connect();
|
||||
} catch {
|
||||
// expected — server is still dead
|
||||
}
|
||||
|
||||
expect(spyB).toHaveBeenCalledTimes(1);
|
||||
|
||||
await userA.disconnect();
|
||||
await userB.disconnect();
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Fix #5b — disconnect(false) preserves cycle tracking */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #5b: disconnect(false) preserves cycle tracking', () => {
|
||||
it('connect() passes false to disconnect, which calls recordCycle()', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('wipe', srv.url);
|
||||
const spy = jest.spyOn(conn, 'disconnect');
|
||||
|
||||
await conn.connect();
|
||||
expect(spy).toHaveBeenCalledWith(false);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cb = (MCPConnection as any).circuitBreakers.get('wipe');
|
||||
expect(cb).toBeDefined();
|
||||
expect(cb.cycleCount).toBeGreaterThan(0);
|
||||
|
||||
await teardownConnection(conn);
|
||||
await srv.close();
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Fix #6 — OAuth failure uses cooldown-based retry */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Fix #6: OAuth failure uses cooldown-based retry', () => {
|
||||
beforeEach(() => jest.useFakeTimers());
|
||||
afterEach(() => jest.useRealTimers());
|
||||
|
||||
it('isFailed expires after first cooldown of 5 min', () => {
|
||||
jest.setSystemTime(Date.now());
|
||||
const tracker = new OAuthReconnectionTracker();
|
||||
tracker.setFailed('u1', 'srv');
|
||||
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(true);
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(false);
|
||||
});
|
||||
|
||||
it('progressive cooldown: 5m, 10m, 20m, 30m (capped)', () => {
|
||||
jest.setSystemTime(Date.now());
|
||||
const tracker = new OAuthReconnectionTracker();
|
||||
|
||||
tracker.setFailed('u1', 'srv');
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(false);
|
||||
|
||||
tracker.setFailed('u1', 'srv');
|
||||
jest.advanceTimersByTime(10 * 60 * 1000);
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(false);
|
||||
|
||||
tracker.setFailed('u1', 'srv');
|
||||
jest.advanceTimersByTime(20 * 60 * 1000);
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(false);
|
||||
|
||||
tracker.setFailed('u1', 'srv');
|
||||
jest.advanceTimersByTime(29 * 60 * 1000);
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(false);
|
||||
});
|
||||
|
||||
it('removeFailed resets attempt count so next failure starts at 5m', () => {
|
||||
jest.setSystemTime(Date.now());
|
||||
const tracker = new OAuthReconnectionTracker();
|
||||
|
||||
tracker.setFailed('u1', 'srv');
|
||||
tracker.setFailed('u1', 'srv');
|
||||
tracker.setFailed('u1', 'srv');
|
||||
tracker.removeFailed('u1', 'srv');
|
||||
|
||||
tracker.setFailed('u1', 'srv');
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed('u1', 'srv')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Integration: Circuit breaker caps rapid cycling with real transport */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Cascade: Circuit breaker caps rapid cycling', () => {
|
||||
it('breaker trips before double CB_MAX_CYCLES complete against a live server', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('cascade', srv.url);
|
||||
const spy = jest.spyOn(conn.client, 'connect');
|
||||
|
||||
let completedCycles = 0;
|
||||
const maxAttempts = mcpConfig.CB_MAX_CYCLES * 2;
|
||||
for (let i = 0; i < maxAttempts; i++) {
|
||||
try {
|
||||
await conn.connect();
|
||||
await teardownConnection(conn);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(conn as any).shouldStopReconnecting = false;
|
||||
completedCycles++;
|
||||
} catch (e) {
|
||||
if ((e as Error).message.includes('Circuit breaker is open')) {
|
||||
break;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
expect(completedCycles).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
|
||||
expect(spy.mock.calls.length).toBeLessThanOrEqual(mcpConfig.CB_MAX_CYCLES);
|
||||
|
||||
await srv.close();
|
||||
});
|
||||
|
||||
it('breaker bounds failures against a killed server', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('cascade-die', srv.url, 2000);
|
||||
|
||||
await conn.connect();
|
||||
await teardownConnection(conn);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(conn as any).shouldStopReconnecting = false;
|
||||
await srv.close();
|
||||
|
||||
let breakerTripped = false;
|
||||
for (let i = 0; i < 10; i++) {
|
||||
try {
|
||||
await conn.connect();
|
||||
} catch (e) {
|
||||
if ((e as Error).message.includes('Circuit breaker is open')) {
|
||||
breakerTripped = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expect(breakerTripped).toBe(true);
|
||||
}, 30_000);
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* OAuth: cycle recovery after successful OAuth reconnect */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('OAuth: cycle budget recovery after successful OAuth', () => {
|
||||
let oauthServer: OAuthTestServer;
|
||||
|
||||
beforeEach(async () => {
|
||||
oauthServer = await createOAuthMCPServer({ tokenTTLMs: 60000 });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await oauthServer.close();
|
||||
});
|
||||
|
||||
async function exchangeCodeForToken(serverUrl: string): Promise<string> {
|
||||
const authRes = await fetch(`${serverUrl}authorize?redirect_uri=http://localhost&state=test`, {
|
||||
redirect: 'manual',
|
||||
});
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code') ?? '';
|
||||
const tokenRes = await fetch(`${serverUrl}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}`,
|
||||
});
|
||||
const data = (await tokenRes.json()) as { access_token: string };
|
||||
return data.access_token;
|
||||
}
|
||||
|
||||
it('should decrement cycle count after successful OAuth recovery', async () => {
|
||||
const serverName = 'oauth-cycle-test';
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
|
||||
const conn = new MCPConnection({
|
||||
serverName,
|
||||
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
// When oauthRequired fires, get a token and emit oauthHandled
|
||||
// This triggers the oauthRecovery path inside connectClient
|
||||
conn.on('oauthRequired', async () => {
|
||||
const accessToken = await exchangeCodeForToken(oauthServer.url);
|
||||
conn.setOAuthTokens({
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
conn.emit('oauthHandled');
|
||||
});
|
||||
|
||||
// connect() → 401 → oauthRequired → oauthHandled → connectClient returns
|
||||
// connect() sees not connected → throws "Connection not established"
|
||||
await expect(conn.connect()).rejects.toThrow('Connection not established');
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
|
||||
const cyclesBeforeRetry = cb.cycleCount;
|
||||
|
||||
// Retry — should succeed and decrement cycle count via oauthRecovery
|
||||
await conn.connect();
|
||||
expect(await conn.isConnected()).toBe(true);
|
||||
|
||||
const cyclesAfterSuccess = cb.cycleCount;
|
||||
// The retry adds +1 cycle (disconnect(false)) then -1 (oauthRecovery decrement)
|
||||
// So cyclesAfterSuccess should equal cyclesBeforeRetry, not cyclesBeforeRetry + 1
|
||||
expect(cyclesAfterSuccess).toBe(cyclesBeforeRetry);
|
||||
|
||||
await teardownConnection(conn);
|
||||
});
|
||||
|
||||
it('should allow more OAuth reconnects than non-OAuth before breaker trips', async () => {
|
||||
const serverName = 'oauth-budget';
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
|
||||
// Each OAuth flow: connect (+1) → 401 → oauthHandled → retry connect (+1) → success (-1) = net 1
|
||||
// Without the decrement it would be net 2 per flow, tripping the breaker after ~2 users
|
||||
let successfulFlows = 0;
|
||||
for (let i = 0; i < 10; i++) {
|
||||
const conn = new MCPConnection({
|
||||
serverName,
|
||||
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
|
||||
userId: `user-${i}`,
|
||||
});
|
||||
|
||||
conn.on('oauthRequired', async () => {
|
||||
const accessToken = await exchangeCodeForToken(oauthServer.url);
|
||||
conn.setOAuthTokens({
|
||||
access_token: accessToken,
|
||||
token_type: 'Bearer',
|
||||
} as MCPOAuthTokens);
|
||||
conn.emit('oauthHandled');
|
||||
});
|
||||
|
||||
try {
|
||||
// First connect: 401 → oauthHandled → returns without connection
|
||||
await conn.connect().catch(() => {});
|
||||
// Retry: succeeds with token, decrements cycle
|
||||
await conn.connect();
|
||||
successfulFlows++;
|
||||
await teardownConnection(conn);
|
||||
} catch (e) {
|
||||
conn.removeAllListeners();
|
||||
if ((e as Error).message.includes('Circuit breaker is open')) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// With the oauthRecovery decrement, each flow is net ~1 cycle instead of ~2,
|
||||
// so we should get more successful flows before the breaker trips
|
||||
expect(successfulFlows).toBeGreaterThanOrEqual(3);
|
||||
});
|
||||
|
||||
it('should not decrement cycle count when OAuth fails', async () => {
|
||||
const serverName = 'oauth-failed-no-decrement';
|
||||
MCPConnection.clearCooldown(serverName);
|
||||
|
||||
const conn = new MCPConnection({
|
||||
serverName,
|
||||
serverConfig: { type: 'streamable-http', url: oauthServer.url, initTimeout: 10000 },
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
conn.on('oauthRequired', () => {
|
||||
conn.emit('oauthFailed', new Error('user denied'));
|
||||
});
|
||||
|
||||
await expect(conn.connect()).rejects.toThrow();
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const cb = (MCPConnection as any).circuitBreakers.get(serverName);
|
||||
const cyclesAfterFailure = cb.cycleCount;
|
||||
|
||||
// connect() recorded +1 cycle, oauthFailed should NOT decrement
|
||||
expect(cyclesAfterFailure).toBeGreaterThanOrEqual(1);
|
||||
|
||||
conn.removeAllListeners();
|
||||
});
|
||||
});
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Sanity: Real transport works end-to-end */
|
||||
/* ------------------------------------------------------------------ */
|
||||
describe('Sanity: Real MCP SDK transport works correctly', () => {
|
||||
it('connects, lists tools, and disconnects cleanly', async () => {
|
||||
const srv = await startMCPServer();
|
||||
const conn = createConnection('sanity', srv.url);
|
||||
|
||||
await conn.connect();
|
||||
expect(await conn.isConnected()).toBe(true);
|
||||
|
||||
const tools = await conn.fetchTools();
|
||||
expect(tools).toEqual(expect.arrayContaining([expect.objectContaining({ name: 'echo' })]));
|
||||
|
||||
await teardownConnection(conn);
|
||||
await srv.close();
|
||||
});
|
||||
});
|
||||
|
|
@ -71,6 +71,17 @@ const FIVE_MINUTES = 5 * 60 * 1000;
|
|||
const DEFAULT_TIMEOUT = 60000;
|
||||
/** SSE connections through proxies may need longer initial handshake time */
|
||||
const SSE_CONNECT_TIMEOUT = 120000;
|
||||
const DEFAULT_INIT_TIMEOUT = 30000;
|
||||
|
||||
interface CircuitBreakerState {
|
||||
cycleCount: number;
|
||||
cycleWindowStart: number;
|
||||
cooldownUntil: number;
|
||||
failedRounds: number;
|
||||
failedWindowStart: number;
|
||||
failedBackoffUntil: number;
|
||||
}
|
||||
|
||||
/** Default body timeout for Streamable HTTP GET SSE streams that idle between server pushes */
|
||||
const DEFAULT_SSE_READ_TIMEOUT = FIVE_MINUTES;
|
||||
|
||||
|
|
@ -262,6 +273,7 @@ export class MCPConnection extends EventEmitter {
|
|||
private oauthTokens?: MCPOAuthTokens | null;
|
||||
private requestHeaders?: Record<string, string> | null;
|
||||
private oauthRequired = false;
|
||||
private oauthRecovery = false;
|
||||
private readonly useSSRFProtection: boolean;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
|
|
@ -274,6 +286,88 @@ export class MCPConnection extends EventEmitter {
|
|||
*/
|
||||
public readonly createdAt: number;
|
||||
|
||||
private static circuitBreakers: Map<string, CircuitBreakerState> = new Map();
|
||||
|
||||
public static clearCooldown(serverName: string): void {
|
||||
MCPConnection.circuitBreakers.delete(serverName);
|
||||
logger.debug(`[MCP][${serverName}] Circuit breaker state cleared`);
|
||||
}
|
||||
|
||||
private getCircuitBreaker(): CircuitBreakerState {
|
||||
let cb = MCPConnection.circuitBreakers.get(this.serverName);
|
||||
if (!cb) {
|
||||
cb = {
|
||||
cycleCount: 0,
|
||||
cycleWindowStart: Date.now(),
|
||||
cooldownUntil: 0,
|
||||
failedRounds: 0,
|
||||
failedWindowStart: Date.now(),
|
||||
failedBackoffUntil: 0,
|
||||
};
|
||||
MCPConnection.circuitBreakers.set(this.serverName, cb);
|
||||
}
|
||||
return cb;
|
||||
}
|
||||
|
||||
private isCircuitOpen(): boolean {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
return now < cb.cooldownUntil || now < cb.failedBackoffUntil;
|
||||
}
|
||||
|
||||
private recordCycle(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
if (now - cb.cycleWindowStart > mcpConfig.CB_CYCLE_WINDOW_MS) {
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
}
|
||||
cb.cycleCount++;
|
||||
if (cb.cycleCount >= mcpConfig.CB_MAX_CYCLES) {
|
||||
cb.cooldownUntil = now + mcpConfig.CB_CYCLE_COOLDOWN_MS;
|
||||
cb.cycleCount = 0;
|
||||
cb.cycleWindowStart = now;
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} Circuit breaker: too many cycles, cooling down for ${mcpConfig.CB_CYCLE_COOLDOWN_MS}ms`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private recordFailedRound(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
const now = Date.now();
|
||||
if (now - cb.failedWindowStart > mcpConfig.CB_FAILED_WINDOW_MS) {
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = now;
|
||||
}
|
||||
cb.failedRounds++;
|
||||
if (cb.failedRounds >= mcpConfig.CB_MAX_FAILED_ROUNDS) {
|
||||
const backoff = Math.min(
|
||||
mcpConfig.CB_BASE_BACKOFF_MS *
|
||||
Math.pow(2, cb.failedRounds - mcpConfig.CB_MAX_FAILED_ROUNDS),
|
||||
mcpConfig.CB_MAX_BACKOFF_MS,
|
||||
);
|
||||
cb.failedBackoffUntil = now + backoff;
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} Circuit breaker: too many failures, backing off for ${backoff}ms`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private resetFailedRounds(): void {
|
||||
const cb = this.getCircuitBreaker();
|
||||
cb.failedRounds = 0;
|
||||
cb.failedWindowStart = Date.now();
|
||||
cb.failedBackoffUntil = 0;
|
||||
}
|
||||
|
||||
public static decrementCycleCount(serverName: string): void {
|
||||
const cb = MCPConnection.circuitBreakers.get(serverName);
|
||||
if (cb && cb.cycleCount > 0) {
|
||||
cb.cycleCount--;
|
||||
}
|
||||
}
|
||||
|
||||
setRequestHeaders(headers: Record<string, string> | null): void {
|
||||
if (!headers) {
|
||||
return;
|
||||
|
|
@ -686,6 +780,12 @@ export class MCPConnection extends EventEmitter {
|
|||
return;
|
||||
}
|
||||
|
||||
if (this.isCircuitOpen()) {
|
||||
this.connectionState = 'error';
|
||||
this.emit('connectionChange', 'error');
|
||||
throw new Error(`${this.getLogPrefix()} Circuit breaker is open, connection attempt blocked`);
|
||||
}
|
||||
|
||||
this.emit('connectionChange', 'connecting');
|
||||
|
||||
this.connectPromise = (async () => {
|
||||
|
|
@ -703,7 +803,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.transport = await runOutsideTracing(() => this.constructTransport(this.options));
|
||||
this.patchTransportSend();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 120000;
|
||||
const connectTimeout = this.options.initTimeout ?? DEFAULT_INIT_TIMEOUT;
|
||||
await runOutsideTracing(() =>
|
||||
withTimeout(
|
||||
this.client.connect(this.transport!),
|
||||
|
|
@ -716,6 +816,14 @@ export class MCPConnection extends EventEmitter {
|
|||
this.connectionState = 'connected';
|
||||
this.emit('connectionChange', 'connected');
|
||||
this.reconnectAttempts = 0;
|
||||
this.resetFailedRounds();
|
||||
if (this.oauthRecovery) {
|
||||
MCPConnection.decrementCycleCount(this.serverName);
|
||||
this.oauthRecovery = false;
|
||||
logger.debug(
|
||||
`${this.getLogPrefix()} OAuth recovery: decremented cycle count after successful reconnect`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
// Check if it's a rate limit error - stop immediately to avoid making it worse
|
||||
if (this.isRateLimitError(error)) {
|
||||
|
|
@ -799,9 +907,8 @@ export class MCPConnection extends EventEmitter {
|
|||
try {
|
||||
// Wait for OAuth to be handled
|
||||
await oauthHandledPromise;
|
||||
// Reset the oauthRequired flag
|
||||
this.oauthRequired = false;
|
||||
// Don't throw the error - just return so connection can be retried
|
||||
this.oauthRecovery = true;
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} OAuth handled successfully, connection will be retried`,
|
||||
);
|
||||
|
|
@ -817,6 +924,7 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
this.connectionState = 'error';
|
||||
this.emit('connectionChange', 'error');
|
||||
this.recordFailedRound();
|
||||
throw error;
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
|
|
@ -866,7 +974,8 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
async connect(): Promise<void> {
|
||||
try {
|
||||
await this.disconnect();
|
||||
// preserve cycle tracking across reconnects so the circuit breaker can detect rapid cycling
|
||||
await this.disconnect(false);
|
||||
await this.connectClient();
|
||||
if (!(await this.isConnected())) {
|
||||
throw new Error('Connection not established');
|
||||
|
|
@ -906,7 +1015,7 @@ export class MCPConnection extends EventEmitter {
|
|||
isTransient,
|
||||
} = extractSSEErrorMessage(error);
|
||||
|
||||
if (errorCode === 404) {
|
||||
if (errorCode === 400 || errorCode === 404 || errorCode === 405) {
|
||||
const hasSession =
|
||||
'sessionId' in transport &&
|
||||
(transport as { sessionId?: string }).sessionId != null &&
|
||||
|
|
@ -914,14 +1023,14 @@ export class MCPConnection extends EventEmitter {
|
|||
|
||||
if (!hasSession && errorMessage.toLowerCase().includes('failed to open sse stream')) {
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} SSE stream not available (404), no session. Ignoring.`,
|
||||
`${this.getLogPrefix()} SSE stream not available (${errorCode}), no session. Ignoring.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (hasSession) {
|
||||
logger.warn(
|
||||
`${this.getLogPrefix()} 404 with active session — session lost, triggering reconnection.`,
|
||||
`${this.getLogPrefix()} ${errorCode} with active session — session lost, triggering reconnection.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -992,7 +1101,7 @@ export class MCPConnection extends EventEmitter {
|
|||
await Promise.all(closing);
|
||||
}
|
||||
|
||||
public async disconnect(): Promise<void> {
|
||||
public async disconnect(resetCycleTracking = true): Promise<void> {
|
||||
try {
|
||||
if (this.transport) {
|
||||
await this.client.close();
|
||||
|
|
@ -1006,6 +1115,9 @@ export class MCPConnection extends EventEmitter {
|
|||
this.emit('connectionChange', 'disconnected');
|
||||
} finally {
|
||||
this.connectPromise = null;
|
||||
if (!resetCycleTracking) {
|
||||
this.recordCycle();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,4 +12,18 @@ export const mcpConfig = {
|
|||
USER_CONNECTION_IDLE_TIMEOUT: math(
|
||||
process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000,
|
||||
),
|
||||
/** Max connect/disconnect cycles before the circuit breaker trips. Default: 7 */
|
||||
CB_MAX_CYCLES: math(process.env.MCP_CB_MAX_CYCLES ?? 7),
|
||||
/** Sliding window (ms) for counting cycles. Default: 45s */
|
||||
CB_CYCLE_WINDOW_MS: math(process.env.MCP_CB_CYCLE_WINDOW_MS ?? 45_000),
|
||||
/** Cooldown (ms) after the cycle breaker trips. Default: 15s */
|
||||
CB_CYCLE_COOLDOWN_MS: math(process.env.MCP_CB_CYCLE_COOLDOWN_MS ?? 15_000),
|
||||
/** Max consecutive failed connection rounds before backoff. Default: 3 */
|
||||
CB_MAX_FAILED_ROUNDS: math(process.env.MCP_CB_MAX_FAILED_ROUNDS ?? 3),
|
||||
/** Sliding window (ms) for counting failed rounds. Default: 120s */
|
||||
CB_FAILED_WINDOW_MS: math(process.env.MCP_CB_FAILED_WINDOW_MS ?? 120_000),
|
||||
/** Base backoff (ms) after failed round threshold is reached. Default: 30s */
|
||||
CB_BASE_BACKOFF_MS: math(process.env.MCP_CB_BASE_BACKOFF_MS ?? 30_000),
|
||||
/** Max backoff cap (ms) for exponential backoff. Default: 300s */
|
||||
CB_MAX_BACKOFF_MS: math(process.env.MCP_CB_MAX_BACKOFF_MS ?? 300_000),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -253,17 +253,21 @@ describe('OAuthReconnectionManager', () => {
|
|||
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||
});
|
||||
|
||||
it('should not reconnect servers with expired tokens', async () => {
|
||||
it('should not reconnect servers with expired tokens and no refresh token', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// server1: has expired token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
|
||||
} as unknown as MCPOAuthTokens);
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() - 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
|
|
@ -272,6 +276,87 @@ describe('OAuthReconnectionManager', () => {
|
|||
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reconnect servers with expired access token but valid refresh token', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() - 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
if (identifier === 'mcp:server1:refresh') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should reconnect when access token is TTL-deleted but refresh token exists', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server1:refresh') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle connection that returns but is not connected', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
|
|
@ -336,6 +421,69 @@ describe('OAuthReconnectionManager', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('reconnectServer', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true on successful reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
const result = await reconnectionManager.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false on failed reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
const result = await reconnectionManager.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false when MCPManager is not available', async () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
|
||||
(OAuthReconnectionManager as unknown as { instance: null }).instance = null;
|
||||
(MCPManager.getInstance as jest.Mock).mockImplementation(() => {
|
||||
throw new Error('MCPManager has not been initialized.');
|
||||
});
|
||||
|
||||
const managerWithoutMCP = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
|
||||
const result = await managerWithoutMCP.reconnectServer(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnection staggering', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
|
||||
|
|
|
|||
|
|
@ -96,6 +96,24 @@ export class OAuthReconnectionManager {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to reconnect a single OAuth MCP server.
|
||||
* @returns true if reconnection succeeded, false otherwise.
|
||||
*/
|
||||
public async reconnectServer(userId: string, serverName: string): Promise<boolean> {
|
||||
if (this.mcpManager == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
this.reconnectionsTracker.setActive(userId, serverName);
|
||||
try {
|
||||
await this.tryReconnect(userId, serverName);
|
||||
return !this.reconnectionsTracker.isFailed(userId, serverName);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public clearReconnection(userId: string, serverName: string) {
|
||||
this.reconnectionsTracker.removeFailed(userId, serverName);
|
||||
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||
|
|
@ -174,23 +192,31 @@ export class OAuthReconnectionManager {
|
|||
}
|
||||
}
|
||||
|
||||
// if the server has no tokens for the user, don't attempt to reconnect
|
||||
// if the server has a valid (non-expired) access token, allow reconnect
|
||||
const accessToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}`,
|
||||
});
|
||||
if (accessToken == null) {
|
||||
|
||||
if (accessToken != null) {
|
||||
const now = new Date();
|
||||
if (!accessToken.expiresAt || accessToken.expiresAt >= now) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// if the access token is expired or TTL-deleted, fall back to refresh token
|
||||
const refreshToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}:refresh`,
|
||||
});
|
||||
|
||||
if (refreshToken == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the token has expired, don't attempt to reconnect
|
||||
const now = new Date();
|
||||
if (accessToken.expiresAt && accessToken.expiresAt < now) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// …otherwise, we're good to go with the reconnect attempt
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -397,6 +397,101 @@ describe('OAuthReconnectTracker', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('cooldown-based retry', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('should return true from isFailed within first cooldown period (5 min)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(4 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false from isFailed after first cooldown elapses (5 min)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should use progressive cooldown schedule (5m, 10m, 20m, 30m)', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
// First failure: 5 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Second failure: 10 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(9 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Third failure: 20 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(19 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Fourth failure: 30 min cooldown
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(29 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should cap cooldown at 30 min for attempts beyond 4', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(30 * 60 * 1000);
|
||||
}
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(29 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
jest.advanceTimersByTime(1 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should fully reset metadata on removeFailed', () => {
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, serverName);
|
||||
|
||||
tracker.removeFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
jest.advanceTimersByTime(5 * 60 * 1000);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('timestamp tracking edge cases', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
interface FailedMeta {
|
||||
attempts: number;
|
||||
lastFailedAt: number;
|
||||
}
|
||||
|
||||
const COOLDOWN_SCHEDULE_MS = [5 * 60 * 1000, 10 * 60 * 1000, 20 * 60 * 1000, 30 * 60 * 1000];
|
||||
|
||||
export class OAuthReconnectionTracker {
|
||||
/** Map of userId -> Set of serverNames that have failed reconnection */
|
||||
private failed: Map<string, Set<string>> = new Map();
|
||||
private failedMeta: Map<string, Map<string, FailedMeta>> = new Map();
|
||||
/** Map of userId -> Set of serverNames that are actively reconnecting */
|
||||
private active: Map<string, Set<string>> = new Map();
|
||||
/** Map of userId:serverName -> timestamp when reconnection started */
|
||||
|
|
@ -9,7 +15,17 @@ export class OAuthReconnectionTracker {
|
|||
private readonly RECONNECTION_TIMEOUT_MS = 3 * 60 * 1000; // 3 minutes
|
||||
|
||||
public isFailed(userId: string, serverName: string): boolean {
|
||||
return this.failed.get(userId)?.has(serverName) ?? false;
|
||||
const meta = this.failedMeta.get(userId)?.get(serverName);
|
||||
if (!meta) {
|
||||
return false;
|
||||
}
|
||||
const idx = Math.min(meta.attempts - 1, COOLDOWN_SCHEDULE_MS.length - 1);
|
||||
const cooldown = COOLDOWN_SCHEDULE_MS[idx];
|
||||
const elapsed = Date.now() - meta.lastFailedAt;
|
||||
if (elapsed >= cooldown) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Check if server is in the active set (original simple check) */
|
||||
|
|
@ -48,11 +64,15 @@ export class OAuthReconnectionTracker {
|
|||
}
|
||||
|
||||
public setFailed(userId: string, serverName: string): void {
|
||||
if (!this.failed.has(userId)) {
|
||||
this.failed.set(userId, new Set());
|
||||
if (!this.failedMeta.has(userId)) {
|
||||
this.failedMeta.set(userId, new Map());
|
||||
}
|
||||
|
||||
this.failed.get(userId)?.add(serverName);
|
||||
const userMap = this.failedMeta.get(userId)!;
|
||||
const existing = userMap.get(serverName);
|
||||
userMap.set(serverName, {
|
||||
attempts: (existing?.attempts ?? 0) + 1,
|
||||
lastFailedAt: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
public setActive(userId: string, serverName: string): void {
|
||||
|
|
@ -68,10 +88,10 @@ export class OAuthReconnectionTracker {
|
|||
}
|
||||
|
||||
public removeFailed(userId: string, serverName: string): void {
|
||||
const userServers = this.failed.get(userId);
|
||||
userServers?.delete(serverName);
|
||||
if (userServers?.size === 0) {
|
||||
this.failed.delete(userId);
|
||||
const userMap = this.failedMeta.get(userId);
|
||||
userMap?.delete(serverName);
|
||||
if (userMap?.size === 0) {
|
||||
this.failedMeta.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,7 +114,7 @@ export class OAuthReconnectionTracker {
|
|||
activeTimestamps: number;
|
||||
} {
|
||||
return {
|
||||
usersWithFailedServers: this.failed.size,
|
||||
usersWithFailedServers: this.failedMeta.size,
|
||||
usersWithActiveReconnections: this.active.size,
|
||||
activeTimestamps: this.activeTimestamps.size,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -161,20 +161,7 @@ export class MCPOAuthHandler {
|
|||
logger.debug(
|
||||
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
let rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl, {
|
||||
fetchFn,
|
||||
});
|
||||
|
||||
// If discovery failed and we're using a path-based URL, try the base URL
|
||||
if (!rawMetadata && authServerUrl.pathname !== '/') {
|
||||
const baseUrl = new URL(authServerUrl.origin);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
|
||||
);
|
||||
rawMetadata = await discoverAuthorizationServerMetadata(baseUrl, {
|
||||
fetchFn,
|
||||
});
|
||||
}
|
||||
const rawMetadata = await this.discoverWithOriginFallback(authServerUrl, fetchFn);
|
||||
|
||||
if (!rawMetadata) {
|
||||
/**
|
||||
|
|
@ -221,6 +208,39 @@ export class MCPOAuthHandler {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers OAuth authorization server metadata, retrying with just the origin
|
||||
* when discovery fails for a path-based URL. Shared implementation used by
|
||||
* `discoverMetadata` and both `refreshOAuthTokens` branches.
|
||||
*/
|
||||
private static async discoverWithOriginFallback(
|
||||
serverUrl: URL,
|
||||
fetchFn: FetchLike,
|
||||
): ReturnType<typeof discoverAuthorizationServerMetadata> {
|
||||
let metadata: Awaited<ReturnType<typeof discoverAuthorizationServerMetadata>>;
|
||||
try {
|
||||
metadata = await discoverAuthorizationServerMetadata(serverUrl, { fetchFn });
|
||||
} catch (err) {
|
||||
if (serverUrl.pathname === '/') {
|
||||
throw err;
|
||||
}
|
||||
const baseUrl = new URL(serverUrl.origin);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Discovery threw for path URL, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
|
||||
{ error: err },
|
||||
);
|
||||
return discoverAuthorizationServerMetadata(baseUrl, { fetchFn });
|
||||
}
|
||||
if (!metadata && serverUrl.pathname !== '/') {
|
||||
const baseUrl = new URL(serverUrl.origin);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Discovery failed with path, trying base URL: ${sanitizeUrlForLogging(baseUrl)}`,
|
||||
);
|
||||
return discoverAuthorizationServerMetadata(baseUrl, { fetchFn });
|
||||
}
|
||||
return metadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers an OAuth client dynamically
|
||||
*/
|
||||
|
|
@ -406,8 +426,8 @@ export class MCPOAuthHandler {
|
|||
scope: config.scope,
|
||||
});
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
/** Add cryptographic state parameter to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', state);
|
||||
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
|
||||
|
||||
const flowMetadata: MCPOAuthFlowMetadata = {
|
||||
|
|
@ -485,8 +505,8 @@ export class MCPOAuthHandler {
|
|||
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
|
||||
);
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
/** Add cryptographic state parameter to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', state);
|
||||
logger.debug(`[MCPOAuth] Added state parameter to authorization URL`);
|
||||
|
||||
if (resourceMetadata?.resource != null && resourceMetadata.resource) {
|
||||
|
|
@ -652,6 +672,44 @@ export class MCPOAuthHandler {
|
|||
return randomBytes(32).toString('base64url');
|
||||
}
|
||||
|
||||
private static readonly STATE_MAP_TYPE = 'mcp_oauth_state';
|
||||
|
||||
/**
|
||||
* Stores a mapping from the opaque OAuth state parameter to the flowId.
|
||||
* This allows the callback to resolve the flowId from an unguessable state
|
||||
* value, preventing attackers from forging callback requests.
|
||||
*/
|
||||
static async storeStateMapping(
|
||||
state: string,
|
||||
flowId: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
): Promise<void> {
|
||||
await flowManager.initFlow(state, this.STATE_MAP_TYPE, { flowId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves an opaque OAuth state parameter back to the original flowId.
|
||||
* Returns null if the state is not found (expired or never stored).
|
||||
*/
|
||||
static async resolveStateToFlowId(
|
||||
state: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
): Promise<string | null> {
|
||||
const mapping = await flowManager.getFlowState(state, this.STATE_MAP_TYPE);
|
||||
return (mapping?.metadata?.flowId as string) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes an orphaned state mapping when a flow is replaced.
|
||||
* Prevents old authorization URLs from resolving after a flow restart.
|
||||
*/
|
||||
static async deleteStateMapping(
|
||||
state: string,
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
): Promise<void> {
|
||||
await flowManager.deleteFlow(state, this.STATE_MAP_TYPE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the default redirect URI for a server
|
||||
*/
|
||||
|
|
@ -735,9 +793,10 @@ export class MCPOAuthHandler {
|
|||
throw new Error('No token URL available for refresh');
|
||||
} else {
|
||||
/** Auto-discover OAuth configuration for refresh */
|
||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
|
||||
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||
});
|
||||
const serverUrl = new URL(metadata.serverUrl);
|
||||
const fetchFn = this.createOAuthFetch(oauthHeaders);
|
||||
const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn);
|
||||
|
||||
if (!oauthMetadata) {
|
||||
/**
|
||||
* No metadata discovered - use fallback /token endpoint.
|
||||
|
|
@ -911,9 +970,9 @@ export class MCPOAuthHandler {
|
|||
}
|
||||
|
||||
/** Auto-discover OAuth configuration for refresh */
|
||||
const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, {
|
||||
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||
});
|
||||
const serverUrl = new URL(metadata.serverUrl);
|
||||
const fetchFn = this.createOAuthFetch(oauthHeaders);
|
||||
const oauthMetadata = await this.discoverWithOriginFallback(serverUrl, fetchFn);
|
||||
|
||||
let tokenUrl: URL;
|
||||
if (!oauthMetadata?.token_endpoint) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,15 @@ import type { TokenMethods, IToken } from '@librechat/data-schemas';
|
|||
import type { MCPOAuthTokens, ExtendedOAuthTokens, OAuthMetadata } from './types';
|
||||
import { isSystemUserId } from '~/mcp/enum';
|
||||
|
||||
export class ReauthenticationRequiredError extends Error {
|
||||
constructor(serverName: string, reason: 'expired' | 'missing') {
|
||||
super(
|
||||
`Re-authentication required for "${serverName}": access token ${reason} and no refresh token available`,
|
||||
);
|
||||
this.name = 'ReauthenticationRequiredError';
|
||||
}
|
||||
}
|
||||
|
||||
interface StoreTokensParams {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
|
|
@ -27,7 +36,12 @@ interface GetTokensParams {
|
|||
findToken: TokenMethods['findToken'];
|
||||
refreshTokens?: (
|
||||
refreshToken: string,
|
||||
metadata: { userId: string; serverName: string; identifier: string },
|
||||
metadata: {
|
||||
userId: string;
|
||||
serverName: string;
|
||||
identifier: string;
|
||||
clientInfo?: OAuthClientInformation;
|
||||
},
|
||||
) => Promise<MCPOAuthTokens>;
|
||||
createToken?: TokenMethods['createToken'];
|
||||
updateToken?: TokenMethods['updateToken'];
|
||||
|
|
@ -273,10 +287,11 @@ export class MCPTokenStorage {
|
|||
});
|
||||
|
||||
if (!refreshTokenData) {
|
||||
const reason = isMissing ? 'missing' : 'expired';
|
||||
logger.info(
|
||||
`${logPrefix} Access token ${isMissing ? 'missing' : 'expired'} and no refresh token available`,
|
||||
`${logPrefix} Access token ${reason} and no refresh token available — re-authentication required`,
|
||||
);
|
||||
return null;
|
||||
throw new ReauthenticationRequiredError(serverName, reason);
|
||||
}
|
||||
|
||||
if (!refreshTokens) {
|
||||
|
|
@ -395,6 +410,9 @@ export class MCPTokenStorage {
|
|||
logger.debug(`${logPrefix} Loaded existing OAuth tokens from storage`);
|
||||
return tokens;
|
||||
} catch (error) {
|
||||
if (error instanceof ReauthenticationRequiredError) {
|
||||
throw error;
|
||||
}
|
||||
logger.error(`${logPrefix} Failed to retrieve tokens`, error);
|
||||
return null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ export interface MCPOAuthFlowMetadata extends FlowMetadata {
|
|||
clientInfo?: OAuthClientInformation;
|
||||
metadata?: OAuthMetadata;
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
authorizationUrl?: string;
|
||||
}
|
||||
|
||||
export interface MCPOAuthTokens extends OAuthTokens {
|
||||
|
|
|
|||
|
|
@ -17,20 +17,20 @@
|
|||
|
||||
import * as net from 'net';
|
||||
import * as http from 'http';
|
||||
import { Keyv } from 'keyv';
|
||||
import { Agent } from 'undici';
|
||||
import { Types } from 'mongoose';
|
||||
import { randomUUID } from 'crypto';
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import { Keyv } from 'keyv';
|
||||
import { Types } from 'mongoose';
|
||||
import type { IUser } from '@librechat/data-schemas';
|
||||
import type { Socket } from 'net';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { MCPInspectionFailedError } from '~/mcp/errors';
|
||||
import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache';
|
||||
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
|
||||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPInspectionFailedError } from '~/mcp/errors';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ export * from './promise';
|
|||
export * from './sanitizeTitle';
|
||||
export * from './tempChatRetention';
|
||||
export * from './text';
|
||||
export { default as Tokenizer, countTokens } from './tokenizer';
|
||||
export * from './yaml';
|
||||
export * from './http';
|
||||
export * from './tokens';
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ const createRealTokenCounter = () => {
|
|||
let callCount = 0;
|
||||
const tokenCountFn = (text: string): number => {
|
||||
callCount++;
|
||||
return Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
return Tokenizer.getTokenCount(text, 'o200k_base');
|
||||
};
|
||||
return {
|
||||
tokenCountFn,
|
||||
|
|
@ -590,9 +590,9 @@ describe('processTextWithTokenLimit', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('direct comparison with REAL tiktoken tokenizer', () => {
|
||||
beforeEach(() => {
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
describe('direct comparison with REAL ai-tokenizer', () => {
|
||||
beforeAll(async () => {
|
||||
await Tokenizer.initEncoding('o200k_base');
|
||||
});
|
||||
|
||||
it('should produce valid truncation with real tokenizer', async () => {
|
||||
|
|
@ -611,7 +611,7 @@ describe('processTextWithTokenLimit', () => {
|
|||
expect(result.text.length).toBeLessThan(text.length);
|
||||
});
|
||||
|
||||
it('should use fewer tiktoken calls than old implementation (realistic text)', async () => {
|
||||
it('should use fewer tokenizer calls than old implementation (realistic text)', async () => {
|
||||
const oldCounter = createRealTokenCounter();
|
||||
const newCounter = createRealTokenCounter();
|
||||
const text = createRealisticText(15000);
|
||||
|
|
@ -623,8 +623,6 @@ describe('processTextWithTokenLimit', () => {
|
|||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
|
|
@ -634,17 +632,17 @@ describe('processTextWithTokenLimit', () => {
|
|||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`[Real tiktoken ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`);
|
||||
console.log(`[Real tiktoken] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
console.log(`[Real tokenizer ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`);
|
||||
console.log(`[Real tokenizer] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
});
|
||||
|
||||
it('should handle the reported user scenario with real tokenizer (~120k tokens)', async () => {
|
||||
it('should handle large text with real tokenizer (~20k tokens)', async () => {
|
||||
const oldCounter = createRealTokenCounter();
|
||||
const newCounter = createRealTokenCounter();
|
||||
const text = createRealisticText(120000);
|
||||
const tokenLimit = 100000;
|
||||
const text = createRealisticText(20000);
|
||||
const tokenLimit = 15000;
|
||||
|
||||
const startOld = performance.now();
|
||||
await processTextWithTokenLimitOLD({
|
||||
|
|
@ -654,8 +652,6 @@ describe('processTextWithTokenLimit', () => {
|
|||
});
|
||||
const timeOld = performance.now() - startOld;
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
const startNew = performance.now();
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
|
|
@ -667,9 +663,9 @@ describe('processTextWithTokenLimit', () => {
|
|||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`\n[REAL TIKTOKEN - User reported scenario: ~120k tokens]`);
|
||||
console.log(`OLD implementation: ${oldCalls} tiktoken calls, ${timeOld.toFixed(0)}ms`);
|
||||
console.log(`NEW implementation: ${newCalls} tiktoken calls, ${timeNew.toFixed(0)}ms`);
|
||||
console.log(`\n[REAL TOKENIZER - ~20k tokens]`);
|
||||
console.log(`OLD implementation: ${oldCalls} tokenizer calls, ${timeOld.toFixed(0)}ms`);
|
||||
console.log(`NEW implementation: ${newCalls} tokenizer calls, ${timeNew.toFixed(0)}ms`);
|
||||
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`);
|
||||
console.log(
|
||||
|
|
@ -684,8 +680,8 @@ describe('processTextWithTokenLimit', () => {
|
|||
it('should achieve at least 70% reduction with real tokenizer', async () => {
|
||||
const oldCounter = createRealTokenCounter();
|
||||
const newCounter = createRealTokenCounter();
|
||||
const text = createRealisticText(50000);
|
||||
const tokenLimit = 10000;
|
||||
const text = createRealisticText(15000);
|
||||
const tokenLimit = 5000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
|
|
@ -693,8 +689,6 @@ describe('processTextWithTokenLimit', () => {
|
|||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
|
|
@ -706,7 +700,7 @@ describe('processTextWithTokenLimit', () => {
|
|||
const reduction = 1 - newCalls / oldCalls;
|
||||
|
||||
console.log(
|
||||
`[Real tiktoken 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
|
||||
`[Real tokenizer 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
|
||||
);
|
||||
|
||||
expect(reduction).toBeGreaterThanOrEqual(0.7);
|
||||
|
|
@ -714,10 +708,6 @@ describe('processTextWithTokenLimit', () => {
|
|||
});
|
||||
|
||||
describe('using countTokens async function from @librechat/api', () => {
|
||||
beforeEach(() => {
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('countTokens should return correct token count', async () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = await countTokens(text);
|
||||
|
|
@ -759,8 +749,6 @@ describe('processTextWithTokenLimit', () => {
|
|||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
|
|
@ -776,11 +764,11 @@ describe('processTextWithTokenLimit', () => {
|
|||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
});
|
||||
|
||||
it('should handle user reported scenario with countTokens (~120k tokens)', async () => {
|
||||
it('should handle large text with countTokens (~20k tokens)', async () => {
|
||||
const oldCounter = createCountTokensCounter();
|
||||
const newCounter = createCountTokensCounter();
|
||||
const text = createRealisticText(120000);
|
||||
const tokenLimit = 100000;
|
||||
const text = createRealisticText(20000);
|
||||
const tokenLimit = 15000;
|
||||
|
||||
const startOld = performance.now();
|
||||
await processTextWithTokenLimitOLD({
|
||||
|
|
@ -790,8 +778,6 @@ describe('processTextWithTokenLimit', () => {
|
|||
});
|
||||
const timeOld = performance.now() - startOld;
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
const startNew = performance.now();
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
|
|
@ -803,7 +789,7 @@ describe('processTextWithTokenLimit', () => {
|
|||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`\n[countTokens - User reported scenario: ~120k tokens]`);
|
||||
console.log(`\n[countTokens - ~20k tokens]`);
|
||||
console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`);
|
||||
console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`);
|
||||
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
|
|
@ -820,8 +806,8 @@ describe('processTextWithTokenLimit', () => {
|
|||
it('should achieve at least 70% reduction with countTokens', async () => {
|
||||
const oldCounter = createCountTokensCounter();
|
||||
const newCounter = createCountTokensCounter();
|
||||
const text = createRealisticText(50000);
|
||||
const tokenLimit = 10000;
|
||||
const text = createRealisticText(15000);
|
||||
const tokenLimit = 5000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
|
|
@ -829,8 +815,6 @@ describe('processTextWithTokenLimit', () => {
|
|||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
|
|
@ -842,7 +826,7 @@ describe('processTextWithTokenLimit', () => {
|
|||
const reduction = 1 - newCalls / oldCalls;
|
||||
|
||||
console.log(
|
||||
`[countTokens 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
|
||||
`[countTokens 15k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
|
||||
);
|
||||
|
||||
expect(reduction).toBeGreaterThanOrEqual(0.7);
|
||||
|
|
|
|||
|
|
@ -1,12 +1,3 @@
|
|||
/**
|
||||
* @file Tokenizer.spec.cjs
|
||||
*
|
||||
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
|
||||
* Make sure to install `tiktoken` and have it configured properly.
|
||||
*/
|
||||
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { Tiktoken } from 'tiktoken';
|
||||
import Tokenizer from './tokenizer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
|
|
@ -17,127 +8,49 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
|
||||
describe('Tokenizer', () => {
|
||||
it('should be a singleton (same instance)', async () => {
|
||||
const AnotherTokenizer = await import('./tokenizer'); // same path
|
||||
const AnotherTokenizer = await import('./tokenizer');
|
||||
expect(Tokenizer).toBe(AnotherTokenizer.default);
|
||||
});
|
||||
|
||||
describe('getTokenizer', () => {
|
||||
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
|
||||
// The real `encoding_for_model` will be called internally
|
||||
// as soon as we pass isModelName = true.
|
||||
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
|
||||
|
||||
// Basic sanity checks
|
||||
expect(tokenizer).toBeDefined();
|
||||
// You can optionally check certain properties from `tiktoken` if they exist
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
describe('initEncoding', () => {
|
||||
it('should load o200k_base encoding', async () => {
|
||||
await Tokenizer.initEncoding('o200k_base');
|
||||
const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
|
||||
// The real `get_encoding` will be called internally
|
||||
// as soon as we pass isModelName = false.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
|
||||
expect(tokenizer).toBeDefined();
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
it('should load claude encoding', async () => {
|
||||
await Tokenizer.initEncoding('claude');
|
||||
const count = Tokenizer.getTokenCount('Hello, world!', 'claude');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should return cached tokenizer if previously fetched', () => {
|
||||
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
// Should be the exact same instance from the cache
|
||||
expect(tokenizer1).toBe(tokenizer2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('freeAndResetAllEncoders', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
|
||||
// By creating two different encodings, we populate the cache
|
||||
Tokenizer.getTokenizer('cl100k_base', false);
|
||||
Tokenizer.getTokenizer('r50k_base', false);
|
||||
|
||||
// Now free them
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// The internal cache is cleared
|
||||
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
|
||||
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
|
||||
|
||||
// tokenizerCallsCount is reset to 1
|
||||
expect(Tokenizer.tokenizerCallsCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should catch and log errors if freeing fails', () => {
|
||||
// Mock logger.error before the test
|
||||
const mockLoggerError = jest.spyOn(logger, 'error');
|
||||
|
||||
// Set up a problematic tokenizer in the cache
|
||||
Tokenizer.tokenizersCache['cl100k_base'] = {
|
||||
free() {
|
||||
throw new Error('Intentional free error');
|
||||
},
|
||||
} as unknown as Tiktoken;
|
||||
|
||||
// Should not throw uncaught errors
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// Verify logger.error was called with correct arguments
|
||||
expect(mockLoggerError).toHaveBeenCalledWith(
|
||||
'[Tokenizer] Free and reset encoders error',
|
||||
expect.any(Error),
|
||||
);
|
||||
|
||||
// Clean up
|
||||
mockLoggerError.mockRestore();
|
||||
Tokenizer.tokenizersCache = {};
|
||||
it('should deduplicate concurrent init calls', async () => {
|
||||
const [, , count] = await Promise.all([
|
||||
Tokenizer.initEncoding('o200k_base'),
|
||||
Tokenizer.initEncoding('o200k_base'),
|
||||
Tokenizer.initEncoding('o200k_base').then(() =>
|
||||
Tokenizer.getTokenCount('test', 'o200k_base'),
|
||||
),
|
||||
]);
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCount', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
beforeAll(async () => {
|
||||
await Tokenizer.initEncoding('o200k_base');
|
||||
await Tokenizer.initEncoding('claude');
|
||||
});
|
||||
|
||||
it('should return the number of tokens in the given text', () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
const count = Tokenizer.getTokenCount('Hello, world!', 'o200k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should reset encoders if an error is thrown', () => {
|
||||
// We can simulate an error by temporarily overriding the selected tokenizer's `encode` method.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const originalEncode = tokenizer.encode;
|
||||
tokenizer.encode = () => {
|
||||
throw new Error('Forced error');
|
||||
};
|
||||
|
||||
// Despite the forced error, the code should catch and reset, then re-encode
|
||||
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
|
||||
it('should count tokens using claude encoding', () => {
|
||||
const count = Tokenizer.getTokenCount('Hello, world!', 'claude');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
|
||||
// Restore the original encode
|
||||
tokenizer.encode = originalEncode;
|
||||
});
|
||||
|
||||
it('should reset tokenizers after 25 calls', () => {
|
||||
// Spy on freeAndResetAllEncoders
|
||||
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
|
||||
|
||||
// Make 24 calls; should NOT reset yet
|
||||
for (let i = 0; i < 24; i++) {
|
||||
Tokenizer.getTokenCount('test text', 'cl100k_base');
|
||||
}
|
||||
expect(resetSpy).not.toHaveBeenCalled();
|
||||
|
||||
// 25th call triggers the reset
|
||||
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
|
||||
expect(resetSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,74 +1,46 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken';
|
||||
import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken';
|
||||
import { Tokenizer as AiTokenizer } from 'ai-tokenizer';
|
||||
|
||||
interface TokenizerOptions {
|
||||
debug?: boolean;
|
||||
}
|
||||
export type EncodingName = 'o200k_base' | 'claude';
|
||||
|
||||
type EncodingData = ConstructorParameters<typeof AiTokenizer>[0];
|
||||
|
||||
class Tokenizer {
|
||||
tokenizersCache: Record<string, Tiktoken>;
|
||||
tokenizerCallsCount: number;
|
||||
private options?: TokenizerOptions;
|
||||
private tokenizersCache: Partial<Record<EncodingName, AiTokenizer>> = {};
|
||||
private loadingPromises: Partial<Record<EncodingName, Promise<void>>> = {};
|
||||
|
||||
constructor() {
|
||||
this.tokenizersCache = {};
|
||||
this.tokenizerCallsCount = 0;
|
||||
}
|
||||
|
||||
getTokenizer(
|
||||
encoding: TiktokenModel | TiktokenEncoding,
|
||||
isModelName = false,
|
||||
extendSpecialTokens: Record<string, number> = {},
|
||||
): Tiktoken {
|
||||
let tokenizer: Tiktoken;
|
||||
/** Pre-loads an encoding so that subsequent getTokenCount calls are accurate. */
|
||||
async initEncoding(encoding: EncodingName): Promise<void> {
|
||||
if (this.tokenizersCache[encoding]) {
|
||||
tokenizer = this.tokenizersCache[encoding];
|
||||
} else {
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens);
|
||||
}
|
||||
this.tokenizersCache[encoding] = tokenizer;
|
||||
return;
|
||||
}
|
||||
return tokenizer;
|
||||
if (this.loadingPromises[encoding]) {
|
||||
return this.loadingPromises[encoding];
|
||||
}
|
||||
this.loadingPromises[encoding] = (async () => {
|
||||
const data: EncodingData =
|
||||
encoding === 'claude'
|
||||
? await import('ai-tokenizer/encoding/claude')
|
||||
: await import('ai-tokenizer/encoding/o200k_base');
|
||||
this.tokenizersCache[encoding] = new AiTokenizer(data);
|
||||
})();
|
||||
return this.loadingPromises[encoding];
|
||||
}
|
||||
|
||||
freeAndResetAllEncoders(): void {
|
||||
getTokenCount(text: string, encoding: EncodingName = 'o200k_base'): number {
|
||||
const tokenizer = this.tokenizersCache[encoding];
|
||||
if (!tokenizer) {
|
||||
this.initEncoding(encoding);
|
||||
return Math.ceil(text.length / 4);
|
||||
}
|
||||
try {
|
||||
Object.keys(this.tokenizersCache).forEach((key) => {
|
||||
if (this.tokenizersCache[key]) {
|
||||
this.tokenizersCache[key].free();
|
||||
delete this.tokenizersCache[key];
|
||||
}
|
||||
});
|
||||
this.tokenizerCallsCount = 1;
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Free and reset encoders error', error);
|
||||
}
|
||||
}
|
||||
|
||||
resetTokenizersIfNecessary(): void {
|
||||
if (this.tokenizerCallsCount >= 25) {
|
||||
if (this.options?.debug) {
|
||||
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
|
||||
}
|
||||
this.freeAndResetAllEncoders();
|
||||
}
|
||||
this.tokenizerCallsCount++;
|
||||
}
|
||||
|
||||
getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number {
|
||||
this.resetTokenizersIfNecessary();
|
||||
try {
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
return tokenizer.count(text);
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Error getting token count:', error);
|
||||
this.freeAndResetAllEncoders();
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
delete this.tokenizersCache[encoding];
|
||||
delete this.loadingPromises[encoding];
|
||||
this.initEncoding(encoding);
|
||||
return Math.ceil(text.length / 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -76,13 +48,13 @@ class Tokenizer {
|
|||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
/**
|
||||
* Counts the number of tokens in a given text using tiktoken.
|
||||
* This is an async wrapper around Tokenizer.getTokenCount for compatibility.
|
||||
* @param text - The text to be tokenized. Defaults to an empty string if not provided.
|
||||
* Counts the number of tokens in a given text using ai-tokenizer with o200k_base encoding.
|
||||
* @param text - The text to count tokens in. Defaults to an empty string.
|
||||
* @returns The number of tokens in the provided text.
|
||||
*/
|
||||
export async function countTokens(text = ''): Promise<number> {
|
||||
return TokenizerSingleton.getTokenCount(text, 'cl100k_base');
|
||||
await TokenizerSingleton.initEncoding('o200k_base');
|
||||
return TokenizerSingleton.getTokenCount(text, 'o200k_base');
|
||||
}
|
||||
|
||||
export default TokenizerSingleton;
|
||||
|
|
|
|||
|
|
@ -593,42 +593,3 @@ export function processModelData(input: z.infer<typeof inputSchema>): EndpointTo
|
|||
|
||||
return tokenConfig;
|
||||
}
|
||||
|
||||
export const tiktokenModels = new Set([
|
||||
'text-davinci-003',
|
||||
'text-davinci-002',
|
||||
'text-davinci-001',
|
||||
'text-curie-001',
|
||||
'text-babbage-001',
|
||||
'text-ada-001',
|
||||
'davinci',
|
||||
'curie',
|
||||
'babbage',
|
||||
'ada',
|
||||
'code-davinci-002',
|
||||
'code-davinci-001',
|
||||
'code-cushman-002',
|
||||
'code-cushman-001',
|
||||
'davinci-codex',
|
||||
'cushman-codex',
|
||||
'text-davinci-edit-001',
|
||||
'code-davinci-edit-001',
|
||||
'text-embedding-ada-002',
|
||||
'text-similarity-davinci-001',
|
||||
'text-similarity-curie-001',
|
||||
'text-similarity-babbage-001',
|
||||
'text-similarity-ada-001',
|
||||
'text-search-davinci-doc-001',
|
||||
'text-search-curie-doc-001',
|
||||
'text-search-babbage-doc-001',
|
||||
'text-search-ada-doc-001',
|
||||
'code-search-babbage-code-001',
|
||||
'code-search-ada-code-001',
|
||||
'gpt2',
|
||||
'gpt-4',
|
||||
'gpt-4-0314',
|
||||
'gpt-4-32k',
|
||||
'gpt-4-32k-0314',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-0301',
|
||||
]);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue