From 02a064ffb195c1ccb5d0d429101a52ad8f196443 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 3 Apr 2026 17:56:14 -0400 Subject: [PATCH] test: add isClientRejection tests and enforced client_id on test server - Add isClientRejection unit tests: invalid_client, unauthorized_client, client_id mismatch, client not found, unknown client, and negative cases (timeout, flow state not found, user denied, null, undefined) - Enhance OAuth test server with enforceClientId option: binds auth codes to the client_id that initiated /authorize, rejects token exchange with mismatched or unregistered client_id (401 invalid_client) - Add integration tests proving the test server correctly rejects stale client_ids and accepts matching ones at /token --- .../MCPOAuthClientRegistrationReuse.test.ts | 102 ++++++++++++++++++ .../mcp/__tests__/helpers/oauthTestServer.ts | 39 ++++++- 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts b/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts index 7456b56600..e9667cf335 100644 --- a/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts +++ b/packages/api/src/mcp/__tests__/MCPOAuthClientRegistrationReuse.test.ts @@ -26,6 +26,7 @@ import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; import type { OAuthTestServer } from './helpers/oauthTestServer'; import { InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPOAuthHandler, MCPTokenStorage } from '~/mcp/oauth'; jest.mock('@librechat/data-schemas', () => ({ @@ -395,4 +396,105 @@ describe('MCPOAuthHandler - client registration reuse on reconnection', () => { spy.mockRestore(); }); }); + + describe('isClientRejection', () => { + it('should detect invalid_client errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('invalid_client'))).toBe(true); + expect( + MCPConnectionFactory.isClientRejection( + new Error('OAuth token exchange failed: invalid_client'), + ), + ).toBe(true); + }); + + it('should detect unauthorized_client errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('unauthorized_client'))).toBe(true); + }); + + it('should detect client_id mismatch errors', () => { + expect( + MCPConnectionFactory.isClientRejection( + new Error('Token exchange rejected: client_id mismatch'), + ), + ).toBe(true); + }); + + it('should detect client not found errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('client not found'))).toBe(true); + expect(MCPConnectionFactory.isClientRejection(new Error('unknown client'))).toBe(true); + }); + + it('should not match unrelated errors', () => { + expect(MCPConnectionFactory.isClientRejection(new Error('timeout'))).toBe(false); + expect(MCPConnectionFactory.isClientRejection(new Error('Flow state not found'))).toBe(false); + expect(MCPConnectionFactory.isClientRejection(new Error('user denied access'))).toBe(false); + expect(MCPConnectionFactory.isClientRejection(null)).toBe(false); + expect(MCPConnectionFactory.isClientRejection(undefined)).toBe(false); + }); + }); + + describe('Token exchange with enforced client_id', () => { + it('should reject token exchange when client_id does not match registered client', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true }); + + // Register a real client via DCR + const regRes = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const registered = (await regRes.json()) as { client_id: string }; + + // Get an auth code bound to the registered client_id + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + // Try to exchange the code with a DIFFERENT (stale) client_id + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&client_id=stale-client-id`, + }); + + expect(tokenRes.status).toBe(401); + const body = (await tokenRes.json()) as { error: string; error_description?: string }; + expect(body.error).toBe('invalid_client'); + + // Verify isClientRejection would match this error + const errorMsg = body.error_description ?? body.error; + expect(MCPConnectionFactory.isClientRejection(new Error(errorMsg))).toBe(true); + }); + + it('should accept token exchange when client_id matches', async () => { + server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true }); + + const regRes = await fetch(`${server.url}register`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }), + }); + const registered = (await regRes.json()) as { client_id: string }; + + const authRes = await fetch( + `${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`, + { redirect: 'manual' }, + ); + const location = authRes.headers.get('location') ?? ''; + const code = new URL(location).searchParams.get('code'); + + const tokenRes = await fetch(`${server.url}token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: `grant_type=authorization_code&code=${code}&client_id=${registered.client_id}`, + }); + + expect(tokenRes.status).toBe(200); + const body = (await tokenRes.json()) as { access_token: string }; + expect(body.access_token).toBeTruthy(); + }); + }); }); diff --git a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts index 3b68b2ded4..770c1148e0 100644 --- a/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts +++ b/packages/api/src/mcp/__tests__/helpers/oauthTestServer.ts @@ -59,6 +59,8 @@ export interface OAuthTestServerOptions { issueRefreshTokens?: boolean; refreshTokenTTLMs?: number; rotateRefreshTokens?: boolean; + /** When true, /token validates client_id against the registered client that initiated /authorize */ + enforceClientId?: boolean; } export interface OAuthTestServer { @@ -81,6 +83,17 @@ async function readRequestBody(req: http.IncomingMessage): Promise { return Buffer.concat(chunks).toString(); } +function parseBasicAuth( + header: string | undefined, +): { clientId: string; clientSecret: string } | null { + if (!header?.startsWith('Basic ')) { + return null; + } + const decoded = Buffer.from(header.slice(6), 'base64').toString(); + const [clientId, clientSecret] = decoded.split(':'); + return clientId ? { clientId, clientSecret: clientSecret ?? '' } : null; +} + function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null { if (contentType?.includes('application/x-www-form-urlencoded')) { return new URLSearchParams(body); @@ -100,6 +113,7 @@ export async function createOAuthMCPServer( issueRefreshTokens = false, refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000, rotateRefreshTokens = false, + enforceClientId = false, } = options; const sessions = new Map(); @@ -107,7 +121,10 @@ export async function createOAuthMCPServer( const tokenIssueTimes = new Map(); const issuedRefreshTokens = new Map(); const refreshTokenIssueTimes = new Map(); - const authCodes = new Map(); + const authCodes = new Map< + string, + { codeChallenge?: string; codeChallengeMethod?: string; clientId?: string } + >(); const registeredClients = new Map(); let port = 0; @@ -155,7 +172,8 @@ export async function createOAuthMCPServer( const code = randomUUID(); const codeChallenge = url.searchParams.get('code_challenge') ?? undefined; const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined; - authCodes.set(code, { codeChallenge, codeChallengeMethod }); + const clientId = url.searchParams.get('client_id') ?? undefined; + authCodes.set(code, { codeChallenge, codeChallengeMethod, clientId }); const redirectUri = url.searchParams.get('redirect_uri') ?? ''; const state = url.searchParams.get('state') ?? ''; res.writeHead(302, { @@ -202,6 +220,23 @@ export async function createOAuthMCPServer( } } + if (enforceClientId && codeData.clientId) { + const requestClientId = + params.get('client_id') ?? parseBasicAuth(req.headers.authorization)?.clientId; + if (!requestClientId || !registeredClients.has(requestClientId)) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'invalid_client' })); + return; + } + if (requestClientId !== codeData.clientId) { + res.writeHead(401, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ error: 'invalid_client', error_description: 'client_id mismatch' }), + ); + return; + } + } + authCodes.delete(code); const accessToken = randomUUID();