Move usermethods and models to data-schema

This commit is contained in:
Cha 2025-05-29 16:37:31 +08:00 committed by Danny Avila
parent 4808c5be48
commit 4049b5572c
No known key found for this signature in database
GPG key ID: BF31EEB2C5CA0956
93 changed files with 2396 additions and 1267 deletions

View file

@ -1,6 +1,5 @@
const mongoose = require('mongoose');
const crypto = require('node:crypto');
const { agentSchema } = require('@librechat/data-schemas');
const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
require('librechat-data-provider').Constants;
@ -15,7 +14,7 @@ const getLogStores = require('~/cache/getLogStores');
const { getActions } = require('./Action');
const { logger } = require('~/config');
const Agent = mongoose.model('agent', agentSchema);
const db = require('~/lib/db/connectDb');
/**
* Create an agent with the provided data.
@ -36,7 +35,7 @@ const createAgent = async (agentData) => {
},
],
};
return (await Agent.create(initialAgentData)).toObject();
return (await db.models.Agent.create(initialAgentData)).toObject();
};
/**
@ -47,7 +46,7 @@ const createAgent = async (agentData) => {
* @param {string} searchParameter.author - The user ID of the agent's author.
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
*/
const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean();
const getAgent = async (searchParameter) => await db.models.Agent.findOne(searchParameter).lean();
/**
* Load an agent based on the provided ID
@ -269,6 +268,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
const { updatingUserId = null, forceVersion = false } = options;
const mongoOptions = { new: true, upsert: false };
const Agent = db.models?.Agent;
const currentAgent = await Agent.findOne(searchParameter);
if (currentAgent) {
const { __v, _id, id, versions, author, ...versionData } = currentAgent.toObject();
@ -362,6 +362,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
* @returns {Promise<Agent>} The updated agent.
*/
const addAgentResourceFile = async ({ req, agent_id, tool_resource, file_id }) => {
const Agent = db.models?.Agent;
const searchParameter = { id: agent_id };
let agent = await getAgent(searchParameter);
if (!agent) {
@ -427,7 +428,7 @@ const removeAgentResourceFiles = async ({ agent_id, files }) => {
}
const updatePullData = { $pull: pullOps };
const agentAfterPull = await Agent.findOneAndUpdate(searchParameter, updatePullData, {
const agentAfterPull = await db.models.Agent.findOneAndUpdate(searchParameter, updatePullData, {
new: true,
}).lean();
@ -457,7 +458,7 @@ const removeAgentResourceFiles = async ({ agent_id, files }) => {
* @returns {Promise<void>} Resolves when the agent has been successfully deleted.
*/
const deleteAgent = async (searchParameter) => {
const agent = await Agent.findOneAndDelete(searchParameter);
const agent = await db.models.Agent.findOneAndDelete(searchParameter);
if (agent) {
await removeAgentFromAllProjects(agent.id);
}
@ -481,9 +482,8 @@ const getListAgents = async (searchParameter) => {
delete globalQuery.author;
query = { $or: [globalQuery, query] };
}
const agents = (
await Agent.find(query, {
await db.models.Agent.find(query, {
id: 1,
_id: 0,
name: 1,
@ -580,6 +580,7 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds
* @throws {Error} If the agent is not found or the specified version does not exist.
*/
const revertAgentVersion = async (searchParameter, versionIndex) => {
const Agent = db.models?.Agent;
const agent = await Agent.findOne(searchParameter);
if (!agent) {
throw new Error('Agent not found');
@ -662,7 +663,6 @@ const generateActionMetadataHash = async (actionIds, actions) => {
*/
module.exports = {
Agent,
getAgent,
loadAgent,
createAgent,

View file

@ -10,7 +10,6 @@ const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const {
Agent,
addAgentResourceFile,
removeAgentResourceFiles,
createAgent,
@ -20,6 +19,9 @@ const {
getListAgents,
updateAgentProjects,
} = require('./Agent');
const db = require('~/lib/db/connectDb');
let Agent;
describe('Agent Resource File Operations', () => {
let mongoServer;
@ -27,7 +29,9 @@ describe('Agent Resource File Operations', () => {
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
await db.connectDb(mongoUri);
Agent = db.models.Agent;
});
afterAll(async () => {
@ -55,6 +59,7 @@ describe('Agent Resource File Operations', () => {
test('should add tool_resource to tools if missing', async () => {
const agent = await createBasicAgent();
const fileId = uuidv4();
const toolResource = 'file_search';

View file

@ -1,7 +1,4 @@
const mongoose = require('mongoose');
const { assistantSchema } = require('@librechat/data-schemas');
const Assistant = mongoose.model('assistant', assistantSchema);
const db = require('~/lib/db/connectDb');
/**
* Update an assistant with new data without overwriting existing properties,
@ -15,7 +12,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
*/
const updateAssistantDoc = async (searchParams, updateData) => {
const options = { new: true, upsert: true };
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
return await db.models.Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
};
/**
@ -26,7 +23,7 @@ const updateAssistantDoc = async (searchParams, updateData) => {
* @param {string} searchParams.user - The user ID of the assistant's author.
* @returns {Promise<AssistantDocument|null>} The assistant document as a plain object, or null if not found.
*/
const getAssistant = async (searchParams) => await Assistant.findOne(searchParams).lean();
const getAssistant = async (searchParams) => await db.models.Assistant.findOne(searchParams).lean();
/**
* Retrieves all assistants that match the given search parameters.
@ -36,7 +33,7 @@ const getAssistant = async (searchParams) => await Assistant.findOne(searchParam
* @returns {Promise<Array<AssistantDocument>>} A promise that resolves to an array of assistant documents as plain objects.
*/
const getAssistants = async (searchParams, select = null) => {
let query = Assistant.find(searchParams);
let query = db.models.Assistant.find(searchParams);
if (select) {
query = query.select(select);
@ -54,7 +51,7 @@ const getAssistants = async (searchParams, select = null) => {
* @returns {Promise<void>} Resolves when the assistant has been successfully deleted.
*/
const deleteAssistant = async (searchParams) => {
return await Assistant.findOneAndDelete(searchParams);
return await db.models.Assistant.findOneAndDelete(searchParams);
};
module.exports = {

View file

@ -1,4 +0,0 @@
const mongoose = require('mongoose');
const { balanceSchema } = require('@librechat/data-schemas');
module.exports = mongoose.model('Balance', balanceSchema);

View file

@ -1,8 +1,5 @@
const mongoose = require('mongoose');
const logger = require('~/config/winston');
const { bannerSchema } = require('@librechat/data-schemas');
const Banner = mongoose.model('Banner', bannerSchema);
const db = require('~/lib/db/connectDb');
/**
* Retrieves the current active banner.
@ -11,7 +8,7 @@ const Banner = mongoose.model('Banner', bannerSchema);
const getBanner = async (user) => {
try {
const now = new Date();
const banner = await Banner.findOne({
const banner = await db.models.Banner.findOne({
displayFrom: { $lte: now },
$or: [{ displayTo: { $gte: now } }, { displayTo: null }],
type: 'banner',
@ -28,4 +25,4 @@ const getBanner = async (user) => {
}
};
module.exports = { Banner, getBanner };
module.exports = { getBanner };

View file

@ -1,7 +1,6 @@
const Conversation = require('./schema/convoSchema');
const { getMessages, deleteMessages } = require('./Message');
const logger = require('~/config/winston');
const db = require('~/lib/db/connectDb');
/**
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
* @param {string} conversationId - The conversation's ID.
@ -9,7 +8,7 @@ const logger = require('~/config/winston');
*/
const searchConversation = async (conversationId) => {
try {
return await Conversation.findOne({ conversationId }, 'conversationId user').lean();
return await db.models.Conversation.findOne({ conversationId }, 'conversationId user').lean();
} catch (error) {
logger.error('[searchConversation] Error searching conversation', error);
throw new Error('Error searching conversation');
@ -24,7 +23,7 @@ const searchConversation = async (conversationId) => {
*/
const getConvo = async (user, conversationId) => {
try {
return await Conversation.findOne({ user, conversationId }).lean();
return await db.models.Conversation.findOne({ user, conversationId }).lean();
} catch (error) {
logger.error('[getConvo] Error getting single conversation', error);
return { message: 'Error getting single conversation' };
@ -41,7 +40,7 @@ const deleteNullOrEmptyConversations = async () => {
],
};
const result = await Conversation.deleteMany(filter);
const result = await db.models.Conversation.deleteMany(filter);
// Delete associated messages
const messageDeleteResult = await deleteMessages(filter);
@ -67,7 +66,7 @@ const deleteNullOrEmptyConversations = async () => {
*/
const getConvoFiles = async (conversationId) => {
try {
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
return (await db.models.Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
} catch (error) {
logger.error('[getConvoFiles] Error getting conversation files', error);
throw new Error('Error getting conversation files');
@ -75,7 +74,6 @@ const getConvoFiles = async (conversationId) => {
};
module.exports = {
Conversation,
getConvoFiles,
searchConversation,
deleteNullOrEmptyConversations,
@ -114,7 +112,7 @@ module.exports = {
}
/** Note: the resulting Model object is necessary for Meilisearch operations */
const conversation = await Conversation.findOneAndUpdate(
const conversation = await db.models.Conversation.findOneAndUpdate(
{ conversationId, user: req.user.id },
updateOperation,
{
@ -143,7 +141,7 @@ module.exports = {
},
}));
const result = await Conversation.bulkWrite(bulkOps);
const result = await db.models.Conversation.bulkWrite(bulkOps);
return result;
} catch (error) {
logger.error('[saveBulkConversations] Error saving conversations in bulk', error);
@ -155,7 +153,7 @@ module.exports = {
{ cursor, limit = 25, isArchived = false, tags, search, order = 'desc' } = {},
) => {
const filters = [{ user }];
const { Conversation } = db.models;
if (isArchived) {
filters.push({ isArchived: true });
} else {
@ -219,7 +217,7 @@ module.exports = {
const conversationIds = convoIds.map((convo) => convo.conversationId);
const results = await Conversation.find({
const results = await db.models.Conversation.find({
user,
conversationId: { $in: conversationIds },
$or: [{ expiredAt: { $exists: false } }, { expiredAt: null }],
@ -288,7 +286,7 @@ module.exports = {
deleteConvos: async (user, filter) => {
try {
const userFilter = { ...filter, user };
const { Conversation } = db.models;
const conversations = await Conversation.find(userFilter).select('conversationId');
const conversationIds = conversations.map((c) => c.conversationId);

View file

@ -1,10 +1,5 @@
const mongoose = require('mongoose');
const Conversation = require('./schema/convoSchema');
const logger = require('~/config/winston');
const { conversationTagSchema } = require('@librechat/data-schemas');
const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema);
const db = require('~/lib/db/connectDb');
/**
* Retrieves all conversation tags for a user.
@ -13,7 +8,7 @@ const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema)
*/
const getConversationTags = async (user) => {
try {
return await ConversationTag.find({ user }).sort({ position: 1 }).lean();
return await db.models.ConversationTag.find({ user }).sort({ position: 1 }).lean();
} catch (error) {
logger.error('[getConversationTags] Error getting conversation tags', error);
throw new Error('Error getting conversation tags');
@ -34,6 +29,7 @@ const createConversationTag = async (user, data) => {
try {
const { tag, description, addToConversation, conversationId } = data;
const { ConversationTag, Conversation } = db.models;
const existingTag = await ConversationTag.findOne({ user, tag }).lean();
if (existingTag) {
return existingTag;
@ -88,6 +84,7 @@ const updateConversationTag = async (user, oldTag, data) => {
try {
const { tag: newTag, description, position } = data;
const { ConversationTag, Conversation } = db.models;
const existingTag = await ConversationTag.findOne({ user, tag: oldTag }).lean();
if (!existingTag) {
return null;
@ -140,15 +137,15 @@ const adjustPositions = async (user, oldPosition, newPosition) => {
const position =
oldPosition < newPosition
? {
$gt: Math.min(oldPosition, newPosition),
$lte: Math.max(oldPosition, newPosition),
}
$gt: Math.min(oldPosition, newPosition),
$lte: Math.max(oldPosition, newPosition),
}
: {
$gte: Math.min(oldPosition, newPosition),
$lt: Math.max(oldPosition, newPosition),
};
$gte: Math.min(oldPosition, newPosition),
$lt: Math.max(oldPosition, newPosition),
};
await ConversationTag.updateMany(
await db.models.ConversationTag.updateMany(
{
user,
position,
@ -165,6 +162,7 @@ const adjustPositions = async (user, oldPosition, newPosition) => {
*/
const deleteConversationTag = async (user, tag) => {
try {
const { ConversationTag, Conversation } = db.models;
const deletedTag = await ConversationTag.findOneAndDelete({ user, tag }).lean();
if (!deletedTag) {
return null;
@ -193,6 +191,7 @@ const deleteConversationTag = async (user, tag) => {
*/
const updateTagsForConversation = async (user, conversationId, tags) => {
try {
const { ConversationTag, Conversation } = db.models;
const conversation = await Conversation.findOne({ user, conversationId }).lean();
if (!conversation) {
throw new Error('Conversation not found');

View file

@ -1,9 +1,7 @@
const mongoose = require('mongoose');
const { EToolResources } = require('librechat-data-provider');
const { fileSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const File = mongoose.model('File', fileSchema);
const db = require('~/lib/db/connectDb');
/**
* Finds a file by its file_id with additional query options.
@ -12,7 +10,7 @@ const File = mongoose.model('File', fileSchema);
* @returns {Promise<MongoFile>} A promise that resolves to the file document or null.
*/
const findFileById = async (file_id, options = {}) => {
return await File.findOne({ file_id, ...options }).lean();
return await db.models.File.findOne({ file_id, ...options }).lean();
};
/**
@ -25,7 +23,7 @@ const findFileById = async (file_id, options = {}) => {
*/
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const sortOptions = { updatedAt: -1, ..._sortOptions };
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
return await db.models.File.find(filter).select(selectFields).sort(sortOptions).lean();
};
/**
@ -81,7 +79,7 @@ const createFile = async (data, disableTTL) => {
delete fileData.expiresAt;
}
return await File.findOneAndUpdate({ file_id: data.file_id }, fileData, {
return await db.models.File.findOneAndUpdate({ file_id: data.file_id }, fileData, {
new: true,
upsert: true,
}).lean();
@ -98,7 +96,7 @@ const updateFile = async (data) => {
$set: update,
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
};
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
return await db.models.File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
};
/**
@ -112,7 +110,7 @@ const updateFileUsage = async (data) => {
$inc: { usage: inc },
$unset: { expiresAt: '', temp_file_id: '' },
};
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
return await db.models.File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
};
/**
@ -121,7 +119,7 @@ const updateFileUsage = async (data) => {
* @returns {Promise<MongoFile>} A promise that resolves to the deleted file document or null.
*/
const deleteFile = async (file_id) => {
return await File.findOneAndDelete({ file_id }).lean();
return await db.models.File.findOneAndDelete({ file_id }).lean();
};
/**
@ -130,7 +128,7 @@ const deleteFile = async (file_id) => {
* @returns {Promise<MongoFile>} A promise that resolves to the deleted file document or null.
*/
const deleteFileByFilter = async (filter) => {
return await File.findOneAndDelete(filter).lean();
return await db.models.File.findOneAndDelete(filter).lean();
};
/**
@ -143,7 +141,7 @@ const deleteFiles = async (file_ids, user) => {
if (user) {
deleteQuery = { user: user };
}
return await File.deleteMany(deleteQuery);
return await db.models.File.deleteMany(deleteQuery);
};
/**
@ -169,7 +167,6 @@ async function batchUpdateFiles(updates) {
}
module.exports = {
File,
findFileById,
getFiles,
getToolFilesByIds,

View file

@ -1,4 +0,0 @@
const mongoose = require('mongoose');
const { keySchema } = require('@librechat/data-schemas');
module.exports = mongoose.model('Key', keySchema);

View file

@ -1,7 +1,6 @@
const { z } = require('zod');
const Message = require('./schema/messageSchema');
const { logger } = require('~/config');
const db = require('~/lib/db/connectDb');
const idSchema = z.string().uuid();
/**
@ -68,8 +67,7 @@ async function saveMessage(req, params, metadata) {
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
update.tokenCount = 0;
}
const message = await Message.findOneAndUpdate(
const message = await db.models.Message.findOneAndUpdate(
{ messageId: params.messageId, user: req.user.id },
update,
{ upsert: true, new: true },
@ -87,7 +85,7 @@ async function saveMessage(req, params, metadata) {
try {
// Try to find the existing message with this ID
const existingMessage = await Message.findOne({
const existingMessage = await db.models.Message.findOne({
messageId: params.messageId,
user: req.user.id,
});
@ -140,8 +138,7 @@ async function bulkSaveMessages(messages, overrideTimestamp = false) {
upsert: true,
},
}));
const result = await Message.bulkWrite(bulkOps);
const result = await db.models.Message.bulkWrite(bulkOps);
return result;
} catch (err) {
logger.error('Error saving messages in bulk:', err);
@ -183,7 +180,7 @@ async function recordMessage({
...rest,
};
return await Message.findOneAndUpdate({ user, messageId }, message, {
return await db.models.Message.findOneAndUpdate({ user, messageId }, message, {
upsert: true,
new: true,
});
@ -207,7 +204,7 @@ async function recordMessage({
*/
async function updateMessageText(req, { messageId, text }) {
try {
await Message.updateOne({ messageId, user: req.user.id }, { text });
await db.models?.Message.updateOne({ messageId, user: req.user.id }, { text });
} catch (err) {
logger.error('Error updating message text:', err);
throw err;
@ -235,7 +232,7 @@ async function updateMessageText(req, { messageId, text }) {
async function updateMessage(req, message, metadata) {
try {
const { messageId, ...update } = message;
const updatedMessage = await Message.findOneAndUpdate(
const updatedMessage = await db.models.Message.findOneAndUpdate(
{ messageId, user: req.user.id },
update,
{
@ -279,10 +276,10 @@ async function updateMessage(req, message, metadata) {
*/
async function deleteMessagesSince(req, { messageId, conversationId }) {
try {
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
const message = await db.models.Message.findOne({ messageId, user: req.user.id }).lean();
if (message) {
const query = Message.find({ conversationId, user: req.user.id });
const query = db.models.Message.find({ conversationId, user: req.user.id });
return await query.deleteMany({
createdAt: { $gt: message.createdAt },
});
@ -306,10 +303,10 @@ async function deleteMessagesSince(req, { messageId, conversationId }) {
async function getMessages(filter, select) {
try {
if (select) {
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
return await db.models.Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
}
return await Message.find(filter).sort({ createdAt: 1 }).lean();
return await db.models.Message.find(filter).sort({ createdAt: 1 }).lean();
} catch (err) {
logger.error('Error getting messages:', err);
throw err;
@ -326,7 +323,7 @@ async function getMessages(filter, select) {
*/
async function getMessage({ user, messageId }) {
try {
return await Message.findOne({
return await db.models.Message.findOne({
user,
messageId,
}).lean();
@ -347,7 +344,7 @@ async function getMessage({ user, messageId }) {
*/
async function deleteMessages(filter) {
try {
return await Message.deleteMany(filter);
return await db.models.Message.deleteMany(filter);
} catch (err) {
logger.error('Error deleting messages:', err);
throw err;
@ -355,7 +352,6 @@ async function deleteMessages(filter) {
}
module.exports = {
Message,
saveMessage,
bulkSaveMessages,
recordMessage,

View file

@ -1,5 +1,6 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const db = require('~/lib/db/connectDb');
jest.mock('mongoose');
@ -20,14 +21,28 @@ const mockSchema = {
deleteMany: jest.fn(),
};
mongoose.model.mockReturnValue(mockSchema);
jest.mock('~/models/schema/messageSchema', () => mockSchema);
jest.mock('~/config/winston', () => ({
error: jest.fn(),
}));
const mockModels = {
Message: {
findOneAndUpdate: mockSchema.findOneAndUpdate,
updateOne: mockSchema.updateOne,
findOne: mockSchema.findOne,
find: mockSchema.find,
deleteMany: mockSchema.deleteMany,
},
};
jest.mock('~/lib/db/connectDb', () => {
return {
get models() {
return mockModels;
},
};
});
const {
saveMessage,
getMessages,
@ -153,7 +168,7 @@ describe('Message Operations', () => {
});
describe('Conversation Hijacking Prevention', () => {
it('should not allow editing a message in another user\'s conversation', async () => {
it("should not allow editing a message in another user's conversation", async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
@ -175,7 +190,7 @@ describe('Message Operations', () => {
);
});
it('should not allow deleting messages from another user\'s conversation', async () => {
it("should not allow deleting messages from another user's conversation", async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
@ -193,7 +208,7 @@ describe('Message Operations', () => {
});
});
it('should not allow inserting a new message into another user\'s conversation', async () => {
it("should not allow inserting a new message into another user's conversation", async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4(); // Use a valid UUID

View file

@ -1,9 +1,9 @@
const Preset = require('./schema/presetSchema');
const { logger } = require('~/config');
const db = require('~/lib/db/connectDb');
const getPreset = async (user, presetId) => {
try {
return await Preset.findOne({ user, presetId }).lean();
return await db.models.Preset.findOne({ user, presetId }).lean();
} catch (error) {
logger.error('[getPreset] Error getting single preset', error);
return { message: 'Error getting single preset' };
@ -11,11 +11,10 @@ const getPreset = async (user, presetId) => {
};
module.exports = {
Preset,
getPreset,
getPresets: async (user, filter) => {
try {
const presets = await Preset.find({ ...filter, user }).lean();
const presets = await db.models.Preset.find({ ...filter, user }).lean();
const defaultValue = 10000;
presets.sort((a, b) => {
@ -40,6 +39,7 @@ module.exports = {
const setter = { $set: {} };
const { user: _, ...cleanPreset } = preset;
const update = { presetId, ...cleanPreset };
const Preset = db.models.Preset;
if (preset.tools && Array.isArray(preset.tools)) {
update.tools =
preset.tools
@ -77,7 +77,7 @@ module.exports = {
deletePresets: async (user, filter) => {
// let toRemove = await Preset.find({ ...filter, user }).select('presetId');
// const ids = toRemove.map((instance) => instance.presetId);
let deleteCount = await Preset.deleteMany({ ...filter, user });
let deleteCount = await db.models.Preset.deleteMany({ ...filter, user });
return deleteCount;
},
};

View file

@ -1,8 +1,5 @@
const { model } = require('mongoose');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { projectSchema } = require('@librechat/data-schemas');
const Project = model('Project', projectSchema);
const db = require('~/lib/db/connectDb');
/**
* Retrieve a project by ID and convert the found project document to a plain object.
@ -12,7 +9,7 @@ const Project = model('Project', projectSchema);
* @returns {Promise<IMongoProject>} A plain object representing the project document, or `null` if no project is found.
*/
const getProjectById = async function (projectId, fieldsToSelect = null) {
const query = Project.findById(projectId);
const query = db.models.Project.findById(projectId);
if (fieldsToSelect) {
query.select(fieldsToSelect);
@ -39,7 +36,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) {
select: fieldsToSelect,
};
return await Project.findOneAndUpdate(query, update, options);
return await db.models.Project.findOneAndUpdate(query, update, options);
};
/**
@ -50,7 +47,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) {
* @returns {Promise<IMongoProject>} The updated project document.
*/
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
return await Project.findByIdAndUpdate(
return await db.models.Project.findByIdAndUpdate(
projectId,
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
{ new: true },
@ -65,7 +62,7 @@ const addGroupIdsToProject = async function (projectId, promptGroupIds) {
* @returns {Promise<IMongoProject>} The updated project document.
*/
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
return await Project.findByIdAndUpdate(
return await db.models.Project.findByIdAndUpdate(
projectId,
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
{ new: true },
@ -79,7 +76,7 @@ const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
* @returns {Promise<void>}
*/
const removeGroupFromAllProjects = async (promptGroupId) => {
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
await db.models.Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
};
/**
@ -90,7 +87,7 @@ const removeGroupFromAllProjects = async (promptGroupId) => {
* @returns {Promise<IMongoProject>} The updated project document.
*/
const addAgentIdsToProject = async function (projectId, agentIds) {
return await Project.findByIdAndUpdate(
return await db.models.Project.findByIdAndUpdate(
projectId,
{ $addToSet: { agentIds: { $each: agentIds } } },
{ new: true },
@ -105,7 +102,7 @@ const addAgentIdsToProject = async function (projectId, agentIds) {
* @returns {Promise<IMongoProject>} The updated project document.
*/
const removeAgentIdsFromProject = async function (projectId, agentIds) {
return await Project.findByIdAndUpdate(
return await db.models.Project.findByIdAndUpdate(
projectId,
{ $pull: { agentIds: { $in: agentIds } } },
{ new: true },
@ -119,7 +116,7 @@ const removeAgentIdsFromProject = async function (projectId, agentIds) {
* @returns {Promise<void>}
*/
const removeAgentFromAllProjects = async (agentId) => {
await Project.updateMany({}, { $pull: { agentIds: agentId } });
await db.models.Project.updateMany({}, { $pull: { agentIds: agentId } });
};
module.exports = {

View file

@ -1,4 +1,3 @@
const mongoose = require('mongoose');
const { ObjectId } = require('mongodb');
const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider');
const {
@ -7,12 +6,9 @@ const {
removeGroupIdsFromProject,
removeGroupFromAllProjects,
} = require('./Project');
const { promptGroupSchema, promptSchema } = require('@librechat/data-schemas');
const { escapeRegExp } = require('~/server/utils');
const { logger } = require('~/config');
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
const Prompt = mongoose.model('Prompt', promptSchema);
const db = require('~/lib/db/connectDb');
/**
* Create a pipeline for the aggregation to get prompt groups
@ -137,7 +133,7 @@ const getAllPromptGroups = async (req, filter) => {
}
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
return await db.models.PromptGroup.aggregate(promptGroupsPipeline).exec();
} catch (error) {
console.error('Error getting all prompt groups', error);
return { message: 'Error getting all prompt groups' };
@ -237,7 +233,7 @@ const deletePromptGroup = async ({ _id, author, role }) => {
throw new Error('Prompt group not found');
}
await Prompt.deleteMany(groupQuery);
await db.models.Prompt.deleteMany(groupQuery);
await removeGroupFromAllProjects(_id);
return { message: 'Prompt group deleted successfully' };
};
@ -254,6 +250,7 @@ module.exports = {
createPromptGroup: async (saveData) => {
try {
const { prompt, group, author, authorName } = saveData;
const { Prompt, PromptGroup } = db.models;
let newPromptGroup = await PromptGroup.findOneAndUpdate(
{ ...group, author, authorName, productionId: null },
@ -309,6 +306,7 @@ module.exports = {
/** @type {TPrompt} */
let newPrompt;
const { Prompt } = db.models;
try {
newPrompt = await Prompt.create(newPromptData);
} catch (error) {
@ -328,7 +326,7 @@ module.exports = {
},
getPrompts: async (filter) => {
try {
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
return await db.models.Prompt.find(filter).sort({ createdAt: -1 }).lean();
} catch (error) {
logger.error('Error getting prompts', error);
return { message: 'Error getting prompts' };
@ -339,7 +337,7 @@ module.exports = {
if (filter.groupId) {
filter.groupId = new ObjectId(filter.groupId);
}
return await Prompt.findOne(filter).lean();
return await db.models.Prompt.findOne(filter).lean();
} catch (error) {
logger.error('Error getting prompt', error);
return { message: 'Error getting prompt' };
@ -352,7 +350,7 @@ module.exports = {
*/
getRandomPromptGroups: async (filter) => {
try {
const result = await PromptGroup.aggregate([
const result = await db.models.PromptGroup.aggregate([
{
$match: {
category: { $ne: '' },
@ -385,7 +383,7 @@ module.exports = {
},
getPromptGroupsWithPrompts: async (filter) => {
try {
return await PromptGroup.findOne(filter)
return await db.models.PromptGroup.findOne(filter)
.populate({
path: 'prompts',
select: '-_id -__v -user',
@ -399,7 +397,7 @@ module.exports = {
},
getPromptGroup: async (filter) => {
try {
return await PromptGroup.findOne(filter).lean();
return await db.models.PromptGroup.findOne(filter).lean();
} catch (error) {
logger.error('Error getting prompt group', error);
return { message: 'Error getting prompt group' };
@ -420,6 +418,7 @@ module.exports = {
*/
deletePrompt: async ({ promptId, groupId, author, role }) => {
const query = { _id: promptId, groupId, author };
const { Prompt, PromptGroup } = db.models;
if (role === SystemRoles.ADMIN) {
delete query.author;
}
@ -484,7 +483,7 @@ module.exports = {
}
const updateData = { ...data, ...updateOps };
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
const updatedDoc = await db.models.PromptGroup.findOneAndUpdate(filter, updateData, {
new: true,
upsert: false,
});
@ -506,6 +505,7 @@ module.exports = {
*/
makePromptProduction: async (promptId) => {
try {
const { Prompt, PromptGroup } = db.models;
const prompt = await Prompt.findById(promptId).lean();
if (!prompt) {
@ -530,7 +530,7 @@ module.exports = {
},
updatePromptLabels: async (_id, labels) => {
try {
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
const response = await db.models.Prompt.updateOne({ _id }, { $set: { labels } });
if (response.matchedCount === 0) {
return { message: 'Prompt not found' };
}

View file

@ -1,4 +1,3 @@
const mongoose = require('mongoose');
const {
CacheKeys,
SystemRoles,
@ -8,10 +7,8 @@ const {
removeNullishValues,
} = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const { roleSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const Role = mongoose.model('Role', roleSchema);
const db = require('~/lib/db/connectDb');
/**
* Retrieve a role by name and convert the found role document to a plain object.
@ -24,6 +21,7 @@ const Role = mongoose.model('Role', roleSchema);
*/
const getRoleByName = async function (roleName, fieldsToSelect = null) {
const cache = getLogStores(CacheKeys.ROLES);
const { Role } = db.models;
try {
const cachedRole = await cache.get(roleName);
if (cachedRole) {
@ -57,7 +55,7 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) {
const updateRoleByName = async function (roleName, updates) {
const cache = getLogStores(CacheKeys.ROLES);
try {
const role = await Role.findOneAndUpdate(
const role = await db.models.Role.findOneAndUpdate(
{ name: roleName },
{ $set: updates },
{ new: true, lean: true },
@ -78,6 +76,7 @@ const updateRoleByName = async function (roleName, updates) {
* @param {Object.<PermissionTypes, Object.<Permissions, boolean>>} permissionsUpdate - Permissions to update and their values.
*/
async function updateAccessPermissions(roleName, permissionsUpdate) {
const { Role } = db.models;
// Filter and clean the permission updates based on our schema definition.
const updates = {};
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
@ -181,6 +180,7 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
* @returns {Promise<void>}
*/
const initializeRoles = async function () {
const { Role } = db.models;
for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) {
let role = await Role.findOne({ name: roleName });
const defaultPerms = roleDefaults[roleName].permissions;
@ -210,6 +210,7 @@ const initializeRoles = async function () {
* @returns {Promise<number>} Number of roles migrated.
*/
const migrateRoleSchema = async function (roleName) {
const { Role } = db.models;
try {
// Get roles to migrate
let roles;
@ -282,7 +283,6 @@ const migrateRoleSchema = async function (roleName) {
};
module.exports = {
Role,
getRoleByName,
initializeRoles,
updateRoleByName,

View file

@ -6,9 +6,11 @@ const {
roleDefaults,
PermissionTypes,
} = require('librechat-data-provider');
const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
const { getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
const getLogStores = require('~/cache/getLogStores');
const db = require('~/lib/db/connectDb');
// Mock the cache
jest.mock('~/cache/getLogStores', () =>
jest.fn().mockReturnValue({
@ -19,11 +21,14 @@ jest.mock('~/cache/getLogStores', () =>
);
let mongoServer;
let Role;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
await db.connectDb(mongoUri);
Role = db.models.Role;
});
afterAll(async () => {

View file

@ -1,275 +0,0 @@
const mongoose = require('mongoose');
const signPayload = require('~/server/services/signPayload');
const { hashToken } = require('~/server/utils/crypto');
const { sessionSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const Session = mongoose.model('Session', sessionSchema);
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default
/**
* Error class for Session-related errors
*/
class SessionError extends Error {
constructor(message, code = 'SESSION_ERROR') {
super(message);
this.name = 'SessionError';
this.code = code;
}
}
/**
* Creates a new session for a user
* @param {string} userId - The ID of the user
* @param {Object} options - Additional options for session creation
* @param {Date} options.expiration - Custom expiration date
* @returns {Promise<{session: Session, refreshToken: string}>}
* @throws {SessionError}
*/
const createSession = async (userId, options = {}) => {
if (!userId) {
throw new SessionError('User ID is required', 'INVALID_USER_ID');
}
try {
const session = new Session({
user: userId,
expiration: options.expiration || new Date(Date.now() + expires),
});
const refreshToken = await generateRefreshToken(session);
return { session, refreshToken };
} catch (error) {
logger.error('[createSession] Error creating session:', error);
throw new SessionError('Failed to create session', 'CREATE_SESSION_FAILED');
}
};
/**
* Finds a session by various parameters
* @param {Object} params - Search parameters
* @param {string} [params.refreshToken] - The refresh token to search by
* @param {string} [params.userId] - The user ID to search by
* @param {string} [params.sessionId] - The session ID to search by
* @param {Object} [options] - Additional options
* @param {boolean} [options.lean=true] - Whether to return plain objects instead of documents
* @returns {Promise<Session|null>}
* @throws {SessionError}
*/
const findSession = async (params, options = { lean: true }) => {
try {
const query = {};
if (!params.refreshToken && !params.userId && !params.sessionId) {
throw new SessionError('At least one search parameter is required', 'INVALID_SEARCH_PARAMS');
}
if (params.refreshToken) {
const tokenHash = await hashToken(params.refreshToken);
query.refreshTokenHash = tokenHash;
}
if (params.userId) {
query.user = params.userId;
}
if (params.sessionId) {
const sessionId = params.sessionId.sessionId || params.sessionId;
if (!mongoose.Types.ObjectId.isValid(sessionId)) {
throw new SessionError('Invalid session ID format', 'INVALID_SESSION_ID');
}
query._id = sessionId;
}
// Add expiration check to only return valid sessions
query.expiration = { $gt: new Date() };
const sessionQuery = Session.findOne(query);
if (options.lean) {
return await sessionQuery.lean();
}
return await sessionQuery.exec();
} catch (error) {
logger.error('[findSession] Error finding session:', error);
throw new SessionError('Failed to find session', 'FIND_SESSION_FAILED');
}
};
/**
* Updates session expiration
* @param {Session|string} session - The session or session ID to update
* @param {Date} [newExpiration] - Optional new expiration date
* @returns {Promise<Session>}
* @throws {SessionError}
*/
const updateExpiration = async (session, newExpiration) => {
try {
const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session;
if (!sessionDoc) {
throw new SessionError('Session not found', 'SESSION_NOT_FOUND');
}
sessionDoc.expiration = newExpiration || new Date(Date.now() + expires);
return await sessionDoc.save();
} catch (error) {
logger.error('[updateExpiration] Error updating session:', error);
throw new SessionError('Failed to update session expiration', 'UPDATE_EXPIRATION_FAILED');
}
};
/**
* Deletes a session by refresh token or session ID
* @param {Object} params - Delete parameters
* @param {string} [params.refreshToken] - The refresh token of the session to delete
* @param {string} [params.sessionId] - The ID of the session to delete
* @returns {Promise<Object>}
* @throws {SessionError}
*/
const deleteSession = async (params) => {
try {
if (!params.refreshToken && !params.sessionId) {
throw new SessionError(
'Either refreshToken or sessionId is required',
'INVALID_DELETE_PARAMS',
);
}
const query = {};
if (params.refreshToken) {
query.refreshTokenHash = await hashToken(params.refreshToken);
}
if (params.sessionId) {
query._id = params.sessionId;
}
const result = await Session.deleteOne(query);
if (result.deletedCount === 0) {
logger.warn('[deleteSession] No session found to delete');
}
return result;
} catch (error) {
logger.error('[deleteSession] Error deleting session:', error);
throw new SessionError('Failed to delete session', 'DELETE_SESSION_FAILED');
}
};
/**
* Deletes all sessions for a user
* @param {string} userId - The ID of the user
* @param {Object} [options] - Additional options
* @param {boolean} [options.excludeCurrentSession] - Whether to exclude the current session
* @param {string} [options.currentSessionId] - The ID of the current session to exclude
* @returns {Promise<Object>}
* @throws {SessionError}
*/
const deleteAllUserSessions = async (userId, options = {}) => {
try {
if (!userId) {
throw new SessionError('User ID is required', 'INVALID_USER_ID');
}
// Extract userId if it's passed as an object
const userIdString = userId.userId || userId;
if (!mongoose.Types.ObjectId.isValid(userIdString)) {
throw new SessionError('Invalid user ID format', 'INVALID_USER_ID_FORMAT');
}
const query = { user: userIdString };
if (options.excludeCurrentSession && options.currentSessionId) {
query._id = { $ne: options.currentSessionId };
}
const result = await Session.deleteMany(query);
if (result.deletedCount > 0) {
logger.debug(
`[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userIdString}.`,
);
}
return result;
} catch (error) {
logger.error('[deleteAllUserSessions] Error deleting user sessions:', error);
throw new SessionError('Failed to delete user sessions', 'DELETE_ALL_SESSIONS_FAILED');
}
};
/**
* Generates a refresh token for a session
* @param {Session} session - The session to generate a token for
* @returns {Promise<string>}
* @throws {SessionError}
*/
const generateRefreshToken = async (session) => {
if (!session || !session.user) {
throw new SessionError('Invalid session object', 'INVALID_SESSION');
}
try {
const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires;
if (!session.expiration) {
session.expiration = new Date(expiresIn);
}
const refreshToken = await signPayload({
payload: {
id: session.user,
sessionId: session._id,
},
secret: process.env.JWT_REFRESH_SECRET,
expirationTime: Math.floor((expiresIn - Date.now()) / 1000),
});
session.refreshTokenHash = await hashToken(refreshToken);
await session.save();
return refreshToken;
} catch (error) {
logger.error('[generateRefreshToken] Error generating refresh token:', error);
throw new SessionError('Failed to generate refresh token', 'GENERATE_TOKEN_FAILED');
}
};
/**
* Counts active sessions for a user
* @param {string} userId - The ID of the user
* @returns {Promise<number>}
* @throws {SessionError}
*/
const countActiveSessions = async (userId) => {
try {
if (!userId) {
throw new SessionError('User ID is required', 'INVALID_USER_ID');
}
return await Session.countDocuments({
user: userId,
expiration: { $gt: new Date() },
});
} catch (error) {
logger.error('[countActiveSessions] Error counting active sessions:', error);
throw new SessionError('Failed to count active sessions', 'COUNT_SESSIONS_FAILED');
}
};
module.exports = {
createSession,
findSession,
updateExpiration,
deleteSession,
deleteAllUserSessions,
generateRefreshToken,
countActiveSessions,
SessionError,
};

View file

@ -1,9 +1,6 @@
const mongoose = require('mongoose');
const { nanoid } = require('nanoid');
const { Constants } = require('librechat-data-provider');
const { Conversation } = require('~/models/Conversation');
const { shareSchema } = require('@librechat/data-schemas');
const SharedLink = mongoose.model('SharedLink', shareSchema);
const db = require('~/lib/db/connectDb');
const { getMessages } = require('./Message');
const logger = require('~/config/winston');
@ -76,7 +73,7 @@ function anonymizeMessages(messages, newConvoId) {
async function getSharedMessages(shareId) {
try {
const share = await SharedLink.findOne({ shareId, isPublic: true })
const share = await db.models.SharedLink.findOne({ shareId, isPublic: true })
.populate({
path: 'messages',
select: '-_id -__v -user',
@ -151,7 +148,7 @@ async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortD
query.conversationId = { $in: query.conversationId };
}
const sharedLinks = await SharedLink.find(query)
const sharedLinks = await db.models.SharedLink.find(query)
.sort(sort)
.limit(pageSize + 1)
.select('-__v -user')
@ -184,7 +181,7 @@ async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortD
async function deleteAllSharedLinks(user) {
try {
const result = await SharedLink.deleteMany({ user });
const result = await db.models.SharedLink.deleteMany({ user });
return {
message: 'All shared links deleted successfully',
deletedCount: result.deletedCount,
@ -202,7 +199,7 @@ async function createSharedLink(user, conversationId) {
if (!user || !conversationId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
const { SharedLink, Conversation } = db.models;
try {
const [existingShare, conversationMessages] = await Promise.all([
SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(),
@ -244,7 +241,7 @@ async function getSharedLink(user, conversationId) {
}
try {
const share = await SharedLink.findOne({ conversationId, user, isPublic: true })
const share = await db.models.SharedLink.findOne({ conversationId, user, isPublic: true })
.select('shareId -_id')
.lean();
@ -268,6 +265,7 @@ async function updateSharedLink(user, shareId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
const { SharedLink } = db.models;
try {
const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean();
@ -318,7 +316,7 @@ async function deleteSharedLink(user, shareId) {
}
try {
const result = await SharedLink.findOneAndDelete({ shareId, user }).lean();
const result = await db.models.SharedLink.findOneAndDelete({ shareId, user }).lean();
if (!result) {
return null;
@ -340,7 +338,6 @@ async function deleteSharedLink(user, shareId) {
}
module.exports = {
SharedLink,
getSharedLink,
getSharedLinks,
createSharedLink,

View file

@ -1,13 +1,6 @@
const mongoose = require('mongoose');
const { encryptV2 } = require('~/server/utils/crypto');
const { tokenSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
/**
* Token model.
* @type {mongoose.Model}
*/
const Token = mongoose.model('Token', tokenSchema);
const db = require('~/lib/db/connectDb');
/**
* Fixes the indexes for the Token collection from legacy TTL indexes to the new expiresAt index.
*/
@ -20,7 +13,7 @@ async function fixIndexes() {
) {
return;
}
const indexes = await Token.collection.indexes();
const indexes = await db.models.Token.collection.indexes();
logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2));
const unwantedTTLIndexes = indexes.filter(
(index) => index.key.createdAt === 1 && index.expireAfterSeconds !== undefined,
@ -31,7 +24,7 @@ async function fixIndexes() {
}
for (const index of unwantedTTLIndexes) {
logger.debug(`Dropping unwanted Token index: ${index.name}`);
await Token.collection.dropIndex(index.name);
await db.models.Token.collection.dropIndex(index.name);
logger.debug(`Dropped Token index: ${index.name}`);
}
logger.debug('Token index cleanup completed successfully.');
@ -42,118 +35,6 @@ async function fixIndexes() {
fixIndexes();
/**
* Creates a new Token instance.
* @param {Object} tokenData - The data for the new Token.
* @param {mongoose.Types.ObjectId} tokenData.userId - The user's ID. It is required.
* @param {String} tokenData.email - The user's email.
* @param {String} tokenData.token - The token. It is required.
* @param {Number} tokenData.expiresIn - The number of seconds until the token expires.
* @returns {Promise<mongoose.Document>} The new Token instance.
* @throws Will throw an error if token creation fails.
*/
async function createToken(tokenData) {
try {
const currentTime = new Date();
const expiresAt = new Date(currentTime.getTime() + tokenData.expiresIn * 1000);
const newTokenData = {
...tokenData,
createdAt: currentTime,
expiresAt,
};
return await Token.create(newTokenData);
} catch (error) {
logger.debug('An error occurred while creating token:', error);
throw error;
}
}
/**
* Finds a Token document that matches the provided query.
* @param {Object} query - The query to match against.
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
* @param {String} query.token - The token value.
* @param {String} [query.email] - The email of the user.
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
* @returns {Promise<Object|null>} The matched Token document, or null if not found.
* @throws Will throw an error if the find operation fails.
*/
async function findToken(query) {
try {
const conditions = [];
if (query.userId) {
conditions.push({ userId: query.userId });
}
if (query.token) {
conditions.push({ token: query.token });
}
if (query.email) {
conditions.push({ email: query.email });
}
if (query.identifier) {
conditions.push({ identifier: query.identifier });
}
const token = await Token.findOne({
$and: conditions,
}).lean();
return token;
} catch (error) {
logger.debug('An error occurred while finding token:', error);
throw error;
}
}
/**
* Updates a Token document that matches the provided query.
* @param {Object} query - The query to match against.
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
* @param {String} query.token - The token value.
* @param {String} [query.email] - The email of the user.
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
* @param {Object} updateData - The data to update the Token with.
* @returns {Promise<mongoose.Document|null>} The updated Token document, or null if not found.
* @throws Will throw an error if the update operation fails.
*/
async function updateToken(query, updateData) {
try {
return await Token.findOneAndUpdate(query, updateData, { new: true });
} catch (error) {
logger.debug('An error occurred while updating token:', error);
throw error;
}
}
/**
* Deletes all Token documents that match the provided token, user ID, or email.
* @param {Object} query - The query to match against.
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
* @param {String} query.token - The token value.
* @param {String} [query.email] - The email of the user.
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
* @returns {Promise<Object>} The result of the delete operation.
* @throws Will throw an error if the delete operation fails.
*/
async function deleteTokens(query) {
try {
return await Token.deleteMany({
$or: [
{ userId: query.userId },
{ token: query.token },
{ email: query.email },
{ identifier: query.identifier },
],
});
} catch (error) {
logger.debug('An error occurred while deleting tokens:', error);
throw error;
}
}
/**
* Handles the OAuth token by creating or updating the token.
* @param {object} fields
@ -182,18 +63,15 @@ async function handleOAuthToken({
expiresIn: parseInt(expiresIn, 10) || 3600,
};
const existingToken = await findToken({ userId, identifier });
const {Token} = db.models;
const existingToken = await Token.findToken({ userId, identifier });
if (existingToken) {
return await updateToken({ identifier }, tokenData);
return await Token.updateToken({ identifier }, tokenData);
} else {
return await createToken(tokenData);
return await Token.createToken(tokenData);
}
}
module.exports = {
findToken,
createToken,
updateToken,
deleteTokens,
handleOAuthToken,
};

View file

@ -1,7 +1,4 @@
const mongoose = require('mongoose');
const { toolCallSchema } = require('@librechat/data-schemas');
const ToolCall = mongoose.model('ToolCall', toolCallSchema);
const db = require('~/lib/db/connectDb');
/**
* Create a new tool call
* @param {IToolCallData} toolCallData - The tool call data
@ -9,7 +6,7 @@ const ToolCall = mongoose.model('ToolCall', toolCallSchema);
*/
async function createToolCall(toolCallData) {
try {
return await ToolCall.create(toolCallData);
return await db.models.ToolCall.create(toolCallData);
} catch (error) {
throw new Error(`Error creating tool call: ${error.message}`);
}
@ -22,7 +19,7 @@ async function createToolCall(toolCallData) {
*/
async function getToolCallById(id) {
try {
return await ToolCall.findById(id).lean();
return await db.models.ToolCall.findById(id).lean();
} catch (error) {
throw new Error(`Error fetching tool call: ${error.message}`);
}
@ -36,7 +33,7 @@ async function getToolCallById(id) {
*/
async function getToolCallsByMessage(messageId, userId) {
try {
return await ToolCall.find({ messageId, user: userId }).lean();
return await db.models.ToolCall.find({ messageId, user: userId }).lean();
} catch (error) {
throw new Error(`Error fetching tool calls: ${error.message}`);
}
@ -50,7 +47,7 @@ async function getToolCallsByMessage(messageId, userId) {
*/
async function getToolCallsByConvo(conversationId, userId) {
try {
return await ToolCall.find({ conversationId, user: userId }).lean();
return await db.models.ToolCall.find({ conversationId, user: userId }).lean();
} catch (error) {
throw new Error(`Error fetching tool calls: ${error.message}`);
}
@ -64,7 +61,7 @@ async function getToolCallsByConvo(conversationId, userId) {
*/
async function updateToolCall(id, updateData) {
try {
return await ToolCall.findByIdAndUpdate(id, updateData, { new: true }).lean();
return await db.models.ToolCall.findByIdAndUpdate(id, updateData, { new: true }).lean();
} catch (error) {
throw new Error(`Error updating tool call: ${error.message}`);
}
@ -82,7 +79,7 @@ async function deleteToolCalls(userId, conversationId) {
if (conversationId) {
query.conversationId = conversationId;
}
return await ToolCall.deleteMany(query);
return await db.models.ToolCall.deleteMany(query);
} catch (error) {
throw new Error(`Error deleting tool call: ${error.message}`);
}

View file

@ -3,7 +3,7 @@ const { transactionSchema } = require('@librechat/data-schemas');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { logger } = require('~/config');
const Balance = require('./Balance');
const db = require('~/lib/db/connectDb');
const cancelRate = 1.15;
@ -23,6 +23,7 @@ const updateBalance = async ({ user, incrementValue, setValues }) => {
let maxRetries = 10; // Number of times to retry on conflict
let delay = 50; // Initial retry delay in ms
let lastError = null;
const { Balance } = db.models;
for (let attempt = 1; attempt <= maxRetries; attempt++) {
let currentBalanceDoc;
@ -140,19 +141,19 @@ const updateBalance = async ({ user, incrementValue, setValues }) => {
};
/** Method to calculate and set the tokenValue for a transaction */
transactionSchema.methods.calculateTokenValue = function () {
if (!this.valueKey || !this.tokenType) {
this.tokenValue = this.rawAmount;
function calculateTokenValue(txn) {
if (!txn.valueKey || !txn.tokenType) {
txn.tokenValue = txn.rawAmount;
}
const { valueKey, tokenType, model, endpointTokenConfig } = this;
const { valueKey, tokenType, model, endpointTokenConfig } = txn;
const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig }));
this.rate = multiplier;
this.tokenValue = this.rawAmount * multiplier;
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
this.rate *= cancelRate;
txn.rate = multiplier;
txn.tokenValue = txn.rawAmount * multiplier;
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate);
txn.rate *= cancelRate;
}
};
}
/**
* New static method to create an auto-refill transaction that does NOT trigger a balance update.
@ -163,13 +164,14 @@ transactionSchema.methods.calculateTokenValue = function () {
* @param {number} txData.rawAmount - The raw amount of tokens.
* @returns {Promise<object>} - The created transaction.
*/
transactionSchema.statics.createAutoRefillTransaction = async function (txData) {
async function createAutoRefillTransaction(txData) {
const Transaction = db.models.Transaction;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}
const transaction = new this(txData);
const transaction = new Transaction(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig;
transaction.calculateTokenValue();
calculateTokenValue(transaction);
await transaction.save();
const balanceResponse = await updateBalance({
@ -185,21 +187,20 @@ transactionSchema.statics.createAutoRefillTransaction = async function (txData)
logger.debug('[Balance.check] Auto-refill performed', result);
result.transaction = transaction;
return result;
};
}
/**
* Static method to create a transaction and update the balance
* @param {txData} txData - Transaction data.
*/
transactionSchema.statics.create = async function (txData) {
const Transaction = this;
async function createTransaction(txData) {
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}
const transaction = new Transaction(txData);
const transaction = new db.models.Transaction(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig;
transaction.calculateTokenValue();
calculateTokenValue(transaction);
await transaction.save();
@ -209,7 +210,6 @@ transactionSchema.statics.create = async function (txData) {
}
let incrementValue = transaction.tokenValue;
const balanceResponse = await updateBalance({
user: transaction.user,
incrementValue,
@ -221,21 +221,19 @@ transactionSchema.statics.create = async function (txData) {
balance: balanceResponse.tokenCredits,
[transaction.tokenType]: incrementValue,
};
};
}
/**
* Static method to create a structured transaction and update the balance
* @param {txData} txData - Transaction data.
*/
transactionSchema.statics.createStructured = async function (txData) {
const Transaction = this;
const transaction = new Transaction({
async function createStructuredTransaction(txData) {
const transaction = new db.models.Transaction({
...txData,
endpointTokenConfig: txData.endpointTokenConfig,
});
transaction.calculateStructuredTokenValue();
calculateStructuredTokenValue(transaction);
await transaction.save();
@ -257,71 +255,69 @@ transactionSchema.statics.createStructured = async function (txData) {
balance: balanceResponse.tokenCredits,
[transaction.tokenType]: incrementValue,
};
};
}
/** Method to calculate token value for structured tokens */
transactionSchema.methods.calculateStructuredTokenValue = function () {
if (!this.tokenType) {
this.tokenValue = this.rawAmount;
function calculateStructuredTokenValue(txn) {
if (!txn.tokenType) {
txn.tokenValue = txn.rawAmount;
return;
}
const { model, endpointTokenConfig } = this;
const { model, endpointTokenConfig } = txn;
if (this.tokenType === 'prompt') {
if (txn.tokenType === 'prompt') {
const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig });
const writeMultiplier =
getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier;
const readMultiplier =
getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier;
this.rateDetail = {
txn.rateDetail = {
input: inputMultiplier,
write: writeMultiplier,
read: readMultiplier,
};
const totalPromptTokens =
Math.abs(this.inputTokens || 0) +
Math.abs(this.writeTokens || 0) +
Math.abs(this.readTokens || 0);
Math.abs(txn.inputTokens || 0) +
Math.abs(txn.writeTokens || 0) +
Math.abs(txn.readTokens || 0);
if (totalPromptTokens > 0) {
this.rate =
(Math.abs(inputMultiplier * (this.inputTokens || 0)) +
Math.abs(writeMultiplier * (this.writeTokens || 0)) +
Math.abs(readMultiplier * (this.readTokens || 0))) /
txn.rate =
(Math.abs(inputMultiplier * (txn.inputTokens || 0)) +
Math.abs(writeMultiplier * (txn.writeTokens || 0)) +
Math.abs(readMultiplier * (txn.readTokens || 0))) /
totalPromptTokens;
} else {
this.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens
txn.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens
}
this.tokenValue = -(
Math.abs(this.inputTokens || 0) * inputMultiplier +
Math.abs(this.writeTokens || 0) * writeMultiplier +
Math.abs(this.readTokens || 0) * readMultiplier
txn.tokenValue = -(
Math.abs(txn.inputTokens || 0) * inputMultiplier +
Math.abs(txn.writeTokens || 0) * writeMultiplier +
Math.abs(txn.readTokens || 0) * readMultiplier
);
this.rawAmount = -totalPromptTokens;
} else if (this.tokenType === 'completion') {
const multiplier = getMultiplier({ tokenType: this.tokenType, model, endpointTokenConfig });
this.rate = Math.abs(multiplier);
this.tokenValue = -Math.abs(this.rawAmount) * multiplier;
this.rawAmount = -Math.abs(this.rawAmount);
txn.rawAmount = -totalPromptTokens;
} else if (txn.tokenType === 'completion') {
const multiplier = getMultiplier({ tokenType: txn.tokenType, model, endpointTokenConfig });
txn.rate = Math.abs(multiplier);
txn.tokenValue = -Math.abs(txn.rawAmount) * multiplier;
txn.rawAmount = -Math.abs(txn.rawAmount);
}
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
this.rate *= cancelRate;
if (this.rateDetail) {
this.rateDetail = Object.fromEntries(
Object.entries(this.rateDetail).map(([k, v]) => [k, v * cancelRate]),
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate);
txn.rate *= cancelRate;
if (txn.rateDetail) {
txn.rateDetail = Object.fromEntries(
Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]),
);
}
}
};
const Transaction = mongoose.model('Transaction', transactionSchema);
}
/**
* Queries and retrieves transactions based on a given filter.
@ -333,11 +329,16 @@ const Transaction = mongoose.model('Transaction', transactionSchema);
*/
async function getTransactions(filter) {
try {
return await Transaction.find(filter).lean();
return await db.models.Transaction.find(filter).lean();
} catch (error) {
logger.error('Error querying transactions:', error);
throw error;
}
}
module.exports = { Transaction, getTransactions };
module.exports = {
getTransactions,
createTransaction,
createAutoRefillTransaction,
createStructuredTransaction,
};

View file

@ -3,18 +3,22 @@ const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
const db = require('~/lib/db/connectDb');
const { createTransaction } = require('./Transaction');
// Mock the custom config module so we can control the balance flag.
jest.mock('~/server/services/Config');
let mongoServer;
let Balance;
let Transaction;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
await db.connectDb(mongoUri);
Balance = db.models.Balance;
Transaction = db.models.Transaction;
});
afterAll(async () => {
@ -368,7 +372,7 @@ describe('NaN Handling Tests', () => {
};
// Act
const result = await Transaction.create(txData);
const result = await createTransaction(txData);
// Assert: No transaction should be created and balance remains unchanged.
expect(result).toBeUndefined();

View file

@ -1,6 +0,0 @@
const mongoose = require('mongoose');
const { userSchema } = require('@librechat/data-schemas');
const User = mongoose.model('User', userSchema);
module.exports = User;

View file

@ -1,9 +1,9 @@
const { ViolationTypes } = require('librechat-data-provider');
const { Transaction } = require('./Transaction');
const { createAutoRefillTransaction } = require('./Transaction');
const { logViolation } = require('~/cache');
const { getMultiplier } = require('./tx');
const { logger } = require('~/config');
const Balance = require('./Balance');
const db = require('~/lib/db/connectDb');
function isInvalidDate(date) {
return isNaN(date);
@ -26,7 +26,7 @@ const checkBalanceRecord = async function ({
const tokenCost = amount * multiplier;
// Retrieve the balance record
let record = await Balance.findOne({ user }).lean();
let record = await db.models.Balance.findOne({ user }).lean();
if (!record) {
logger.debug('[Balance.check] No balance record found for user', { user });
return {
@ -60,7 +60,7 @@ const checkBalanceRecord = async function ({
) {
try {
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
const result = await Transaction.createAutoRefillTransaction({
const result = await createAutoRefillTransaction({
user: user,
tokenType: 'credits',
context: 'autoRefill',

View file

@ -1,6 +1,7 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Message, getMessages, bulkSaveMessages } = require('./Message');
const { getMessages, bulkSaveMessages } = require('./Message');
const db = require('~/lib/db/connectDb');
// Original version of buildTree function
function buildTree({ messages, fileMap }) {
@ -42,11 +43,13 @@ function buildTree({ messages, fileMap }) {
}
let mongod;
let Message;
beforeAll(async () => {
mongod = await MongoMemoryServer.create();
const uri = mongod.getUri();
await mongoose.connect(uri);
await db.connectDb(uri);
Message = db.models.Message;
});
afterAll(async () => {

View file

@ -1,13 +1,4 @@
const {
comparePassword,
deleteUserById,
generateToken,
getUserById,
updateUser,
createUser,
countUsers,
findUser,
} = require('./userMethods');
const { comparePassword } = require('./userMethods');
const {
findFileById,
createFile,
@ -26,32 +17,11 @@ const {
deleteMessagesSince,
deleteMessages,
} = require('./Message');
const {
createSession,
findSession,
updateExpiration,
deleteSession,
deleteAllUserSessions,
generateRefreshToken,
countActiveSessions,
} = require('./Session');
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
const { createToken, findToken, updateToken, deleteTokens } = require('./Token');
const Balance = require('./Balance');
const User = require('./User');
const Key = require('./Key');
module.exports = {
comparePassword,
deleteUserById,
generateToken,
getUserById,
updateUser,
createUser,
countUsers,
findUser,
findFileById,
createFile,
updateFile,
@ -77,21 +47,4 @@ module.exports = {
getPresets,
savePreset,
deletePresets,
createToken,
findToken,
updateToken,
deleteTokens,
createSession,
findSession,
updateExpiration,
deleteSession,
deleteAllUserSessions,
generateRefreshToken,
countActiveSessions,
User,
Key,
Balance,
};

View file

@ -1,7 +1,7 @@
const mongoose = require('mongoose');
const { getRandomValues, hashToken } = require('~/server/utils/crypto');
const { createToken, findToken } = require('./Token');
const logger = require('~/config/winston');
const db = require('~/lib/db/connectDb');
/**
* @module inviteUser
@ -23,7 +23,7 @@ const createInvite = async (email) => {
const fakeUserId = new mongoose.Types.ObjectId();
await createToken({
await db.models.Token.createToken({
userId: fakeUserId,
email,
token: hash,
@ -50,7 +50,7 @@ const getInvite = async (encodedToken, email) => {
try {
const token = decodeURIComponent(encodedToken);
const hash = await hashToken(token);
const invite = await findToken({ token: hash, email });
const invite = await db.models.Token.findToken({ token: hash, email });
if (!invite) {
throw new Error('Invite not found or email does not match');

View file

@ -1,18 +0,0 @@
const mongoose = require('mongoose');
const mongoMeili = require('../plugins/mongoMeili');
const { convoSchema } = require('@librechat/data-schemas');
if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
convoSchema.plugin(mongoMeili, {
host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY,
/** Note: Will get created automatically if it doesn't exist already */
indexName: 'convos',
primaryKey: 'conversationId',
});
}
const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema);
module.exports = Conversation;

View file

@ -1,4 +1,3 @@
const mongoose = require('mongoose');
const mongoMeili = require('~/models/plugins/mongoMeili');
const { messageSchema } = require('@librechat/data-schemas');
@ -11,6 +10,6 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
});
}
const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
// const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
module.exports = Message;

View file

@ -1,6 +0,0 @@
const mongoose = require('mongoose');
const { presetSchema } = require('@librechat/data-schemas');
const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema);
module.exports = Preset;

View file

@ -1,6 +1,6 @@
const { Transaction } = require('./Transaction');
const { logger } = require('~/config');
const db = require('~/lib/db/connectDb');
const { createTransaction, createStructuredTransaction } = require('./Transaction');
/**
* Creates up to two transactions to record the spending of tokens.
*
@ -33,7 +33,7 @@ const spendTokens = async (txData, tokenUsage) => {
let prompt, completion;
try {
if (promptTokens !== undefined) {
prompt = await Transaction.create({
prompt = await createTransaction({
...txData,
tokenType: 'prompt',
rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0),
@ -41,7 +41,7 @@ const spendTokens = async (txData, tokenUsage) => {
}
if (completionTokens !== undefined) {
completion = await Transaction.create({
completion = await createTransaction({
...txData,
tokenType: 'completion',
rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0),
@ -101,7 +101,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => {
try {
if (promptTokens) {
const { input = 0, write = 0, read = 0 } = promptTokens;
prompt = await Transaction.createStructured({
prompt = await createStructuredTransaction({
...txData,
tokenType: 'prompt',
inputTokens: -input,
@ -111,7 +111,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => {
}
if (completionTokens) {
completion = await Transaction.create({
completion = await createTransaction({
...txData,
tokenType: 'completion',
rawAmount: -completionTokens,

View file

@ -1,8 +1,8 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const db = require('~/lib/db/connectDb');
const { createTransaction, createAutoRefillTransaction } = require('./Transaction');
// Mock the logger to prevent console output during tests
jest.mock('~/config', () => ({
@ -20,10 +20,15 @@ describe('spendTokens', () => {
let mongoServer;
let userId;
let Transaction;
let Balance;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
await db.connectDb(mongoUri);
Balance = db.models.Balance;
Transaction = db.models.Transaction;
});
afterAll(async () => {
@ -197,7 +202,7 @@ describe('spendTokens', () => {
// Check that the transaction records show the adjusted values
const transactionResults = await Promise.all(
transactions.map((t) =>
Transaction.create({
createTransaction({
...txData,
tokenType: t.tokenType,
rawAmount: t.rawAmount,
@ -280,7 +285,7 @@ describe('spendTokens', () => {
// Check the return values from Transaction.create directly
// This is to verify that the incrementValue is not becoming positive
const directResult = await Transaction.create({
const directResult = await createTransaction({
user: userId,
conversationId: 'test-convo-3',
model: 'gpt-4',
@ -607,7 +612,7 @@ describe('spendTokens', () => {
const promises = [];
for (let i = 0; i < numberOfRefills; i++) {
promises.push(
Transaction.createAutoRefillTransaction({
createAutoRefillTransaction({
user: userId,
tokenType: 'credits',
context: 'concurrent-refill-test',

View file

@ -1,159 +1,4 @@
const bcrypt = require('bcryptjs');
const { getBalanceConfig } = require('~/server/services/Config');
const signPayload = require('~/server/services/signPayload');
const Balance = require('./Balance');
const User = require('./User');
/**
* Retrieve a user by ID and convert the found user document to a plain object.
*
* @param {string} userId - The ID of the user to find and return as a plain object.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
*/
const getUserById = async function (userId, fieldsToSelect = null) {
const query = User.findById(userId);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
/**
* Search for a single user based on partial data and return matching user document as plain object.
* @param {Partial<MongoUser>} searchCriteria - The partial data to use for searching the user.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
*/
const findUser = async function (searchCriteria, fieldsToSelect = null) {
const query = User.findOne(searchCriteria);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
/**
* Update a user with new data without overwriting existing properties.
*
* @param {string} userId - The ID of the user to update.
* @param {Object} updateData - An object containing the properties to update.
* @returns {Promise<MongoUser>} The updated user document as a plain object, or `null` if no user is found.
*/
const updateUser = async function (userId, updateData) {
const updateOperation = {
$set: updateData,
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
};
return await User.findByIdAndUpdate(userId, updateOperation, {
new: true,
runValidators: true,
}).lean();
};
/**
* Creates a new user, optionally with a TTL of 1 week.
* @param {MongoUser} data - The user data to be created, must contain user_id.
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
* @param {boolean} [returnUser=false] - Whether to return the created user object.
* @returns {Promise<ObjectId|MongoUser>} A promise that resolves to the created user document ID or user object.
* @throws {Error} If a user with the same user_id already exists.
*/
const createUser = async (data, disableTTL = true, returnUser = false) => {
const balance = await getBalanceConfig();
const userData = {
...data,
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
};
if (disableTTL) {
delete userData.expiresAt;
}
const user = await User.create(userData);
// If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance
if (balance?.enabled && balance?.startBalance) {
const update = {
$inc: { tokenCredits: balance.startBalance },
};
if (
balance.autoRefillEnabled &&
balance.refillIntervalValue != null &&
balance.refillIntervalUnit != null &&
balance.refillAmount != null
) {
update.$set = {
autoRefillEnabled: true,
refillIntervalValue: balance.refillIntervalValue,
refillIntervalUnit: balance.refillIntervalUnit,
refillAmount: balance.refillAmount,
};
}
await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean();
}
if (returnUser) {
return user.toObject();
}
return user._id;
};
/**
* Count the number of user documents in the collection based on the provided filter.
*
* @param {Object} [filter={}] - The filter to apply when counting the documents.
* @returns {Promise<number>} The count of documents that match the filter.
*/
const countUsers = async function (filter = {}) {
return await User.countDocuments(filter);
};
/**
* Delete a user by their unique ID.
*
* @param {string} userId - The ID of the user to delete.
* @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents.
*/
const deleteUserById = async function (userId) {
try {
const result = await User.deleteOne({ _id: userId });
if (result.deletedCount === 0) {
return { deletedCount: 0, message: 'No user found with that ID.' };
}
return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' };
} catch (error) {
throw new Error('Error deleting user: ' + error.message);
}
};
const { SESSION_EXPIRY } = process.env ?? {};
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
/**
* Generates a JWT token for a given user.
*
* @param {MongoUser} user - The user for whom the token is being generated.
* @returns {Promise<string>} A promise that resolves to a JWT token.
*/
const generateToken = async (user) => {
if (!user) {
throw new Error('No user provided');
}
return await signPayload({
payload: {
id: user._id,
username: user.username,
provider: user.provider,
email: user.email,
},
secret: process.env.JWT_SECRET,
expirationTime: expires / 1000,
});
};
/**
* Compares the provided password with the user's password.
@ -179,11 +24,4 @@ const comparePassword = async (user, candidatePassword) => {
module.exports = {
comparePassword,
deleteUserById,
generateToken,
getUserById,
countUsers,
createUser,
updateUser,
findUser,
};