From e8dffd35f3596818b4330e871481247028a73f16 Mon Sep 17 00:00:00 2001 From: Marco Beretta <81851188+berry-13@users.noreply.github.com> Date: Sun, 1 Dec 2024 17:39:58 +0100 Subject: [PATCH] feat: Implement moderation middleware with configurable categories and actions --- api/cache/getLogStores.js | 1 + api/server/middleware/moderateText.js | 148 ++++++++++++++++++++------ packages/data-provider/src/config.ts | 99 +++++++++++++++++ 3 files changed, 213 insertions(+), 35 deletions(-) diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 1fdaee9006..3b4dd714ed 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -78,6 +78,7 @@ const namespaces = { [ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance( ViolationTypes.ILLEGAL_MODEL_REQUEST, ), + [ViolationTypes.MODERATION]: createViolationInstance(ViolationTypes.MODERATION), logins: createViolationInstance('logins'), [CacheKeys.ABORT_KEYS]: abortKeys, [CacheKeys.TOKEN_CONFIG]: tokenConfig, diff --git a/api/server/middleware/moderateText.js b/api/server/middleware/moderateText.js index 7346f97ce7..52505c79b2 100644 --- a/api/server/middleware/moderateText.js +++ b/api/server/middleware/moderateText.js @@ -1,40 +1,71 @@ const OpenAI = require('openai'); -const { ErrorTypes } = require('librechat-data-provider'); +const { ErrorTypes, ViolationTypes } = require('librechat-data-provider'); +const { getCustomConfig } = require('~/server/services/Config'); const { isEnabled } = require('~/server/utils'); const denyRequest = require('./denyRequest'); +const { logViolation } = require('~/cache'); const { logger } = require('~/config'); -/** - * Middleware to moderate text content using OpenAI's moderation API - * @param {Express.Request} req - Express request object - * @param {Express.Response} res - Express response object - * @param {Express.NextFunction} next - Express next middleware function - * @returns {Promise} - */ +const DEFAULT_ACTIONS = Object.freeze({ + violation: 2, + blockMessage: true, + log: true, +}); + +// Pre-compile threshold map for faster lookups +const DEFAULT_THRESHOLDS = new Map(); + +function formatViolation(violation) { + return { + category: violation.category, + score: Math.round(violation.score * 100) / 100, + threshold: violation.threshold, + severity: getSeverityLevel(violation.score), + }; +} + +function getSeverityLevel(score) { + if (score >= 0.9) { + return 'HIGH'; + } + if (score >= 0.7) { + return 'MEDIUM'; + } + return 'LOW'; +} + +function formatViolationsLog(violations, userId = 'unknown') { + const violationsStr = violations + .map((v) => `${v.category}:${v.score}>${v.threshold}`) + .join(' | '); + + return `userId=${userId} violations=[${violationsStr}]`; +} + async function moderateText(req, res, next) { if (!isEnabled(process.env.OPENAI_MODERATION)) { return next(); } + const moderationKey = process.env.OPENAI_MODERATION_API_KEY; + if (!moderationKey) { + logger.error('Missing OpenAI moderation API key'); + return denyRequest(req, res, { message: 'Moderation configuration error' }); + } + + const { text } = req.body; + if (!text?.length || typeof text !== 'string') { + return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' }); + } + try { - const moderationKey = process.env.OPENAI_MODERATION_API_KEY; + const customConfig = await getCustomConfig(); - if (!moderationKey) { - logger.error('Missing OpenAI moderation API key'); - return denyRequest(req, res, { message: 'Moderation configuration error' }); + if (!moderateText.openai) { + moderateText.openai = new OpenAI({ apiKey: moderationKey }); } - const openai = new OpenAI({ - apiKey: moderationKey, - }); - - const { text } = req.body; - - if (!text || typeof text !== 'string') { - return denyRequest(req, res, { type: ErrorTypes.VALIDATION, message: 'Invalid text input' }); - } - - const response = await openai.moderations.create({ + const response = await moderateText.openai.moderations.create({ model: 'omni-moderation-latest', input: text, }); @@ -43,22 +74,47 @@ async function moderateText(req, res, next) { throw new Error('Invalid moderation API response format'); } - const flagged = response.results.some((result) => result.flagged); + const violations = checkViolations(response.results, customConfig).map(formatViolation); - if (flagged) { - return denyRequest(req, res, { - type: ErrorTypes.MODERATION, - message: 'Content violates moderation policies', - }); + if (violations.length === 0) { + return next(); } - next(); - } catch (error) { - logger.error('Moderation error:', { - error: error.message, - stack: error.stack, - status: error.response?.status, + const actions = Object.assign({}, DEFAULT_ACTIONS, customConfig?.moderation?.actions); + + if (actions.log) { + const userId = req.user?.id || 'anonymous'; + logger.warn( + '[moderateText] Content moderation violations: ' + formatViolationsLog(violations, userId), + ); + } + + if (!actions.blockMessage) { + return next(); + } + + if (actions.violation > 0) { + logViolation(req, res, ViolationTypes.MODERATION, { violations }, actions.violation); + } + + return denyRequest(req, res, { + type: ErrorTypes.MODERATION, + message: `Content violates moderation policies: ${violations + .map((v) => v.category) + .join(', ')}`, + violations: violations, }); + } catch (error) { + const errorDetails = + process.env.NODE_ENV === 'production' + ? { message: error.message } + : { + error: error.message, + stack: error.stack, + status: error.response?.status, + }; + + logger.error('Moderation error:', errorDetails); return denyRequest(req, res, { type: ErrorTypes.MODERATION, @@ -67,4 +123,26 @@ async function moderateText(req, res, next) { } } +function checkViolations(results, customConfig) { + const violations = []; + const categories = customConfig?.moderation?.categories || {}; + + for (const result of results) { + for (const [category, score] of Object.entries(result.category_scores)) { + const categoryConfig = categories[category]; + + if (categoryConfig?.enabled === false) { + continue; + } + + const threshold = categoryConfig?.threshold || DEFAULT_THRESHOLDS.get(category) || 0.7; + + if (score >= threshold) { + violations.push({ category, score, threshold }); + } + } + } + return violations; +} + module.exports = moderateText; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 04f3faf077..399e939317 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -437,6 +437,100 @@ export const rateLimitSchema = z.object({ .optional(), }); +const moderationSchema = z + .object({ + categories: z + .object({ + sexual: z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'sexual/minors': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0), + }) + .optional(), + harassment: z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'harassment/threatening': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + hate: z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'hate/threatening': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + illicit: z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'illicit/violent': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'self-harm': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'self-harm/intent': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'self-harm/instructions': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + violence: z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + 'violence/graphic': z + .object({ + enabled: z.boolean().default(true), + threshold: z.number().min(0).max(1).default(0.7), + }) + .optional(), + }) + .optional(), + actions: z + .object({ + violation: z.number().default(2), + blockMessage: z.boolean().default(true), + log: z.boolean().default(false), + }) + .optional(), + }) + .optional(); + export enum EImageOutputType { PNG = 'png', WEBP = 'webp', @@ -487,6 +581,7 @@ export const configSchema = z.object({ prompts: true, }), fileStrategy: fileSourceSchema.default(FileSources.local), + moderation: moderationSchema.optional(), registration: z .object({ socialLogins: z.array(z.string()).optional(), @@ -931,6 +1026,10 @@ export enum ViolationTypes { * Verify Conversation Access violation. */ CONVO_ACCESS = 'convo_access', + /** + * Verify moderation LLM violation. + */ + MODERATION = 'moderation', } /**