mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-25 11:46:12 +01:00
148 lines
4 KiB
JavaScript
148 lines
4 KiB
JavaScript
const OpenAI = require('openai');
|
|
const { ErrorTypes, ViolationTypes } = require('librechat-data-provider');
|
|
const { getCustomConfig } = require('~/server/services/Config');
|
|
const { isEnabled } = require('~/server/utils');
|
|
const denyRequest = require('./denyRequest');
|
|
const { logViolation } = require('~/cache');
|
|
const { logger } = require('~/config');
|
|
|
|
const DEFAULT_ACTIONS = Object.freeze({
|
|
violation: 2,
|
|
blockMessage: true,
|
|
log: true,
|
|
});
|
|
|
|
// Pre-compile threshold map for faster lookups
|
|
const DEFAULT_THRESHOLDS = new Map();
|
|
|
|
function formatViolation(violation) {
|
|
return {
|
|
category: violation.category,
|
|
score: Math.round(violation.score * 100) / 100,
|
|
threshold: violation.threshold,
|
|
severity: getSeverityLevel(violation.score),
|
|
};
|
|
}
|
|
|
|
function getSeverityLevel(score) {
|
|
if (score >= 0.9) {
|
|
return 'HIGH';
|
|
}
|
|
if (score >= 0.7) {
|
|
return 'MEDIUM';
|
|
}
|
|
return 'LOW';
|
|
}
|
|
|
|
function formatViolationsLog(violations, userId = 'unknown') {
|
|
const violationsStr = violations
|
|
.map((v) => `${v.category}:${v.score}>${v.threshold}`)
|
|
.join(' | ');
|
|
|
|
return `userId=${userId} violations=[${violationsStr}]`;
|
|
}
|
|
|
|
async function moderateText(req, res, next) {
|
|
if (!isEnabled(process.env.OPENAI_MODERATION)) {
|
|
return next();
|
|
}
|
|
|
|
const moderationKey = process.env.OPENAI_MODERATION_API_KEY;
|
|
if (!moderationKey) {
|
|
logger.error('Missing OpenAI moderation API key');
|
|
return denyRequest(req, res, { message: 'Moderation configuration error' });
|
|
}
|
|
|
|
const { text } = req.body;
|
|
if (!text?.length || typeof text !== 'string') {
|
|
return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' });
|
|
}
|
|
|
|
try {
|
|
const customConfig = await getCustomConfig();
|
|
|
|
if (!moderateText.openai) {
|
|
moderateText.openai = new OpenAI({ apiKey: moderationKey });
|
|
}
|
|
|
|
const response = await moderateText.openai.moderations.create({
|
|
model: 'omni-moderation-latest',
|
|
input: text,
|
|
});
|
|
|
|
if (!Array.isArray(response.results)) {
|
|
throw new Error('Invalid moderation API response format');
|
|
}
|
|
|
|
const violations = checkViolations(response.results, customConfig).map(formatViolation);
|
|
|
|
if (violations.length === 0) {
|
|
return next();
|
|
}
|
|
|
|
const actions = Object.assign({}, DEFAULT_ACTIONS, customConfig?.moderation?.actions);
|
|
|
|
if (actions.log) {
|
|
const userId = req.user?.id || 'anonymous';
|
|
logger.warn(
|
|
'[moderateText] Content moderation violations: ' + formatViolationsLog(violations, userId),
|
|
);
|
|
}
|
|
|
|
if (!actions.blockMessage) {
|
|
return next();
|
|
}
|
|
|
|
if (actions.violation > 0) {
|
|
logViolation(req, res, ViolationTypes.MODERATION, { violations }, actions.violation);
|
|
}
|
|
|
|
return denyRequest(req, res, {
|
|
type: ErrorTypes.MODERATION,
|
|
message: `Content violates moderation policies: ${violations
|
|
.map((v) => v.category)
|
|
.join(', ')}`,
|
|
violations: violations,
|
|
});
|
|
} catch (error) {
|
|
const errorDetails =
|
|
process.env.NODE_ENV === 'production'
|
|
? { message: error.message }
|
|
: {
|
|
error: error.message,
|
|
stack: error.stack,
|
|
status: error.response?.status,
|
|
};
|
|
|
|
logger.error('Moderation error:', errorDetails);
|
|
|
|
return denyRequest(req, res, {
|
|
type: ErrorTypes.MODERATION,
|
|
message: 'Content moderation check failed',
|
|
});
|
|
}
|
|
}
|
|
|
|
function checkViolations(results, customConfig) {
|
|
const violations = [];
|
|
const categories = customConfig?.moderation?.categories || {};
|
|
|
|
for (const result of results) {
|
|
for (const [category, score] of Object.entries(result.category_scores)) {
|
|
const categoryConfig = categories[category];
|
|
|
|
if (categoryConfig?.enabled === false) {
|
|
continue;
|
|
}
|
|
|
|
const threshold = categoryConfig?.threshold || DEFAULT_THRESHOLDS.get(category) || 0.7;
|
|
|
|
if (score >= threshold) {
|
|
violations.push({ category, score, threshold });
|
|
}
|
|
}
|
|
}
|
|
return violations;
|
|
}
|
|
|
|
module.exports = moderateText;
|