mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
fix: address review findings — tests, types, normalization, docs
- Add deleteTokens method to InMemoryTokenStore matching TokenMethods contract; update test call site from deleteToken to deleteTokens - Add MCPConnectionFactory test: returnOnOAuth flow fails with invalid_client → clearStaleClientIfRejected invoked automatically - Add mcp.spec.js tests: OAuth error with CSRF → failFlow called; OAuth error without cookies → failFlow NOT called (DoS prevention) - Add JSDoc to isClientRejection with RFC 6749 and vendor attribution - Add inline comment explaining findToken/deleteTokens coupling guard - Normalize issuer comparison: strip trailing slashes to prevent spurious re-registrations from URL formatting differences - Fix dead-code: use local reusedStoredClient variable in PENDING join return instead of re-reading flowMeta
This commit is contained in:
parent
b5231547bb
commit
c20266c4f9
6 changed files with 157 additions and 6 deletions
|
|
@ -311,6 +311,60 @@ describe('MCP Routes', () => {
|
|||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`);
|
||||
});
|
||||
|
||||
describe('OAuth error callback failFlow', () => {
|
||||
it('should fail the flow when OAuth error is received with valid CSRF cookie', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING', createdAt: Date.now() }),
|
||||
failFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValueOnce({});
|
||||
require('~/config').getFlowStateManager.mockReturnValueOnce(mockFlowManager);
|
||||
MCPOAuthHandler.resolveStateToFlowId.mockResolvedValueOnce(flowId);
|
||||
|
||||
const csrfToken = generateTestCsrfToken(flowId);
|
||||
const response = await request(app)
|
||||
.get('/api/mcp/test-server/oauth/callback')
|
||||
.set('Cookie', [`oauth_csrf=${csrfToken}`])
|
||||
.query({
|
||||
error: 'invalid_client',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_client`);
|
||||
expect(mockFlowManager.failFlow).toHaveBeenCalledWith(
|
||||
flowId,
|
||||
'mcp_oauth',
|
||||
'invalid_client',
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT fail the flow when OAuth error is received without cookies (DoS prevention)', async () => {
|
||||
const flowId = 'test-user-id:test-server';
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({ status: 'PENDING', createdAt: Date.now() }),
|
||||
failFlow: jest.fn(),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValueOnce({});
|
||||
require('~/config').getFlowStateManager.mockReturnValueOnce(mockFlowManager);
|
||||
MCPOAuthHandler.resolveStateToFlowId.mockResolvedValueOnce(flowId);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
error: 'invalid_client',
|
||||
state: flowId,
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_client`);
|
||||
expect(mockFlowManager.failFlow).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should redirect to error page when code is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
state: 'test-user-id:test-server',
|
||||
|
|
|
|||
|
|
@ -351,6 +351,7 @@ export class MCPConnectionFactory {
|
|||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
this.allowedDomains,
|
||||
// Only reuse stored client when deleteTokens is available for stale-client cleanup
|
||||
this.tokenMethods?.deleteTokens ? this.tokenMethods.findToken : undefined,
|
||||
);
|
||||
|
||||
|
|
@ -491,7 +492,11 @@ export class MCPConnectionFactory {
|
|||
});
|
||||
}
|
||||
|
||||
/** Determines if an error indicates the OAuth client registration was rejected */
|
||||
/**
|
||||
* Checks whether an error indicates the OAuth client registration was rejected.
|
||||
* Includes RFC 6749 §5.2 standard codes (`invalid_client`, `unauthorized_client`)
|
||||
* and known vendor-specific patterns (Okta: `client_id mismatch`, Auth0: `client not found`).
|
||||
*/
|
||||
static isClientRejection(error: unknown): boolean {
|
||||
if (!error || typeof error !== 'object') {
|
||||
return false;
|
||||
|
|
@ -608,7 +613,7 @@ export class MCPConnectionFactory {
|
|||
tokens,
|
||||
clientInfo: flowMeta?.clientInfo,
|
||||
metadata: flowMeta?.metadata,
|
||||
reusedStoredClient: flowMeta?.reusedStoredClient === true,
|
||||
reusedStoredClient,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import type { MCPOAuthTokens } from '~/mcp/oauth';
|
|||
import type * as t from '~/mcp/types';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { MCPOAuthHandler, MCPTokenStorage } from '~/mcp/oauth';
|
||||
import { processMCPEnv } from '~/utils';
|
||||
|
||||
jest.mock('~/mcp/connection');
|
||||
|
|
@ -24,6 +24,7 @@ const mockLogger = logger as jest.Mocked<typeof logger>;
|
|||
const mockProcessMCPEnv = processMCPEnv as jest.MockedFunction<typeof processMCPEnv>;
|
||||
const mockMCPConnection = MCPConnection as jest.MockedClass<typeof MCPConnection>;
|
||||
const mockMCPOAuthHandler = MCPOAuthHandler as jest.Mocked<typeof MCPOAuthHandler>;
|
||||
const mockMCPTokenStorage = MCPTokenStorage as jest.Mocked<typeof MCPTokenStorage>;
|
||||
|
||||
describe('MCPConnectionFactory', () => {
|
||||
let mockUser: IUser | undefined;
|
||||
|
|
@ -293,6 +294,78 @@ describe('MCPConnectionFactory', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should clear stale client registration when returnOnOAuth flow fails with client rejection', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: {
|
||||
...mockServerConfig,
|
||||
url: 'https://api.example.com',
|
||||
type: 'sse' as const,
|
||||
} as t.SSEOptions,
|
||||
};
|
||||
|
||||
const deleteTokensSpy = jest.fn().mockResolvedValue({ acknowledged: true, deletedCount: 1 });
|
||||
const oauthOptions = {
|
||||
useOAuth: true as const,
|
||||
user: mockUser,
|
||||
flowManager: mockFlowManager,
|
||||
returnOnOAuth: true,
|
||||
oauthStart: jest.fn(),
|
||||
tokenMethods: {
|
||||
findToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
updateToken: jest.fn(),
|
||||
deleteTokens: deleteTokensSpy,
|
||||
},
|
||||
};
|
||||
|
||||
const mockFlowData = {
|
||||
authorizationUrl: 'https://auth.example.com',
|
||||
flowId: 'flow123',
|
||||
flowMetadata: {
|
||||
serverName: 'test-server',
|
||||
userId: 'user123',
|
||||
serverUrl: 'https://api.example.com',
|
||||
state: 'random-state',
|
||||
clientInfo: { client_id: 'stale-client' },
|
||||
reusedStoredClient: true,
|
||||
},
|
||||
};
|
||||
|
||||
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
|
||||
mockMCPTokenStorage.deleteClientRegistration.mockResolvedValue(undefined);
|
||||
// createFlow rejects with invalid_client — simulating stale client rejection
|
||||
mockFlowManager.createFlow.mockRejectedValue(new Error('invalid_client'));
|
||||
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||
|
||||
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
|
||||
mockConnectionInstance.on.mockImplementation((event, handler) => {
|
||||
if (event === 'oauthRequired') {
|
||||
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
|
||||
}
|
||||
return mockConnectionInstance;
|
||||
});
|
||||
|
||||
try {
|
||||
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||
} catch {
|
||||
// Expected
|
||||
}
|
||||
|
||||
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||
|
||||
// Wait for the background .catch() handler to run
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
// deleteClientRegistration should have been called via clearStaleClientIfRejected
|
||||
expect(mockMCPTokenStorage.deleteClientRegistration).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
userId: 'user123',
|
||||
serverName: 'test-server',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ describe('MCPOAuthHandler - client registration reuse on reconnection', () => {
|
|||
await MCPTokenStorage.deleteClientRegistration({
|
||||
userId: 'user-1',
|
||||
serverName: 'test-server',
|
||||
deleteTokens: tokenStore.deleteToken,
|
||||
deleteTokens: tokenStore.deleteTokens,
|
||||
});
|
||||
|
||||
// Second attempt (retry after failure): should do a fresh DCR
|
||||
|
|
|
|||
|
|
@ -474,6 +474,25 @@ export class InMemoryTokenStore {
|
|||
this.tokens.delete(this.key(filter));
|
||||
};
|
||||
|
||||
deleteTokens = async (query: {
|
||||
userId?: string;
|
||||
type?: string;
|
||||
identifier?: string;
|
||||
}): Promise<{ acknowledged: boolean; deletedCount: number }> => {
|
||||
let deletedCount = 0;
|
||||
for (const [key, token] of this.tokens.entries()) {
|
||||
const match =
|
||||
(!query.userId || token.userId === query.userId) &&
|
||||
(!query.type || token.type === query.type) &&
|
||||
(!query.identifier || token.identifier === query.identifier);
|
||||
if (match) {
|
||||
this.tokens.delete(key);
|
||||
deletedCount++;
|
||||
}
|
||||
}
|
||||
return { acknowledged: true, deletedCount };
|
||||
};
|
||||
|
||||
getAll(): InMemoryToken[] {
|
||||
return [...this.tokens.values()];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -514,9 +514,9 @@ export class MCPOAuthHandler {
|
|||
.redirect_uris?.[0];
|
||||
const storedIssuer =
|
||||
typeof existing.clientMetadata?.issuer === 'string'
|
||||
? existing.clientMetadata.issuer
|
||||
? existing.clientMetadata.issuer.replace(/\/+$/, '')
|
||||
: null;
|
||||
const currentIssuer = metadata.issuer ?? authServerUrl.toString();
|
||||
const currentIssuer = (metadata.issuer ?? authServerUrl.toString()).replace(/\/+$/, '');
|
||||
|
||||
if (!storedRedirectUri || storedRedirectUri !== redirectUri) {
|
||||
logger.debug(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue