feat(api): initial Redis support; fix(SearchBar): proper debounce (#1039)

* refactor: use keyv for search caching with 1 min expirations

* feat: keyvRedis; chore: bump keyv, bun.lockb, add jsconfig for vscode file resolution

* feat: api/search redis support

* refactor(redis) use ioredis cluster for keyv
fix(OpenID): when redis is configured, use redis memory store for express-session

* fix: revert using uri for keyvredis

* fix(SearchBar): properly debounce search queries, fix weird render behaviors

* refactor: add authentication to search endpoint and show error messages in results

* feat: redis support for violation logs

* fix(logViolation): ensure a number is always being stored in cache

* feat(concurrentLimiter): uses clearPendingReq, clears pendingReq on abort, redis support

* fix(api/search/enable): query only when authenticated

* feat(ModelService): redis support

* feat(checkBan): redis support

* refactor(api/search): consolidate keyv logic

* fix(ci): add default empty value for REDIS_URI

* refactor(keyvRedis): use condition to initialize keyvRedis assignment

* refactor(connectDb): handle disconnected state (should create a new conn)

* fix(ci/e2e): handle case where cleanUp did not successfully run

* fix(getDefaultEndpoint): return endpoint from localStorage if defined and endpointsConfig is default

* ci(e2e): remove afterAll messages as startup/cleanUp will clear messages

* ci(e2e): remove teardown for CI until further notice

* chore: bump playwright/test

* ci(e2e): reinstate teardown as CI issue is specific to github env

* fix(ci): click settings menu trigger by testid
This commit is contained in:
Danny Avila 2023-10-11 17:05:47 -04:00 committed by GitHub
parent 4ac0c04e83
commit 5145121eb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 461 additions and 171 deletions

View file

@ -1,29 +1,48 @@
const Keyv = require('keyv');
const { pendingReqFile } = require('./keyvFiles');
const { LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};
const keyv = new Keyv({ store: pendingReqFile, namespace: 'pendingRequests' });
const getLogStores = require('./getLogStores');
const { isEnabled } = require('../server/utils');
const { USE_REDIS, LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};
const ttl = 1000 * 60 * 1;
/**
* Clear pending requests from the cache.
* Clear or decrement pending requests from the cache.
* Checks the environmental variable LIMIT_CONCURRENT_MESSAGES;
* if the rule is enabled ('true'), pending requests in the cache are cleared.
* if the rule is enabled ('true'), it either decrements the count of pending requests
* or deletes the key if the count is less than or equal to 1.
*
* @module clearPendingReq
* @requires keyv
* @requires keyvFiles
* @requires ./getLogStores
* @requires ../server/utils
* @requires process
*
* @async
* @function
* @returns {Promise<void>} A promise that either clears 'pendingRequests' from store or resolves with no value.
* @param {Object} params - The parameters object.
* @param {string} params.userId - The user ID for which the pending requests are to be cleared or decremented.
* @param {Object} [params.cache] - An optional cache object to use. If not provided, a default cache will be fetched using getLogStores.
* @returns {Promise<void>} A promise that either decrements the 'pendingRequests' count, deletes the key from the store, or resolves with no value.
*/
const clearPendingReq = async () => {
if (LIMIT_CONCURRENT_MESSAGES?.toLowerCase() !== 'true') {
const clearPendingReq = async ({ userId, cache: _cache }) => {
if (!userId) {
return;
} else if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
return;
}
await keyv.clear();
const namespace = 'pending_req';
const cache = _cache ?? getLogStores(namespace);
if (!cache) {
return;
}
const key = `${USE_REDIS ? namespace : ''}:${userId ?? ''}`;
const currentReq = +((await cache.get(key)) ?? 0);
if (currentReq && currentReq >= 1) {
await cache.set(key, currentReq - 1, ttl);
} else {
await cache.delete(key);
}
};
module.exports = clearPendingReq;

View file

@ -1,26 +1,37 @@
const Keyv = require('keyv');
const keyvMongo = require('./keyvMongo');
const { math } = require('../server/utils');
const keyvRedis = require('./keyvRedis');
const { math, isEnabled } = require('../server/utils');
const { logFile, violationFile } = require('./keyvFiles');
const { BAN_DURATION } = process.env ?? {};
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000);
const createViolationInstance = (namespace) => {
const config = isEnabled(USE_REDIS) ? { store: keyvRedis } : { store: violationFile, namespace };
return new Keyv(config);
};
// Serve cache from memory so no need to clear it on startup/exit
const pending_req = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'pending_req' });
const namespaces = {
ban: new Keyv({ store: keyvMongo, ttl: duration, namespace: 'bans' }),
pending_req,
ban: new Keyv({ store: keyvMongo, namespace: 'bans', duration }),
general: new Keyv({ store: logFile, namespace: 'violations' }),
concurrent: new Keyv({ store: violationFile, namespace: 'concurrent' }),
non_browser: new Keyv({ store: violationFile, namespace: 'non_browser' }),
message_limit: new Keyv({ store: violationFile, namespace: 'message_limit' }),
token_balance: new Keyv({ store: violationFile, namespace: 'token_balance' }),
registrations: new Keyv({ store: violationFile, namespace: 'registrations' }),
logins: new Keyv({ store: violationFile, namespace: 'logins' }),
concurrent: createViolationInstance('concurrent'),
non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance('token_balance'),
registrations: createViolationInstance('registrations'),
logins: createViolationInstance('logins'),
};
/**
* Returns either the logs of violations specified by type if a type is provided
* or it returns the general log if no type is specified. If an invalid type is passed,
* an error will be thrown.
* Returns the keyv cache specified by type.
* If an invalid type is passed, an error will be thrown.
*
* @module getLogStores
* @requires keyv - a simple key-value storage that allows you to easily switch out storage adapters.
@ -31,11 +42,10 @@ const namespaces = {
* @throws Will throw an error if an invalid violation type is passed.
*/
const getLogStores = (type) => {
if (!type) {
if (!type || !namespaces[type]) {
throw new Error(`Invalid store type: ${type}`);
}
const logs = namespaces[type];
return logs;
return namespaces[type];
};
module.exports = getLogStores;

3
api/cache/index.js vendored
View file

@ -1,6 +1,5 @@
const keyvFiles = require('./keyvFiles');
const getLogStores = require('./getLogStores');
const logViolation = require('./logViolation');
const clearPendingReq = require('./clearPendingReq');
module.exports = { ...keyvFiles, getLogStores, logViolation, clearPendingReq };
module.exports = { ...keyvFiles, getLogStores, logViolation };

14
api/cache/keyvRedis.js vendored Normal file
View file

@ -0,0 +1,14 @@
const KeyvRedis = require('@keyv/redis');
const { REDIS_URI } = process.env;
let keyvRedis;
if (REDIS_URI) {
keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false });
keyvRedis.on('error', (err) => console.error('KeyvRedis connection error:', err));
} else {
// console.log('REDIS_URI not provided. Redis module will not be initialized.');
}
module.exports = keyvRedis;

