mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
test: add isClientRejection tests and enforced client_id on test server
- Add isClientRejection unit tests: invalid_client, unauthorized_client, client_id mismatch, client not found, unknown client, and negative cases (timeout, flow state not found, user denied, null, undefined) - Enhance OAuth test server with enforceClientId option: binds auth codes to the client_id that initiated /authorize, rejects token exchange with mismatched or unregistered client_id (401 invalid_client) - Add integration tests proving the test server correctly rejects stale client_ids and accepts matching ones at /token
This commit is contained in:
parent
e188ff992b
commit
02a064ffb1
2 changed files with 139 additions and 2 deletions
|
|
@ -26,6 +26,7 @@
|
|||
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { OAuthTestServer } from './helpers/oauthTestServer';
|
||||
import { InMemoryTokenStore, createOAuthMCPServer } from './helpers/oauthTestServer';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPOAuthHandler, MCPTokenStorage } from '~/mcp/oauth';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
|
|
@ -395,4 +396,105 @@ describe('MCPOAuthHandler - client registration reuse on reconnection', () => {
|
|||
spy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('isClientRejection', () => {
|
||||
it('should detect invalid_client errors', () => {
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error('invalid_client'))).toBe(true);
|
||||
expect(
|
||||
MCPConnectionFactory.isClientRejection(
|
||||
new Error('OAuth token exchange failed: invalid_client'),
|
||||
),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect unauthorized_client errors', () => {
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error('unauthorized_client'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect client_id mismatch errors', () => {
|
||||
expect(
|
||||
MCPConnectionFactory.isClientRejection(
|
||||
new Error('Token exchange rejected: client_id mismatch'),
|
||||
),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect client not found errors', () => {
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error('client not found'))).toBe(true);
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error('unknown client'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should not match unrelated errors', () => {
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error('timeout'))).toBe(false);
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error('Flow state not found'))).toBe(false);
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error('user denied access'))).toBe(false);
|
||||
expect(MCPConnectionFactory.isClientRejection(null)).toBe(false);
|
||||
expect(MCPConnectionFactory.isClientRejection(undefined)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Token exchange with enforced client_id', () => {
|
||||
it('should reject token exchange when client_id does not match registered client', async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true });
|
||||
|
||||
// Register a real client via DCR
|
||||
const regRes = await fetch(`${server.url}register`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
|
||||
});
|
||||
const registered = (await regRes.json()) as { client_id: string };
|
||||
|
||||
// Get an auth code bound to the registered client_id
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
|
||||
// Try to exchange the code with a DIFFERENT (stale) client_id
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}&client_id=stale-client-id`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(401);
|
||||
const body = (await tokenRes.json()) as { error: string; error_description?: string };
|
||||
expect(body.error).toBe('invalid_client');
|
||||
|
||||
// Verify isClientRejection would match this error
|
||||
const errorMsg = body.error_description ?? body.error;
|
||||
expect(MCPConnectionFactory.isClientRejection(new Error(errorMsg))).toBe(true);
|
||||
});
|
||||
|
||||
it('should accept token exchange when client_id matches', async () => {
|
||||
server = await createOAuthMCPServer({ tokenTTLMs: 60000, enforceClientId: true });
|
||||
|
||||
const regRes = await fetch(`${server.url}register`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ redirect_uris: ['http://localhost/callback'] }),
|
||||
});
|
||||
const registered = (await regRes.json()) as { client_id: string };
|
||||
|
||||
const authRes = await fetch(
|
||||
`${server.url}authorize?redirect_uri=http://localhost/callback&state=s1&client_id=${registered.client_id}`,
|
||||
{ redirect: 'manual' },
|
||||
);
|
||||
const location = authRes.headers.get('location') ?? '';
|
||||
const code = new URL(location).searchParams.get('code');
|
||||
|
||||
const tokenRes = await fetch(`${server.url}token`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: `grant_type=authorization_code&code=${code}&client_id=${registered.client_id}`,
|
||||
});
|
||||
|
||||
expect(tokenRes.status).toBe(200);
|
||||
const body = (await tokenRes.json()) as { access_token: string };
|
||||
expect(body.access_token).toBeTruthy();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ export interface OAuthTestServerOptions {
|
|||
issueRefreshTokens?: boolean;
|
||||
refreshTokenTTLMs?: number;
|
||||
rotateRefreshTokens?: boolean;
|
||||
/** When true, /token validates client_id against the registered client that initiated /authorize */
|
||||
enforceClientId?: boolean;
|
||||
}
|
||||
|
||||
export interface OAuthTestServer {
|
||||
|
|
@ -81,6 +83,17 @@ async function readRequestBody(req: http.IncomingMessage): Promise<string> {
|
|||
return Buffer.concat(chunks).toString();
|
||||
}
|
||||
|
||||
function parseBasicAuth(
|
||||
header: string | undefined,
|
||||
): { clientId: string; clientSecret: string } | null {
|
||||
if (!header?.startsWith('Basic ')) {
|
||||
return null;
|
||||
}
|
||||
const decoded = Buffer.from(header.slice(6), 'base64').toString();
|
||||
const [clientId, clientSecret] = decoded.split(':');
|
||||
return clientId ? { clientId, clientSecret: clientSecret ?? '' } : null;
|
||||
}
|
||||
|
||||
function parseTokenRequest(body: string, contentType: string | undefined): URLSearchParams | null {
|
||||
if (contentType?.includes('application/x-www-form-urlencoded')) {
|
||||
return new URLSearchParams(body);
|
||||
|
|
@ -100,6 +113,7 @@ export async function createOAuthMCPServer(
|
|||
issueRefreshTokens = false,
|
||||
refreshTokenTTLMs = 365 * 24 * 60 * 60 * 1000,
|
||||
rotateRefreshTokens = false,
|
||||
enforceClientId = false,
|
||||
} = options;
|
||||
|
||||
const sessions = new Map<string, StreamableHTTPServerTransport>();
|
||||
|
|
@ -107,7 +121,10 @@ export async function createOAuthMCPServer(
|
|||
const tokenIssueTimes = new Map<string, number>();
|
||||
const issuedRefreshTokens = new Map<string, string>();
|
||||
const refreshTokenIssueTimes = new Map<string, number>();
|
||||
const authCodes = new Map<string, { codeChallenge?: string; codeChallengeMethod?: string }>();
|
||||
const authCodes = new Map<
|
||||
string,
|
||||
{ codeChallenge?: string; codeChallengeMethod?: string; clientId?: string }
|
||||
>();
|
||||
const registeredClients = new Map<string, { client_id: string; client_secret: string }>();
|
||||
|
||||
let port = 0;
|
||||
|
|
@ -155,7 +172,8 @@ export async function createOAuthMCPServer(
|
|||
const code = randomUUID();
|
||||
const codeChallenge = url.searchParams.get('code_challenge') ?? undefined;
|
||||
const codeChallengeMethod = url.searchParams.get('code_challenge_method') ?? undefined;
|
||||
authCodes.set(code, { codeChallenge, codeChallengeMethod });
|
||||
const clientId = url.searchParams.get('client_id') ?? undefined;
|
||||
authCodes.set(code, { codeChallenge, codeChallengeMethod, clientId });
|
||||
const redirectUri = url.searchParams.get('redirect_uri') ?? '';
|
||||
const state = url.searchParams.get('state') ?? '';
|
||||
res.writeHead(302, {
|
||||
|
|
@ -202,6 +220,23 @@ export async function createOAuthMCPServer(
|
|||
}
|
||||
}
|
||||
|
||||
if (enforceClientId && codeData.clientId) {
|
||||
const requestClientId =
|
||||
params.get('client_id') ?? parseBasicAuth(req.headers.authorization)?.clientId;
|
||||
if (!requestClientId || !registeredClients.has(requestClientId)) {
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: 'invalid_client' }));
|
||||
return;
|
||||
}
|
||||
if (requestClientId !== codeData.clientId) {
|
||||
res.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
res.end(
|
||||
JSON.stringify({ error: 'invalid_client', error_description: 'client_id mismatch' }),
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
authCodes.delete(code);
|
||||
|
||||
const accessToken = randomUUID();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue