feat: Implement moderation middleware with configurable categories and actions

This commit is contained in:
Marco Beretta 2024-12-01 17:39:58 +01:00
parent 30db34e737
commit e8dffd35f3
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
3 changed files with 213 additions and 35 deletions

View file

@ -78,6 +78,7 @@ const namespaces = {
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
[ViolationTypes.MODERATION]: createViolationInstance(ViolationTypes.MODERATION),
logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,

View file

@ -1,40 +1,71 @@
const OpenAI = require('openai');
const { ErrorTypes } = require('librechat-data-provider');
const { ErrorTypes, ViolationTypes } = require('librechat-data-provider');
const { getCustomConfig } = require('~/server/services/Config');
const { isEnabled } = require('~/server/utils');
const denyRequest = require('./denyRequest');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
/**
* Middleware to moderate text content using OpenAI's moderation API
* @param {Express.Request} req - Express request object
* @param {Express.Response} res - Express response object
* @param {Express.NextFunction} next - Express next middleware function
* @returns {Promise<void>}
*/
const DEFAULT_ACTIONS = Object.freeze({
violation: 2,
blockMessage: true,
log: true,
});
// Pre-compile threshold map for faster lookups
const DEFAULT_THRESHOLDS = new Map();
function formatViolation(violation) {
return {
category: violation.category,
score: Math.round(violation.score * 100) / 100,
threshold: violation.threshold,
severity: getSeverityLevel(violation.score),
};
}
function getSeverityLevel(score) {
if (score >= 0.9) {
return 'HIGH';
}
if (score >= 0.7) {
return 'MEDIUM';
}
return 'LOW';
}
function formatViolationsLog(violations, userId = 'unknown') {
const violationsStr = violations
.map((v) => `${v.category}:${v.score}>${v.threshold}`)
.join(' | ');
return `userId=${userId} violations=[${violationsStr}]`;
}
async function moderateText(req, res, next) {
if (!isEnabled(process.env.OPENAI_MODERATION)) {
return next();
}
const moderationKey = process.env.OPENAI_MODERATION_API_KEY;
if (!moderationKey) {
logger.error('Missing OpenAI moderation API key');
return denyRequest(req, res, { message: 'Moderation configuration error' });
}
const { text } = req.body;
if (!text?.length || typeof text !== 'string') {
return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' });
}
try {
const moderationKey = process.env.OPENAI_MODERATION_API_KEY;
const customConfig = await getCustomConfig();
if (!moderationKey) {
logger.error('Missing OpenAI moderation API key');
return denyRequest(req, res, { message: 'Moderation configuration error' });
if (!moderateText.openai) {
moderateText.openai = new OpenAI({ apiKey: moderationKey });
}
const openai = new OpenAI({
apiKey: moderationKey,
});
const { text } = req.body;
if (!text || typeof text !== 'string') {
return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' });
}
const response = await openai.moderations.create({
const response = await moderateText.openai.moderations.create({
model: 'omni-moderation-latest',
input: text,
});
@ -43,22 +74,47 @@ async function moderateText(req, res, next) {
throw new Error('Invalid moderation API response format');
}
const flagged = response.results.some((result) => result.flagged);
const violations = checkViolations(response.results, customConfig).map(formatViolation);
if (flagged) {
return denyRequest(req, res, {
type: ErrorTypes.MODERATION,
message: 'Content violates moderation policies',
});
if (violations.length === 0) {
return next();
}
next();
} catch (error) {
logger.error('Moderation error:', {
error: error.message,
stack: error.stack,
status: error.response?.status,
const actions = Object.assign({}, DEFAULT_ACTIONS, customConfig?.moderation?.actions);
if (actions.log) {
const userId = req.user?.id || 'anonymous';
logger.warn(
'[moderateText] Content moderation violations: ' + formatViolationsLog(violations, userId),
);
}
if (!actions.blockMessage) {
return next();
}
if (actions.violation > 0) {
logViolation(req, res, ViolationTypes.MODERATION, { violations }, actions.violation);
}
return denyRequest(req, res, {
type: ErrorTypes.MODERATION,
message: `Content violates moderation policies: ${violations
.map((v) => v.category)
.join(', ')}`,
violations: violations,
});
} catch (error) {
const errorDetails =
process.env.NODE_ENV === 'production'
? { message: error.message }
: {
error: error.message,
stack: error.stack,
status: error.response?.status,
};
logger.error('Moderation error:', errorDetails);
return denyRequest(req, res, {
type: ErrorTypes.MODERATION,
@ -67,4 +123,26 @@ async function moderateText(req, res, next) {
}
}
function checkViolations(results, customConfig) {
const violations = [];
const categories = customConfig?.moderation?.categories || {};
for (const result of results) {
for (const [category, score] of Object.entries(result.category_scores)) {
const categoryConfig = categories[category];
if (categoryConfig?.enabled === false) {
continue;
}
const threshold = categoryConfig?.threshold || DEFAULT_THRESHOLDS.get(category) || 0.7;
if (score >= threshold) {
violations.push({ category, score, threshold });
}
}
}
return violations;
}
module.exports = moderateText;

View file

@ -437,6 +437,100 @@ export const rateLimitSchema = z.object({
.optional(),
});
const moderationSchema = z
.object({
categories: z
.object({
sexual: z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'sexual/minors': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0),
})
.optional(),
harassment: z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'harassment/threatening': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
hate: z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'hate/threatening': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
illicit: z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'illicit/violent': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'self-harm': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'self-harm/intent': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'self-harm/instructions': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
violence: z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
'violence/graphic': z
.object({
enabled: z.boolean().default(true),
threshold: z.number().min(0).max(1).default(0.7),
})
.optional(),
})
.optional(),
actions: z
.object({
violation: z.number().default(2),
blockMessage: z.boolean().default(true),
log: z.boolean().default(false),
})
.optional(),
})
.optional();
export enum EImageOutputType {
PNG = 'png',
WEBP = 'webp',
@ -487,6 +581,7 @@ export const configSchema = z.object({
prompts: true,
}),
fileStrategy: fileSourceSchema.default(FileSources.local),
moderation: moderationSchema.optional(),
registration: z
.object({
socialLogins: z.array(z.string()).optional(),
@ -931,6 +1026,10 @@ export enum ViolationTypes {
* Verify Conversation Access violation.
*/
CONVO_ACCESS = 'convo_access',
/**
* Verify moderation LLM violation.
*/
MODERATION = 'moderation',
}
/**