View file

@ -1,5 +1,6 @@
const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation');
const { isEnabled } = require('../server/utils');
/**
* Logs the violation.
@ -17,10 +18,11 @@ const logViolation = async (req, res, type, errorMessage, score = 1) => {
}
const logs = getLogStores('general');
const violationLogs = getLogStores(type);
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;
const userViolations = (await violationLogs.get(userId)) ?? 0;
const violationCount = userViolations + score;
await violationLogs.set(userId, violationCount);
const userViolations = (await violationLogs.get(key)) ?? 0;
const violationCount = +userViolations + +score;
await violationLogs.set(key, violationCount);
errorMessage.user_id = userId;
errorMessage.prev_count = userViolations;
@ -28,10 +30,10 @@ const logViolation = async (req, res, type, errorMessage, score = 1) => {
errorMessage.date = new Date().toISOString();
await banViolation(req, res, errorMessage);
const userLogs = (await logs.get(userId)) ?? [];
const userLogs = (await logs.get(key)) ?? [];
userLogs.push(errorMessage);
delete errorMessage.user_id;
await logs.set(userId, userLogs);
await logs.set(key, userLogs);
};
module.exports = logViolation;

4
api/cache/redis.js vendored Normal file
View file

@ -0,0 +1,4 @@
const Redis = require('ioredis');
const { REDIS_URI } = process.env ?? {};
const redis = new Redis.Cluster(REDIS_URI);
module.exports = redis;

13
api/jsconfig.json Normal file
View file

@ -0,0 +1,13 @@
{
"compilerOptions": {
"target": "ES6",
"module": "CommonJS",
// "checkJs": true, // Report errors in JavaScript files
"baseUrl": "./",
"paths": {
"*": ["*", "node_modules/*"],
"~/*": ["./*"]
}
},
"exclude": ["node_modules"]
}

View file

@ -18,11 +18,12 @@ if (!cached) {
}
async function connectDb() {
if (cached.conn) {
if (cached.conn && cached.conn?._readyState === 1) {
return cached.conn;
}
if (!cached.promise) {
const disconnected = cached.conn && cached.conn?._readyState !== 1;
if (!cached.promise || disconnected) {
const opts = {
useNewUrlParser: true,
useUnifiedTopology: true,

View file

@ -24,11 +24,13 @@
"@anthropic-ai/sdk": "^0.5.4",
"@azure/search-documents": "^11.3.2",
"@keyv/mongo": "^2.1.8",
"@keyv/redis": "^2.8.0",
"@waylaidwanderer/chatgpt-api": "^1.37.2",
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
"cohere-ai": "^6.0.0",
"connect-redis": "^7.1.0",
"cookie": "^0.5.0",
"cors": "^2.8.5",
"dotenv": "^16.0.3",
@ -39,10 +41,11 @@
"googleapis": "^118.0.0",
"handlebars": "^4.7.7",
"html": "^1.0.0",
"ioredis": "^5.3.2",
"jose": "^4.15.2",
"js-yaml": "^4.1.0",
"jsonwebtoken": "^9.0.0",
"keyv": "^4.5.3",
"keyv": "^4.5.4",
"keyv-file": "^0.2.0",
"langchain": "^0.0.153",
"lodash": "^4.17.21",

View file

@ -1,5 +1,6 @@
const { sendMessage, sendError, countTokens, isEnabled } = require('../utils');
const { saveMessage, getConvo, getConvoTitle } = require('../../models');
const { sendMessage, sendError, countTokens } = require('../utils');
const clearPendingReq = require('../../cache/clearPendingReq');
const spendTokens = require('../../models/spendTokens');
const abortControllers = require('./abortControllers');
@ -20,6 +21,9 @@ async function abortMessage(req, res) {
const handleAbort = () => {
return async (req, res) => {
try {
if (isEnabled(process.env.LIMIT_CONCURRENT_MESSAGES)) {
await clearPendingReq({ userId: req.user.id });
}
return await abortMessage(req, res);
} catch (err) {
console.error(err);

View file

@ -3,8 +3,11 @@ const uap = require('ua-parser-js');
const { getLogStores } = require('../../cache');
const denyRequest = require('./denyRequest');
const { isEnabled, removePorts } = require('../utils');
const keyvRedis = require('../../cache/keyvRedis');
const banCache = new Keyv({ namespace: 'bans', ttl: 0 });
const banCache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'bans', ttl: 0 });
const message = 'Your account has been temporarily banned due to violations of our service.';
/**
@ -50,9 +53,11 @@ const checkBan = async (req, res, next = () => {}) => {
req.ip = removePorts(req);
const userId = req.user?.id ?? req.user?._id ?? null;
const ipKey = isEnabled(process.env.USE_REDIS) ? `ban_cache:ip:${req.ip}` : req.ip;
const userKey = isEnabled(process.env.USE_REDIS) ? `ban_cache:user:${userId}` : userId;
const cachedIPBan = await banCache.get(req.ip);
const cachedUserBan = await banCache.get(userId);
const cachedIPBan = await banCache.get(ipKey);
const cachedUserBan = await banCache.get(userKey);
const cachedBan = cachedIPBan || cachedUserBan;
if (cachedBan) {
@ -78,13 +83,13 @@ const checkBan = async (req, res, next = () => {}) => {
const timeLeft = Number(isBanned.expiresAt) - Date.now();
if (timeLeft <= 0) {
await banLogs.delete(req.ip);
await banLogs.delete(userId);
await banLogs.delete(ipKey);
await banLogs.delete(userKey);
return next();
}
banCache.set(req.ip, isBanned, timeLeft);
banCache.set(userId, isBanned, timeLeft);
banCache.set(ipKey, isBanned, timeLeft);
banCache.set(userKey, isBanned, timeLeft);
req.banned = true;
return await banResponse(req, res);
};

View file

@ -1,10 +1,13 @@
const Keyv = require('keyv');
const { logViolation } = require('../../cache');
const clearPendingReq = require('../../cache/clearPendingReq');
const { logViolation, getLogStores } = require('../../cache');
const denyRequest = require('./denyRequest');
// Serve cache from memory so no need to clear it on startup/exit
const pendingReqCache = new Keyv({ namespace: 'pendingRequests' });
const {
USE_REDIS,
CONCURRENT_MESSAGE_MAX = 1,
CONCURRENT_VIOLATION_SCORE: score,
} = process.env ?? {};
const ttl = 1000 * 60 * 1;
/**
* Middleware to limit concurrent requests for a user.
@ -12,7 +15,7 @@ const pendingReqCache = new Keyv({ namespace: 'pendingRequests' });
* This middleware checks if a user has exceeded a specified concurrent request limit.
* If the user exceeds the limit, an error is returned. If the user is within the limit,
* their request count is incremented. After the request is processed, the count is decremented.
* If the `pendingReqCache` store is not available, the middleware will skip its logic.
* If the `cache` store is not available, the middleware will skip its logic.
*
* @function
* @param {Object} req - Express request object containing user information.
@ -21,7 +24,9 @@ const pendingReqCache = new Keyv({ namespace: 'pendingRequests' });
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
*/
const concurrentLimiter = async (req, res, next) => {
if (!pendingReqCache) {
const namespace = 'pending_req';
const cache = getLogStores(namespace);
if (!cache) {
return next();
}
@ -29,12 +34,12 @@ const concurrentLimiter = async (req, res, next) => {
return next();
}
const { CONCURRENT_MESSAGE_MAX = 1, CONCURRENT_VIOLATION_SCORE: score } = process.env;
const userId = req.user?.id ?? req.user?._id ?? '';
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
const type = 'concurrent';
const userId = req.user?.id ?? req.user?._id ?? null;
const pendingRequests = (await pendingReqCache.get(userId)) ?? 0;
const key = `${USE_REDIS ? namespace : ''}:${userId}`;
const pendingRequests = +((await cache.get(key)) ?? 0);
if (pendingRequests >= limit) {
const errorMessage = {
@ -46,22 +51,17 @@ const concurrentLimiter = async (req, res, next) => {
await logViolation(req, res, type, errorMessage, score);
return await denyRequest(req, res, errorMessage);
} else {
await pendingReqCache.set(userId, pendingRequests + 1);
await cache.set(key, pendingRequests + 1, ttl);
}
// Ensure the requests are removed from the store once the request is done
let cleared = false;
const cleanUp = async () => {
if (!pendingReqCache) {
if (cleared) {
return;
}
const currentRequests = await pendingReqCache.get(userId);
if (currentRequests && currentRequests >= 1) {
await pendingReqCache.set(userId, currentRequests - 1);
} else {
await pendingReqCache.delete(userId);
}
cleared = true;
await clearPendingReq({ userId, cache });
};
if (pendingRequests < limit) {
@ -72,10 +72,4 @@ const concurrentLimiter = async (req, res, next) => {
next();
};
// if cache is not served from memory, clear it on exit
// process.on('exit', async () => {
// console.log('Clearing all pending requests before exiting...');
// await pendingReqCache.clear();
// });
module.exports = concurrentLimiter;

View file

@ -1,3 +1,4 @@
const Keyv = require('keyv');
const express = require('express');
const router = express.Router();
const { MeiliSearch } = require('meilisearch');
@ -6,8 +7,15 @@ const { Conversation, getConvosQueried } = require('../../models/Conversation');
const { reduceHits } = require('../../lib/utils/reduceHits');
const { cleanUpPrimaryKeyValue } = require('../../lib/utils/misc');
const requireJwtAuth = require('../middleware/requireJwtAuth');
const keyvRedis = require('../../cache/keyvRedis');
const { isEnabled } = require('../utils');
const cache = new Map();
const expiration = 60 * 1000;
const cache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'search', ttl: expiration });
router.use(requireJwtAuth);
router.get('/sync', async function (req, res) {
await Message.syncWithMeili();
@ -15,24 +23,20 @@ router.get('/sync', async function (req, res) {
res.send('synced');
});
router.get('/', requireJwtAuth, async function (req, res) {
router.get('/', async function (req, res) {
try {
let user = req.user.id;
user = user ?? null;
let user = req.user.id ?? '';
const { q } = req.query;
const pageNumber = req.query.pageNumber || 1;
const key = `${user || ''}${q}`;
if (cache.has(key)) {
const key = `${user}:search:${q}`;
const cached = await cache.get(key);
if (cached) {
console.log('cache hit', key);
const cached = cache.get(key);
const { pages, pageSize, messages } = cached;
res
.status(200)
.send({ conversations: cached[pageNumber], pages, pageNumber, pageSize, messages });
return;
} else {
cache.clear();
}
// const message = await Message.meiliSearch(q);
@ -67,7 +71,7 @@ router.get('/', requireJwtAuth, async function (req, res) {
if (message.conversationId.includes('--')) {
message.conversationId = cleanUpPrimaryKeyValue(message.conversationId);
}
if (result.convoMap[message.conversationId] && !message.error) {
if (result.convoMap[message.conversationId]) {
const convo = result.convoMap[message.conversationId];
const { title, chatGptLabel, model } = convo;
message = { ...message, ...{ title, chatGptLabel, model } };
@ -77,7 +81,7 @@ router.get('/', requireJwtAuth, async function (req, res) {
result.messages = activeMessages;
if (result.cache) {
result.cache.messages = activeMessages;
cache.set(key, result.cache);
cache.set(key, result.cache, expiration);
delete result.cache;
}
delete result.convoMap;

View file

@ -1,9 +1,13 @@
const Keyv = require('keyv');
const axios = require('axios');
const { isEnabled } = require('../utils');
const keyvRedis = require('../../cache/keyvRedis');
// const { getAzureCredentials, genAzureChatCompletion } = require('../../utils/');
const { openAIApiKey, userProvidedOpenAI } = require('./EndpointService').config;
const modelsCache = new Keyv({ namespace: 'models' });
const modelsCache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'models' });
const { OPENROUTER_API_KEY, OPENAI_REVERSE_PROXY, CHATGPT_MODELS, ANTHROPIC_MODELS } =
process.env ?? {};

View file

@ -1,4 +1,5 @@
const session = require('express-session');
const RedisStore = require('connect-redis').default;
const passport = require('passport');
const {
googleLogin,
@ -7,6 +8,7 @@ const {
facebookLogin,
setupOpenId,
} = require('../strategies');
const client = require('../cache/redis');
const configureSocialLogins = (app) => {
if (process.env.GOOGLE_CLIENT_ID && process.env.GOOGLE_CLIENT_SECRET) {
@ -28,13 +30,15 @@ const configureSocialLogins = (app) => {
process.env.OPENID_SCOPE &&
process.env.OPENID_SESSION_SECRET
) {
app.use(
session({
secret: process.env.OPENID_SESSION_SECRET,
resave: false,
saveUninitialized: false,
}),
);
const sessionOptions = {
secret: process.env.OPENID_SESSION_SECRET,
resave: false,
saveUninitialized: false,
};
if (process.env.USE_REDIS) {
sessionOptions.store = new RedisStore({ client, prefix: 'librechat' });
}
app.use(session(sessionOptions));
app.use(passport.session());
setupOpenId();
}