This commit is contained in:
Danny Avila 2026-04-05 01:24:59 +09:00 committed by GitHub
commit 44071a38ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 222 additions and 39 deletions

View file

@ -45,7 +45,7 @@ export class MCPConnectionFactory {
/** Creates a new MCP connection with optional OAuth support */
static async create(
basic: t.BasicConnectionOptions,
oauth?: t.OAuthConnectionOptions,
oauth?: t.OAuthConnectionOptions | t.UserConnectionContext,
): Promise<MCPConnection> {
const factory = new this(basic, oauth);
return factory.createConnection();
@ -232,6 +232,17 @@ export class MCPConnectionFactory {
let cleanupOAuthHandlers: (() => void) | null = null;
if (this.useOAuth) {
cleanupOAuthHandlers = this.handleOAuthEvents(connection);
} else {
const nonOAuthHandler = () => {
logger.info(
`${this.logPrefix} Server does not use OAuth — treating 401/403 as auth failure, not OAuth`,
);
connection.emit('oauthFailed', new Error('Server does not use OAuth'));
};
connection.on('oauthRequired', nonOAuthHandler);
cleanupOAuthHandlers = () => {
connection.removeListener('oauthRequired', nonOAuthHandler);
};
}
try {

View file

@ -18,7 +18,7 @@ import { preProcessGraphTokens } from '~/utils/graph';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils/env';
import { isUserSourced } from './utils';
import { isUserSourced, isOAuthServer } from './utils';
/**
* Centralized manager for MCP server connections and tool execution.
@ -102,7 +102,7 @@ export class MCPManager extends UserConnectionManager {
return { tools: null, oauthRequired: false, oauthUrl: null };
}
const useOAuth = Boolean(serverConfig.requiresOAuth || serverConfig.oauthMetadata);
const useOAuth = isOAuthServer(serverConfig);
const registry = MCPServersRegistry.getInstance();
const useSSRFProtection = registry.shouldEnableSSRFProtection();

View file

@ -4,7 +4,7 @@ import type * as t from './types';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { isUserSourced } from './utils';
import { isUserSourced, isOAuthServer } from './utils';
import { MCPConnection } from './connection';
import { mcpConfig } from './mcpConfig';
@ -35,14 +35,7 @@ export abstract class UserConnectionManager {
}
/** Gets or creates a connection for a specific user, coalescing concurrent attempts */
public async getUserConnection(
opts: {
serverName: string;
forceNew?: boolean;
/** Pre-resolved config for config-source servers not in YAML/DB */
serverConfig?: t.ParsedServerConfig;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
): Promise<MCPConnection> {
public async getUserConnection(opts: t.UserMCPConnectionOptions): Promise<MCPConnection> {
const { serverName, forceNew, user } = opts;
const userId = user?.id;
if (!userId) {
@ -89,11 +82,7 @@ export abstract class UserConnectionManager {
returnOnOAuth = false,
connectionTimeout,
serverConfig: providedConfig,
}: {
serverName: string;
forceNew?: boolean;
serverConfig?: t.ParsedServerConfig;
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
}: t.UserMCPConnectionOptions,
userId: string,
): Promise<MCPConnection> {
if (await this.appConnections!.has(serverName)) {
@ -161,28 +150,38 @@ export abstract class UserConnectionManager {
try {
const registry = MCPServersRegistry.getInstance();
connection = await MCPConnectionFactory.create(
{
serverConfig: config,
serverName: serverName,
dbSourced: isUserSourced(config),
useSSRFProtection: registry.shouldEnableSSRFProtection(),
allowedDomains: registry.getAllowedDomains(),
},
{
useOAuth: true,
user: user,
customUserVars: customUserVars,
flowManager: flowManager,
tokenMethods: tokenMethods,
signal: signal,
oauthStart: oauthStart,
oauthEnd: oauthEnd,
returnOnOAuth: returnOnOAuth,
requestBody: requestBody,
connectionTimeout: connectionTimeout,
},
);
const basic: t.BasicConnectionOptions = {
serverConfig: config,
serverName: serverName,
dbSourced: isUserSourced(config),
useSSRFProtection: registry.shouldEnableSSRFProtection(),
allowedDomains: registry.getAllowedDomains(),
};
const useOAuth = isOAuthServer(config);
if (useOAuth && !flowManager) {
throw new McpError(
ErrorCode.InvalidRequest,
`[MCP][User: ${userId}] OAuth server "${serverName}" requires a flowManager`,
);
}
const oauthOptions: t.OAuthConnectionOptions | t.UserConnectionContext = useOAuth
? {
useOAuth: true as const,
user,
customUserVars,
flowManager: flowManager,
tokenMethods,
signal,
oauthStart,
oauthEnd,
returnOnOAuth,
requestBody,
connectionTimeout,
}
: { user, customUserVars, requestBody, connectionTimeout };
connection = await MCPConnectionFactory.create(basic, oauthOptions);
if (!(await connection?.isConnected())) {
throw new Error('Failed to establish connection after initialization attempt.');

View file

@ -94,6 +94,36 @@ describe('MCPConnectionFactory', () => {
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
it('should register fallback oauthRequired handler for non-OAuth connections', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
};
mockConnectionInstance.isConnected.mockResolvedValue(true);
await MCPConnectionFactory.create(basicOptions);
expect(mockConnectionInstance.on).toHaveBeenCalledWith('oauthRequired', expect.any(Function));
const onCall = (mockConnectionInstance.on as jest.Mock).mock.calls.find(
([event]: [string]) => event === 'oauthRequired',
);
const handler = onCall![1] as () => void;
handler();
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({ message: 'Server does not use OAuth' }),
);
expect(mockConnectionInstance.removeListener).toHaveBeenCalledWith(
'oauthRequired',
expect.any(Function),
);
});
it('should create a connection with OAuth', async () => {
const basicOptions = {
serverName: 'test-server',

View file

@ -925,6 +925,44 @@ describe('MCPManager', () => {
);
});
it('should use isOAuthServer in discoverServerTools', async () => {
const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser;
const mockFlowManager = {
createFlow: jest.fn(),
getFlowState: jest.fn(),
deleteFlow: jest.fn(),
};
mockAppConnections({
get: jest.fn().mockResolvedValue(null),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
type: 'streamable-http',
url: 'http://private-mcp.svc:5446/mcp',
requiresOAuth: false,
});
(MCPConnectionFactory.discoverTools as jest.Mock).mockResolvedValue({
tools: mockTools,
connection: null,
oauthRequired: false,
oauthUrl: null,
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.discoverServerTools({
serverName,
user: mockUser,
flowManager: mockFlowManager as unknown as t.ToolDiscoveryOptions['flowManager'],
});
expect(MCPConnectionFactory.discoverTools).toHaveBeenCalledWith(
expect.objectContaining({ serverName }),
expect.not.objectContaining({ useOAuth: true }),
);
});
it('should discover tools with OAuth when user and flowManager provided', async () => {
const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser;
const mockFlowManager = {
@ -966,4 +1004,89 @@ describe('MCPManager', () => {
);
});
});
describe('getUserConnection - useOAuth derivation', () => {
const mockUser = { id: userId, email: 'test@example.com' } as unknown as IUser;
const mockFlowManager = {
createFlow: jest.fn(),
getFlowState: jest.fn(),
deleteFlow: jest.fn(),
};
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
isStale: jest.fn().mockReturnValue(false),
disconnect: jest.fn(),
} as unknown as MCPConnection;
it('should pass useOAuth: true for servers with requiresOAuth', async () => {
mockAppConnections({
has: jest.fn().mockResolvedValue(false),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
type: 'sse',
url: 'https://oauth-mcp.example.com',
requiresOAuth: true,
});
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.getUserConnection({
serverName,
user: mockUser,
flowManager: mockFlowManager as unknown as t.UserMCPConnectionOptions['flowManager'],
});
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
expect.objectContaining({ serverName }),
expect.objectContaining({ useOAuth: true }),
);
});
it('should not pass useOAuth for servers with requiresOAuth: false', async () => {
mockAppConnections({
has: jest.fn().mockResolvedValue(false),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
type: 'streamable-http',
url: 'http://private-mcp.svc:5446/mcp',
requiresOAuth: false,
});
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
const manager = await MCPManager.createInstance(newMCPServersConfig());
await manager.getUserConnection({
serverName,
user: mockUser,
});
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
expect.objectContaining({ serverName }),
expect.not.objectContaining({ useOAuth: true }),
);
});
it('should throw when OAuth server lacks flowManager', async () => {
mockAppConnections({
has: jest.fn().mockResolvedValue(false),
});
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
type: 'sse',
url: 'https://oauth-mcp.example.com',
requiresOAuth: true,
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
await expect(
manager.getUserConnection({
serverName,
user: mockUser,
}),
).rejects.toThrow('requires a flowManager');
});
});
});

View file

@ -202,6 +202,19 @@ export interface OAuthConnectionOptions extends UserConnectionContext {
returnOnOAuth?: boolean;
}
/** Options accepted by UserConnectionManager.getUserConnection — OAuth fields are optional. */
export interface UserMCPConnectionOptions extends UserConnectionContext {
serverName: string;
forceNew?: boolean;
serverConfig?: ParsedServerConfig;
flowManager?: FlowStateManager<o.MCPOAuthTokens | null>;
tokenMethods?: TokenMethods;
signal?: AbortSignal;
oauthStart?: (authURL: string) => Promise<void>;
oauthEnd?: () => Promise<void>;
returnOnOAuth?: boolean;
}
export interface ToolDiscoveryOptions {
serverName: string;
user?: IUser;

View file

@ -3,6 +3,13 @@ import type { ParsedServerConfig } from '~/mcp/types';
export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`);
/** Whether a server requires OAuth (has `requiresOAuth` or `oauthMetadata`). */
export function isOAuthServer(
config: Pick<ParsedServerConfig, 'requiresOAuth' | 'oauthMetadata'>,
): boolean {
return Boolean(config.requiresOAuth || config.oauthMetadata);
}
/** Checks that `customUserVars` is present AND non-empty (guards against truthy `{}`) */
export function hasCustomUserVars(config: Pick<ParsedServerConfig, 'customUserVars'>): boolean {
return !!config.customUserVars && Object.keys(config.customUserVars).length > 0;