From ca3237c7be25cf32f65fd7b6da9b2ca4aca32e29 Mon Sep 17 00:00:00 2001 From: Ruben Talstra Date: Sat, 22 Feb 2025 12:18:00 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Implement=20role=20extracti?= =?UTF-8?q?on=20and=20user=20group=20update=20logic=20in=20OpenID=20strate?= =?UTF-8?q?gy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/strategies/openidStrategy.js | 127 ++++++++++++++++---------- api/strategies/openidStrategy.spec.js | 3 + 2 files changed, 82 insertions(+), 48 deletions(-) diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index abc41041ec..8e44a7f627 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -106,6 +106,71 @@ function convertToUsername(input, defaultValue = '') { return defaultValue; } +/** + * Extracts roles from the specified token using configuration from environment variables. + * @param {object} tokenset - The token set returned by the OpenID provider. + * @returns {Array} The roles extracted from the token. + */ +function extractRoles(tokenset) { + const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; + const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; + const token = + requiredRoleTokenKind === 'access' + ? jwtDecode(tokenset.access_token) + : jwtDecode(tokenset.id_token); + const pathParts = requiredRoleParameterPath.split('.'); + let found = true; + const roles = pathParts.reduce((acc, key) => { + if (!acc || !(key in acc)) { + found = false; + return []; + } + return acc[key]; + }, token); + if (!found) { + logger.error( + `[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`, + ); + } + return roles; +} + +/** + * Updates the user's groups based on the provided roles. + * It removes any existing OpenID group references and then adds the groups + * that match the roles from the external group collection. + * + * @param {object} user - The user object. + * @param {Array} roles - The roles extracted from the token. + * @returns {Promise} The updated groups array. + */ +async function updateUserGroups(user, roles) { + user.groups = user.groups || []; + // Remove existing OpenID group references. + const currentOpenIdGroups = await findGroup({ + _id: { $in: user.groups }, + provider: 'openid', + }); + const currentOpenIdGroupIds = new Set( + currentOpenIdGroups.map((g) => g._id.toString()), + ); + user.groups = user.groups.filter( + (id) => !currentOpenIdGroupIds.has(id.toString()), + ); + + // Look up groups matching the roles. + const matchingGroups = await findGroup({ + provider: 'openid', + externalId: { $in: roles }, + }); + matchingGroups.forEach((group) => { + if (!user.groups.some((id) => id.toString() === group._id.toString())) { + user.groups.push(group._id); + } + }); + return user.groups; +} + async function setupOpenId() { try { if (process.env.PROXY) { @@ -135,8 +200,6 @@ async function setupOpenId() { } const client = new issuer.Client(clientMetadata); const requiredRole = process.env.OPENID_REQUIRED_ROLE; - const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; - const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; const openidLogin = new OpenIDStrategy( { client, @@ -146,8 +209,13 @@ async function setupOpenId() { }, async (tokenset, userinfo, done) => { try { - logger.info(`[openidStrategy] verify login openidId: ${userinfo.sub}`); - logger.debug('[openidStrategy] verify login tokenset and userinfo', { tokenset, userinfo }); + logger.info( + `[openidStrategy] verify login openidId: ${userinfo.sub}`, + ); + logger.debug('[openidStrategy] verify login tokenset and userinfo', { + tokenset, + userinfo, + }); let user = await findUser({ openidId: userinfo.sub }); logger.info( @@ -165,56 +233,15 @@ async function setupOpenId() { const fullName = getFullName(userinfo); + // Check for the required role using extracted roles. + let roles = []; if (requiredRole) { - let decodedToken = ''; - if (requiredRoleTokenKind === 'access') { - decodedToken = jwtDecode(tokenset.access_token); - } else if (requiredRoleTokenKind === 'id') { - decodedToken = jwtDecode(tokenset.id_token); - } - const pathParts = requiredRoleParameterPath.split('.'); - let found = true; - let roles = pathParts.reduce((o, key) => { - if (o === null || o === undefined || !(key in o)) { - found = false; - return []; - } - return o[key]; - }, decodedToken); - - if (!found) { - logger.error( - `[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`, - ); - } - + roles = extractRoles(tokenset); if (!roles.includes(requiredRole)) { return done(null, false, { message: `You must have the "${requiredRole}" role to log in.`, }); } - - if (!user.groups) { - user.groups = []; - } - // Remove existing OpenID group references. - const currentOpenIdGroups = await findGroup({ - _id: { $in: user.groups }, - provider: 'openid', - }); - const currentOpenIdGroupIds = new Set(currentOpenIdGroups.map(g => g._id.toString())); - user.groups = user.groups.filter(id => !currentOpenIdGroupIds.has(id.toString())); - - // Look up groups in the Group collection matching the roles. - const matchingGroups = await findGroup({ - provider: 'openid', - externalId: { $in: roles }, - }); - matchingGroups.forEach(group => { - if (!user.groups.some(id => id.toString() === group._id.toString())) { - user.groups.push(group._id); - } - }); } let username = ''; @@ -266,6 +293,10 @@ async function setupOpenId() { } } + if (requiredRole) { + await updateUserGroups(user, roles); + } + user = await updateUser(user._id, user); logger.info( diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index cea7c5e4a6..ab82ca67ad 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -19,6 +19,9 @@ jest.mock('~/models/userMethods', () => ({ createUser: jest.fn(), updateUser: jest.fn(), })); +jest.mock('~/models/groupMethods', () => ({ + findGroup: jest.fn().mockResolvedValue([]), +})); jest.mock('~/server/utils/crypto', () => ({ hashToken: jest.fn().mockResolvedValue('hashed-token'), }));