mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
Merge bc038213db into 8ed0bcf5ca
This commit is contained in:
commit
44071a38ec
7 changed files with 222 additions and 39 deletions
|
|
@ -45,7 +45,7 @@ export class MCPConnectionFactory {
|
|||
/** Creates a new MCP connection with optional OAuth support */
|
||||
static async create(
|
||||
basic: t.BasicConnectionOptions,
|
||||
oauth?: t.OAuthConnectionOptions,
|
||||
oauth?: t.OAuthConnectionOptions | t.UserConnectionContext,
|
||||
): Promise<MCPConnection> {
|
||||
const factory = new this(basic, oauth);
|
||||
return factory.createConnection();
|
||||
|
|
@ -232,6 +232,17 @@ export class MCPConnectionFactory {
|
|||
let cleanupOAuthHandlers: (() => void) | null = null;
|
||||
if (this.useOAuth) {
|
||||
cleanupOAuthHandlers = this.handleOAuthEvents(connection);
|
||||
} else {
|
||||
const nonOAuthHandler = () => {
|
||||
logger.info(
|
||||
`${this.logPrefix} Server does not use OAuth — treating 401/403 as auth failure, not OAuth`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('Server does not use OAuth'));
|
||||
};
|
||||
connection.on('oauthRequired', nonOAuthHandler);
|
||||
cleanupOAuthHandlers = () => {
|
||||
connection.removeListener('oauthRequired', nonOAuthHandler);
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import { preProcessGraphTokens } from '~/utils/graph';
|
|||
import { formatToolContent } from './parsers';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils/env';
|
||||
import { isUserSourced } from './utils';
|
||||
import { isUserSourced, isOAuthServer } from './utils';
|
||||
|
||||
/**
|
||||
* Centralized manager for MCP server connections and tool execution.
|
||||
|
|
@ -102,7 +102,7 @@ export class MCPManager extends UserConnectionManager {
|
|||
return { tools: null, oauthRequired: false, oauthUrl: null };
|
||||
}
|
||||
|
||||
const useOAuth = Boolean(serverConfig.requiresOAuth || serverConfig.oauthMetadata);
|
||||
const useOAuth = isOAuthServer(serverConfig);
|
||||
|
||||
const registry = MCPServersRegistry.getInstance();
|
||||
const useSSRFProtection = registry.shouldEnableSSRFProtection();
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import type * as t from './types';
|
|||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { isUserSourced } from './utils';
|
||||
import { isUserSourced, isOAuthServer } from './utils';
|
||||
import { MCPConnection } from './connection';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
||||
|
|
@ -35,14 +35,7 @@ export abstract class UserConnectionManager {
|
|||
}
|
||||
|
||||
/** Gets or creates a connection for a specific user, coalescing concurrent attempts */
|
||||
public async getUserConnection(
|
||||
opts: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
/** Pre-resolved config for config-source servers not in YAML/DB */
|
||||
serverConfig?: t.ParsedServerConfig;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
|
||||
): Promise<MCPConnection> {
|
||||
public async getUserConnection(opts: t.UserMCPConnectionOptions): Promise<MCPConnection> {
|
||||
const { serverName, forceNew, user } = opts;
|
||||
const userId = user?.id;
|
||||
if (!userId) {
|
||||
|
|
@ -89,11 +82,7 @@ export abstract class UserConnectionManager {
|
|||
returnOnOAuth = false,
|
||||
connectionTimeout,
|
||||
serverConfig: providedConfig,
|
||||
}: {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
serverConfig?: t.ParsedServerConfig;
|
||||
} & Omit<t.OAuthConnectionOptions, 'useOAuth'>,
|
||||
}: t.UserMCPConnectionOptions,
|
||||
userId: string,
|
||||
): Promise<MCPConnection> {
|
||||
if (await this.appConnections!.has(serverName)) {
|
||||
|
|
@ -161,28 +150,38 @@ export abstract class UserConnectionManager {
|
|||
|
||||
try {
|
||||
const registry = MCPServersRegistry.getInstance();
|
||||
connection = await MCPConnectionFactory.create(
|
||||
{
|
||||
serverConfig: config,
|
||||
serverName: serverName,
|
||||
dbSourced: isUserSourced(config),
|
||||
useSSRFProtection: registry.shouldEnableSSRFProtection(),
|
||||
allowedDomains: registry.getAllowedDomains(),
|
||||
},
|
||||
{
|
||||
useOAuth: true,
|
||||
user: user,
|
||||
customUserVars: customUserVars,
|
||||
flowManager: flowManager,
|
||||
tokenMethods: tokenMethods,
|
||||
signal: signal,
|
||||
oauthStart: oauthStart,
|
||||
oauthEnd: oauthEnd,
|
||||
returnOnOAuth: returnOnOAuth,
|
||||
requestBody: requestBody,
|
||||
connectionTimeout: connectionTimeout,
|
||||
},
|
||||
);
|
||||
const basic: t.BasicConnectionOptions = {
|
||||
serverConfig: config,
|
||||
serverName: serverName,
|
||||
dbSourced: isUserSourced(config),
|
||||
useSSRFProtection: registry.shouldEnableSSRFProtection(),
|
||||
allowedDomains: registry.getAllowedDomains(),
|
||||
};
|
||||
|
||||
const useOAuth = isOAuthServer(config);
|
||||
if (useOAuth && !flowManager) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
`[MCP][User: ${userId}] OAuth server "${serverName}" requires a flowManager`,
|
||||
);
|
||||
}
|
||||
const oauthOptions: t.OAuthConnectionOptions | t.UserConnectionContext = useOAuth
|
||||
? {
|
||||
useOAuth: true as const,
|
||||
user,
|
||||
customUserVars,
|
||||
flowManager: flowManager,
|
||||
tokenMethods,
|
||||
signal,
|
||||
oauthStart,
|
||||
oauthEnd,
|
||||
returnOnOAuth,
|
||||
requestBody,
|
||||
connectionTimeout,
|
||||
}
|
||||
: { user, customUserVars, requestBody, connectionTimeout };
|
||||
|
||||
connection = await MCPConnectionFactory.create(basic, oauthOptions);
|
||||
|
||||
if (!(await connection?.isConnected())) {
|
||||
throw new Error('Failed to establish connection after initialization attempt.');
|
||||
|
|
|
|||
|
|
@ -94,6 +94,36 @@ describe('MCPConnectionFactory', () => {
|
|||
expect(mockConnectionInstance.connect).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should register fallback oauthRequired handler for non-OAuth connections', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: mockServerConfig,
|
||||
};
|
||||
|
||||
mockConnectionInstance.isConnected.mockResolvedValue(true);
|
||||
|
||||
await MCPConnectionFactory.create(basicOptions);
|
||||
|
||||
expect(mockConnectionInstance.on).toHaveBeenCalledWith('oauthRequired', expect.any(Function));
|
||||
|
||||
const onCall = (mockConnectionInstance.on as jest.Mock).mock.calls.find(
|
||||
([event]: [string]) => event === 'oauthRequired',
|
||||
);
|
||||
|
||||
const handler = onCall![1] as () => void;
|
||||
handler();
|
||||
|
||||
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
|
||||
'oauthFailed',
|
||||
expect.objectContaining({ message: 'Server does not use OAuth' }),
|
||||
);
|
||||
|
||||
expect(mockConnectionInstance.removeListener).toHaveBeenCalledWith(
|
||||
'oauthRequired',
|
||||
expect.any(Function),
|
||||
);
|
||||
});
|
||||
|
||||
it('should create a connection with OAuth', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
|
|
|
|||
|
|
@ -925,6 +925,44 @@ describe('MCPManager', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should use isOAuthServer in discoverServerTools', async () => {
|
||||
const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser;
|
||||
const mockFlowManager = {
|
||||
createFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
deleteFlow: jest.fn(),
|
||||
};
|
||||
|
||||
mockAppConnections({
|
||||
get: jest.fn().mockResolvedValue(null),
|
||||
});
|
||||
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
|
||||
type: 'streamable-http',
|
||||
url: 'http://private-mcp.svc:5446/mcp',
|
||||
requiresOAuth: false,
|
||||
});
|
||||
|
||||
(MCPConnectionFactory.discoverTools as jest.Mock).mockResolvedValue({
|
||||
tools: mockTools,
|
||||
connection: null,
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.discoverServerTools({
|
||||
serverName,
|
||||
user: mockUser,
|
||||
flowManager: mockFlowManager as unknown as t.ToolDiscoveryOptions['flowManager'],
|
||||
});
|
||||
|
||||
expect(MCPConnectionFactory.discoverTools).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName }),
|
||||
expect.not.objectContaining({ useOAuth: true }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should discover tools with OAuth when user and flowManager provided', async () => {
|
||||
const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser;
|
||||
const mockFlowManager = {
|
||||
|
|
@ -966,4 +1004,89 @@ describe('MCPManager', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getUserConnection - useOAuth derivation', () => {
|
||||
const mockUser = { id: userId, email: 'test@example.com' } as unknown as IUser;
|
||||
const mockFlowManager = {
|
||||
createFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
deleteFlow: jest.fn(),
|
||||
};
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
isStale: jest.fn().mockReturnValue(false),
|
||||
disconnect: jest.fn(),
|
||||
} as unknown as MCPConnection;
|
||||
|
||||
it('should pass useOAuth: true for servers with requiresOAuth', async () => {
|
||||
mockAppConnections({
|
||||
has: jest.fn().mockResolvedValue(false),
|
||||
});
|
||||
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
|
||||
type: 'sse',
|
||||
url: 'https://oauth-mcp.example.com',
|
||||
requiresOAuth: true,
|
||||
});
|
||||
|
||||
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.getUserConnection({
|
||||
serverName,
|
||||
user: mockUser,
|
||||
flowManager: mockFlowManager as unknown as t.UserMCPConnectionOptions['flowManager'],
|
||||
});
|
||||
|
||||
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName }),
|
||||
expect.objectContaining({ useOAuth: true }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not pass useOAuth for servers with requiresOAuth: false', async () => {
|
||||
mockAppConnections({
|
||||
has: jest.fn().mockResolvedValue(false),
|
||||
});
|
||||
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
|
||||
type: 'streamable-http',
|
||||
url: 'http://private-mcp.svc:5446/mcp',
|
||||
requiresOAuth: false,
|
||||
});
|
||||
|
||||
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await manager.getUserConnection({
|
||||
serverName,
|
||||
user: mockUser,
|
||||
});
|
||||
|
||||
expect(MCPConnectionFactory.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName }),
|
||||
expect.not.objectContaining({ useOAuth: true }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw when OAuth server lacks flowManager', async () => {
|
||||
mockAppConnections({
|
||||
has: jest.fn().mockResolvedValue(false),
|
||||
});
|
||||
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({
|
||||
type: 'sse',
|
||||
url: 'https://oauth-mcp.example.com',
|
||||
requiresOAuth: true,
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
await expect(
|
||||
manager.getUserConnection({
|
||||
serverName,
|
||||
user: mockUser,
|
||||
}),
|
||||
).rejects.toThrow('requires a flowManager');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -202,6 +202,19 @@ export interface OAuthConnectionOptions extends UserConnectionContext {
|
|||
returnOnOAuth?: boolean;
|
||||
}
|
||||
|
||||
/** Options accepted by UserConnectionManager.getUserConnection — OAuth fields are optional. */
|
||||
export interface UserMCPConnectionOptions extends UserConnectionContext {
|
||||
serverName: string;
|
||||
forceNew?: boolean;
|
||||
serverConfig?: ParsedServerConfig;
|
||||
flowManager?: FlowStateManager<o.MCPOAuthTokens | null>;
|
||||
tokenMethods?: TokenMethods;
|
||||
signal?: AbortSignal;
|
||||
oauthStart?: (authURL: string) => Promise<void>;
|
||||
oauthEnd?: () => Promise<void>;
|
||||
returnOnOAuth?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolDiscoveryOptions {
|
||||
serverName: string;
|
||||
user?: IUser;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,13 @@ import type { ParsedServerConfig } from '~/mcp/types';
|
|||
|
||||
export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`);
|
||||
|
||||
/** Whether a server requires OAuth (has `requiresOAuth` or `oauthMetadata`). */
|
||||
export function isOAuthServer(
|
||||
config: Pick<ParsedServerConfig, 'requiresOAuth' | 'oauthMetadata'>,
|
||||
): boolean {
|
||||
return Boolean(config.requiresOAuth || config.oauthMetadata);
|
||||
}
|
||||
|
||||
/** Checks that `customUserVars` is present AND non-empty (guards against truthy `{}`) */
|
||||
export function hasCustomUserVars(config: Pick<ParsedServerConfig, 'customUserVars'>): boolean {
|
||||
return !!config.customUserVars && Object.keys(config.customUserVars).length > 0;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue