mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
💫 feat: MCP OAuth Auto-Reconnect (#9646)
* add oauth reconnect tracker * add connection tracker to mcp manager * reconnect oauth mcp servers function * call reconnection in auth controller * make sure to check connection in panel * wait for isConnected * add const for poll interval * add logging to tryReconnect * check expiration * check mcp manager is not null * check mcp manager is not null * add test for reconnecting mcp server * unify logic inside OAuthReconnectionManager * test reconnection manager, adjust * chore: reorder import statements in index.js * chore: imports * chore: imports * chore: imports * chore: imports * chore: imports * chore: imports and use types explicitly --------- Co-authored-by: Danny Avila <danny@librechat.ai>
This commit is contained in:
parent
0e94d97bfb
commit
d04da60b3b
13 changed files with 830 additions and 13 deletions
|
@ -1,6 +1,6 @@
|
||||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
|
||||||
const { EventSource } = require('eventsource');
|
const { EventSource } = require('eventsource');
|
||||||
const { Time } = require('librechat-data-provider');
|
const { Time } = require('librechat-data-provider');
|
||||||
|
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
|
||||||
const logger = require('./winston');
|
const logger = require('./winston');
|
||||||
|
|
||||||
global.EventSource = EventSource;
|
global.EventSource = EventSource;
|
||||||
|
@ -26,4 +26,6 @@ module.exports = {
|
||||||
createMCPManager: MCPManager.createInstance,
|
createMCPManager: MCPManager.createInstance,
|
||||||
getMCPManager: MCPManager.getInstance,
|
getMCPManager: MCPManager.getInstance,
|
||||||
getFlowStateManager,
|
getFlowStateManager,
|
||||||
|
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
|
||||||
|
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
|
||||||
};
|
};
|
||||||
|
|
|
@ -11,8 +11,9 @@ const {
|
||||||
registerUser,
|
registerUser,
|
||||||
} = require('~/server/services/AuthService');
|
} = require('~/server/services/AuthService');
|
||||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||||
const { getOpenIdConfig } = require('~/strategies');
|
|
||||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||||
|
const { getOAuthReconnectionManager } = require('~/config');
|
||||||
|
const { getOpenIdConfig } = require('~/strategies');
|
||||||
|
|
||||||
const registrationController = async (req, res) => {
|
const registrationController = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
|
@ -107,6 +108,14 @@ const refreshController = async (req, res) => {
|
||||||
|
|
||||||
if (session && session.expiration > new Date()) {
|
if (session && session.expiration > new Date()) {
|
||||||
const token = await setAuthTokens(userId, res, session);
|
const token = await setAuthTokens(userId, res, session);
|
||||||
|
|
||||||
|
// trigger OAuth MCP server reconnection asynchronously (best effort)
|
||||||
|
void getOAuthReconnectionManager()
|
||||||
|
.reconnectServers(userId)
|
||||||
|
.catch((err) => {
|
||||||
|
logger.error('Error reconnecting OAuth MCP servers:', err);
|
||||||
|
});
|
||||||
|
|
||||||
res.status(200).send({ token, user });
|
res.status(200).send({ token, user });
|
||||||
} else if (req?.query?.retry) {
|
} else if (req?.query?.retry) {
|
||||||
// Retrying from a refresh token request that failed (401)
|
// Retrying from a refresh token request that failed (401)
|
||||||
|
|
|
@ -12,6 +12,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||||
const mongoSanitize = require('express-mongo-sanitize');
|
const mongoSanitize = require('express-mongo-sanitize');
|
||||||
const { isEnabled, ErrorController } = require('@librechat/api');
|
const { isEnabled, ErrorController } = require('@librechat/api');
|
||||||
const { connectDb, indexSync } = require('~/db');
|
const { connectDb, indexSync } = require('~/db');
|
||||||
|
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||||
const { updateInterfacePermissions } = require('~/models/interface');
|
const { updateInterfacePermissions } = require('~/models/interface');
|
||||||
|
@ -154,7 +155,7 @@ const startServer = async () => {
|
||||||
res.send(updatedIndexHtml);
|
res.send(updatedIndexHtml);
|
||||||
});
|
});
|
||||||
|
|
||||||
app.listen(port, host, () => {
|
app.listen(port, host, async () => {
|
||||||
if (host === '0.0.0.0') {
|
if (host === '0.0.0.0') {
|
||||||
logger.info(
|
logger.info(
|
||||||
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
|
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
|
||||||
|
@ -163,7 +164,9 @@ const startServer = async () => {
|
||||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
initializeMCPs().then(() => checkMigrations());
|
await initializeMCPs();
|
||||||
|
await initializeOAuthReconnectManager();
|
||||||
|
await checkMigrations();
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
const { Router } = require('express');
|
const { Router } = require('express');
|
||||||
const { logger } = require('@librechat/data-schemas');
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||||
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
||||||
|
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
|
||||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
|
||||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||||
const { requireJwtAuth } = require('~/server/middleware');
|
const { requireJwtAuth } = require('~/server/middleware');
|
||||||
const { findPluginAuthsByKeys } = require('~/models');
|
const { findPluginAuthsByKeys } = require('~/models');
|
||||||
|
@ -144,6 +144,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// clear any reconnection attempts
|
||||||
|
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||||
|
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
|
||||||
|
|
||||||
const tools = await userConnection.fetchTools();
|
const tools = await userConnection.fetchTools();
|
||||||
await updateMCPUserTools({
|
await updateMCPUserTools({
|
||||||
userId: flowState.userId,
|
userId: flowState.userId,
|
||||||
|
|
|
@ -20,8 +20,8 @@ const {
|
||||||
ContentTypes,
|
ContentTypes,
|
||||||
isAssistantsEndpoint,
|
isAssistantsEndpoint,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
|
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||||
const { findToken, createToken, updateToken } = require('~/models');
|
const { findToken, createToken, updateToken } = require('~/models');
|
||||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
|
||||||
const { getCachedTools, getAppConfig } = require('./Config');
|
const { getCachedTools, getAppConfig } = require('./Config');
|
||||||
const { reinitMCPServer } = require('./Tools/mcp');
|
const { reinitMCPServer } = require('./Tools/mcp');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
|
@ -538,13 +538,20 @@ async function getServerConnectionStatus(
|
||||||
const baseConnectionState = getConnectionState();
|
const baseConnectionState = getConnectionState();
|
||||||
let finalConnectionState = baseConnectionState;
|
let finalConnectionState = baseConnectionState;
|
||||||
|
|
||||||
|
// connection state overrides specific to OAuth servers
|
||||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
// check if server is actively being reconnected
|
||||||
|
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||||
if (hasFailedFlow) {
|
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
|
||||||
finalConnectionState = 'error';
|
|
||||||
} else if (hasActiveFlow) {
|
|
||||||
finalConnectionState = 'connecting';
|
finalConnectionState = 'connecting';
|
||||||
|
} else {
|
||||||
|
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||||
|
|
||||||
|
if (hasFailedFlow) {
|
||||||
|
finalConnectionState = 'error';
|
||||||
|
} else if (hasActiveFlow) {
|
||||||
|
finalConnectionState = 'connecting';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ jest.mock('./Config', () => ({
|
||||||
jest.mock('~/config', () => ({
|
jest.mock('~/config', () => ({
|
||||||
getMCPManager: jest.fn(),
|
getMCPManager: jest.fn(),
|
||||||
getFlowStateManager: jest.fn(),
|
getFlowStateManager: jest.fn(),
|
||||||
|
getOAuthReconnectionManager: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('~/cache', () => ({
|
jest.mock('~/cache', () => ({
|
||||||
|
@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
let mockGetMCPManager;
|
let mockGetMCPManager;
|
||||||
let mockGetFlowStateManager;
|
let mockGetFlowStateManager;
|
||||||
let mockGetLogStores;
|
let mockGetLogStores;
|
||||||
|
let mockGetOAuthReconnectionManager;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks();
|
||||||
|
@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
mockGetMCPManager = require('~/config').getMCPManager;
|
mockGetMCPManager = require('~/config').getMCPManager;
|
||||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||||
mockGetLogStores = require('~/cache').getLogStores;
|
mockGetLogStores = require('~/cache').getLogStores;
|
||||||
|
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getMCPSetupData', () => {
|
describe('getMCPSetupData', () => {
|
||||||
|
@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
const userConnections = new Map();
|
const userConnections = new Map();
|
||||||
const oauthServers = new Set([mockServerName]);
|
const oauthServers = new Set([mockServerName]);
|
||||||
|
|
||||||
|
// Mock OAuthReconnectionManager
|
||||||
|
const mockOAuthReconnectionManager = {
|
||||||
|
isReconnecting: jest.fn(() => false),
|
||||||
|
};
|
||||||
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||||
|
|
||||||
const result = await getServerConnectionStatus(
|
const result = await getServerConnectionStatus(
|
||||||
mockUserId,
|
mockUserId,
|
||||||
mockServerName,
|
mockServerName,
|
||||||
|
@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
const userConnections = new Map();
|
const userConnections = new Map();
|
||||||
const oauthServers = new Set([mockServerName]);
|
const oauthServers = new Set([mockServerName]);
|
||||||
|
|
||||||
|
// Mock OAuthReconnectionManager
|
||||||
|
const mockOAuthReconnectionManager = {
|
||||||
|
isReconnecting: jest.fn(() => false),
|
||||||
|
};
|
||||||
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||||
|
|
||||||
// Mock flow state to return failed flow
|
// Mock flow state to return failed flow
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
getFlowState: jest.fn(() => ({
|
getFlowState: jest.fn(() => ({
|
||||||
|
@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
const userConnections = new Map();
|
const userConnections = new Map();
|
||||||
const oauthServers = new Set([mockServerName]);
|
const oauthServers = new Set([mockServerName]);
|
||||||
|
|
||||||
|
// Mock OAuthReconnectionManager
|
||||||
|
const mockOAuthReconnectionManager = {
|
||||||
|
isReconnecting: jest.fn(() => false),
|
||||||
|
};
|
||||||
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||||
|
|
||||||
// Mock flow state to return active flow
|
// Mock flow state to return active flow
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
getFlowState: jest.fn(() => ({
|
getFlowState: jest.fn(() => ({
|
||||||
|
@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
const userConnections = new Map();
|
const userConnections = new Map();
|
||||||
const oauthServers = new Set([mockServerName]);
|
const oauthServers = new Set([mockServerName]);
|
||||||
|
|
||||||
|
// Mock OAuthReconnectionManager
|
||||||
|
const mockOAuthReconnectionManager = {
|
||||||
|
isReconnecting: jest.fn(() => false),
|
||||||
|
};
|
||||||
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||||
|
|
||||||
// Mock flow state to return no flow
|
// Mock flow state to return no flow
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
getFlowState: jest.fn(() => null),
|
getFlowState: jest.fn(() => null),
|
||||||
|
@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return connecting state when OAuth server is reconnecting', async () => {
|
||||||
|
const appConnections = new Map();
|
||||||
|
const userConnections = new Map();
|
||||||
|
const oauthServers = new Set([mockServerName]);
|
||||||
|
|
||||||
|
// Mock OAuthReconnectionManager to return true for isReconnecting
|
||||||
|
const mockOAuthReconnectionManager = {
|
||||||
|
isReconnecting: jest.fn(() => true),
|
||||||
|
};
|
||||||
|
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||||
|
|
||||||
|
const result = await getServerConnectionStatus(
|
||||||
|
mockUserId,
|
||||||
|
mockServerName,
|
||||||
|
appConnections,
|
||||||
|
userConnections,
|
||||||
|
oauthServers,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
requiresOAuth: true,
|
||||||
|
connectionState: 'connecting',
|
||||||
|
});
|
||||||
|
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
|
||||||
|
mockUserId,
|
||||||
|
mockServerName,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should not check OAuth flow status when server is connected', async () => {
|
it('should not check OAuth flow status when server is connected', async () => {
|
||||||
const mockFlowManager = {
|
const mockFlowManager = {
|
||||||
getFlowState: jest.fn(),
|
getFlowState: jest.fn(),
|
||||||
|
|
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
const { logger } = require('@librechat/data-schemas');
|
||||||
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
|
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
|
||||||
|
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize OAuth reconnect manager
|
||||||
|
*/
|
||||||
|
async function initializeOAuthReconnectManager() {
|
||||||
|
try {
|
||||||
|
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||||
|
const tokenMethods = {
|
||||||
|
findToken,
|
||||||
|
updateToken,
|
||||||
|
createToken,
|
||||||
|
deleteTokens,
|
||||||
|
};
|
||||||
|
await createOAuthReconnectionManager(flowManager, tokenMethods);
|
||||||
|
logger.info(`OAuth reconnect manager initialized successfully.`);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to initialize OAuth reconnect manager:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = initializeOAuthReconnectManager;
|
|
@ -1,4 +1,4 @@
|
||||||
import React, { useState, useMemo, useCallback } from 'react';
|
import React, { useState, useMemo, useCallback, useEffect } from 'react';
|
||||||
import { ChevronLeft, Trash2 } from 'lucide-react';
|
import { ChevronLeft, Trash2 } from 'lucide-react';
|
||||||
import { useQueryClient } from '@tanstack/react-query';
|
import { useQueryClient } from '@tanstack/react-query';
|
||||||
import { Button, useToastContext } from '@librechat/client';
|
import { Button, useToastContext } from '@librechat/client';
|
||||||
|
@ -12,6 +12,8 @@ import { useLocalize, useMCPConnectionStatus } from '~/hooks';
|
||||||
import { useGetStartupConfig } from '~/data-provider';
|
import { useGetStartupConfig } from '~/data-provider';
|
||||||
import MCPPanelSkeleton from './MCPPanelSkeleton';
|
import MCPPanelSkeleton from './MCPPanelSkeleton';
|
||||||
|
|
||||||
|
const POLL_FOR_CONNECTION_STATUS_INTERVAL = 2_000; // ms
|
||||||
|
|
||||||
function MCPPanelContent() {
|
function MCPPanelContent() {
|
||||||
const localize = useLocalize();
|
const localize = useLocalize();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
|
@ -26,6 +28,29 @@ function MCPPanelContent() {
|
||||||
null,
|
null,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Check if any connections are in 'connecting' state
|
||||||
|
const hasConnectingServers = useMemo(() => {
|
||||||
|
if (!connectionStatus) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return Object.values(connectionStatus).some(
|
||||||
|
(status) => status?.connectionState === 'connecting',
|
||||||
|
);
|
||||||
|
}, [connectionStatus]);
|
||||||
|
|
||||||
|
// Set up polling when servers are connecting
|
||||||
|
useEffect(() => {
|
||||||
|
if (!hasConnectingServers) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const intervalId = setInterval(() => {
|
||||||
|
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
|
||||||
|
}, POLL_FOR_CONNECTION_STATUS_INTERVAL);
|
||||||
|
|
||||||
|
return () => clearInterval(intervalId);
|
||||||
|
}, [hasConnectingServers, queryClient]);
|
||||||
|
|
||||||
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
|
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
|
||||||
onSuccess: async () => {
|
onSuccess: async () => {
|
||||||
showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' });
|
showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' });
|
||||||
|
|
|
@ -14,6 +14,7 @@ export * from './utils';
|
||||||
export * from './db/utils';
|
export * from './db/utils';
|
||||||
/* OAuth */
|
/* OAuth */
|
||||||
export * from './oauth';
|
export * from './oauth';
|
||||||
|
export * from './mcp/oauth/OAuthReconnectionManager';
|
||||||
/* Crypto */
|
/* Crypto */
|
||||||
export * from './crypto';
|
export * from './crypto';
|
||||||
/* Flow */
|
/* Flow */
|
||||||
|
|
294
packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts
Normal file
294
packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts
Normal file
|
@ -0,0 +1,294 @@
|
||||||
|
import { TokenMethods } from '@librechat/data-schemas';
|
||||||
|
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
|
||||||
|
import { MCPManager } from '../MCPManager';
|
||||||
|
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
|
||||||
|
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||||
|
|
||||||
|
jest.mock('@librechat/data-schemas', () => ({
|
||||||
|
logger: {
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
debug: jest.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('../MCPManager');
|
||||||
|
|
||||||
|
describe('OAuthReconnectionManager', () => {
|
||||||
|
let flowManager: jest.Mocked<FlowStateManager<null>>;
|
||||||
|
let tokenMethods: jest.Mocked<TokenMethods>;
|
||||||
|
let mockMCPManager: jest.Mocked<MCPManager>;
|
||||||
|
let reconnectionManager: OAuthReconnectionManager;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
|
||||||
|
// Reset singleton instance
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
(OAuthReconnectionManager as any).instance = null;
|
||||||
|
|
||||||
|
// Setup mock flow manager
|
||||||
|
flowManager = {
|
||||||
|
createFlow: jest.fn(),
|
||||||
|
completeFlow: jest.fn(),
|
||||||
|
failFlow: jest.fn(),
|
||||||
|
deleteFlow: jest.fn(),
|
||||||
|
getFlow: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<FlowStateManager<null>>;
|
||||||
|
|
||||||
|
// Setup mock token methods
|
||||||
|
tokenMethods = {
|
||||||
|
findToken: jest.fn(),
|
||||||
|
createToken: jest.fn(),
|
||||||
|
updateToken: jest.fn(),
|
||||||
|
deleteToken: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<TokenMethods>;
|
||||||
|
|
||||||
|
// Setup mock MCP Manager
|
||||||
|
mockMCPManager = {
|
||||||
|
getOAuthServers: jest.fn(),
|
||||||
|
getUserConnection: jest.fn(),
|
||||||
|
getUserConnections: jest.fn(),
|
||||||
|
disconnectUserConnection: jest.fn(),
|
||||||
|
getRawConfig: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<MCPManager>;
|
||||||
|
|
||||||
|
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Singleton Pattern', () => {
|
||||||
|
it('should create instance successfully', async () => {
|
||||||
|
const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
|
||||||
|
expect(instance).toBeInstanceOf(OAuthReconnectionManager);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw error when creating instance twice', async () => {
|
||||||
|
await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
|
||||||
|
await expect(
|
||||||
|
OAuthReconnectionManager.createInstance(flowManager, tokenMethods),
|
||||||
|
).rejects.toThrow('OAuthReconnectionManager already initialized');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw error when getting instance before creation', () => {
|
||||||
|
expect(() => OAuthReconnectionManager.getInstance()).toThrow(
|
||||||
|
'OAuthReconnectionManager not initialized',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('isReconnecting', () => {
|
||||||
|
let reconnectionTracker: OAuthReconnectionTracker;
|
||||||
|
beforeEach(async () => {
|
||||||
|
reconnectionTracker = new OAuthReconnectionTracker();
|
||||||
|
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||||
|
flowManager,
|
||||||
|
tokenMethods,
|
||||||
|
reconnectionTracker,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return true when server is actively reconnecting', () => {
|
||||||
|
const userId = 'user-123';
|
||||||
|
const serverName = 'test-server';
|
||||||
|
|
||||||
|
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
|
||||||
|
|
||||||
|
reconnectionTracker.setActive(userId, serverName);
|
||||||
|
const result = reconnectionManager.isReconnecting(userId, serverName);
|
||||||
|
expect(result).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false when server is not reconnecting', () => {
|
||||||
|
const userId = 'user-123';
|
||||||
|
const serverName = 'test-server';
|
||||||
|
|
||||||
|
const result = reconnectionManager.isReconnecting(userId, serverName);
|
||||||
|
expect(result).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('clearReconnection', () => {
|
||||||
|
let reconnectionTracker: OAuthReconnectionTracker;
|
||||||
|
beforeEach(async () => {
|
||||||
|
reconnectionTracker = new OAuthReconnectionTracker();
|
||||||
|
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||||
|
flowManager,
|
||||||
|
tokenMethods,
|
||||||
|
reconnectionTracker,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should clear both failed and active reconnection states', () => {
|
||||||
|
const userId = 'user-123';
|
||||||
|
const serverName = 'test-server';
|
||||||
|
|
||||||
|
reconnectionTracker.setFailed(userId, serverName);
|
||||||
|
reconnectionTracker.setActive(userId, serverName);
|
||||||
|
|
||||||
|
reconnectionManager.clearReconnection(userId, serverName);
|
||||||
|
|
||||||
|
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
|
||||||
|
expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false);
|
||||||
|
expect(reconnectionTracker.isActive(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('reconnectServers', () => {
|
||||||
|
let reconnectionTracker: OAuthReconnectionTracker;
|
||||||
|
beforeEach(async () => {
|
||||||
|
reconnectionTracker = new OAuthReconnectionTracker();
|
||||||
|
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||||
|
flowManager,
|
||||||
|
tokenMethods,
|
||||||
|
reconnectionTracker,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reconnect eligible servers', async () => {
|
||||||
|
const userId = 'user-123';
|
||||||
|
const oauthServers = new Set(['server1', 'server2', 'server3']);
|
||||||
|
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||||
|
|
||||||
|
// server1: has failed reconnection
|
||||||
|
reconnectionTracker.setFailed(userId, 'server1');
|
||||||
|
|
||||||
|
// server2: already connected
|
||||||
|
const mockConnection = {
|
||||||
|
isConnected: jest.fn().mockResolvedValue(true),
|
||||||
|
};
|
||||||
|
const userConnections = new Map([['server2', mockConnection]]);
|
||||||
|
mockMCPManager.getUserConnections.mockReturnValue(
|
||||||
|
userConnections as unknown as Map<string, MCPConnection>,
|
||||||
|
);
|
||||||
|
|
||||||
|
// server3: has valid token and not connected
|
||||||
|
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||||
|
if (identifier === 'mcp:server3') {
|
||||||
|
return {
|
||||||
|
userId,
|
||||||
|
identifier,
|
||||||
|
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
|
||||||
|
} as unknown as MCPOAuthTokens;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock successful reconnection
|
||||||
|
const mockNewConnection = {
|
||||||
|
isConnected: jest.fn().mockResolvedValue(true),
|
||||||
|
disconnect: jest.fn(),
|
||||||
|
};
|
||||||
|
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||||
|
mockNewConnection as unknown as MCPConnection,
|
||||||
|
);
|
||||||
|
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
|
||||||
|
|
||||||
|
await reconnectionManager.reconnectServers(userId);
|
||||||
|
|
||||||
|
// Verify server3 was marked as active
|
||||||
|
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true);
|
||||||
|
|
||||||
|
// Wait for async tryReconnect to complete
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||||
|
|
||||||
|
// Verify reconnection was attempted for server3
|
||||||
|
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({
|
||||||
|
serverName: 'server3',
|
||||||
|
user: { id: userId },
|
||||||
|
flowManager,
|
||||||
|
tokenMethods,
|
||||||
|
forceNew: false,
|
||||||
|
connectionTimeout: 5000,
|
||||||
|
returnOnOAuth: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify successful reconnection cleared the states
|
||||||
|
expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false);
|
||||||
|
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle failed reconnection attempts', async () => {
|
||||||
|
const userId = 'user-123';
|
||||||
|
const oauthServers = new Set(['server1']);
|
||||||
|
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||||
|
|
||||||
|
// server1: has valid token
|
||||||
|
tokenMethods.findToken.mockResolvedValue({
|
||||||
|
userId,
|
||||||
|
identifier: 'mcp:server1',
|
||||||
|
expiresAt: new Date(Date.now() + 3600000),
|
||||||
|
} as unknown as MCPOAuthTokens);
|
||||||
|
|
||||||
|
// Mock failed connection
|
||||||
|
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||||
|
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||||
|
|
||||||
|
await reconnectionManager.reconnectServers(userId);
|
||||||
|
|
||||||
|
// Wait for async tryReconnect to complete
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||||
|
|
||||||
|
// Verify failure handling
|
||||||
|
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
|
||||||
|
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||||
|
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not reconnect servers with expired tokens', async () => {
|
||||||
|
const userId = 'user-123';
|
||||||
|
const oauthServers = new Set(['server1']);
|
||||||
|
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||||
|
|
||||||
|
// server1: has expired token
|
||||||
|
tokenMethods.findToken.mockResolvedValue({
|
||||||
|
userId,
|
||||||
|
identifier: 'mcp:server1',
|
||||||
|
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
|
||||||
|
} as unknown as MCPOAuthTokens);
|
||||||
|
|
||||||
|
await reconnectionManager.reconnectServers(userId);
|
||||||
|
|
||||||
|
// Verify no reconnection attempt was made
|
||||||
|
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||||
|
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle connection that returns but is not connected', async () => {
|
||||||
|
const userId = 'user-123';
|
||||||
|
const oauthServers = new Set(['server1']);
|
||||||
|
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||||
|
|
||||||
|
tokenMethods.findToken.mockResolvedValue({
|
||||||
|
userId,
|
||||||
|
identifier: 'mcp:server1',
|
||||||
|
expiresAt: new Date(Date.now() + 3600000),
|
||||||
|
} as unknown as MCPOAuthTokens);
|
||||||
|
|
||||||
|
// Mock connection that returns but is not connected
|
||||||
|
const mockConnection = {
|
||||||
|
isConnected: jest.fn().mockResolvedValue(false),
|
||||||
|
disconnect: jest.fn(),
|
||||||
|
};
|
||||||
|
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||||
|
mockConnection as unknown as MCPConnection,
|
||||||
|
);
|
||||||
|
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||||
|
|
||||||
|
await reconnectionManager.reconnectServers(userId);
|
||||||
|
|
||||||
|
// Wait for async tryReconnect to complete
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||||
|
|
||||||
|
// Verify failure handling
|
||||||
|
expect(mockConnection.disconnect).toHaveBeenCalled();
|
||||||
|
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
|
||||||
|
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||||
|
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
163
packages/api/src/mcp/oauth/OAuthReconnectionManager.ts
Normal file
163
packages/api/src/mcp/oauth/OAuthReconnectionManager.ts
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
import { logger } from '@librechat/data-schemas';
|
||||||
|
import type { TokenMethods } from '@librechat/data-schemas';
|
||||||
|
import type { TUser } from 'librechat-data-provider';
|
||||||
|
import type { MCPOAuthTokens } from './types';
|
||||||
|
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||||
|
import { FlowStateManager } from '~/flow/manager';
|
||||||
|
import { MCPManager } from '~/mcp/MCPManager';
|
||||||
|
|
||||||
|
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
|
||||||
|
|
||||||
|
export class OAuthReconnectionManager {
|
||||||
|
private static instance: OAuthReconnectionManager | null = null;
|
||||||
|
|
||||||
|
protected readonly flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||||
|
protected readonly tokenMethods: TokenMethods;
|
||||||
|
|
||||||
|
private readonly reconnectionsTracker: OAuthReconnectionTracker;
|
||||||
|
|
||||||
|
public static getInstance(): OAuthReconnectionManager {
|
||||||
|
if (!OAuthReconnectionManager.instance) {
|
||||||
|
throw new Error('OAuthReconnectionManager not initialized');
|
||||||
|
}
|
||||||
|
return OAuthReconnectionManager.instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static async createInstance(
|
||||||
|
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||||
|
tokenMethods: TokenMethods,
|
||||||
|
reconnections?: OAuthReconnectionTracker,
|
||||||
|
): Promise<OAuthReconnectionManager> {
|
||||||
|
if (OAuthReconnectionManager.instance != null) {
|
||||||
|
throw new Error('OAuthReconnectionManager already initialized');
|
||||||
|
}
|
||||||
|
|
||||||
|
const manager = new OAuthReconnectionManager(flowManager, tokenMethods, reconnections);
|
||||||
|
OAuthReconnectionManager.instance = manager;
|
||||||
|
|
||||||
|
return manager;
|
||||||
|
}
|
||||||
|
|
||||||
|
public constructor(
|
||||||
|
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||||
|
tokenMethods: TokenMethods,
|
||||||
|
reconnections?: OAuthReconnectionTracker,
|
||||||
|
) {
|
||||||
|
this.flowManager = flowManager;
|
||||||
|
this.tokenMethods = tokenMethods;
|
||||||
|
this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker();
|
||||||
|
}
|
||||||
|
|
||||||
|
public isReconnecting(userId: string, serverName: string): boolean {
|
||||||
|
return this.reconnectionsTracker.isActive(userId, serverName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async reconnectServers(userId: string) {
|
||||||
|
const mcpManager = MCPManager.getInstance();
|
||||||
|
|
||||||
|
// 1. derive the servers to reconnect
|
||||||
|
const serversToReconnect = [];
|
||||||
|
for (const serverName of mcpManager.getOAuthServers() ?? []) {
|
||||||
|
const canReconnect = await this.canReconnect(userId, serverName);
|
||||||
|
if (canReconnect) {
|
||||||
|
serversToReconnect.push(serverName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. mark the servers as reconnecting
|
||||||
|
for (const serverName of serversToReconnect) {
|
||||||
|
this.reconnectionsTracker.setActive(userId, serverName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. attempt to reconnect the servers
|
||||||
|
for (const serverName of serversToReconnect) {
|
||||||
|
void this.tryReconnect(userId, serverName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public clearReconnection(userId: string, serverName: string) {
|
||||||
|
this.reconnectionsTracker.removeFailed(userId, serverName);
|
||||||
|
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async tryReconnect(userId: string, serverName: string) {
|
||||||
|
const mcpManager = MCPManager.getInstance();
|
||||||
|
|
||||||
|
const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`;
|
||||||
|
|
||||||
|
logger.info(`${logPrefix} Attempting reconnection`);
|
||||||
|
|
||||||
|
const config = mcpManager.getRawConfig(serverName);
|
||||||
|
|
||||||
|
const cleanupOnFailedReconnect = () => {
|
||||||
|
this.reconnectionsTracker.setFailed(userId, serverName);
|
||||||
|
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||||
|
mcpManager.disconnectUserConnection(userId, serverName);
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
// attempt to get connection (this will use existing tokens and refresh if needed)
|
||||||
|
const connection = await mcpManager.getUserConnection({
|
||||||
|
serverName,
|
||||||
|
user: { id: userId } as TUser,
|
||||||
|
flowManager: this.flowManager,
|
||||||
|
tokenMethods: this.tokenMethods,
|
||||||
|
// don't force new connection, let it reuse existing or create new as needed
|
||||||
|
forceNew: false,
|
||||||
|
// set a reasonable timeout for reconnection attempts
|
||||||
|
connectionTimeout: config?.initTimeout ?? DEFAULT_CONNECTION_TIMEOUT_MS,
|
||||||
|
// don't trigger OAuth flow during reconnection
|
||||||
|
returnOnOAuth: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (connection && (await connection.isConnected())) {
|
||||||
|
logger.info(`${logPrefix} Successfully reconnected`);
|
||||||
|
this.clearReconnection(userId, serverName);
|
||||||
|
} else {
|
||||||
|
logger.warn(`${logPrefix} Failed to reconnect`);
|
||||||
|
await connection?.disconnect();
|
||||||
|
cleanupOnFailedReconnect();
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn(`${logPrefix} Failed to reconnect: ${error}`);
|
||||||
|
cleanupOnFailedReconnect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async canReconnect(userId: string, serverName: string) {
|
||||||
|
const mcpManager = MCPManager.getInstance();
|
||||||
|
|
||||||
|
// if the server has failed reconnection, don't attempt to reconnect
|
||||||
|
if (this.reconnectionsTracker.isFailed(userId, serverName)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the server is already connected, don't attempt to reconnect
|
||||||
|
const existingConnections = mcpManager.getUserConnections(userId);
|
||||||
|
if (existingConnections?.has(serverName)) {
|
||||||
|
const isConnected = await existingConnections.get(serverName)?.isConnected();
|
||||||
|
if (isConnected) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the server has no tokens for the user, don't attempt to reconnect
|
||||||
|
const accessToken = await this.tokenMethods.findToken({
|
||||||
|
userId,
|
||||||
|
type: 'mcp_oauth',
|
||||||
|
identifier: `mcp:${serverName}`,
|
||||||
|
});
|
||||||
|
if (accessToken == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the token has expired, don't attempt to reconnect
|
||||||
|
const now = new Date();
|
||||||
|
if (accessToken.expiresAt && accessToken.expiresAt < now) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// …otherwise, we're good to go with the reconnect attempt
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
181
packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts
Normal file
181
packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||||
|
|
||||||
|
describe('OAuthReconnectTracker', () => {
|
||||||
|
let tracker: OAuthReconnectionTracker;
|
||||||
|
const userId = 'user123';
|
||||||
|
const serverName = 'test-server';
|
||||||
|
const anotherServer = 'another-server';
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
tracker = new OAuthReconnectionTracker();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setFailed', () => {
|
||||||
|
it('should record a failed reconnection attempt', () => {
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should track multiple servers for the same user', () => {
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
tracker.setFailed(userId, anotherServer);
|
||||||
|
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||||
|
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should track different users independently', () => {
|
||||||
|
const anotherUserId = 'user456';
|
||||||
|
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
tracker.setFailed(anotherUserId, serverName);
|
||||||
|
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||||
|
expect(tracker.isFailed(anotherUserId, serverName)).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('isFailed', () => {
|
||||||
|
it('should return false when no failed attempt is recorded', () => {
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return true after a failed attempt is recorded', () => {
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false for a different server even after another server failed', () => {
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
expect(tracker.isFailed(userId, anotherServer)).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('removeFailed', () => {
|
||||||
|
it('should clear a failed reconnect record', () => {
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||||
|
|
||||||
|
tracker.removeFailed(userId, serverName);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should only clear the specific server for the user', () => {
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
tracker.setFailed(userId, anotherServer);
|
||||||
|
|
||||||
|
tracker.removeFailed(userId, serverName);
|
||||||
|
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||||
|
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle clearing non-existent records gracefully', () => {
|
||||||
|
expect(() => tracker.removeFailed(userId, serverName)).not.toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setActive', () => {
|
||||||
|
it('should mark a server as reconnecting', () => {
|
||||||
|
tracker.setActive(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should track multiple reconnecting servers', () => {
|
||||||
|
tracker.setActive(userId, serverName);
|
||||||
|
tracker.setActive(userId, anotherServer);
|
||||||
|
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||||
|
expect(tracker.isActive(userId, anotherServer)).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('isActive', () => {
|
||||||
|
it('should return false when server is not reconnecting', () => {
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return true when server is marked as reconnecting', () => {
|
||||||
|
tracker.setActive(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle non-existent user gracefully', () => {
|
||||||
|
expect(tracker.isActive('non-existent-user', serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('removeActive', () => {
|
||||||
|
it('should clear reconnecting state for a server', () => {
|
||||||
|
tracker.setActive(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||||
|
|
||||||
|
tracker.removeActive(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should only clear specific server state', () => {
|
||||||
|
tracker.setActive(userId, serverName);
|
||||||
|
tracker.setActive(userId, anotherServer);
|
||||||
|
|
||||||
|
tracker.removeActive(userId, serverName);
|
||||||
|
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||||
|
expect(tracker.isActive(userId, anotherServer)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle clearing non-existent state gracefully', () => {
|
||||||
|
expect(() => tracker.removeActive(userId, serverName)).not.toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('cleanup behavior', () => {
|
||||||
|
it('should clean up empty user sets for failed reconnects', () => {
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
tracker.removeFailed(userId, serverName);
|
||||||
|
|
||||||
|
// Record and clear another user to ensure internal cleanup
|
||||||
|
const anotherUserId = 'user456';
|
||||||
|
tracker.setFailed(anotherUserId, serverName);
|
||||||
|
|
||||||
|
// Original user should still be able to reconnect
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should clean up empty user sets for active reconnections', () => {
|
||||||
|
tracker.setActive(userId, serverName);
|
||||||
|
tracker.removeActive(userId, serverName);
|
||||||
|
|
||||||
|
// Mark another user to ensure internal cleanup
|
||||||
|
const anotherUserId = 'user456';
|
||||||
|
tracker.setActive(anotherUserId, serverName);
|
||||||
|
|
||||||
|
// Original user should not be reconnecting
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('combined state management', () => {
|
||||||
|
it('should handle both failed and reconnecting states independently', () => {
|
||||||
|
// Mark as reconnecting
|
||||||
|
tracker.setActive(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||||
|
|
||||||
|
// Record failed attempt
|
||||||
|
tracker.setFailed(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||||
|
|
||||||
|
// Clear reconnecting state
|
||||||
|
tracker.removeActive(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||||
|
|
||||||
|
// Clear failed state
|
||||||
|
tracker.removeFailed(userId, serverName);
|
||||||
|
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||||
|
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
46
packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts
Normal file
46
packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
export class OAuthReconnectionTracker {
|
||||||
|
// Map of userId -> Set of serverNames that have failed reconnection
|
||||||
|
private failed: Map<string, Set<string>> = new Map();
|
||||||
|
// Map of userId -> Set of serverNames that are actively reconnecting
|
||||||
|
private active: Map<string, Set<string>> = new Map();
|
||||||
|
|
||||||
|
public isFailed(userId: string, serverName: string): boolean {
|
||||||
|
return this.failed.get(userId)?.has(serverName) ?? false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public isActive(userId: string, serverName: string): boolean {
|
||||||
|
return this.active.get(userId)?.has(serverName) ?? false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public setFailed(userId: string, serverName: string): void {
|
||||||
|
if (!this.failed.has(userId)) {
|
||||||
|
this.failed.set(userId, new Set());
|
||||||
|
}
|
||||||
|
|
||||||
|
this.failed.get(userId)?.add(serverName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public setActive(userId: string, serverName: string): void {
|
||||||
|
if (!this.active.has(userId)) {
|
||||||
|
this.active.set(userId, new Set());
|
||||||
|
}
|
||||||
|
|
||||||
|
this.active.get(userId)?.add(serverName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public removeFailed(userId: string, serverName: string): void {
|
||||||
|
const userServers = this.failed.get(userId);
|
||||||
|
userServers?.delete(serverName);
|
||||||
|
if (userServers?.size === 0) {
|
||||||
|
this.failed.delete(userId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public removeActive(userId: string, serverName: string): void {
|
||||||
|
const userServers = this.active.get(userId);
|
||||||
|
userServers?.delete(serverName);
|
||||||
|
if (userServers?.size === 0) {
|
||||||
|
this.active.delete(userId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue