mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-20 10:20:15 +01:00
✨ feat: Implement role extraction and user group update logic in OpenID strategy
This commit is contained in:
parent
d3764fd9fe
commit
ca3237c7be
2 changed files with 82 additions and 48 deletions
|
|
@ -106,6 +106,71 @@ function convertToUsername(input, defaultValue = '') {
|
|||
return defaultValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts roles from the specified token using configuration from environment variables.
|
||||
* @param {object} tokenset - The token set returned by the OpenID provider.
|
||||
* @returns {Array} The roles extracted from the token.
|
||||
*/
|
||||
function extractRoles(tokenset) {
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const token =
|
||||
requiredRoleTokenKind === 'access'
|
||||
? jwtDecode(tokenset.access_token)
|
||||
: jwtDecode(tokenset.id_token);
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
const roles = pathParts.reduce((acc, key) => {
|
||||
if (!acc || !(key in acc)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return acc[key];
|
||||
}, token);
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
return roles;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the user's groups based on the provided roles.
|
||||
* It removes any existing OpenID group references and then adds the groups
|
||||
* that match the roles from the external group collection.
|
||||
*
|
||||
* @param {object} user - The user object.
|
||||
* @param {Array} roles - The roles extracted from the token.
|
||||
* @returns {Promise<Array>} The updated groups array.
|
||||
*/
|
||||
async function updateUserGroups(user, roles) {
|
||||
user.groups = user.groups || [];
|
||||
// Remove existing OpenID group references.
|
||||
const currentOpenIdGroups = await findGroup({
|
||||
_id: { $in: user.groups },
|
||||
provider: 'openid',
|
||||
});
|
||||
const currentOpenIdGroupIds = new Set(
|
||||
currentOpenIdGroups.map((g) => g._id.toString()),
|
||||
);
|
||||
user.groups = user.groups.filter(
|
||||
(id) => !currentOpenIdGroupIds.has(id.toString()),
|
||||
);
|
||||
|
||||
// Look up groups matching the roles.
|
||||
const matchingGroups = await findGroup({
|
||||
provider: 'openid',
|
||||
externalId: { $in: roles },
|
||||
});
|
||||
matchingGroups.forEach((group) => {
|
||||
if (!user.groups.some((id) => id.toString() === group._id.toString())) {
|
||||
user.groups.push(group._id);
|
||||
}
|
||||
});
|
||||
return user.groups;
|
||||
}
|
||||
|
||||
async function setupOpenId() {
|
||||
try {
|
||||
if (process.env.PROXY) {
|
||||
|
|
@ -135,8 +200,6 @@ async function setupOpenId() {
|
|||
}
|
||||
const client = new issuer.Client(clientMetadata);
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
const openidLogin = new OpenIDStrategy(
|
||||
{
|
||||
client,
|
||||
|
|
@ -146,8 +209,13 @@ async function setupOpenId() {
|
|||
},
|
||||
async (tokenset, userinfo, done) => {
|
||||
try {
|
||||
logger.info(`[openidStrategy] verify login openidId: ${userinfo.sub}`);
|
||||
logger.debug('[openidStrategy] verify login tokenset and userinfo', { tokenset, userinfo });
|
||||
logger.info(
|
||||
`[openidStrategy] verify login openidId: ${userinfo.sub}`,
|
||||
);
|
||||
logger.debug('[openidStrategy] verify login tokenset and userinfo', {
|
||||
tokenset,
|
||||
userinfo,
|
||||
});
|
||||
|
||||
let user = await findUser({ openidId: userinfo.sub });
|
||||
logger.info(
|
||||
|
|
@ -165,56 +233,15 @@ async function setupOpenId() {
|
|||
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
// Check for the required role using extracted roles.
|
||||
let roles = [];
|
||||
if (requiredRole) {
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access') {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
} else if (requiredRoleTokenKind === 'id') {
|
||||
decodedToken = jwtDecode(tokenset.id_token);
|
||||
}
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
let roles = pathParts.reduce((o, key) => {
|
||||
if (o === null || o === undefined || !(key in o)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return o[key];
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
roles = extractRoles(tokenset);
|
||||
if (!roles.includes(requiredRole)) {
|
||||
return done(null, false, {
|
||||
message: `You must have the "${requiredRole}" role to log in.`,
|
||||
});
|
||||
}
|
||||
|
||||
if (!user.groups) {
|
||||
user.groups = [];
|
||||
}
|
||||
// Remove existing OpenID group references.
|
||||
const currentOpenIdGroups = await findGroup({
|
||||
_id: { $in: user.groups },
|
||||
provider: 'openid',
|
||||
});
|
||||
const currentOpenIdGroupIds = new Set(currentOpenIdGroups.map(g => g._id.toString()));
|
||||
user.groups = user.groups.filter(id => !currentOpenIdGroupIds.has(id.toString()));
|
||||
|
||||
// Look up groups in the Group collection matching the roles.
|
||||
const matchingGroups = await findGroup({
|
||||
provider: 'openid',
|
||||
externalId: { $in: roles },
|
||||
});
|
||||
matchingGroups.forEach(group => {
|
||||
if (!user.groups.some(id => id.toString() === group._id.toString())) {
|
||||
user.groups.push(group._id);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let username = '';
|
||||
|
|
@ -266,6 +293,10 @@ async function setupOpenId() {
|
|||
}
|
||||
}
|
||||
|
||||
if (requiredRole) {
|
||||
await updateUserGroups(user, roles);
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ jest.mock('~/models/userMethods', () => ({
|
|||
createUser: jest.fn(),
|
||||
updateUser: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/groupMethods', () => ({
|
||||
findGroup: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
jest.mock('~/server/utils/crypto', () => ({
|
||||
hashToken: jest.fn().mockResolvedValue('hashed-token'),
|
||||
}));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue