test: add isClientRejection tests and enforced client_id on test server

- Add isClientRejection unit tests: invalid_client, unauthorized_client,
  client_id mismatch, client not found, unknown client, and negative
  cases (timeout, flow state not found, user denied, null, undefined)
- Enhance OAuth test server with enforceClientId option: binds auth
  codes to the client_id that initiated /authorize, rejects token
  exchange with mismatched or unregistered client_id (401 invalid_client)
- Add integration tests proving the test server correctly rejects
  stale client_ids and accepts matching ones at /token
This commit is contained in:
Danny Avila 2026-04-03 17:56:14 -04:00
parent e188ff992b
commit 02a064ffb1
2 changed files with 139 additions and 2 deletions

View file

@ -26,6 +26,7 @@
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { OAuthTestServer } from './helpers/oauthTestServer';
import { InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPOAuthHandler, MCPTokenStorage } from '~/mcp/oauth';
jest.mock('@librechat/data-schemas', () => ({
@ -395,4 +396,105 @@ describe('MCPOAuthHandler - client registration reuse on reconnection', () => {
spy.mockRestore();
});
});
describe('isClientRejection', () => {
it('should detect invalid_client errors', () => {
expect(MCPConnectionFactory.isClientRejection(new Error('invalid_client'))).toBe(true);
expect(
MCPConnectionFactory.isClientRejection(
new Error('OAuth token exchange failed: invalid_client'),
),
).toBe(true);
});
it('should detect unauthorized_client errors', () => {
expect(MCPConnectionFactory.isClientRejection(new Error('unauthorized_client'))).toBe(true);
});
it('should detect client_id mismatch errors', () => {
expect(
MCPConnectionFactory.isClientRejection(
new Error('Token exchange rejected: client_id mismatch'),
),
).toBe(true);
});
it('should detect client not found errors', () => {
expect(MCPConnectionFactory.isClientRejection(new Error('client not found'))).toBe(true);
expect(MCPConnectionFactory.isClientRejection(new Error('unknown client'))).toBe(true);
});
it('should not match unrelated errors', () => {
expect(MCPConnectionFactory.isClientRejection(new Error('timeout'))).toBe(false);
expect(MCPConnectionFactory.isClientRejection(new Error('Flow state not found'))).toBe(false);
expect(MCPConnectionFactory.isClientRejection(new Error('user denied access'))).toBe(false);
expect(MCPConnectionFactory.isClientRejection(null)).toBe(false);
expect(MCPConnectionFactory.isClientRejection(undefined)).toBe(false);
});
});
describe('Token exchange with enforced client_id', () => {
it('should reject token exchange when client_id does not match registered client', async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true });
// Register a real client via DCR
const regRes = await fetch(`${server.url}register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
});
const registered = (await regRes.json()) as { client_id: string };
// Get an auth code bound to the registered client_id
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
// Try to exchange the code with a DIFFERENT (stale) client_id
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}&client_id=stale-client-id`,
});
expect(tokenRes.status).toBe(401);
const body = (await tokenRes.json()) as { error: string; error_description?: string };
expect(body.error).toBe('invalid_client');
// Verify isClientRejection would match this error
const errorMsg = body.error_description ?? body.error;
expect(MCPConnectionFactory.isClientRejection(new Error(errorMsg))).toBe(true);
});
it('should accept token exchange when client_id matches', async () => {
server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true });
const regRes = await fetch(`${server.url}register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
});
const registered = (await regRes.json()) as { client_id: string };
const authRes = await fetch(
`${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`,
{ redirect: 'manual' },
);
const location = authRes.headers.get('location') ?? '';
const code = new URL(location).searchParams.get('code');
const tokenRes = await fetch(`${server.url}token`, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: `grant_type=authorization_code&code=${code}&client_id=${registered.client_id}`,
});
expect(tokenRes.status).toBe(200);
const body = (await tokenRes.json()) as { access_token: string };
expect(body.access_token).toBeTruthy();
});
});
});

View file

@ -59,6 +59,8 @@ export interface OAuthTestServerOptions {
issueRefreshTokens?: boolean;
refreshTokenTTLMs?: number;
rotateRefreshTokens?: boolean;
/** When true, /token validates client_id against the registered client that initiated /authorize */
enforceClientId?: boolean;
}
export interface OAuthTestServer {
@ -81,6 +83,17 @@ async function readRequestBody(req: http.IncomingMessage): Promise<string> {
return Buffer.concat(chunks).toString();
}
function parseBasicAuth(
header: string | undefined,
): { clientId: string; clientSecret: string } | null {
if (!header?.startsWith('Basic ')) {
return null;
}
const decoded = Buffer.from(header.slice(6), 'base64').toString();
const [clientId, clientSecret] = decoded.split(':');
return clientId ? { clientId, clientSecret: clientSecret ?? '' } : null;
}
function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null {
if (contentType?.includes('application/x-www-form-urlencoded')) {
return new URLSearchParams(body);
@ -100,6 +113,7 @@ export async function createOAuthMCPServer(
issueRefreshTokens = false,
refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000,
rotateRefreshTokens = false,
enforceClientId = false,
} = options;
const sessions = new Map<string, StreamableHTTPServerTransport>();
@ -107,7 +121,10 @@ export async function createOAuthMCPServer(
const tokenIssueTimes = new Map<string, number>();
const issuedRefreshTokens = new Map<string, string>();
const refreshTokenIssueTimes = new Map<string, number>();
const authCodes = new Map<string, { codeChallenge?: string; codeChallengeMethod?: string }>();
const authCodes = new Map<
string,
{ codeChallenge?: string; codeChallengeMethod?: string; clientId?: string }
>();
const registeredClients = new Map<string, { client_id: string; client_secret: string }>();
let port = 0;
@ -155,7 +172,8 @@ export async function createOAuthMCPServer(
const code = randomUUID();
const codeChallenge = url.searchParams.get('code_challenge') ?? undefined;
const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined;
authCodes.set(code, { codeChallenge, codeChallengeMethod });
const clientId = url.searchParams.get('client_id') ?? undefined;
authCodes.set(code, { codeChallenge, codeChallengeMethod, clientId });
const redirectUri = url.searchParams.get('redirect_uri') ?? '';
const state = url.searchParams.get('state') ?? '';
res.writeHead(302, {
@ -202,6 +220,23 @@ export async function createOAuthMCPServer(
}
}
if (enforceClientId && codeData.clientId) {
const requestClientId =
params.get('client_id') ?? parseBasicAuth(req.headers.authorization)?.clientId;
if (!requestClientId || !registeredClients.has(requestClientId)) {
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'invalid_client' }));
return;
}
if (requestClientId !== codeData.clientId) {
res.writeHead(401, { 'Content-Type': 'application/json' });
res.end(
JSON.stringify({ error: 'invalid_client', error_description: 'client_id mismatch' }),
);
return;
}
}
authCodes.delete(code);
const accessToken = randomUUID();