LibreChat/api/server/middleware/moderateText.js

148 lines
4 KiB
JavaScript

const OpenAI = require('openai');
const { ErrorTypes, ViolationTypes } = require('librechat-data-provider');
const { getCustomConfig } = require('~/server/services/Config');
const { isEnabled } = require('~/server/utils');
const denyRequest = require('./denyRequest');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const DEFAULT_ACTIONS = Object.freeze({
violation: 2,
blockMessage: true,
log: true,
});
// Pre-compile threshold map for faster lookups
const DEFAULT_THRESHOLDS = new Map();
function formatViolation(violation) {
return {
category: violation.category,
score: Math.round(violation.score * 100) / 100,
threshold: violation.threshold,
severity: getSeverityLevel(violation.score),
};
}
function getSeverityLevel(score) {
if (score >= 0.9) {
return 'HIGH';
}
if (score >= 0.7) {
return 'MEDIUM';
}
return 'LOW';
}
function formatViolationsLog(violations, userId = 'unknown') {
const violationsStr = violations
.map((v) => `${v.category}:${v.score}>${v.threshold}`)
.join(' | ');
return `userId=${userId} violations=[${violationsStr}]`;
}
async function moderateText(req, res, next) {
if (!isEnabled(process.env.OPENAI_MODERATION)) {
return next();
}
const moderationKey = process.env.OPENAI_MODERATION_API_KEY;
if (!moderationKey) {
logger.error('Missing OpenAI moderation API key');
return denyRequest(req, res, { message: 'Moderation configuration error' });
}
const { text } = req.body;
if (!text?.length || typeof text !== 'string') {
return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' });
}
try {
const customConfig = await getCustomConfig();
if (!moderateText.openai) {
moderateText.openai = new OpenAI({ apiKey: moderationKey });
}
const response = await moderateText.openai.moderations.create({
model: 'omni-moderation-latest',
input: text,
});
if (!Array.isArray(response.results)) {
throw new Error('Invalid moderation API response format');
}
const violations = checkViolations(response.results, customConfig).map(formatViolation);
if (violations.length === 0) {
return next();
}
const actions = Object.assign({}, DEFAULT_ACTIONS, customConfig?.moderation?.actions);
if (actions.log) {
const userId = req.user?.id || 'anonymous';
logger.warn(
'[moderateText] Content moderation violations: ' + formatViolationsLog(violations, userId),
);
}
if (!actions.blockMessage) {
return next();
}
if (actions.violation > 0) {
logViolation(req, res, ViolationTypes.MODERATION, { violations }, actions.violation);
}
return denyRequest(req, res, {
type: ErrorTypes.MODERATION,
message: `Content violates moderation policies: ${violations
.map((v) => v.category)
.join(', ')}`,
violations: violations,
});
} catch (error) {
const errorDetails =
process.env.NODE_ENV === 'production'
? { message: error.message }
: {
error: error.message,
stack: error.stack,
status: error.response?.status,
};
logger.error('Moderation error:', errorDetails);
return denyRequest(req, res, {
type: ErrorTypes.MODERATION,
message: 'Content moderation check failed',
});
}
}
function checkViolations(results, customConfig) {
const violations = [];
const categories = customConfig?.moderation?.categories || {};
for (const result of results) {
for (const [category, score] of Object.entries(result.category_scores)) {
const categoryConfig = categories[category];
if (categoryConfig?.enabled === false) {
continue;
}
const threshold = categoryConfig?.threshold || DEFAULT_THRESHOLDS.get(category) || 0.7;
if (score >= threshold) {
violations.push({ category, score, threshold });
}
}
}
return violations;
}
module.exports = moderateText;