mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-20 10:20:15 +01:00
Move usermethods and models to data-schema
This commit is contained in:
parent
4808c5be48
commit
4049b5572c
93 changed files with 2396 additions and 1267 deletions
|
|
@ -10,18 +10,45 @@ const mockPluginService = {
|
|||
getUserPluginAuthValue: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('~/models/User', () => {
|
||||
return function () {
|
||||
return mockUser;
|
||||
const mockModels = {
|
||||
User: mockUser,
|
||||
};
|
||||
jest.mock('~/lib/db/connectDb', () => {
|
||||
return {
|
||||
connectDb: jest.fn(),
|
||||
get models() {
|
||||
return mockModels;
|
||||
},
|
||||
};
|
||||
});
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const userModelMock = {
|
||||
createUser: jest.fn(() => mockUser),
|
||||
findUser: jest.fn(),
|
||||
updateUser: jest.fn(),
|
||||
};
|
||||
return {
|
||||
registerModels: jest.fn().mockReturnValue({
|
||||
User: userModelMock,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/models/Message', () => ({
|
||||
Message: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
Conversation: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/File', () => ({
|
||||
File: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/PluginService', () => mockPluginService);
|
||||
|
||||
const { BaseLLM } = require('@langchain/openai');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
|
||||
const User = require('~/models/User');
|
||||
const PluginService = require('~/server/services/PluginService');
|
||||
const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools');
|
||||
const { StructuredSD, availableTools, DALLE3 } = require('../');
|
||||
|
|
@ -36,6 +63,9 @@ describe('Tool Handlers', () => {
|
|||
const mainPlugin = availableTools.find((tool) => tool.pluginKey === pluginKey);
|
||||
const authConfigs = mainPlugin.authConfig;
|
||||
|
||||
const { registerModels } = require('@librechat/data-schemas');
|
||||
let User = registerModels().User;
|
||||
|
||||
beforeAll(async () => {
|
||||
mockUser.save.mockResolvedValue(undefined);
|
||||
|
||||
|
|
@ -52,7 +82,7 @@ describe('Tool Handlers', () => {
|
|||
},
|
||||
);
|
||||
|
||||
fakeUser = new User({
|
||||
fakeUser = await User.createUser({
|
||||
name: 'Fake User',
|
||||
username: 'fakeuser',
|
||||
email: 'fakeuser@example.com',
|
||||
|
|
|
|||
6
api/cache/banViolation.js
vendored
6
api/cache/banViolation.js
vendored
|
|
@ -1,8 +1,8 @@
|
|||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, math, removePorts } = require('~/server/utils');
|
||||
const { deleteAllUserSessions } = require('~/models');
|
||||
const getLogStores = require('./getLogStores');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
|
||||
const interval = math(BAN_INTERVAL, 20);
|
||||
|
|
@ -32,7 +32,6 @@ const banViolation = async (req, res, errorMessage) => {
|
|||
if (!isEnabled(BAN_VIOLATIONS)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!errorMessage) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -46,12 +45,11 @@ const banViolation = async (req, res, errorMessage) => {
|
|||
return;
|
||||
}
|
||||
|
||||
await deleteAllUserSessions({ userId: user_id });
|
||||
await db.models.Session.deleteAllUserSessions({ userId: user_id });
|
||||
res.clearCookie('refreshToken');
|
||||
|
||||
const banLogs = getLogStores(ViolationTypes.BAN);
|
||||
const duration = errorMessage.duration || banLogs.opts.ttl;
|
||||
|
||||
if (duration <= 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
35
api/cache/banViolation.spec.js
vendored
35
api/cache/banViolation.spec.js
vendored
|
|
@ -1,7 +1,40 @@
|
|||
const banViolation = require('./banViolation');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const sessionModelMock = {
|
||||
deleteAllUserSessions: jest.fn(),
|
||||
};
|
||||
|
||||
return {
|
||||
registerModels: jest.fn().mockReturnValue({
|
||||
Session: sessionModelMock,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const mockModels = {
|
||||
Session: {
|
||||
deleteAllUserSessions: jest.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock('~/lib/db/connectDb', () => {
|
||||
return {
|
||||
connectDb: jest.fn(),
|
||||
get models() {
|
||||
return mockModels;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/server/utils', () => ({
|
||||
isEnabled: jest.fn(() => true), // default to false, override per test if needed
|
||||
math: jest.fn(() => 20), // default to false, override per test if needed
|
||||
removePorts: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('keyv');
|
||||
jest.mock('../models/Session');
|
||||
// jest.mock('../models/Session');
|
||||
// Mocking the getLogStores function
|
||||
jest.mock('./getLogStores', () => {
|
||||
return jest.fn().mockImplementation(() => {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
require('dotenv').config();
|
||||
const mongoose = require('mongoose');
|
||||
const MONGO_URI = process.env.MONGO_URI;
|
||||
const { registerModels } = require('@librechat/data-schemas');
|
||||
|
||||
if (!MONGO_URI) {
|
||||
if (!process.env.MONGO_URI) {
|
||||
throw new Error('Please define the MONGO_URI environment variable');
|
||||
}
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ if (!cached) {
|
|||
cached = global.mongoose = { conn: null, promise: null };
|
||||
}
|
||||
|
||||
async function connectDb() {
|
||||
async function connectDb(mongoUri = process.env.MONGO_URI) {
|
||||
if (cached.conn && cached.conn?._readyState === 1) {
|
||||
return cached.conn;
|
||||
}
|
||||
|
|
@ -34,12 +34,30 @@ async function connectDb() {
|
|||
};
|
||||
|
||||
mongoose.set('strictQuery', true);
|
||||
cached.promise = mongoose.connect(MONGO_URI, opts).then((mongoose) => {
|
||||
cached.promise = mongoose.connect(mongoUri, opts).then((mongoose) => {
|
||||
return mongoose;
|
||||
});
|
||||
}
|
||||
cached.conn = await cached.promise;
|
||||
|
||||
// Register models once
|
||||
if (!cached.models) {
|
||||
cached.models = registerModels(mongoose);
|
||||
}
|
||||
|
||||
return cached.conn;
|
||||
}
|
||||
|
||||
module.exports = connectDb;
|
||||
function getModels() {
|
||||
return cached.models;
|
||||
}
|
||||
module.exports = {
|
||||
connectDb,
|
||||
getModels,
|
||||
get models() {
|
||||
if (!cached.models) {
|
||||
throw new Error('Models not registered. ');
|
||||
}
|
||||
return cached.models;
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const connectDb = require('./connectDb');
|
||||
const { connectDb, getModels} = require('./connectDb');
|
||||
const indexSync = require('./indexSync');
|
||||
|
||||
module.exports = { connectDb, indexSync };
|
||||
module.exports = { connectDb, getModels, indexSync };
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
const { MeiliSearch } = require('meilisearch');
|
||||
const { Conversation } = require('~/models/Conversation');
|
||||
const { Message } = require('~/models/Message');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const searchEnabled = isEnabled(process.env.SEARCH);
|
||||
const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC);
|
||||
|
|
@ -29,7 +28,7 @@ async function indexSync() {
|
|||
if (!searchEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { Message, Conversation } = db.models;
|
||||
try {
|
||||
const client = MeiliSearchClient.getInstance();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
const mongoose = require('mongoose');
|
||||
const crypto = require('node:crypto');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
|
||||
require('librechat-data-provider').Constants;
|
||||
|
|
@ -15,7 +14,7 @@ const getLogStores = require('~/cache/getLogStores');
|
|||
const { getActions } = require('./Action');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const Agent = mongoose.model('agent', agentSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Create an agent with the provided data.
|
||||
|
|
@ -36,7 +35,7 @@ const createAgent = async (agentData) => {
|
|||
},
|
||||
],
|
||||
};
|
||||
return (await Agent.create(initialAgentData)).toObject();
|
||||
return (await db.models.Agent.create(initialAgentData)).toObject();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -47,7 +46,7 @@ const createAgent = async (agentData) => {
|
|||
* @param {string} searchParameter.author - The user ID of the agent's author.
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean();
|
||||
const getAgent = async (searchParameter) => await db.models.Agent.findOne(searchParameter).lean();
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID
|
||||
|
|
@ -269,6 +268,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
|||
const { updatingUserId = null, forceVersion = false } = options;
|
||||
const mongoOptions = { new: true, upsert: false };
|
||||
|
||||
const Agent = db.models?.Agent;
|
||||
const currentAgent = await Agent.findOne(searchParameter);
|
||||
if (currentAgent) {
|
||||
const { __v, _id, id, versions, author, ...versionData } = currentAgent.toObject();
|
||||
|
|
@ -362,6 +362,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
|||
* @returns {Promise<Agent>} The updated agent.
|
||||
*/
|
||||
const addAgentResourceFile = async ({ req, agent_id, tool_resource, file_id }) => {
|
||||
const Agent = db.models?.Agent;
|
||||
const searchParameter = { id: agent_id };
|
||||
let agent = await getAgent(searchParameter);
|
||||
if (!agent) {
|
||||
|
|
@ -427,7 +428,7 @@ const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
|||
}
|
||||
|
||||
const updatePullData = { $pull: pullOps };
|
||||
const agentAfterPull = await Agent.findOneAndUpdate(searchParameter, updatePullData, {
|
||||
const agentAfterPull = await db.models.Agent.findOneAndUpdate(searchParameter, updatePullData, {
|
||||
new: true,
|
||||
}).lean();
|
||||
|
||||
|
|
@ -457,7 +458,7 @@ const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
|||
* @returns {Promise<void>} Resolves when the agent has been successfully deleted.
|
||||
*/
|
||||
const deleteAgent = async (searchParameter) => {
|
||||
const agent = await Agent.findOneAndDelete(searchParameter);
|
||||
const agent = await db.models.Agent.findOneAndDelete(searchParameter);
|
||||
if (agent) {
|
||||
await removeAgentFromAllProjects(agent.id);
|
||||
}
|
||||
|
|
@ -481,9 +482,8 @@ const getListAgents = async (searchParameter) => {
|
|||
delete globalQuery.author;
|
||||
query = { $or: [globalQuery, query] };
|
||||
}
|
||||
|
||||
const agents = (
|
||||
await Agent.find(query, {
|
||||
await db.models.Agent.find(query, {
|
||||
id: 1,
|
||||
_id: 0,
|
||||
name: 1,
|
||||
|
|
@ -580,6 +580,7 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds
|
|||
* @throws {Error} If the agent is not found or the specified version does not exist.
|
||||
*/
|
||||
const revertAgentVersion = async (searchParameter, versionIndex) => {
|
||||
const Agent = db.models?.Agent;
|
||||
const agent = await Agent.findOne(searchParameter);
|
||||
if (!agent) {
|
||||
throw new Error('Agent not found');
|
||||
|
|
@ -662,7 +663,6 @@ const generateActionMetadataHash = async (actionIds, actions) => {
|
|||
*/
|
||||
|
||||
module.exports = {
|
||||
Agent,
|
||||
getAgent,
|
||||
loadAgent,
|
||||
createAgent,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ const mongoose = require('mongoose');
|
|||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
Agent,
|
||||
addAgentResourceFile,
|
||||
removeAgentResourceFiles,
|
||||
createAgent,
|
||||
|
|
@ -20,6 +19,9 @@ const {
|
|||
getListAgents,
|
||||
updateAgentProjects,
|
||||
} = require('./Agent');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
let Agent;
|
||||
|
||||
describe('Agent Resource File Operations', () => {
|
||||
let mongoServer;
|
||||
|
|
@ -27,7 +29,9 @@ describe('Agent Resource File Operations', () => {
|
|||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
await db.connectDb(mongoUri);
|
||||
|
||||
Agent = db.models.Agent;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
|
|
@ -55,6 +59,7 @@ describe('Agent Resource File Operations', () => {
|
|||
|
||||
test('should add tool_resource to tools if missing', async () => {
|
||||
const agent = await createBasicAgent();
|
||||
|
||||
const fileId = uuidv4();
|
||||
const toolResource = 'file_search';
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { assistantSchema } = require('@librechat/data-schemas');
|
||||
|
||||
const Assistant = mongoose.model('assistant', assistantSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Update an assistant with new data without overwriting existing properties,
|
||||
|
|
@ -15,7 +12,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
|
|||
*/
|
||||
const updateAssistantDoc = async (searchParams, updateData) => {
|
||||
const options = { new: true, upsert: true };
|
||||
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
return await db.models.Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -26,7 +23,7 @@ const updateAssistantDoc = async (searchParams, updateData) => {
|
|||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @returns {Promise<AssistantDocument|null>} The assistant document as a plain object, or null if not found.
|
||||
*/
|
||||
const getAssistant = async (searchParams) => await Assistant.findOne(searchParams).lean();
|
||||
const getAssistant = async (searchParams) => await db.models.Assistant.findOne(searchParams).lean();
|
||||
|
||||
/**
|
||||
* Retrieves all assistants that match the given search parameters.
|
||||
|
|
@ -36,7 +33,7 @@ const getAssistant = async (searchParams) => await Assistant.findOne(searchParam
|
|||
* @returns {Promise<Array<AssistantDocument>>} A promise that resolves to an array of assistant documents as plain objects.
|
||||
*/
|
||||
const getAssistants = async (searchParams, select = null) => {
|
||||
let query = Assistant.find(searchParams);
|
||||
let query = db.models.Assistant.find(searchParams);
|
||||
|
||||
if (select) {
|
||||
query = query.select(select);
|
||||
|
|
@ -54,7 +51,7 @@ const getAssistants = async (searchParams, select = null) => {
|
|||
* @returns {Promise<void>} Resolves when the assistant has been successfully deleted.
|
||||
*/
|
||||
const deleteAssistant = async (searchParams) => {
|
||||
return await Assistant.findOneAndDelete(searchParams);
|
||||
return await db.models.Assistant.findOneAndDelete(searchParams);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { balanceSchema } = require('@librechat/data-schemas');
|
||||
|
||||
module.exports = mongoose.model('Balance', balanceSchema);
|
||||
|
|
@ -1,8 +1,5 @@
|
|||
const mongoose = require('mongoose');
|
||||
const logger = require('~/config/winston');
|
||||
const { bannerSchema } = require('@librechat/data-schemas');
|
||||
|
||||
const Banner = mongoose.model('Banner', bannerSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Retrieves the current active banner.
|
||||
|
|
@ -11,7 +8,7 @@ const Banner = mongoose.model('Banner', bannerSchema);
|
|||
const getBanner = async (user) => {
|
||||
try {
|
||||
const now = new Date();
|
||||
const banner = await Banner.findOne({
|
||||
const banner = await db.models.Banner.findOne({
|
||||
displayFrom: { $lte: now },
|
||||
$or: [{ displayTo: { $gte: now } }, { displayTo: null }],
|
||||
type: 'banner',
|
||||
|
|
@ -28,4 +25,4 @@ const getBanner = async (user) => {
|
|||
}
|
||||
};
|
||||
|
||||
module.exports = { Banner, getBanner };
|
||||
module.exports = { getBanner };
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
const Conversation = require('./schema/convoSchema');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
|
|
@ -9,7 +8,7 @@ const logger = require('~/config/winston');
|
|||
*/
|
||||
const searchConversation = async (conversationId) => {
|
||||
try {
|
||||
return await Conversation.findOne({ conversationId }, 'conversationId user').lean();
|
||||
return await db.models.Conversation.findOne({ conversationId }, 'conversationId user').lean();
|
||||
} catch (error) {
|
||||
logger.error('[searchConversation] Error searching conversation', error);
|
||||
throw new Error('Error searching conversation');
|
||||
|
|
@ -24,7 +23,7 @@ const searchConversation = async (conversationId) => {
|
|||
*/
|
||||
const getConvo = async (user, conversationId) => {
|
||||
try {
|
||||
return await Conversation.findOne({ user, conversationId }).lean();
|
||||
return await db.models.Conversation.findOne({ user, conversationId }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getConvo] Error getting single conversation', error);
|
||||
return { message: 'Error getting single conversation' };
|
||||
|
|
@ -41,7 +40,7 @@ const deleteNullOrEmptyConversations = async () => {
|
|||
],
|
||||
};
|
||||
|
||||
const result = await Conversation.deleteMany(filter);
|
||||
const result = await db.models.Conversation.deleteMany(filter);
|
||||
|
||||
// Delete associated messages
|
||||
const messageDeleteResult = await deleteMessages(filter);
|
||||
|
|
@ -67,7 +66,7 @@ const deleteNullOrEmptyConversations = async () => {
|
|||
*/
|
||||
const getConvoFiles = async (conversationId) => {
|
||||
try {
|
||||
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
|
||||
return (await db.models.Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
|
||||
} catch (error) {
|
||||
logger.error('[getConvoFiles] Error getting conversation files', error);
|
||||
throw new Error('Error getting conversation files');
|
||||
|
|
@ -75,7 +74,6 @@ const getConvoFiles = async (conversationId) => {
|
|||
};
|
||||
|
||||
module.exports = {
|
||||
Conversation,
|
||||
getConvoFiles,
|
||||
searchConversation,
|
||||
deleteNullOrEmptyConversations,
|
||||
|
|
@ -114,7 +112,7 @@ module.exports = {
|
|||
}
|
||||
|
||||
/** Note: the resulting Model object is necessary for Meilisearch operations */
|
||||
const conversation = await Conversation.findOneAndUpdate(
|
||||
const conversation = await db.models.Conversation.findOneAndUpdate(
|
||||
{ conversationId, user: req.user.id },
|
||||
updateOperation,
|
||||
{
|
||||
|
|
@ -143,7 +141,7 @@ module.exports = {
|
|||
},
|
||||
}));
|
||||
|
||||
const result = await Conversation.bulkWrite(bulkOps);
|
||||
const result = await db.models.Conversation.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[saveBulkConversations] Error saving conversations in bulk', error);
|
||||
|
|
@ -155,7 +153,7 @@ module.exports = {
|
|||
{ cursor, limit = 25, isArchived = false, tags, search, order = 'desc' } = {},
|
||||
) => {
|
||||
const filters = [{ user }];
|
||||
|
||||
const { Conversation } = db.models;
|
||||
if (isArchived) {
|
||||
filters.push({ isArchived: true });
|
||||
} else {
|
||||
|
|
@ -219,7 +217,7 @@ module.exports = {
|
|||
|
||||
const conversationIds = convoIds.map((convo) => convo.conversationId);
|
||||
|
||||
const results = await Conversation.find({
|
||||
const results = await db.models.Conversation.find({
|
||||
user,
|
||||
conversationId: { $in: conversationIds },
|
||||
$or: [{ expiredAt: { $exists: false } }, { expiredAt: null }],
|
||||
|
|
@ -288,7 +286,7 @@ module.exports = {
|
|||
deleteConvos: async (user, filter) => {
|
||||
try {
|
||||
const userFilter = { ...filter, user };
|
||||
|
||||
const { Conversation } = db.models;
|
||||
const conversations = await Conversation.find(userFilter).select('conversationId');
|
||||
const conversationIds = conversations.map((c) => c.conversationId);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
const mongoose = require('mongoose');
|
||||
const Conversation = require('./schema/convoSchema');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const { conversationTagSchema } = require('@librechat/data-schemas');
|
||||
|
||||
const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Retrieves all conversation tags for a user.
|
||||
|
|
@ -13,7 +8,7 @@ const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema)
|
|||
*/
|
||||
const getConversationTags = async (user) => {
|
||||
try {
|
||||
return await ConversationTag.find({ user }).sort({ position: 1 }).lean();
|
||||
return await db.models.ConversationTag.find({ user }).sort({ position: 1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getConversationTags] Error getting conversation tags', error);
|
||||
throw new Error('Error getting conversation tags');
|
||||
|
|
@ -34,6 +29,7 @@ const createConversationTag = async (user, data) => {
|
|||
try {
|
||||
const { tag, description, addToConversation, conversationId } = data;
|
||||
|
||||
const { ConversationTag, Conversation } = db.models;
|
||||
const existingTag = await ConversationTag.findOne({ user, tag }).lean();
|
||||
if (existingTag) {
|
||||
return existingTag;
|
||||
|
|
@ -88,6 +84,7 @@ const updateConversationTag = async (user, oldTag, data) => {
|
|||
try {
|
||||
const { tag: newTag, description, position } = data;
|
||||
|
||||
const { ConversationTag, Conversation } = db.models;
|
||||
const existingTag = await ConversationTag.findOne({ user, tag: oldTag }).lean();
|
||||
if (!existingTag) {
|
||||
return null;
|
||||
|
|
@ -140,15 +137,15 @@ const adjustPositions = async (user, oldPosition, newPosition) => {
|
|||
const position =
|
||||
oldPosition < newPosition
|
||||
? {
|
||||
$gt: Math.min(oldPosition, newPosition),
|
||||
$lte: Math.max(oldPosition, newPosition),
|
||||
}
|
||||
$gt: Math.min(oldPosition, newPosition),
|
||||
$lte: Math.max(oldPosition, newPosition),
|
||||
}
|
||||
: {
|
||||
$gte: Math.min(oldPosition, newPosition),
|
||||
$lt: Math.max(oldPosition, newPosition),
|
||||
};
|
||||
$gte: Math.min(oldPosition, newPosition),
|
||||
$lt: Math.max(oldPosition, newPosition),
|
||||
};
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
await db.models.ConversationTag.updateMany(
|
||||
{
|
||||
user,
|
||||
position,
|
||||
|
|
@ -165,6 +162,7 @@ const adjustPositions = async (user, oldPosition, newPosition) => {
|
|||
*/
|
||||
const deleteConversationTag = async (user, tag) => {
|
||||
try {
|
||||
const { ConversationTag, Conversation } = db.models;
|
||||
const deletedTag = await ConversationTag.findOneAndDelete({ user, tag }).lean();
|
||||
if (!deletedTag) {
|
||||
return null;
|
||||
|
|
@ -193,6 +191,7 @@ const deleteConversationTag = async (user, tag) => {
|
|||
*/
|
||||
const updateTagsForConversation = async (user, conversationId, tags) => {
|
||||
try {
|
||||
const { ConversationTag, Conversation } = db.models;
|
||||
const conversation = await Conversation.findOne({ user, conversationId }).lean();
|
||||
if (!conversation) {
|
||||
throw new Error('Conversation not found');
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { EToolResources } = require('librechat-data-provider');
|
||||
const { fileSchema } = require('@librechat/data-schemas');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const File = mongoose.model('File', fileSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Finds a file by its file_id with additional query options.
|
||||
|
|
@ -12,7 +10,7 @@ const File = mongoose.model('File', fileSchema);
|
|||
* @returns {Promise<MongoFile>} A promise that resolves to the file document or null.
|
||||
*/
|
||||
const findFileById = async (file_id, options = {}) => {
|
||||
return await File.findOne({ file_id, ...options }).lean();
|
||||
return await db.models.File.findOne({ file_id, ...options }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -25,7 +23,7 @@ const findFileById = async (file_id, options = {}) => {
|
|||
*/
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||
const sortOptions = { updatedAt: -1, ..._sortOptions };
|
||||
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
return await db.models.File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -81,7 +79,7 @@ const createFile = async (data, disableTTL) => {
|
|||
delete fileData.expiresAt;
|
||||
}
|
||||
|
||||
return await File.findOneAndUpdate({ file_id: data.file_id }, fileData, {
|
||||
return await db.models.File.findOneAndUpdate({ file_id: data.file_id }, fileData, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
}).lean();
|
||||
|
|
@ -98,7 +96,7 @@ const updateFile = async (data) => {
|
|||
$set: update,
|
||||
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
|
||||
};
|
||||
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
|
||||
return await db.models.File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -112,7 +110,7 @@ const updateFileUsage = async (data) => {
|
|||
$inc: { usage: inc },
|
||||
$unset: { expiresAt: '', temp_file_id: '' },
|
||||
};
|
||||
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
|
||||
return await db.models.File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -121,7 +119,7 @@ const updateFileUsage = async (data) => {
|
|||
* @returns {Promise<MongoFile>} A promise that resolves to the deleted file document or null.
|
||||
*/
|
||||
const deleteFile = async (file_id) => {
|
||||
return await File.findOneAndDelete({ file_id }).lean();
|
||||
return await db.models.File.findOneAndDelete({ file_id }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -130,7 +128,7 @@ const deleteFile = async (file_id) => {
|
|||
* @returns {Promise<MongoFile>} A promise that resolves to the deleted file document or null.
|
||||
*/
|
||||
const deleteFileByFilter = async (filter) => {
|
||||
return await File.findOneAndDelete(filter).lean();
|
||||
return await db.models.File.findOneAndDelete(filter).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -143,7 +141,7 @@ const deleteFiles = async (file_ids, user) => {
|
|||
if (user) {
|
||||
deleteQuery = { user: user };
|
||||
}
|
||||
return await File.deleteMany(deleteQuery);
|
||||
return await db.models.File.deleteMany(deleteQuery);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -169,7 +167,6 @@ async function batchUpdateFiles(updates) {
|
|||
}
|
||||
|
||||
module.exports = {
|
||||
File,
|
||||
findFileById,
|
||||
getFiles,
|
||||
getToolFilesByIds,
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { keySchema } = require('@librechat/data-schemas');
|
||||
|
||||
module.exports = mongoose.model('Key', keySchema);
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
const { z } = require('zod');
|
||||
const Message = require('./schema/messageSchema');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
/**
|
||||
|
|
@ -68,8 +67,7 @@ async function saveMessage(req, params, metadata) {
|
|||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
update.tokenCount = 0;
|
||||
}
|
||||
|
||||
const message = await Message.findOneAndUpdate(
|
||||
const message = await db.models.Message.findOneAndUpdate(
|
||||
{ messageId: params.messageId, user: req.user.id },
|
||||
update,
|
||||
{ upsert: true, new: true },
|
||||
|
|
@ -87,7 +85,7 @@ async function saveMessage(req, params, metadata) {
|
|||
|
||||
try {
|
||||
// Try to find the existing message with this ID
|
||||
const existingMessage = await Message.findOne({
|
||||
const existingMessage = await db.models.Message.findOne({
|
||||
messageId: params.messageId,
|
||||
user: req.user.id,
|
||||
});
|
||||
|
|
@ -140,8 +138,7 @@ async function bulkSaveMessages(messages, overrideTimestamp = false) {
|
|||
upsert: true,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
const result = await db.models.Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
|
|
@ -183,7 +180,7 @@ async function recordMessage({
|
|||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
return await db.models.Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
|
|
@ -207,7 +204,7 @@ async function recordMessage({
|
|||
*/
|
||||
async function updateMessageText(req, { messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId, user: req.user.id }, { text });
|
||||
await db.models?.Message.updateOne({ messageId, user: req.user.id }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw err;
|
||||
|
|
@ -235,7 +232,7 @@ async function updateMessageText(req, { messageId, text }) {
|
|||
async function updateMessage(req, message, metadata) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
const updatedMessage = await Message.findOneAndUpdate(
|
||||
const updatedMessage = await db.models.Message.findOneAndUpdate(
|
||||
{ messageId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
|
|
@ -279,10 +276,10 @@ async function updateMessage(req, message, metadata) {
|
|||
*/
|
||||
async function deleteMessagesSince(req, { messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
|
||||
const message = await db.models.Message.findOne({ messageId, user: req.user.id }).lean();
|
||||
|
||||
if (message) {
|
||||
const query = Message.find({ conversationId, user: req.user.id });
|
||||
const query = db.models.Message.find({ conversationId, user: req.user.id });
|
||||
return await query.deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
|
|
@ -306,10 +303,10 @@ async function deleteMessagesSince(req, { messageId, conversationId }) {
|
|||
async function getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
return await db.models.Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
return await db.models.Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw err;
|
||||
|
|
@ -326,7 +323,7 @@ async function getMessages(filter, select) {
|
|||
*/
|
||||
async function getMessage({ user, messageId }) {
|
||||
try {
|
||||
return await Message.findOne({
|
||||
return await db.models.Message.findOne({
|
||||
user,
|
||||
messageId,
|
||||
}).lean();
|
||||
|
|
@ -347,7 +344,7 @@ async function getMessage({ user, messageId }) {
|
|||
*/
|
||||
async function deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
return await db.models.Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
|
|
@ -355,7 +352,6 @@ async function deleteMessages(filter) {
|
|||
}
|
||||
|
||||
module.exports = {
|
||||
Message,
|
||||
saveMessage,
|
||||
bulkSaveMessages,
|
||||
recordMessage,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
jest.mock('mongoose');
|
||||
|
||||
|
|
@ -20,14 +21,28 @@ const mockSchema = {
|
|||
deleteMany: jest.fn(),
|
||||
};
|
||||
|
||||
mongoose.model.mockReturnValue(mockSchema);
|
||||
|
||||
jest.mock('~/models/schema/messageSchema', () => mockSchema);
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
}));
|
||||
|
||||
const mockModels = {
|
||||
Message: {
|
||||
findOneAndUpdate: mockSchema.findOneAndUpdate,
|
||||
updateOne: mockSchema.updateOne,
|
||||
findOne: mockSchema.findOne,
|
||||
find: mockSchema.find,
|
||||
deleteMany: mockSchema.deleteMany,
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock('~/lib/db/connectDb', () => {
|
||||
return {
|
||||
get models() {
|
||||
return mockModels;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const {
|
||||
saveMessage,
|
||||
getMessages,
|
||||
|
|
@ -153,7 +168,7 @@ describe('Message Operations', () => {
|
|||
});
|
||||
|
||||
describe('Conversation Hijacking Prevention', () => {
|
||||
it('should not allow editing a message in another user\'s conversation', async () => {
|
||||
it("should not allow editing a message in another user's conversation", async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
|
@ -175,7 +190,7 @@ describe('Message Operations', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should not allow deleting messages from another user\'s conversation', async () => {
|
||||
it("should not allow deleting messages from another user's conversation", async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
|
@ -193,7 +208,7 @@ describe('Message Operations', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('should not allow inserting a new message into another user\'s conversation', async () => {
|
||||
it("should not allow inserting a new message into another user's conversation", async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = uuidv4(); // Use a valid UUID
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
const Preset = require('./schema/presetSchema');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const getPreset = async (user, presetId) => {
|
||||
try {
|
||||
return await Preset.findOne({ user, presetId }).lean();
|
||||
return await db.models.Preset.findOne({ user, presetId }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getPreset] Error getting single preset', error);
|
||||
return { message: 'Error getting single preset' };
|
||||
|
|
@ -11,11 +11,10 @@ const getPreset = async (user, presetId) => {
|
|||
};
|
||||
|
||||
module.exports = {
|
||||
Preset,
|
||||
getPreset,
|
||||
getPresets: async (user, filter) => {
|
||||
try {
|
||||
const presets = await Preset.find({ ...filter, user }).lean();
|
||||
const presets = await db.models.Preset.find({ ...filter, user }).lean();
|
||||
const defaultValue = 10000;
|
||||
|
||||
presets.sort((a, b) => {
|
||||
|
|
@ -40,6 +39,7 @@ module.exports = {
|
|||
const setter = { $set: {} };
|
||||
const { user: _, ...cleanPreset } = preset;
|
||||
const update = { presetId, ...cleanPreset };
|
||||
const Preset = db.models.Preset;
|
||||
if (preset.tools && Array.isArray(preset.tools)) {
|
||||
update.tools =
|
||||
preset.tools
|
||||
|
|
@ -77,7 +77,7 @@ module.exports = {
|
|||
deletePresets: async (user, filter) => {
|
||||
// let toRemove = await Preset.find({ ...filter, user }).select('presetId');
|
||||
// const ids = toRemove.map((instance) => instance.presetId);
|
||||
let deleteCount = await Preset.deleteMany({ ...filter, user });
|
||||
let deleteCount = await db.models.Preset.deleteMany({ ...filter, user });
|
||||
return deleteCount;
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
const { model } = require('mongoose');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { projectSchema } = require('@librechat/data-schemas');
|
||||
|
||||
const Project = model('Project', projectSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Retrieve a project by ID and convert the found project document to a plain object.
|
||||
|
|
@ -12,7 +9,7 @@ const Project = model('Project', projectSchema);
|
|||
* @returns {Promise<IMongoProject>} A plain object representing the project document, or `null` if no project is found.
|
||||
*/
|
||||
const getProjectById = async function (projectId, fieldsToSelect = null) {
|
||||
const query = Project.findById(projectId);
|
||||
const query = db.models.Project.findById(projectId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
|
|
@ -39,7 +36,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
|||
select: fieldsToSelect,
|
||||
};
|
||||
|
||||
return await Project.findOneAndUpdate(query, update, options);
|
||||
return await db.models.Project.findOneAndUpdate(query, update, options);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -50,7 +47,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
|||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
return await db.models.Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
|
||||
{ new: true },
|
||||
|
|
@ -65,7 +62,7 @@ const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
|||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
return await db.models.Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
|
||||
{ new: true },
|
||||
|
|
@ -79,7 +76,7 @@ const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeGroupFromAllProjects = async (promptGroupId) => {
|
||||
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||
await db.models.Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -90,7 +87,7 @@ const removeGroupFromAllProjects = async (promptGroupId) => {
|
|||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const addAgentIdsToProject = async function (projectId, agentIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
return await db.models.Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { agentIds: { $each: agentIds } } },
|
||||
{ new: true },
|
||||
|
|
@ -105,7 +102,7 @@ const addAgentIdsToProject = async function (projectId, agentIds) {
|
|||
* @returns {Promise<IMongoProject>} The updated project document.
|
||||
*/
|
||||
const removeAgentIdsFromProject = async function (projectId, agentIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
return await db.models.Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { agentIds: { $in: agentIds } } },
|
||||
{ new: true },
|
||||
|
|
@ -119,7 +116,7 @@ const removeAgentIdsFromProject = async function (projectId, agentIds) {
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeAgentFromAllProjects = async (agentId) => {
|
||||
await Project.updateMany({}, { $pull: { agentIds: agentId } });
|
||||
await db.models.Project.updateMany({}, { $pull: { agentIds: agentId } });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -7,12 +6,9 @@ const {
|
|||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
} = require('./Project');
|
||||
const { promptGroupSchema, promptSchema } = require('@librechat/data-schemas');
|
||||
const { escapeRegExp } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
|
||||
const Prompt = mongoose.model('Prompt', promptSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get prompt groups
|
||||
|
|
@ -137,7 +133,7 @@ const getAllPromptGroups = async (req, filter) => {
|
|||
}
|
||||
|
||||
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
|
||||
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||
return await db.models.PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||
} catch (error) {
|
||||
console.error('Error getting all prompt groups', error);
|
||||
return { message: 'Error getting all prompt groups' };
|
||||
|
|
@ -237,7 +233,7 @@ const deletePromptGroup = async ({ _id, author, role }) => {
|
|||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
await Prompt.deleteMany(groupQuery);
|
||||
await db.models.Prompt.deleteMany(groupQuery);
|
||||
await removeGroupFromAllProjects(_id);
|
||||
return { message: 'Prompt group deleted successfully' };
|
||||
};
|
||||
|
|
@ -254,6 +250,7 @@ module.exports = {
|
|||
createPromptGroup: async (saveData) => {
|
||||
try {
|
||||
const { prompt, group, author, authorName } = saveData;
|
||||
const { Prompt, PromptGroup } = db.models;
|
||||
|
||||
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||
{ ...group, author, authorName, productionId: null },
|
||||
|
|
@ -309,6 +306,7 @@ module.exports = {
|
|||
|
||||
/** @type {TPrompt} */
|
||||
let newPrompt;
|
||||
const { Prompt } = db.models;
|
||||
try {
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
} catch (error) {
|
||||
|
|
@ -328,7 +326,7 @@ module.exports = {
|
|||
},
|
||||
getPrompts: async (filter) => {
|
||||
try {
|
||||
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
return await db.models.Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompts', error);
|
||||
return { message: 'Error getting prompts' };
|
||||
|
|
@ -339,7 +337,7 @@ module.exports = {
|
|||
if (filter.groupId) {
|
||||
filter.groupId = new ObjectId(filter.groupId);
|
||||
}
|
||||
return await Prompt.findOne(filter).lean();
|
||||
return await db.models.Prompt.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt', error);
|
||||
return { message: 'Error getting prompt' };
|
||||
|
|
@ -352,7 +350,7 @@ module.exports = {
|
|||
*/
|
||||
getRandomPromptGroups: async (filter) => {
|
||||
try {
|
||||
const result = await PromptGroup.aggregate([
|
||||
const result = await db.models.PromptGroup.aggregate([
|
||||
{
|
||||
$match: {
|
||||
category: { $ne: '' },
|
||||
|
|
@ -385,7 +383,7 @@ module.exports = {
|
|||
},
|
||||
getPromptGroupsWithPrompts: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter)
|
||||
return await db.models.PromptGroup.findOne(filter)
|
||||
.populate({
|
||||
path: 'prompts',
|
||||
select: '-_id -__v -user',
|
||||
|
|
@ -399,7 +397,7 @@ module.exports = {
|
|||
},
|
||||
getPromptGroup: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter).lean();
|
||||
return await db.models.PromptGroup.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
return { message: 'Error getting prompt group' };
|
||||
|
|
@ -420,6 +418,7 @@ module.exports = {
|
|||
*/
|
||||
deletePrompt: async ({ promptId, groupId, author, role }) => {
|
||||
const query = { _id: promptId, groupId, author };
|
||||
const { Prompt, PromptGroup } = db.models;
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
|
|
@ -484,7 +483,7 @@ module.exports = {
|
|||
}
|
||||
|
||||
const updateData = { ...data, ...updateOps };
|
||||
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
const updatedDoc = await db.models.PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
|
|
@ -506,6 +505,7 @@ module.exports = {
|
|||
*/
|
||||
makePromptProduction: async (promptId) => {
|
||||
try {
|
||||
const { Prompt, PromptGroup } = db.models;
|
||||
const prompt = await Prompt.findById(promptId).lean();
|
||||
|
||||
if (!prompt) {
|
||||
|
|
@ -530,7 +530,7 @@ module.exports = {
|
|||
},
|
||||
updatePromptLabels: async (_id, labels) => {
|
||||
try {
|
||||
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
const response = await db.models.Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
if (response.matchedCount === 0) {
|
||||
return { message: 'Prompt not found' };
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
const mongoose = require('mongoose');
|
||||
const {
|
||||
CacheKeys,
|
||||
SystemRoles,
|
||||
|
|
@ -8,10 +7,8 @@ const {
|
|||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { roleSchema } = require('@librechat/data-schemas');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const Role = mongoose.model('Role', roleSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
|
|
@ -24,6 +21,7 @@ const Role = mongoose.model('Role', roleSchema);
|
|||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const { Role } = db.models;
|
||||
try {
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
|
|
@ -57,7 +55,7 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
|||
const updateRoleByName = async function (roleName, updates) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
try {
|
||||
const role = await Role.findOneAndUpdate(
|
||||
const role = await db.models.Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
{ new: true, lean: true },
|
||||
|
|
@ -78,6 +76,7 @@ const updateRoleByName = async function (roleName, updates) {
|
|||
* @param {Object.<PermissionTypes, Object.<Permissions, boolean>>} permissionsUpdate - Permissions to update and their values.
|
||||
*/
|
||||
async function updateAccessPermissions(roleName, permissionsUpdate) {
|
||||
const { Role } = db.models;
|
||||
// Filter and clean the permission updates based on our schema definition.
|
||||
const updates = {};
|
||||
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
|
||||
|
|
@ -181,6 +180,7 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
|
|||
* @returns {Promise<void>}
|
||||
*/
|
||||
const initializeRoles = async function () {
|
||||
const { Role } = db.models;
|
||||
for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) {
|
||||
let role = await Role.findOne({ name: roleName });
|
||||
const defaultPerms = roleDefaults[roleName].permissions;
|
||||
|
|
@ -210,6 +210,7 @@ const initializeRoles = async function () {
|
|||
* @returns {Promise<number>} Number of roles migrated.
|
||||
*/
|
||||
const migrateRoleSchema = async function (roleName) {
|
||||
const { Role } = db.models;
|
||||
try {
|
||||
// Get roles to migrate
|
||||
let roles;
|
||||
|
|
@ -282,7 +283,6 @@ const migrateRoleSchema = async function (roleName) {
|
|||
};
|
||||
|
||||
module.exports = {
|
||||
Role,
|
||||
getRoleByName,
|
||||
initializeRoles,
|
||||
updateRoleByName,
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ const {
|
|||
roleDefaults,
|
||||
PermissionTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
|
||||
const { getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
// Mock the cache
|
||||
jest.mock('~/cache/getLogStores', () =>
|
||||
jest.fn().mockReturnValue({
|
||||
|
|
@ -19,11 +21,14 @@ jest.mock('~/cache/getLogStores', () =>
|
|||
);
|
||||
|
||||
let mongoServer;
|
||||
let Role;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
await db.connectDb(mongoUri);
|
||||
|
||||
Role = db.models.Role;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
|
|
|
|||
|
|
@ -1,275 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const signPayload = require('~/server/services/signPayload');
|
||||
const { hashToken } = require('~/server/utils/crypto');
|
||||
const { sessionSchema } = require('@librechat/data-schemas');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const Session = mongoose.model('Session', sessionSchema);
|
||||
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
|
||||
/**
|
||||
* Error class for Session-related errors
|
||||
*/
|
||||
class SessionError extends Error {
|
||||
constructor(message, code = 'SESSION_ERROR') {
|
||||
super(message);
|
||||
this.name = 'SessionError';
|
||||
this.code = code;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new session for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @param {Object} options - Additional options for session creation
|
||||
* @param {Date} options.expiration - Custom expiration date
|
||||
* @returns {Promise<{session: Session, refreshToken: string}>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const createSession = async (userId, options = {}) => {
|
||||
if (!userId) {
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
try {
|
||||
const session = new Session({
|
||||
user: userId,
|
||||
expiration: options.expiration || new Date(Date.now() + expires),
|
||||
});
|
||||
const refreshToken = await generateRefreshToken(session);
|
||||
return { session, refreshToken };
|
||||
} catch (error) {
|
||||
logger.error('[createSession] Error creating session:', error);
|
||||
throw new SessionError('Failed to create session', 'CREATE_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Finds a session by various parameters
|
||||
* @param {Object} params - Search parameters
|
||||
* @param {string} [params.refreshToken] - The refresh token to search by
|
||||
* @param {string} [params.userId] - The user ID to search by
|
||||
* @param {string} [params.sessionId] - The session ID to search by
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {boolean} [options.lean=true] - Whether to return plain objects instead of documents
|
||||
* @returns {Promise<Session|null>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const findSession = async (params, options = { lean: true }) => {
|
||||
try {
|
||||
const query = {};
|
||||
|
||||
if (!params.refreshToken && !params.userId && !params.sessionId) {
|
||||
throw new SessionError('At least one search parameter is required', 'INVALID_SEARCH_PARAMS');
|
||||
}
|
||||
|
||||
if (params.refreshToken) {
|
||||
const tokenHash = await hashToken(params.refreshToken);
|
||||
query.refreshTokenHash = tokenHash;
|
||||
}
|
||||
|
||||
if (params.userId) {
|
||||
query.user = params.userId;
|
||||
}
|
||||
|
||||
if (params.sessionId) {
|
||||
const sessionId = params.sessionId.sessionId || params.sessionId;
|
||||
if (!mongoose.Types.ObjectId.isValid(sessionId)) {
|
||||
throw new SessionError('Invalid session ID format', 'INVALID_SESSION_ID');
|
||||
}
|
||||
query._id = sessionId;
|
||||
}
|
||||
|
||||
// Add expiration check to only return valid sessions
|
||||
query.expiration = { $gt: new Date() };
|
||||
|
||||
const sessionQuery = Session.findOne(query);
|
||||
|
||||
if (options.lean) {
|
||||
return await sessionQuery.lean();
|
||||
}
|
||||
|
||||
return await sessionQuery.exec();
|
||||
} catch (error) {
|
||||
logger.error('[findSession] Error finding session:', error);
|
||||
throw new SessionError('Failed to find session', 'FIND_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates session expiration
|
||||
* @param {Session|string} session - The session or session ID to update
|
||||
* @param {Date} [newExpiration] - Optional new expiration date
|
||||
* @returns {Promise<Session>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const updateExpiration = async (session, newExpiration) => {
|
||||
try {
|
||||
const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session;
|
||||
|
||||
if (!sessionDoc) {
|
||||
throw new SessionError('Session not found', 'SESSION_NOT_FOUND');
|
||||
}
|
||||
|
||||
sessionDoc.expiration = newExpiration || new Date(Date.now() + expires);
|
||||
return await sessionDoc.save();
|
||||
} catch (error) {
|
||||
logger.error('[updateExpiration] Error updating session:', error);
|
||||
throw new SessionError('Failed to update session expiration', 'UPDATE_EXPIRATION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a session by refresh token or session ID
|
||||
* @param {Object} params - Delete parameters
|
||||
* @param {string} [params.refreshToken] - The refresh token of the session to delete
|
||||
* @param {string} [params.sessionId] - The ID of the session to delete
|
||||
* @returns {Promise<Object>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const deleteSession = async (params) => {
|
||||
try {
|
||||
if (!params.refreshToken && !params.sessionId) {
|
||||
throw new SessionError(
|
||||
'Either refreshToken or sessionId is required',
|
||||
'INVALID_DELETE_PARAMS',
|
||||
);
|
||||
}
|
||||
|
||||
const query = {};
|
||||
|
||||
if (params.refreshToken) {
|
||||
query.refreshTokenHash = await hashToken(params.refreshToken);
|
||||
}
|
||||
|
||||
if (params.sessionId) {
|
||||
query._id = params.sessionId;
|
||||
}
|
||||
|
||||
const result = await Session.deleteOne(query);
|
||||
|
||||
if (result.deletedCount === 0) {
|
||||
logger.warn('[deleteSession] No session found to delete');
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[deleteSession] Error deleting session:', error);
|
||||
throw new SessionError('Failed to delete session', 'DELETE_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes all sessions for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {boolean} [options.excludeCurrentSession] - Whether to exclude the current session
|
||||
* @param {string} [options.currentSessionId] - The ID of the current session to exclude
|
||||
* @returns {Promise<Object>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const deleteAllUserSessions = async (userId, options = {}) => {
|
||||
try {
|
||||
if (!userId) {
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
// Extract userId if it's passed as an object
|
||||
const userIdString = userId.userId || userId;
|
||||
|
||||
if (!mongoose.Types.ObjectId.isValid(userIdString)) {
|
||||
throw new SessionError('Invalid user ID format', 'INVALID_USER_ID_FORMAT');
|
||||
}
|
||||
|
||||
const query = { user: userIdString };
|
||||
|
||||
if (options.excludeCurrentSession && options.currentSessionId) {
|
||||
query._id = { $ne: options.currentSessionId };
|
||||
}
|
||||
|
||||
const result = await Session.deleteMany(query);
|
||||
|
||||
if (result.deletedCount > 0) {
|
||||
logger.debug(
|
||||
`[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userIdString}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllUserSessions] Error deleting user sessions:', error);
|
||||
throw new SessionError('Failed to delete user sessions', 'DELETE_ALL_SESSIONS_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates a refresh token for a session
|
||||
* @param {Session} session - The session to generate a token for
|
||||
* @returns {Promise<string>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const generateRefreshToken = async (session) => {
|
||||
if (!session || !session.user) {
|
||||
throw new SessionError('Invalid session object', 'INVALID_SESSION');
|
||||
}
|
||||
|
||||
try {
|
||||
const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires;
|
||||
|
||||
if (!session.expiration) {
|
||||
session.expiration = new Date(expiresIn);
|
||||
}
|
||||
|
||||
const refreshToken = await signPayload({
|
||||
payload: {
|
||||
id: session.user,
|
||||
sessionId: session._id,
|
||||
},
|
||||
secret: process.env.JWT_REFRESH_SECRET,
|
||||
expirationTime: Math.floor((expiresIn - Date.now()) / 1000),
|
||||
});
|
||||
|
||||
session.refreshTokenHash = await hashToken(refreshToken);
|
||||
await session.save();
|
||||
|
||||
return refreshToken;
|
||||
} catch (error) {
|
||||
logger.error('[generateRefreshToken] Error generating refresh token:', error);
|
||||
throw new SessionError('Failed to generate refresh token', 'GENERATE_TOKEN_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Counts active sessions for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @returns {Promise<number>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const countActiveSessions = async (userId) => {
|
||||
try {
|
||||
if (!userId) {
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
return await Session.countDocuments({
|
||||
user: userId,
|
||||
expiration: { $gt: new Date() },
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[countActiveSessions] Error counting active sessions:', error);
|
||||
throw new SessionError('Failed to count active sessions', 'COUNT_SESSIONS_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
SessionError,
|
||||
};
|
||||
|
|
@ -1,9 +1,6 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { Conversation } = require('~/models/Conversation');
|
||||
const { shareSchema } = require('@librechat/data-schemas');
|
||||
const SharedLink = mongoose.model('SharedLink', shareSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { getMessages } = require('./Message');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
|
|
@ -76,7 +73,7 @@ function anonymizeMessages(messages, newConvoId) {
|
|||
|
||||
async function getSharedMessages(shareId) {
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId, isPublic: true })
|
||||
const share = await db.models.SharedLink.findOne({ shareId, isPublic: true })
|
||||
.populate({
|
||||
path: 'messages',
|
||||
select: '-_id -__v -user',
|
||||
|
|
@ -151,7 +148,7 @@ async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortD
|
|||
query.conversationId = { $in: query.conversationId };
|
||||
}
|
||||
|
||||
const sharedLinks = await SharedLink.find(query)
|
||||
const sharedLinks = await db.models.SharedLink.find(query)
|
||||
.sort(sort)
|
||||
.limit(pageSize + 1)
|
||||
.select('-__v -user')
|
||||
|
|
@ -184,7 +181,7 @@ async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortD
|
|||
|
||||
async function deleteAllSharedLinks(user) {
|
||||
try {
|
||||
const result = await SharedLink.deleteMany({ user });
|
||||
const result = await db.models.SharedLink.deleteMany({ user });
|
||||
return {
|
||||
message: 'All shared links deleted successfully',
|
||||
deletedCount: result.deletedCount,
|
||||
|
|
@ -202,7 +199,7 @@ async function createSharedLink(user, conversationId) {
|
|||
if (!user || !conversationId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
const { SharedLink, Conversation } = db.models;
|
||||
try {
|
||||
const [existingShare, conversationMessages] = await Promise.all([
|
||||
SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(),
|
||||
|
|
@ -244,7 +241,7 @@ async function getSharedLink(user, conversationId) {
|
|||
}
|
||||
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId, user, isPublic: true })
|
||||
const share = await db.models.SharedLink.findOne({ conversationId, user, isPublic: true })
|
||||
.select('shareId -_id')
|
||||
.lean();
|
||||
|
||||
|
|
@ -268,6 +265,7 @@ async function updateSharedLink(user, shareId) {
|
|||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
const { SharedLink } = db.models;
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean();
|
||||
|
||||
|
|
@ -318,7 +316,7 @@ async function deleteSharedLink(user, shareId) {
|
|||
}
|
||||
|
||||
try {
|
||||
const result = await SharedLink.findOneAndDelete({ shareId, user }).lean();
|
||||
const result = await db.models.SharedLink.findOneAndDelete({ shareId, user }).lean();
|
||||
|
||||
if (!result) {
|
||||
return null;
|
||||
|
|
@ -340,7 +338,6 @@ async function deleteSharedLink(user, shareId) {
|
|||
}
|
||||
|
||||
module.exports = {
|
||||
SharedLink,
|
||||
getSharedLink,
|
||||
getSharedLinks,
|
||||
createSharedLink,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,6 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { encryptV2 } = require('~/server/utils/crypto');
|
||||
const { tokenSchema } = require('@librechat/data-schemas');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Token model.
|
||||
* @type {mongoose.Model}
|
||||
*/
|
||||
const Token = mongoose.model('Token', tokenSchema);
|
||||
const db = require('~/lib/db/connectDb');
|
||||
/**
|
||||
* Fixes the indexes for the Token collection from legacy TTL indexes to the new expiresAt index.
|
||||
*/
|
||||
|
|
@ -20,7 +13,7 @@ async function fixIndexes() {
|
|||
) {
|
||||
return;
|
||||
}
|
||||
const indexes = await Token.collection.indexes();
|
||||
const indexes = await db.models.Token.collection.indexes();
|
||||
logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2));
|
||||
const unwantedTTLIndexes = indexes.filter(
|
||||
(index) => index.key.createdAt === 1 && index.expireAfterSeconds !== undefined,
|
||||
|
|
@ -31,7 +24,7 @@ async function fixIndexes() {
|
|||
}
|
||||
for (const index of unwantedTTLIndexes) {
|
||||
logger.debug(`Dropping unwanted Token index: ${index.name}`);
|
||||
await Token.collection.dropIndex(index.name);
|
||||
await db.models.Token.collection.dropIndex(index.name);
|
||||
logger.debug(`Dropped Token index: ${index.name}`);
|
||||
}
|
||||
logger.debug('Token index cleanup completed successfully.');
|
||||
|
|
@ -42,118 +35,6 @@ async function fixIndexes() {
|
|||
|
||||
fixIndexes();
|
||||
|
||||
/**
|
||||
* Creates a new Token instance.
|
||||
* @param {Object} tokenData - The data for the new Token.
|
||||
* @param {mongoose.Types.ObjectId} tokenData.userId - The user's ID. It is required.
|
||||
* @param {String} tokenData.email - The user's email.
|
||||
* @param {String} tokenData.token - The token. It is required.
|
||||
* @param {Number} tokenData.expiresIn - The number of seconds until the token expires.
|
||||
* @returns {Promise<mongoose.Document>} The new Token instance.
|
||||
* @throws Will throw an error if token creation fails.
|
||||
*/
|
||||
async function createToken(tokenData) {
|
||||
try {
|
||||
const currentTime = new Date();
|
||||
const expiresAt = new Date(currentTime.getTime() + tokenData.expiresIn * 1000);
|
||||
|
||||
const newTokenData = {
|
||||
...tokenData,
|
||||
createdAt: currentTime,
|
||||
expiresAt,
|
||||
};
|
||||
|
||||
return await Token.create(newTokenData);
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while creating token:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a Token document that matches the provided query.
|
||||
* @param {Object} query - The query to match against.
|
||||
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
|
||||
* @param {String} query.token - The token value.
|
||||
* @param {String} [query.email] - The email of the user.
|
||||
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
|
||||
* @returns {Promise<Object|null>} The matched Token document, or null if not found.
|
||||
* @throws Will throw an error if the find operation fails.
|
||||
*/
|
||||
async function findToken(query) {
|
||||
try {
|
||||
const conditions = [];
|
||||
|
||||
if (query.userId) {
|
||||
conditions.push({ userId: query.userId });
|
||||
}
|
||||
if (query.token) {
|
||||
conditions.push({ token: query.token });
|
||||
}
|
||||
if (query.email) {
|
||||
conditions.push({ email: query.email });
|
||||
}
|
||||
if (query.identifier) {
|
||||
conditions.push({ identifier: query.identifier });
|
||||
}
|
||||
|
||||
const token = await Token.findOne({
|
||||
$and: conditions,
|
||||
}).lean();
|
||||
|
||||
return token;
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while finding token:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a Token document that matches the provided query.
|
||||
* @param {Object} query - The query to match against.
|
||||
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
|
||||
* @param {String} query.token - The token value.
|
||||
* @param {String} [query.email] - The email of the user.
|
||||
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
|
||||
* @param {Object} updateData - The data to update the Token with.
|
||||
* @returns {Promise<mongoose.Document|null>} The updated Token document, or null if not found.
|
||||
* @throws Will throw an error if the update operation fails.
|
||||
*/
|
||||
async function updateToken(query, updateData) {
|
||||
try {
|
||||
return await Token.findOneAndUpdate(query, updateData, { new: true });
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while updating token:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all Token documents that match the provided token, user ID, or email.
|
||||
* @param {Object} query - The query to match against.
|
||||
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
|
||||
* @param {String} query.token - The token value.
|
||||
* @param {String} [query.email] - The email of the user.
|
||||
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
|
||||
* @returns {Promise<Object>} The result of the delete operation.
|
||||
* @throws Will throw an error if the delete operation fails.
|
||||
*/
|
||||
async function deleteTokens(query) {
|
||||
try {
|
||||
return await Token.deleteMany({
|
||||
$or: [
|
||||
{ userId: query.userId },
|
||||
{ token: query.token },
|
||||
{ email: query.email },
|
||||
{ identifier: query.identifier },
|
||||
],
|
||||
});
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while deleting tokens:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the OAuth token by creating or updating the token.
|
||||
* @param {object} fields
|
||||
|
|
@ -182,18 +63,15 @@ async function handleOAuthToken({
|
|||
expiresIn: parseInt(expiresIn, 10) || 3600,
|
||||
};
|
||||
|
||||
const existingToken = await findToken({ userId, identifier });
|
||||
const {Token} = db.models;
|
||||
const existingToken = await Token.findToken({ userId, identifier });
|
||||
if (existingToken) {
|
||||
return await updateToken({ identifier }, tokenData);
|
||||
return await Token.updateToken({ identifier }, tokenData);
|
||||
} else {
|
||||
return await createToken(tokenData);
|
||||
return await Token.createToken(tokenData);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
findToken,
|
||||
createToken,
|
||||
updateToken,
|
||||
deleteTokens,
|
||||
handleOAuthToken,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { toolCallSchema } = require('@librechat/data-schemas');
|
||||
const ToolCall = mongoose.model('ToolCall', toolCallSchema);
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
/**
|
||||
* Create a new tool call
|
||||
* @param {IToolCallData} toolCallData - The tool call data
|
||||
|
|
@ -9,7 +6,7 @@ const ToolCall = mongoose.model('ToolCall', toolCallSchema);
|
|||
*/
|
||||
async function createToolCall(toolCallData) {
|
||||
try {
|
||||
return await ToolCall.create(toolCallData);
|
||||
return await db.models.ToolCall.create(toolCallData);
|
||||
} catch (error) {
|
||||
throw new Error(`Error creating tool call: ${error.message}`);
|
||||
}
|
||||
|
|
@ -22,7 +19,7 @@ async function createToolCall(toolCallData) {
|
|||
*/
|
||||
async function getToolCallById(id) {
|
||||
try {
|
||||
return await ToolCall.findById(id).lean();
|
||||
return await db.models.ToolCall.findById(id).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool call: ${error.message}`);
|
||||
}
|
||||
|
|
@ -36,7 +33,7 @@ async function getToolCallById(id) {
|
|||
*/
|
||||
async function getToolCallsByMessage(messageId, userId) {
|
||||
try {
|
||||
return await ToolCall.find({ messageId, user: userId }).lean();
|
||||
return await db.models.ToolCall.find({ messageId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${error.message}`);
|
||||
}
|
||||
|
|
@ -50,7 +47,7 @@ async function getToolCallsByMessage(messageId, userId) {
|
|||
*/
|
||||
async function getToolCallsByConvo(conversationId, userId) {
|
||||
try {
|
||||
return await ToolCall.find({ conversationId, user: userId }).lean();
|
||||
return await db.models.ToolCall.find({ conversationId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${error.message}`);
|
||||
}
|
||||
|
|
@ -64,7 +61,7 @@ async function getToolCallsByConvo(conversationId, userId) {
|
|||
*/
|
||||
async function updateToolCall(id, updateData) {
|
||||
try {
|
||||
return await ToolCall.findByIdAndUpdate(id, updateData, { new: true }).lean();
|
||||
return await db.models.ToolCall.findByIdAndUpdate(id, updateData, { new: true }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error updating tool call: ${error.message}`);
|
||||
}
|
||||
|
|
@ -82,7 +79,7 @@ async function deleteToolCalls(userId, conversationId) {
|
|||
if (conversationId) {
|
||||
query.conversationId = conversationId;
|
||||
}
|
||||
return await ToolCall.deleteMany(query);
|
||||
return await db.models.ToolCall.deleteMany(query);
|
||||
} catch (error) {
|
||||
throw new Error(`Error deleting tool call: ${error.message}`);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ const { transactionSchema } = require('@librechat/data-schemas');
|
|||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { logger } = require('~/config');
|
||||
const Balance = require('./Balance');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const cancelRate = 1.15;
|
||||
|
||||
|
|
@ -23,6 +23,7 @@ const updateBalance = async ({ user, incrementValue, setValues }) => {
|
|||
let maxRetries = 10; // Number of times to retry on conflict
|
||||
let delay = 50; // Initial retry delay in ms
|
||||
let lastError = null;
|
||||
const { Balance } = db.models;
|
||||
|
||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
||||
let currentBalanceDoc;
|
||||
|
|
@ -140,19 +141,19 @@ const updateBalance = async ({ user, incrementValue, setValues }) => {
|
|||
};
|
||||
|
||||
/** Method to calculate and set the tokenValue for a transaction */
|
||||
transactionSchema.methods.calculateTokenValue = function () {
|
||||
if (!this.valueKey || !this.tokenType) {
|
||||
this.tokenValue = this.rawAmount;
|
||||
function calculateTokenValue(txn) {
|
||||
if (!txn.valueKey || !txn.tokenType) {
|
||||
txn.tokenValue = txn.rawAmount;
|
||||
}
|
||||
const { valueKey, tokenType, model, endpointTokenConfig } = this;
|
||||
const { valueKey, tokenType, model, endpointTokenConfig } = txn;
|
||||
const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig }));
|
||||
this.rate = multiplier;
|
||||
this.tokenValue = this.rawAmount * multiplier;
|
||||
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
|
||||
this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
|
||||
this.rate *= cancelRate;
|
||||
txn.rate = multiplier;
|
||||
txn.tokenValue = txn.rawAmount * multiplier;
|
||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate);
|
||||
txn.rate *= cancelRate;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* New static method to create an auto-refill transaction that does NOT trigger a balance update.
|
||||
|
|
@ -163,13 +164,14 @@ transactionSchema.methods.calculateTokenValue = function () {
|
|||
* @param {number} txData.rawAmount - The raw amount of tokens.
|
||||
* @returns {Promise<object>} - The created transaction.
|
||||
*/
|
||||
transactionSchema.statics.createAutoRefillTransaction = async function (txData) {
|
||||
async function createAutoRefillTransaction(txData) {
|
||||
const Transaction = db.models.Transaction;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
const transaction = new this(txData);
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.calculateTokenValue();
|
||||
calculateTokenValue(transaction);
|
||||
await transaction.save();
|
||||
|
||||
const balanceResponse = await updateBalance({
|
||||
|
|
@ -185,21 +187,20 @@ transactionSchema.statics.createAutoRefillTransaction = async function (txData)
|
|||
logger.debug('[Balance.check] Auto-refill performed', result);
|
||||
result.transaction = transaction;
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Static method to create a transaction and update the balance
|
||||
* @param {txData} txData - Transaction data.
|
||||
*/
|
||||
transactionSchema.statics.create = async function (txData) {
|
||||
const Transaction = this;
|
||||
async function createTransaction(txData) {
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(txData);
|
||||
const transaction = new db.models.Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.calculateTokenValue();
|
||||
calculateTokenValue(transaction);
|
||||
|
||||
await transaction.save();
|
||||
|
||||
|
|
@ -209,7 +210,6 @@ transactionSchema.statics.create = async function (txData) {
|
|||
}
|
||||
|
||||
let incrementValue = transaction.tokenValue;
|
||||
|
||||
const balanceResponse = await updateBalance({
|
||||
user: transaction.user,
|
||||
incrementValue,
|
||||
|
|
@ -221,21 +221,19 @@ transactionSchema.statics.create = async function (txData) {
|
|||
balance: balanceResponse.tokenCredits,
|
||||
[transaction.tokenType]: incrementValue,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Static method to create a structured transaction and update the balance
|
||||
* @param {txData} txData - Transaction data.
|
||||
*/
|
||||
transactionSchema.statics.createStructured = async function (txData) {
|
||||
const Transaction = this;
|
||||
|
||||
const transaction = new Transaction({
|
||||
async function createStructuredTransaction(txData) {
|
||||
const transaction = new db.models.Transaction({
|
||||
...txData,
|
||||
endpointTokenConfig: txData.endpointTokenConfig,
|
||||
});
|
||||
|
||||
transaction.calculateStructuredTokenValue();
|
||||
calculateStructuredTokenValue(transaction);
|
||||
|
||||
await transaction.save();
|
||||
|
||||
|
|
@ -257,71 +255,69 @@ transactionSchema.statics.createStructured = async function (txData) {
|
|||
balance: balanceResponse.tokenCredits,
|
||||
[transaction.tokenType]: incrementValue,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/** Method to calculate token value for structured tokens */
|
||||
transactionSchema.methods.calculateStructuredTokenValue = function () {
|
||||
if (!this.tokenType) {
|
||||
this.tokenValue = this.rawAmount;
|
||||
function calculateStructuredTokenValue(txn) {
|
||||
if (!txn.tokenType) {
|
||||
txn.tokenValue = txn.rawAmount;
|
||||
return;
|
||||
}
|
||||
|
||||
const { model, endpointTokenConfig } = this;
|
||||
const { model, endpointTokenConfig } = txn;
|
||||
|
||||
if (this.tokenType === 'prompt') {
|
||||
if (txn.tokenType === 'prompt') {
|
||||
const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig });
|
||||
const writeMultiplier =
|
||||
getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier;
|
||||
const readMultiplier =
|
||||
getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier;
|
||||
|
||||
this.rateDetail = {
|
||||
txn.rateDetail = {
|
||||
input: inputMultiplier,
|
||||
write: writeMultiplier,
|
||||
read: readMultiplier,
|
||||
};
|
||||
|
||||
const totalPromptTokens =
|
||||
Math.abs(this.inputTokens || 0) +
|
||||
Math.abs(this.writeTokens || 0) +
|
||||
Math.abs(this.readTokens || 0);
|
||||
Math.abs(txn.inputTokens || 0) +
|
||||
Math.abs(txn.writeTokens || 0) +
|
||||
Math.abs(txn.readTokens || 0);
|
||||
|
||||
if (totalPromptTokens > 0) {
|
||||
this.rate =
|
||||
(Math.abs(inputMultiplier * (this.inputTokens || 0)) +
|
||||
Math.abs(writeMultiplier * (this.writeTokens || 0)) +
|
||||
Math.abs(readMultiplier * (this.readTokens || 0))) /
|
||||
txn.rate =
|
||||
(Math.abs(inputMultiplier * (txn.inputTokens || 0)) +
|
||||
Math.abs(writeMultiplier * (txn.writeTokens || 0)) +
|
||||
Math.abs(readMultiplier * (txn.readTokens || 0))) /
|
||||
totalPromptTokens;
|
||||
} else {
|
||||
this.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens
|
||||
txn.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens
|
||||
}
|
||||
|
||||
this.tokenValue = -(
|
||||
Math.abs(this.inputTokens || 0) * inputMultiplier +
|
||||
Math.abs(this.writeTokens || 0) * writeMultiplier +
|
||||
Math.abs(this.readTokens || 0) * readMultiplier
|
||||
txn.tokenValue = -(
|
||||
Math.abs(txn.inputTokens || 0) * inputMultiplier +
|
||||
Math.abs(txn.writeTokens || 0) * writeMultiplier +
|
||||
Math.abs(txn.readTokens || 0) * readMultiplier
|
||||
);
|
||||
|
||||
this.rawAmount = -totalPromptTokens;
|
||||
} else if (this.tokenType === 'completion') {
|
||||
const multiplier = getMultiplier({ tokenType: this.tokenType, model, endpointTokenConfig });
|
||||
this.rate = Math.abs(multiplier);
|
||||
this.tokenValue = -Math.abs(this.rawAmount) * multiplier;
|
||||
this.rawAmount = -Math.abs(this.rawAmount);
|
||||
txn.rawAmount = -totalPromptTokens;
|
||||
} else if (txn.tokenType === 'completion') {
|
||||
const multiplier = getMultiplier({ tokenType: txn.tokenType, model, endpointTokenConfig });
|
||||
txn.rate = Math.abs(multiplier);
|
||||
txn.tokenValue = -Math.abs(txn.rawAmount) * multiplier;
|
||||
txn.rawAmount = -Math.abs(txn.rawAmount);
|
||||
}
|
||||
|
||||
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
|
||||
this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
|
||||
this.rate *= cancelRate;
|
||||
if (this.rateDetail) {
|
||||
this.rateDetail = Object.fromEntries(
|
||||
Object.entries(this.rateDetail).map(([k, v]) => [k, v * cancelRate]),
|
||||
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') {
|
||||
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate);
|
||||
txn.rate *= cancelRate;
|
||||
if (txn.rateDetail) {
|
||||
txn.rateDetail = Object.fromEntries(
|
||||
Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]),
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const Transaction = mongoose.model('Transaction', transactionSchema);
|
||||
}
|
||||
|
||||
/**
|
||||
* Queries and retrieves transactions based on a given filter.
|
||||
|
|
@ -333,11 +329,16 @@ const Transaction = mongoose.model('Transaction', transactionSchema);
|
|||
*/
|
||||
async function getTransactions(filter) {
|
||||
try {
|
||||
return await Transaction.find(filter).lean();
|
||||
return await db.models.Transaction.find(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error querying transactions:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { Transaction, getTransactions };
|
||||
module.exports = {
|
||||
getTransactions,
|
||||
createTransaction,
|
||||
createAutoRefillTransaction,
|
||||
createStructuredTransaction,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3,18 +3,22 @@ const { MongoMemoryServer } = require('mongodb-memory-server');
|
|||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const Balance = require('./Balance');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { createTransaction } = require('./Transaction');
|
||||
|
||||
// Mock the custom config module so we can control the balance flag.
|
||||
jest.mock('~/server/services/Config');
|
||||
|
||||
let mongoServer;
|
||||
|
||||
let Balance;
|
||||
let Transaction;
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
await db.connectDb(mongoUri);
|
||||
|
||||
Balance = db.models.Balance;
|
||||
Transaction = db.models.Transaction;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
|
|
@ -368,7 +372,7 @@ describe('NaN Handling Tests', () => {
|
|||
};
|
||||
|
||||
// Act
|
||||
const result = await Transaction.create(txData);
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created and balance remains unchanged.
|
||||
expect(result).toBeUndefined();
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { userSchema } = require('@librechat/data-schemas');
|
||||
|
||||
const User = mongoose.model('User', userSchema);
|
||||
|
||||
module.exports = User;
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const { createAutoRefillTransaction } = require('./Transaction');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { getMultiplier } = require('./tx');
|
||||
const { logger } = require('~/config');
|
||||
const Balance = require('./Balance');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
function isInvalidDate(date) {
|
||||
return isNaN(date);
|
||||
|
|
@ -26,7 +26,7 @@ const checkBalanceRecord = async function ({
|
|||
const tokenCost = amount * multiplier;
|
||||
|
||||
// Retrieve the balance record
|
||||
let record = await Balance.findOne({ user }).lean();
|
||||
let record = await db.models.Balance.findOne({ user }).lean();
|
||||
if (!record) {
|
||||
logger.debug('[Balance.check] No balance record found for user', { user });
|
||||
return {
|
||||
|
|
@ -60,7 +60,7 @@ const checkBalanceRecord = async function ({
|
|||
) {
|
||||
try {
|
||||
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
|
||||
const result = await Transaction.createAutoRefillTransaction({
|
||||
const result = await createAutoRefillTransaction({
|
||||
user: user,
|
||||
tokenType: 'credits',
|
||||
context: 'autoRefill',
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { Message, getMessages, bulkSaveMessages } = require('./Message');
|
||||
const { getMessages, bulkSaveMessages } = require('./Message');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
// Original version of buildTree function
|
||||
function buildTree({ messages, fileMap }) {
|
||||
|
|
@ -42,11 +43,13 @@ function buildTree({ messages, fileMap }) {
|
|||
}
|
||||
|
||||
let mongod;
|
||||
|
||||
let Message;
|
||||
beforeAll(async () => {
|
||||
mongod = await MongoMemoryServer.create();
|
||||
const uri = mongod.getUri();
|
||||
await mongoose.connect(uri);
|
||||
await db.connectDb(uri);
|
||||
|
||||
Message = db.models.Message;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,4 @@
|
|||
const {
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
updateUser,
|
||||
createUser,
|
||||
countUsers,
|
||||
findUser,
|
||||
} = require('./userMethods');
|
||||
const { comparePassword } = require('./userMethods');
|
||||
const {
|
||||
findFileById,
|
||||
createFile,
|
||||
|
|
@ -26,32 +17,11 @@ const {
|
|||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
} = require('./Message');
|
||||
const {
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
} = require('./Session');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { createToken, findToken, updateToken, deleteTokens } = require('./Token');
|
||||
const Balance = require('./Balance');
|
||||
const User = require('./User');
|
||||
const Key = require('./Key');
|
||||
|
||||
module.exports = {
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
updateUser,
|
||||
createUser,
|
||||
countUsers,
|
||||
findUser,
|
||||
|
||||
findFileById,
|
||||
createFile,
|
||||
updateFile,
|
||||
|
|
@ -77,21 +47,4 @@ module.exports = {
|
|||
getPresets,
|
||||
savePreset,
|
||||
deletePresets,
|
||||
|
||||
createToken,
|
||||
findToken,
|
||||
updateToken,
|
||||
deleteTokens,
|
||||
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
|
||||
User,
|
||||
Key,
|
||||
Balance,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { getRandomValues, hashToken } = require('~/server/utils/crypto');
|
||||
const { createToken, findToken } = require('./Token');
|
||||
const logger = require('~/config/winston');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* @module inviteUser
|
||||
|
|
@ -23,7 +23,7 @@ const createInvite = async (email) => {
|
|||
|
||||
const fakeUserId = new mongoose.Types.ObjectId();
|
||||
|
||||
await createToken({
|
||||
await db.models.Token.createToken({
|
||||
userId: fakeUserId,
|
||||
email,
|
||||
token: hash,
|
||||
|
|
@ -50,7 +50,7 @@ const getInvite = async (encodedToken, email) => {
|
|||
try {
|
||||
const token = decodeURIComponent(encodedToken);
|
||||
const hash = await hashToken(token);
|
||||
const invite = await findToken({ token: hash, email });
|
||||
const invite = await db.models.Token.findToken({ token: hash, email });
|
||||
|
||||
if (!invite) {
|
||||
throw new Error('Invite not found or email does not match');
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const mongoMeili = require('../plugins/mongoMeili');
|
||||
|
||||
const { convoSchema } = require('@librechat/data-schemas');
|
||||
|
||||
if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
|
||||
convoSchema.plugin(mongoMeili, {
|
||||
host: process.env.MEILI_HOST,
|
||||
apiKey: process.env.MEILI_MASTER_KEY,
|
||||
/** Note: Will get created automatically if it doesn't exist already */
|
||||
indexName: 'convos',
|
||||
primaryKey: 'conversationId',
|
||||
});
|
||||
}
|
||||
|
||||
const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema);
|
||||
|
||||
module.exports = Conversation;
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
const mongoose = require('mongoose');
|
||||
const mongoMeili = require('~/models/plugins/mongoMeili');
|
||||
const { messageSchema } = require('@librechat/data-schemas');
|
||||
|
||||
|
|
@ -11,6 +10,6 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
|
|||
});
|
||||
}
|
||||
|
||||
const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
|
||||
// const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
|
||||
|
||||
module.exports = Message;
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { presetSchema } = require('@librechat/data-schemas');
|
||||
|
||||
const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema);
|
||||
|
||||
module.exports = Preset;
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
const { Transaction } = require('./Transaction');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { createTransaction, createStructuredTransaction } = require('./Transaction');
|
||||
/**
|
||||
* Creates up to two transactions to record the spending of tokens.
|
||||
*
|
||||
|
|
@ -33,7 +33,7 @@ const spendTokens = async (txData, tokenUsage) => {
|
|||
let prompt, completion;
|
||||
try {
|
||||
if (promptTokens !== undefined) {
|
||||
prompt = await Transaction.create({
|
||||
prompt = await createTransaction({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0),
|
||||
|
|
@ -41,7 +41,7 @@ const spendTokens = async (txData, tokenUsage) => {
|
|||
}
|
||||
|
||||
if (completionTokens !== undefined) {
|
||||
completion = await Transaction.create({
|
||||
completion = await createTransaction({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0),
|
||||
|
|
@ -101,7 +101,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => {
|
|||
try {
|
||||
if (promptTokens) {
|
||||
const { input = 0, write = 0, read = 0 } = promptTokens;
|
||||
prompt = await Transaction.createStructured({
|
||||
prompt = await createStructuredTransaction({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -input,
|
||||
|
|
@ -111,7 +111,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => {
|
|||
}
|
||||
|
||||
if (completionTokens) {
|
||||
completion = await Transaction.create({
|
||||
completion = await createTransaction({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -completionTokens,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const Balance = require('./Balance');
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { createTransaction, createAutoRefillTransaction } = require('./Transaction');
|
||||
|
||||
// Mock the logger to prevent console output during tests
|
||||
jest.mock('~/config', () => ({
|
||||
|
|
@ -20,10 +20,15 @@ describe('spendTokens', () => {
|
|||
let mongoServer;
|
||||
let userId;
|
||||
|
||||
let Transaction;
|
||||
let Balance;
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
await db.connectDb(mongoUri);
|
||||
|
||||
Balance = db.models.Balance;
|
||||
Transaction = db.models.Transaction;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
|
|
@ -197,7 +202,7 @@ describe('spendTokens', () => {
|
|||
// Check that the transaction records show the adjusted values
|
||||
const transactionResults = await Promise.all(
|
||||
transactions.map((t) =>
|
||||
Transaction.create({
|
||||
createTransaction({
|
||||
...txData,
|
||||
tokenType: t.tokenType,
|
||||
rawAmount: t.rawAmount,
|
||||
|
|
@ -280,7 +285,7 @@ describe('spendTokens', () => {
|
|||
|
||||
// Check the return values from Transaction.create directly
|
||||
// This is to verify that the incrementValue is not becoming positive
|
||||
const directResult = await Transaction.create({
|
||||
const directResult = await createTransaction({
|
||||
user: userId,
|
||||
conversationId: 'test-convo-3',
|
||||
model: 'gpt-4',
|
||||
|
|
@ -607,7 +612,7 @@ describe('spendTokens', () => {
|
|||
const promises = [];
|
||||
for (let i = 0; i < numberOfRefills; i++) {
|
||||
promises.push(
|
||||
Transaction.createAutoRefillTransaction({
|
||||
createAutoRefillTransaction({
|
||||
user: userId,
|
||||
tokenType: 'credits',
|
||||
context: 'concurrent-refill-test',
|
||||
|
|
|
|||
|
|
@ -1,159 +1,4 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const signPayload = require('~/server/services/signPayload');
|
||||
const Balance = require('./Balance');
|
||||
const User = require('./User');
|
||||
|
||||
/**
|
||||
* Retrieve a user by ID and convert the found user document to a plain object.
|
||||
*
|
||||
* @param {string} userId - The ID of the user to find and return as a plain object.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||
*/
|
||||
const getUserById = async function (userId, fieldsToSelect = null) {
|
||||
const query = User.findById(userId);
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for a single user based on partial data and return matching user document as plain object.
|
||||
* @param {Partial<MongoUser>} searchCriteria - The partial data to use for searching the user.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||
*/
|
||||
const findUser = async function (searchCriteria, fieldsToSelect = null) {
|
||||
const query = User.findOne(searchCriteria);
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Update a user with new data without overwriting existing properties.
|
||||
*
|
||||
* @param {string} userId - The ID of the user to update.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<MongoUser>} The updated user document as a plain object, or `null` if no user is found.
|
||||
*/
|
||||
const updateUser = async function (userId, updateData) {
|
||||
const updateOperation = {
|
||||
$set: updateData,
|
||||
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
|
||||
};
|
||||
return await User.findByIdAndUpdate(userId, updateOperation, {
|
||||
new: true,
|
||||
runValidators: true,
|
||||
}).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a new user, optionally with a TTL of 1 week.
|
||||
* @param {MongoUser} data - The user data to be created, must contain user_id.
|
||||
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
|
||||
* @param {boolean} [returnUser=false] - Whether to return the created user object.
|
||||
* @returns {Promise<ObjectId|MongoUser>} A promise that resolves to the created user document ID or user object.
|
||||
* @throws {Error} If a user with the same user_id already exists.
|
||||
*/
|
||||
const createUser = async (data, disableTTL = true, returnUser = false) => {
|
||||
const balance = await getBalanceConfig();
|
||||
const userData = {
|
||||
...data,
|
||||
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
|
||||
};
|
||||
|
||||
if (disableTTL) {
|
||||
delete userData.expiresAt;
|
||||
}
|
||||
|
||||
const user = await User.create(userData);
|
||||
|
||||
// If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance
|
||||
if (balance?.enabled && balance?.startBalance) {
|
||||
const update = {
|
||||
$inc: { tokenCredits: balance.startBalance },
|
||||
};
|
||||
|
||||
if (
|
||||
balance.autoRefillEnabled &&
|
||||
balance.refillIntervalValue != null &&
|
||||
balance.refillIntervalUnit != null &&
|
||||
balance.refillAmount != null
|
||||
) {
|
||||
update.$set = {
|
||||
autoRefillEnabled: true,
|
||||
refillIntervalValue: balance.refillIntervalValue,
|
||||
refillIntervalUnit: balance.refillIntervalUnit,
|
||||
refillAmount: balance.refillAmount,
|
||||
};
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean();
|
||||
}
|
||||
|
||||
if (returnUser) {
|
||||
return user.toObject();
|
||||
}
|
||||
return user._id;
|
||||
};
|
||||
|
||||
/**
|
||||
* Count the number of user documents in the collection based on the provided filter.
|
||||
*
|
||||
* @param {Object} [filter={}] - The filter to apply when counting the documents.
|
||||
* @returns {Promise<number>} The count of documents that match the filter.
|
||||
*/
|
||||
const countUsers = async function (filter = {}) {
|
||||
return await User.countDocuments(filter);
|
||||
};
|
||||
|
||||
/**
|
||||
* Delete a user by their unique ID.
|
||||
*
|
||||
* @param {string} userId - The ID of the user to delete.
|
||||
* @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents.
|
||||
*/
|
||||
const deleteUserById = async function (userId) {
|
||||
try {
|
||||
const result = await User.deleteOne({ _id: userId });
|
||||
if (result.deletedCount === 0) {
|
||||
return { deletedCount: 0, message: 'No user found with that ID.' };
|
||||
}
|
||||
return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' };
|
||||
} catch (error) {
|
||||
throw new Error('Error deleting user: ' + error.message);
|
||||
}
|
||||
};
|
||||
|
||||
const { SESSION_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
||||
|
||||
/**
|
||||
* Generates a JWT token for a given user.
|
||||
*
|
||||
* @param {MongoUser} user - The user for whom the token is being generated.
|
||||
* @returns {Promise<string>} A promise that resolves to a JWT token.
|
||||
*/
|
||||
const generateToken = async (user) => {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
return await signPayload({
|
||||
payload: {
|
||||
id: user._id,
|
||||
username: user.username,
|
||||
provider: user.provider,
|
||||
email: user.email,
|
||||
},
|
||||
secret: process.env.JWT_SECRET,
|
||||
expirationTime: expires / 1000,
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Compares the provided password with the user's password.
|
||||
|
|
@ -179,11 +24,4 @@ const comparePassword = async (user, candidatePassword) => {
|
|||
|
||||
module.exports = {
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
countUsers,
|
||||
createUser,
|
||||
updateUser,
|
||||
findUser,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@
|
|||
"librechat-data-provider": "*",
|
||||
"librechat-mcp": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"meilisearch": "^0.38.0",
|
||||
"meilisearch": "^0.50.0",
|
||||
"memorystore": "^1.6.7",
|
||||
"mime": "^3.0.0",
|
||||
"module-alias": "^2.2.3",
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ const {
|
|||
requestPasswordReset,
|
||||
setOpenIDAuthTokens,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findSession, getUserById, deleteAllUserSessions, findUser } = require('~/models');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { logger } = require('~/config');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
|
|
@ -48,7 +48,7 @@ const resetPasswordController = async (req, res) => {
|
|||
if (resetPasswordService instanceof Error) {
|
||||
return res.status(400).json(resetPasswordService);
|
||||
} else {
|
||||
await deleteAllUserSessions({ userId: req.body.userId });
|
||||
await db.models.Session.deleteAllUserSessions({ userId: req.body.userId });
|
||||
return res.status(200).json(resetPasswordService);
|
||||
}
|
||||
} catch (e) {
|
||||
|
|
@ -83,7 +83,7 @@ const refreshController = async (req, res) => {
|
|||
}
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const user = await getUserById(payload.id, '-password -__v -totpSecret');
|
||||
const user = await db.models.User.getUserById(payload.id, '-password -__v -totpSecret');
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
|
|
@ -96,7 +96,10 @@ const refreshController = async (req, res) => {
|
|||
}
|
||||
|
||||
// Find the session with the hashed refresh token
|
||||
const session = await findSession({ userId: userId, refreshToken: refreshToken });
|
||||
const session = await db.models.Session.findSession({
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
});
|
||||
|
||||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
const Balance = require('~/models/Balance');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
async function balanceController(req, res) {
|
||||
const balanceData = await Balance.findOne(
|
||||
const balanceData = await db.models.Balance.findOne(
|
||||
{ user: req.user.id },
|
||||
'-_id tokenCredits autoRefillEnabled refillIntervalValue refillIntervalUnit lastRefill refillAmount',
|
||||
).lean();
|
||||
|
|
|
|||
|
|
@ -5,10 +5,9 @@ const {
|
|||
verifyBackupCode,
|
||||
getTOTPSecret,
|
||||
} = require('~/server/services/twoFactorService');
|
||||
const { updateUser, getUserById } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const { encryptV3 } = require('~/server/utils/crypto');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
|
||||
|
||||
/**
|
||||
|
|
@ -25,7 +24,7 @@ const enable2FA = async (req, res) => {
|
|||
const encryptedSecret = encryptV3(secret);
|
||||
|
||||
// Update the user record: store the secret & backup codes and set twoFactorEnabled to false.
|
||||
const user = await updateUser(userId, {
|
||||
const user = await db.models.User.updateUser(userId, {
|
||||
totpSecret: encryptedSecret,
|
||||
backupCodes: codeObjects,
|
||||
twoFactorEnabled: false,
|
||||
|
|
@ -47,7 +46,7 @@ const verify2FA = async (req, res) => {
|
|||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const user = await db.models.User.getUserById(userId);
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
|
|
@ -79,7 +78,8 @@ const confirm2FA = async (req, res) => {
|
|||
try {
|
||||
const userId = req.user.id;
|
||||
const { token } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const { User } = db.models;
|
||||
const user = await User.getUserById(userId);
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
|
|
@ -87,7 +87,7 @@ const confirm2FA = async (req, res) => {
|
|||
|
||||
const secret = await getTOTPSecret(user.totpSecret);
|
||||
if (await verifyTOTP(secret, token)) {
|
||||
await updateUser(userId, { twoFactorEnabled: true });
|
||||
await User.updateUser(userId, { twoFactorEnabled: true });
|
||||
return res.status(200).json();
|
||||
}
|
||||
return res.status(400).json({ message: 'Invalid token.' });
|
||||
|
|
@ -103,7 +103,7 @@ const confirm2FA = async (req, res) => {
|
|||
const disable2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false });
|
||||
await db.models.User.updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false });
|
||||
return res.status(200).json();
|
||||
} catch (err) {
|
||||
logger.error('[disable2FA]', err);
|
||||
|
|
@ -118,7 +118,7 @@ const regenerateBackupCodes = async (req, res) => {
|
|||
try {
|
||||
const userId = req.user.id;
|
||||
const { plainCodes, codeObjects } = await generateBackupCodes();
|
||||
await updateUser(userId, { backupCodes: codeObjects });
|
||||
await db.models.User.updateUser(userId, { backupCodes: codeObjects });
|
||||
return res.status(200).json({
|
||||
backupCodes: plainCodes,
|
||||
backupCodesHash: codeObjects,
|
||||
|
|
|
|||
|
|
@ -8,15 +8,11 @@ const {
|
|||
const {
|
||||
Balance,
|
||||
getFiles,
|
||||
updateUser,
|
||||
deleteFiles,
|
||||
deleteConvos,
|
||||
deletePresets,
|
||||
deleteMessages,
|
||||
deleteUserById,
|
||||
deleteAllUserSessions,
|
||||
} = require('~/models');
|
||||
const User = require('~/models/User');
|
||||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
|
|
@ -26,6 +22,7 @@ const { deleteAllSharedLinks } = require('~/models/Share');
|
|||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { Transaction } = require('~/models/Transaction');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
/** @type {MongoUser} */
|
||||
|
|
@ -39,7 +36,7 @@ const getUserController = async (req, res) => {
|
|||
const originalAvatar = userData.avatar;
|
||||
try {
|
||||
userData.avatar = await getNewS3URL(userData.avatar);
|
||||
await updateUser(userData.id, { avatar: userData.avatar });
|
||||
await db.models.User.updateUser(userData.id, { avatar: userData.avatar });
|
||||
} catch (error) {
|
||||
userData.avatar = originalAvatar;
|
||||
logger.error('Error getting new S3 URL for avatar:', error);
|
||||
|
|
@ -50,7 +47,7 @@ const getUserController = async (req, res) => {
|
|||
|
||||
const getTermsStatusController = async (req, res) => {
|
||||
try {
|
||||
const user = await User.findById(req.user.id);
|
||||
const user = await db.models.User.findById(req.user.id);
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
|
@ -63,7 +60,7 @@ const getTermsStatusController = async (req, res) => {
|
|||
|
||||
const acceptTermsController = async (req, res) => {
|
||||
try {
|
||||
const user = await User.findByIdAndUpdate(req.user.id, { termsAccepted: true }, { new: true });
|
||||
const user = await db.models.User.findByIdAndUpdate(req.user.id, { termsAccepted: true }, { new: true });
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
|
@ -160,7 +157,7 @@ const deleteUserController = async (req, res) => {
|
|||
|
||||
try {
|
||||
await deleteMessages({ user: user.id }); // delete user messages
|
||||
await deleteAllUserSessions({ userId: user.id }); // delete user sessions
|
||||
await db.models.Session.deleteAllUserSessions({ userId: user.id }); // delete user sessions
|
||||
await Transaction.deleteMany({ user: user.id }); // delete user transactions
|
||||
await deleteUserKey({ userId: user.id, all: true }); // delete user keys
|
||||
await Balance.deleteMany({ user: user._id }); // delete user balances
|
||||
|
|
@ -168,7 +165,7 @@ const deleteUserController = async (req, res) => {
|
|||
/* TODO: Delete Assistant Threads */
|
||||
await deleteConvos(user.id); // delete user convos
|
||||
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
|
||||
await deleteUserById(user.id); // delete user
|
||||
await db.models.User.deleteUserById(user.id); // delete user
|
||||
await deleteAllSharedLinks(user.id); // delete user shared links
|
||||
await deleteUserFiles(req); // delete user files
|
||||
await deleteFiles(null, user.id); // delete database files in case of orphaned files from previous steps
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ const {
|
|||
getTOTPSecret,
|
||||
} = require('~/server/services/twoFactorService');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { getUserById } = require('~/models/userMethods');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Verifies the 2FA code during login using a temporary token.
|
||||
|
|
@ -25,7 +25,7 @@ const verify2FAWithTempToken = async (req, res) => {
|
|||
return res.status(401).json({ message: 'Invalid or expired temporary token' });
|
||||
}
|
||||
|
||||
const user = await getUserById(payload.userId);
|
||||
const user = await db.models.User.getUserById(payload.userId);
|
||||
if (!user || !user.twoFactorEnabled) {
|
||||
return res.status(400).json({ message: '2FA is not enabled for this user' });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ const mongoSanitize = require('express-mongo-sanitize');
|
|||
const fs = require('fs');
|
||||
const cookieParser = require('cookie-parser');
|
||||
const { jwtLogin, passportLogin } = require('~/strategies');
|
||||
const { connectDb, indexSync } = require('~/lib/db');
|
||||
const { connectDb, indexSync, getModels } = require('~/lib/db');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { ldapLogin } = require('~/strategies');
|
||||
const { logger } = require('~/config');
|
||||
|
|
@ -36,6 +36,7 @@ const startServer = async () => {
|
|||
axios.defaults.headers.common['Accept-Encoding'] = 'gzip';
|
||||
}
|
||||
await connectDb();
|
||||
|
||||
logger.info('Connected to MongoDB');
|
||||
await indexSync();
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@ const { isEnabled, removePorts } = require('~/server/utils');
|
|||
const keyvMongo = require('~/cache/keyvMongo');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { findUser } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
|
||||
const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 });
|
||||
const message = 'Your account has been temporarily banned due to violations of our service.';
|
||||
|
|
@ -57,7 +58,7 @@ const checkBan = async (req, res, next = () => {}) => {
|
|||
let userId = req.user?.id ?? req.user?._id ?? null;
|
||||
|
||||
if (!userId && req?.body?.email) {
|
||||
const user = await findUser({ email: req.body.email }, '_id');
|
||||
const user = await db.models.User.findUser({ email: req.body.email }, '_id');
|
||||
userId = user?._id ? user._id.toString() : userId;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { getInvite } = require('~/models/inviteUser');
|
||||
const { deleteTokens } = require('~/models/Token');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
async function checkInviteUser(req, res, next) {
|
||||
const token = req.body.token;
|
||||
|
|
@ -16,7 +16,7 @@ async function checkInviteUser(req, res, next) {
|
|||
return res.status(400).json({ message: 'Invalid invite token' });
|
||||
}
|
||||
|
||||
await deleteTokens({ token: invite.token });
|
||||
await db.models.Token.deleteTokens({ token: invite.token });
|
||||
req.invite = invite;
|
||||
next();
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const Balance = require('~/models/Balance');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Middleware to synchronize user balance settings with current balance configuration.
|
||||
|
|
@ -20,14 +20,14 @@ const setBalanceConfig = async (req, res, next) => {
|
|||
}
|
||||
|
||||
const userId = req.user._id;
|
||||
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
|
||||
const userBalanceRecord = await db.models.Balance.findOne({ user: userId }).lean();
|
||||
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
|
||||
|
||||
if (Object.keys(updateFields).length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate(
|
||||
await db.models.Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: updateFields },
|
||||
{ upsert: true, new: true },
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
|
|||
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
|
||||
const { getConvosQueried } = require('~/models/Conversation');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { Message } = require('~/models/Message');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const router = express.Router();
|
||||
router.use(requireJwtAuth);
|
||||
|
|
@ -40,21 +40,25 @@ router.get('/', async (req, res) => {
|
|||
const sortOrder = sortDirection === 'asc' ? 1 : -1;
|
||||
|
||||
if (conversationId && messageId) {
|
||||
const message = await Message.findOne({ conversationId, messageId, user: user }).lean();
|
||||
const message = await db.models.Message.findOne({
|
||||
conversationId,
|
||||
messageId,
|
||||
user: user,
|
||||
}).lean();
|
||||
response = { messages: message ? [message] : [], nextCursor: null };
|
||||
} else if (conversationId) {
|
||||
const filter = { conversationId, user: user };
|
||||
if (cursor) {
|
||||
filter[sortField] = sortOrder === 1 ? { $gt: cursor } : { $lt: cursor };
|
||||
}
|
||||
const messages = await Message.find(filter)
|
||||
const messages = await db.models.Message.find(filter)
|
||||
.sort({ [sortField]: sortOrder })
|
||||
.limit(pageSize + 1)
|
||||
.lean();
|
||||
const nextCursor = messages.length > pageSize ? messages.pop()[sortField] : null;
|
||||
response = { messages, nextCursor };
|
||||
} else if (search) {
|
||||
const searchResults = await Message.meiliSearch(search, undefined, true);
|
||||
const searchResults = await db.models.Message.meiliSearch(search, undefined, true);
|
||||
|
||||
const messages = searchResults.hits || [];
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ const { logger, getFlowStateManager, sendEvent } = require('~/config');
|
|||
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { getActions, deleteActions } = require('~/models/Action');
|
||||
const { deleteAssistant } = require('~/models/Assistant');
|
||||
const { findToken } = require('~/models/Token');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const JWT_SECRET = process.env.JWT_SECRET;
|
||||
const toolNameRegex = /^[a-zA-Z0-9_-]+$/;
|
||||
|
|
@ -231,9 +231,10 @@ async function createActionTool({
|
|||
};
|
||||
|
||||
const tokenPromises = [];
|
||||
tokenPromises.push(findToken({ userId, type: 'oauth', identifier }));
|
||||
const { Token } = db.models;
|
||||
tokenPromises.push(Token.findToken({ userId, type: 'oauth', identifier }));
|
||||
tokenPromises.push(
|
||||
findToken({
|
||||
Token.findToken({
|
||||
userId,
|
||||
type: 'oauth_refresh',
|
||||
identifier: `${identifier}:refresh`,
|
||||
|
|
|
|||
|
|
@ -1,28 +1,12 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
const { webcrypto } = require('node:crypto');
|
||||
const { SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const {
|
||||
findUser,
|
||||
countUsers,
|
||||
createUser,
|
||||
updateUser,
|
||||
getUserById,
|
||||
generateToken,
|
||||
deleteUserById,
|
||||
} = require('~/models/userMethods');
|
||||
const {
|
||||
createToken,
|
||||
findToken,
|
||||
deleteTokens,
|
||||
findSession,
|
||||
deleteSession,
|
||||
createSession,
|
||||
generateRefreshToken,
|
||||
} = require('~/models');
|
||||
const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { registerSchema } = require('~/strategies/validators');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
|
||||
const domains = {
|
||||
client: process.env.DOMAIN_CLIENT,
|
||||
|
|
@ -40,13 +24,14 @@ const genericVerificationMessage = 'Please check your email to verify your email
|
|||
* @returns
|
||||
*/
|
||||
const logoutUser = async (req, refreshToken) => {
|
||||
const { Session } = db.models;
|
||||
try {
|
||||
const userId = req.user._id;
|
||||
const session = await findSession({ userId: userId, refreshToken });
|
||||
const session = await Session.findSession({ userId: userId, refreshToken });
|
||||
|
||||
if (session) {
|
||||
try {
|
||||
await deleteSession({ sessionId: session._id });
|
||||
await Session.deleteSession({ sessionId: session._id });
|
||||
} catch (deleteErr) {
|
||||
logger.error('[logoutUser] Failed to delete session.', deleteErr);
|
||||
return { status: 500, message: 'Failed to delete session.' };
|
||||
|
|
@ -98,7 +83,7 @@ const sendVerificationEmail = async (user) => {
|
|||
template: 'verifyEmail.handlebars',
|
||||
});
|
||||
|
||||
await createToken({
|
||||
await db.models.Token.createToken({
|
||||
userId: user._id,
|
||||
email: user.email,
|
||||
token: hash,
|
||||
|
|
@ -117,7 +102,8 @@ const verifyEmail = async (req) => {
|
|||
const { email, token } = req.body;
|
||||
const decodedEmail = decodeURIComponent(email);
|
||||
|
||||
const user = await findUser({ email: decodedEmail }, 'email _id emailVerified');
|
||||
const { User, Token } = db.models;
|
||||
const user = await User.findUser({ email: decodedEmail }, 'email _id emailVerified');
|
||||
|
||||
if (!user) {
|
||||
logger.warn(`[verifyEmail] [User not found] [Email: ${decodedEmail}]`);
|
||||
|
|
@ -129,7 +115,7 @@ const verifyEmail = async (req) => {
|
|||
return { message: 'Email already verified', status: 'success' };
|
||||
}
|
||||
|
||||
let emailVerificationData = await findToken({ email: decodedEmail });
|
||||
let emailVerificationData = await Token.findToken({ email: decodedEmail });
|
||||
|
||||
if (!emailVerificationData) {
|
||||
logger.warn(`[verifyEmail] [No email verification data found] [Email: ${decodedEmail}]`);
|
||||
|
|
@ -145,13 +131,13 @@ const verifyEmail = async (req) => {
|
|||
return new Error('Invalid or expired email verification token');
|
||||
}
|
||||
|
||||
const updatedUser = await updateUser(emailVerificationData.userId, { emailVerified: true });
|
||||
const updatedUser = await User.updateUser(emailVerificationData.userId, { emailVerified: true });
|
||||
if (!updatedUser) {
|
||||
logger.warn(`[verifyEmail] [User update failed] [Email: ${decodedEmail}]`);
|
||||
return new Error('Failed to update user verification status');
|
||||
}
|
||||
|
||||
await deleteTokens({ token: emailVerificationData.token });
|
||||
await Token.deleteTokens({ token: emailVerificationData.token });
|
||||
logger.info(`[verifyEmail] Email verification successful [Email: ${decodedEmail}]`);
|
||||
return { message: 'Email verification was successful', status: 'success' };
|
||||
};
|
||||
|
|
@ -162,6 +148,7 @@ const verifyEmail = async (req) => {
|
|||
* @returns {Promise<{status: number, message: string, user?: MongoUser}>}
|
||||
*/
|
||||
const registerUser = async (user, additionalData = {}) => {
|
||||
const { User } = db.models;
|
||||
const { error } = registerSchema.safeParse(user);
|
||||
if (error) {
|
||||
const errorMessage = errorsToString(error.errors);
|
||||
|
|
@ -178,7 +165,7 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
|
||||
let newUserId;
|
||||
try {
|
||||
const existingUser = await findUser({ email }, 'email _id');
|
||||
const existingUser = await User.findUser({ email }, 'email _id');
|
||||
|
||||
if (existingUser) {
|
||||
logger.info(
|
||||
|
|
@ -200,7 +187,7 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
}
|
||||
|
||||
//determine if this is the first registered user (not counting anonymous_user)
|
||||
const isFirstRegisteredUser = (await countUsers()) === 0;
|
||||
const isFirstRegisteredUser = (await User.countUsers()) === 0;
|
||||
|
||||
const salt = bcrypt.genSaltSync(10);
|
||||
const newUserData = {
|
||||
|
|
@ -216,7 +203,9 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
|
||||
const emailEnabled = checkEmailConfig();
|
||||
const disableTTL = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN);
|
||||
const newUser = await createUser(newUserData, disableTTL, true);
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
|
||||
const newUser = await User.createUser(newUserData, balanceConfig, disableTTL, true);
|
||||
newUserId = newUser._id;
|
||||
if (emailEnabled && !newUser.emailVerified) {
|
||||
await sendVerificationEmail({
|
||||
|
|
@ -225,14 +214,14 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
name,
|
||||
});
|
||||
} else {
|
||||
await updateUser(newUserId, { emailVerified: true });
|
||||
await User.updateUser(newUserId, { emailVerified: true });
|
||||
}
|
||||
|
||||
return { status: 200, message: genericVerificationMessage };
|
||||
} catch (err) {
|
||||
logger.error('[registerUser] Error in registering user:', err);
|
||||
if (newUserId) {
|
||||
const result = await deleteUserById(newUserId);
|
||||
const result = await User.deleteUserById(newUserId);
|
||||
logger.warn(
|
||||
`[registerUser] [Email: ${email}] [Temporary User deleted: ${JSON.stringify(result)}]`,
|
||||
);
|
||||
|
|
@ -247,7 +236,8 @@ const registerUser = async (user, additionalData = {}) => {
|
|||
*/
|
||||
const requestPasswordReset = async (req) => {
|
||||
const { email } = req.body;
|
||||
const user = await findUser({ email }, 'email _id');
|
||||
const { User, Token } = db.models;
|
||||
const user = await User.findUser({ email }, 'email _id');
|
||||
const emailEnabled = checkEmailConfig();
|
||||
|
||||
logger.warn(`[requestPasswordReset] [Password reset request initiated] [Email: ${email}]`);
|
||||
|
|
@ -259,11 +249,11 @@ const requestPasswordReset = async (req) => {
|
|||
};
|
||||
}
|
||||
|
||||
await deleteTokens({ userId: user._id });
|
||||
await Token.deleteTokens({ userId: user._id });
|
||||
|
||||
const [resetToken, hash] = createTokenHash();
|
||||
|
||||
await createToken({
|
||||
await Token.createToken({
|
||||
userId: user._id,
|
||||
token: hash,
|
||||
createdAt: Date.now(),
|
||||
|
|
@ -308,7 +298,8 @@ const requestPasswordReset = async (req) => {
|
|||
* @returns
|
||||
*/
|
||||
const resetPassword = async (userId, token, password) => {
|
||||
let passwordResetToken = await findToken({
|
||||
const { User, Token } = db.models;
|
||||
let passwordResetToken = await Token.findToken({
|
||||
userId,
|
||||
});
|
||||
|
||||
|
|
@ -323,7 +314,7 @@ const resetPassword = async (userId, token, password) => {
|
|||
}
|
||||
|
||||
const hash = bcrypt.hashSync(password, 10);
|
||||
const user = await updateUser(userId, { password: hash });
|
||||
const user = await User.updateUser(userId, { password: hash });
|
||||
|
||||
if (checkEmailConfig()) {
|
||||
await sendEmail({
|
||||
|
|
@ -338,7 +329,7 @@ const resetPassword = async (userId, token, password) => {
|
|||
});
|
||||
}
|
||||
|
||||
await deleteTokens({ token: passwordResetToken.token });
|
||||
await Token.deleteTokens({ token: passwordResetToken.token });
|
||||
logger.info(`[resetPassword] Password reset successful. [Email: ${user.email}]`);
|
||||
return { message: 'Password reset was successful' };
|
||||
};
|
||||
|
|
@ -353,19 +344,20 @@ const resetPassword = async (userId, token, password) => {
|
|||
*/
|
||||
const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
try {
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
const { User, Session } = db.models;
|
||||
const user = await User.getUserById(userId);
|
||||
const token = await User.generateToken(user);
|
||||
|
||||
let session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
|
||||
if (sessionId) {
|
||||
session = await findSession({ sessionId: sessionId }, { lean: false });
|
||||
session = await Session.findSession({ sessionId: sessionId }, { lean: false });
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
refreshToken = await Session.generateRefreshToken(session);
|
||||
} else {
|
||||
const result = await createSession(userId);
|
||||
const result = await Session.createSession(userId);
|
||||
session = result.session;
|
||||
refreshToken = result.refreshToken;
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
|
|
@ -444,8 +436,9 @@ const setOpenIDAuthTokens = (tokenset, res) => {
|
|||
const resendVerificationEmail = async (req) => {
|
||||
try {
|
||||
const { email } = req.body;
|
||||
await deleteTokens(email);
|
||||
const user = await findUser({ email }, 'email _id name');
|
||||
const { User, Token } = db.models;
|
||||
await Token.deleteTokens(email);
|
||||
const user = await User.findUser({ email }, 'email _id name');
|
||||
|
||||
if (!user) {
|
||||
logger.warn(`[resendVerificationEmail] [No user found] [Email: ${email}]`);
|
||||
|
|
@ -470,7 +463,7 @@ const resendVerificationEmail = async (req) => {
|
|||
template: 'verifyEmail.handlebars',
|
||||
});
|
||||
|
||||
await createToken({
|
||||
await Token.createToken({
|
||||
userId: user._id,
|
||||
email: user.email,
|
||||
token: hash,
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@ const fs = require('fs');
|
|||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { resizeImageBuffer } = require('../images/resize');
|
||||
const { updateUser } = require('~/models/userMethods');
|
||||
const { updateFile } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
const { saveBufferToAzure } = require('./crud');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
/**
|
||||
* Uploads an image file to Azure Blob Storage.
|
||||
* It resizes and converts the image similar to your Firebase implementation.
|
||||
|
|
@ -108,7 +107,7 @@ async function processAzureAvatar({ buffer, userId, manual, basePath = 'images',
|
|||
const isManual = manual === 'true';
|
||||
const url = `${downloadURL}?manual=${isManual}`;
|
||||
if (isManual) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
await db.models?.User.updateUser(userId, { avatar: url });
|
||||
}
|
||||
return url;
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ const fs = require('fs');
|
|||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { resizeImageBuffer } = require('../images/resize');
|
||||
const { updateUser } = require('~/models/userMethods');
|
||||
const { saveBufferToFirebase } = require('./crud');
|
||||
const { updateFile } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Converts an image file to the target format. The function first resizes the image based on the specified
|
||||
|
|
@ -99,7 +99,7 @@ async function processFirebaseAvatar({ buffer, userId, manual }) {
|
|||
const url = `${downloadURL}?manual=${isManual}`;
|
||||
|
||||
if (isManual) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
await db.models.User.updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
return url;
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ const fs = require('fs');
|
|||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { resizeImageBuffer } = require('../images/resize');
|
||||
const { updateUser } = require('~/models/userMethods');
|
||||
const { updateFile } = require('~/models/File');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
/**
|
||||
* Converts an image file to the target format. The function first resizes the image based on the specified
|
||||
|
|
@ -141,7 +141,7 @@ async function processLocalAvatar({ buffer, userId, manual }) {
|
|||
let url = `${urlRoute}?manual=${isManual}`;
|
||||
|
||||
if (isManual) {
|
||||
await updateUser(userId, { avatar: url });
|
||||
await db.models?.User.updateUser(userId, { avatar: url });
|
||||
}
|
||||
|
||||
return url;
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ const fs = require('fs');
|
|||
const path = require('path');
|
||||
const sharp = require('sharp');
|
||||
const { resizeImageBuffer } = require('../images/resize');
|
||||
const { updateUser } = require('~/models/userMethods');
|
||||
const { saveBufferToS3 } = require('./crud');
|
||||
const { updateFile } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const defaultBasePath = 'images';
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ async function processS3Avatar({ buffer, userId, manual, basePath = defaultBaseP
|
|||
try {
|
||||
const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath });
|
||||
if (manual === 'true') {
|
||||
await updateUser(userId, { avatar: downloadURL });
|
||||
await db.models?.User.updateUser(userId, { avatar: downloadURL });
|
||||
}
|
||||
return downloadURL;
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { encrypt, decrypt } = require('~/server/utils');
|
||||
const { updateUser, Key } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
/**
|
||||
* Updates the plugins for a user based on the action specified (install/uninstall).
|
||||
* @async
|
||||
|
|
@ -17,10 +16,11 @@ const { logger } = require('~/config');
|
|||
const updateUserPluginsService = async (user, pluginKey, action) => {
|
||||
try {
|
||||
const userPlugins = user.plugins || [];
|
||||
const { User } = db.models;
|
||||
if (action === 'install') {
|
||||
return await updateUser(user._id, { plugins: [...userPlugins, pluginKey] });
|
||||
return await User.updateUser(user._id, { plugins: [...userPlugins, pluginKey] });
|
||||
} else if (action === 'uninstall') {
|
||||
return await updateUser(user._id, {
|
||||
return await User.updateUser(user._id, {
|
||||
plugins: userPlugins.filter((plugin) => plugin !== pluginKey),
|
||||
});
|
||||
}
|
||||
|
|
@ -42,7 +42,7 @@ const updateUserPluginsService = async (user, pluginKey, action) => {
|
|||
* an error indicating that there is no user key available.
|
||||
*/
|
||||
const getUserKey = async ({ userId, name }) => {
|
||||
const keyValue = await Key.findOne({ userId, name }).lean();
|
||||
const keyValue = await db.models.Key.findOne({ userId, name }).lean();
|
||||
if (!keyValue) {
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
|
|
@ -89,7 +89,7 @@ const getUserKeyValues = async ({ userId, name }) => {
|
|||
* returns its expiry date. If the key is not found, it returns null for the expiry date.
|
||||
*/
|
||||
const getUserKeyExpiry = async ({ userId, name }) => {
|
||||
const keyValue = await Key.findOne({ userId, name }).lean();
|
||||
const keyValue = await db.models.Key.findOne({ userId, name }).lean();
|
||||
if (!keyValue) {
|
||||
return { expiresAt: null };
|
||||
}
|
||||
|
|
@ -123,7 +123,7 @@ const updateUserKey = async ({ userId, name, value, expiresAt = null }) => {
|
|||
// make sure to remove if already present
|
||||
updateQuery.$unset = { expiresAt };
|
||||
}
|
||||
return await Key.findOneAndUpdate({ userId, name }, updateQuery, {
|
||||
return await db.models.Key.findOneAndUpdate({ userId, name }, updateQuery, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
}).lean();
|
||||
|
|
@ -143,10 +143,10 @@ const updateUserKey = async ({ userId, name, value, expiresAt = null }) => {
|
|||
*/
|
||||
const deleteUserKey = async ({ userId, name, all = false }) => {
|
||||
if (all) {
|
||||
return await Key.deleteMany({ userId });
|
||||
return await db.models.Key.deleteMany({ userId });
|
||||
}
|
||||
|
||||
await Key.findOneAndDelete({ userId, name }).lean();
|
||||
await db.models.Key.findOneAndDelete({ userId, name }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
const { webcrypto } = require('node:crypto');
|
||||
const { decryptV3, decryptV2 } = require('../utils/crypto');
|
||||
const { hashBackupCode } = require('~/server/utils/crypto');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
// Base32 alphabet for TOTP secret encoding.
|
||||
const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567';
|
||||
|
|
@ -172,8 +173,7 @@ const verifyBackupCode = async ({ user, backupCode }) => {
|
|||
: codeObj,
|
||||
);
|
||||
// Update the user record with the marked backup code.
|
||||
const { updateUser } = require('~/models');
|
||||
await updateUser(user._id, { backupCodes: updatedBackupCodes });
|
||||
await db.models.User.updateUser(user._id, { backupCodes: updatedBackupCodes });
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ const crypto = require('node:crypto');
|
|||
const { webcrypto } = crypto;
|
||||
|
||||
// Use hex decoding for both key and IV for legacy methods.
|
||||
const key = Buffer.from(process.env.CREDS_KEY, 'hex');
|
||||
const iv = Buffer.from(process.env.CREDS_IV, 'hex');
|
||||
const key = Buffer.from('f34be427ebb29de8d88c107a71546019685ed8b241d8f2ed00c3df97ad2566f0', 'hex');
|
||||
const iv = Buffer.from('e2341419ec3dd3d19b13a1a87fafcbfb', 'hex');
|
||||
const algorithm = 'AES-CBC';
|
||||
|
||||
// --- Legacy v1/v2 Setup: AES-CBC with fixed key and IV ---
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
jest.mock('~/models/Message', () => ({
|
||||
Message: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
Conversation: jest.fn(),
|
||||
}));
|
||||
const { isEnabled, sanitizeFilename } = require('./handleText');
|
||||
|
||||
describe('isEnabled', () => {
|
||||
|
|
|
|||
|
|
@ -18,17 +18,13 @@ const getProfileDetails = ({ idToken, profile }) => {
|
|||
|
||||
const decoded = jwt.decode(idToken);
|
||||
|
||||
logger.debug(
|
||||
`Decoded Apple JWT: ${JSON.stringify(decoded, null, 2)}`,
|
||||
);
|
||||
logger.debug(`Decoded Apple JWT: ${JSON.stringify(decoded, null, 2)}`);
|
||||
|
||||
return {
|
||||
email: decoded.email,
|
||||
id: decoded.sub,
|
||||
avatarUrl: null, // Apple does not provide an avatar URL
|
||||
username: decoded.email
|
||||
? decoded.email.split('@')[0].toLowerCase()
|
||||
: `user_${decoded.sub}`,
|
||||
username: decoded.email ? decoded.email.split('@')[0].toLowerCase() : `user_${decoded.sub}`,
|
||||
name: decoded.name
|
||||
? `${decoded.name.firstName} ${decoded.name.lastName}`
|
||||
: profile.displayName || null,
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@ const { MongoMemoryServer } = require('mongodb-memory-server');
|
|||
const jwt = require('jsonwebtoken');
|
||||
const { Strategy: AppleStrategy } = require('passport-apple');
|
||||
const socialLogin = require('./socialLogin');
|
||||
const User = require('~/models/User');
|
||||
const { logger } = require('~/config');
|
||||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { findUser } = require('~/models');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
// Mocking external dependencies
|
||||
jest.mock('jsonwebtoken');
|
||||
|
|
@ -24,21 +23,20 @@ jest.mock('./process', () => ({
|
|||
jest.mock('~/server/utils', () => ({
|
||||
isEnabled: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
findUser: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('Apple Login Strategy', () => {
|
||||
let mongoServer;
|
||||
let appleStrategyInstance;
|
||||
const OLD_ENV = process.env;
|
||||
let getProfileDetails;
|
||||
|
||||
let User;
|
||||
// Start and stop in-memory MongoDB
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
await db.connectDb(mongoUri);
|
||||
|
||||
User = db.models.User;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
|
|
@ -64,7 +62,6 @@ describe('Apple Login Strategy', () => {
|
|||
|
||||
// Define getProfileDetails within the test scope
|
||||
getProfileDetails = ({ idToken, profile }) => {
|
||||
console.log('getProfileDetails called with idToken:', idToken);
|
||||
if (!idToken) {
|
||||
logger.error('idToken is missing');
|
||||
throw new Error('idToken is missing');
|
||||
|
|
@ -84,9 +81,7 @@ describe('Apple Login Strategy', () => {
|
|||
email: decoded.email,
|
||||
id: decoded.sub,
|
||||
avatarUrl: null, // Apple does not provide an avatar URL
|
||||
username: decoded.email
|
||||
? decoded.email.split('@')[0].toLowerCase()
|
||||
: `user_${decoded.sub}`,
|
||||
username: decoded.email ? decoded.email.split('@')[0].toLowerCase() : `user_${decoded.sub}`,
|
||||
name: decoded.name
|
||||
? `${decoded.name.firstName} ${decoded.name.lastName}`
|
||||
: profile.displayName || null,
|
||||
|
|
@ -96,8 +91,12 @@ describe('Apple Login Strategy', () => {
|
|||
|
||||
// Mock isEnabled based on environment variable
|
||||
isEnabled.mockImplementation((flag) => {
|
||||
if (flag === 'true') { return true; }
|
||||
if (flag === 'false') { return false; }
|
||||
if (flag === 'true') {
|
||||
return true;
|
||||
}
|
||||
if (flag === 'false') {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
|
|
@ -154,9 +153,7 @@ describe('Apple Login Strategy', () => {
|
|||
});
|
||||
|
||||
expect(jwt.decode).toHaveBeenCalledWith('fake_id_token');
|
||||
expect(logger.debug).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Decoded Apple JWT'),
|
||||
);
|
||||
expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Decoded Apple JWT'));
|
||||
expect(profileDetails).toEqual({
|
||||
email: 'john.doe@example.com',
|
||||
id: 'apple-sub-1234',
|
||||
|
|
@ -209,12 +206,13 @@ describe('Apple Login Strategy', () => {
|
|||
|
||||
beforeEach(() => {
|
||||
jwt.decode.mockReturnValue(decodedToken);
|
||||
findUser.mockImplementation(({ email }) => User.findOne({ email }));
|
||||
User.findUser = jest.fn();
|
||||
User.findUser.mockImplementation(({ email }) => User.findOne({ email }));
|
||||
});
|
||||
|
||||
it('should create a new user if one does not exist and registration is allowed', async () => {
|
||||
// Mock findUser to return null (user does not exist)
|
||||
findUser.mockResolvedValue(null);
|
||||
User.findUser.mockResolvedValue(null);
|
||||
|
||||
// Mock createSocialUser to create a user
|
||||
createSocialUser.mockImplementation(async (userData) => {
|
||||
|
|
@ -260,7 +258,7 @@ describe('Apple Login Strategy', () => {
|
|||
await existingUser.save();
|
||||
|
||||
// Mock findUser to return the existing user
|
||||
findUser.mockResolvedValue(existingUser);
|
||||
User.findUser.mockResolvedValue(existingUser);
|
||||
|
||||
// Mock handleExistingUser to update avatarUrl
|
||||
handleExistingUser.mockImplementation(async (user, avatarUrl) => {
|
||||
|
|
@ -297,7 +295,7 @@ describe('Apple Login Strategy', () => {
|
|||
appleStrategyInstance._verify(
|
||||
fakeAccessToken,
|
||||
fakeRefreshToken,
|
||||
null, // idToken is missing
|
||||
null, // idToken is missing
|
||||
mockProfile,
|
||||
(err, user) => {
|
||||
mockVerifyCallback(err, user);
|
||||
|
|
@ -344,7 +342,7 @@ describe('Apple Login Strategy', () => {
|
|||
|
||||
it('should handle errors during user creation', async () => {
|
||||
// Mock findUser to return null (user does not exist)
|
||||
findUser.mockResolvedValue(null);
|
||||
User.findUser.mockResolvedValue(null);
|
||||
|
||||
// Mock createSocialUser to throw an error
|
||||
createSocialUser.mockImplementation(() => {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { getUserById, updateUser } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
// JWT strategy
|
||||
const jwtLogin = () =>
|
||||
|
|
@ -12,12 +12,13 @@ const jwtLogin = () =>
|
|||
},
|
||||
async (payload, done) => {
|
||||
try {
|
||||
const user = await getUserById(payload?.id, '-password -__v -totpSecret');
|
||||
const {User} = db.models;
|
||||
const user = await User.getUserById(payload?.id, '-password -__v -totpSecret');
|
||||
if (user) {
|
||||
user.id = user._id.toString();
|
||||
if (!user.role) {
|
||||
user.role = SystemRoles.USER;
|
||||
await updateUser(user.id, { role: user.role });
|
||||
await User.updateUser(user.id, { role: user.role });
|
||||
}
|
||||
done(null, user);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
const fs = require('fs');
|
||||
const LdapStrategy = require('passport-ldapauth');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { findUser, createUser, updateUser } = require('~/models/userMethods');
|
||||
const { countUsers } = require('~/models/userMethods');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const logger = require('~/utils/logger');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
|
||||
const {
|
||||
LDAP_URL,
|
||||
|
|
@ -81,6 +81,7 @@ const ldapOptions = {
|
|||
};
|
||||
|
||||
const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
||||
const { User } = db.models;
|
||||
if (!userinfo) {
|
||||
return done(null, false, { message: 'Invalid credentials' });
|
||||
}
|
||||
|
|
@ -89,7 +90,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
|||
const ldapId =
|
||||
(LDAP_ID && userinfo[LDAP_ID]) || userinfo.uid || userinfo.sAMAccountName || userinfo.mail;
|
||||
|
||||
let user = await findUser({ ldapId });
|
||||
let user = await User.findUser({ ldapId });
|
||||
|
||||
const fullNameAttributes = LDAP_FULL_NAME && LDAP_FULL_NAME.split(',');
|
||||
const fullName =
|
||||
|
|
@ -114,7 +115,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
|||
}
|
||||
|
||||
if (!user) {
|
||||
const isFirstRegisteredUser = (await countUsers()) === 0;
|
||||
const isFirstRegisteredUser = (await User.countUsers()) === 0;
|
||||
user = {
|
||||
provider: 'ldap',
|
||||
ldapId,
|
||||
|
|
@ -124,7 +125,9 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
|||
name: fullName,
|
||||
role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER,
|
||||
};
|
||||
const userId = await createUser(user);
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
|
||||
const userId = await User.createUser(user, balanceConfig);
|
||||
user._id = userId;
|
||||
} else {
|
||||
// Users registered in LDAP are assumed to have their user information managed in LDAP,
|
||||
|
|
@ -136,7 +139,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
|||
user.name = fullName;
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
user = await User.updateUser(user._id, user);
|
||||
done(null, user);
|
||||
} catch (err) {
|
||||
logger.error('[ldapStrategy]', err);
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
const { errorsToString } = require('librechat-data-provider');
|
||||
const { Strategy: PassportLocalStrategy } = require('passport-local');
|
||||
const { findUser, comparePassword, updateUser } = require('~/models');
|
||||
const { comparePassword } = require('~/models');
|
||||
const { isEnabled, checkEmailConfig } = require('~/server/utils');
|
||||
const { loginSchema } = require('./validators');
|
||||
const logger = require('~/utils/logger');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
// Unix timestamp for 2024-06-07 15:20:18 Eastern Time
|
||||
const verificationEnabledTimestamp = 1717788018;
|
||||
|
|
@ -14,6 +15,7 @@ async function validateLoginRequest(req) {
|
|||
}
|
||||
|
||||
async function passportLogin(req, email, password, done) {
|
||||
const {User} = db.models;
|
||||
try {
|
||||
const validationError = await validateLoginRequest(req);
|
||||
if (validationError) {
|
||||
|
|
@ -22,7 +24,7 @@ async function passportLogin(req, email, password, done) {
|
|||
return done(null, false, { message: validationError });
|
||||
}
|
||||
|
||||
const user = await findUser({ email: email.trim() });
|
||||
const user = await User.findUser({ email: email.trim() });
|
||||
if (!user) {
|
||||
logError('Passport Local Strategy - User Not Found', { email });
|
||||
logger.error(`[Login] [Login failed] [Username: ${email}] [Request-IP: ${req.ip}]`);
|
||||
|
|
@ -44,13 +46,13 @@ async function passportLogin(req, email, password, done) {
|
|||
!user.emailVerified &&
|
||||
userCreatedAtTimestamp < verificationEnabledTimestamp
|
||||
) {
|
||||
await updateUser(user._id, { emailVerified: true });
|
||||
await User.updateUser(user._id, { emailVerified: true });
|
||||
user.emailVerified = true;
|
||||
}
|
||||
|
||||
const unverifiedAllowed = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN);
|
||||
if (user.expiresAt && unverifiedAllowed) {
|
||||
await updateUser(user._id, {});
|
||||
await User.updateUser(user._id, {});
|
||||
}
|
||||
|
||||
if (!user.emailVerified && !unverifiedAllowed) {
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ const { HttpsProxyAgent } = require('https-proxy-agent');
|
|||
const client = require('openid-client');
|
||||
const { Strategy: OpenIDStrategy } = require('openid-client/passport');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { findUser, createUser, updateUser } = require('~/models/userMethods');
|
||||
const { hashToken } = require('~/server/utils/crypto');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
|
@ -37,6 +35,24 @@ class CustomOpenIDStrategy extends OpenIDStrategy {
|
|||
}
|
||||
}
|
||||
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
|
||||
let crypto;
|
||||
let webcrypto;
|
||||
try {
|
||||
crypto = require('node:crypto');
|
||||
webcrypto = crypto;
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] crypto support is disabled!', err);
|
||||
}
|
||||
|
||||
async function hashToken(str) {
|
||||
const data = new TextEncoder().encode(str);
|
||||
const hashBuffer = await webcrypto.subtle.digest('SHA-256', data);
|
||||
return Buffer.from(hashBuffer).toString('hex');
|
||||
}
|
||||
|
||||
/**
|
||||
* Exchange the access token for a new access token using the on-behalf-of flow if required.
|
||||
* @param {Configuration} config
|
||||
|
|
@ -196,6 +212,7 @@ function convertToUsername(input, defaultValue = '') {
|
|||
* @throws {Error} If an error occurs during the setup process.
|
||||
*/
|
||||
async function setupOpenId() {
|
||||
const { User } = db.models;
|
||||
try {
|
||||
/** @type {ClientMetadata} */
|
||||
const clientMetadata = {
|
||||
|
|
@ -230,13 +247,13 @@ async function setupOpenId() {
|
|||
async (tokenset, done) => {
|
||||
try {
|
||||
const claims = tokenset.claims();
|
||||
let user = await findUser({ openidId: claims.sub });
|
||||
let user = await User.findUser({ openidId: claims.sub });
|
||||
logger.info(
|
||||
`[openidStrategy] user ${user ? 'found' : 'not found'} with openidId: ${claims.sub}`,
|
||||
);
|
||||
|
||||
if (!user) {
|
||||
user = await findUser({ email: claims.email });
|
||||
user = await User.findUser({ email: claims.email });
|
||||
logger.info(
|
||||
`[openidStrategy] user ${user ? 'found' : 'not found'} with email: ${
|
||||
claims.email
|
||||
|
|
@ -297,7 +314,10 @@ async function setupOpenId() {
|
|||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
};
|
||||
user = await createUser(user, true, true);
|
||||
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
|
||||
user = await User.createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
user.provider = 'openid';
|
||||
user.openidId = userinfo.sub;
|
||||
|
|
@ -333,7 +353,7 @@ async function setupOpenId() {
|
|||
}
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
user = await User.updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
`[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,48 @@
|
|||
const fetch = require('node-fetch');
|
||||
const jwtDecode = require('jsonwebtoken/decode');
|
||||
const { findUser, createUser, updateUser } = require('~/models/userMethods');
|
||||
const { setupOpenId } = require('./openidStrategy');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
|
||||
// --- Mocks ---
|
||||
const mockCreateUser = jest.fn();
|
||||
const mockFindUser = jest.fn();
|
||||
const mockUpdateUser = jest.fn();
|
||||
let User;
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
return {
|
||||
registerModels: jest.fn().mockReturnValue({
|
||||
User: {
|
||||
createUser: mockCreateUser,
|
||||
findUser: mockFindUser,
|
||||
updateUser: mockUpdateUser,
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const mockModels = {
|
||||
User: {
|
||||
createUser: mockCreateUser,
|
||||
findUser: mockFindUser,
|
||||
updateUser: mockUpdateUser,
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock('~/lib/db/connectDb', () => {
|
||||
return {
|
||||
getModels: jest.fn(() => mockModels),
|
||||
connectDb: jest.fn(),
|
||||
get models() {
|
||||
return mockModels;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getBalanceConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('node-fetch');
|
||||
jest.mock('jsonwebtoken/decode');
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
|
|
@ -11,11 +50,7 @@ jest.mock('~/server/services/Files/strategies', () => ({
|
|||
saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'),
|
||||
})),
|
||||
}));
|
||||
jest.mock('~/models/userMethods', () => ({
|
||||
findUser: jest.fn(),
|
||||
createUser: jest.fn(),
|
||||
updateUser: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/utils/crypto', () => ({
|
||||
hashToken: jest.fn().mockResolvedValue('hashed-token'),
|
||||
}));
|
||||
|
|
@ -109,7 +144,8 @@ describe('setupOpenId', () => {
|
|||
picture: 'https://example.com/avatar.png',
|
||||
}),
|
||||
};
|
||||
|
||||
const { registerModels } = require('@librechat/data-schemas');
|
||||
User = registerModels().User;
|
||||
beforeEach(async () => {
|
||||
// Clear previous mock calls and reset implementations
|
||||
jest.clearAllMocks();
|
||||
|
|
@ -134,13 +170,16 @@ describe('setupOpenId', () => {
|
|||
roles: ['requiredRole'],
|
||||
});
|
||||
|
||||
User.findUser = jest.fn();
|
||||
// By default, assume that no user is found, so createUser will be called
|
||||
findUser.mockResolvedValue(null);
|
||||
createUser.mockImplementation(async (userData) => {
|
||||
|
||||
const balance = await getBalanceConfig.mockResolvedValue({ enabled: false });
|
||||
User.createUser.mockImplementation(async (userData, balance) => {
|
||||
// simulate created user with an _id property
|
||||
return { _id: 'newUserId', ...userData };
|
||||
});
|
||||
updateUser.mockImplementation(async (id, userData) => {
|
||||
// User.updateUser = jest.fn();
|
||||
User.updateUser.mockImplementation(async (id, userData) => {
|
||||
return { _id: id, ...userData };
|
||||
});
|
||||
|
||||
|
|
@ -166,7 +205,7 @@ describe('setupOpenId', () => {
|
|||
|
||||
// Assert
|
||||
expect(user.username).toBe(userinfo.username);
|
||||
expect(createUser).toHaveBeenCalledWith(
|
||||
expect(User.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
|
|
@ -174,6 +213,7 @@ describe('setupOpenId', () => {
|
|||
email: userinfo.email,
|
||||
name: `${userinfo.given_name} ${userinfo.family_name}`,
|
||||
}),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
|
|
@ -191,8 +231,9 @@ describe('setupOpenId', () => {
|
|||
|
||||
// Assert
|
||||
expect(user.username).toBe(expectUsername);
|
||||
expect(createUser).toHaveBeenCalledWith(
|
||||
expect(User.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ username: expectUsername }),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
|
|
@ -210,8 +251,9 @@ describe('setupOpenId', () => {
|
|||
|
||||
// Assert
|
||||
expect(user.username).toBe(expectUsername);
|
||||
expect(createUser).toHaveBeenCalledWith(
|
||||
expect(User.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ username: expectUsername }),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
|
|
@ -227,8 +269,9 @@ describe('setupOpenId', () => {
|
|||
|
||||
// Assert – username should equal the sub (converted as-is)
|
||||
expect(user.username).toBe(userinfo.sub);
|
||||
expect(createUser).toHaveBeenCalledWith(
|
||||
expect(User.createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ username: userinfo.sub }),
|
||||
{ enabled: false },
|
||||
true,
|
||||
true,
|
||||
);
|
||||
|
|
@ -268,11 +311,8 @@ describe('setupOpenId', () => {
|
|||
username: '',
|
||||
name: '',
|
||||
};
|
||||
findUser.mockImplementation(async (query) => {
|
||||
if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) {
|
||||
return existingUser;
|
||||
}
|
||||
return null;
|
||||
mockFindUser.mockImplementation(async (query) => {
|
||||
return existingUser;
|
||||
});
|
||||
|
||||
const userinfo = tokenset.claims();
|
||||
|
|
@ -281,7 +321,7 @@ describe('setupOpenId', () => {
|
|||
await validate(tokenset);
|
||||
|
||||
// Assert – updateUser should be called and the user object updated
|
||||
expect(updateUser).toHaveBeenCalledWith(
|
||||
expect(User.updateUser).toHaveBeenCalledWith(
|
||||
existingUser._id,
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
const { FileSources } = require('librechat-data-provider');
|
||||
const { createUser, updateUser, getUserById } = require('~/models/userMethods');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
|
||||
/**
|
||||
* Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter
|
||||
|
|
@ -34,7 +35,7 @@ const handleExistingUser = async (oldUser, avatarUrl) => {
|
|||
}
|
||||
|
||||
if (updatedAvatar) {
|
||||
await updateUser(oldUser._id, { avatar: updatedAvatar });
|
||||
await db.models.User.updateUser(oldUser._id, { avatar: updatedAvatar });
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -78,7 +79,8 @@ const createSocialUser = async ({
|
|||
emailVerified,
|
||||
};
|
||||
|
||||
const newUserId = await createUser(update);
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
const newUserId = await db.models.User.createUser(update, balanceConfig);
|
||||
const fileStrategy = process.env.CDN_PROVIDER;
|
||||
const isLocal = fileStrategy === FileSources.local;
|
||||
|
||||
|
|
@ -89,10 +91,10 @@ const createSocialUser = async ({
|
|||
});
|
||||
const { processAvatar } = getStrategyFunctions(fileStrategy);
|
||||
const avatar = await processAvatar({ buffer: resizedBuffer, userId: newUserId });
|
||||
await updateUser(newUserId, { avatar });
|
||||
await User.updateUser(newUserId, { avatar });
|
||||
}
|
||||
|
||||
return await getUserById(newUserId);
|
||||
return await User.getUserById(newUserId);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { findUser } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const db = require('~/lib/db/connectDb');
|
||||
|
||||
const socialLogin =
|
||||
(provider, getProfileDetails) => async (accessToken, refreshToken, idToken, profile, cb) => {
|
||||
|
|
@ -11,7 +11,7 @@ const socialLogin =
|
|||
profile,
|
||||
});
|
||||
|
||||
const oldUser = await findUser({ email: email.trim() });
|
||||
const oldUser = await db.models.User.findUser({ email: email.trim() });
|
||||
const ALLOW_SOCIAL_REGISTRATION = isEnabled(process.env.ALLOW_SOCIAL_REGISTRATION);
|
||||
|
||||
if (oldUser) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue