diff --git a/api/server/middleware/checkDomainAllowed.js b/api/server/middleware/checkDomainAllowed.js new file mode 100644 index 0000000000..895ce99a56 --- /dev/null +++ b/api/server/middleware/checkDomainAllowed.js @@ -0,0 +1,25 @@ +const { isDomainAllowed } = require('~/server/services/AuthService'); +const { logger } = require('~/config'); + +/** + * Checks the domain's social login is allowed + * + * @async + * @function + * @param {Object} req - Express request object. + * @param {Object} res - Express response object. + * @param {Function} next - Next middleware function. + * + * @returns {Promise} - Returns a Promise which when resolved calls next middleware if the domain's email is allowed + */ +const checkDomainAllowed = async (req, res, next = () => {}) => { + const email = req?.user?.email; + if (email && !(await isDomainAllowed(email))) { + logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`); + return res.redirect('/login'); + } else { + return next(); + } +}; + +module.exports = checkDomainAllowed; diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index b9960a237a..15ec991352 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -1,5 +1,6 @@ const abortMiddleware = require('./abortMiddleware'); const checkBan = require('./checkBan'); +const checkDomainAllowed = require('./checkDomainAllowed'); const uaParser = require('./uaParser'); const setHeaders = require('./setHeaders'); const loginLimiter = require('./loginLimiter'); @@ -38,4 +39,5 @@ module.exports = { validateModel, moderateText, noIndex, + checkDomainAllowed, }; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index e85d83d888..0749436865 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -4,7 +4,7 @@ const passport = require('passport'); const express = require('express'); const router = express.Router(); const { setAuthTokens } = require('~/server/services/AuthService'); -const { loginLimiter, checkBan } = require('~/server/middleware'); +const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware'); const { logger } = require('~/config'); const domains = { @@ -16,6 +16,7 @@ router.use(loginLimiter); const oauthHandler = async (req, res) => { try { + await checkDomainAllowed(req, res); await checkBan(req, res); if (req.banned) { return;