From 7b2cedf5ff197b93ce25ef8e1e195e0d2ec83267 Mon Sep 17 00:00:00 2001 From: Danny Avila <110412045+danny-avila@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:57:07 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20Message=20Rate=20Limiters,=20Violation?= =?UTF-8?q?=20Logging,=20&=20Ban=20System=20=F0=9F=94=A8=20(#903)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: require Auth middleware in route index files * feat: concurrent message limiter * feat: complete concurrent message limiter with caching * refactor: SSE response methods separated from handleText * fix(abortMiddleware): fix req and res order to standard, use endpointOption in req.body * chore: minor name changes * refactor: add isUUID condition to saveMessage * fix(concurrentLimiter): logic correctly handles the max number of concurrent messages and res closing/finalization * chore: bump keyv and remove console.log from Message * fix(concurrentLimiter): ensure messages are only saved in later message children * refactor(concurrentLimiter): use KeyvFile instead, could make other stores configurable in the future * feat: add denyRequest function for error responses * feat(utils): add isStringTruthy function Introduce the isStringTruthy function to the utilities module to check if a string value is a case-insensitive match for 'true' * feat: add optional message rate limiters by IP and userId * feat: add optional message rate limiters by IP and userId to edit route * refactor: rename isStringTruthy to isTrue for brevity * refactor(getError): use map to make code cleaner * refactor: use memory for concurrent rate limiter to prevent clearing on startup/exit, add multiple log files, fix error message for concurrent violation * feat: check if errorMessage is object, stringify if so * chore: send object to denyRequest which will stringify it * feat: log excessive requests * fix(getError): correctly pluralize messages * refactor(limiters): make type consistent between logs and errorMessage * refactor(cache): move files out of lib/db into separate cache dir >> feat: add getLogStores function so Keyv instance is not redundantly created on every violation feat: separate violation logging to own function with logViolation * fix: cache/index.js export, properly record userViolations * refactor(messageLimiters): use new logging method, add logging to registrations * refactor(logViolation): make userLogs an array of logs per user * feat: add logging to login limiter * refactor: pass req as first param to logViolation and record offending IP * refactor: rename isTrue helper fn to isEnabled * feat: add simple non_browser check and log violation * fix: open handles in unit tests, remove KeyvMongo as not used and properly mock global fetch * chore: adjust nodemon ignore paths to properly ignore logs * feat: add math helper function for safe use of eval * refactor(api/convos): use middleware at top of file to avoid redundancy * feat: add delete all static method for Sessions * fix: redirect to login on refresh if user is not found, or the session is not found but hasn't expired (ban case) * refactor(getLogStores): adjust return type * feat: add ban violation and check ban logic refactor(logViolation): pass both req and res objects * feat: add removePorts helper function * refactor: rename getError to getMessageError and add getLoginError for displaying different login errors * fix(AuthContext): fix type issue and remove unused code * refactor(bans): ban by ip and user id, send response based on origin * chore: add frontend ban messages * refactor(routes/oauth): add ban check to handler, also consolidate logic to avoid redundancy * feat: add ban check to AI messaging routes * feat: add ban check to login/registration * fix(ci/api): mock KeyvMongo to avoid tests hanging * docs: update .env.example > refactor(banViolation): calculate interval rate crossover, early return if duration is invalid ci(banViolation): add tests to ensure users are only banned when expected * docs: improve wording for mod system * feat: add configurable env variables for violation scores * chore: add jsdoc for uaParser.js * chore: improve ban text log * chore: update bun test scripts * refactor(math.js): add fallback values * fix(KeyvMongo/banLogs): refactor keyv instances to top of files to avoid memory leaks, refactor ban logic to use getLogStores instead refactor(getLogStores): get a single log store by type * fix(ci): refactor tests due to banLogs changes, also make sure to clear and revoke sessions even if ban duration is 0 * fix(banViolation.js): getLogStores import * feat: handle 500 code error at login * fix(middleware): handle case where user.id is _id and not just id * ci: add ban secrets for backend unit tests * refactor: logout user upon ban * chore: log session delete message only if deletedCount > 0 * refactor: change default ban duration (2h) and make logic more clear in JSDOC * fix: login and registration limiters will now return rate limiting error * fix: userId not parsable as non ObjectId string * feat: add useTimeout hook to properly clear timeouts when invoking functions within them refactor(AuthContext): cleanup code by using new hook and defining types in ~/common * fix: login error message for rate limits * docs: add info for automated mod system and rate limiters, update other docs accordingly * chore: bump data-provider version --- .env.example | 36 +- .github/workflows/backend-review.yml | 3 + .gitignore | 1 + README.md | 3 +- .../tools/dynamic/OpenAPIPlugin.spec.js | 9 +- api/cache/banViolation.js | 68 ++++ api/cache/banViolation.spec.js | 155 +++++++ api/cache/clearPendingReq.js | 29 ++ api/cache/getLogStores.js | 40 ++ api/cache/index.js | 6 + api/cache/keyvFiles.js | 11 + api/cache/keyvMongo.js | 7 + api/cache/logViolation.js | 36 ++ api/jest.config.js | 2 +- api/lib/db/index.js | 4 + api/models/Message.js | 8 +- api/models/Session.js | 15 + api/package.json | 4 +- api/server/controllers/AuthController.js | 4 +- api/server/index.js | 12 +- api/server/middleware/abortMiddleware.js | 32 +- api/server/middleware/checkBan.js | 92 +++++ api/server/middleware/concurrentLimiter.js | 81 ++++ api/server/middleware/denyRequest.js | 58 +++ api/server/middleware/index.js | 8 + api/server/middleware/loginLimiter.js | 28 +- api/server/middleware/messageLimiters.js | 67 ++++ api/server/middleware/registerLimiter.js | 28 +- api/server/middleware/uaParser.js | 31 ++ api/server/routes/ask/anthropic.js | 235 ++++++----- api/server/routes/ask/askChatGPTBrowser.js | 4 +- api/server/routes/ask/bingAI.js | 4 +- api/server/routes/ask/google.js | 4 +- api/server/routes/ask/gptPlugins.js | 377 +++++++++--------- api/server/routes/ask/index.js | 27 ++ api/server/routes/ask/openAI.js | 253 ++++++------ api/server/routes/auth.js | 5 +- api/server/routes/convos.js | 10 +- api/server/routes/edit/anthropic.js | 225 +++++------ api/server/routes/edit/gptPlugins.js | 309 +++++++------- api/server/routes/edit/index.js | 29 +- api/server/routes/edit/openAI.js | 231 +++++------ api/server/routes/oauth.js | 61 +-- api/server/utils/handleText.js | 36 +- api/server/utils/index.js | 8 +- api/server/utils/math.js | 48 +++ api/server/utils/removePorts.js | 1 + api/server/utils/streamResponse.js | 63 +++ api/test/.env.test.example | 4 + api/test/__mocks__/KeyvMongo.js | 30 ++ client/package.json | 1 + client/src/common/types.ts | 34 +- client/src/components/Auth/Login.tsx | 5 +- .../Messages/Content/MessageContent.tsx | 26 +- client/src/hooks/AuthContext.tsx | 54 +-- client/src/hooks/index.ts | 1 + client/src/hooks/useTimeout.tsx | 39 ++ client/src/localization/languages/Eng.tsx | 6 +- client/src/utils/getError.ts | 28 -- client/src/utils/getLoginError.ts | 18 + client/src/utils/getMessageError.ts | 62 +++ client/src/utils/index.ts | 3 +- docs/features/mod_system.md | 67 ++++ .../{third-party.md => third_party.md} | 0 docs/install/user_auth_system.md | 20 +- mkdocs.yml | 3 +- package-lock.json | 25 +- package.json | 6 +- packages/data-provider/package.json | 2 +- 69 files changed, 2180 insertions(+), 1062 deletions(-) create mode 100644 api/cache/banViolation.js create mode 100644 api/cache/banViolation.spec.js create mode 100644 api/cache/clearPendingReq.js create mode 100644 api/cache/getLogStores.js create mode 100644 api/cache/index.js create mode 100644 api/cache/keyvFiles.js create mode 100644 api/cache/keyvMongo.js create mode 100644 api/cache/logViolation.js create mode 100644 api/lib/db/index.js create mode 100644 api/server/middleware/checkBan.js create mode 100644 api/server/middleware/concurrentLimiter.js create mode 100644 api/server/middleware/denyRequest.js create mode 100644 api/server/middleware/messageLimiters.js create mode 100644 api/server/middleware/uaParser.js create mode 100644 api/server/utils/math.js create mode 100644 api/server/utils/removePorts.js create mode 100644 api/server/utils/streamResponse.js create mode 100644 api/test/__mocks__/KeyvMongo.js create mode 100644 client/src/hooks/useTimeout.tsx delete mode 100644 client/src/utils/getError.ts create mode 100644 client/src/utils/getLoginError.ts create mode 100644 client/src/utils/getMessageError.ts create mode 100644 docs/features/mod_system.md rename docs/features/{third-party.md => third_party.md} (100%) diff --git a/.env.example b/.env.example index d4788e71a..385cbbc7e 100644 --- a/.env.example +++ b/.env.example @@ -13,12 +13,44 @@ APP_TITLE=LibreChat HOST=localhost PORT=3080 +# Automated Moderation System +# The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions +# like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching +# a set threshold, the user and their IP are temporarily banned. This system ensures platform security +# by monitoring and penalizing rapid or suspicious activities. + +BAN_VIOLATIONS=true # Whether or not to enable banning users for violations (they will still be logged) +BAN_DURATION=1000 * 60 * 60 * 2 # how long the user and associated IP are banned for +BAN_INTERVAL=20 # a user will be banned everytime their score reaches/crosses over the interval threshold + +# The score for each violation + +LOGIN_VIOLATION_SCORE=1 +REGISTRATION_VIOLATION_SCORE=1 +CONCURRENT_VIOLATION_SCORE=1 +MESSAGE_VIOLATION_SCORE=1 +NON_BROWSER_VIOLATION_SCORE=20 + # Login and registration rate limiting. LOGIN_MAX=7 # The max amount of logins allowed per IP per LOGIN_WINDOW -LOGIN_WINDOW=5 # in minutes, determines how long an IP is banned for after LOGIN_MAX logins +LOGIN_WINDOW=5 # in minutes, determines the window of time for LOGIN_MAX logins REGISTER_MAX=5 # The max amount of registrations allowed per IP per REGISTER_WINDOW -REGISTER_WINDOW=60 # in minutes, determines how long an IP is banned for after REGISTER_MAX registrations +REGISTER_WINDOW=60 # in minutes, determines the window of time for REGISTER_MAX registrations + +# Message rate limiting (per user & IP) + +LIMIT_CONCURRENT_MESSAGES=true # Whether to limit the amount of messages a user can send per request +CONCURRENT_MESSAGE_MAX=2 # The max amount of messages a user can send per request + +LIMIT_MESSAGE_IP=true # Whether to limit the amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_MAX=40 # The max amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_MAX messages + +# Note: You can utilize both limiters, but default is to limit by IP only. +LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages # Change this to proxy any API request. # It's useful if your machine has difficulty calling the original API server. diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index e5c86caa4..a005c10ef 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -18,6 +18,9 @@ jobs: JWT_SECRET: ${{ secrets.JWT_SECRET }} CREDS_KEY: ${{ secrets.CREDS_KEY }} CREDS_IV: ${{ secrets.CREDS_IV }} + BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }} + BAN_DURATION: ${{ secrets.BAN_DURATION }} + BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }} NODE_ENV: ci steps: - uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index c294b25b4..52ce79baa 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ # Logs data-node meili_data +data/ logs *.log diff --git a/README.md b/README.md index bfe3cf4b2..b5fca4c44 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,8 @@ Keep up with the latest updates by visiting the releases page - [Releases](https * [Using official ChatGPT Plugins](docs/features/plugins/chatgpt_plugins_openapi.md) - * [Third-Party Tools](docs/features/third-party.md) + * [Automated Moderation](docs/features/mod_system.md) + * [Third-Party Tools](docs/features/third_party.md) * [Proxy](docs/features/proxy.md) * [Bing Jailbreak](docs/features/bing_jailbreak.md) diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js index 5fe7f1cb3..83bc5e939 100644 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js +++ b/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js @@ -1,7 +1,14 @@ const fs = require('fs'); const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin'); -jest.mock('node-fetch'); +global.fetch = jest.fn().mockImplementationOnce(() => { + return new Promise((resolve) => { + resolve({ + ok: true, + json: () => Promise.resolve({ key: 'value' }), + }); + }); +}); jest.mock('fs', () => ({ promises: { readFile: jest.fn(), diff --git a/api/cache/banViolation.js b/api/cache/banViolation.js new file mode 100644 index 000000000..f00296d3b --- /dev/null +++ b/api/cache/banViolation.js @@ -0,0 +1,68 @@ +const Session = require('../models/Session'); +const getLogStores = require('./getLogStores'); +const { isEnabled, math, removePorts } = require('../server/utils'); +const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; +const interval = math(BAN_INTERVAL, 20); + +/** + * Bans a user based on violation criteria. + * + * If the user's violation count is a multiple of the BAN_INTERVAL, the user will be banned. + * The duration of the ban is determined by the BAN_DURATION environment variable. + * If BAN_DURATION is not set or invalid, the user will not be banned. + * Sessions will be deleted and the refreshToken cookie will be cleared even with + * an invalid or nill duration, which is a "soft" ban; the user can remain active until + * access token expiry. + * + * @async + * @param {Object} req - Express request object containing user information. + * @param {Object} res - Express response object. + * @param {Object} errorMessage - Object containing user violation details. + * @param {string} errorMessage.type - Type of the violation. + * @param {string} errorMessage.user_id - ID of the user who committed the violation. + * @param {number} errorMessage.violation_count - Number of violations committed by the user. + * + * @returns {Promise} + * + */ +const banViolation = async (req, res, errorMessage) => { + if (!isEnabled(BAN_VIOLATIONS)) { + return; + } + + if (!errorMessage) { + return; + } + + const { type, user_id, prev_count, violation_count } = errorMessage; + + const prevThreshold = Math.floor(prev_count / interval); + const currentThreshold = Math.floor(violation_count / interval); + + if (prevThreshold >= currentThreshold) { + return; + } + + await Session.deleteAllUserSessions(user_id); + res.clearCookie('refreshToken'); + + const banLogs = getLogStores('ban'); + const duration = banLogs.opts.ttl; + + if (duration <= 0) { + return; + } + + req.ip = removePorts(req); + console.log(`[BAN] Banning user ${user_id} @ ${req.ip} for ${duration / 1000 / 60} minutes`); + const expiresAt = Date.now() + duration; + await banLogs.set(user_id, { type, violation_count, duration, expiresAt }); + await banLogs.set(req.ip, { type, user_id, violation_count, duration, expiresAt }); + + errorMessage.ban = true; + errorMessage.ban_duration = duration; + + return; +}; + +module.exports = banViolation; diff --git a/api/cache/banViolation.spec.js b/api/cache/banViolation.spec.js new file mode 100644 index 000000000..ba8e78a1e --- /dev/null +++ b/api/cache/banViolation.spec.js @@ -0,0 +1,155 @@ +const banViolation = require('./banViolation'); + +jest.mock('keyv'); +jest.mock('../models/Session'); +// Mocking the getLogStores function +jest.mock('./getLogStores', () => { + return jest.fn().mockImplementation(() => { + const EventEmitter = require('events'); + const math = require('../server/utils/math'); + const mockGet = jest.fn(); + const mockSet = jest.fn(); + class KeyvMongo extends EventEmitter { + constructor(url = 'mongodb://127.0.0.1:27017', options) { + super(); + this.ttlSupport = false; + url = url ?? {}; + if (typeof url === 'string') { + url = { url }; + } + if (url.uri) { + url = { url: url.uri, ...url }; + } + this.opts = { + url, + collection: 'keyv', + ...url, + ...options, + }; + } + + get = mockGet; + set = mockSet; + } + + return new KeyvMongo('', { + namespace: 'bans', + ttl: math(process.env.BAN_DURATION, 7200000), + }); + }); +}); + +describe('banViolation', () => { + let req, res, errorMessage; + + beforeEach(() => { + req = { + ip: '127.0.0.1', + cookies: { + refreshToken: 'someToken', + }, + }; + res = { + clearCookie: jest.fn(), + }; + errorMessage = { + type: 'someViolation', + user_id: '12345', + prev_count: 0, + violation_count: 0, + }; + process.env.BAN_VIOLATIONS = 'true'; + process.env.BAN_DURATION = '7200000'; // 2 hours in ms + process.env.BAN_INTERVAL = '20'; + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should not ban if BAN_VIOLATIONS are not enabled', async () => { + process.env.BAN_VIOLATIONS = 'false'; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('should not ban if errorMessage is not provided', async () => { + await banViolation(req, res, null); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('[1/3] should ban if violation_count crosses the interval threshold: 19 -> 39', async () => { + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('[2/3] should ban if violation_count crosses the interval threshold: 19 -> 20', async () => { + errorMessage.prev_count = 19; + errorMessage.violation_count = 20; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + const randomValueAbove = Math.floor(20 + Math.random() * 100); + it(`[3/3] should ban if violation_count crosses the interval threshold: 19 -> ${randomValueAbove}`, async () => { + errorMessage.prev_count = 19; + errorMessage.violation_count = randomValueAbove; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('should handle invalid BAN_INTERVAL and default to 20', async () => { + process.env.BAN_INTERVAL = 'invalid'; + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('should ban if BAN_DURATION is invalid as default is 2 hours', async () => { + process.env.BAN_DURATION = 'invalid'; + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeTruthy(); + }); + + it('should not ban if BAN_DURATION is 0 but should clear cookies', async () => { + process.env.BAN_DURATION = '0'; + errorMessage.prev_count = 19; + errorMessage.violation_count = 39; + await banViolation(req, res, errorMessage); + expect(res.clearCookie).toHaveBeenCalledWith('refreshToken'); + }); + + it('should not ban if violation_count does not change', async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = 0; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('[1/2] should not ban if violation_count does not cross the interval threshold: 0 -> 19', async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = 19; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + const randomValueUnder = Math.floor(1 + Math.random() * 19); + it(`[2/2] should not ban if violation_count does not cross the interval threshold: 0 -> ${randomValueUnder}`, async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = randomValueUnder; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); + + it('[EDGE CASE] should not ban if violation_count is lower', async () => { + errorMessage.prev_count = 0; + errorMessage.violation_count = -10; + await banViolation(req, res, errorMessage); + expect(errorMessage.ban).toBeFalsy(); + }); +}); diff --git a/api/cache/clearPendingReq.js b/api/cache/clearPendingReq.js new file mode 100644 index 000000000..d31d51d78 --- /dev/null +++ b/api/cache/clearPendingReq.js @@ -0,0 +1,29 @@ +const Keyv = require('keyv'); +const { pendingReqFile } = require('./keyvFiles'); +const { LIMIT_CONCURRENT_MESSAGES } = process.env ?? {}; + +const keyv = new Keyv({ store: pendingReqFile, namespace: 'pendingRequests' }); + +/** + * Clear pending requests from the cache. + * Checks the environmental variable LIMIT_CONCURRENT_MESSAGES; + * if the rule is enabled ('true'), pending requests in the cache are cleared. + * + * @module clearPendingReq + * @requires keyv + * @requires keyvFiles + * @requires process + * + * @async + * @function + * @returns {Promise} A promise that either clears 'pendingRequests' from store or resolves with no value. + */ +const clearPendingReq = async () => { + if (LIMIT_CONCURRENT_MESSAGES?.toLowerCase() !== 'true') { + return; + } + + await keyv.clear(); +}; + +module.exports = clearPendingReq; diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js new file mode 100644 index 000000000..5bc703fe5 --- /dev/null +++ b/api/cache/getLogStores.js @@ -0,0 +1,40 @@ +const Keyv = require('keyv'); +const keyvMongo = require('./keyvMongo'); +const { math } = require('../server/utils'); +const { logFile, violationFile } = require('./keyvFiles'); +const { BAN_DURATION } = process.env ?? {}; + +const duration = math(BAN_DURATION, 7200000); + +const namespaces = { + ban: new Keyv({ store: keyvMongo, ttl: duration, namespace: 'bans' }), + general: new Keyv({ store: logFile, namespace: 'violations' }), + concurrent: new Keyv({ store: violationFile, namespace: 'concurrent' }), + non_browser: new Keyv({ store: violationFile, namespace: 'non_browser' }), + message_limit: new Keyv({ store: violationFile, namespace: 'message_limit' }), + registrations: new Keyv({ store: violationFile, namespace: 'registrations' }), + logins: new Keyv({ store: violationFile, namespace: 'logins' }), +}; + +/** + * Returns either the logs of violations specified by type if a type is provided + * or it returns the general log if no type is specified. If an invalid type is passed, + * an error will be thrown. + * + * @module getLogStores + * @requires keyv - a simple key-value storage that allows you to easily switch out storage adapters. + * @requires keyvFiles - a module that includes the logFile and violationFile. + * + * @param {string} type - The type of violation, which can be 'concurrent', 'message_limit', 'registrations' or 'logins'. + * @returns {Keyv} - If a valid type is passed, returns an object containing the logs for violations of the specified type. + * @throws Will throw an error if an invalid violation type is passed. + */ +const getLogStores = (type) => { + if (!type) { + throw new Error(`Invalid store type: ${type}`); + } + const logs = namespaces[type]; + return logs; +}; + +module.exports = getLogStores; diff --git a/api/cache/index.js b/api/cache/index.js new file mode 100644 index 000000000..1edbf981d --- /dev/null +++ b/api/cache/index.js @@ -0,0 +1,6 @@ +const keyvFiles = require('./keyvFiles'); +const getLogStores = require('./getLogStores'); +const logViolation = require('./logViolation'); +const clearPendingReq = require('./clearPendingReq'); + +module.exports = { ...keyvFiles, getLogStores, logViolation, clearPendingReq }; diff --git a/api/cache/keyvFiles.js b/api/cache/keyvFiles.js new file mode 100644 index 000000000..f969174b7 --- /dev/null +++ b/api/cache/keyvFiles.js @@ -0,0 +1,11 @@ +const { KeyvFile } = require('keyv-file'); + +const logFile = new KeyvFile({ filename: './data/logs.json' }); +const pendingReqFile = new KeyvFile({ filename: './data/pendingReqCache.json' }); +const violationFile = new KeyvFile({ filename: './data/violations.json' }); + +module.exports = { + logFile, + pendingReqFile, + violationFile, +}; diff --git a/api/cache/keyvMongo.js b/api/cache/keyvMongo.js new file mode 100644 index 000000000..429329adc --- /dev/null +++ b/api/cache/keyvMongo.js @@ -0,0 +1,7 @@ +const KeyvMongo = require('@keyv/mongo'); +const { MONGO_URI } = process.env ?? {}; + +const keyvMongo = new KeyvMongo(MONGO_URI, { collection: 'logs' }); +keyvMongo.on('error', (err) => console.error('KeyvMongo connection error:', err)); + +module.exports = keyvMongo; diff --git a/api/cache/logViolation.js b/api/cache/logViolation.js new file mode 100644 index 000000000..0e35cf185 --- /dev/null +++ b/api/cache/logViolation.js @@ -0,0 +1,36 @@ +const getLogStores = require('./getLogStores'); +const banViolation = require('./banViolation'); + +/** + * Logs the violation. + * + * @param {Object} req - Express request object containing user information. + * @param {Object} res - Express response object. + * @param {string} type - The type of violation. + * @param {Object} errorMessage - The error message to log. + * @param {number} [score=1] - The severity of the violation. Defaults to 1 + */ +const logViolation = async (req, res, type, errorMessage, score = 1) => { + const userId = req.user?.id ?? req.user?._id; + if (!userId) { + return; + } + const logs = getLogStores('general'); + const violationLogs = getLogStores(type); + + const userViolations = (await violationLogs.get(userId)) ?? 0; + const violationCount = userViolations + score; + await violationLogs.set(userId, violationCount); + + errorMessage.user_id = userId; + errorMessage.prev_count = userViolations; + errorMessage.violation_count = violationCount; + errorMessage.date = new Date().toISOString(); + + await banViolation(req, res, errorMessage); + const userLogs = (await logs.get(userId)) ?? []; + userLogs.push(errorMessage); + await logs.set(userId, userLogs); +}; + +module.exports = logViolation; diff --git a/api/jest.config.js b/api/jest.config.js index a877e7598..a2147b221 100644 --- a/api/jest.config.js +++ b/api/jest.config.js @@ -3,5 +3,5 @@ module.exports = { clearMocks: true, roots: [''], coverageDirectory: 'coverage', - setupFiles: ['./test/jestSetup.js'], + setupFiles: ['./test/jestSetup.js', './test/__mocks__/KeyvMongo.js'], }; diff --git a/api/lib/db/index.js b/api/lib/db/index.js new file mode 100644 index 000000000..fa7a460d0 --- /dev/null +++ b/api/lib/db/index.js @@ -0,0 +1,4 @@ +const connectDb = require('./connectDb'); +const indexSync = require('./indexSync'); + +module.exports = { connectDb, indexSync }; diff --git a/api/models/Message.js b/api/models/Message.js index 64dba3dbf..adcdd9e56 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -1,5 +1,8 @@ +const { z } = require('zod'); const Message = require('./schema/messageSchema'); +const idSchema = z.string().uuid(); + module.exports = { Message, @@ -22,8 +25,9 @@ module.exports = { model = null, }) { try { - if (!conversationId) { - return console.log('Message not saved: no conversationId'); + const validConvoId = idSchema.safeParse(conversationId); + if (!validConvoId.success) { + return; } // may also need to update the conversation here await Message.findOneAndUpdate( diff --git a/api/models/Session.js b/api/models/Session.js index e1b9898bb..697fa6634 100644 --- a/api/models/Session.js +++ b/api/models/Session.js @@ -54,6 +54,21 @@ sessionSchema.methods.generateRefreshToken = async function () { } }; +sessionSchema.statics.deleteAllUserSessions = async function (userId) { + try { + if (!userId) { + return; + } + const result = await this.deleteMany({ user: userId }); + if (result && result?.deletedCount > 0) { + console.log(`Deleted ${result.deletedCount} sessions for user ${userId}.`); + } + } catch (error) { + console.log('Error in deleting user sessions:', error); + throw error; + } +}; + const Session = mongoose.model('Session', sessionSchema); module.exports = Session; diff --git a/api/package.json b/api/package.json index d548a74fa..5d1d1d602 100644 --- a/api/package.json +++ b/api/package.json @@ -6,6 +6,7 @@ "start": "echo 'please run this from the root directory'", "server-dev": "echo 'please run this from the root directory'", "test": "cross-env NODE_ENV=test jest", + "b:test": "NODE_ENV=test bun jest", "test:ci": "jest --ci" }, "repository": { @@ -45,7 +46,7 @@ "joi": "^17.9.2", "js-yaml": "^4.1.0", "jsonwebtoken": "^9.0.0", - "keyv": "^4.5.2", + "keyv": "^4.5.3", "keyv-file": "^0.2.0", "langchain": "^0.0.144", "lodash": "^4.17.21", @@ -64,6 +65,7 @@ "pino": "^8.12.1", "sanitize": "^2.1.2", "sharp": "^0.32.5", + "ua-parser-js": "^1.0.36", "zod": "^3.22.2" }, "devDependencies": { diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index fbf85bec2..361c3464b 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -80,7 +80,7 @@ const refreshController = async (req, res) => { const userId = payload.id; const user = await User.findOne({ _id: userId }); if (!user) { - return res.status(401).send('User not found'); + return res.status(401).redirect('/login'); } if (process.env.NODE_ENV === 'development') { @@ -99,6 +99,8 @@ const refreshController = async (req, res) => { const token = await setAuthTokens(userId, res, session._id); const userObj = user.toJSON(); res.status(200).send({ token, user: userObj }); + } else if (payload.exp > Date.now() / 1000) { + res.status(403).redirect('/login'); } else { res.status(401).send('Refresh token expired or not found for this user'); } diff --git a/api/server/index.js b/api/server/index.js index 8a6f43819..496f0ac42 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -1,16 +1,16 @@ const express = require('express'); const mongoSanitize = require('express-mongo-sanitize'); -const connectDb = require('../lib/db/connectDb'); -const indexSync = require('../lib/db/indexSync'); +const { connectDb, indexSync } = require('../lib/db'); const path = require('path'); const cors = require('cors'); const routes = require('./routes'); const errorController = require('./controllers/ErrorController'); const passport = require('passport'); const configureSocialLogins = require('./socialLogins'); +const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {}; -const port = Number(process.env.PORT) || 3080; -const host = process.env.HOST || 'localhost'; +const port = Number(PORT) || 3080; +const host = HOST || 'localhost'; const projectPath = path.join(__dirname, '..', '..', 'client'); const { jwtLogin, passportLogin } = require('../strategies'); @@ -31,7 +31,7 @@ const startServer = async () => { app.set('trust proxy', 1); // trust first proxy app.use(cors()); - if (!process.env.ALLOW_SOCIAL_LOGIN) { + if (!ALLOW_SOCIAL_LOGIN) { console.warn( 'Social logins are disabled. Set Envrionment Variable "ALLOW_SOCIAL_LOGIN" to true to enable them.', ); @@ -42,7 +42,7 @@ const startServer = async () => { passport.use(await jwtLogin()); passport.use(passportLogin()); - if (process.env.ALLOW_SOCIAL_LOGIN === 'true') { + if (ALLOW_SOCIAL_LOGIN?.toLowerCase() === 'true') { configureSocialLogins(app); } diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 54e11b7f5..80cf26ba4 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,6 +1,5 @@ -const crypto = require('crypto'); const { saveMessage, getConvo, getConvoTitle } = require('../../models'); -const { sendMessage, handleError } = require('../utils'); +const { sendMessage, sendError } = require('../utils'); const abortControllers = require('./abortControllers'); async function abortMessage(req, res) { @@ -27,8 +26,9 @@ const handleAbort = () => { }; }; -const createAbortController = (res, req, endpointOption, getAbortData) => { +const createAbortController = (req, res, getAbortData) => { const abortController = new AbortController(); + const { endpointOption } = req.body; const onStart = (userMessage) => { sendMessage(res, { message: userMessage, created: true }); const abortKey = userMessage?.conversationId ?? req.user.id; @@ -73,25 +73,23 @@ const handleAbortError = async (res, req, error, data) => { const { sender, conversationId, messageId, parentMessageId, partialText } = data; const respondWithError = async () => { - const errorMessage = { + const options = { sender, - messageId: messageId ?? crypto.randomUUID(), + messageId, conversationId, parentMessageId, - unfinished: false, - cancelled: false, - error: true, - final: true, text: error.message, - isCreatedByUser: false, + shouldSaveMessage: true, }; - if (abortControllers.has(conversationId)) { - const { abortController } = abortControllers.get(conversationId); - abortController.abort(); - abortControllers.delete(conversationId); - } - await saveMessage(errorMessage); - handleError(res, errorMessage); + const callback = async () => { + if (abortControllers.has(conversationId)) { + const { abortController } = abortControllers.get(conversationId); + abortController.abort(); + abortControllers.delete(conversationId); + } + }; + + await sendError(res, options, callback); }; if (partialText && partialText.length > 5) { diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js new file mode 100644 index 000000000..294f4a668 --- /dev/null +++ b/api/server/middleware/checkBan.js @@ -0,0 +1,92 @@ +const Keyv = require('keyv'); +const uap = require('ua-parser-js'); +const { getLogStores } = require('../../cache'); +const denyRequest = require('./denyRequest'); +const { isEnabled, removePorts } = require('../utils'); + +const banCache = new Keyv({ namespace: 'bans', ttl: 0 }); +const message = 'Your account has been temporarily banned due to violations of our service.'; + +/** + * Respond to the request if the user is banned. + * + * @async + * @function + * @param {Object} req - Express Request object. + * @param {Object} res - Express Response object. + * @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request. + * + * @returns {Promise} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function. + */ +const banResponse = async (req, res) => { + const ua = uap(req.headers['user-agent']); + const { baseUrl } = req; + if (!ua.browser.name) { + return res.status(403).json({ message }); + } else if (baseUrl === '/api/ask' || baseUrl === '/api/edit') { + return await denyRequest(req, res, { type: 'ban' }); + } + + return res.status(403).json({ message }); +}; + +/** + * Checks if the source IP or user is banned or not. + * + * @async + * @function + * @param {Object} req - Express request object. + * @param {Object} res - Express response object. + * @param {Function} next - Next middleware function. + * + * @returns {Promise} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`. + */ +const checkBan = async (req, res, next = () => {}) => { + const { BAN_VIOLATIONS } = process.env ?? {}; + + if (!isEnabled(BAN_VIOLATIONS)) { + return next(); + } + + req.ip = removePorts(req); + const userId = req.user?.id ?? req.user?._id ?? null; + + const cachedIPBan = await banCache.get(req.ip); + const cachedUserBan = await banCache.get(userId); + const cachedBan = cachedIPBan || cachedUserBan; + + if (cachedBan) { + req.banned = true; + return await banResponse(req, res); + } + + const banLogs = getLogStores('ban'); + const duration = banLogs.opts.ttl; + + if (duration <= 0) { + return next(); + } + + const ipBan = await banLogs.get(req.ip); + const userBan = await banLogs.get(userId); + const isBanned = ipBan || userBan; + + if (!isBanned) { + return next(); + } + + const timeLeft = Number(isBanned.expiresAt) - Date.now(); + + if (timeLeft <= 0) { + await banLogs.delete(req.ip); + await banLogs.delete(userId); + return next(); + } + + banCache.set(req.ip, isBanned, timeLeft); + banCache.set(userId, isBanned, timeLeft); + req.banned = true; + return await banResponse(req, res); +}; + +module.exports = checkBan; diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js new file mode 100644 index 000000000..d110b1b86 --- /dev/null +++ b/api/server/middleware/concurrentLimiter.js @@ -0,0 +1,81 @@ +const Keyv = require('keyv'); +const { logViolation } = require('../../cache'); + +const denyRequest = require('./denyRequest'); + +// Serve cache from memory so no need to clear it on startup/exit +const pendingReqCache = new Keyv({ namespace: 'pendingRequests' }); + +/** + * Middleware to limit concurrent requests for a user. + * + * This middleware checks if a user has exceeded a specified concurrent request limit. + * If the user exceeds the limit, an error is returned. If the user is within the limit, + * their request count is incremented. After the request is processed, the count is decremented. + * If the `pendingReqCache` store is not available, the middleware will skip its logic. + * + * @function + * @param {Object} req - Express request object containing user information. + * @param {Object} res - Express response object. + * @param {function} next - Express next middleware function. + * @throws {Error} Throws an error if the user exceeds the concurrent request limit. + */ +const concurrentLimiter = async (req, res, next) => { + if (!pendingReqCache) { + return next(); + } + + if (Object.keys(req?.body ?? {}).length === 1 && req?.body?.abortKey) { + return next(); + } + + const { CONCURRENT_MESSAGE_MAX = 1, CONCURRENT_VIOLATION_SCORE: score } = process.env; + const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1); + const type = 'concurrent'; + + const userId = req.user?.id ?? req.user?._id ?? null; + const pendingRequests = (await pendingReqCache.get(userId)) ?? 0; + + if (pendingRequests >= limit) { + const errorMessage = { + type, + limit, + pendingRequests, + }; + + await logViolation(req, res, type, errorMessage, score); + return await denyRequest(req, res, errorMessage); + } else { + await pendingReqCache.set(userId, pendingRequests + 1); + } + + // Ensure the requests are removed from the store once the request is done + const cleanUp = async () => { + if (!pendingReqCache) { + return; + } + + const currentRequests = await pendingReqCache.get(userId); + + if (currentRequests && currentRequests >= 1) { + await pendingReqCache.set(userId, currentRequests - 1); + } else { + await pendingReqCache.delete(userId); + } + }; + + if (pendingRequests < limit) { + res.on('finish', cleanUp); + res.on('close', cleanUp); + } + + next(); +}; + +// if cache is not served from memory, clear it on exit +// process.on('exit', async () => { +// console.log('Clearing all pending requests before exiting...'); +// await pendingReqCache.clear(); +// }); + +module.exports = concurrentLimiter; diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js new file mode 100644 index 000000000..64ca86c63 --- /dev/null +++ b/api/server/middleware/denyRequest.js @@ -0,0 +1,58 @@ +const crypto = require('crypto'); +const { sendMessage, sendError } = require('../utils'); +const { getResponseSender } = require('../routes/endpoints/schemas'); +const { saveMessage } = require('../../models'); + +/** + * Denies a request by sending an error message and optionally saves the user's message. + * + * @async + * @function + * @param {Object} req - Express request object. + * @param {Object} req.body - The body of the request. + * @param {string} [req.body.messageId] - The ID of the message. + * @param {string} [req.body.conversationId] - The ID of the conversation. + * @param {string} [req.body.parentMessageId] - The ID of the parent message. + * @param {string} req.body.text - The text of the message. + * @param {Object} res - Express response object. + * @param {string} errorMessage - The error message to be sent. + * @returns {Promise} A promise that resolves with the error response. + * @throws {Error} Throws an error if there's an issue saving the message or sending the error. + */ +const denyRequest = async (req, res, errorMessage) => { + let responseText = errorMessage; + if (typeof errorMessage === 'object') { + responseText = JSON.stringify(errorMessage); + } + + const { messageId, conversationId: _convoId, parentMessageId, text } = req.body; + const conversationId = _convoId ?? crypto.randomUUID(); + + const userMessage = { + sender: 'User', + messageId: messageId ?? crypto.randomUUID(), + parentMessageId, + conversationId, + isCreatedByUser: true, + text, + }; + sendMessage(res, { message: userMessage, created: true }); + + const shouldSaveMessage = + _convoId && parentMessageId && parentMessageId !== '00000000-0000-0000-0000-000000000000'; + + if (shouldSaveMessage) { + await saveMessage(userMessage); + } + + return await sendError(res, { + sender: getResponseSender(req.body), + messageId: crypto.randomUUID(), + conversationId, + parentMessageId: userMessage.messageId, + text: responseText, + shouldSaveMessage, + }); +}; + +module.exports = denyRequest; diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index eb1f53870..553f2c663 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -1,22 +1,30 @@ const abortMiddleware = require('./abortMiddleware'); +const checkBan = require('./checkBan'); +const uaParser = require('./uaParser'); const setHeaders = require('./setHeaders'); const loginLimiter = require('./loginLimiter'); const requireJwtAuth = require('./requireJwtAuth'); const registerLimiter = require('./registerLimiter'); +const messageLimiters = require('./messageLimiters'); const requireLocalAuth = require('./requireLocalAuth'); const validateEndpoint = require('./validateEndpoint'); +const concurrentLimiter = require('./concurrentLimiter'); const validateMessageReq = require('./validateMessageReq'); const buildEndpointOption = require('./buildEndpointOption'); const validateRegistration = require('./validateRegistration'); module.exports = { ...abortMiddleware, + ...messageLimiters, + checkBan, + uaParser, setHeaders, loginLimiter, requireJwtAuth, registerLimiter, requireLocalAuth, validateEndpoint, + concurrentLimiter, validateMessageReq, buildEndpointOption, validateRegistration, diff --git a/api/server/middleware/loginLimiter.js b/api/server/middleware/loginLimiter.js index bca07d0a7..bdc95e287 100644 --- a/api/server/middleware/loginLimiter.js +++ b/api/server/middleware/loginLimiter.js @@ -1,16 +1,30 @@ const rateLimit = require('express-rate-limit'); -const windowMs = (process.env?.LOGIN_WINDOW ?? 5) * 60 * 1000; // default: 5 minutes -const max = process.env?.LOGIN_MAX ?? 7; // default: limit each IP to 7 requests per windowMs +const { logViolation } = require('../../cache'); +const { removePorts } = require('../utils'); + +const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env; +const windowMs = LOGIN_WINDOW * 60 * 1000; +const max = LOGIN_MAX; const windowInMinutes = windowMs / 60000; +const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`; + +const handler = async (req, res) => { + const type = 'logins'; + const errorMessage = { + type, + max, + windowInMinutes, + }; + + await logViolation(req, res, type, errorMessage, score); + return res.status(429).json({ message }); +}; const loginLimiter = rateLimit({ windowMs, max, - message: `Too many login attempts from this IP, please try again after ${windowInMinutes} minutes.`, - keyGenerator: function (req) { - // Strip out the port number from the IP address - return req.ip.replace(/:\d+[^:]*$/, ''); - }, + handler, + keyGenerator: removePorts, }); module.exports = loginLimiter; diff --git a/api/server/middleware/messageLimiters.js b/api/server/middleware/messageLimiters.js new file mode 100644 index 000000000..63bac7e18 --- /dev/null +++ b/api/server/middleware/messageLimiters.js @@ -0,0 +1,67 @@ +const rateLimit = require('express-rate-limit'); +const { logViolation } = require('../../cache'); +const denyRequest = require('./denyRequest'); + +const { + MESSAGE_IP_MAX = 40, + MESSAGE_IP_WINDOW = 1, + MESSAGE_USER_MAX = 40, + MESSAGE_USER_WINDOW = 1, +} = process.env; + +const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000; +const ipMax = MESSAGE_IP_MAX; +const ipWindowInMinutes = ipWindowMs / 60000; + +const userWindowMs = MESSAGE_USER_WINDOW * 60 * 1000; +const userMax = MESSAGE_USER_MAX; +const userWindowInMinutes = userWindowMs / 60000; + +/** + * Creates either an IP/User message request rate limiter for excessive requests + * that properly logs and denies the violation. + * + * @param {boolean} [ip=true] - Whether to create an IP limiter or a user limiter. + * @returns {function} A rate limiter function. + * + */ +const createHandler = (ip = true) => { + return async (req, res) => { + const type = 'message_limit'; + const errorMessage = { + type, + max: ip ? ipMax : userMax, + limiter: ip ? 'ip' : 'user', + windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes, + }; + + await logViolation(req, res, type, errorMessage); + return await denyRequest(req, res, errorMessage); + }; +}; + +/** + * Message request rate limiter by IP + */ +const messageIpLimiter = rateLimit({ + windowMs: ipWindowMs, + max: ipMax, + handler: createHandler(), +}); + +/** + * Message request rate limiter by userId + */ +const messageUserLimiter = rateLimit({ + windowMs: userWindowMs, + max: userMax, + handler: createHandler(false), + keyGenerator: function (req) { + return req.user?.id; // Use the user ID or NULL if not available + }, +}); + +module.exports = { + messageIpLimiter, + messageUserLimiter, +}; diff --git a/api/server/middleware/registerLimiter.js b/api/server/middleware/registerLimiter.js index df2d3d1ca..e19e261cb 100644 --- a/api/server/middleware/registerLimiter.js +++ b/api/server/middleware/registerLimiter.js @@ -1,16 +1,30 @@ const rateLimit = require('express-rate-limit'); -const windowMs = (process.env?.REGISTER_WINDOW ?? 60) * 60 * 1000; // default: 1 hour -const max = process.env?.REGISTER_MAX ?? 5; // default: limit each IP to 5 registrations per windowMs +const { logViolation } = require('../../cache'); +const { removePorts } = require('../utils'); + +const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env; +const windowMs = REGISTER_WINDOW * 60 * 1000; +const max = REGISTER_MAX; const windowInMinutes = windowMs / 60000; +const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`; + +const handler = async (req, res) => { + const type = 'registrations'; + const errorMessage = { + type, + max, + windowInMinutes, + }; + + await logViolation(req, res, type, errorMessage, score); + return res.status(429).json({ message }); +}; const registerLimiter = rateLimit({ windowMs, max, - message: `Too many accounts created from this IP, please try again after ${windowInMinutes} minutes`, - keyGenerator: function (req) { - // Strip out the port number from the IP address - return req.ip.replace(/:\d+[^:]*$/, ''); - }, + handler, + keyGenerator: removePorts, }); module.exports = registerLimiter; diff --git a/api/server/middleware/uaParser.js b/api/server/middleware/uaParser.js new file mode 100644 index 000000000..f5b726dd3 --- /dev/null +++ b/api/server/middleware/uaParser.js @@ -0,0 +1,31 @@ +const uap = require('ua-parser-js'); +const { handleError } = require('../utils'); +const { logViolation } = require('../../cache'); + +/** + * Middleware to parse User-Agent header and check if it's from a recognized browser. + * If the User-Agent is not recognized as a browser, logs a violation and sends an error response. + * + * @function + * @async + * @param {Object} req - Express request object. + * @param {Object} res - Express response object. + * @param {Function} next - Express next middleware function. + * @returns {void} Sends an error response if the User-Agent is not recognized as a browser. + * + * @example + * app.use(uaParser); + */ +async function uaParser(req, res, next) { + const { NON_BROWSER_VIOLATION_SCORE: score = 20 } = process.env; + const ua = uap(req.headers['user-agent']); + + if (!ua.browser.name) { + const type = 'non_browser'; + await logViolation(req, res, type, { type }, score); + return handleError(res, { message: 'Illegal request' }); + } + next(); +} + +module.exports = uaParser; diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index 637fc090a..673fd185d 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -7,138 +7,125 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { sendMessage, createOnProgress } = require('../../utils'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; - const getIds = (data) => { - userMessage = data.userMessage; - userMessageId = data.userMessage.messageId; - responseMessageId = data.responseMessageId; - if (!conversationId) { - conversationId = data.conversationId; - } - }; - - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - unfinished: true, - cancelled: false, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - try { - const getAbortData = () => ({ - conversationId, - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - getIds, - // debug: true, - user: req.user.id, - conversationId, - parentMessageId, - overrideParentMessageId, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId ?? userMessageId, - }), - onStart, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - await saveConvo(req.user.id, { - ...endpointOption, - ...endpointOption.modelOptions, - conversationId, - endpoint: 'anthropic', - }); - - await saveMessage(response); - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - // TODO: add anthropic titling - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const getIds = (data) => { + userMessage = data.userMessage; + userMessageId = data.userMessage.messageId; + responseMessageId = data.responseMessageId; + if (!conversationId) { + conversationId = data.conversationId; } - }, -); + }; + + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + unfinished: true, + cancelled: false, + error: false, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + try { + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + getIds, + // debug: true, + user: req.user.id, + conversationId, + parentMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + onStart, + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + await saveConvo(req.user.id, { + ...endpointOption, + ...endpointOption.modelOptions, + conversationId, + endpoint: 'anthropic', + }); + + await saveMessage(response); + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + // TODO: add anthropic titling + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index b590314f9..1c916265c 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -4,9 +4,9 @@ const router = express.Router(); const { browserClient } = require('../../../app/'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils'); -const { requireJwtAuth, setHeaders } = require('../../middleware'); +const { setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, setHeaders, async (req, res) => { +router.post('/', setHeaders, async (req, res) => { const { endpoint, text, diff --git a/api/server/routes/ask/bingAI.js b/api/server/routes/ask/bingAI.js index 10bf55442..f3047c285 100644 --- a/api/server/routes/ask/bingAI.js +++ b/api/server/routes/ask/bingAI.js @@ -4,9 +4,9 @@ const router = express.Router(); const { titleConvoBing, askBing } = require('../../../app'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils'); -const { requireJwtAuth, setHeaders } = require('../../middleware'); +const { setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, setHeaders, async (req, res) => { +router.post('/', setHeaders, async (req, res) => { const { endpoint, text, diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index aa77c3129..5742120b0 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -5,9 +5,9 @@ const { GoogleClient } = require('../../../app'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress } = require('../../utils'); const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService'); -const { requireJwtAuth, setHeaders } = require('../../middleware'); +const { setHeaders } = require('../../middleware'); -router.post('/', requireJwtAuth, setHeaders, async (req, res) => { +router.post('/', setHeaders, async (req, res) => { const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body; if (text.length === 0) { return handleError(res, { text: 'Prompt empty or too short' }); diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 890228cef..330f9404d 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -11,218 +11,205 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const newConvo = !conversationId; - const user = req.user.id; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const newConvo = !conversationId; + const user = req.user.id; - const plugins = []; + const plugins = []; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - userMessageId = userMessage.messageId; - responseMessageId = data.responseMessageId; - if (!conversationId) { - conversationId = data.conversationId; - } - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + userMessageId = userMessage.messageId; + responseMessageId = data.responseMessageId; + if (!conversationId) { + conversationId = data.conversationId; + } + }; - let streaming = null; - let timer = null; + let streaming = null; + let timer = null; - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (timer) { - clearTimeout(timer); - } - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - error: false, - plugins, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - - streaming = new Promise((resolve) => { - timer = setTimeout(() => { - resolve(); - }, 250); - }); - }, - }); - - const pluginMap = new Map(); - const onAgentAction = async (action, runId) => { - pluginMap.set(runId, action.tool); - sendIntermediateMessage(res, { plugins }); - }; - - const onToolStart = async (tool, input, runId, parentRunId) => { - const pluginName = pluginMap.get(parentRunId); - const latestPlugin = { - runId, - loading: true, - inputs: [input], - latest: pluginName, - outputs: null, - }; - - if (streaming) { - await streaming; - } - const extraTokens = ':::plugin:::\n'; - plugins.push(latestPlugin); - sendIntermediateMessage(res, { plugins }, extraTokens); - }; - - const onToolEnd = async (output, runId) => { - if (streaming) { - await streaming; + if (timer) { + clearTimeout(timer); } - const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); - - if (pluginIndex !== -1) { - plugins[pluginIndex].loading = false; - plugins[pluginIndex].outputs = output; - } - }; - - const onChainEnd = () => { - saveMessage(userMessage); - sendIntermediateMessage(res, { plugins }); - }; - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugins: plugins.map((p) => ({ ...p, loading: false })), - userMessage, - }); - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user, - conversationId, - parentMessageId, - overrideParentMessageId, - getIds, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - onStart, - addMetadata, - getPartialText, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId || userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, plugins, - }), - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - if (metadata) { - response = { ...response, ...metadata }; - } - - console.log('CLIENT RESPONSE'); - console.dir(response, { depth: null }); - response.plugins = plugins.map((p) => ({ ...p, loading: false })); - await saveMessage(response); - - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { - addTitle(req, { - text, - response, - client, }); } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + + if (saveDelay < 500) { + saveDelay = 500; + } + + streaming = new Promise((resolve) => { + timer = setTimeout(() => { + resolve(); + }, 250); + }); + }, + }); + + const pluginMap = new Map(); + const onAgentAction = async (action, runId) => { + pluginMap.set(runId, action.tool); + sendIntermediateMessage(res, { plugins }); + }; + + const onToolStart = async (tool, input, runId, parentRunId) => { + const pluginName = pluginMap.get(parentRunId); + const latestPlugin = { + runId, + loading: true, + inputs: [input], + latest: pluginName, + outputs: null, + }; + + if (streaming) { + await streaming; + } + const extraTokens = ':::plugin:::\n'; + plugins.push(latestPlugin); + sendIntermediateMessage(res, { plugins }, extraTokens); + }; + + const onToolEnd = async (output, runId) => { + if (streaming) { + await streaming; + } + + const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); + + if (pluginIndex !== -1) { + plugins[pluginIndex].loading = false; + plugins[pluginIndex].outputs = output; + } + }; + + const onChainEnd = () => { + saveMessage(userMessage); + sendIntermediateMessage(res, { plugins }); + }; + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugins: plugins.map((p) => ({ ...p, loading: false })), + userMessage, + }); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + conversationId, + parentMessageId, + overrideParentMessageId, + getIds, + onAgentAction, + onChainEnd, + onToolStart, + onToolEnd, + onStart, + addMetadata, + getPartialText, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + plugins, + }), + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log('CLIENT RESPONSE'); + console.dir(response, { depth: null }); + response.plugins = plugins.map((p) => ({ ...p, loading: false })); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { + addTitle(req, { + text, + response, + client, }); } - }, -); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js index 77da50f68..d87daa6a8 100644 --- a/api/server/routes/ask/index.js +++ b/api/server/routes/ask/index.js @@ -6,6 +6,33 @@ const bingAI = require('./bingAI'); const gptPlugins = require('./gptPlugins'); const askChatGPTBrowser = require('./askChatGPTBrowser'); const anthropic = require('./anthropic'); +const { + uaParser, + checkBan, + requireJwtAuth, + concurrentLimiter, + messageIpLimiter, + messageUserLimiter, +} = require('../../middleware'); +const { isEnabled } = require('../../utils'); + +const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; + +router.use(requireJwtAuth); +router.use(checkBan); +router.use(uaParser); + +if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + router.use(concurrentLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_IP)) { + router.use(messageIpLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_USER)) { + router.use(messageUserLimiter); +} router.use(['/azureOpenAI', '/openAI'], openAI); router.use('/google', google); diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index 5e5f44508..fb662809d 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -9,151 +9,138 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('ask log'); - console.dir({ text, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const newConvo = !conversationId; - const user = req.user.id; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('ask log'); + console.dir({ text, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const newConvo = !conversationId; + const user = req.user.id; - const addMetadata = (data) => (metadata = data); + const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - userMessageId = userMessage.messageId; - responseMessageId = data.responseMessageId; - if (!conversationId) { - conversationId = data.conversationId; - } - }; + const getIds = (data) => { + userMessage = data.userMessage; + userMessageId = userMessage.messageId; + responseMessageId = data.responseMessageId; + if (!conversationId) { + conversationId = data.conversationId; + } + }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user, - parentMessageId, - conversationId, - overrideParentMessageId, - getIds, - onStart, - addMetadata, - abortController, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId || userMessageId, - }), - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - if (metadata) { - response = { ...response, ...metadata }; - } - - console.log( - 'promptTokens, completionTokens:', - response.promptTokens, - response.completionTokens, - ); - await saveMessage(response); - - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { - addTitle(req, { - text, - response, - client, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + error: false, }); } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + parentMessageId, + conversationId, + overrideParentMessageId, + getIds, + onStart, + addMetadata, + abortController, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + }), + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log( + 'promptTokens, completionTokens:', + response.promptTokens, + response.completionTokens, + ); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) { + addTitle(req, { + text, + response, + client, }); } - }, -); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index 1ccbcb34b..862a098fa 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -8,6 +8,7 @@ const { const { loginController } = require('../controllers/auth/LoginController'); const { logoutController } = require('../controllers/auth/LogoutController'); const { + checkBan, loginLimiter, registerLimiter, requireJwtAuth, @@ -19,9 +20,9 @@ const router = express.Router(); //Local router.post('/logout', requireJwtAuth, logoutController); -router.post('/login', loginLimiter, requireLocalAuth, loginController); +router.post('/login', loginLimiter, checkBan, requireLocalAuth, loginController); router.post('/refresh', refreshController); -router.post('/register', registerLimiter, validateRegistration, registrationController); +router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController); router.post('/requestPasswordReset', resetPasswordRequestController); router.post('/resetPassword', resetPasswordController); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 66b3ffc0a..d4b919d30 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -4,12 +4,14 @@ const { getConvo, saveConvo } = require('../../models'); const { getConvosByPage, deleteConvos } = require('../../models/Conversation'); const requireJwtAuth = require('../middleware/requireJwtAuth'); -router.get('/', requireJwtAuth, async (req, res) => { +router.use(requireJwtAuth); + +router.get('/', async (req, res) => { const pageNumber = req.query.pageNumber || 1; res.status(200).send(await getConvosByPage(req.user.id, pageNumber)); }); -router.get('/:conversationId', requireJwtAuth, async (req, res) => { +router.get('/:conversationId', async (req, res) => { const { conversationId } = req.params; const convo = await getConvo(req.user.id, conversationId); @@ -20,7 +22,7 @@ router.get('/:conversationId', requireJwtAuth, async (req, res) => { } }); -router.post('/clear', requireJwtAuth, async (req, res) => { +router.post('/clear', async (req, res) => { let filter = {}; const { conversationId, source } = req.body.arg; if (conversationId) { @@ -43,7 +45,7 @@ router.post('/clear', requireJwtAuth, async (req, res) => { } }); -router.post('/update', requireJwtAuth, async (req, res) => { +router.post('/update', async (req, res) => { const update = req.body.arg; try { diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index 0fe67fb56..5695d67cc 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -7,140 +7,127 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); const { saveMessage, getConvoTitle, getConvo } = require('../../../models'); const { sendMessage, createOnProgress } = require('../../utils'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('edit log'); - console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const userMessageId = parentMessageId; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - responseMessageId = data.responseMessageId; - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - unfinished: true, - cancelled: false, - isEdited: true, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - try { - const getAbortData = () => ({ - conversationId, - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user: req.user.id, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId ?? userMessageId, - }), - getIds, - onStart, - addMetadata, - abortController, - }); - - if (metadata) { - response = { ...response, ...metadata }; + text: partialText, + unfinished: true, + cancelled: false, + isEdited: true, + error: false, + }); } - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; + if (saveDelay < 500) { + saveDelay = 500; } + }, + }); + try { + const getAbortData = () => ({ + conversationId, + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); - await saveMessage(response); - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + const { abortController, onStart } = createAbortController(req, res, getAbortData); - // TODO: add anthropic titling - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user: req.user.id, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId ?? userMessageId, + }), + getIds, + onStart, + addMetadata, + abortController, + }); + + if (metadata) { + response = { ...response, ...metadata }; } - }, -); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + await saveMessage(response); + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + // TODO: add anthropic titling + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index c05592323..b180c844f 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -10,183 +10,170 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('edit log'); - console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const userMessageId = parentMessageId; - const user = req.user.id; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; + const user = req.user.id; - const plugin = { - loading: true, - inputs: [], - latest: null, - outputs: null, - }; + const plugin = { + loading: true, + inputs: [], + latest: null, + outputs: null, + }; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - responseMessageId = data.responseMessageId; - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (plugin.loading === true) { - plugin.loading = false; - } - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - isEdited: true, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const onAgentAction = (action, start = false) => { - const formattedAction = formatAction(action); - plugin.inputs.push(formattedAction); - plugin.latest = formattedAction.plugin; - if (!start) { - saveMessage(userMessage); + if (plugin.loading === true) { + plugin.loading = false; } - sendIntermediateMessage(res, { plugin }); - // console.log('PLUGIN ACTION', formattedAction); - }; - const onChainEnd = (data) => { - let { intermediateSteps: steps } = data; - plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; - plugin.loading = false; - saveMessage(userMessage); - sendIntermediateMessage(res, { plugin }); - // console.log('CHAIN END', plugin.outputs); - }; - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugin: { ...plugin, loading: false }, - userMessage, - }); - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getIds, - onAgentAction, - onChainEnd, - onStart, - addMetadata, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - plugin, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId || userMessageId, - }), - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + isEdited: true, + error: false, + }); } - if (metadata) { - response = { ...response, ...metadata }; + if (saveDelay < 500) { + saveDelay = 500; } + }, + }); - console.log('CLIENT RESPONSE'); - console.dir(response, { depth: null }); - response.plugin = { ...plugin, loading: false }; - await saveMessage(response); - - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const onAgentAction = (action, start = false) => { + const formattedAction = formatAction(action); + plugin.inputs.push(formattedAction); + plugin.latest = formattedAction.plugin; + if (!start) { + saveMessage(userMessage); } - }, -); + sendIntermediateMessage(res, { plugin }); + // console.log('PLUGIN ACTION', formattedAction); + }; + + const onChainEnd = (data) => { + let { intermediateSteps: steps } = data; + plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; + plugin.loading = false; + saveMessage(userMessage); + sendIntermediateMessage(res, { plugin }); + // console.log('CHAIN END', plugin.outputs); + }; + + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugin: { ...plugin, loading: false }, + userMessage, + }); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getIds, + onAgentAction, + onChainEnd, + onStart, + addMetadata, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + plugin, + parentMessageId: overrideParentMessageId || userMessageId, + }), + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + console.log('CLIENT RESPONSE'); + console.dir(response, { depth: null }); + response.plugin = { ...plugin, loading: false }; + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/edit/index.js b/api/server/routes/edit/index.js index 7eda18b8a..dcf5ff553 100644 --- a/api/server/routes/edit/index.js +++ b/api/server/routes/edit/index.js @@ -3,11 +3,36 @@ const router = express.Router(); const openAI = require('./openAI'); const gptPlugins = require('./gptPlugins'); const anthropic = require('./anthropic'); -// const google = require('./google'); +const { + checkBan, + uaParser, + requireJwtAuth, + concurrentLimiter, + messageIpLimiter, + messageUserLimiter, +} = require('../../middleware'); +const { isEnabled } = require('../../utils'); + +const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; + +router.use(requireJwtAuth); +router.use(checkBan); +router.use(uaParser); + +if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + router.use(concurrentLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_IP)) { + router.use(messageIpLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_USER)) { + router.use(messageUserLimiter); +} router.use(['/azureOpenAI', '/openAI'], openAI); router.use('/gptPlugins', gptPlugins); router.use('/anthropic', anthropic); -// router.use('/google', google); module.exports = router; diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index 1a9dc3ff2..8af7ee206 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -9,140 +9,127 @@ const { createAbortController, handleAbortError, setHeaders, - requireJwtAuth, validateEndpoint, buildEndpointOption, } = require('../../middleware'); -router.post('/abort', requireJwtAuth, handleAbort()); +router.post('/abort', handleAbort()); -router.post( - '/', - requireJwtAuth, - validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - console.log('edit log'); - console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); - let metadata; - let userMessage; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const userMessageId = parentMessageId; +router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + console.log('edit log'); + console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null }); + let metadata; + let userMessage; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const userMessageId = parentMessageId; - const addMetadata = (data) => (metadata = data); - const getIds = (data) => { - userMessage = data.userMessage; - responseMessageId = data.responseMessageId; - }; + const addMetadata = (data) => (metadata = data); + const getIds = (data) => { + userMessage = data.userMessage; + responseMessageId = data.responseMessageId; + }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender: getResponseSender(endpointOption), - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - cancelled: false, - isEdited: true, - error: false, - }); - } - - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const getAbortData = () => ({ - sender: getResponseSender(endpointOption), - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - }); - - const { abortController, onStart } = createAbortController( - res, - req, - endpointOption, - getAbortData, - ); - - try { - const { client } = await initializeClient(req, endpointOption); - - let response = await client.sendMessage(text, { - user: req.user.id, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getIds, - onStart, - addMetadata, - abortController, - onProgress: progressCallback.call(null, { - res, - text, + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: getResponseSender(endpointOption), + conversationId, parentMessageId: overrideParentMessageId || userMessageId, - }), - }); - - if (metadata) { - response = { ...response, ...metadata }; + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + cancelled: false, + isEdited: true, + error: false, + }); } - console.log( - 'promptTokens, completionTokens:', - response.promptTokens, - response.completionTokens, - ); - await saveMessage(response); + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); - sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), - final: true, - conversation: await getConvo(req.user.id, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender: getResponseSender(endpointOption), - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); + const getAbortData = () => ({ + sender: getResponseSender(endpointOption), + conversationId, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + userMessage, + }); + + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + const { client } = await initializeClient(req, endpointOption); + + let response = await client.sendMessage(text, { + user: req.user.id, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getIds, + onStart, + addMetadata, + abortController, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + }), + }); + + if (metadata) { + response = { ...response, ...metadata }; } - }, -); + + console.log( + 'promptTokens, completionTokens:', + response.promptTokens, + response.completionTokens, + ); + await saveMessage(response); + + sendMessage(res, { + title: await getConvoTitle(req.user.id, conversationId), + final: true, + conversation: await getConvo(req.user.id, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender: getResponseSender(endpointOption), + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } +}); module.exports = router; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 556603e9e..f64930c75 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -3,8 +3,24 @@ const express = require('express'); const router = express.Router(); const config = require('../../../config/loader'); const { setAuthTokens } = require('../services/AuthService'); +const { loginLimiter, checkBan } = require('../middleware'); const domains = config.domains; +router.use(loginLimiter); + +const oauthHandler = async (req, res) => { + try { + await checkBan(req, res); + if (req.banned) { + return; + } + await setAuthTokens(req.user._id, res); + res.redirect(domains.client); + } catch (err) { + console.error('Error in setting authentication tokens:', err); + } +}; + /** * Google Routes */ @@ -24,14 +40,7 @@ router.get( session: false, scope: ['openid', 'profile', 'email'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( @@ -52,14 +61,7 @@ router.get( scope: ['public_profile'], profileFields: ['id', 'email', 'name'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( @@ -76,14 +78,7 @@ router.get( failureMessage: true, session: false, }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( @@ -102,14 +97,7 @@ router.get( session: false, scope: ['user:email', 'read:user'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); router.get( '/discord', @@ -127,14 +115,7 @@ router.get( session: false, scope: ['identify', 'email'], }), - async (req, res) => { - try { - await setAuthTokens(req.user._id, res); - res.redirect(domains.client); - } catch (err) { - console.error('Error in setting authentication tokens:', err); - } - }, + oauthHandler, ); module.exports = router; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 4715610d7..3ae18e98c 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,22 +1,11 @@ const partialRight = require('lodash/partialRight'); -const citationRegex = /\[\^\d+?\^]/g; const { getCitations, citeText } = require('./citations'); +const { sendMessage } = require('./streamResponse'); const cursor = ''; +const citationRegex = /\[\^\d+?\^]/g; const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); -const handleError = (res, message) => { - res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); - res.end(); -}; - -const sendMessage = (res, message, event = 'message') => { - if (message.length === 0) { - return; - } - res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); -}; - const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { let i = 0; let code = ''; @@ -148,10 +137,27 @@ function formatAction(action) { return formattedAction; } +/** + * Checks if the given string value is truthy by comparing it to the string 'true' (case-insensitive). + * + * @function + * @param {string|null|undefined} value - The string value to check. + * @returns {boolean} Returns `true` if the value is a case-insensitive match for the string 'true', otherwise returns `false`. + * @example + * + * isEnabled("True"); // returns true + * isEnabled("TRUE"); // returns true + * isEnabled("false"); // returns false + * isEnabled(null); // returns false + * isEnabled(); // returns false + */ +function isEnabled(value) { + return value?.toLowerCase()?.trim() === 'true'; +} + module.exports = { - handleError, - sendMessage, createOnProgress, + isEnabled, handleText, formatSteps, formatAction, diff --git a/api/server/utils/index.js b/api/server/utils/index.js index e76d5b436..ba21583f5 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -1,11 +1,17 @@ -const cryptoUtils = require('./crypto'); +const streamResponse = require('./streamResponse'); +const removePorts = require('./removePorts'); const handleText = require('./handleText'); +const cryptoUtils = require('./crypto'); const citations = require('./citations'); const sendEmail = require('./sendEmail'); +const math = require('./math'); module.exports = { + ...streamResponse, ...cryptoUtils, ...handleText, ...citations, + removePorts, sendEmail, + math, }; diff --git a/api/server/utils/math.js b/api/server/utils/math.js new file mode 100644 index 000000000..12c12c8cc --- /dev/null +++ b/api/server/utils/math.js @@ -0,0 +1,48 @@ +/** + * Evaluates a mathematical expression provided as a string and returns the result. + * + * If the input is already a number, it returns the number as is. + * If the input is not a string or contains invalid characters, an error is thrown. + * If the evaluated result is not a number, an error is thrown. + * + * @param {string|number} str - The mathematical expression to evaluate, or a number. + * @param {number} [fallbackValue] - The default value to return if the input is not a string or number, or if the evaluated result is not a number. + * + * @returns {number} The result of the evaluated expression or the input number. + * + * @throws {Error} Throws an error if the input is not a string or number, contains invalid characters, or does not evaluate to a number. + */ +function math(str, fallbackValue) { + const fallback = typeof fallbackValue !== 'undefined' && typeof fallbackValue === 'number'; + if (typeof str !== 'string' && typeof str === 'number') { + return str; + } else if (typeof str !== 'string') { + if (fallback) { + return fallbackValue; + } + throw new Error(`str is ${typeof str}, but should be a string`); + } + + const validStr = /^[+\-\d.\s*/%()]+$/.test(str); + + if (!validStr) { + if (fallback) { + return fallbackValue; + } + throw new Error('Invalid characters in string'); + } + + const value = eval(str); + + if (typeof value !== 'number') { + if (fallback) { + return fallbackValue; + } + console.error('str', str); + throw new Error(`str did not evaluate to a number but to a ${typeof value}`); + } + + return value; +} + +module.exports = math; diff --git a/api/server/utils/removePorts.js b/api/server/utils/removePorts.js new file mode 100644 index 000000000..db3e5e1db --- /dev/null +++ b/api/server/utils/removePorts.js @@ -0,0 +1 @@ +module.exports = (req) => req.ip.replace(/:\d+[^:]*$/, ''); diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js new file mode 100644 index 000000000..26cb0c238 --- /dev/null +++ b/api/server/utils/streamResponse.js @@ -0,0 +1,63 @@ +const crypto = require('crypto'); +const { saveMessage } = require('../../models'); + +/** + * Sends error data in Server Sent Events format and ends the response. + * @param {object} res - The server response. + * @param {string} message - The error message. + */ +const handleError = (res, message) => { + res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); + res.end(); +}; + +/** + * Sends message data in Server Sent Events format. + * @param {object} res - - The server response. + * @param {string} message - The message to be sent. + * @param {string} event - [Optional] The type of event. Default is 'message'. + */ +const sendMessage = (res, message, event = 'message') => { + if (message.length === 0) { + return; + } + res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); +}; + +/** + * Processes an error with provided options, saves the error message and sends a corresponding SSE response + * @async + * @param {object} res - The server response. + * @param {object} options - The options for handling the error containing message properties. + * @param {function} callback - [Optional] The callback function to be executed. + */ +const sendError = async (res, options, callback) => { + const { sender, conversationId, messageId, parentMessageId, text, shouldSaveMessage } = options; + const errorMessage = { + sender, + messageId: messageId ?? crypto.randomUUID(), + conversationId, + parentMessageId, + unfinished: false, + cancelled: false, + error: true, + final: true, + text, + isCreatedByUser: false, + }; + if (callback && typeof callback === 'function') { + await callback(); + } + + if (shouldSaveMessage) { + await saveMessage(errorMessage); + } + + handleError(res, errorMessage); +}; + +module.exports = { + handleError, + sendMessage, + sendError, +}; diff --git a/api/test/.env.test.example b/api/test/.env.test.example index e7a3fc48e..16730f672 100644 --- a/api/test/.env.test.example +++ b/api/test/.env.test.example @@ -7,3 +7,7 @@ CREDS_IV=cd02538f4be2fa37aba9420b5924389f # For testing the ChatAgent OPENAI_API_KEY=your-api-key + +BAN_VIOLATIONS=true +BAN_DURATION=7200000 +BAN_INTERVAL=20 diff --git a/api/test/__mocks__/KeyvMongo.js b/api/test/__mocks__/KeyvMongo.js new file mode 100644 index 000000000..f88bc144b --- /dev/null +++ b/api/test/__mocks__/KeyvMongo.js @@ -0,0 +1,30 @@ +const mockGet = jest.fn(); +const mockSet = jest.fn(); + +jest.mock('@keyv/mongo', () => { + const EventEmitter = require('events'); + class KeyvMongo extends EventEmitter { + constructor(url = 'mongodb://127.0.0.1:27017', options) { + super(); + this.ttlSupport = false; + url = url ?? {}; + if (typeof url === 'string') { + url = { url }; + } + if (url.uri) { + url = { url: url.uri, ...url }; + } + this.opts = { + url, + collection: 'keyv', + ...url, + ...options, + }; + } + + get = mockGet; + set = mockSet; + } + + return KeyvMongo; +}); diff --git a/client/package.json b/client/package.json index df34c5a26..a335e25e4 100644 --- a/client/package.json +++ b/client/package.json @@ -9,6 +9,7 @@ "dev": "cross-env NODE_ENV=development dotenv -e ../.env -- vite", "preview-prod": "cross-env NODE_ENV=development dotenv -e ../.env -- vite preview", "test": "cross-env NODE_ENV=test jest --watch", + "b:test": "NODE_ENV=test bun jest --watch", "test:ci": "cross-env NODE_ENV=test jest --ci", "b:build": "NODE_ENV=production bun vite build", "b:dev": "NODE_ENV=development bun vite" diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 63635d84b..a2ab5c88d 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -1,4 +1,11 @@ -import type { TConversation, TMessage, TPreset, TMutation } from 'librechat-data-provider'; +import type { + TConversation, + TMessage, + TPreset, + TMutation, + TLoginUser, + TUser, +} from 'librechat-data-provider'; export type TSetOption = (param: number | string) => (newValue: number | string | boolean) => void; export type TSetExample = ( @@ -146,3 +153,28 @@ export type TDialogProps = { open: boolean; onOpenChange: (open: boolean) => void; }; + +export type TResError = { + response: { data: { message: string } }; + message: string; +}; + +export type TAuthContext = { + user: TUser | undefined; + token: string | undefined; + isAuthenticated: boolean; + error: string | undefined; + login: (data: TLoginUser) => void; + logout: () => void; +}; + +export type TUserContext = { + user?: TUser | undefined; + token: string | undefined; + isAuthenticated: boolean; + redirect?: string; +}; + +export type TAuthConfig = { + loginRedirect: string; +}; diff --git a/client/src/components/Auth/Login.tsx b/client/src/components/Auth/Login.tsx index 6e2616c24..f75530b32 100644 --- a/client/src/components/Auth/Login.tsx +++ b/client/src/components/Auth/Login.tsx @@ -5,6 +5,7 @@ import { useNavigate } from 'react-router-dom'; import { useLocalize } from '~/hooks'; import { useGetStartupConfig } from 'librechat-data-provider'; import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components'; +import { getLoginError } from '~/utils'; function Login() { const { login, error, isAuthenticated } = useAuthContext(); @@ -30,9 +31,7 @@ function Login() { className="relative mt-4 rounded border border-red-400 bg-red-100 px-4 py-3 text-red-700" role="alert" > - {error?.includes('429') - ? localize('com_auth_error_login_rl') - : localize('com_auth_error_login')} + {localize(getLoginError(error))} )} diff --git a/client/src/components/Messages/Content/MessageContent.tsx b/client/src/components/Messages/Content/MessageContent.tsx index 9da0a831b..879dd6098 100644 --- a/client/src/components/Messages/Content/MessageContent.tsx +++ b/client/src/components/Messages/Content/MessageContent.tsx @@ -1,20 +1,28 @@ import { Fragment } from 'react'; import type { TResPlugin } from 'librechat-data-provider'; import type { TMessageContent, TText, TDisplayProps } from '~/common'; -import { cn, getError } from '~/utils'; +import { useAuthContext } from '~/hooks'; +import { cn, getMessageError } from '~/utils'; import EditMessage from './EditMessage'; import Container from './Container'; import Markdown from './Markdown'; import Plugin from './Plugin'; -// Error Message Component -const ErrorMessage = ({ text }: TText) => ( - -
- {getError(text)} -
-
-); +const ErrorMessage = ({ text }: TText) => { + const { logout } = useAuthContext(); + + if (text.includes('ban')) { + logout(); + return null; + } + return ( + +
+ {getMessageError(text)} +
+
+ ); +}; // Display Message Component const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => ( diff --git a/client/src/hooks/AuthContext.tsx b/client/src/hooks/AuthContext.tsx index 3c6c8aff8..6f67df8d9 100644 --- a/client/src/hooks/AuthContext.tsx +++ b/client/src/hooks/AuthContext.tsx @@ -1,7 +1,7 @@ import { + useMemo, useState, useEffect, - useMemo, ReactNode, useCallback, createContext, @@ -17,33 +17,14 @@ import { useRefreshTokenMutation, TLoginUser, } from 'librechat-data-provider'; +import { TAuthConfig, TUserContext, TAuthContext, TResError } from '~/common'; import { useNavigate } from 'react-router-dom'; +import useTimeout from './useTimeout'; -export type TAuthContext = { - user: TUser | undefined; - token: string | undefined; - isAuthenticated: boolean; - error: string | undefined; - login: (data: TLoginUser) => void; - logout: () => void; -}; - -export type TUserContext = { - user?: TUser | undefined; - token: string | undefined; - isAuthenticated: boolean; - redirect?: string; -}; - -export type TAuthConfig = { - loginRedirect: string; -}; -//@ts-ignore - index expression is not of type number -window['errorTimeout'] = undefined; const AuthContext = createContext(undefined); const AuthContextProvider = ({ - authConfig, + // authConfig, children, }: { authConfig?: TAuthConfig; @@ -61,16 +42,7 @@ const AuthContextProvider = ({ const userQuery = useGetUserQuery({ enabled: !!token }); const refreshToken = useRefreshTokenMutation(); - // This seems to prevent the error flashing issue - const doSetError = (error: string | undefined) => { - if (error) { - console.log(error); - // set timeout to ensure we don't get a flash of the error message - window['errorTimeout'] = setTimeout(() => { - setError(error); - }, 400); - } - }; + const doSetError = useTimeout({ callback: (error) => setError(error as string | undefined) }); const setUserContext = useCallback( (userContext: TUserContext) => { @@ -89,19 +61,15 @@ const AuthContextProvider = ({ [navigate], ); - const getCookieValue = (key: string) => { - const keyValue = document.cookie.match('(^|;) ?' + key + '=([^;]*)(;|$)'); - return keyValue ? keyValue[2] : null; - }; - const login = (data: TLoginUser) => { loginUser.mutate(data, { onSuccess: (data: TLoginResponse) => { const { user, token } = data; setUserContext({ token, isAuthenticated: true, user, redirect: '/chat/new' }); }, - onError: (error) => { - doSetError((error as Error).message); + onError: (error: TResError | unknown) => { + const resError = error as TResError; + doSetError(resError.message); navigate('/login', { replace: true }); }, }); @@ -119,6 +87,12 @@ const AuthContextProvider = ({ }, onError: (error) => { doSetError((error as Error).message); + setUserContext({ + token: undefined, + isAuthenticated: false, + user: undefined, + redirect: '/login', + }); }, }); }, [setUserContext, logoutUser]); diff --git a/client/src/hooks/index.ts b/client/src/hooks/index.ts index 746f44cdd..25f2755a9 100644 --- a/client/src/hooks/index.ts +++ b/client/src/hooks/index.ts @@ -2,6 +2,7 @@ export * from './AuthContext'; export * from './ThemeContext'; export * from './ScreenshotContext'; export * from './ApiErrorBoundaryContext'; +export { default as useTimeout } from './useTimeout'; export { default as useUserKey } from './useUserKey'; export { default as useDebounce } from './useDebounce'; export { default as useLocalize } from './useLocalize'; diff --git a/client/src/hooks/useTimeout.tsx b/client/src/hooks/useTimeout.tsx new file mode 100644 index 000000000..e058e9ca8 --- /dev/null +++ b/client/src/hooks/useTimeout.tsx @@ -0,0 +1,39 @@ +import { useEffect, useRef } from 'react'; + +type TUseTimeoutParams = { + callback: (error: string | number | boolean | null) => void; + delay?: number | undefined; +}; +type TTimeout = ReturnType | null; + +function useTimeout({ callback, delay = 400 }: TUseTimeoutParams) { + const timeout = useRef(null); + + const callOnTimeout = (value: string | undefined) => { + // Clear existing timeout + if (timeout.current !== null) { + clearTimeout(timeout.current); + } + + // Set new timeout + if (value) { + console.log(value); + timeout.current = setTimeout(() => { + callback(value); + }, delay); + } + }; + + // Clear timeout when the component unmounts + useEffect(() => { + return () => { + if (timeout.current !== null) { + clearTimeout(timeout.current); + } + }; + }, []); + + return callOnTimeout; +} + +export default useTimeout; diff --git a/client/src/localization/languages/Eng.tsx b/client/src/localization/languages/Eng.tsx index 937bcd676..f8bd24531 100644 --- a/client/src/localization/languages/Eng.tsx +++ b/client/src/localization/languages/Eng.tsx @@ -52,7 +52,11 @@ export default { com_auth_error_login: 'Unable to login with the information provided. Please check your credentials and try again.', com_auth_error_login_rl: - 'Too many login attempts from this IP in a short amount of time. Please try again later.', + 'Too many login attempts in a short amount of time. Please try again later.', + com_auth_error_login_ban: + 'Your account has been temporarily banned due to violations of our service.', + com_auth_error_login_server: + 'There was an internal server error. Please wait a few moments and try again.', com_auth_no_account: 'Don\'t have an account?', com_auth_sign_up: 'Sign up', com_auth_sign_in: 'Sign in', diff --git a/client/src/utils/getError.ts b/client/src/utils/getError.ts deleted file mode 100644 index e41cc3951..000000000 --- a/client/src/utils/getError.ts +++ /dev/null @@ -1,28 +0,0 @@ -const isJson = (str: string) => { - try { - JSON.parse(str); - } catch (e) { - return false; - } - return true; -}; - -const getError = (text: string) => { - const errorMessage = text.length > 512 ? text.slice(0, 512) + '...' : text; - const match = text.match(/\{[^{}]*\}/); - const jsonString = match ? match[0] : ''; - if (isJson(jsonString)) { - const json = JSON.parse(jsonString); - if (json.code === 'invalid_api_key') { - return 'Invalid API key. Please check your API key and try again. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.'; - } else if (json.type === 'insufficient_quota') { - return 'We apologize for any inconvenience caused. The default API key has reached its limit. To continue using this service, please set up your own API key. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.'; - } else { - return `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`; - } - } else { - return `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`; - } -}; - -export default getError; diff --git a/client/src/utils/getLoginError.ts b/client/src/utils/getLoginError.ts new file mode 100644 index 000000000..6bd3c1ba8 --- /dev/null +++ b/client/src/utils/getLoginError.ts @@ -0,0 +1,18 @@ +const getLoginError = (errorText: string) => { + const defaultError = 'com_auth_error_login'; + if (!errorText) { + return defaultError; + } + + if (errorText?.includes('429')) { + return 'com_auth_error_login_rl'; + } else if (errorText?.includes('403')) { + return 'com_auth_error_login_ban'; + } else if (errorText?.includes('500')) { + return 'com_auth_error_login_server'; + } else { + return defaultError; + } +}; + +export default getLoginError; diff --git a/client/src/utils/getMessageError.ts b/client/src/utils/getMessageError.ts new file mode 100644 index 000000000..4d2be10e4 --- /dev/null +++ b/client/src/utils/getMessageError.ts @@ -0,0 +1,62 @@ +const isJson = (str: string) => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; + +type TConcurrent = { + limit: number; +}; + +type TMessageLimit = { + max: number; + windowInMinutes: number; +}; + +const errorMessages = { + ban: 'Your account has been temporarily banned due to violations of our service.', + invalid_api_key: + 'Invalid API key. Please check your API key and try again. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.', + insufficient_quota: + 'We apologize for any inconvenience caused. The default API key has reached its limit. To continue using this service, please set up your own API key. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.', + concurrent: (json: TConcurrent) => { + const { limit } = json; + const plural = limit > 1 ? 's' : ''; + return `Only ${limit} message${plural} at a time. Please allow any other responses to complete before sending another message, or wait one minute.`; + }, + message_limit: (json: TMessageLimit) => { + const { max, windowInMinutes } = json; + const plural = max > 1 ? 's' : ''; + return `You hit the message limit. You have a cap of ${max} message${plural} per ${ + windowInMinutes > 1 ? `${windowInMinutes} minutes` : 'minute' + }.`; + }, +}; + +const getMessageError = (text: string) => { + const errorMessage = text.length > 512 ? text.slice(0, 512) + '...' : text; + const match = text.match(/\{[^{}]*\}/); + const jsonString = match ? match[0] : ''; + const defaultResponse = `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`; + + if (!isJson(jsonString)) { + return defaultResponse; + } + + const json = JSON.parse(jsonString); + const errorKey = json.code || json.type; + const keyExists = errorKey && errorMessages[errorKey]; + + if (keyExists && typeof errorMessages[errorKey] === 'function') { + return errorMessages[errorKey](json); + } else if (keyExists) { + return errorMessages[errorKey]; + } else { + return defaultResponse; + } +}; + +export default getMessageError; diff --git a/client/src/utils/index.ts b/client/src/utils/index.ts index dff114016..b8a199b2f 100644 --- a/client/src/utils/index.ts +++ b/client/src/utils/index.ts @@ -2,10 +2,11 @@ import { clsx } from 'clsx'; import { twMerge } from 'tailwind-merge'; export * from './languages'; -export { default as getError } from './getError'; export { default as buildTree } from './buildTree'; +export { default as getLoginError } from './getLoginError'; export { default as cleanupPreset } from './cleanupPreset'; export { default as validateIframe } from './validateIframe'; +export { default as getMessageError } from './getMessageError'; export { default as getLocalStorageItems } from './getLocalStorageItems'; export { default as getDefaultConversation } from './getDefaultConversation'; diff --git a/docs/features/mod_system.md b/docs/features/mod_system.md new file mode 100644 index 000000000..107c61cd3 --- /dev/null +++ b/docs/features/mod_system.md @@ -0,0 +1,67 @@ +## Automated Moderation System (optional) +The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching a set threshold, the user and their IP are temporarily banned. This system ensures platform security by monitoring and penalizing rapid or suspicious activities. + +In production, you should have Cloudflare or some other DDoS protection in place to really protect the server from excessive requests, but these changes will largely protect you from the single or several bad actors targeting your deployed instance for proxying. + +### Notes + +- Uses Caching for basic security and violation logging (bans, concurrent messages, exceeding rate limits) + - In the near future, I will add **Redis** support for production instances, which can be easily injected into the current caching setup +- Exceeding any of the rate limiters (login/registration/messaging) is considered a violation, default score is 1 +- Non-browser origin is a violation +- Default score for each violation is configurable +- Enabling any of the limiters and/or bans enables caching/logging +- Violation logs can be found in the data folder, which is created when logging begins: `librechat/data` + - **Only violations are logged** + - `violations.json` keeps track of the total count for each violation per user + - `logs.json` records each individual violation per user +- Ban logs are stored in MongoDB under the `logs` collection. They are transient as they only exist for the ban duration + - If you would like to remove a ban manually, you would have to remove them from the database manually and restart the server + - **Redis** support is also planned for this. + +### Rate Limiters + +The project's current rate limiters are as follows (see below under setup for default values): + +- Login and registration rate limiting +- [optional] Concurrent Message limiting (only X messages at a time per user) +- [optional] Message limiting (how often a user can send a message, configurable by IP and User) + +### Setup + +The following are all of the related env variables to make use of and configure the mod system. Note this is also found in the [/.env.example](/.env.example) file, to be set in your own `.env` file. + +```bash +BAN_VIOLATIONS=true # Whether or not to enable banning users for violations (they will still be logged) +BAN_DURATION=1000 * 60 * 60 * 2 # how long the user and associated IP are banned for +BAN_INTERVAL=20 # a user will be banned everytime their score reaches/crosses over the interval threshold + +# The score for each violation + +LOGIN_VIOLATION_SCORE=1 +REGISTRATION_VIOLATION_SCORE=1 +CONCURRENT_VIOLATION_SCORE=1 +MESSAGE_VIOLATION_SCORE=1 +NON_BROWSER_VIOLATION_SCORE=20 + +# Login and registration rate limiting. + +LOGIN_MAX=7 # The max amount of logins allowed per IP per LOGIN_WINDOW +LOGIN_WINDOW=5 # in minutes, determines the window of time for LOGIN_MAX logins +REGISTER_MAX=5 # The max amount of registrations allowed per IP per REGISTER_WINDOW +REGISTER_WINDOW=60 # in minutes, determines the window of time for REGISTER_MAX registrations + +# Message rate limiting (per user & IP) + +LIMIT_CONCURRENT_MESSAGES=true # Whether to limit the amount of messages a user can send per request +CONCURRENT_MESSAGE_MAX=2 # The max amount of messages a user can send per request + +LIMIT_MESSAGE_IP=true # Whether to limit the amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_MAX=40 # The max amount of messages an IP can send per MESSAGE_IP_WINDOW +MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_MAX messages + +# Note: You can utilize both limiters, but default is to limit by IP only. +LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW +MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages +``` \ No newline at end of file diff --git a/docs/features/third-party.md b/docs/features/third_party.md similarity index 100% rename from docs/features/third-party.md rename to docs/features/third_party.md diff --git a/docs/install/user_auth_system.md b/docs/install/user_auth_system.md index 92ad4d2b1..b605f44de 100644 --- a/docs/install/user_auth_system.md +++ b/docs/install/user_auth_system.md @@ -9,15 +9,29 @@ In order for the auth system to function properly, there are some environment va In /.env, you will need to set the following variables: ```bash -# Change this to a secure string +# Change the secrets to a secure, random string JWT_SECRET=secret +JWT_REFRESH_SECRET=refresh_secret + # Set the expiration delay for the secure cookie with the JWT token -# Delay is in millisecond e.g. 7 days is 1000*60*60*24*7 -SESSION_EXPIRY=1000 * 60 * 60 * 24 * 7 +# Delay is in milliseconds e.g. 7 days is 1000*60*60*24*7 + +# Recommended session expiry is 15 minutes. Make it longer if you want the user to be able to revist the page without logging in for a longer duration of time. + +# Recommended refresh token expiry is 7 days +SESSION_EXPIRY=1000 * 60 * 15 +REFRESH_TOKEN_EXPIRY=(1000 * 60 * 60 * 24) * 7 + DOMAIN_SERVER=http://localhost:3080 DOMAIN_CLIENT=http://localhost:3080 ``` +## Automated Moderation System (optional) + +The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching a set threshold, the user and their IP are temporarily banned. This system ensures platform security by monitoring and penalizing rapid or suspicious activities. + +To set up the mod system, review [the setup guide](../features/mod_system.md). + *Please Note: If you are wanting this to work in development mode, you will need to create a file called `.env.development` in the root directory and set `DOMAIN_CLIENT` to `http://localhost:3090` or whatever port is provided by vite when runnning `npm run frontend-dev`* Important: When you run the app for the first time, you need to create a new account by clicking on "Sign up" on the login page. The first account you make will be the admin account. The admin account doesn't have any special features right now, but it might be useful if you want to make an admin dashboard to manage other users later. diff --git a/mkdocs.yml b/mkdocs.yml index b036ecbae..b10915e60 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -102,7 +102,8 @@ nav: - Azure Cognitive Search: 'features/plugins/azure_cognitive_search.md' - Make Your Own Plugin: 'features/plugins/make_your_own.md' - Using official ChatGPT Plugins: 'features/plugins/chatgpt_plugins_openapi.md' - - Third-Party Tools: 'features/third-party.md' + - Automated Moderation: 'features/mod_system.md' + - Third-Party Tools: 'features/third_party.md' - Proxy: 'features/proxy.md' - Bing Jailbreak: 'features/bing_jailbreak.md' - Cloud Deployment: diff --git a/package-lock.json b/package-lock.json index 849219d91..4336e7446 100644 --- a/package-lock.json +++ b/package-lock.json @@ -72,7 +72,7 @@ "joi": "^17.9.2", "js-yaml": "^4.1.0", "jsonwebtoken": "^9.0.0", - "keyv": "^4.5.2", + "keyv": "^4.5.3", "keyv-file": "^0.2.0", "langchain": "^0.0.144", "lodash": "^4.17.21", @@ -91,6 +91,7 @@ "pino": "^8.12.1", "sanitize": "^2.1.2", "sharp": "^0.32.5", + "ua-parser-js": "^1.0.36", "zod": "^3.22.2" }, "devDependencies": { @@ -25069,6 +25070,28 @@ "node": ">=14.17" } }, + "node_modules/ua-parser-js": { + "version": "1.0.36", + "resolved": "https://registry.npmjs.org/ua-parser-js/-/ua-parser-js-1.0.36.tgz", + "integrity": "sha512-znuyCIXzl8ciS3+y3fHJI/2OhQIXbXw9MWC/o3qwyR+RGppjZHrM27CGFSKCJXi2Kctiz537iOu2KnXs1lMQhw==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/ua-parser-js" + }, + { + "type": "paypal", + "url": "https://paypal.me/faisalman" + }, + { + "type": "github", + "url": "https://github.com/sponsors/faisalman" + } + ], + "engines": { + "node": "*" + } + }, "node_modules/uglify-js": { "version": "3.17.4", "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.17.4.tgz", diff --git a/package.json b/package.json index bff7ba80e..22397f172 100644 --- a/package.json +++ b/package.json @@ -49,8 +49,8 @@ "b:data-provider": "cd packages/data-provider && bun run b:build", "b:client": "bun run b:data-provider && cd client && bun run b:build", "b:client:dev": "cd client && bun run b:dev", - "b:test:client": "cd client && bun run test", - "b:test:api": "cd api && bun run test" + "b:test:client": "cd client && bun run b:test", + "b:test:api": "cd api && bun run b:test" }, "repository": { "type": "git", @@ -92,7 +92,7 @@ "nodemonConfig": { "ignore": [ "api/data/", - "data", + "data/", "client/", "admin/", "packages/" diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index aa69af2e2..5db9e26d3 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.1.7", + "version": "0.1.8", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js",