mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-25 19:56:13 +01:00
feat: Implement moderation middleware with configurable categories and actions
This commit is contained in:
parent
30db34e737
commit
e8dffd35f3
3 changed files with 213 additions and 35 deletions
|
|
@ -1,40 +1,71 @@
|
|||
const OpenAI = require('openai');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { ErrorTypes, ViolationTypes } = require('librechat-data-provider');
|
||||
const { getCustomConfig } = require('~/server/services/Config');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Middleware to moderate text content using OpenAI's moderation API
|
||||
* @param {Express.Request} req - Express request object
|
||||
* @param {Express.Response} res - Express response object
|
||||
* @param {Express.NextFunction} next - Express next middleware function
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const DEFAULT_ACTIONS = Object.freeze({
|
||||
violation: 2,
|
||||
blockMessage: true,
|
||||
log: true,
|
||||
});
|
||||
|
||||
// Pre-compile threshold map for faster lookups
|
||||
const DEFAULT_THRESHOLDS = new Map();
|
||||
|
||||
function formatViolation(violation) {
|
||||
return {
|
||||
category: violation.category,
|
||||
score: Math.round(violation.score * 100) / 100,
|
||||
threshold: violation.threshold,
|
||||
severity: getSeverityLevel(violation.score),
|
||||
};
|
||||
}
|
||||
|
||||
function getSeverityLevel(score) {
|
||||
if (score >= 0.9) {
|
||||
return 'HIGH';
|
||||
}
|
||||
if (score >= 0.7) {
|
||||
return 'MEDIUM';
|
||||
}
|
||||
return 'LOW';
|
||||
}
|
||||
|
||||
function formatViolationsLog(violations, userId = 'unknown') {
|
||||
const violationsStr = violations
|
||||
.map((v) => `${v.category}:${v.score}>${v.threshold}`)
|
||||
.join(' | ');
|
||||
|
||||
return `userId=${userId} violations=[${violationsStr}]`;
|
||||
}
|
||||
|
||||
async function moderateText(req, res, next) {
|
||||
if (!isEnabled(process.env.OPENAI_MODERATION)) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const moderationKey = process.env.OPENAI_MODERATION_API_KEY;
|
||||
if (!moderationKey) {
|
||||
logger.error('Missing OpenAI moderation API key');
|
||||
return denyRequest(req, res, { message: 'Moderation configuration error' });
|
||||
}
|
||||
|
||||
const { text } = req.body;
|
||||
if (!text?.length || typeof text !== 'string') {
|
||||
return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' });
|
||||
}
|
||||
|
||||
try {
|
||||
const moderationKey = process.env.OPENAI_MODERATION_API_KEY;
|
||||
const customConfig = await getCustomConfig();
|
||||
|
||||
if (!moderationKey) {
|
||||
logger.error('Missing OpenAI moderation API key');
|
||||
return denyRequest(req, res, { message: 'Moderation configuration error' });
|
||||
if (!moderateText.openai) {
|
||||
moderateText.openai = new OpenAI({ apiKey: moderationKey });
|
||||
}
|
||||
|
||||
const openai = new OpenAI({
|
||||
apiKey: moderationKey,
|
||||
});
|
||||
|
||||
const { text } = req.body;
|
||||
|
||||
if (!text || typeof text !== 'string') {
|
||||
return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' });
|
||||
}
|
||||
|
||||
const response = await openai.moderations.create({
|
||||
const response = await moderateText.openai.moderations.create({
|
||||
model: 'omni-moderation-latest',
|
||||
input: text,
|
||||
});
|
||||
|
|
@ -43,22 +74,47 @@ async function moderateText(req, res, next) {
|
|||
throw new Error('Invalid moderation API response format');
|
||||
}
|
||||
|
||||
const flagged = response.results.some((result) => result.flagged);
|
||||
const violations = checkViolations(response.results, customConfig).map(formatViolation);
|
||||
|
||||
if (flagged) {
|
||||
return denyRequest(req, res, {
|
||||
type: ErrorTypes.MODERATION,
|
||||
message: 'Content violates moderation policies',
|
||||
});
|
||||
if (violations.length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('Moderation error:', {
|
||||
error: error.message,
|
||||
stack: error.stack,
|
||||
status: error.response?.status,
|
||||
const actions = Object.assign({}, DEFAULT_ACTIONS, customConfig?.moderation?.actions);
|
||||
|
||||
if (actions.log) {
|
||||
const userId = req.user?.id || 'anonymous';
|
||||
logger.warn(
|
||||
'[moderateText] Content moderation violations: ' + formatViolationsLog(violations, userId),
|
||||
);
|
||||
}
|
||||
|
||||
if (!actions.blockMessage) {
|
||||
return next();
|
||||
}
|
||||
|
||||
if (actions.violation > 0) {
|
||||
logViolation(req, res, ViolationTypes.MODERATION, { violations }, actions.violation);
|
||||
}
|
||||
|
||||
return denyRequest(req, res, {
|
||||
type: ErrorTypes.MODERATION,
|
||||
message: `Content violates moderation policies: ${violations
|
||||
.map((v) => v.category)
|
||||
.join(', ')}`,
|
||||
violations: violations,
|
||||
});
|
||||
} catch (error) {
|
||||
const errorDetails =
|
||||
process.env.NODE_ENV === 'production'
|
||||
? { message: error.message }
|
||||
: {
|
||||
error: error.message,
|
||||
stack: error.stack,
|
||||
status: error.response?.status,
|
||||
};
|
||||
|
||||
logger.error('Moderation error:', errorDetails);
|
||||
|
||||
return denyRequest(req, res, {
|
||||
type: ErrorTypes.MODERATION,
|
||||
|
|
@ -67,4 +123,26 @@ async function moderateText(req, res, next) {
|
|||
}
|
||||
}
|
||||
|
||||
function checkViolations(results, customConfig) {
|
||||
const violations = [];
|
||||
const categories = customConfig?.moderation?.categories || {};
|
||||
|
||||
for (const result of results) {
|
||||
for (const [category, score] of Object.entries(result.category_scores)) {
|
||||
const categoryConfig = categories[category];
|
||||
|
||||
if (categoryConfig?.enabled === false) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const threshold = categoryConfig?.threshold || DEFAULT_THRESHOLDS.get(category) || 0.7;
|
||||
|
||||
if (score >= threshold) {
|
||||
violations.push({ category, score, threshold });
|
||||
}
|
||||
}
|
||||
}
|
||||
return violations;
|
||||
}
|
||||
|
||||
module.exports = moderateText;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue