feat: Message Rate Limiters, Violation Logging, & Ban System 🔨 (#903)

* refactor: require Auth middleware in route index files

* feat: concurrent message limiter

* feat: complete concurrent message limiter with caching

* refactor: SSE response methods separated from handleText

* fix(abortMiddleware): fix req and res order to standard, use endpointOption in req.body

* chore: minor name changes

* refactor: add isUUID condition to saveMessage

* fix(concurrentLimiter): logic correctly handles the max number of concurrent messages and res closing/finalization

* chore: bump keyv and remove console.log from Message

* fix(concurrentLimiter): ensure messages are only saved in later message children

* refactor(concurrentLimiter): use KeyvFile instead, could make other stores configurable in the future

* feat: add denyRequest function for error responses

* feat(utils): add isStringTruthy function

Introduce the isStringTruthy function to the utilities module to check if a string value is a case-insensitive match for 'true'

* feat: add optional message rate limiters by IP and userId

* feat: add optional message rate limiters by IP and userId to edit route

* refactor: rename isStringTruthy to isTrue for brevity

* refactor(getError): use map to make code cleaner

* refactor: use memory for concurrent rate limiter to prevent clearing on startup/exit, add multiple log files, fix error message for concurrent violation

* feat: check if errorMessage is object, stringify if so

* chore: send object to denyRequest which will stringify it

* feat: log excessive requests

* fix(getError): correctly pluralize messages

* refactor(limiters): make type consistent between logs and errorMessage

* refactor(cache): move files out of lib/db into separate cache dir
>> feat: add getLogStores function so Keyv instance is not redundantly created on every violation
feat: separate violation logging to own function with logViolation

* fix: cache/index.js export, properly record userViolations

* refactor(messageLimiters): use new logging method, add logging to registrations

* refactor(logViolation): make userLogs an array of logs per user

* feat: add logging to login limiter

* refactor: pass req as first param to logViolation and record offending IP

* refactor: rename isTrue helper fn to isEnabled

* feat: add simple non_browser check and log violation

* fix: open handles in unit tests, remove KeyvMongo as not used and properly mock global fetch

* chore: adjust nodemon ignore paths to properly ignore logs

* feat: add math helper function for safe use of eval

* refactor(api/convos): use middleware at top of file to avoid redundancy

* feat: add delete all static method for Sessions

* fix: redirect to login on refresh if user is not found, or the session is not found but hasn't expired (ban case)

* refactor(getLogStores): adjust return type

* feat: add ban violation and check ban logic
refactor(logViolation): pass both req and res objects

* feat: add removePorts helper function

* refactor: rename getError to getMessageError and add getLoginError for displaying different login errors

* fix(AuthContext): fix type issue and remove unused code

* refactor(bans): ban by ip and user id, send response based on origin

* chore: add frontend ban messages

* refactor(routes/oauth): add ban check to handler, also consolidate logic to avoid redundancy

* feat: add ban check to AI messaging routes

* feat: add ban check to login/registration

* fix(ci/api): mock KeyvMongo to avoid tests hanging

* docs: update .env.example
> refactor(banViolation): calculate interval rate crossover, early return if duration is invalid
ci(banViolation): add tests to ensure users are only banned when expected

* docs: improve wording for mod system

* feat: add configurable env variables for violation scores

* chore: add jsdoc for uaParser.js

* chore: improve ban text log

* chore: update bun test scripts

* refactor(math.js): add fallback values

* fix(KeyvMongo/banLogs): refactor keyv instances to top of files to avoid memory leaks, refactor ban logic to use getLogStores instead
refactor(getLogStores): get a single log store by type

* fix(ci): refactor tests due to banLogs changes, also make sure to clear and revoke sessions even if ban duration is 0

* fix(banViolation.js): getLogStores import

* feat: handle 500 code error at login

* fix(middleware): handle case where user.id is _id and not just id

* ci: add ban secrets for backend unit tests

* refactor: logout user upon ban

* chore: log session delete message only if deletedCount > 0

* refactor: change default ban duration (2h) and make logic more clear in JSDOC

* fix: login and registration limiters will now return rate limiting error

* fix: userId not parsable as non ObjectId string

* feat: add useTimeout hook to properly clear timeouts when invoking functions within them
refactor(AuthContext): cleanup code by using new hook and defining types in ~/common

* fix: login error message for rate limits

* docs: add info for automated mod system and rate limiters, update other docs accordingly

* chore: bump data-provider version
This commit is contained in:
Danny Avila 2023-09-13 10:57:07 -04:00 committed by GitHub
parent db803cd640
commit 7b2cedf5ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
69 changed files with 2180 additions and 1062 deletions

View file

@ -13,12 +13,44 @@ APP_TITLE=LibreChat
HOST=localhost
PORT=3080
# Automated Moderation System
# The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions
# like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching
# a set threshold, the user and their IP are temporarily banned. This system ensures platform security
# by monitoring and penalizing rapid or suspicious activities.
BAN_VIOLATIONS=true # Whether or not to enable banning users for violations (they will still be logged)
BAN_DURATION=1000 * 60 * 60 * 2 # how long the user and associated IP are banned for
BAN_INTERVAL=20 # a user will be banned everytime their score reaches/crosses over the interval threshold
# The score for each violation
LOGIN_VIOLATION_SCORE=1
REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20
# Login and registration rate limiting.
LOGIN_MAX=7 # The max amount of logins allowed per IP per LOGIN_WINDOW
LOGIN_WINDOW=5 # in minutes, determines how long an IP is banned for after LOGIN_MAX logins
LOGIN_WINDOW=5 # in minutes, determines the window of time for LOGIN_MAX logins
REGISTER_MAX=5 # The max amount of registrations allowed per IP per REGISTER_WINDOW
REGISTER_WINDOW=60 # in minutes, determines how long an IP is banned for after REGISTER_MAX registrations
REGISTER_WINDOW=60 # in minutes, determines the window of time for REGISTER_MAX registrations
# Message rate limiting (per user & IP)
LIMIT_CONCURRENT_MESSAGES=true # Whether to limit the amount of messages a user can send per request
CONCURRENT_MESSAGE_MAX=2 # The max amount of messages a user can send per request
LIMIT_MESSAGE_IP=true # Whether to limit the amount of messages an IP can send per MESSAGE_IP_WINDOW
MESSAGE_IP_MAX=40 # The max amount of messages an IP can send per MESSAGE_IP_WINDOW
MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_MAX messages
# Note: You can utilize both limiters, but default is to limit by IP only.
LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW
MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW
MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages
# Change this to proxy any API request.
# It's useful if your machine has difficulty calling the original API server.

View file

@ -18,6 +18,9 @@ jobs:
JWT_SECRET: ${{ secrets.JWT_SECRET }}
CREDS_KEY: ${{ secrets.CREDS_KEY }}
CREDS_IV: ${{ secrets.CREDS_IV }}
BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }}
BAN_DURATION: ${{ secrets.BAN_DURATION }}
BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }}
NODE_ENV: ci
steps:
- uses: actions/checkout@v2

1
.gitignore vendored
View file

@ -3,6 +3,7 @@
# Logs
data-node
meili_data
data/
logs
*.log

View file

@ -96,7 +96,8 @@ Keep up with the latest updates by visiting the releases page - [Releases](https
* [Using official ChatGPT Plugins](docs/features/plugins/chatgpt_plugins_openapi.md)
* [Third-Party Tools](docs/features/third-party.md)
* [Automated Moderation](docs/features/mod_system.md)
* [Third-Party Tools](docs/features/third_party.md)
* [Proxy](docs/features/proxy.md)
* [Bing Jailbreak](docs/features/bing_jailbreak.md)
</details>

View file

@ -1,7 +1,14 @@
const fs = require('fs');
const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin');
jest.mock('node-fetch');
global.fetch = jest.fn().mockImplementationOnce(() => {
return new Promise((resolve) => {
resolve({
ok: true,
json: () => Promise.resolve({ key: 'value' }),
});
});
});
jest.mock('fs', () => ({
promises: {
readFile: jest.fn(),

68
api/cache/banViolation.js vendored Normal file
View file

@ -0,0 +1,68 @@
const Session = require('../models/Session');
const getLogStores = require('./getLogStores');
const { isEnabled, math, removePorts } = require('../server/utils');
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
const interval = math(BAN_INTERVAL, 20);
/**
* Bans a user based on violation criteria.
*
* If the user's violation count is a multiple of the BAN_INTERVAL, the user will be banned.
* The duration of the ban is determined by the BAN_DURATION environment variable.
* If BAN_DURATION is not set or invalid, the user will not be banned.
* Sessions will be deleted and the refreshToken cookie will be cleared even with
* an invalid or nill duration, which is a "soft" ban; the user can remain active until
* access token expiry.
*
* @async
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {Object} errorMessage - Object containing user violation details.
* @param {string} errorMessage.type - Type of the violation.
* @param {string} errorMessage.user_id - ID of the user who committed the violation.
* @param {number} errorMessage.violation_count - Number of violations committed by the user.
*
* @returns {Promise<void>}
*
*/
const banViolation = async (req, res, errorMessage) => {
if (!isEnabled(BAN_VIOLATIONS)) {
return;
}
if (!errorMessage) {
return;
}
const { type, user_id, prev_count, violation_count } = errorMessage;
const prevThreshold = Math.floor(prev_count / interval);
const currentThreshold = Math.floor(violation_count / interval);
if (prevThreshold >= currentThreshold) {
return;
}
await Session.deleteAllUserSessions(user_id);
res.clearCookie('refreshToken');
const banLogs = getLogStores('ban');
const duration = banLogs.opts.ttl;
if (duration <= 0) {
return;
}
req.ip = removePorts(req);
console.log(`[BAN] Banning user ${user_id} @ ${req.ip} for ${duration / 1000 / 60} minutes`);
const expiresAt = Date.now() + duration;
await banLogs.set(user_id, { type, violation_count, duration, expiresAt });
await banLogs.set(req.ip, { type, user_id, violation_count, duration, expiresAt });
errorMessage.ban = true;
errorMessage.ban_duration = duration;
return;
};
module.exports = banViolation;

155
api/cache/banViolation.spec.js vendored Normal file
View file

@ -0,0 +1,155 @@
const banViolation = require('./banViolation');
jest.mock('keyv');
jest.mock('../models/Session');
// Mocking the getLogStores function
jest.mock('./getLogStores', () => {
return jest.fn().mockImplementation(() => {
const EventEmitter = require('events');
const math = require('../server/utils/math');
const mockGet = jest.fn();
const mockSet = jest.fn();
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = mockGet;
set = mockSet;
}
return new KeyvMongo('', {
namespace: 'bans',
ttl: math(process.env.BAN_DURATION, 7200000),
});
});
});
describe('banViolation', () => {
let req, res, errorMessage;
beforeEach(() => {
req = {
ip: '127.0.0.1',
cookies: {
refreshToken: 'someToken',
},
};
res = {
clearCookie: jest.fn(),
};
errorMessage = {
type: 'someViolation',
user_id: '12345',
prev_count: 0,
violation_count: 0,
};
process.env.BAN_VIOLATIONS = 'true';
process.env.BAN_DURATION = '7200000'; // 2 hours in ms
process.env.BAN_INTERVAL = '20';
});
afterEach(() => {
jest.clearAllMocks();
});
it('should not ban if BAN_VIOLATIONS are not enabled', async () => {
process.env.BAN_VIOLATIONS = 'false';
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
it('should not ban if errorMessage is not provided', async () => {
await banViolation(req, res, null);
expect(errorMessage.ban).toBeFalsy();
});
it('[1/3] should ban if violation_count crosses the interval threshold: 19 -> 39', async () => {
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('[2/3] should ban if violation_count crosses the interval threshold: 19 -> 20', async () => {
errorMessage.prev_count = 19;
errorMessage.violation_count = 20;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
const randomValueAbove = Math.floor(20 + Math.random() * 100);
it(`[3/3] should ban if violation_count crosses the interval threshold: 19 -> ${randomValueAbove}`, async () => {
errorMessage.prev_count = 19;
errorMessage.violation_count = randomValueAbove;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('should handle invalid BAN_INTERVAL and default to 20', async () => {
process.env.BAN_INTERVAL = 'invalid';
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('should ban if BAN_DURATION is invalid as default is 2 hours', async () => {
process.env.BAN_DURATION = 'invalid';
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('should not ban if BAN_DURATION is 0 but should clear cookies', async () => {
process.env.BAN_DURATION = '0';
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(res.clearCookie).toHaveBeenCalledWith('refreshToken');
});
it('should not ban if violation_count does not change', async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = 0;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
it('[1/2] should not ban if violation_count does not cross the interval threshold: 0 -> 19', async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = 19;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
const randomValueUnder = Math.floor(1 + Math.random() * 19);
it(`[2/2] should not ban if violation_count does not cross the interval threshold: 0 -> ${randomValueUnder}`, async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = randomValueUnder;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
it('[EDGE CASE] should not ban if violation_count is lower', async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = -10;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
});

29
api/cache/clearPendingReq.js vendored Normal file
View file

@ -0,0 +1,29 @@
const Keyv = require('keyv');
const { pendingReqFile } = require('./keyvFiles');
const { LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};
const keyv = new Keyv({ store: pendingReqFile, namespace: 'pendingRequests' });
/**
* Clear pending requests from the cache.
* Checks the environmental variable LIMIT_CONCURRENT_MESSAGES;
* if the rule is enabled ('true'), pending requests in the cache are cleared.
*
* @module clearPendingReq
* @requires keyv
* @requires keyvFiles
* @requires process
*
* @async
* @function
* @returns {Promise<void>} A promise that either clears 'pendingRequests' from store or resolves with no value.
*/
const clearPendingReq = async () => {
if (LIMIT_CONCURRENT_MESSAGES?.toLowerCase() !== 'true') {
return;
}
await keyv.clear();
};
module.exports = clearPendingReq;

40
api/cache/getLogStores.js vendored Normal file
View file

@ -0,0 +1,40 @@
const Keyv = require('keyv');
const keyvMongo = require('./keyvMongo');
const { math } = require('../server/utils');
const { logFile, violationFile } = require('./keyvFiles');
const { BAN_DURATION } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000);
const namespaces = {
ban: new Keyv({ store: keyvMongo, ttl: duration, namespace: 'bans' }),
general: new Keyv({ store: logFile, namespace: 'violations' }),
concurrent: new Keyv({ store: violationFile, namespace: 'concurrent' }),
non_browser: new Keyv({ store: violationFile, namespace: 'non_browser' }),
message_limit: new Keyv({ store: violationFile, namespace: 'message_limit' }),
registrations: new Keyv({ store: violationFile, namespace: 'registrations' }),
logins: new Keyv({ store: violationFile, namespace: 'logins' }),
};
/**
* Returns either the logs of violations specified by type if a type is provided
* or it returns the general log if no type is specified. If an invalid type is passed,
* an error will be thrown.
*
* @module getLogStores
* @requires keyv - a simple key-value storage that allows you to easily switch out storage adapters.
* @requires keyvFiles - a module that includes the logFile and violationFile.
*
* @param {string} type - The type of violation, which can be 'concurrent', 'message_limit', 'registrations' or 'logins'.
* @returns {Keyv} - If a valid type is passed, returns an object containing the logs for violations of the specified type.
* @throws Will throw an error if an invalid violation type is passed.
*/
const getLogStores = (type) => {
if (!type) {
throw new Error(`Invalid store type: ${type}`);
}
const logs = namespaces[type];
return logs;
};
module.exports = getLogStores;

6
api/cache/index.js vendored Normal file
View file

@ -0,0 +1,6 @@
const keyvFiles = require('./keyvFiles');
const getLogStores = require('./getLogStores');
const logViolation = require('./logViolation');
const clearPendingReq = require('./clearPendingReq');
module.exports = { ...keyvFiles, getLogStores, logViolation, clearPendingReq };

11
api/cache/keyvFiles.js vendored Normal file
View file

@ -0,0 +1,11 @@
const { KeyvFile } = require('keyv-file');
const logFile = new KeyvFile({ filename: './data/logs.json' });
const pendingReqFile = new KeyvFile({ filename: './data/pendingReqCache.json' });
const violationFile = new KeyvFile({ filename: './data/violations.json' });
module.exports = {
logFile,
pendingReqFile,
violationFile,
};

7
api/cache/keyvMongo.js vendored Normal file
View file

@ -0,0 +1,7 @@
const KeyvMongo = require('@keyv/mongo');
const { MONGO_URI } = process.env ?? {};
const keyvMongo = new KeyvMongo(MONGO_URI, { collection: 'logs' });
keyvMongo.on('error', (err) => console.error('KeyvMongo connection error:', err));
module.exports = keyvMongo;

36
api/cache/logViolation.js vendored Normal file
View file

@ -0,0 +1,36 @@
const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation');
/**
* Logs the violation.
*
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {string} type - The type of violation.
* @param {Object} errorMessage - The error message to log.
* @param {number} [score=1] - The severity of the violation. Defaults to 1
*/
const logViolation = async (req, res, type, errorMessage, score = 1) => {
const userId = req.user?.id ?? req.user?._id;
if (!userId) {
return;
}
const logs = getLogStores('general');
const violationLogs = getLogStores(type);
const userViolations = (await violationLogs.get(userId)) ?? 0;
const violationCount = userViolations + score;
await violationLogs.set(userId, violationCount);
errorMessage.user_id = userId;
errorMessage.prev_count = userViolations;
errorMessage.violation_count = violationCount;
errorMessage.date = new Date().toISOString();
await banViolation(req, res, errorMessage);
const userLogs = (await logs.get(userId)) ?? [];
userLogs.push(errorMessage);
await logs.set(userId, userLogs);
};
module.exports = logViolation;

View file

@ -3,5 +3,5 @@ module.exports = {
clearMocks: true,
roots: ['<rootDir>'],
coverageDirectory: 'coverage',
setupFiles: ['./test/jestSetup.js'],
setupFiles: ['./test/jestSetup.js', './test/__mocks__/KeyvMongo.js'],
};

4
api/lib/db/index.js Normal file
View file

@ -0,0 +1,4 @@
const connectDb = require('./connectDb');
const indexSync = require('./indexSync');
module.exports = { connectDb, indexSync };

View file

@ -1,5 +1,8 @@
const { z } = require('zod');
const Message = require('./schema/messageSchema');
const idSchema = z.string().uuid();
module.exports = {
Message,
@ -22,8 +25,9 @@ module.exports = {
model = null,
}) {
try {
if (!conversationId) {
return console.log('Message not saved: no conversationId');
const validConvoId = idSchema.safeParse(conversationId);
if (!validConvoId.success) {
return;
}
// may also need to update the conversation here
await Message.findOneAndUpdate(

View file

@ -54,6 +54,21 @@ sessionSchema.methods.generateRefreshToken = async function () {
}
};
sessionSchema.statics.deleteAllUserSessions = async function (userId) {
try {
if (!userId) {
return;
}
const result = await this.deleteMany({ user: userId });
if (result && result?.deletedCount > 0) {
console.log(`Deleted ${result.deletedCount} sessions for user ${userId}.`);
}
} catch (error) {
console.log('Error in deleting user sessions:', error);
throw error;
}
};
const Session = mongoose.model('Session', sessionSchema);
module.exports = Session;

View file

@ -6,6 +6,7 @@
"start": "echo 'please run this from the root directory'",
"server-dev": "echo 'please run this from the root directory'",
"test": "cross-env NODE_ENV=test jest",
"b:test": "NODE_ENV=test bun jest",
"test:ci": "jest --ci"
},
"repository": {
@ -45,7 +46,7 @@
"joi": "^17.9.2",
"js-yaml": "^4.1.0",
"jsonwebtoken": "^9.0.0",
"keyv": "^4.5.2",
"keyv": "^4.5.3",
"keyv-file": "^0.2.0",
"langchain": "^0.0.144",
"lodash": "^4.17.21",
@ -64,6 +65,7 @@
"pino": "^8.12.1",
"sanitize": "^2.1.2",
"sharp": "^0.32.5",
"ua-parser-js": "^1.0.36",
"zod": "^3.22.2"
},
"devDependencies": {

View file

@ -80,7 +80,7 @@ const refreshController = async (req, res) => {
const userId = payload.id;
const user = await User.findOne({ _id: userId });
if (!user) {
return res.status(401).send('User not found');
return res.status(401).redirect('/login');
}
if (process.env.NODE_ENV === 'development') {
@ -99,6 +99,8 @@ const refreshController = async (req, res) => {
const token = await setAuthTokens(userId, res, session._id);
const userObj = user.toJSON();
res.status(200).send({ token, user: userObj });
} else if (payload.exp > Date.now() / 1000) {
res.status(403).redirect('/login');
} else {
res.status(401).send('Refresh token expired or not found for this user');
}

View file

@ -1,16 +1,16 @@
const express = require('express');
const mongoSanitize = require('express-mongo-sanitize');
const connectDb = require('../lib/db/connectDb');
const indexSync = require('../lib/db/indexSync');
const { connectDb, indexSync } = require('../lib/db');
const path = require('path');
const cors = require('cors');
const routes = require('./routes');
const errorController = require('./controllers/ErrorController');
const passport = require('passport');
const configureSocialLogins = require('./socialLogins');
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
const port = Number(process.env.PORT) || 3080;
const host = process.env.HOST || 'localhost';
const port = Number(PORT) || 3080;
const host = HOST || 'localhost';
const projectPath = path.join(__dirname, '..', '..', 'client');
const { jwtLogin, passportLogin } = require('../strategies');
@ -31,7 +31,7 @@ const startServer = async () => {
app.set('trust proxy', 1); // trust first proxy
app.use(cors());
if (!process.env.ALLOW_SOCIAL_LOGIN) {
if (!ALLOW_SOCIAL_LOGIN) {
console.warn(
'Social logins are disabled. Set Envrionment Variable "ALLOW_SOCIAL_LOGIN" to true to enable them.',
);
@ -42,7 +42,7 @@ const startServer = async () => {
passport.use(await jwtLogin());
passport.use(passportLogin());
if (process.env.ALLOW_SOCIAL_LOGIN === 'true') {
if (ALLOW_SOCIAL_LOGIN?.toLowerCase() === 'true') {
configureSocialLogins(app);
}

View file

@ -1,6 +1,5 @@
const crypto = require('crypto');
const { saveMessage, getConvo, getConvoTitle } = require('../../models');
const { sendMessage, handleError } = require('../utils');
const { sendMessage, sendError } = require('../utils');
const abortControllers = require('./abortControllers');
async function abortMessage(req, res) {
@ -27,8 +26,9 @@ const handleAbort = () => {
};
};
const createAbortController = (res, req, endpointOption, getAbortData) => {
const createAbortController = (req, res, getAbortData) => {
const abortController = new AbortController();
const { endpointOption } = req.body;
const onStart = (userMessage) => {
sendMessage(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id;
@ -73,25 +73,23 @@ const handleAbortError = async (res, req, error, data) => {
const { sender, conversationId, messageId, parentMessageId, partialText } = data;
const respondWithError = async () => {
const errorMessage = {
const options = {
sender,
messageId: messageId ?? crypto.randomUUID(),
messageId,
conversationId,
parentMessageId,
unfinished: false,
cancelled: false,
error: true,
final: true,
text: error.message,
isCreatedByUser: false,
shouldSaveMessage: true,
};
if (abortControllers.has(conversationId)) {
const { abortController } = abortControllers.get(conversationId);
abortController.abort();
abortControllers.delete(conversationId);
}
await saveMessage(errorMessage);
handleError(res, errorMessage);
const callback = async () => {
if (abortControllers.has(conversationId)) {
const { abortController } = abortControllers.get(conversationId);
abortController.abort();
abortControllers.delete(conversationId);
}
};
await sendError(res, options, callback);
};
if (partialText && partialText.length > 5) {

View file

@ -0,0 +1,92 @@
const Keyv = require('keyv');
const uap = require('ua-parser-js');
const { getLogStores } = require('../../cache');
const denyRequest = require('./denyRequest');
const { isEnabled, removePorts } = require('../utils');
const banCache = new Keyv({ namespace: 'bans', ttl: 0 });
const message = 'Your account has been temporarily banned due to violations of our service.';
/**
* Respond to the request if the user is banned.
*
* @async
* @function
* @param {Object} req - Express Request object.
* @param {Object} res - Express Response object.
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
*
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
*/
const banResponse = async (req, res) => {
const ua = uap(req.headers['user-agent']);
const { baseUrl } = req;
if (!ua.browser.name) {
return res.status(403).json({ message });
} else if (baseUrl === '/api/ask' || baseUrl === '/api/edit') {
return await denyRequest(req, res, { type: 'ban' });
}
return res.status(403).json({ message });
};
/**
* Checks if the source IP or user is banned or not.
*
* @async
* @function
* @param {Object} req - Express request object.
* @param {Object} res - Express response object.
* @param {Function} next - Next middleware function.
*
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`.
*/
const checkBan = async (req, res, next = () => {}) => {
const { BAN_VIOLATIONS } = process.env ?? {};
if (!isEnabled(BAN_VIOLATIONS)) {
return next();
}
req.ip = removePorts(req);
const userId = req.user?.id ?? req.user?._id ?? null;
const cachedIPBan = await banCache.get(req.ip);
const cachedUserBan = await banCache.get(userId);
const cachedBan = cachedIPBan || cachedUserBan;
if (cachedBan) {
req.banned = true;
return await banResponse(req, res);
}
const banLogs = getLogStores('ban');
const duration = banLogs.opts.ttl;
if (duration <= 0) {
return next();
}
const ipBan = await banLogs.get(req.ip);
const userBan = await banLogs.get(userId);
const isBanned = ipBan || userBan;
if (!isBanned) {
return next();
}
const timeLeft = Number(isBanned.expiresAt) - Date.now();
if (timeLeft <= 0) {
await banLogs.delete(req.ip);
await banLogs.delete(userId);
return next();
}
banCache.set(req.ip, isBanned, timeLeft);
banCache.set(userId, isBanned, timeLeft);
req.banned = true;
return await banResponse(req, res);
};
module.exports = checkBan;

View file

@ -0,0 +1,81 @@
const Keyv = require('keyv');
const { logViolation } = require('../../cache');
const denyRequest = require('./denyRequest');
// Serve cache from memory so no need to clear it on startup/exit
const pendingReqCache = new Keyv({ namespace: 'pendingRequests' });
/**
* Middleware to limit concurrent requests for a user.
*
* This middleware checks if a user has exceeded a specified concurrent request limit.
* If the user exceeds the limit, an error is returned. If the user is within the limit,
* their request count is incremented. After the request is processed, the count is decremented.
* If the `pendingReqCache` store is not available, the middleware will skip its logic.
*
* @function
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {function} next - Express next middleware function.
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
*/
const concurrentLimiter = async (req, res, next) => {
if (!pendingReqCache) {
return next();
}
if (Object.keys(req?.body ?? {}).length === 1 && req?.body?.abortKey) {
return next();
}
const { CONCURRENT_MESSAGE_MAX = 1, CONCURRENT_VIOLATION_SCORE: score } = process.env;
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
const type = 'concurrent';
const userId = req.user?.id ?? req.user?._id ?? null;
const pendingRequests = (await pendingReqCache.get(userId)) ?? 0;
if (pendingRequests >= limit) {
const errorMessage = {
type,
limit,
pendingRequests,
};
await logViolation(req, res, type, errorMessage, score);
return await denyRequest(req, res, errorMessage);
} else {
await pendingReqCache.set(userId, pendingRequests + 1);
}
// Ensure the requests are removed from the store once the request is done
const cleanUp = async () => {
if (!pendingReqCache) {
return;
}
const currentRequests = await pendingReqCache.get(userId);
if (currentRequests && currentRequests >= 1) {
await pendingReqCache.set(userId, currentRequests - 1);
} else {
await pendingReqCache.delete(userId);
}
};
if (pendingRequests < limit) {
res.on('finish', cleanUp);
res.on('close', cleanUp);
}
next();
};
// if cache is not served from memory, clear it on exit
// process.on('exit', async () => {
// console.log('Clearing all pending requests before exiting...');
// await pendingReqCache.clear();
// });
module.exports = concurrentLimiter;

View file

@ -0,0 +1,58 @@
const crypto = require('crypto');
const { sendMessage, sendError } = require('../utils');
const { getResponseSender } = require('../routes/endpoints/schemas');
const { saveMessage } = require('../../models');
/**
* Denies a request by sending an error message and optionally saves the user's message.
*
* @async
* @function
* @param {Object} req - Express request object.
* @param {Object} req.body - The body of the request.
* @param {string} [req.body.messageId] - The ID of the message.
* @param {string} [req.body.conversationId] - The ID of the conversation.
* @param {string} [req.body.parentMessageId] - The ID of the parent message.
* @param {string} req.body.text - The text of the message.
* @param {Object} res - Express response object.
* @param {string} errorMessage - The error message to be sent.
* @returns {Promise<Object>} A promise that resolves with the error response.
* @throws {Error} Throws an error if there's an issue saving the message or sending the error.
*/
const denyRequest = async (req, res, errorMessage) => {
let responseText = errorMessage;
if (typeof errorMessage === 'object') {
responseText = JSON.stringify(errorMessage);
}
const { messageId, conversationId: _convoId, parentMessageId, text } = req.body;
const conversationId = _convoId ?? crypto.randomUUID();
const userMessage = {
sender: 'User',
messageId: messageId ?? crypto.randomUUID(),
parentMessageId,
conversationId,
isCreatedByUser: true,
text,
};
sendMessage(res, { message: userMessage, created: true });
const shouldSaveMessage =
_convoId && parentMessageId && parentMessageId !== '00000000-0000-0000-0000-000000000000';
if (shouldSaveMessage) {
await saveMessage(userMessage);
}
return await sendError(res, {
sender: getResponseSender(req.body),
messageId: crypto.randomUUID(),
conversationId,
parentMessageId: userMessage.messageId,
text: responseText,
shouldSaveMessage,
});
};
module.exports = denyRequest;

View file

@ -1,22 +1,30 @@
const abortMiddleware = require('./abortMiddleware');
const checkBan = require('./checkBan');
const uaParser = require('./uaParser');
const setHeaders = require('./setHeaders');
const loginLimiter = require('./loginLimiter');
const requireJwtAuth = require('./requireJwtAuth');
const registerLimiter = require('./registerLimiter');
const messageLimiters = require('./messageLimiters');
const requireLocalAuth = require('./requireLocalAuth');
const validateEndpoint = require('./validateEndpoint');
const concurrentLimiter = require('./concurrentLimiter');
const validateMessageReq = require('./validateMessageReq');
const buildEndpointOption = require('./buildEndpointOption');
const validateRegistration = require('./validateRegistration');
module.exports = {
...abortMiddleware,
...messageLimiters,
checkBan,
uaParser,
setHeaders,
loginLimiter,
requireJwtAuth,
registerLimiter,
requireLocalAuth,
validateEndpoint,
concurrentLimiter,
validateMessageReq,
buildEndpointOption,
validateRegistration,

View file

@ -1,16 +1,30 @@
const rateLimit = require('express-rate-limit');
const windowMs = (process.env?.LOGIN_WINDOW ?? 5) * 60 * 1000; // default: 5 minutes
const max = process.env?.LOGIN_MAX ?? 7; // default: limit each IP to 7 requests per windowMs
const { logViolation } = require('../../cache');
const { removePorts } = require('../utils');
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
const windowMs = LOGIN_WINDOW * 60 * 1000;
const max = LOGIN_MAX;
const windowInMinutes = windowMs / 60000;
const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`;
const handler = async (req, res) => {
const type = 'logins';
const errorMessage = {
type,
max,
windowInMinutes,
};
await logViolation(req, res, type, errorMessage, score);
return res.status(429).json({ message });
};
const loginLimiter = rateLimit({
windowMs,
max,
message: `Too many login attempts from this IP, please try again after ${windowInMinutes} minutes.`,
keyGenerator: function (req) {
// Strip out the port number from the IP address
return req.ip.replace(/:\d+[^:]*$/, '');
},
handler,
keyGenerator: removePorts,
});
module.exports = loginLimiter;

View file

@ -0,0 +1,67 @@
const rateLimit = require('express-rate-limit');
const { logViolation } = require('../../cache');
const denyRequest = require('./denyRequest');
const {
MESSAGE_IP_MAX = 40,
MESSAGE_IP_WINDOW = 1,
MESSAGE_USER_MAX = 40,
MESSAGE_USER_WINDOW = 1,
} = process.env;
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
const ipMax = MESSAGE_IP_MAX;
const ipWindowInMinutes = ipWindowMs / 60000;
const userWindowMs = MESSAGE_USER_WINDOW * 60 * 1000;
const userMax = MESSAGE_USER_MAX;
const userWindowInMinutes = userWindowMs / 60000;
/**
* Creates either an IP/User message request rate limiter for excessive requests
* that properly logs and denies the violation.
*
* @param {boolean} [ip=true] - Whether to create an IP limiter or a user limiter.
* @returns {function} A rate limiter function.
*
*/
const createHandler = (ip = true) => {
return async (req, res) => {
const type = 'message_limit';
const errorMessage = {
type,
max: ip ? ipMax : userMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
};
await logViolation(req, res, type, errorMessage);
return await denyRequest(req, res, errorMessage);
};
};
/**
* Message request rate limiter by IP
*/
const messageIpLimiter = rateLimit({
windowMs: ipWindowMs,
max: ipMax,
handler: createHandler(),
});
/**
* Message request rate limiter by userId
*/
const messageUserLimiter = rateLimit({
windowMs: userWindowMs,
max: userMax,
handler: createHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
});
module.exports = {
messageIpLimiter,
messageUserLimiter,
};

View file

@ -1,16 +1,30 @@
const rateLimit = require('express-rate-limit');
const windowMs = (process.env?.REGISTER_WINDOW ?? 60) * 60 * 1000; // default: 1 hour
const max = process.env?.REGISTER_MAX ?? 5; // default: limit each IP to 5 registrations per windowMs
const { logViolation } = require('../../cache');
const { removePorts } = require('../utils');
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
const windowMs = REGISTER_WINDOW * 60 * 1000;
const max = REGISTER_MAX;
const windowInMinutes = windowMs / 60000;
const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`;
const handler = async (req, res) => {
const type = 'registrations';
const errorMessage = {
type,
max,
windowInMinutes,
};
await logViolation(req, res, type, errorMessage, score);
return res.status(429).json({ message });
};
const registerLimiter = rateLimit({
windowMs,
max,
message: `Too many accounts created from this IP, please try again after ${windowInMinutes} minutes`,
keyGenerator: function (req) {
// Strip out the port number from the IP address
return req.ip.replace(/:\d+[^:]*$/, '');
},
handler,
keyGenerator: removePorts,
});
module.exports = registerLimiter;

View file

@ -0,0 +1,31 @@
const uap = require('ua-parser-js');
const { handleError } = require('../utils');
const { logViolation } = require('../../cache');
/**
* Middleware to parse User-Agent header and check if it's from a recognized browser.
* If the User-Agent is not recognized as a browser, logs a violation and sends an error response.
*
* @function
* @async
* @param {Object} req - Express request object.
* @param {Object} res - Express response object.
* @param {Function} next - Express next middleware function.
* @returns {void} Sends an error response if the User-Agent is not recognized as a browser.
*
* @example
* app.use(uaParser);
*/
async function uaParser(req, res, next) {
const { NON_BROWSER_VIOLATION_SCORE: score = 20 } = process.env;
const ua = uap(req.headers['user-agent']);
if (!ua.browser.name) {
const type = 'non_browser';
await logViolation(req, res, type, { type }, score);
return handleError(res, { message: 'Illegal request' });
}
next();
}
module.exports = uaParser;

View file

@ -7,138 +7,125 @@ const {
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = data.userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
getIds,
// debug: true,
user: req.user.id,
conversationId,
parentMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
onStart,
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveConvo(req.user.id, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'anthropic',
});
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = data.userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
},
);
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
getIds,
// debug: true,
user: req.user.id,
conversationId,
parentMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
onStart,
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveConvo(req.user.id, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'anthropic',
});
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
});
module.exports = router;

View file

@ -4,9 +4,9 @@ const router = express.Router();
const { browserClient } = require('../../../app/');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils');
const { requireJwtAuth, setHeaders } = require('../../middleware');
const { setHeaders } = require('../../middleware');
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
router.post('/', setHeaders, async (req, res) => {
const {
endpoint,
text,

View file

@ -4,9 +4,9 @@ const router = express.Router();
const { titleConvoBing, askBing } = require('../../../app');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('../../utils');
const { requireJwtAuth, setHeaders } = require('../../middleware');
const { setHeaders } = require('../../middleware');
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
router.post('/', setHeaders, async (req, res) => {
const {
endpoint,
text,

View file

@ -5,9 +5,9 @@ const { GoogleClient } = require('../../../app');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress } = require('../../utils');
const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService');
const { requireJwtAuth, setHeaders } = require('../../middleware');
const { setHeaders } = require('../../middleware');
router.post('/', requireJwtAuth, setHeaders, async (req, res) => {
router.post('/', setHeaders, async (req, res) => {
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body;
if (text.length === 0) {
return handleError(res, { text: 'Prompt empty or too short' });

View file

@ -11,218 +11,205 @@ const {
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const newConvo = !conversationId;
const user = req.user.id;
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const newConvo = !conversationId;
const user = req.user.id;
const plugins = [];
const plugins = [];
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
};
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
};
let streaming = null;
let timer = null;
let streaming = null;
let timer = null;
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (timer) {
clearTimeout(timer);
}
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
plugins,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
streaming = new Promise((resolve) => {
timer = setTimeout(() => {
resolve();
}, 250);
});
},
});
const pluginMap = new Map();
const onAgentAction = async (action, runId) => {
pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, { plugins });
};
const onToolStart = async (tool, input, runId, parentRunId) => {
const pluginName = pluginMap.get(parentRunId);
const latestPlugin = {
runId,
loading: true,
inputs: [input],
latest: pluginName,
outputs: null,
};
if (streaming) {
await streaming;
}
const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin);
sendIntermediateMessage(res, { plugins }, extraTokens);
};
const onToolEnd = async (output, runId) => {
if (streaming) {
await streaming;
if (timer) {
clearTimeout(timer);
}
const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
if (pluginIndex !== -1) {
plugins[pluginIndex].loading = false;
plugins[pluginIndex].outputs = output;
}
};
const onChainEnd = () => {
saveMessage(userMessage);
sendIntermediateMessage(res, { plugins });
};
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugins: plugins.map((p) => ({ ...p, loading: false })),
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user,
conversationId,
parentMessageId,
overrideParentMessageId,
getIds,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
onStart,
addMetadata,
getPartialText,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
plugins,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
console.log('CLIENT RESPONSE');
console.dir(response, { depth: null });
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
addTitle(req, {
text,
response,
client,
});
}
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
if (saveDelay < 500) {
saveDelay = 500;
}
streaming = new Promise((resolve) => {
timer = setTimeout(() => {
resolve();
}, 250);
});
},
});
const pluginMap = new Map();
const onAgentAction = async (action, runId) => {
pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, { plugins });
};
const onToolStart = async (tool, input, runId, parentRunId) => {
const pluginName = pluginMap.get(parentRunId);
const latestPlugin = {
runId,
loading: true,
inputs: [input],
latest: pluginName,
outputs: null,
};
if (streaming) {
await streaming;
}
const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin);
sendIntermediateMessage(res, { plugins }, extraTokens);
};
const onToolEnd = async (output, runId) => {
if (streaming) {
await streaming;
}
const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
if (pluginIndex !== -1) {
plugins[pluginIndex].loading = false;
plugins[pluginIndex].outputs = output;
}
};
const onChainEnd = () => {
saveMessage(userMessage);
sendIntermediateMessage(res, { plugins });
};
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugins: plugins.map((p) => ({ ...p, loading: false })),
userMessage,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user,
conversationId,
parentMessageId,
overrideParentMessageId,
getIds,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
onStart,
addMetadata,
getPartialText,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
plugins,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
console.log('CLIENT RESPONSE');
console.dir(response, { depth: null });
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
addTitle(req, {
text,
response,
client,
});
}
},
);
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
});
module.exports = router;

View file

@ -6,6 +6,33 @@ const bingAI = require('./bingAI');
const gptPlugins = require('./gptPlugins');
const askChatGPTBrowser = require('./askChatGPTBrowser');
const anthropic = require('./anthropic');
const {
uaParser,
checkBan,
requireJwtAuth,
concurrentLimiter,
messageIpLimiter,
messageUserLimiter,
} = require('../../middleware');
const { isEnabled } = require('../../utils');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
router.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
router.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(['/azureOpenAI', '/openAI'], openAI);
router.use('/google', google);

View file

@ -9,151 +9,138 @@ const {
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const newConvo = !conversationId;
const user = req.user.id;
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let userMessageId;
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const newConvo = !conversationId;
const user = req.user.id;
const addMetadata = (data) => (metadata = data);
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
};
const getIds = (data) => {
userMessage = data.userMessage;
userMessageId = userMessage.messageId;
responseMessageId = data.responseMessageId;
if (!conversationId) {
conversationId = data.conversationId;
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
try {
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user,
parentMessageId,
conversationId,
overrideParentMessageId,
getIds,
onStart,
addMetadata,
abortController,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
console.log(
'promptTokens, completionTokens:',
response.promptTokens,
response.completionTokens,
);
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
addTitle(req, {
text,
response,
client,
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
});
}
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user,
parentMessageId,
conversationId,
overrideParentMessageId,
getIds,
onStart,
addMetadata,
abortController,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
console.log(
'promptTokens, completionTokens:',
response.promptTokens,
response.completionTokens,
);
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
addTitle(req, {
text,
response,
client,
});
}
},
);
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
});
module.exports = router;

View file

@ -8,6 +8,7 @@ const {
const { loginController } = require('../controllers/auth/LoginController');
const { logoutController } = require('../controllers/auth/LogoutController');
const {
checkBan,
loginLimiter,
registerLimiter,
requireJwtAuth,
@ -19,9 +20,9 @@ const router = express.Router();
//Local
router.post('/logout', requireJwtAuth, logoutController);
router.post('/login', loginLimiter, requireLocalAuth, loginController);
router.post('/login', loginLimiter, checkBan, requireLocalAuth, loginController);
router.post('/refresh', refreshController);
router.post('/register', registerLimiter, validateRegistration, registrationController);
router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController);
router.post('/requestPasswordReset', resetPasswordRequestController);
router.post('/resetPassword', resetPasswordController);

View file

@ -4,12 +4,14 @@ const { getConvo, saveConvo } = require('../../models');
const { getConvosByPage, deleteConvos } = require('../../models/Conversation');
const requireJwtAuth = require('../middleware/requireJwtAuth');
router.get('/', requireJwtAuth, async (req, res) => {
router.use(requireJwtAuth);
router.get('/', async (req, res) => {
const pageNumber = req.query.pageNumber || 1;
res.status(200).send(await getConvosByPage(req.user.id, pageNumber));
});
router.get('/:conversationId', requireJwtAuth, async (req, res) => {
router.get('/:conversationId', async (req, res) => {
const { conversationId } = req.params;
const convo = await getConvo(req.user.id, conversationId);
@ -20,7 +22,7 @@ router.get('/:conversationId', requireJwtAuth, async (req, res) => {
}
});
router.post('/clear', requireJwtAuth, async (req, res) => {
router.post('/clear', async (req, res) => {
let filter = {};
const { conversationId, source } = req.body.arg;
if (conversationId) {
@ -43,7 +45,7 @@ router.post('/clear', requireJwtAuth, async (req, res) => {
}
});
router.post('/update', requireJwtAuth, async (req, res) => {
router.post('/update', async (req, res) => {
const update = req.body.arg;
try {

View file

@ -7,140 +7,127 @@ const {
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
responseMessageId = data.responseMessageId;
};
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
responseMessageId = data.responseMessageId;
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user: req.user.id,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
getIds,
onStart,
addMetadata,
abortController,
});
if (metadata) {
response = { ...response, ...metadata };
text: partialText,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
});
}
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
try {
const getAbortData = () => ({
conversationId,
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
const { abortController, onStart } = createAbortController(req, res, getAbortData);
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user: req.user.id,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId ?? userMessageId,
}),
getIds,
onStart,
addMetadata,
abortController,
});
if (metadata) {
response = { ...response, ...metadata };
}
},
);
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
});
module.exports = router;

View file

@ -10,183 +10,170 @@ const {
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
const user = req.user.id;
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
responseMessageId = data.responseMessageId;
};
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
responseMessageId = data.responseMessageId;
};
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (plugin.loading === true) {
plugin.loading = false;
}
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start) {
saveMessage(userMessage);
if (plugin.loading === true) {
plugin.loading = false;
}
sendIntermediateMessage(res, { plugin });
// console.log('PLUGIN ACTION', formattedAction);
};
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage(userMessage);
sendIntermediateMessage(res, { plugin });
// console.log('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getIds,
onAgentAction,
onChainEnd,
onStart,
addMetadata,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
plugin,
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
});
}
if (metadata) {
response = { ...response, ...metadata };
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
console.log('CLIENT RESPONSE');
console.dir(response, { depth: null });
response.plugin = { ...plugin, loading: false };
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start) {
saveMessage(userMessage);
}
},
);
sendIntermediateMessage(res, { plugin });
// console.log('PLUGIN ACTION', formattedAction);
};
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage(userMessage);
sendIntermediateMessage(res, { plugin });
// console.log('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getIds,
onAgentAction,
onChainEnd,
onStart,
addMetadata,
...endpointOption,
onProgress: progressCallback.call(null, {
res,
text,
plugin,
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
if (metadata) {
response = { ...response, ...metadata };
}
console.log('CLIENT RESPONSE');
console.dir(response, { depth: null });
response.plugin = { ...plugin, loading: false };
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
});
module.exports = router;

View file

@ -3,11 +3,36 @@ const router = express.Router();
const openAI = require('./openAI');
const gptPlugins = require('./gptPlugins');
const anthropic = require('./anthropic');
// const google = require('./google');
const {
checkBan,
uaParser,
requireJwtAuth,
concurrentLimiter,
messageIpLimiter,
messageUserLimiter,
} = require('../../middleware');
const { isEnabled } = require('../../utils');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
router.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
router.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(['/azureOpenAI', '/openAI'], openAI);
router.use('/gptPlugins', gptPlugins);
router.use('/anthropic', anthropic);
// router.use('/google', google);
module.exports = router;

View file

@ -9,140 +9,127 @@ const {
createAbortController,
handleAbortError,
setHeaders,
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
router.post('/abort', requireJwtAuth, handleAbort());
router.post('/abort', handleAbort());
router.post(
'/',
requireJwtAuth,
validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
console.log('edit log');
console.dir({ text, generation, isContinued, conversationId, endpointOption }, { depth: null });
let metadata;
let userMessage;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const userMessageId = parentMessageId;
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
responseMessageId = data.responseMessageId;
};
const addMetadata = (data) => (metadata = data);
const getIds = (data) => {
userMessage = data.userMessage;
responseMessageId = data.responseMessageId;
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(
res,
req,
endpointOption,
getAbortData,
);
try {
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user: req.user.id,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getIds,
onStart,
addMetadata,
abortController,
onProgress: progressCallback.call(null, {
res,
text,
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: getResponseSender(endpointOption),
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
}),
});
if (metadata) {
response = { ...response, ...metadata };
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
});
}
console.log(
'promptTokens, completionTokens:',
response.promptTokens,
response.completionTokens,
);
await saveMessage(response);
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
const getAbortData = () => ({
sender: getResponseSender(endpointOption),
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
const { client } = await initializeClient(req, endpointOption);
let response = await client.sendMessage(text, {
user: req.user.id,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getIds,
onStart,
addMetadata,
abortController,
onProgress: progressCallback.call(null, {
res,
text,
parentMessageId: overrideParentMessageId || userMessageId,
}),
});
if (metadata) {
response = { ...response, ...metadata };
}
},
);
console.log(
'promptTokens, completionTokens:',
response.promptTokens,
response.completionTokens,
);
await saveMessage(response);
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
res.end();
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender: getResponseSender(endpointOption),
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
});
module.exports = router;

View file

@ -3,8 +3,24 @@ const express = require('express');
const router = express.Router();
const config = require('../../../config/loader');
const { setAuthTokens } = require('../services/AuthService');
const { loginLimiter, checkBan } = require('../middleware');
const domains = config.domains;
router.use(loginLimiter);
const oauthHandler = async (req, res) => {
try {
await checkBan(req, res);
if (req.banned) {
return;
}
await setAuthTokens(req.user._id, res);
res.redirect(domains.client);
} catch (err) {
console.error('Error in setting authentication tokens:', err);
}
};
/**
* Google Routes
*/
@ -24,14 +40,7 @@ router.get(
session: false,
scope: ['openid', 'profile', 'email'],
}),
async (req, res) => {
try {
await setAuthTokens(req.user._id, res);
res.redirect(domains.client);
} catch (err) {
console.error('Error in setting authentication tokens:', err);
}
},
oauthHandler,
);
router.get(
@ -52,14 +61,7 @@ router.get(
scope: ['public_profile'],
profileFields: ['id', 'email', 'name'],
}),
async (req, res) => {
try {
await setAuthTokens(req.user._id, res);
res.redirect(domains.client);
} catch (err) {
console.error('Error in setting authentication tokens:', err);
}
},
oauthHandler,
);
router.get(
@ -76,14 +78,7 @@ router.get(
failureMessage: true,
session: false,
}),
async (req, res) => {
try {
await setAuthTokens(req.user._id, res);
res.redirect(domains.client);
} catch (err) {
console.error('Error in setting authentication tokens:', err);
}
},
oauthHandler,
);
router.get(
@ -102,14 +97,7 @@ router.get(
session: false,
scope: ['user:email', 'read:user'],
}),
async (req, res) => {
try {
await setAuthTokens(req.user._id, res);
res.redirect(domains.client);
} catch (err) {
console.error('Error in setting authentication tokens:', err);
}
},
oauthHandler,
);
router.get(
'/discord',
@ -127,14 +115,7 @@ router.get(
session: false,
scope: ['identify', 'email'],
}),
async (req, res) => {
try {
await setAuthTokens(req.user._id, res);
res.redirect(domains.client);
} catch (err) {
console.error('Error in setting authentication tokens:', err);
}
},
oauthHandler,
);
module.exports = router;

View file

@ -1,22 +1,11 @@
const partialRight = require('lodash/partialRight');
const citationRegex = /\[\^\d+?\^]/g;
const { getCitations, citeText } = require('./citations');
const { sendMessage } = require('./streamResponse');
const cursor = '<span className="result-streaming">█</span>';
const citationRegex = /\[\^\d+?\^]/g;
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
const handleError = (res, message) => {
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
res.end();
};
const sendMessage = (res, message, event = 'message') => {
if (message.length === 0) {
return;
}
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
};
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
let i = 0;
let code = '';
@ -148,10 +137,27 @@ function formatAction(action) {
return formattedAction;
}
/**
* Checks if the given string value is truthy by comparing it to the string 'true' (case-insensitive).
*
* @function
* @param {string|null|undefined} value - The string value to check.
* @returns {boolean} Returns `true` if the value is a case-insensitive match for the string 'true', otherwise returns `false`.
* @example
*
* isEnabled("True"); // returns true
* isEnabled("TRUE"); // returns true
* isEnabled("false"); // returns false
* isEnabled(null); // returns false
* isEnabled(); // returns false
*/
function isEnabled(value) {
return value?.toLowerCase()?.trim() === 'true';
}
module.exports = {
handleError,
sendMessage,
createOnProgress,
isEnabled,
handleText,
formatSteps,
formatAction,

View file

@ -1,11 +1,17 @@
const cryptoUtils = require('./crypto');
const streamResponse = require('./streamResponse');
const removePorts = require('./removePorts');
const handleText = require('./handleText');
const cryptoUtils = require('./crypto');
const citations = require('./citations');
const sendEmail = require('./sendEmail');
const math = require('./math');
module.exports = {
...streamResponse,
...cryptoUtils,
...handleText,
...citations,
removePorts,
sendEmail,
math,
};

48
api/server/utils/math.js Normal file
View file

@ -0,0 +1,48 @@
/**
* Evaluates a mathematical expression provided as a string and returns the result.
*
* If the input is already a number, it returns the number as is.
* If the input is not a string or contains invalid characters, an error is thrown.
* If the evaluated result is not a number, an error is thrown.
*
* @param {string|number} str - The mathematical expression to evaluate, or a number.
* @param {number} [fallbackValue] - The default value to return if the input is not a string or number, or if the evaluated result is not a number.
*
* @returns {number} The result of the evaluated expression or the input number.
*
* @throws {Error} Throws an error if the input is not a string or number, contains invalid characters, or does not evaluate to a number.
*/
function math(str, fallbackValue) {
const fallback = typeof fallbackValue !== 'undefined' && typeof fallbackValue === 'number';
if (typeof str !== 'string' && typeof str === 'number') {
return str;
} else if (typeof str !== 'string') {
if (fallback) {
return fallbackValue;
}
throw new Error(`str is ${typeof str}, but should be a string`);
}
const validStr = /^[+\-\d.\s*/%()]+$/.test(str);
if (!validStr) {
if (fallback) {
return fallbackValue;
}
throw new Error('Invalid characters in string');
}
const value = eval(str);
if (typeof value !== 'number') {
if (fallback) {
return fallbackValue;
}
console.error('str', str);
throw new Error(`str did not evaluate to a number but to a ${typeof value}`);
}
return value;
}
module.exports = math;

View file

@ -0,0 +1 @@
module.exports = (req) => req.ip.replace(/:\d+[^:]*$/, '');

View file

@ -0,0 +1,63 @@
const crypto = require('crypto');
const { saveMessage } = require('../../models');
/**
* Sends error data in Server Sent Events format and ends the response.
* @param {object} res - The server response.
* @param {string} message - The error message.
*/
const handleError = (res, message) => {
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
res.end();
};
/**
* Sends message data in Server Sent Events format.
* @param {object} res - - The server response.
* @param {string} message - The message to be sent.
* @param {string} event - [Optional] The type of event. Default is 'message'.
*/
const sendMessage = (res, message, event = 'message') => {
if (message.length === 0) {
return;
}
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
};
/**
* Processes an error with provided options, saves the error message and sends a corresponding SSE response
* @async
* @param {object} res - The server response.
* @param {object} options - The options for handling the error containing message properties.
* @param {function} callback - [Optional] The callback function to be executed.
*/
const sendError = async (res, options, callback) => {
const { sender, conversationId, messageId, parentMessageId, text, shouldSaveMessage } = options;
const errorMessage = {
sender,
messageId: messageId ?? crypto.randomUUID(),
conversationId,
parentMessageId,
unfinished: false,
cancelled: false,
error: true,
final: true,
text,
isCreatedByUser: false,
};
if (callback && typeof callback === 'function') {
await callback();
}
if (shouldSaveMessage) {
await saveMessage(errorMessage);
}
handleError(res, errorMessage);
};
module.exports = {
handleError,
sendMessage,
sendError,
};

View file

@ -7,3 +7,7 @@ CREDS_IV=cd02538f4be2fa37aba9420b5924389f
# For testing the ChatAgent
OPENAI_API_KEY=your-api-key
BAN_VIOLATIONS=true
BAN_DURATION=7200000
BAN_INTERVAL=20

View file

@ -0,0 +1,30 @@
const mockGet = jest.fn();
const mockSet = jest.fn();
jest.mock('@keyv/mongo', () => {
const EventEmitter = require('events');
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = mockGet;
set = mockSet;
}
return KeyvMongo;
});

View file

@ -9,6 +9,7 @@
"dev": "cross-env NODE_ENV=development dotenv -e ../.env -- vite",
"preview-prod": "cross-env NODE_ENV=development dotenv -e ../.env -- vite preview",
"test": "cross-env NODE_ENV=test jest --watch",
"b:test": "NODE_ENV=test bun jest --watch",
"test:ci": "cross-env NODE_ENV=test jest --ci",
"b:build": "NODE_ENV=production bun vite build",
"b:dev": "NODE_ENV=development bun vite"

View file

@ -1,4 +1,11 @@
import type { TConversation, TMessage, TPreset, TMutation } from 'librechat-data-provider';
import type {
TConversation,
TMessage,
TPreset,
TMutation,
TLoginUser,
TUser,
} from 'librechat-data-provider';
export type TSetOption = (param: number | string) => (newValue: number | string | boolean) => void;
export type TSetExample = (
@ -146,3 +153,28 @@ export type TDialogProps = {
open: boolean;
onOpenChange: (open: boolean) => void;
};
export type TResError = {
response: { data: { message: string } };
message: string;
};
export type TAuthContext = {
user: TUser | undefined;
token: string | undefined;
isAuthenticated: boolean;
error: string | undefined;
login: (data: TLoginUser) => void;
logout: () => void;
};
export type TUserContext = {
user?: TUser | undefined;
token: string | undefined;
isAuthenticated: boolean;
redirect?: string;
};
export type TAuthConfig = {
loginRedirect: string;
};

View file

@ -5,6 +5,7 @@ import { useNavigate } from 'react-router-dom';
import { useLocalize } from '~/hooks';
import { useGetStartupConfig } from 'librechat-data-provider';
import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components';
import { getLoginError } from '~/utils';
function Login() {
const { login, error, isAuthenticated } = useAuthContext();
@ -30,9 +31,7 @@ function Login() {
className="relative mt-4 rounded border border-red-400 bg-red-100 px-4 py-3 text-red-700"
role="alert"
>
{error?.includes('429')
? localize('com_auth_error_login_rl')
: localize('com_auth_error_login')}
{localize(getLoginError(error))}
</div>
)}
<LoginForm onSubmit={login} />

View file

@ -1,20 +1,28 @@
import { Fragment } from 'react';
import type { TResPlugin } from 'librechat-data-provider';
import type { TMessageContent, TText, TDisplayProps } from '~/common';
import { cn, getError } from '~/utils';
import { useAuthContext } from '~/hooks';
import { cn, getMessageError } from '~/utils';
import EditMessage from './EditMessage';
import Container from './Container';
import Markdown from './Markdown';
import Plugin from './Plugin';
// Error Message Component
const ErrorMessage = ({ text }: TText) => (
<Container>
<div className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-100">
{getError(text)}
</div>
</Container>
);
const ErrorMessage = ({ text }: TText) => {
const { logout } = useAuthContext();
if (text.includes('ban')) {
logout();
return null;
}
return (
<Container>
<div className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-100">
{getMessageError(text)}
</div>
</Container>
);
};
// Display Message Component
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => (

View file

@ -1,7 +1,7 @@
import {
useMemo,
useState,
useEffect,
useMemo,
ReactNode,
useCallback,
createContext,
@ -17,33 +17,14 @@ import {
useRefreshTokenMutation,
TLoginUser,
} from 'librechat-data-provider';
import { TAuthConfig, TUserContext, TAuthContext, TResError } from '~/common';
import { useNavigate } from 'react-router-dom';
import useTimeout from './useTimeout';
export type TAuthContext = {
user: TUser | undefined;
token: string | undefined;
isAuthenticated: boolean;
error: string | undefined;
login: (data: TLoginUser) => void;
logout: () => void;
};
export type TUserContext = {
user?: TUser | undefined;
token: string | undefined;
isAuthenticated: boolean;
redirect?: string;
};
export type TAuthConfig = {
loginRedirect: string;
};
//@ts-ignore - index expression is not of type number
window['errorTimeout'] = undefined;
const AuthContext = createContext<TAuthContext | undefined>(undefined);
const AuthContextProvider = ({
authConfig,
// authConfig,
children,
}: {
authConfig?: TAuthConfig;
@ -61,16 +42,7 @@ const AuthContextProvider = ({
const userQuery = useGetUserQuery({ enabled: !!token });
const refreshToken = useRefreshTokenMutation();
// This seems to prevent the error flashing issue
const doSetError = (error: string | undefined) => {
if (error) {
console.log(error);
// set timeout to ensure we don't get a flash of the error message
window['errorTimeout'] = setTimeout(() => {
setError(error);
}, 400);
}
};
const doSetError = useTimeout({ callback: (error) => setError(error as string | undefined) });
const setUserContext = useCallback(
(userContext: TUserContext) => {
@ -89,19 +61,15 @@ const AuthContextProvider = ({
[navigate],
);
const getCookieValue = (key: string) => {
const keyValue = document.cookie.match('(^|;) ?' + key + '=([^;]*)(;|$)');
return keyValue ? keyValue[2] : null;
};
const login = (data: TLoginUser) => {
loginUser.mutate(data, {
onSuccess: (data: TLoginResponse) => {
const { user, token } = data;
setUserContext({ token, isAuthenticated: true, user, redirect: '/chat/new' });
},
onError: (error) => {
doSetError((error as Error).message);
onError: (error: TResError | unknown) => {
const resError = error as TResError;
doSetError(resError.message);
navigate('/login', { replace: true });
},
});
@ -119,6 +87,12 @@ const AuthContextProvider = ({
},
onError: (error) => {
doSetError((error as Error).message);
setUserContext({
token: undefined,
isAuthenticated: false,
user: undefined,
redirect: '/login',
});
},
});
}, [setUserContext, logoutUser]);

View file

@ -2,6 +2,7 @@ export * from './AuthContext';
export * from './ThemeContext';
export * from './ScreenshotContext';
export * from './ApiErrorBoundaryContext';
export { default as useTimeout } from './useTimeout';
export { default as useUserKey } from './useUserKey';
export { default as useDebounce } from './useDebounce';
export { default as useLocalize } from './useLocalize';

View file

@ -0,0 +1,39 @@
import { useEffect, useRef } from 'react';
type TUseTimeoutParams = {
callback: (error: string | number | boolean | null) => void;
delay?: number | undefined;
};
type TTimeout = ReturnType<typeof setTimeout> | null;
function useTimeout({ callback, delay = 400 }: TUseTimeoutParams) {
const timeout = useRef<TTimeout>(null);
const callOnTimeout = (value: string | undefined) => {
// Clear existing timeout
if (timeout.current !== null) {
clearTimeout(timeout.current);
}
// Set new timeout
if (value) {
console.log(value);
timeout.current = setTimeout(() => {
callback(value);
}, delay);
}
};
// Clear timeout when the component unmounts
useEffect(() => {
return () => {
if (timeout.current !== null) {
clearTimeout(timeout.current);
}
};
}, []);
return callOnTimeout;
}
export default useTimeout;

View file

@ -52,7 +52,11 @@ export default {
com_auth_error_login:
'Unable to login with the information provided. Please check your credentials and try again.',
com_auth_error_login_rl:
'Too many login attempts from this IP in a short amount of time. Please try again later.',
'Too many login attempts in a short amount of time. Please try again later.',
com_auth_error_login_ban:
'Your account has been temporarily banned due to violations of our service.',
com_auth_error_login_server:
'There was an internal server error. Please wait a few moments and try again.',
com_auth_no_account: 'Don\'t have an account?',
com_auth_sign_up: 'Sign up',
com_auth_sign_in: 'Sign in',

View file

@ -1,28 +0,0 @@
const isJson = (str: string) => {
try {
JSON.parse(str);
} catch (e) {
return false;
}
return true;
};
const getError = (text: string) => {
const errorMessage = text.length > 512 ? text.slice(0, 512) + '...' : text;
const match = text.match(/\{[^{}]*\}/);
const jsonString = match ? match[0] : '';
if (isJson(jsonString)) {
const json = JSON.parse(jsonString);
if (json.code === 'invalid_api_key') {
return 'Invalid API key. Please check your API key and try again. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.';
} else if (json.type === 'insufficient_quota') {
return 'We apologize for any inconvenience caused. The default API key has reached its limit. To continue using this service, please set up your own API key. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.';
} else {
return `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`;
}
} else {
return `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`;
}
};
export default getError;

View file

@ -0,0 +1,18 @@
const getLoginError = (errorText: string) => {
const defaultError = 'com_auth_error_login';
if (!errorText) {
return defaultError;
}
if (errorText?.includes('429')) {
return 'com_auth_error_login_rl';
} else if (errorText?.includes('403')) {
return 'com_auth_error_login_ban';
} else if (errorText?.includes('500')) {
return 'com_auth_error_login_server';
} else {
return defaultError;
}
};
export default getLoginError;

View file

@ -0,0 +1,62 @@
const isJson = (str: string) => {
try {
JSON.parse(str);
} catch (e) {
return false;
}
return true;
};
type TConcurrent = {
limit: number;
};
type TMessageLimit = {
max: number;
windowInMinutes: number;
};
const errorMessages = {
ban: 'Your account has been temporarily banned due to violations of our service.',
invalid_api_key:
'Invalid API key. Please check your API key and try again. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.',
insufficient_quota:
'We apologize for any inconvenience caused. The default API key has reached its limit. To continue using this service, please set up your own API key. You can do this by clicking on the model logo in the left corner of the textbox and selecting "Set Token" for the current selected endpoint. Thank you for your understanding.',
concurrent: (json: TConcurrent) => {
const { limit } = json;
const plural = limit > 1 ? 's' : '';
return `Only ${limit} message${plural} at a time. Please allow any other responses to complete before sending another message, or wait one minute.`;
},
message_limit: (json: TMessageLimit) => {
const { max, windowInMinutes } = json;
const plural = max > 1 ? 's' : '';
return `You hit the message limit. You have a cap of ${max} message${plural} per ${
windowInMinutes > 1 ? `${windowInMinutes} minutes` : 'minute'
}.`;
},
};
const getMessageError = (text: string) => {
const errorMessage = text.length > 512 ? text.slice(0, 512) + '...' : text;
const match = text.match(/\{[^{}]*\}/);
const jsonString = match ? match[0] : '';
const defaultResponse = `Something went wrong. Here's the specific error message we encountered: ${errorMessage}`;
if (!isJson(jsonString)) {
return defaultResponse;
}
const json = JSON.parse(jsonString);
const errorKey = json.code || json.type;
const keyExists = errorKey && errorMessages[errorKey];
if (keyExists && typeof errorMessages[errorKey] === 'function') {
return errorMessages[errorKey](json);
} else if (keyExists) {
return errorMessages[errorKey];
} else {
return defaultResponse;
}
};
export default getMessageError;

View file

@ -2,10 +2,11 @@ import { clsx } from 'clsx';
import { twMerge } from 'tailwind-merge';
export * from './languages';
export { default as getError } from './getError';
export { default as buildTree } from './buildTree';
export { default as getLoginError } from './getLoginError';
export { default as cleanupPreset } from './cleanupPreset';
export { default as validateIframe } from './validateIframe';
export { default as getMessageError } from './getMessageError';
export { default as getLocalStorageItems } from './getLocalStorageItems';
export { default as getDefaultConversation } from './getDefaultConversation';

View file

@ -0,0 +1,67 @@
## Automated Moderation System (optional)
The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching a set threshold, the user and their IP are temporarily banned. This system ensures platform security by monitoring and penalizing rapid or suspicious activities.
In production, you should have Cloudflare or some other DDoS protection in place to really protect the server from excessive requests, but these changes will largely protect you from the single or several bad actors targeting your deployed instance for proxying.
### Notes
- Uses Caching for basic security and violation logging (bans, concurrent messages, exceeding rate limits)
- In the near future, I will add **Redis** support for production instances, which can be easily injected into the current caching setup
- Exceeding any of the rate limiters (login/registration/messaging) is considered a violation, default score is 1
- Non-browser origin is a violation
- Default score for each violation is configurable
- Enabling any of the limiters and/or bans enables caching/logging
- Violation logs can be found in the data folder, which is created when logging begins: `librechat/data`
- **Only violations are logged**
- `violations.json` keeps track of the total count for each violation per user
- `logs.json` records each individual violation per user
- Ban logs are stored in MongoDB under the `logs` collection. They are transient as they only exist for the ban duration
- If you would like to remove a ban manually, you would have to remove them from the database manually and restart the server
- **Redis** support is also planned for this.
### Rate Limiters
The project's current rate limiters are as follows (see below under setup for default values):
- Login and registration rate limiting
- [optional] Concurrent Message limiting (only X messages at a time per user)
- [optional] Message limiting (how often a user can send a message, configurable by IP and User)
### Setup
The following are all of the related env variables to make use of and configure the mod system. Note this is also found in the [/.env.example](/.env.example) file, to be set in your own `.env` file.
```bash
BAN_VIOLATIONS=true # Whether or not to enable banning users for violations (they will still be logged)
BAN_DURATION=1000 * 60 * 60 * 2 # how long the user and associated IP are banned for
BAN_INTERVAL=20 # a user will be banned everytime their score reaches/crosses over the interval threshold
# The score for each violation
LOGIN_VIOLATION_SCORE=1
REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20
# Login and registration rate limiting.
LOGIN_MAX=7 # The max amount of logins allowed per IP per LOGIN_WINDOW
LOGIN_WINDOW=5 # in minutes, determines the window of time for LOGIN_MAX logins
REGISTER_MAX=5 # The max amount of registrations allowed per IP per REGISTER_WINDOW
REGISTER_WINDOW=60 # in minutes, determines the window of time for REGISTER_MAX registrations
# Message rate limiting (per user & IP)
LIMIT_CONCURRENT_MESSAGES=true # Whether to limit the amount of messages a user can send per request
CONCURRENT_MESSAGE_MAX=2 # The max amount of messages a user can send per request
LIMIT_MESSAGE_IP=true # Whether to limit the amount of messages an IP can send per MESSAGE_IP_WINDOW
MESSAGE_IP_MAX=40 # The max amount of messages an IP can send per MESSAGE_IP_WINDOW
MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_MAX messages
# Note: You can utilize both limiters, but default is to limit by IP only.
LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW
MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW
MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages
```

View file

@ -9,15 +9,29 @@ In order for the auth system to function properly, there are some environment va
In /.env, you will need to set the following variables:
```bash
# Change this to a secure string
# Change the secrets to a secure, random string
JWT_SECRET=secret
JWT_REFRESH_SECRET=refresh_secret
# Set the expiration delay for the secure cookie with the JWT token
# Delay is in millisecond e.g. 7 days is 1000*60*60*24*7
SESSION_EXPIRY=1000 * 60 * 60 * 24 * 7
# Delay is in milliseconds e.g. 7 days is 1000*60*60*24*7
# Recommended session expiry is 15 minutes. Make it longer if you want the user to be able to revist the page without logging in for a longer duration of time.
# Recommended refresh token expiry is 7 days
SESSION_EXPIRY=1000 * 60 * 15
REFRESH_TOKEN_EXPIRY=(1000 * 60 * 60 * 24) * 7
DOMAIN_SERVER=http://localhost:3080
DOMAIN_CLIENT=http://localhost:3080
```
## Automated Moderation System (optional)
The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching a set threshold, the user and their IP are temporarily banned. This system ensures platform security by monitoring and penalizing rapid or suspicious activities.
To set up the mod system, review [the setup guide](../features/mod_system.md).
*Please Note: If you are wanting this to work in development mode, you will need to create a file called `.env.development` in the root directory and set `DOMAIN_CLIENT` to `http://localhost:3090` or whatever port is provided by vite when runnning `npm run frontend-dev`*
Important: When you run the app for the first time, you need to create a new account by clicking on "Sign up" on the login page. The first account you make will be the admin account. The admin account doesn't have any special features right now, but it might be useful if you want to make an admin dashboard to manage other users later.

View file

@ -102,7 +102,8 @@ nav:
- Azure Cognitive Search: 'features/plugins/azure_cognitive_search.md'
- Make Your Own Plugin: 'features/plugins/make_your_own.md'
- Using official ChatGPT Plugins: 'features/plugins/chatgpt_plugins_openapi.md'
- Third-Party Tools: 'features/third-party.md'
- Automated Moderation: 'features/mod_system.md'
- Third-Party Tools: 'features/third_party.md'
- Proxy: 'features/proxy.md'
- Bing Jailbreak: 'features/bing_jailbreak.md'
- Cloud Deployment:

25
package-lock.json generated
View file

@ -72,7 +72,7 @@
"joi": "^17.9.2",
"js-yaml": "^4.1.0",
"jsonwebtoken": "^9.0.0",
"keyv": "^4.5.2",
"keyv": "^4.5.3",
"keyv-file": "^0.2.0",
"langchain": "^0.0.144",
"lodash": "^4.17.21",
@ -91,6 +91,7 @@
"pino": "^8.12.1",
"sanitize": "^2.1.2",
"sharp": "^0.32.5",
"ua-parser-js": "^1.0.36",
"zod": "^3.22.2"
},
"devDependencies": {
@ -25069,6 +25070,28 @@
"node": ">=14.17"
}
},
"node_modules/ua-parser-js": {
"version": "1.0.36",
"resolved": "https://registry.npmjs.org/ua-parser-js/-/ua-parser-js-1.0.36.tgz",
"integrity": "sha512-znuyCIXzl8ciS3+y3fHJI/2OhQIXbXw9MWC/o3qwyR+RGppjZHrM27CGFSKCJXi2Kctiz537iOu2KnXs1lMQhw==",
"funding": [
{
"type": "opencollective",
"url": "https://opencollective.com/ua-parser-js"
},
{
"type": "paypal",
"url": "https://paypal.me/faisalman"
},
{
"type": "github",
"url": "https://github.com/sponsors/faisalman"
}
],
"engines": {
"node": "*"
}
},
"node_modules/uglify-js": {
"version": "3.17.4",
"resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.17.4.tgz",

View file

@ -49,8 +49,8 @@
"b:data-provider": "cd packages/data-provider && bun run b:build",
"b:client": "bun run b:data-provider && cd client && bun run b:build",
"b:client:dev": "cd client && bun run b:dev",
"b:test:client": "cd client && bun run test",
"b:test:api": "cd api && bun run test"
"b:test:client": "cd client && bun run b:test",
"b:test:api": "cd api && bun run b:test"
},
"repository": {
"type": "git",
@ -92,7 +92,7 @@
"nodemonConfig": {
"ignore": [
"api/data/",
"data",
"data/",
"client/",
"admin/",
"packages/"

View file

@ -1,6 +1,6 @@
{
"name": "librechat-data-provider",
"version": "0.1.7",
"version": "0.1.8",
"description": "data services for librechat apps",
"main": "dist/index.js",
"module": "dist/index.es.js",