mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-06 02:28:51 +01:00
Merge branch 'main' into partial-filter-index
This commit is contained in:
commit
915022bc08
1573 changed files with 145791 additions and 49740 deletions
|
|
@ -11,13 +11,11 @@ const Action = mongoose.model('action', actionSchema);
|
|||
* @param {string} searchParams.action_id - The ID of the action to update.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<Object>} The updated or newly created action document as a plain object.
|
||||
* @returns {Promise<Action>} The updated or newly created action document as a plain object.
|
||||
*/
|
||||
const updateAction = async (searchParams, updateData) => {
|
||||
return await Action.findOneAndUpdate(searchParams, updateData, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
}).lean();
|
||||
const options = { new: true, upsert: true };
|
||||
return await Action.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -25,7 +23,7 @@ const updateAction = async (searchParams, updateData) => {
|
|||
*
|
||||
* @param {Object} searchParams - The search parameters to find matching actions.
|
||||
* @param {boolean} includeSensitive - Flag to include sensitive data in the metadata.
|
||||
* @returns {Promise<Array<Object>>} A promise that resolves to an array of action documents as plain objects.
|
||||
* @returns {Promise<Array<Action>>} A promise that resolves to an array of action documents as plain objects.
|
||||
*/
|
||||
const getActions = async (searchParams, includeSensitive = false) => {
|
||||
const actions = await Action.find(searchParams).lean();
|
||||
|
|
@ -50,19 +48,33 @@ const getActions = async (searchParams, includeSensitive = false) => {
|
|||
};
|
||||
|
||||
/**
|
||||
* Deletes an action by its ID.
|
||||
* Deletes an action by params.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the action to update.
|
||||
* @param {string} searchParams.action_id - The ID of the action to update.
|
||||
* @param {Object} searchParams - The search parameters to find the action to delete.
|
||||
* @param {string} searchParams.action_id - The ID of the action to delete.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @returns {Promise<Object>} A promise that resolves to the deleted action document as a plain object, or null if no document was found.
|
||||
* @returns {Promise<Action>} A promise that resolves to the deleted action document as a plain object, or null if no document was found.
|
||||
*/
|
||||
const deleteAction = async (searchParams) => {
|
||||
return await Action.findOneAndDelete(searchParams).lean();
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAction,
|
||||
getActions,
|
||||
deleteAction,
|
||||
/**
|
||||
* Deletes actions by params.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the actions to delete.
|
||||
* @param {string} searchParams.action_id - The ID of the action(s) to delete.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @returns {Promise<Number>} A promise that resolves to the number of deleted action documents.
|
||||
*/
|
||||
const deleteActions = async (searchParams) => {
|
||||
const result = await Action.deleteMany(searchParams);
|
||||
return result.deletedCount;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getActions,
|
||||
updateAction,
|
||||
deleteAction,
|
||||
deleteActions,
|
||||
};
|
||||
|
|
|
|||
302
api/models/Agent.js
Normal file
302
api/models/Agent.js
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys;
|
||||
const {
|
||||
getProjectByName,
|
||||
addAgentIdsToProject,
|
||||
removeAgentIdsFromProject,
|
||||
removeAgentFromAllProjects,
|
||||
} = require('./Project');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const agentSchema = require('./schema/agent');
|
||||
|
||||
const Agent = mongoose.model('agent', agentSchema);
|
||||
|
||||
/**
|
||||
* Create an agent with the provided data.
|
||||
* @param {Object} agentData - The agent data to create.
|
||||
* @returns {Promise<Agent>} The created agent document as a plain object.
|
||||
* @throws {Error} If the agent creation fails.
|
||||
*/
|
||||
const createAgent = async (agentData) => {
|
||||
return (await Agent.create(agentData)).toObject();
|
||||
};
|
||||
|
||||
/**
|
||||
* Get an agent document based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find the agent to update.
|
||||
* @param {string} searchParameter.id - The ID of the agent to update.
|
||||
* @param {string} searchParameter.author - The user ID of the agent's author.
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean();
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.agent_id
|
||||
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
|
||||
*/
|
||||
const loadAgent = async ({ req, agent_id }) => {
|
||||
const agent = await getAgent({
|
||||
id: agent_id,
|
||||
});
|
||||
|
||||
if (agent.author.toString() === req.user.id) {
|
||||
return agent;
|
||||
}
|
||||
|
||||
if (!agent.projectIds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const cache = getLogStores(CONFIG_STORE);
|
||||
/** @type {TStartupConfig} */
|
||||
const cachedStartupConfig = await cache.get(STARTUP_CONFIG);
|
||||
let { instanceProjectId } = cachedStartupConfig ?? {};
|
||||
if (!instanceProjectId) {
|
||||
instanceProjectId = (await getProjectByName(GLOBAL_PROJECT_NAME, '_id'))._id.toString();
|
||||
}
|
||||
|
||||
for (const projectObjectId of agent.projectIds) {
|
||||
const projectId = projectObjectId.toString();
|
||||
if (projectId === instanceProjectId) {
|
||||
return agent;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Update an agent with new data without overwriting existing
|
||||
* properties, or create a new agent if it doesn't exist.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find the agent to update.
|
||||
* @param {string} searchParameter.id - The ID of the agent to update.
|
||||
* @param {string} [searchParameter.author] - The user ID of the agent's author.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<Agent>} The updated or newly created agent document as a plain object.
|
||||
*/
|
||||
const updateAgent = async (searchParameter, updateData) => {
|
||||
const options = { new: true, upsert: false };
|
||||
return Agent.findOneAndUpdate(searchParameter, updateData, options).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an agent with the resource file id.
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.agent_id
|
||||
* @param {string} params.tool_resource
|
||||
* @param {string} params.file_id
|
||||
* @returns {Promise<Agent>} The updated agent.
|
||||
*/
|
||||
const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
|
||||
// build the update to push or create the file ids set
|
||||
const fileIdsPath = `tool_resources.${tool_resource}.file_ids`;
|
||||
const updateData = { $addToSet: { [fileIdsPath]: file_id } };
|
||||
|
||||
// return the updated agent or throw if no agent matches
|
||||
const updatedAgent = await updateAgent(searchParameter, updateData);
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for adding resource file');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Removes multiple resource files from an agent in a single update.
|
||||
* @param {object} params
|
||||
* @param {string} params.agent_id
|
||||
* @param {Array<{tool_resource: string, file_id: string}>} params.files
|
||||
* @returns {Promise<Agent>} The updated agent.
|
||||
*/
|
||||
const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
|
||||
// associate each tool resource with the respective file ids array
|
||||
const filesByResource = files.reduce((acc, { tool_resource, file_id }) => {
|
||||
if (!acc[tool_resource]) {
|
||||
acc[tool_resource] = [];
|
||||
}
|
||||
acc[tool_resource].push(file_id);
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// build the update aggregation pipeline wich removes file ids from tool resources array
|
||||
// and eventually deletes empty tool resources
|
||||
const updateData = [];
|
||||
Object.entries(filesByResource).forEach(([resource, fileIds]) => {
|
||||
const toolResourcePath = `tool_resources.${resource}`;
|
||||
const fileIdsPath = `${toolResourcePath}.file_ids`;
|
||||
|
||||
// file ids removal stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[fileIdsPath]: {
|
||||
$filter: {
|
||||
input: `$${fileIdsPath}`,
|
||||
cond: { $not: [{ $in: ['$$this', fileIds] }] },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// empty tool resource deletion stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[toolResourcePath]: {
|
||||
$cond: [{ $eq: [`$${fileIdsPath}`, []] }, '$$REMOVE', `$${toolResourcePath}`],
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
// return the updated agent or throw if no agent matches
|
||||
const updatedAgent = await updateAgent(searchParameter, updateData);
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an agent based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find the agent to delete.
|
||||
* @param {string} searchParameter.id - The ID of the agent to delete.
|
||||
* @param {string} [searchParameter.author] - The user ID of the agent's author.
|
||||
* @returns {Promise<void>} Resolves when the agent has been successfully deleted.
|
||||
*/
|
||||
const deleteAgent = async (searchParameter) => {
|
||||
const agent = await Agent.findOneAndDelete(searchParameter);
|
||||
if (agent) {
|
||||
await removeAgentFromAllProjects(agent.id);
|
||||
}
|
||||
return agent;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all agents.
|
||||
* @param {Object} searchParameter - The search parameters to find matching agents.
|
||||
* @param {string} searchParameter.author - The user ID of the agent's author.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
|
||||
*/
|
||||
const getListAgents = async (searchParameter) => {
|
||||
const { author, ...otherParams } = searchParameter;
|
||||
|
||||
let query = Object.assign({ author }, otherParams);
|
||||
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, ['agentIds']);
|
||||
if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) {
|
||||
const globalQuery = { id: { $in: globalProject.agentIds }, ...otherParams };
|
||||
delete globalQuery.author;
|
||||
query = { $or: [globalQuery, query] };
|
||||
}
|
||||
|
||||
const agents = (
|
||||
await Agent.find(query, {
|
||||
id: 1,
|
||||
_id: 0,
|
||||
name: 1,
|
||||
avatar: 1,
|
||||
author: 1,
|
||||
projectIds: 1,
|
||||
description: 1,
|
||||
isCollaborative: 1,
|
||||
}).lean()
|
||||
).map((agent) => {
|
||||
if (agent.author?.toString() !== author) {
|
||||
delete agent.author;
|
||||
}
|
||||
if (agent.author) {
|
||||
agent.author = agent.author.toString();
|
||||
}
|
||||
return agent;
|
||||
});
|
||||
|
||||
const hasMore = agents.length > 0;
|
||||
const firstId = agents.length > 0 ? agents[0].id : null;
|
||||
const lastId = agents.length > 0 ? agents[agents.length - 1].id : null;
|
||||
|
||||
return {
|
||||
data: agents,
|
||||
has_more: hasMore,
|
||||
first_id: firstId,
|
||||
last_id: lastId,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates the projects associated with an agent, adding and removing project IDs as specified.
|
||||
* This function also updates the corresponding projects to include or exclude the agent ID.
|
||||
*
|
||||
* @param {Object} params - Parameters for updating the agent's projects.
|
||||
* @param {import('librechat-data-provider').TUser} params.user - Parameters for updating the agent's projects.
|
||||
* @param {string} params.agentId - The ID of the agent to update.
|
||||
* @param {string[]} [params.projectIds] - Array of project IDs to add to the agent.
|
||||
* @param {string[]} [params.removeProjectIds] - Array of project IDs to remove from the agent.
|
||||
* @returns {Promise<MongoAgent>} The updated agent document.
|
||||
* @throws {Error} If there's an error updating the agent or projects.
|
||||
*/
|
||||
const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds }) => {
|
||||
const updateOps = {};
|
||||
|
||||
if (removeProjectIds && removeProjectIds.length > 0) {
|
||||
for (const projectId of removeProjectIds) {
|
||||
await removeAgentIdsFromProject(projectId, [agentId]);
|
||||
}
|
||||
updateOps.$pull = { projectIds: { $in: removeProjectIds } };
|
||||
}
|
||||
|
||||
if (projectIds && projectIds.length > 0) {
|
||||
for (const projectId of projectIds) {
|
||||
await addAgentIdsToProject(projectId, [agentId]);
|
||||
}
|
||||
updateOps.$addToSet = { projectIds: { $each: projectIds } };
|
||||
}
|
||||
|
||||
if (Object.keys(updateOps).length === 0) {
|
||||
return await getAgent({ id: agentId });
|
||||
}
|
||||
|
||||
const updateQuery = { id: agentId, author: user.id };
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
delete updateQuery.author;
|
||||
}
|
||||
|
||||
const updatedAgent = await updateAgent(updateQuery, updateOps);
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
}
|
||||
if (updateOps.$addToSet) {
|
||||
for (const projectId of projectIds) {
|
||||
await removeAgentIdsFromProject(projectId, [agentId]);
|
||||
}
|
||||
} else if (updateOps.$pull) {
|
||||
for (const projectId of removeProjectIds) {
|
||||
await addAgentIdsToProject(projectId, [agentId]);
|
||||
}
|
||||
}
|
||||
|
||||
return await getAgent({ id: agentId });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getAgent,
|
||||
loadAgent,
|
||||
createAgent,
|
||||
updateAgent,
|
||||
deleteAgent,
|
||||
getListAgents,
|
||||
updateAgentProjects,
|
||||
addAgentResourceFile,
|
||||
removeAgentResourceFiles,
|
||||
};
|
||||
|
|
@ -11,13 +11,11 @@ const Assistant = mongoose.model('assistant', assistantSchema);
|
|||
* @param {string} searchParams.assistant_id - The ID of the assistant to update.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
|
||||
* @returns {Promise<AssistantDocument>} The updated or newly created assistant document as a plain object.
|
||||
*/
|
||||
const updateAssistant = async (searchParams, updateData) => {
|
||||
return await Assistant.findOneAndUpdate(searchParams, updateData, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
}).lean();
|
||||
const updateAssistantDoc = async (searchParams, updateData) => {
|
||||
const options = { new: true, upsert: true };
|
||||
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -26,7 +24,7 @@ const updateAssistant = async (searchParams, updateData) => {
|
|||
* @param {Object} searchParams - The search parameters to find the assistant to update.
|
||||
* @param {string} searchParams.assistant_id - The ID of the assistant to update.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @returns {Promise<Object|null>} The assistant document as a plain object, or null if not found.
|
||||
* @returns {Promise<AssistantDocument|null>} The assistant document as a plain object, or null if not found.
|
||||
*/
|
||||
const getAssistant = async (searchParams) => await Assistant.findOne(searchParams).lean();
|
||||
|
||||
|
|
@ -34,14 +32,34 @@ const getAssistant = async (searchParams) => await Assistant.findOne(searchParam
|
|||
* Retrieves all assistants that match the given search parameters.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find matching assistants.
|
||||
* @returns {Promise<Array<Object>>} A promise that resolves to an array of action documents as plain objects.
|
||||
* @param {Object} [select] - Optional. Specifies which document fields to include or exclude.
|
||||
* @returns {Promise<Array<AssistantDocument>>} A promise that resolves to an array of assistant documents as plain objects.
|
||||
*/
|
||||
const getAssistants = async (searchParams) => {
|
||||
return await Assistant.find(searchParams).lean();
|
||||
const getAssistants = async (searchParams, select = null) => {
|
||||
let query = Assistant.find(searchParams);
|
||||
|
||||
if (select) {
|
||||
query = query.select(select);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an assistant based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the assistant to delete.
|
||||
* @param {string} searchParams.assistant_id - The ID of the assistant to delete.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @returns {Promise<void>} Resolves when the assistant has been successfully deleted.
|
||||
*/
|
||||
const deleteAssistant = async (searchParams) => {
|
||||
return await Assistant.findOneAndDelete(searchParams);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAssistant,
|
||||
updateAssistantDoc,
|
||||
deleteAssistant,
|
||||
getAssistants,
|
||||
getAssistant,
|
||||
};
|
||||
|
|
|
|||
27
api/models/Banner.js
Normal file
27
api/models/Banner.js
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
const Banner = require('./schema/banner');
|
||||
const logger = require('~/config/winston');
|
||||
/**
|
||||
* Retrieves the current active banner.
|
||||
* @returns {Promise<Object|null>} The active banner object or null if no active banner is found.
|
||||
*/
|
||||
const getBanner = async (user) => {
|
||||
try {
|
||||
const now = new Date();
|
||||
const banner = await Banner.findOne({
|
||||
displayFrom: { $lte: now },
|
||||
$or: [{ displayTo: { $gte: now } }, { displayTo: null }],
|
||||
type: 'banner',
|
||||
}).lean();
|
||||
|
||||
if (!banner || banner.isPublic || user) {
|
||||
return banner;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
logger.error('[getBanners] Error getting banners', error);
|
||||
throw new Error('Error getting banners');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = { getBanner };
|
||||
57
api/models/Categories.js
Normal file
57
api/models/Categories.js
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
const { logger } = require('~/config');
|
||||
// const { Categories } = require('./schema/categories');
|
||||
const options = [
|
||||
{
|
||||
label: 'idea',
|
||||
value: 'idea',
|
||||
},
|
||||
{
|
||||
label: 'travel',
|
||||
value: 'travel',
|
||||
},
|
||||
{
|
||||
label: 'teach_or_explain',
|
||||
value: 'teach_or_explain',
|
||||
},
|
||||
{
|
||||
label: 'write',
|
||||
value: 'write',
|
||||
},
|
||||
{
|
||||
label: 'shop',
|
||||
value: 'shop',
|
||||
},
|
||||
{
|
||||
label: 'code',
|
||||
value: 'code',
|
||||
},
|
||||
{
|
||||
label: 'misc',
|
||||
value: 'misc',
|
||||
},
|
||||
{
|
||||
label: 'roleplay',
|
||||
value: 'roleplay',
|
||||
},
|
||||
{
|
||||
label: 'finance',
|
||||
value: 'finance',
|
||||
},
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
/**
|
||||
* Retrieves the categories asynchronously.
|
||||
* @returns {Promise<TGetCategoriesResponse>} An array of category objects.
|
||||
* @throws {Error} If there is an error retrieving the categories.
|
||||
*/
|
||||
getCategories: async () => {
|
||||
try {
|
||||
// const categories = await Categories.find();
|
||||
return options;
|
||||
} catch (error) {
|
||||
logger.error('Error getting categories', error);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -2,6 +2,39 @@ const Conversation = require('./schema/convoSchema');
|
|||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<{conversationId: string, user: string} | null>} The conversation object with selected fields or null if not found.
|
||||
*/
|
||||
const searchConversation = async (conversationId) => {
|
||||
try {
|
||||
return await Conversation.findOne({ conversationId }, 'conversationId user').lean();
|
||||
} catch (error) {
|
||||
logger.error('[searchConversation] Error searching conversation', error);
|
||||
throw new Error('Error searching conversation');
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Searches for a conversation by conversationId and returns associated file ids.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<string[] | null>}
|
||||
*/
|
||||
const getConvoFiles = async (conversationId) => {
|
||||
try {
|
||||
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
|
||||
} catch (error) {
|
||||
logger.error('[getConvoFiles] Error getting conversation files', error);
|
||||
throw new Error('Error getting conversation files');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves a single conversation for a given user and conversation ID.
|
||||
* @param {string} user - The user's ID.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @returns {Promise<TConversation>} The conversation object.
|
||||
*/
|
||||
const getConvo = async (user, conversationId) => {
|
||||
try {
|
||||
return await Conversation.findOne({ user, conversationId }).lean();
|
||||
|
|
@ -11,30 +44,120 @@ const getConvo = async (user, conversationId) => {
|
|||
}
|
||||
};
|
||||
|
||||
const deleteNullOrEmptyConversations = async () => {
|
||||
try {
|
||||
const filter = {
|
||||
$or: [
|
||||
{ conversationId: null },
|
||||
{ conversationId: '' },
|
||||
{ conversationId: { $exists: false } },
|
||||
],
|
||||
};
|
||||
|
||||
const result = await Conversation.deleteMany(filter);
|
||||
|
||||
// Delete associated messages
|
||||
const messageDeleteResult = await deleteMessages(filter);
|
||||
|
||||
logger.info(
|
||||
`[deleteNullOrEmptyConversations] Deleted ${result.deletedCount} conversations and ${messageDeleteResult.deletedCount} messages`,
|
||||
);
|
||||
|
||||
return {
|
||||
conversations: result,
|
||||
messages: messageDeleteResult,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteNullOrEmptyConversations] Error deleting conversations', error);
|
||||
throw new Error('Error deleting conversations with null or empty conversationId');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
Conversation,
|
||||
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
|
||||
getConvoFiles,
|
||||
searchConversation,
|
||||
deleteNullOrEmptyConversations,
|
||||
/**
|
||||
* Saves a conversation to the database.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @param {Object} metadata - Additional metadata to log for operation.
|
||||
* @returns {Promise<TConversation>} The conversation object.
|
||||
*/
|
||||
saveConvo: async (req, { conversationId, newConversationId, ...convo }, metadata) => {
|
||||
try {
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...convo, messages, user };
|
||||
if (metadata && metadata?.context) {
|
||||
logger.debug(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
const messages = await getMessages({ conversationId }, '_id');
|
||||
const update = { ...convo, messages, user: req.user.id };
|
||||
if (newConversationId) {
|
||||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
});
|
||||
if (req.body.isTemporary) {
|
||||
const expiredAt = new Date();
|
||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
||||
update.expiredAt = expiredAt;
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
/** Note: the resulting Model object is necessary for Meilisearch operations */
|
||||
const conversation = await Conversation.findOneAndUpdate(
|
||||
{ conversationId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
new: true,
|
||||
upsert: true,
|
||||
},
|
||||
);
|
||||
|
||||
return conversation.toObject();
|
||||
} catch (error) {
|
||||
logger.error('[saveConvo] Error saving conversation', error);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
return { message: 'Error saving conversation' };
|
||||
}
|
||||
},
|
||||
getConvosByPage: async (user, pageNumber = 1, pageSize = 25) => {
|
||||
bulkSaveConvos: async (conversations) => {
|
||||
try {
|
||||
const totalConvos = (await Conversation.countDocuments({ user })) || 1;
|
||||
const bulkOps = conversations.map((convo) => ({
|
||||
updateOne: {
|
||||
filter: { conversationId: convo.conversationId, user: convo.user },
|
||||
update: convo,
|
||||
upsert: true,
|
||||
timestamps: false,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Conversation.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[saveBulkConversations] Error saving conversations in bulk', error);
|
||||
throw new Error('Failed to save conversations in bulk.');
|
||||
}
|
||||
},
|
||||
getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false, tags) => {
|
||||
const query = { user };
|
||||
if (isArchived) {
|
||||
query.isArchived = true;
|
||||
} else {
|
||||
query.$or = [{ isArchived: false }, { isArchived: { $exists: false } }];
|
||||
}
|
||||
if (Array.isArray(tags) && tags.length > 0) {
|
||||
query.tags = { $in: tags };
|
||||
}
|
||||
|
||||
query.$and = [{ $or: [{ expiredAt: null }, { expiredAt: { $exists: false } }] }];
|
||||
|
||||
try {
|
||||
const totalConvos = (await Conversation.countDocuments(query)) || 1;
|
||||
const totalPages = Math.ceil(totalConvos / pageSize);
|
||||
const convos = await Conversation.find({ user })
|
||||
const convos = await Conversation.find(query)
|
||||
.sort({ updatedAt: -1 })
|
||||
.skip((pageNumber - 1) * pageSize)
|
||||
.limit(pageSize)
|
||||
|
|
@ -60,6 +183,7 @@ module.exports = {
|
|||
Conversation.findOne({
|
||||
user,
|
||||
conversationId: convo.conversationId,
|
||||
$or: [{ expiredAt: { $exists: false } }, { expiredAt: null }],
|
||||
}).lean(),
|
||||
),
|
||||
);
|
||||
|
|
|
|||
249
api/models/ConversationTag.js
Normal file
249
api/models/ConversationTag.js
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
const ConversationTag = require('./schema/conversationTagSchema');
|
||||
const Conversation = require('./schema/convoSchema');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
* Retrieves all conversation tags for a user.
|
||||
* @param {string} user - The user ID.
|
||||
* @returns {Promise<Array>} An array of conversation tags.
|
||||
*/
|
||||
const getConversationTags = async (user) => {
|
||||
try {
|
||||
return await ConversationTag.find({ user }).sort({ position: 1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('[getConversationTags] Error getting conversation tags', error);
|
||||
throw new Error('Error getting conversation tags');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a new conversation tag.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {Object} data - The tag data.
|
||||
* @param {string} data.tag - The tag name.
|
||||
* @param {string} [data.description] - The tag description.
|
||||
* @param {boolean} [data.addToConversation] - Whether to add the tag to a conversation.
|
||||
* @param {string} [data.conversationId] - The conversation ID to add the tag to.
|
||||
* @returns {Promise<Object>} The created tag.
|
||||
*/
|
||||
const createConversationTag = async (user, data) => {
|
||||
try {
|
||||
const { tag, description, addToConversation, conversationId } = data;
|
||||
|
||||
const existingTag = await ConversationTag.findOne({ user, tag }).lean();
|
||||
if (existingTag) {
|
||||
return existingTag;
|
||||
}
|
||||
|
||||
const maxPosition = await ConversationTag.findOne({ user }).sort('-position').lean();
|
||||
const position = (maxPosition?.position || 0) + 1;
|
||||
|
||||
const newTag = await ConversationTag.findOneAndUpdate(
|
||||
{ tag, user },
|
||||
{
|
||||
tag,
|
||||
user,
|
||||
count: addToConversation ? 1 : 0,
|
||||
position,
|
||||
description,
|
||||
$setOnInsert: { createdAt: new Date() },
|
||||
},
|
||||
{
|
||||
new: true,
|
||||
upsert: true,
|
||||
lean: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (addToConversation && conversationId) {
|
||||
await Conversation.findOneAndUpdate(
|
||||
{ user, conversationId },
|
||||
{ $addToSet: { tags: tag } },
|
||||
{ new: true },
|
||||
);
|
||||
}
|
||||
|
||||
return newTag;
|
||||
} catch (error) {
|
||||
logger.error('[createConversationTag] Error creating conversation tag', error);
|
||||
throw new Error('Error creating conversation tag');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates an existing conversation tag.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string} oldTag - The current tag name.
|
||||
* @param {Object} data - The updated tag data.
|
||||
* @param {string} [data.tag] - The new tag name.
|
||||
* @param {string} [data.description] - The updated description.
|
||||
* @param {number} [data.position] - The new position.
|
||||
* @returns {Promise<Object>} The updated tag.
|
||||
*/
|
||||
const updateConversationTag = async (user, oldTag, data) => {
|
||||
try {
|
||||
const { tag: newTag, description, position } = data;
|
||||
|
||||
const existingTag = await ConversationTag.findOne({ user, tag: oldTag }).lean();
|
||||
if (!existingTag) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (newTag && newTag !== oldTag) {
|
||||
const tagAlreadyExists = await ConversationTag.findOne({ user, tag: newTag }).lean();
|
||||
if (tagAlreadyExists) {
|
||||
throw new Error('Tag already exists');
|
||||
}
|
||||
|
||||
await Conversation.updateMany({ user, tags: oldTag }, { $set: { 'tags.$': newTag } });
|
||||
}
|
||||
|
||||
const updateData = {};
|
||||
if (newTag) {
|
||||
updateData.tag = newTag;
|
||||
}
|
||||
if (description !== undefined) {
|
||||
updateData.description = description;
|
||||
}
|
||||
if (position !== undefined) {
|
||||
await adjustPositions(user, existingTag.position, position);
|
||||
updateData.position = position;
|
||||
}
|
||||
|
||||
return await ConversationTag.findOneAndUpdate({ user, tag: oldTag }, updateData, {
|
||||
new: true,
|
||||
lean: true,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[updateConversationTag] Error updating conversation tag', error);
|
||||
throw new Error('Error updating conversation tag');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adjusts positions of tags when a tag's position is changed.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {number} oldPosition - The old position of the tag.
|
||||
* @param {number} newPosition - The new position of the tag.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const adjustPositions = async (user, oldPosition, newPosition) => {
|
||||
if (oldPosition === newPosition) {
|
||||
return;
|
||||
}
|
||||
|
||||
const update = oldPosition < newPosition ? { $inc: { position: -1 } } : { $inc: { position: 1 } };
|
||||
const position =
|
||||
oldPosition < newPosition
|
||||
? {
|
||||
$gt: Math.min(oldPosition, newPosition),
|
||||
$lte: Math.max(oldPosition, newPosition),
|
||||
}
|
||||
: {
|
||||
$gte: Math.min(oldPosition, newPosition),
|
||||
$lt: Math.max(oldPosition, newPosition),
|
||||
};
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
{
|
||||
user,
|
||||
position,
|
||||
},
|
||||
update,
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a conversation tag.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string} tag - The tag to delete.
|
||||
* @returns {Promise<Object>} The deleted tag.
|
||||
*/
|
||||
const deleteConversationTag = async (user, tag) => {
|
||||
try {
|
||||
const deletedTag = await ConversationTag.findOneAndDelete({ user, tag }).lean();
|
||||
if (!deletedTag) {
|
||||
return null;
|
||||
}
|
||||
|
||||
await Conversation.updateMany({ user, tags: tag }, { $pull: { tags: tag } });
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
{ user, position: { $gt: deletedTag.position } },
|
||||
{ $inc: { position: -1 } },
|
||||
);
|
||||
|
||||
return deletedTag;
|
||||
} catch (error) {
|
||||
logger.error('[deleteConversationTag] Error deleting conversation tag', error);
|
||||
throw new Error('Error deleting conversation tag');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates tags for a specific conversation.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string} conversationId - The conversation ID.
|
||||
* @param {string[]} tags - The new set of tags for the conversation.
|
||||
* @returns {Promise<string[]>} The updated list of tags for the conversation.
|
||||
*/
|
||||
const updateTagsForConversation = async (user, conversationId, tags) => {
|
||||
try {
|
||||
const conversation = await Conversation.findOne({ user, conversationId }).lean();
|
||||
if (!conversation) {
|
||||
throw new Error('Conversation not found');
|
||||
}
|
||||
|
||||
const oldTags = new Set(conversation.tags);
|
||||
const newTags = new Set(tags);
|
||||
|
||||
const addedTags = [...newTags].filter((tag) => !oldTags.has(tag));
|
||||
const removedTags = [...oldTags].filter((tag) => !newTags.has(tag));
|
||||
|
||||
const bulkOps = [];
|
||||
|
||||
for (const tag of addedTags) {
|
||||
bulkOps.push({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: 1 } },
|
||||
upsert: true,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
for (const tag of removedTags) {
|
||||
bulkOps.push({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: -1 } },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (bulkOps.length > 0) {
|
||||
await ConversationTag.bulkWrite(bulkOps);
|
||||
}
|
||||
|
||||
const updatedConversation = (
|
||||
await Conversation.findOneAndUpdate(
|
||||
{ user, conversationId },
|
||||
{ $set: { tags: [...newTags] } },
|
||||
{ new: true },
|
||||
)
|
||||
).toObject();
|
||||
|
||||
return updatedConversation.tags;
|
||||
} catch (error) {
|
||||
logger.error('[updateTagsForConversation] Error updating tags', error);
|
||||
throw new Error('Error updating tags for conversation');
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getConversationTags,
|
||||
createConversationTag,
|
||||
updateConversationTag,
|
||||
deleteConversationTag,
|
||||
updateTagsForConversation,
|
||||
};
|
||||
|
|
@ -69,7 +69,7 @@ const updateFileUsage = async (data) => {
|
|||
const { file_id, inc = 1 } = data;
|
||||
const updateOperation = {
|
||||
$inc: { usage: inc },
|
||||
$unset: { expiresAt: '' },
|
||||
$unset: { expiresAt: '', temp_file_id: '' },
|
||||
};
|
||||
return await File.findOneAndUpdate({ file_id }, updateOperation, { new: true }).lean();
|
||||
};
|
||||
|
|
@ -97,8 +97,12 @@ const deleteFileByFilter = async (filter) => {
|
|||
* @param {Array<string>} file_ids - The unique identifiers of the files to delete.
|
||||
* @returns {Promise<Object>} A promise that resolves to the result of the deletion operation.
|
||||
*/
|
||||
const deleteFiles = async (file_ids) => {
|
||||
return await File.deleteMany({ file_id: { $in: file_ids } });
|
||||
const deleteFiles = async (file_ids, user) => {
|
||||
let deleteQuery = { file_id: { $in: file_ids } };
|
||||
if (user) {
|
||||
deleteQuery = { user: user };
|
||||
}
|
||||
return await File.deleteMany(deleteQuery);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
|
|
|||
|
|
@ -1,170 +1,323 @@
|
|||
const { z } = require('zod');
|
||||
const Message = require('./schema/messageSchema');
|
||||
const logger = require('~/config/winston');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
/**
|
||||
* Saves a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function saveMessage
|
||||
* @param {Express.Request} req - The request object containing user information.
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.iconURL - The URL of the sender's icon.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.newMessageId - The new unique identifier for the message (if applicable).
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {string} params.sender - The identifier of the sender.
|
||||
* @param {string} params.text - The text content of the message.
|
||||
* @param {boolean} params.isCreatedByUser - Indicates if the message was created by the user.
|
||||
* @param {string} [params.error] - Any error associated with the message.
|
||||
* @param {boolean} [params.unfinished] - Indicates if the message is unfinished.
|
||||
* @param {Object[]} [params.files] - An array of files associated with the message.
|
||||
* @param {string} [params.finish_reason] - Reason for finishing the message.
|
||||
* @param {number} [params.tokenCount] - The number of tokens in the message.
|
||||
* @param {string} [params.plugin] - Plugin associated with the message.
|
||||
* @param {string[]} [params.plugins] - An array of plugins associated with the message.
|
||||
* @param {string} [params.model] - The model used to generate the message.
|
||||
* @param {Object} [metadata] - Additional metadata for this operation
|
||||
* @param {string} [metadata.context] - The context of the operation
|
||||
* @returns {Promise<TMessage>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function saveMessage(req, params, metadata) {
|
||||
if (!req?.user?.id) {
|
||||
throw new Error('User not authenticated');
|
||||
}
|
||||
|
||||
const validConvoId = idSchema.safeParse(params.conversationId);
|
||||
if (!validConvoId.success) {
|
||||
logger.warn(`Invalid conversation ID: ${params.conversationId}`);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
logger.info(`---Invalid conversation ID Params: ${JSON.stringify(params, null, 2)}`);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const update = {
|
||||
...params,
|
||||
user: req.user.id,
|
||||
messageId: params.newMessageId || params.messageId,
|
||||
};
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
const expiredAt = new Date();
|
||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
||||
update.expiredAt = expiredAt;
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
const message = await Message.findOneAndUpdate(
|
||||
{ messageId: params.messageId, user: req.user.id },
|
||||
update,
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
|
||||
return message.toObject();
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves multiple messages in the database in bulk.
|
||||
*
|
||||
* @async
|
||||
* @function bulkSaveMessages
|
||||
* @param {Object[]} messages - An array of message objects to save.
|
||||
* @param {boolean} [overrideTimestamp=false] - Indicates whether to override the timestamps of the messages. Defaults to false.
|
||||
* @returns {Promise<Object>} The result of the bulk write operation.
|
||||
* @throws {Error} If there is an error in saving messages in bulk.
|
||||
*/
|
||||
async function bulkSaveMessages(messages, overrideTimestamp = false) {
|
||||
try {
|
||||
const bulkOps = messages.map((message) => ({
|
||||
updateOne: {
|
||||
filter: { messageId: message.messageId },
|
||||
update: message,
|
||||
timestamps: !overrideTimestamp,
|
||||
upsert: true,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Records a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function recordMessage
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.user - The identifier of the user.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
|
||||
* @returns {Promise<Object>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function recordMessage({
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest
|
||||
}) {
|
||||
try {
|
||||
// No parsing of convoId as may use threadId
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error recording message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the text of a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessageText
|
||||
* @param {Object} params - The update data object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.text - The new text content of the message.
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} If there is an error in updating the message text.
|
||||
*/
|
||||
async function updateMessageText(req, { messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId, user: req.user.id }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessage
|
||||
* @param {Object} req - The request object.
|
||||
* @param {Object} message - The message object containing update data.
|
||||
* @param {string} message.messageId - The unique identifier for the message.
|
||||
* @param {string} [message.text] - The new text content of the message.
|
||||
* @param {Object[]} [message.files] - The files associated with the message.
|
||||
* @param {boolean} [message.isCreatedByUser] - Indicates if the message was created by the user.
|
||||
* @param {string} [message.sender] - The identifier of the sender.
|
||||
* @param {number} [message.tokenCount] - The number of tokens in the message.
|
||||
* @param {Object} [metadata] - The operation metadata
|
||||
* @param {string} [metadata.context] - The operation metadata
|
||||
* @returns {Promise<TMessage>} The updated message document.
|
||||
* @throws {Error} If there is an error in updating the message or if the message is not found.
|
||||
*/
|
||||
async function updateMessage(req, message, metadata) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
const updatedMessage = await Message.findOneAndUpdate(
|
||||
{ messageId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
new: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found or user not authorized.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`updateMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages in a conversation since a specific message.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessagesSince
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @returns {Promise<Number>} The number of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessagesSince(req, { messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
|
||||
|
||||
if (message) {
|
||||
const query = Message.find({ conversationId, user: req.user.id });
|
||||
return await query.deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
return undefined;
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
* @async
|
||||
* @function getMessages
|
||||
* @param {Record<string, unknown>} filter - The filter criteria.
|
||||
* @param {string | undefined} [select] - The fields to select.
|
||||
* @returns {Promise<TMessage[]>} The messages that match the filter criteria.
|
||||
* @throws {Error} If there is an error in retrieving messages.
|
||||
*/
|
||||
async function getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a single message from the database.
|
||||
* @async
|
||||
* @function getMessage
|
||||
* @param {{ user: string, messageId: string }} params - The search parameters
|
||||
* @returns {Promise<TMessage | null>} The message that matches the criteria or null if not found
|
||||
* @throws {Error} If there is an error in retrieving the message
|
||||
*/
|
||||
async function getMessage({ user, messageId }) {
|
||||
try {
|
||||
return await Message.findOne({
|
||||
user,
|
||||
messageId,
|
||||
}).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages from the database.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessages
|
||||
* @param {Object} filter - The filter criteria to find messages to delete.
|
||||
* @returns {Promise<Object>} The metadata with count of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
Message,
|
||||
|
||||
async saveMessage({
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
newMessageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
error,
|
||||
unfinished,
|
||||
files,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
}) {
|
||||
try {
|
||||
const validConvoId = idSchema.safeParse(conversationId);
|
||||
if (!validConvoId.success) {
|
||||
return;
|
||||
}
|
||||
|
||||
const update = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId: newMessageId || messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
error,
|
||||
unfinished,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
};
|
||||
|
||||
if (files) {
|
||||
update.files = files;
|
||||
}
|
||||
// may also need to update the conversation here
|
||||
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
|
||||
|
||||
return {
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
tokenCount,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Records a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function recordMessage
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.user - The identifier of the user.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
|
||||
* @returns {Promise<Object>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) {
|
||||
try {
|
||||
// No parsing of convoId as may use threadId
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
async updateMessage(message) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
update.isEdited = true;
|
||||
const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, {
|
||||
new: true,
|
||||
});
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
isEdited: true,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
throw new Error('Failed to update message.');
|
||||
}
|
||||
},
|
||||
async deleteMessagesSince({ messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId }).lean();
|
||||
|
||||
if (message) {
|
||||
return await Message.find({ conversationId }).deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw new Error('Failed to delete messages.');
|
||||
}
|
||||
},
|
||||
|
||||
async getMessages(filter) {
|
||||
try {
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw new Error('Failed to get messages.');
|
||||
}
|
||||
},
|
||||
|
||||
async deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw new Error('Failed to delete messages.');
|
||||
}
|
||||
},
|
||||
saveMessage,
|
||||
bulkSaveMessages,
|
||||
recordMessage,
|
||||
updateMessageText,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
getMessages,
|
||||
getMessage,
|
||||
deleteMessages,
|
||||
};
|
||||
|
|
|
|||
238
api/models/Message.spec.js
Normal file
238
api/models/Message.spec.js
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
|
||||
jest.mock('mongoose');
|
||||
|
||||
const mockFindQuery = {
|
||||
select: jest.fn().mockReturnThis(),
|
||||
sort: jest.fn().mockReturnThis(),
|
||||
lean: jest.fn().mockReturnThis(),
|
||||
deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }),
|
||||
};
|
||||
|
||||
const mockSchema = {
|
||||
findOneAndUpdate: jest.fn(),
|
||||
updateOne: jest.fn(),
|
||||
findOne: jest.fn(() => ({
|
||||
lean: jest.fn(),
|
||||
})),
|
||||
find: jest.fn(() => mockFindQuery),
|
||||
deleteMany: jest.fn(),
|
||||
};
|
||||
|
||||
mongoose.model.mockReturnValue(mockSchema);
|
||||
|
||||
jest.mock('~/models/schema/messageSchema', () => mockSchema);
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
}));
|
||||
|
||||
const {
|
||||
saveMessage,
|
||||
getMessages,
|
||||
updateMessage,
|
||||
deleteMessages,
|
||||
updateMessageText,
|
||||
deleteMessagesSince,
|
||||
} = require('~/models/Message');
|
||||
|
||||
describe('Message Operations', () => {
|
||||
let mockReq;
|
||||
let mockMessage;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
};
|
||||
|
||||
mockMessage = {
|
||||
messageId: 'msg123',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Hello, world!',
|
||||
user: 'user123',
|
||||
};
|
||||
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue({
|
||||
toObject: () => mockMessage,
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveMessage', () => {
|
||||
it('should save a message for an authenticated user', async () => {
|
||||
const result = await saveMessage(mockReq, mockMessage);
|
||||
expect(result).toEqual(mockMessage);
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: 'msg123', user: 'user123' },
|
||||
expect.objectContaining({ user: 'user123' }),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error for unauthenticated user', async () => {
|
||||
mockReq.user = null;
|
||||
await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated');
|
||||
});
|
||||
|
||||
it('should throw an error for invalid conversation ID', async () => {
|
||||
mockMessage.conversationId = 'invalid-id';
|
||||
await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessageText', () => {
|
||||
it('should update message text for the authenticated user', async () => {
|
||||
await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' });
|
||||
expect(mockSchema.updateOne).toHaveBeenCalledWith(
|
||||
{ messageId: 'msg123', user: 'user123' },
|
||||
{ text: 'Updated text' },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessage', () => {
|
||||
it('should update a message for the authenticated user', async () => {
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage);
|
||||
const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' });
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({
|
||||
messageId: 'msg123',
|
||||
text: 'Hello, world!',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if message is not found', async () => {
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(null);
|
||||
await expect(
|
||||
updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessagesSince', () => {
|
||||
it('should delete messages only for the authenticated user', async () => {
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() });
|
||||
mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 });
|
||||
const result = await deleteMessagesSince(mockReq, {
|
||||
messageId: 'msg123',
|
||||
conversationId: 'convo123',
|
||||
});
|
||||
expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' });
|
||||
expect(mockSchema.find).not.toHaveBeenCalled();
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if no message is found', async () => {
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce(null);
|
||||
const result = await deleteMessagesSince(mockReq, {
|
||||
messageId: 'nonexistent',
|
||||
conversationId: 'convo123',
|
||||
});
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMessages', () => {
|
||||
it('should retrieve messages with the correct filter', async () => {
|
||||
const filter = { conversationId: 'convo123' };
|
||||
await getMessages(filter);
|
||||
expect(mockSchema.find).toHaveBeenCalledWith(filter);
|
||||
expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 });
|
||||
expect(mockFindQuery.lean).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessages', () => {
|
||||
it('should delete messages with the correct filter', async () => {
|
||||
await deleteMessages({ user: 'user123' });
|
||||
expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('Conversation Hijacking Prevention', () => {
|
||||
it('should not allow editing a message in another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
updateMessage(attackerReq, {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
text: 'Hacked message',
|
||||
}),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: victimMessageId, user: 'attacker123' },
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow deleting messages from another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user
|
||||
const result = await deleteMessagesSince(attackerReq, {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(mockSchema.findOne).toHaveBeenCalledWith({
|
||||
messageId: victimMessageId,
|
||||
user: 'attacker123',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not allow inserting a new message into another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = uuidv4(); // Use a valid UUID
|
||||
|
||||
await expect(
|
||||
saveMessage(attackerReq, {
|
||||
conversationId: victimConversationId,
|
||||
text: 'Inserted malicious message',
|
||||
messageId: 'new-msg-123',
|
||||
}),
|
||||
).resolves.not.toThrow(); // It should not throw an error
|
||||
|
||||
// Check that the message was saved with the attacker's user ID
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: 'new-msg-123', user: 'attacker123' },
|
||||
expect.objectContaining({
|
||||
user: 'attacker123',
|
||||
conversationId: victimConversationId,
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow retrieving messages from any conversation', async () => {
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
|
||||
await getMessages({ conversationId: victimConversationId });
|
||||
|
||||
expect(mockSchema.find).toHaveBeenCalledWith({
|
||||
conversationId: victimConversationId,
|
||||
});
|
||||
|
||||
mockSchema.find.mockReturnValueOnce({
|
||||
select: jest.fn().mockReturnThis(),
|
||||
sort: jest.fn().mockReturnThis(),
|
||||
lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]),
|
||||
});
|
||||
|
||||
const result = await getMessages({ conversationId: victimConversationId });
|
||||
expect(result).toEqual([{ text: 'Test message' }]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -38,7 +38,14 @@ module.exports = {
|
|||
savePreset: async (user, { presetId, newPresetId, defaultPreset, ...preset }) => {
|
||||
try {
|
||||
const setter = { $set: {} };
|
||||
const update = { presetId, ...preset };
|
||||
const { user: _, ...cleanPreset } = preset;
|
||||
const update = { presetId, ...cleanPreset };
|
||||
if (preset.tools && Array.isArray(preset.tools)) {
|
||||
update.tools =
|
||||
preset.tools
|
||||
.map((tool) => tool?.pluginKey ?? tool)
|
||||
.filter((toolName) => typeof toolName === 'string') ?? [];
|
||||
}
|
||||
if (newPresetId) {
|
||||
update.presetId = newPresetId;
|
||||
}
|
||||
|
|
|
|||
136
api/models/Project.js
Normal file
136
api/models/Project.js
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
const { model } = require('mongoose');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const projectSchema = require('~/models/schema/projectSchema');
|
||||
|
||||
const Project = model('Project', projectSchema);
|
||||
|
||||
/**
|
||||
* Retrieve a project by ID and convert the found project document to a plain object.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to find and return as a plain object.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document, or `null` if no project is found.
|
||||
*/
|
||||
const getProjectById = async function (projectId, fieldsToSelect = null) {
|
||||
const query = Project.findById(projectId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve a project by name and convert the found project document to a plain object.
|
||||
* If the project with the given name doesn't exist and the name is "instance", create it and return the lean version.
|
||||
*
|
||||
* @param {string} projectName - The name of the project to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document.
|
||||
*/
|
||||
const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
||||
const query = { name: projectName };
|
||||
const update = { $setOnInsert: { name: projectName } };
|
||||
const options = {
|
||||
new: true,
|
||||
upsert: projectName === GLOBAL_PROJECT_NAME,
|
||||
lean: true,
|
||||
select: fieldsToSelect,
|
||||
};
|
||||
|
||||
return await Project.findOneAndUpdate(query, update, options);
|
||||
};
|
||||
|
||||
/**
|
||||
* Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an array of prompt group IDs from a project's promptGroupIds array.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove a prompt group ID from all projects.
|
||||
*
|
||||
* @param {string} promptGroupId - The ID of the prompt group to remove from projects.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeGroupFromAllProjects = async (promptGroupId) => {
|
||||
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||
};
|
||||
|
||||
/**
|
||||
* Add an array of agent IDs to a project's agentIds array, ensuring uniqueness.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} agentIds - The array of agent IDs to add to the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const addAgentIdsToProject = async function (projectId, agentIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { agentIds: { $each: agentIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an array of agent IDs from a project's agentIds array.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} agentIds - The array of agent IDs to remove from the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const removeAgentIdsFromProject = async function (projectId, agentIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { agentIds: { $in: agentIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an agent ID from all projects.
|
||||
*
|
||||
* @param {string} agentId - The ID of the agent to remove from projects.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeAgentFromAllProjects = async (agentId) => {
|
||||
await Project.updateMany({}, { $pull: { agentIds: agentId } });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getProjectById,
|
||||
getProjectByName,
|
||||
/* prompts */
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
/* agents */
|
||||
addAgentIdsToProject,
|
||||
removeAgentIdsFromProject,
|
||||
removeAgentFromAllProjects,
|
||||
};
|
||||
|
|
@ -1,52 +1,539 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
} = require('./Project');
|
||||
const { Prompt, PromptGroup } = require('./schema/promptSchema');
|
||||
const { escapeRegExp } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const promptSchema = mongoose.Schema(
|
||||
{
|
||||
title: {
|
||||
type: String,
|
||||
required: true,
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get prompt groups
|
||||
* @param {Object} query
|
||||
* @param {number} skip
|
||||
* @param {number} limit
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createGroupPipeline = (query, skip, limit) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{ $skip: skip },
|
||||
{ $limit: limit },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
// 'productionPrompt._id': 1,
|
||||
// 'productionPrompt.type': 1,
|
||||
},
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
];
|
||||
};
|
||||
|
||||
const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema);
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get all prompt groups
|
||||
* @param {Object} query
|
||||
* @param {Partial<MongoPromptGroup>} $project
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createAllGroupsPipeline = (
|
||||
query,
|
||||
$project = {
|
||||
name: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
command: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project,
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all prompt groups with filters
|
||||
* @param {ServerRequest} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getAllPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { name, ...query } = filter;
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds');
|
||||
if (project && project.promptGroupIds && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
|
||||
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||
} catch (error) {
|
||||
console.error('Error getting all prompt groups', error);
|
||||
return { message: 'Error getting all prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {ServerRequest} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { pageNumber = 1, pageSize = 10, name, ...query } = filter;
|
||||
|
||||
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
|
||||
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
// const projects = req.user.projects || []; // TODO: handle multiple projects
|
||||
const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds');
|
||||
if (project && project.promptGroupIds && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const skip = (validatedPageNumber - 1) * validatedPageSize;
|
||||
const limit = validatedPageSize;
|
||||
|
||||
const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit);
|
||||
const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }];
|
||||
|
||||
const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([
|
||||
PromptGroup.aggregate(promptGroupsPipeline).exec(),
|
||||
PromptGroup.aggregate(totalPromptGroupsPipeline).exec(),
|
||||
]);
|
||||
|
||||
const promptGroups = promptGroupsResults;
|
||||
const totalPromptGroups =
|
||||
totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0;
|
||||
|
||||
return {
|
||||
promptGroups,
|
||||
pageNumber: validatedPageNumber.toString(),
|
||||
pageSize: validatedPageSize.toString(),
|
||||
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {Object} fields
|
||||
* @param {string} fields._id
|
||||
* @param {string} fields.author
|
||||
* @param {string} fields.role
|
||||
* @returns {Promise<TDeletePromptGroupResponse>}
|
||||
*/
|
||||
const deletePromptGroup = async ({ _id, author, role }) => {
|
||||
const query = { _id, author };
|
||||
const groupQuery = { groupId: new ObjectId(_id), author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
delete groupQuery.author;
|
||||
}
|
||||
const response = await PromptGroup.deleteOne(query);
|
||||
|
||||
if (!response || response.deletedCount === 0) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
await Prompt.deleteMany(groupQuery);
|
||||
await removeGroupFromAllProjects(_id);
|
||||
return { message: 'Prompt group deleted successfully' };
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
savePrompt: async ({ title, prompt }) => {
|
||||
getPromptGroups,
|
||||
deletePromptGroup,
|
||||
getAllPromptGroups,
|
||||
/**
|
||||
* Create a prompt and its respective group
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
createPromptGroup: async (saveData) => {
|
||||
try {
|
||||
await Prompt.create({
|
||||
title,
|
||||
prompt,
|
||||
});
|
||||
return { title, prompt };
|
||||
const { prompt, group, author, authorName } = saveData;
|
||||
|
||||
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||
{ ...group, author, authorName, productionId: null },
|
||||
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
const newPrompt = await Prompt.findOneAndUpdate(
|
||||
{ ...prompt, author, groupId: newPromptGroup._id },
|
||||
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
newPromptGroup = await PromptGroup.findByIdAndUpdate(
|
||||
newPromptGroup._id,
|
||||
{ productionId: newPrompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
return {
|
||||
prompt: newPrompt,
|
||||
group: {
|
||||
...newPromptGroup,
|
||||
productionPrompt: { prompt: newPrompt.prompt },
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt group', error);
|
||||
throw new Error('Error saving prompt group');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Save a prompt
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
savePrompt: async (saveData) => {
|
||||
try {
|
||||
const { prompt, author } = saveData;
|
||||
const newPromptData = {
|
||||
...prompt,
|
||||
author,
|
||||
};
|
||||
|
||||
/** @type {TPrompt} */
|
||||
let newPrompt;
|
||||
try {
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
} catch (error) {
|
||||
if (error?.message?.includes('groupId_1_version_1')) {
|
||||
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
}
|
||||
|
||||
return { prompt: newPrompt };
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt', error);
|
||||
return { prompt: 'Error saving prompt' };
|
||||
return { message: 'Error saving prompt' };
|
||||
}
|
||||
},
|
||||
getPrompts: async (filter) => {
|
||||
try {
|
||||
return await Prompt.find(filter).lean();
|
||||
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompts', error);
|
||||
return { prompt: 'Error getting prompts' };
|
||||
return { message: 'Error getting prompts' };
|
||||
}
|
||||
},
|
||||
deletePrompts: async (filter) => {
|
||||
getPrompt: async (filter) => {
|
||||
try {
|
||||
return await Prompt.deleteMany(filter);
|
||||
if (filter.groupId) {
|
||||
filter.groupId = new ObjectId(filter.groupId);
|
||||
}
|
||||
return await Prompt.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompts', error);
|
||||
return { prompt: 'Error deleting prompts' };
|
||||
logger.error('Error getting prompt', error);
|
||||
return { message: 'Error getting prompt' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {TGetRandomPromptsRequest} filter
|
||||
* @returns {Promise<TGetRandomPromptsResponse>}
|
||||
*/
|
||||
getRandomPromptGroups: async (filter) => {
|
||||
try {
|
||||
const result = await PromptGroup.aggregate([
|
||||
{
|
||||
$match: {
|
||||
category: { $ne: '' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$group: {
|
||||
_id: '$category',
|
||||
promptGroup: { $first: '$$ROOT' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$replaceRoot: { newRoot: '$promptGroup' },
|
||||
},
|
||||
{
|
||||
$sample: { size: +filter.limit + +filter.skip },
|
||||
},
|
||||
{
|
||||
$skip: +filter.skip,
|
||||
},
|
||||
{
|
||||
$limit: +filter.limit,
|
||||
},
|
||||
]);
|
||||
return { prompts: result };
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroupsWithPrompts: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter)
|
||||
.populate({
|
||||
path: 'prompts',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroup: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
return { message: 'Error getting prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Deletes a prompt and its corresponding prompt group if it is the last prompt in the group.
|
||||
*
|
||||
* @param {Object} options - The options for deleting the prompt.
|
||||
* @param {ObjectId|string} options.promptId - The ID of the prompt to delete.
|
||||
* @param {ObjectId|string} options.groupId - The ID of the prompt's group.
|
||||
* @param {ObjectId|string} options.author - The ID of the prompt's author.
|
||||
* @param {string} options.role - The role of the prompt's author.
|
||||
* @return {Promise<TDeletePromptResponse>} An object containing the result of the deletion.
|
||||
* If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'.
|
||||
* If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group.
|
||||
* If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'.
|
||||
*/
|
||||
deletePrompt: async ({ promptId, groupId, author, role }) => {
|
||||
const query = { _id: promptId, groupId, author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const { deletedCount } = await Prompt.deleteOne(query);
|
||||
if (deletedCount === 0) {
|
||||
throw new Error('Failed to delete the prompt');
|
||||
}
|
||||
|
||||
const remainingPrompts = await Prompt.find({ groupId })
|
||||
.select('_id')
|
||||
.sort({ createdAt: 1 })
|
||||
.lean();
|
||||
|
||||
if (remainingPrompts.length === 0) {
|
||||
await PromptGroup.deleteOne({ _id: groupId });
|
||||
await removeGroupFromAllProjects(groupId);
|
||||
|
||||
return {
|
||||
prompt: 'Prompt deleted successfully',
|
||||
promptGroup: {
|
||||
message: 'Prompt group deleted successfully',
|
||||
id: groupId,
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const promptGroup = await PromptGroup.findById(groupId).lean();
|
||||
if (promptGroup.productionId.toString() === promptId.toString()) {
|
||||
await PromptGroup.updateOne(
|
||||
{ _id: groupId },
|
||||
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
|
||||
);
|
||||
}
|
||||
|
||||
return { prompt: 'Prompt deleted successfully' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Update prompt group
|
||||
* @param {Partial<MongoPromptGroup>} filter - Filter to find prompt group
|
||||
* @param {Partial<MongoPromptGroup>} data - Data to update
|
||||
* @returns {Promise<TUpdatePromptGroupResponse>}
|
||||
*/
|
||||
updatePromptGroup: async (filter, data) => {
|
||||
try {
|
||||
const updateOps = {};
|
||||
if (data.removeProjectIds) {
|
||||
for (const projectId of data.removeProjectIds) {
|
||||
await removeGroupIdsFromProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$pull = { projectIds: { $in: data.removeProjectIds } };
|
||||
delete data.removeProjectIds;
|
||||
}
|
||||
|
||||
if (data.projectIds) {
|
||||
for (const projectId of data.projectIds) {
|
||||
await addGroupIdsToProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$addToSet = { projectIds: { $each: data.projectIds } };
|
||||
delete data.projectIds;
|
||||
}
|
||||
|
||||
const updateData = { ...data, ...updateOps };
|
||||
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
|
||||
if (!updatedDoc) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
return updatedDoc;
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt group', error);
|
||||
return { message: 'Error updating prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Function to make a prompt production based on its ID.
|
||||
* @param {String} promptId - The ID of the prompt to make production.
|
||||
* @returns {Object} The result of the production operation.
|
||||
*/
|
||||
makePromptProduction: async (promptId) => {
|
||||
try {
|
||||
const prompt = await Prompt.findById(promptId).lean();
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt not found');
|
||||
}
|
||||
|
||||
await PromptGroup.findByIdAndUpdate(
|
||||
prompt.groupId,
|
||||
{ productionId: prompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.exec();
|
||||
|
||||
return {
|
||||
message: 'Prompt production made successfully',
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error making prompt production', error);
|
||||
return { message: 'Error making prompt production' };
|
||||
}
|
||||
},
|
||||
updatePromptLabels: async (_id, labels) => {
|
||||
try {
|
||||
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
if (response.matchedCount === 0) {
|
||||
return { message: 'Prompt not found' };
|
||||
}
|
||||
return { message: 'Prompt labels updated successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt labels', error);
|
||||
return { message: 'Error updating prompt labels' };
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
|
|||
171
api/models/Role.js
Normal file
171
api/models/Role.js
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
const {
|
||||
CacheKeys,
|
||||
SystemRoles,
|
||||
roleDefaults,
|
||||
PermissionTypes,
|
||||
removeNullishValues,
|
||||
agentPermissionsSchema,
|
||||
promptPermissionsSchema,
|
||||
bookmarkPermissionsSchema,
|
||||
multiConvoPermissionsSchema,
|
||||
} = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const Role = require('~/models/schema/roleSchema');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<Object>} A plain object representing the role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole;
|
||||
}
|
||||
let query = Role.findOne({ name: roleName });
|
||||
if (fieldsToSelect) {
|
||||
query = query.select(fieldsToSelect);
|
||||
}
|
||||
let role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName]) {
|
||||
role = roleDefaults[roleName];
|
||||
role = await new Role(role).save();
|
||||
await cache.set(roleName, role);
|
||||
return role.toObject();
|
||||
}
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to retrieve or create role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Update role values by name.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to update.
|
||||
* @param {Partial<TRole>} updates - The fields to update.
|
||||
* @returns {Promise<TRole>} Updated role document.
|
||||
*/
|
||||
const updateRoleByName = async function (roleName, updates) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
{ new: true, lean: true },
|
||||
)
|
||||
.select('-__v')
|
||||
.lean()
|
||||
.exec();
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to update role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
const permissionSchemas = {
|
||||
[PermissionTypes.AGENTS]: agentPermissionsSchema,
|
||||
[PermissionTypes.PROMPTS]: promptPermissionsSchema,
|
||||
[PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema,
|
||||
[PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema,
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates access permissions for a specific role and multiple permission types.
|
||||
* @param {SystemRoles} roleName - The role to update.
|
||||
* @param {Object.<PermissionTypes, Object.<Permissions, boolean>>} permissionsUpdate - Permissions to update and their values.
|
||||
*/
|
||||
async function updateAccessPermissions(roleName, permissionsUpdate) {
|
||||
const updates = {};
|
||||
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
|
||||
if (permissionSchemas[permissionType]) {
|
||||
updates[permissionType] = removeNullishValues(permissions);
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(updates).length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return;
|
||||
}
|
||||
|
||||
const updatedPermissions = {};
|
||||
let hasChanges = false;
|
||||
|
||||
for (const [permissionType, permissions] of Object.entries(updates)) {
|
||||
const currentPermissions = role[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentPermissions };
|
||||
|
||||
for (const [permission, value] of Object.entries(permissions)) {
|
||||
if (currentPermissions[permission] !== value) {
|
||||
updatedPermissions[permissionType][permission] = value;
|
||||
hasChanges = true;
|
||||
logger.info(
|
||||
`Updating '${roleName}' role ${permissionType} '${permission}' permission from ${currentPermissions[permission]} to: ${value}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasChanges) {
|
||||
await updateRoleByName(roleName, updatedPermissions);
|
||||
logger.info(`Updated '${roleName}' role permissions`);
|
||||
} else {
|
||||
logger.info(`No changes needed for '${roleName}' role permissions`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to update ${roleName} role permissions:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize default roles in the system.
|
||||
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
|
||||
* Updates existing roles with new permission types if they're missing.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const initializeRoles = async function () {
|
||||
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
|
||||
|
||||
for (const roleName of defaultRoles) {
|
||||
let role = await Role.findOne({ name: roleName });
|
||||
|
||||
if (!role) {
|
||||
// Create new role if it doesn't exist
|
||||
role = new Role(roleDefaults[roleName]);
|
||||
} else {
|
||||
// Add missing permission types
|
||||
let isUpdated = false;
|
||||
for (const permType of Object.values(PermissionTypes)) {
|
||||
if (!role[permType]) {
|
||||
role[permType] = roleDefaults[roleName][permType];
|
||||
isUpdated = true;
|
||||
}
|
||||
}
|
||||
if (isUpdated) {
|
||||
await role.save();
|
||||
}
|
||||
}
|
||||
await role.save();
|
||||
}
|
||||
};
|
||||
module.exports = {
|
||||
getRoleByName,
|
||||
initializeRoles,
|
||||
updateRoleByName,
|
||||
updateAccessPermissions,
|
||||
};
|
||||
420
api/models/Role.spec.js
Normal file
420
api/models/Role.spec.js
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
SystemRoles,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
Permissions,
|
||||
} = require('librechat-data-provider');
|
||||
const { updateAccessPermissions, initializeRoles } = require('~/models/Role');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const Role = require('~/models/schema/roleSchema');
|
||||
|
||||
// Mock the cache
|
||||
jest.mock('~/cache/getLogStores', () => {
|
||||
return jest.fn().mockReturnValue({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
del: jest.fn(),
|
||||
});
|
||||
});
|
||||
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Role.deleteMany({});
|
||||
getLogStores.mockClear();
|
||||
});
|
||||
|
||||
describe('updateAccessPermissions', () => {
|
||||
it('should update permissions when changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should not update permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle non-existent roles', async () => {
|
||||
await updateAccessPermissions('NON_EXISTENT_ROLE', {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
},
|
||||
});
|
||||
|
||||
const role = await Role.findOne({ name: 'NON_EXISTENT_ROLE' });
|
||||
expect(role).toBeNull();
|
||||
});
|
||||
|
||||
it('should update only specified permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
SHARED_GLOBAL: true,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle partial updates', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
USE: false,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should update multiple permission types at once', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: {
|
||||
USE: true,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARED_GLOBAL: true },
|
||||
[PermissionTypes.BOOKMARKS]: { USE: false },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
expect(updatedRole[PermissionTypes.BOOKMARKS]).toEqual({
|
||||
USE: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle updates for a single permission type', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { USE: false, SHARED_GLOBAL: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: false,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should update MULTI_CONVO permissions', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: false,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should update MULTI_CONVO permissions along with other permission types', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: false,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.PROMPTS]: { SHARED_GLOBAL: true },
|
||||
[PermissionTypes.MULTI_CONVO]: { USE: true },
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({
|
||||
CREATE: true,
|
||||
USE: true,
|
||||
SHARED_GLOBAL: true,
|
||||
});
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should not update MULTI_CONVO permissions when no changes are needed', async () => {
|
||||
await new Role({
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
},
|
||||
}).save();
|
||||
|
||||
await updateAccessPermissions(SystemRoles.USER, {
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
USE: true,
|
||||
},
|
||||
});
|
||||
|
||||
const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({
|
||||
USE: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('initializeRoles', () => {
|
||||
beforeEach(async () => {
|
||||
await Role.deleteMany({});
|
||||
});
|
||||
|
||||
it('should create default roles if they do not exist', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(adminRole).toBeTruthy();
|
||||
expect(userRole).toBeTruthy();
|
||||
|
||||
// Check if all permission types exist
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminRole[permType]).toBeDefined();
|
||||
expect(userRole[permType]).toBeDefined();
|
||||
});
|
||||
|
||||
// Check if permissions match defaults (example for ADMIN role)
|
||||
expect(adminRole[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true);
|
||||
expect(adminRole[PermissionTypes.BOOKMARKS].USE).toBe(true);
|
||||
expect(adminRole[PermissionTypes.AGENTS].CREATE).toBe(true);
|
||||
});
|
||||
|
||||
it('should not modify existing permissions for existing roles', async () => {
|
||||
const customUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: {
|
||||
[Permissions.USE]: false,
|
||||
},
|
||||
};
|
||||
|
||||
await new Role(customUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.PROMPTS]).toEqual(customUserRole[PermissionTypes.PROMPTS]);
|
||||
expect(userRole[PermissionTypes.BOOKMARKS]).toEqual(customUserRole[PermissionTypes.BOOKMARKS]);
|
||||
expect(userRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
});
|
||||
|
||||
it('should add new permission types to existing roles', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS],
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(userRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle multiple runs without duplicating or modifying data', async () => {
|
||||
await initializeRoles();
|
||||
await initializeRoles();
|
||||
|
||||
const adminRoles = await Role.find({ name: SystemRoles.ADMIN });
|
||||
const userRoles = await Role.find({ name: SystemRoles.USER });
|
||||
|
||||
expect(adminRoles).toHaveLength(1);
|
||||
expect(userRoles).toHaveLength(1);
|
||||
|
||||
const adminRole = adminRoles[0].toObject();
|
||||
const userRole = userRoles[0].toObject();
|
||||
|
||||
// Check if all permission types exist
|
||||
Object.values(PermissionTypes).forEach((permType) => {
|
||||
expect(adminRole[permType]).toBeDefined();
|
||||
expect(userRole[permType]).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
it('should update roles with missing permission types from roleDefaults', async () => {
|
||||
const partialAdminRole = {
|
||||
name: SystemRoles.ADMIN,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.ADMIN][PermissionTypes.BOOKMARKS],
|
||||
};
|
||||
|
||||
await new Role(partialAdminRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
|
||||
expect(adminRole[PermissionTypes.PROMPTS]).toEqual(partialAdminRole[PermissionTypes.PROMPTS]);
|
||||
expect(adminRole[PermissionTypes.AGENTS]).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].CREATE).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].USE).toBeDefined();
|
||||
expect(adminRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined();
|
||||
});
|
||||
|
||||
it('should include MULTI_CONVO permissions when creating default roles', async () => {
|
||||
await initializeRoles();
|
||||
|
||||
const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean();
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(adminRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
|
||||
// Check if MULTI_CONVO permissions match defaults
|
||||
expect(adminRole[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.ADMIN][PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBe(
|
||||
roleDefaults[SystemRoles.USER][PermissionTypes.MULTI_CONVO].USE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add MULTI_CONVO permissions to existing roles without them', async () => {
|
||||
const partialUserRole = {
|
||||
name: SystemRoles.USER,
|
||||
[PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS],
|
||||
[PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS],
|
||||
};
|
||||
|
||||
await new Role(partialUserRole).save();
|
||||
|
||||
await initializeRoles();
|
||||
|
||||
const userRole = await Role.findOne({ name: SystemRoles.USER }).lean();
|
||||
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined();
|
||||
expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
|
@ -1,76 +1,275 @@
|
|||
const crypto = require('crypto');
|
||||
const mongoose = require('mongoose');
|
||||
const signPayload = require('~/server/services/signPayload');
|
||||
const { hashToken } = require('~/server/utils/crypto');
|
||||
const sessionSchema = require('./schema/session');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const Session = mongoose.model('Session', sessionSchema);
|
||||
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7;
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
|
||||
const sessionSchema = mongoose.Schema({
|
||||
refreshTokenHash: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
expiration: {
|
||||
type: Date,
|
||||
required: true,
|
||||
expires: 0,
|
||||
},
|
||||
user: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
});
|
||||
/**
|
||||
* Error class for Session-related errors
|
||||
*/
|
||||
class SessionError extends Error {
|
||||
constructor(message, code = 'SESSION_ERROR') {
|
||||
super(message);
|
||||
this.name = 'SessionError';
|
||||
this.code = code;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new session for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @param {Object} options - Additional options for session creation
|
||||
* @param {Date} options.expiration - Custom expiration date
|
||||
* @returns {Promise<{session: Session, refreshToken: string}>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const createSession = async (userId, options = {}) => {
|
||||
if (!userId) {
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
sessionSchema.methods.generateRefreshToken = async function () {
|
||||
try {
|
||||
let expiresIn;
|
||||
if (this.expiration) {
|
||||
expiresIn = this.expiration.getTime();
|
||||
} else {
|
||||
expiresIn = Date.now() + expires;
|
||||
this.expiration = new Date(expiresIn);
|
||||
const session = new Session({
|
||||
user: userId,
|
||||
expiration: options.expiration || new Date(Date.now() + expires),
|
||||
});
|
||||
const refreshToken = await generateRefreshToken(session);
|
||||
return { session, refreshToken };
|
||||
} catch (error) {
|
||||
logger.error('[createSession] Error creating session:', error);
|
||||
throw new SessionError('Failed to create session', 'CREATE_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Finds a session by various parameters
|
||||
* @param {Object} params - Search parameters
|
||||
* @param {string} [params.refreshToken] - The refresh token to search by
|
||||
* @param {string} [params.userId] - The user ID to search by
|
||||
* @param {string} [params.sessionId] - The session ID to search by
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {boolean} [options.lean=true] - Whether to return plain objects instead of documents
|
||||
* @returns {Promise<Session|null>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const findSession = async (params, options = { lean: true }) => {
|
||||
try {
|
||||
const query = {};
|
||||
|
||||
if (!params.refreshToken && !params.userId && !params.sessionId) {
|
||||
throw new SessionError('At least one search parameter is required', 'INVALID_SEARCH_PARAMS');
|
||||
}
|
||||
|
||||
if (params.refreshToken) {
|
||||
const tokenHash = await hashToken(params.refreshToken);
|
||||
query.refreshTokenHash = tokenHash;
|
||||
}
|
||||
|
||||
if (params.userId) {
|
||||
query.user = params.userId;
|
||||
}
|
||||
|
||||
if (params.sessionId) {
|
||||
const sessionId = params.sessionId.sessionId || params.sessionId;
|
||||
if (!mongoose.Types.ObjectId.isValid(sessionId)) {
|
||||
throw new SessionError('Invalid session ID format', 'INVALID_SESSION_ID');
|
||||
}
|
||||
query._id = sessionId;
|
||||
}
|
||||
|
||||
// Add expiration check to only return valid sessions
|
||||
query.expiration = { $gt: new Date() };
|
||||
|
||||
const sessionQuery = Session.findOne(query);
|
||||
|
||||
if (options.lean) {
|
||||
return await sessionQuery.lean();
|
||||
}
|
||||
|
||||
return await sessionQuery.exec();
|
||||
} catch (error) {
|
||||
logger.error('[findSession] Error finding session:', error);
|
||||
throw new SessionError('Failed to find session', 'FIND_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates session expiration
|
||||
* @param {Session|string} session - The session or session ID to update
|
||||
* @param {Date} [newExpiration] - Optional new expiration date
|
||||
* @returns {Promise<Session>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const updateExpiration = async (session, newExpiration) => {
|
||||
try {
|
||||
const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session;
|
||||
|
||||
if (!sessionDoc) {
|
||||
throw new SessionError('Session not found', 'SESSION_NOT_FOUND');
|
||||
}
|
||||
|
||||
sessionDoc.expiration = newExpiration || new Date(Date.now() + expires);
|
||||
return await sessionDoc.save();
|
||||
} catch (error) {
|
||||
logger.error('[updateExpiration] Error updating session:', error);
|
||||
throw new SessionError('Failed to update session expiration', 'UPDATE_EXPIRATION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a session by refresh token or session ID
|
||||
* @param {Object} params - Delete parameters
|
||||
* @param {string} [params.refreshToken] - The refresh token of the session to delete
|
||||
* @param {string} [params.sessionId] - The ID of the session to delete
|
||||
* @returns {Promise<Object>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const deleteSession = async (params) => {
|
||||
try {
|
||||
if (!params.refreshToken && !params.sessionId) {
|
||||
throw new SessionError(
|
||||
'Either refreshToken or sessionId is required',
|
||||
'INVALID_DELETE_PARAMS',
|
||||
);
|
||||
}
|
||||
|
||||
const query = {};
|
||||
|
||||
if (params.refreshToken) {
|
||||
query.refreshTokenHash = await hashToken(params.refreshToken);
|
||||
}
|
||||
|
||||
if (params.sessionId) {
|
||||
query._id = params.sessionId;
|
||||
}
|
||||
|
||||
const result = await Session.deleteOne(query);
|
||||
|
||||
if (result.deletedCount === 0) {
|
||||
logger.warn('[deleteSession] No session found to delete');
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[deleteSession] Error deleting session:', error);
|
||||
throw new SessionError('Failed to delete session', 'DELETE_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes all sessions for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {boolean} [options.excludeCurrentSession] - Whether to exclude the current session
|
||||
* @param {string} [options.currentSessionId] - The ID of the current session to exclude
|
||||
* @returns {Promise<Object>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const deleteAllUserSessions = async (userId, options = {}) => {
|
||||
try {
|
||||
if (!userId) {
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
// Extract userId if it's passed as an object
|
||||
const userIdString = userId.userId || userId;
|
||||
|
||||
if (!mongoose.Types.ObjectId.isValid(userIdString)) {
|
||||
throw new SessionError('Invalid user ID format', 'INVALID_USER_ID_FORMAT');
|
||||
}
|
||||
|
||||
const query = { user: userIdString };
|
||||
|
||||
if (options.excludeCurrentSession && options.currentSessionId) {
|
||||
query._id = { $ne: options.currentSessionId };
|
||||
}
|
||||
|
||||
const result = await Session.deleteMany(query);
|
||||
|
||||
if (result.deletedCount > 0) {
|
||||
logger.debug(
|
||||
`[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userIdString}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllUserSessions] Error deleting user sessions:', error);
|
||||
throw new SessionError('Failed to delete user sessions', 'DELETE_ALL_SESSIONS_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates a refresh token for a session
|
||||
* @param {Session} session - The session to generate a token for
|
||||
* @returns {Promise<string>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const generateRefreshToken = async (session) => {
|
||||
if (!session || !session.user) {
|
||||
throw new SessionError('Invalid session object', 'INVALID_SESSION');
|
||||
}
|
||||
|
||||
try {
|
||||
const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires;
|
||||
|
||||
if (!session.expiration) {
|
||||
session.expiration = new Date(expiresIn);
|
||||
}
|
||||
|
||||
const refreshToken = await signPayload({
|
||||
payload: { id: this.user },
|
||||
payload: {
|
||||
id: session.user,
|
||||
sessionId: session._id,
|
||||
},
|
||||
secret: process.env.JWT_REFRESH_SECRET,
|
||||
expirationTime: Math.floor((expiresIn - Date.now()) / 1000),
|
||||
});
|
||||
|
||||
const hash = crypto.createHash('sha256');
|
||||
this.refreshTokenHash = hash.update(refreshToken).digest('hex');
|
||||
|
||||
await this.save();
|
||||
session.refreshTokenHash = await hashToken(refreshToken);
|
||||
await session.save();
|
||||
|
||||
return refreshToken;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'Error generating refresh token. Is a `JWT_REFRESH_SECRET` set in the .env file?\n\n',
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
logger.error('[generateRefreshToken] Error generating refresh token:', error);
|
||||
throw new SessionError('Failed to generate refresh token', 'GENERATE_TOKEN_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
sessionSchema.statics.deleteAllUserSessions = async function (userId) {
|
||||
/**
|
||||
* Counts active sessions for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @returns {Promise<number>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const countActiveSessions = async (userId) => {
|
||||
try {
|
||||
if (!userId) {
|
||||
return;
|
||||
}
|
||||
const result = await this.deleteMany({ user: userId });
|
||||
if (result && result?.deletedCount > 0) {
|
||||
logger.debug(
|
||||
`[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userId}.`,
|
||||
);
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
return await Session.countDocuments({
|
||||
user: userId,
|
||||
expiration: { $gt: new Date() },
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllUserSessions] Error in deleting user sessions:', error);
|
||||
throw error;
|
||||
logger.error('[countActiveSessions] Error counting active sessions:', error);
|
||||
throw new SessionError('Failed to count active sessions', 'COUNT_SESSIONS_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
const Session = mongoose.model('Session', sessionSchema);
|
||||
|
||||
module.exports = Session;
|
||||
module.exports = {
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
SessionError,
|
||||
};
|
||||
|
|
|
|||
340
api/models/Share.js
Normal file
340
api/models/Share.js
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { Conversation } = require('~/models/Conversation');
|
||||
const SharedLink = require('./schema/shareSchema');
|
||||
const { getMessages } = require('./Message');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
class ShareServiceError extends Error {
|
||||
constructor(message, code) {
|
||||
super(message);
|
||||
this.name = 'ShareServiceError';
|
||||
this.code = code;
|
||||
}
|
||||
}
|
||||
|
||||
const memoizedAnonymizeId = (prefix) => {
|
||||
const memo = new Map();
|
||||
return (id) => {
|
||||
if (!memo.has(id)) {
|
||||
memo.set(id, `${prefix}_${nanoid()}`);
|
||||
}
|
||||
return memo.get(id);
|
||||
};
|
||||
};
|
||||
|
||||
const anonymizeConvoId = memoizedAnonymizeId('convo');
|
||||
const anonymizeAssistantId = memoizedAnonymizeId('a');
|
||||
const anonymizeMessageId = (id) =>
|
||||
id === Constants.NO_PARENT ? id : memoizedAnonymizeId('msg')(id);
|
||||
|
||||
function anonymizeConvo(conversation) {
|
||||
if (!conversation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const newConvo = { ...conversation };
|
||||
if (newConvo.assistant_id) {
|
||||
newConvo.assistant_id = anonymizeAssistantId(newConvo.assistant_id);
|
||||
}
|
||||
return newConvo;
|
||||
}
|
||||
|
||||
function anonymizeMessages(messages, newConvoId) {
|
||||
if (!Array.isArray(messages)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const idMap = new Map();
|
||||
return messages.map((message) => {
|
||||
const newMessageId = anonymizeMessageId(message.messageId);
|
||||
idMap.set(message.messageId, newMessageId);
|
||||
|
||||
return {
|
||||
...message,
|
||||
messageId: newMessageId,
|
||||
parentMessageId:
|
||||
idMap.get(message.parentMessageId) || anonymizeMessageId(message.parentMessageId),
|
||||
conversationId: newConvoId,
|
||||
model: message.model?.startsWith('asst_')
|
||||
? anonymizeAssistantId(message.model)
|
||||
: message.model,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
async function getSharedMessages(shareId) {
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId, isPublic: true })
|
||||
.populate({
|
||||
path: 'messages',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
|
||||
if (!share?.conversationId || !share.isPublic) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const newConvoId = anonymizeConvoId(share.conversationId);
|
||||
const result = {
|
||||
...share,
|
||||
conversationId: newConvoId,
|
||||
messages: anonymizeMessages(share.messages, newConvoId),
|
||||
};
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[getShare] Error getting share link', {
|
||||
error: error.message,
|
||||
shareId,
|
||||
});
|
||||
throw new ShareServiceError('Error getting share link', 'SHARE_FETCH_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortDirection, search) {
|
||||
try {
|
||||
const query = { user, isPublic };
|
||||
|
||||
if (pageParam) {
|
||||
if (sortDirection === 'desc') {
|
||||
query[sortBy] = { $lt: pageParam };
|
||||
} else {
|
||||
query[sortBy] = { $gt: pageParam };
|
||||
}
|
||||
}
|
||||
|
||||
if (search && search.trim()) {
|
||||
try {
|
||||
const searchResults = await Conversation.meiliSearch(search);
|
||||
|
||||
if (!searchResults?.hits?.length) {
|
||||
return {
|
||||
links: [],
|
||||
nextCursor: undefined,
|
||||
hasNextPage: false,
|
||||
};
|
||||
}
|
||||
|
||||
const conversationIds = searchResults.hits.map((hit) => hit.conversationId);
|
||||
query['conversationId'] = { $in: conversationIds };
|
||||
} catch (searchError) {
|
||||
logger.error('[getSharedLinks] Meilisearch error', {
|
||||
error: searchError.message,
|
||||
user,
|
||||
});
|
||||
return {
|
||||
links: [],
|
||||
nextCursor: undefined,
|
||||
hasNextPage: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const sort = {};
|
||||
sort[sortBy] = sortDirection === 'desc' ? -1 : 1;
|
||||
|
||||
if (Array.isArray(query.conversationId)) {
|
||||
query.conversationId = { $in: query.conversationId };
|
||||
}
|
||||
|
||||
const sharedLinks = await SharedLink.find(query)
|
||||
.sort(sort)
|
||||
.limit(pageSize + 1)
|
||||
.select('-__v -user')
|
||||
.lean();
|
||||
|
||||
const hasNextPage = sharedLinks.length > pageSize;
|
||||
const links = sharedLinks.slice(0, pageSize);
|
||||
|
||||
const nextCursor = hasNextPage ? links[links.length - 1][sortBy] : undefined;
|
||||
|
||||
return {
|
||||
links: links.map((link) => ({
|
||||
shareId: link.shareId,
|
||||
title: link?.title || 'Untitled',
|
||||
isPublic: link.isPublic,
|
||||
createdAt: link.createdAt,
|
||||
conversationId: link.conversationId,
|
||||
})),
|
||||
nextCursor,
|
||||
hasNextPage,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[getSharedLinks] Error getting shares', {
|
||||
error: error.message,
|
||||
user,
|
||||
});
|
||||
throw new ShareServiceError('Error getting shares', 'SHARES_FETCH_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteAllSharedLinks(user) {
|
||||
try {
|
||||
const result = await SharedLink.deleteMany({ user });
|
||||
return {
|
||||
message: 'All shared links deleted successfully',
|
||||
deletedCount: result.deletedCount,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllSharedLinks] Error deleting shared links', {
|
||||
error: error.message,
|
||||
user,
|
||||
});
|
||||
throw new ShareServiceError('Error deleting shared links', 'BULK_DELETE_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
async function createSharedLink(user, conversationId) {
|
||||
if (!user || !conversationId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
try {
|
||||
const [existingShare, conversationMessages] = await Promise.all([
|
||||
SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(),
|
||||
getMessages({ conversationId }),
|
||||
]);
|
||||
|
||||
if (existingShare && existingShare.isPublic) {
|
||||
throw new ShareServiceError('Share already exists', 'SHARE_EXISTS');
|
||||
} else if (existingShare) {
|
||||
await SharedLink.deleteOne({ conversationId });
|
||||
}
|
||||
|
||||
const conversation = await Conversation.findOne({ conversationId }).lean();
|
||||
const title = conversation?.title || 'Untitled';
|
||||
|
||||
const shareId = nanoid();
|
||||
await SharedLink.create({
|
||||
shareId,
|
||||
conversationId,
|
||||
messages: conversationMessages,
|
||||
title,
|
||||
user,
|
||||
});
|
||||
|
||||
return { shareId, conversationId };
|
||||
} catch (error) {
|
||||
logger.error('[createSharedLink] Error creating shared link', {
|
||||
error: error.message,
|
||||
user,
|
||||
conversationId,
|
||||
});
|
||||
throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
async function getSharedLink(user, conversationId) {
|
||||
if (!user || !conversationId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId, user, isPublic: true })
|
||||
.select('shareId -_id')
|
||||
.lean();
|
||||
|
||||
if (!share) {
|
||||
return { shareId: null, success: false };
|
||||
}
|
||||
|
||||
return { shareId: share.shareId, success: true };
|
||||
} catch (error) {
|
||||
logger.error('[getSharedLink] Error getting shared link', {
|
||||
error: error.message,
|
||||
user,
|
||||
conversationId,
|
||||
});
|
||||
throw new ShareServiceError('Error getting shared link', 'SHARE_FETCH_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
async function updateSharedLink(user, shareId) {
|
||||
if (!user || !shareId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean();
|
||||
|
||||
if (!share) {
|
||||
throw new ShareServiceError('Share not found', 'SHARE_NOT_FOUND');
|
||||
}
|
||||
|
||||
const [updatedMessages] = await Promise.all([
|
||||
getMessages({ conversationId: share.conversationId }),
|
||||
]);
|
||||
|
||||
const newShareId = nanoid();
|
||||
const update = {
|
||||
messages: updatedMessages,
|
||||
user,
|
||||
shareId: newShareId,
|
||||
};
|
||||
|
||||
const updatedShare = await SharedLink.findOneAndUpdate({ shareId, user }, update, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
runValidators: true,
|
||||
}).lean();
|
||||
|
||||
if (!updatedShare) {
|
||||
throw new ShareServiceError('Share update failed', 'SHARE_UPDATE_ERROR');
|
||||
}
|
||||
|
||||
anonymizeConvo(updatedShare);
|
||||
|
||||
return { shareId: newShareId, conversationId: updatedShare.conversationId };
|
||||
} catch (error) {
|
||||
logger.error('[updateSharedLink] Error updating shared link', {
|
||||
error: error.message,
|
||||
user,
|
||||
shareId,
|
||||
});
|
||||
throw new ShareServiceError(
|
||||
error.code === 'SHARE_UPDATE_ERROR' ? error.message : 'Error updating shared link',
|
||||
error.code || 'SHARE_UPDATE_ERROR',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteSharedLink(user, shareId) {
|
||||
if (!user || !shareId) {
|
||||
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await SharedLink.findOneAndDelete({ shareId, user }).lean();
|
||||
|
||||
if (!result) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
shareId,
|
||||
message: 'Share deleted successfully',
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteSharedLink] Error deleting shared link', {
|
||||
error: error.message,
|
||||
user,
|
||||
shareId,
|
||||
});
|
||||
throw new ShareServiceError('Error deleting shared link', 'SHARE_DELETE_ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
SharedLink,
|
||||
getSharedLink,
|
||||
getSharedLinks,
|
||||
createSharedLink,
|
||||
updateSharedLink,
|
||||
deleteSharedLink,
|
||||
getSharedMessages,
|
||||
deleteAllSharedLinks,
|
||||
};
|
||||
192
api/models/Token.js
Normal file
192
api/models/Token.js
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { encryptV2 } = require('~/server/utils/crypto');
|
||||
const tokenSchema = require('./schema/tokenSchema');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Token model.
|
||||
* @type {mongoose.Model}
|
||||
*/
|
||||
const Token = mongoose.model('Token', tokenSchema);
|
||||
/**
|
||||
* Fixes the indexes for the Token collection from legacy TTL indexes to the new expiresAt index.
|
||||
*/
|
||||
async function fixIndexes() {
|
||||
try {
|
||||
const indexes = await Token.collection.indexes();
|
||||
logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2));
|
||||
const unwantedTTLIndexes = indexes.filter(
|
||||
(index) => index.key.createdAt === 1 && index.expireAfterSeconds !== undefined,
|
||||
);
|
||||
if (unwantedTTLIndexes.length === 0) {
|
||||
logger.debug('No unwanted Token indexes found.');
|
||||
return;
|
||||
}
|
||||
for (const index of unwantedTTLIndexes) {
|
||||
logger.debug(`Dropping unwanted Token index: ${index.name}`);
|
||||
await Token.collection.dropIndex(index.name);
|
||||
logger.debug(`Dropped Token index: ${index.name}`);
|
||||
}
|
||||
logger.debug('Token index cleanup completed successfully.');
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while fixing Token indexes:', error);
|
||||
}
|
||||
}
|
||||
|
||||
fixIndexes();
|
||||
|
||||
/**
|
||||
* Creates a new Token instance.
|
||||
* @param {Object} tokenData - The data for the new Token.
|
||||
* @param {mongoose.Types.ObjectId} tokenData.userId - The user's ID. It is required.
|
||||
* @param {String} tokenData.email - The user's email.
|
||||
* @param {String} tokenData.token - The token. It is required.
|
||||
* @param {Number} tokenData.expiresIn - The number of seconds until the token expires.
|
||||
* @returns {Promise<mongoose.Document>} The new Token instance.
|
||||
* @throws Will throw an error if token creation fails.
|
||||
*/
|
||||
async function createToken(tokenData) {
|
||||
try {
|
||||
const currentTime = new Date();
|
||||
const expiresAt = new Date(currentTime.getTime() + tokenData.expiresIn * 1000);
|
||||
|
||||
const newTokenData = {
|
||||
...tokenData,
|
||||
createdAt: currentTime,
|
||||
expiresAt,
|
||||
};
|
||||
|
||||
return await Token.create(newTokenData);
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while creating token:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a Token document that matches the provided query.
|
||||
* @param {Object} query - The query to match against.
|
||||
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
|
||||
* @param {String} query.token - The token value.
|
||||
* @param {String} [query.email] - The email of the user.
|
||||
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
|
||||
* @returns {Promise<Object|null>} The matched Token document, or null if not found.
|
||||
* @throws Will throw an error if the find operation fails.
|
||||
*/
|
||||
async function findToken(query) {
|
||||
try {
|
||||
const conditions = [];
|
||||
|
||||
if (query.userId) {
|
||||
conditions.push({ userId: query.userId });
|
||||
}
|
||||
if (query.token) {
|
||||
conditions.push({ token: query.token });
|
||||
}
|
||||
if (query.email) {
|
||||
conditions.push({ email: query.email });
|
||||
}
|
||||
if (query.identifier) {
|
||||
conditions.push({ identifier: query.identifier });
|
||||
}
|
||||
|
||||
const token = await Token.findOne({
|
||||
$and: conditions,
|
||||
}).lean();
|
||||
|
||||
return token;
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while finding token:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a Token document that matches the provided query.
|
||||
* @param {Object} query - The query to match against.
|
||||
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
|
||||
* @param {String} query.token - The token value.
|
||||
* @param {String} [query.email] - The email of the user.
|
||||
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
|
||||
* @param {Object} updateData - The data to update the Token with.
|
||||
* @returns {Promise<mongoose.Document|null>} The updated Token document, or null if not found.
|
||||
* @throws Will throw an error if the update operation fails.
|
||||
*/
|
||||
async function updateToken(query, updateData) {
|
||||
try {
|
||||
return await Token.findOneAndUpdate(query, updateData, { new: true });
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while updating token:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes all Token documents that match the provided token, user ID, or email.
|
||||
* @param {Object} query - The query to match against.
|
||||
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
|
||||
* @param {String} query.token - The token value.
|
||||
* @param {String} [query.email] - The email of the user.
|
||||
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
|
||||
* @returns {Promise<Object>} The result of the delete operation.
|
||||
* @throws Will throw an error if the delete operation fails.
|
||||
*/
|
||||
async function deleteTokens(query) {
|
||||
try {
|
||||
return await Token.deleteMany({
|
||||
$or: [
|
||||
{ userId: query.userId },
|
||||
{ token: query.token },
|
||||
{ email: query.email },
|
||||
{ identifier: query.identifier },
|
||||
],
|
||||
});
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while deleting tokens:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the OAuth token by creating or updating the token.
|
||||
* @param {object} fields
|
||||
* @param {string} fields.userId - The user's ID.
|
||||
* @param {string} fields.token - The full token to store.
|
||||
* @param {string} fields.identifier - Unique, alternative identifier for the token.
|
||||
* @param {number} fields.expiresIn - The number of seconds until the token expires.
|
||||
* @param {object} fields.metadata - Additional metadata to store with the token.
|
||||
* @param {string} [fields.type="oauth"] - The type of token. Default is 'oauth'.
|
||||
*/
|
||||
async function handleOAuthToken({
|
||||
token,
|
||||
userId,
|
||||
identifier,
|
||||
expiresIn,
|
||||
metadata,
|
||||
type = 'oauth',
|
||||
}) {
|
||||
const encrypedToken = await encryptV2(token);
|
||||
const tokenData = {
|
||||
type,
|
||||
userId,
|
||||
metadata,
|
||||
identifier,
|
||||
token: encrypedToken,
|
||||
expiresIn: parseInt(expiresIn, 10) || 3600,
|
||||
};
|
||||
|
||||
const existingToken = await findToken({ userId, identifier });
|
||||
if (existingToken) {
|
||||
return await updateToken({ identifier }, tokenData);
|
||||
} else {
|
||||
return await createToken(tokenData);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
findToken,
|
||||
createToken,
|
||||
updateToken,
|
||||
deleteTokens,
|
||||
handleOAuthToken,
|
||||
};
|
||||
96
api/models/ToolCall.js
Normal file
96
api/models/ToolCall.js
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
const ToolCall = require('./schema/toolCallSchema');
|
||||
|
||||
/**
|
||||
* Create a new tool call
|
||||
* @param {ToolCallData} toolCallData - The tool call data
|
||||
* @returns {Promise<ToolCallData>} The created tool call document
|
||||
*/
|
||||
async function createToolCall(toolCallData) {
|
||||
try {
|
||||
return await ToolCall.create(toolCallData);
|
||||
} catch (error) {
|
||||
throw new Error(`Error creating tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a tool call by ID
|
||||
* @param {string} id - The tool call document ID
|
||||
* @returns {Promise<ToolCallData|null>} The tool call document or null if not found
|
||||
*/
|
||||
async function getToolCallById(id) {
|
||||
try {
|
||||
return await ToolCall.findById(id).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tool calls by message ID and user
|
||||
* @param {string} messageId - The message ID
|
||||
* @param {string} userId - The user's ObjectId
|
||||
* @returns {Promise<Array>} Array of tool call documents
|
||||
*/
|
||||
async function getToolCallsByMessage(messageId, userId) {
|
||||
try {
|
||||
return await ToolCall.find({ messageId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tool calls by conversation ID and user
|
||||
* @param {string} conversationId - The conversation ID
|
||||
* @param {string} userId - The user's ObjectId
|
||||
* @returns {Promise<ToolCallData[]>} Array of tool call documents
|
||||
*/
|
||||
async function getToolCallsByConvo(conversationId, userId) {
|
||||
try {
|
||||
return await ToolCall.find({ conversationId, user: userId }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error fetching tool calls: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a tool call
|
||||
* @param {string} id - The tool call document ID
|
||||
* @param {Partial<ToolCallData>} updateData - The data to update
|
||||
* @returns {Promise<ToolCallData|null>} The updated tool call document or null if not found
|
||||
*/
|
||||
async function updateToolCall(id, updateData) {
|
||||
try {
|
||||
return await ToolCall.findByIdAndUpdate(id, updateData, { new: true }).lean();
|
||||
} catch (error) {
|
||||
throw new Error(`Error updating tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a tool call
|
||||
* @param {string} userId - The related user's ObjectId
|
||||
* @param {string} [conversationId] - The tool call conversation ID
|
||||
* @returns {Promise<{ ok?: number; n?: number; deletedCount?: number }>} The result of the delete operation
|
||||
*/
|
||||
async function deleteToolCalls(userId, conversationId) {
|
||||
try {
|
||||
const query = { user: userId };
|
||||
if (conversationId) {
|
||||
query.conversationId = conversationId;
|
||||
}
|
||||
return await ToolCall.deleteMany(query);
|
||||
} catch (error) {
|
||||
throw new Error(`Error deleting tool call: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createToolCall,
|
||||
updateToolCall,
|
||||
deleteToolCalls,
|
||||
getToolCallById,
|
||||
getToolCallsByConvo,
|
||||
getToolCallsByMessage,
|
||||
};
|
||||
|
|
@ -1,17 +1,18 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { isEnabled } = require('../server/utils/handleText');
|
||||
const { isEnabled } = require('~/server/utils/handleText');
|
||||
const transactionSchema = require('./schema/transaction');
|
||||
const { getMultiplier } = require('./tx');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { logger } = require('~/config');
|
||||
const Balance = require('./Balance');
|
||||
const cancelRate = 1.15;
|
||||
|
||||
// Method to calculate and set the tokenValue for a transaction
|
||||
/** Method to calculate and set the tokenValue for a transaction */
|
||||
transactionSchema.methods.calculateTokenValue = function () {
|
||||
if (!this.valueKey || !this.tokenType) {
|
||||
this.tokenValue = this.rawAmount;
|
||||
}
|
||||
const { valueKey, tokenType, model, endpointTokenConfig } = this;
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model, endpointTokenConfig });
|
||||
const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig }));
|
||||
this.rate = multiplier;
|
||||
this.tokenValue = this.rawAmount * multiplier;
|
||||
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
|
||||
|
|
@ -20,34 +21,167 @@ transactionSchema.methods.calculateTokenValue = function () {
|
|||
}
|
||||
};
|
||||
|
||||
// Static method to create a transaction and update the balance
|
||||
transactionSchema.statics.create = async function (transactionData) {
|
||||
/**
|
||||
* Static method to create a transaction and update the balance
|
||||
* @param {txData} txData - Transaction data.
|
||||
*/
|
||||
transactionSchema.statics.create = async function (txData) {
|
||||
const Transaction = this;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(transactionData);
|
||||
transaction.endpointTokenConfig = transactionData.endpointTokenConfig;
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
transaction.calculateTokenValue();
|
||||
|
||||
// Save the transaction
|
||||
await transaction.save();
|
||||
|
||||
if (!isEnabled(process.env.CHECK_BALANCE)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Adjust the user's balance
|
||||
const updatedBalance = await Balance.findOneAndUpdate(
|
||||
let balance = await Balance.findOne({ user: transaction.user }).lean();
|
||||
let incrementValue = transaction.tokenValue;
|
||||
|
||||
if (balance && balance?.tokenCredits + incrementValue < 0) {
|
||||
incrementValue = -balance.tokenCredits;
|
||||
}
|
||||
|
||||
balance = await Balance.findOneAndUpdate(
|
||||
{ user: transaction.user },
|
||||
{ $inc: { tokenCredits: transaction.tokenValue } },
|
||||
{ $inc: { tokenCredits: incrementValue } },
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
|
||||
return {
|
||||
rate: transaction.rate,
|
||||
user: transaction.user.toString(),
|
||||
balance: updatedBalance.tokenCredits,
|
||||
[transaction.tokenType]: transaction.tokenValue,
|
||||
balance: balance.tokenCredits,
|
||||
[transaction.tokenType]: incrementValue,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = mongoose.model('Transaction', transactionSchema);
|
||||
/**
|
||||
* Static method to create a structured transaction and update the balance
|
||||
* @param {txData} txData - Transaction data.
|
||||
*/
|
||||
transactionSchema.statics.createStructured = async function (txData) {
|
||||
const Transaction = this;
|
||||
|
||||
const transaction = new Transaction({
|
||||
...txData,
|
||||
endpointTokenConfig: txData.endpointTokenConfig,
|
||||
});
|
||||
|
||||
transaction.calculateStructuredTokenValue();
|
||||
|
||||
await transaction.save();
|
||||
|
||||
if (!isEnabled(process.env.CHECK_BALANCE)) {
|
||||
return;
|
||||
}
|
||||
|
||||
let balance = await Balance.findOne({ user: transaction.user }).lean();
|
||||
let incrementValue = transaction.tokenValue;
|
||||
|
||||
if (balance && balance?.tokenCredits + incrementValue < 0) {
|
||||
incrementValue = -balance.tokenCredits;
|
||||
}
|
||||
|
||||
balance = await Balance.findOneAndUpdate(
|
||||
{ user: transaction.user },
|
||||
{ $inc: { tokenCredits: incrementValue } },
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
|
||||
return {
|
||||
rate: transaction.rate,
|
||||
user: transaction.user.toString(),
|
||||
balance: balance.tokenCredits,
|
||||
[transaction.tokenType]: incrementValue,
|
||||
};
|
||||
};
|
||||
|
||||
/** Method to calculate token value for structured tokens */
|
||||
transactionSchema.methods.calculateStructuredTokenValue = function () {
|
||||
if (!this.tokenType) {
|
||||
this.tokenValue = this.rawAmount;
|
||||
return;
|
||||
}
|
||||
|
||||
const { model, endpointTokenConfig } = this;
|
||||
|
||||
if (this.tokenType === 'prompt') {
|
||||
const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig });
|
||||
const writeMultiplier =
|
||||
getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier;
|
||||
const readMultiplier =
|
||||
getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier;
|
||||
|
||||
this.rateDetail = {
|
||||
input: inputMultiplier,
|
||||
write: writeMultiplier,
|
||||
read: readMultiplier,
|
||||
};
|
||||
|
||||
const totalPromptTokens =
|
||||
Math.abs(this.inputTokens || 0) +
|
||||
Math.abs(this.writeTokens || 0) +
|
||||
Math.abs(this.readTokens || 0);
|
||||
|
||||
if (totalPromptTokens > 0) {
|
||||
this.rate =
|
||||
(Math.abs(inputMultiplier * (this.inputTokens || 0)) +
|
||||
Math.abs(writeMultiplier * (this.writeTokens || 0)) +
|
||||
Math.abs(readMultiplier * (this.readTokens || 0))) /
|
||||
totalPromptTokens;
|
||||
} else {
|
||||
this.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens
|
||||
}
|
||||
|
||||
this.tokenValue = -(
|
||||
Math.abs(this.inputTokens || 0) * inputMultiplier +
|
||||
Math.abs(this.writeTokens || 0) * writeMultiplier +
|
||||
Math.abs(this.readTokens || 0) * readMultiplier
|
||||
);
|
||||
|
||||
this.rawAmount = -totalPromptTokens;
|
||||
} else if (this.tokenType === 'completion') {
|
||||
const multiplier = getMultiplier({ tokenType: this.tokenType, model, endpointTokenConfig });
|
||||
this.rate = Math.abs(multiplier);
|
||||
this.tokenValue = -Math.abs(this.rawAmount) * multiplier;
|
||||
this.rawAmount = -Math.abs(this.rawAmount);
|
||||
}
|
||||
|
||||
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
|
||||
this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
|
||||
this.rate *= cancelRate;
|
||||
if (this.rateDetail) {
|
||||
this.rateDetail = Object.fromEntries(
|
||||
Object.entries(this.rateDetail).map(([k, v]) => [k, v * cancelRate]),
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const Transaction = mongoose.model('Transaction', transactionSchema);
|
||||
|
||||
/**
|
||||
* Queries and retrieves transactions based on a given filter.
|
||||
* @async
|
||||
* @function getTransactions
|
||||
* @param {Object} filter - MongoDB filter object to apply when querying transactions.
|
||||
* @returns {Promise<Array>} A promise that resolves to an array of matched transactions.
|
||||
* @throws {Error} Throws an error if querying the database fails.
|
||||
*/
|
||||
async function getTransactions(filter) {
|
||||
try {
|
||||
return await Transaction.find(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error querying transactions:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { Transaction, getTransactions };
|
||||
|
|
|
|||
374
api/models/Transaction.spec.js
Normal file
374
api/models/Transaction.spec.js
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const Balance = require('./Balance');
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
});
|
||||
|
||||
describe('Regular Token Spending Tests', () => {
|
||||
test('Balance should decrease when spending tokens with spendTokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000; // $10.00
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
// Act
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
console.log('Initial Balance:', initialBalance);
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
console.log('Updated Balance:', updatedBalance.tokenCredits);
|
||||
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const completionMultiplier = getMultiplier({ model, tokenType: 'completion' });
|
||||
|
||||
const expectedPromptCost = tokenUsage.promptTokens * promptMultiplier;
|
||||
const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier;
|
||||
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
|
||||
const expectedBalance = initialBalance - expectedTotalCost;
|
||||
|
||||
expect(updatedBalance.tokenCredits).toBeLessThan(initialBalance);
|
||||
expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0);
|
||||
|
||||
console.log('Expected Total Cost:', expectedTotalCost);
|
||||
console.log('Actual Balance Decrease:', initialBalance - updatedBalance.tokenCredits);
|
||||
});
|
||||
|
||||
test('spendTokens should handle zero completion tokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000; // $10.00
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 0,
|
||||
};
|
||||
|
||||
// Act
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const expectedCost = tokenUsage.promptTokens * promptMultiplier;
|
||||
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
|
||||
console.log('Initial Balance:', initialBalance);
|
||||
console.log('Updated Balance:', updatedBalance.tokenCredits);
|
||||
console.log('Expected Cost:', expectedCost);
|
||||
});
|
||||
|
||||
test('spendTokens should handle undefined token counts', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000; // $10.00
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
};
|
||||
|
||||
const tokenUsage = {};
|
||||
|
||||
const result = await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
test('spendTokens should handle only prompt tokens', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000; // $10.00
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
};
|
||||
|
||||
const tokenUsage = { promptTokens: 100 };
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
const updatedBalance = await Balance.findOne({ user: userId });
|
||||
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const expectedCost = 100 * promptMultiplier;
|
||||
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Structured Token Spending Tests', () => {
|
||||
test('Balance should decrease and rawAmount should be set when spending a large number of structured tokens', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55; // $17.61
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'c23a18da-706c-470a-ac28-ec87ed065199',
|
||||
model,
|
||||
context: 'message',
|
||||
endpointTokenConfig: null, // We'll use the default rates
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 11,
|
||||
write: 140522,
|
||||
read: 0,
|
||||
},
|
||||
completionTokens: 5,
|
||||
};
|
||||
|
||||
// Get the actual multipliers
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const completionMultiplier = getMultiplier({ model, tokenType: 'completion' });
|
||||
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' });
|
||||
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' });
|
||||
|
||||
console.log('Multipliers:', {
|
||||
promptMultiplier,
|
||||
completionMultiplier,
|
||||
writeMultiplier,
|
||||
readMultiplier,
|
||||
});
|
||||
|
||||
// Act
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
// Assert
|
||||
console.log('Initial Balance:', initialBalance);
|
||||
console.log('Updated Balance:', result.completion.balance);
|
||||
console.log('Transaction Result:', result);
|
||||
|
||||
const expectedPromptCost =
|
||||
tokenUsage.promptTokens.input * promptMultiplier +
|
||||
tokenUsage.promptTokens.write * writeMultiplier +
|
||||
tokenUsage.promptTokens.read * readMultiplier;
|
||||
const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier;
|
||||
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
|
||||
const expectedBalance = initialBalance - expectedTotalCost;
|
||||
|
||||
console.log('Expected Cost:', expectedTotalCost);
|
||||
console.log('Expected Balance:', expectedBalance);
|
||||
|
||||
expect(result.completion.balance).toBeLessThan(initialBalance);
|
||||
|
||||
// Allow for a small difference (e.g., 100 token credits, which is $0.0001)
|
||||
const allowedDifference = 100;
|
||||
expect(Math.abs(result.completion.balance - expectedBalance)).toBeLessThan(allowedDifference);
|
||||
|
||||
// Check if the decrease is approximately as expected
|
||||
const balanceDecrease = initialBalance - result.completion.balance;
|
||||
expect(balanceDecrease).toBeCloseTo(expectedTotalCost, 0);
|
||||
|
||||
// Check token values
|
||||
const expectedPromptTokenValue = -(
|
||||
tokenUsage.promptTokens.input * promptMultiplier +
|
||||
tokenUsage.promptTokens.write * writeMultiplier +
|
||||
tokenUsage.promptTokens.read * readMultiplier
|
||||
);
|
||||
const expectedCompletionTokenValue = -tokenUsage.completionTokens * completionMultiplier;
|
||||
|
||||
expect(result.prompt.prompt).toBeCloseTo(expectedPromptTokenValue, 1);
|
||||
expect(result.completion.completion).toBe(expectedCompletionTokenValue);
|
||||
|
||||
console.log('Expected prompt tokenValue:', expectedPromptTokenValue);
|
||||
console.log('Actual prompt tokenValue:', result.prompt.prompt);
|
||||
console.log('Expected completion tokenValue:', expectedCompletionTokenValue);
|
||||
console.log('Actual completion tokenValue:', result.completion.completion);
|
||||
});
|
||||
|
||||
test('should handle zero completion tokens in structured spending', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
completionTokens: 0,
|
||||
};
|
||||
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
expect(result.prompt).toBeDefined();
|
||||
expect(result.completion).toBeUndefined();
|
||||
expect(result.prompt.prompt).toBeLessThan(0);
|
||||
});
|
||||
|
||||
test('should handle only prompt tokens in structured spending', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
};
|
||||
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
expect(result.prompt).toBeDefined();
|
||||
expect(result.completion).toBeUndefined();
|
||||
expect(result.prompt.prompt).toBeLessThan(0);
|
||||
});
|
||||
|
||||
test('should handle undefined token counts in structured spending', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
};
|
||||
|
||||
const tokenUsage = {};
|
||||
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
expect(result).toEqual({
|
||||
prompt: undefined,
|
||||
completion: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle incomplete context for completion tokens', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 17613154.55;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'incomplete',
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); // Assuming multiplier is 15 and cancelRate is 1.15
|
||||
});
|
||||
});
|
||||
|
||||
describe('NaN Handling Tests', () => {
|
||||
test('should skip transaction creation when rawAmount is NaN', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: NaN,
|
||||
tokenType: 'prompt',
|
||||
};
|
||||
|
||||
const result = await Transaction.create(txData);
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,61 +1,5 @@
|
|||
const mongoose = require('mongoose');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const signPayload = require('../server/services/signPayload');
|
||||
const userSchema = require('./schema/userSchema.js');
|
||||
const { SESSION_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
||||
|
||||
userSchema.methods.toJSON = function () {
|
||||
return {
|
||||
id: this._id,
|
||||
provider: this.provider,
|
||||
email: this.email,
|
||||
name: this.name,
|
||||
username: this.username,
|
||||
avatar: this.avatar,
|
||||
role: this.role,
|
||||
emailVerified: this.emailVerified,
|
||||
plugins: this.plugins,
|
||||
createdAt: this.createdAt,
|
||||
updatedAt: this.updatedAt,
|
||||
};
|
||||
};
|
||||
|
||||
userSchema.methods.generateToken = async function () {
|
||||
return await signPayload({
|
||||
payload: {
|
||||
id: this._id,
|
||||
username: this.username,
|
||||
provider: this.provider,
|
||||
email: this.email,
|
||||
},
|
||||
secret: process.env.JWT_SECRET,
|
||||
expirationTime: expires / 1000,
|
||||
});
|
||||
};
|
||||
|
||||
userSchema.methods.comparePassword = function (candidatePassword, callback) {
|
||||
bcrypt.compare(candidatePassword, this.password, (err, isMatch) => {
|
||||
if (err) {
|
||||
return callback(err);
|
||||
}
|
||||
callback(null, isMatch);
|
||||
});
|
||||
};
|
||||
|
||||
module.exports.hashPassword = async (password) => {
|
||||
const hashedPassword = await new Promise((resolve, reject) => {
|
||||
bcrypt.hash(password, 10, function (err, hash) {
|
||||
if (err) {
|
||||
reject(err);
|
||||
} else {
|
||||
resolve(hash);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return hashedPassword;
|
||||
};
|
||||
const userSchema = require('~/models/schema/userSchema');
|
||||
|
||||
const User = mongoose.model('User', userSchema);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { logViolation } = require('~/cache');
|
||||
const Balance = require('./Balance');
|
||||
const { logViolation } = require('../cache');
|
||||
/**
|
||||
* Checks the balance for a user and determines if they can spend a certain amount.
|
||||
* If the user cannot spend the amount, it logs a violation and denies the request.
|
||||
|
|
@ -25,7 +26,7 @@ const checkBalance = async ({ req, res, txData }) => {
|
|||
return true;
|
||||
}
|
||||
|
||||
const type = 'token_balance';
|
||||
const type = ViolationTypes.TOKEN_BALANCE;
|
||||
const errorMessage = {
|
||||
type,
|
||||
balance,
|
||||
|
|
|
|||
313
api/models/convoStructure.spec.js
Normal file
313
api/models/convoStructure.spec.js
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { Message, getMessages, bulkSaveMessages } = require('./Message');
|
||||
|
||||
// Original version of buildTree function
|
||||
function buildTree({ messages, fileMap }) {
|
||||
if (messages === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const messageMap = {};
|
||||
const rootMessages = [];
|
||||
const childrenCount = {};
|
||||
|
||||
messages.forEach((message) => {
|
||||
const parentId = message.parentMessageId ?? '';
|
||||
childrenCount[parentId] = (childrenCount[parentId] || 0) + 1;
|
||||
|
||||
const extendedMessage = {
|
||||
...message,
|
||||
children: [],
|
||||
depth: 0,
|
||||
siblingIndex: childrenCount[parentId] - 1,
|
||||
};
|
||||
|
||||
if (message.files && fileMap) {
|
||||
extendedMessage.files = message.files.map((file) => fileMap[file.file_id ?? ''] ?? file);
|
||||
}
|
||||
|
||||
messageMap[message.messageId] = extendedMessage;
|
||||
|
||||
const parentMessage = messageMap[parentId];
|
||||
if (parentMessage) {
|
||||
parentMessage.children.push(extendedMessage);
|
||||
extendedMessage.depth = parentMessage.depth + 1;
|
||||
} else {
|
||||
rootMessages.push(extendedMessage);
|
||||
}
|
||||
});
|
||||
|
||||
return rootMessages;
|
||||
}
|
||||
|
||||
let mongod;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongod = await MongoMemoryServer.create();
|
||||
const uri = mongod.getUri();
|
||||
await mongoose.connect(uri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongod.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Message.deleteMany({});
|
||||
});
|
||||
|
||||
describe('Conversation Structure Tests', () => {
|
||||
test('Conversation folding/corrupting with inconsistent timestamps', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create messages with inconsistent timestamps
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'message0',
|
||||
parentMessageId: null,
|
||||
text: 'Message 0',
|
||||
createdAt: new Date('2023-01-01T00:00:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'message1',
|
||||
parentMessageId: 'message0',
|
||||
text: 'Message 1',
|
||||
createdAt: new Date('2023-01-01T00:02:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'message2',
|
||||
parentMessageId: 'message1',
|
||||
text: 'Message 2',
|
||||
createdAt: new Date('2023-01-01T00:01:00Z'),
|
||||
}, // Note: Earlier than its parent
|
||||
{
|
||||
messageId: 'message3',
|
||||
parentMessageId: 'message1',
|
||||
text: 'Message 3',
|
||||
createdAt: new Date('2023-01-01T00:03:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'message4',
|
||||
parentMessageId: 'message2',
|
||||
text: 'Message 4',
|
||||
createdAt: new Date('2023-01-01T00:04:00Z'),
|
||||
},
|
||||
];
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
msg.conversationId = conversationId;
|
||||
msg.user = userId;
|
||||
msg.isCreatedByUser = false;
|
||||
msg.error = false;
|
||||
msg.unfinished = false;
|
||||
});
|
||||
|
||||
// Save messages with overrideTimestamp omitted (default is false)
|
||||
await bulkSaveMessages(messages, true);
|
||||
|
||||
// Retrieve messages (this will sort by createdAt)
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages });
|
||||
|
||||
// Check if the tree is incorrect (folded/corrupted)
|
||||
expect(tree.length).toBeGreaterThan(1); // Should have multiple root messages, indicating corruption
|
||||
});
|
||||
|
||||
test('Fix: Conversation structure maintained with more than 16 messages', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create more than 16 messages
|
||||
const messages = Array.from({ length: 20 }, (_, i) => ({
|
||||
messageId: `message${i}`,
|
||||
parentMessageId: i === 0 ? null : `message${i - 1}`,
|
||||
conversationId,
|
||||
user: userId,
|
||||
text: `Message ${i}`,
|
||||
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 500000 : -i * 500000)),
|
||||
}));
|
||||
|
||||
// Save messages with new timestamps being generated (message objects ignored)
|
||||
await bulkSaveMessages(messages);
|
||||
|
||||
// Retrieve messages (this will sort by createdAt, but it shouldn't matter now)
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages });
|
||||
|
||||
// Check if the tree is correct
|
||||
expect(tree.length).toBe(1); // Should have only one root message
|
||||
let currentNode = tree[0];
|
||||
for (let i = 1; i < 20; i++) {
|
||||
expect(currentNode.children.length).toBe(1);
|
||||
currentNode = currentNode.children[0];
|
||||
expect(currentNode.text).toBe(`Message ${i}`);
|
||||
}
|
||||
expect(currentNode.children.length).toBe(0); // Last message should have no children
|
||||
});
|
||||
|
||||
test('Simulate MongoDB ordering issue with more than 16 messages and close timestamps', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create more than 16 messages with very close timestamps
|
||||
const messages = Array.from({ length: 20 }, (_, i) => ({
|
||||
messageId: `message${i}`,
|
||||
parentMessageId: i === 0 ? null : `message${i - 1}`,
|
||||
conversationId,
|
||||
user: userId,
|
||||
text: `Message ${i}`,
|
||||
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 1 : -i * 1)),
|
||||
}));
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
msg.isCreatedByUser = false;
|
||||
msg.error = false;
|
||||
msg.unfinished = false;
|
||||
});
|
||||
|
||||
await bulkSaveMessages(messages, true);
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
const tree = buildTree({ messages: retrievedMessages });
|
||||
expect(tree.length).toBeGreaterThan(1);
|
||||
});
|
||||
|
||||
test('Fix: Preserve order with more than 16 messages by maintaining original timestamps', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create more than 16 messages with distinct timestamps
|
||||
const messages = Array.from({ length: 20 }, (_, i) => ({
|
||||
messageId: `message${i}`,
|
||||
parentMessageId: i === 0 ? null : `message${i - 1}`,
|
||||
conversationId,
|
||||
user: userId,
|
||||
text: `Message ${i}`,
|
||||
createdAt: new Date(Date.now() + i * 1000), // Ensure each message has a distinct timestamp
|
||||
}));
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
msg.isCreatedByUser = false;
|
||||
msg.error = false;
|
||||
msg.unfinished = false;
|
||||
});
|
||||
|
||||
// Save messages with overriding timestamps (preserve original timestamps)
|
||||
await bulkSaveMessages(messages, true);
|
||||
|
||||
// Retrieve messages (this will sort by createdAt)
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages });
|
||||
|
||||
// Check if the tree is correct
|
||||
expect(tree.length).toBe(1); // Should have only one root message
|
||||
let currentNode = tree[0];
|
||||
for (let i = 1; i < 20; i++) {
|
||||
expect(currentNode.children.length).toBe(1);
|
||||
currentNode = currentNode.children[0];
|
||||
expect(currentNode.text).toBe(`Message ${i}`);
|
||||
}
|
||||
expect(currentNode.children.length).toBe(0); // Last message should have no children
|
||||
});
|
||||
|
||||
test('Random order dates between parent and children messages', async () => {
|
||||
const userId = 'testUser';
|
||||
const conversationId = 'testConversation';
|
||||
|
||||
// Create messages with deliberately out-of-order timestamps but sequential creation
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'parent',
|
||||
parentMessageId: null,
|
||||
text: 'Parent Message',
|
||||
createdAt: new Date('2023-01-01T00:00:00Z'), // Make parent earliest
|
||||
},
|
||||
{
|
||||
messageId: 'child1',
|
||||
parentMessageId: 'parent',
|
||||
text: 'Child Message 1',
|
||||
createdAt: new Date('2023-01-01T00:01:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'child2',
|
||||
parentMessageId: 'parent',
|
||||
text: 'Child Message 2',
|
||||
createdAt: new Date('2023-01-01T00:02:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'grandchild1',
|
||||
parentMessageId: 'child1',
|
||||
text: 'Grandchild Message 1',
|
||||
createdAt: new Date('2023-01-01T00:03:00Z'),
|
||||
},
|
||||
];
|
||||
|
||||
// Add common properties to all messages
|
||||
messages.forEach((msg) => {
|
||||
msg.conversationId = conversationId;
|
||||
msg.user = userId;
|
||||
msg.isCreatedByUser = false;
|
||||
msg.error = false;
|
||||
msg.unfinished = false;
|
||||
});
|
||||
|
||||
// Save messages with overrideTimestamp set to true
|
||||
await bulkSaveMessages(messages, true);
|
||||
|
||||
// Retrieve messages
|
||||
const retrievedMessages = await getMessages({ conversationId, user: userId });
|
||||
|
||||
// Debug log to see what's being returned
|
||||
console.log(
|
||||
'Retrieved Messages:',
|
||||
retrievedMessages.map((msg) => ({
|
||||
messageId: msg.messageId,
|
||||
parentMessageId: msg.parentMessageId,
|
||||
createdAt: msg.createdAt,
|
||||
})),
|
||||
);
|
||||
|
||||
// Build tree
|
||||
const tree = buildTree({ messages: retrievedMessages });
|
||||
|
||||
// Debug log to see the tree structure
|
||||
console.log(
|
||||
'Tree structure:',
|
||||
tree.map((root) => ({
|
||||
messageId: root.messageId,
|
||||
children: root.children.map((child) => ({
|
||||
messageId: child.messageId,
|
||||
children: child.children.map((grandchild) => ({
|
||||
messageId: grandchild.messageId,
|
||||
})),
|
||||
})),
|
||||
})),
|
||||
);
|
||||
|
||||
// Verify the structure before making assertions
|
||||
expect(retrievedMessages.length).toBe(4); // Should have all 4 messages
|
||||
|
||||
// Check if messages are properly linked
|
||||
const parentMsg = retrievedMessages.find((msg) => msg.messageId === 'parent');
|
||||
expect(parentMsg.parentMessageId).toBeNull(); // Parent should have null parentMessageId
|
||||
|
||||
const childMsg1 = retrievedMessages.find((msg) => msg.messageId === 'child1');
|
||||
expect(childMsg1.parentMessageId).toBe('parent');
|
||||
|
||||
// Then check tree structure
|
||||
expect(tree.length).toBe(1); // Should have only one root message
|
||||
expect(tree[0].messageId).toBe('parent');
|
||||
expect(tree[0].children.length).toBe(2); // Should have two children
|
||||
});
|
||||
});
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
const {
|
||||
getMessages,
|
||||
saveMessage,
|
||||
recordMessage,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
} = require('./Message');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { hashPassword, getUser, updateUser } = require('./userMethods');
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
updateUser,
|
||||
createUser,
|
||||
countUsers,
|
||||
findUser,
|
||||
} = require('./userMethods');
|
||||
const {
|
||||
findFileById,
|
||||
createFile,
|
||||
|
|
@ -18,23 +17,50 @@ const {
|
|||
getFiles,
|
||||
updateFileUsage,
|
||||
} = require('./File');
|
||||
const Key = require('./Key');
|
||||
const User = require('./User');
|
||||
const Session = require('./Session');
|
||||
const {
|
||||
getMessage,
|
||||
getMessages,
|
||||
saveMessage,
|
||||
recordMessage,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
} = require('./Message');
|
||||
const {
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
} = require('./Session');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { createToken, findToken, updateToken, deleteTokens } = require('./Token');
|
||||
const Balance = require('./Balance');
|
||||
const Transaction = require('./Transaction');
|
||||
const User = require('./User');
|
||||
const Key = require('./Key');
|
||||
|
||||
module.exports = {
|
||||
User,
|
||||
Key,
|
||||
Session,
|
||||
Balance,
|
||||
Transaction,
|
||||
|
||||
hashPassword,
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
updateUser,
|
||||
getUser,
|
||||
createUser,
|
||||
countUsers,
|
||||
findUser,
|
||||
|
||||
findFileById,
|
||||
createFile,
|
||||
updateFile,
|
||||
deleteFile,
|
||||
deleteFiles,
|
||||
getFiles,
|
||||
updateFileUsage,
|
||||
|
||||
getMessage,
|
||||
getMessages,
|
||||
saveMessage,
|
||||
recordMessage,
|
||||
|
|
@ -52,11 +78,20 @@ module.exports = {
|
|||
savePreset,
|
||||
deletePresets,
|
||||
|
||||
findFileById,
|
||||
createFile,
|
||||
updateFile,
|
||||
deleteFile,
|
||||
deleteFiles,
|
||||
getFiles,
|
||||
updateFileUsage,
|
||||
createToken,
|
||||
findToken,
|
||||
updateToken,
|
||||
deleteTokens,
|
||||
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
|
||||
User,
|
||||
Key,
|
||||
Balance,
|
||||
};
|
||||
|
|
|
|||
69
api/models/inviteUser.js
Normal file
69
api/models/inviteUser.js
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { getRandomValues, hashToken } = require('~/server/utils/crypto');
|
||||
const { createToken, findToken } = require('./Token');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
* @module inviteUser
|
||||
* @description This module provides functions to create and get user invites
|
||||
*/
|
||||
|
||||
/**
|
||||
* @function createInvite
|
||||
* @description This function creates a new user invite
|
||||
* @param {string} email - The email of the user to invite
|
||||
* @returns {Promise<Object>} A promise that resolves to the saved invite document
|
||||
* @throws {Error} If there is an error creating the invite
|
||||
*/
|
||||
const createInvite = async (email) => {
|
||||
try {
|
||||
const token = await getRandomValues(32);
|
||||
const hash = await hashToken(token);
|
||||
const encodedToken = encodeURIComponent(token);
|
||||
|
||||
const fakeUserId = new mongoose.Types.ObjectId();
|
||||
|
||||
await createToken({
|
||||
userId: fakeUserId,
|
||||
email,
|
||||
token: hash,
|
||||
createdAt: Date.now(),
|
||||
expiresIn: 604800,
|
||||
});
|
||||
|
||||
return encodedToken;
|
||||
} catch (error) {
|
||||
logger.error('[createInvite] Error creating invite', error);
|
||||
return { message: 'Error creating invite' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @function getInvite
|
||||
* @description This function retrieves a user invite
|
||||
* @param {string} encodedToken - The token of the invite to retrieve
|
||||
* @param {string} email - The email of the user to validate
|
||||
* @returns {Promise<Object>} A promise that resolves to the retrieved invite document
|
||||
* @throws {Error} If there is an error retrieving the invite, if the invite does not exist, or if the email does not match
|
||||
*/
|
||||
const getInvite = async (encodedToken, email) => {
|
||||
try {
|
||||
const token = decodeURIComponent(encodedToken);
|
||||
const hash = await hashToken(token);
|
||||
const invite = await findToken({ token: hash, email });
|
||||
|
||||
if (!invite) {
|
||||
throw new Error('Invite not found or email does not match');
|
||||
}
|
||||
|
||||
return invite;
|
||||
} catch (error) {
|
||||
logger.error('[getInvite] Error getting invite:', error);
|
||||
return { error: true, message: error.message };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
createInvite,
|
||||
getInvite,
|
||||
};
|
||||
|
|
@ -155,7 +155,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) {
|
|||
function (results, value, key) {
|
||||
return { ...results, [key]: 1 };
|
||||
},
|
||||
{ _id: 1 },
|
||||
{ _id: 1, __v: 1 },
|
||||
),
|
||||
).lean();
|
||||
|
||||
|
|
@ -348,7 +348,7 @@ module.exports = function mongoMeili(schema, options) {
|
|||
try {
|
||||
meiliDoc = await client.index('convos').getDocument(doc.conversationId);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
logger.debug(
|
||||
'[MeiliMongooseModel.findOneAndUpdate] Convo not found in MeiliSearch and will index ' +
|
||||
doc.conversationId,
|
||||
error,
|
||||
|
|
|
|||
|
|
@ -39,13 +39,13 @@ const actionSchema = new Schema({
|
|||
default: 'action_prototype',
|
||||
},
|
||||
settings: Schema.Types.Mixed,
|
||||
agent_id: String,
|
||||
assistant_id: String,
|
||||
metadata: {
|
||||
api_key: String, // private, encrypted
|
||||
auth: AuthSchema,
|
||||
domain: {
|
||||
type: String,
|
||||
unique: true,
|
||||
required: true,
|
||||
},
|
||||
// json_schema: Schema.Types.Mixed,
|
||||
|
|
|
|||
96
api/models/schema/agent.js
Normal file
96
api/models/schema/agent.js
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
const agentSchema = mongoose.Schema(
|
||||
{
|
||||
id: {
|
||||
type: String,
|
||||
index: true,
|
||||
unique: true,
|
||||
required: true,
|
||||
},
|
||||
name: {
|
||||
type: String,
|
||||
},
|
||||
description: {
|
||||
type: String,
|
||||
},
|
||||
instructions: {
|
||||
type: String,
|
||||
},
|
||||
avatar: {
|
||||
type: {
|
||||
filepath: String,
|
||||
source: String,
|
||||
},
|
||||
default: undefined,
|
||||
},
|
||||
provider: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
model: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
model_parameters: {
|
||||
type: Object,
|
||||
},
|
||||
artifacts: {
|
||||
type: String,
|
||||
},
|
||||
access_level: {
|
||||
type: Number,
|
||||
},
|
||||
tools: {
|
||||
type: [String],
|
||||
default: undefined,
|
||||
},
|
||||
tool_kwargs: {
|
||||
type: [{ type: mongoose.Schema.Types.Mixed }],
|
||||
},
|
||||
actions: {
|
||||
type: [String],
|
||||
default: undefined,
|
||||
},
|
||||
author: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
authorName: {
|
||||
type: String,
|
||||
default: undefined,
|
||||
},
|
||||
hide_sequential_outputs: {
|
||||
type: Boolean,
|
||||
},
|
||||
end_after_tools: {
|
||||
type: Boolean,
|
||||
},
|
||||
agent_ids: {
|
||||
type: [String],
|
||||
},
|
||||
isCollaborative: {
|
||||
type: Boolean,
|
||||
default: undefined,
|
||||
},
|
||||
conversation_starters: {
|
||||
type: [String],
|
||||
default: [],
|
||||
},
|
||||
tool_resources: {
|
||||
type: mongoose.Schema.Types.Mixed,
|
||||
default: {},
|
||||
},
|
||||
projectIds: {
|
||||
type: [mongoose.Schema.Types.ObjectId],
|
||||
ref: 'Project',
|
||||
index: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = agentSchema;
|
||||
|
|
@ -9,7 +9,6 @@ const assistantSchema = mongoose.Schema(
|
|||
},
|
||||
assistant_id: {
|
||||
type: String,
|
||||
unique: true,
|
||||
index: true,
|
||||
required: true,
|
||||
},
|
||||
|
|
@ -20,11 +19,19 @@ const assistantSchema = mongoose.Schema(
|
|||
},
|
||||
default: undefined,
|
||||
},
|
||||
conversation_starters: {
|
||||
type: [String],
|
||||
default: [],
|
||||
},
|
||||
access_level: {
|
||||
type: Number,
|
||||
},
|
||||
file_ids: { type: [String], default: undefined },
|
||||
actions: { type: [String], default: undefined },
|
||||
append_current_datetime: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
|
|
|
|||
36
api/models/schema/banner.js
Normal file
36
api/models/schema/banner.js
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
const bannerSchema = mongoose.Schema(
|
||||
{
|
||||
bannerId: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
message: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
displayFrom: {
|
||||
type: Date,
|
||||
required: true,
|
||||
default: Date.now,
|
||||
},
|
||||
displayTo: {
|
||||
type: Date,
|
||||
},
|
||||
type: {
|
||||
type: String,
|
||||
enum: ['banner', 'popup'],
|
||||
default: 'banner',
|
||||
},
|
||||
isPublic: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
},
|
||||
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
const Banner = mongoose.model('Banner', bannerSchema);
|
||||
module.exports = Banner;
|
||||
19
api/models/schema/categories.js
Normal file
19
api/models/schema/categories.js
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
const mongoose = require('mongoose');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
const categoriesSchema = new Schema({
|
||||
label: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
value: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
});
|
||||
|
||||
const categories = mongoose.model('categories', categoriesSchema);
|
||||
|
||||
module.exports = { Categories: categories };
|
||||
32
api/models/schema/conversationTagSchema.js
Normal file
32
api/models/schema/conversationTagSchema.js
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
const conversationTagSchema = mongoose.Schema(
|
||||
{
|
||||
tag: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
user: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
description: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
count: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
position: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
index: true,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
conversationTagSchema.index({ tag: 1, user: 1 }, { unique: true });
|
||||
|
||||
module.exports = mongoose.model('ConversationTag', conversationTagSchema);
|
||||
|
|
@ -26,21 +26,19 @@ const convoSchema = mongoose.Schema(
|
|||
type: mongoose.Schema.Types.Mixed,
|
||||
},
|
||||
...conversationPreset,
|
||||
// for bingAI only
|
||||
bingConversationId: {
|
||||
agent_id: {
|
||||
type: String,
|
||||
},
|
||||
jailbreakConversationId: {
|
||||
type: String,
|
||||
tags: {
|
||||
type: [String],
|
||||
default: [],
|
||||
meiliIndex: true,
|
||||
},
|
||||
conversationSignature: {
|
||||
type: String,
|
||||
files: {
|
||||
type: [String],
|
||||
},
|
||||
clientId: {
|
||||
type: String,
|
||||
},
|
||||
invocationId: {
|
||||
type: Number,
|
||||
expiredAt: {
|
||||
type: Date,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
|
|
@ -55,7 +53,10 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
|
|||
});
|
||||
}
|
||||
|
||||
// Create TTL index
|
||||
convoSchema.index({ expiredAt: 1 }, { expireAfterSeconds: 0 });
|
||||
convoSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
convoSchema.index({ conversationId: 1, user: 1 }, { unique: true });
|
||||
|
||||
const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const conversationPreset = {
|
||||
// endpoint: [azureOpenAI, openAI, bingAI, anthropic, chatGPTBrowser]
|
||||
// endpoint: [azureOpenAI, openAI, anthropic, chatGPTBrowser]
|
||||
endpoint: {
|
||||
type: String,
|
||||
default: null,
|
||||
|
|
@ -13,6 +13,11 @@ const conversationPreset = {
|
|||
type: String,
|
||||
required: false,
|
||||
},
|
||||
// for bedrock only
|
||||
region: {
|
||||
type: String,
|
||||
required: false,
|
||||
},
|
||||
// for azureOpenAI, openAI only
|
||||
chatGptLabel: {
|
||||
type: String,
|
||||
|
|
@ -56,27 +61,29 @@ const conversationPreset = {
|
|||
type: Number,
|
||||
required: false,
|
||||
},
|
||||
// for bingai only
|
||||
jailbreak: {
|
||||
type: Boolean,
|
||||
},
|
||||
context: {
|
||||
type: String,
|
||||
},
|
||||
systemMessage: {
|
||||
type: String,
|
||||
},
|
||||
toneStyle: {
|
||||
type: String,
|
||||
},
|
||||
file_ids: { type: [{ type: String }], default: undefined },
|
||||
// vision
|
||||
// deprecated
|
||||
resendImages: {
|
||||
type: Boolean,
|
||||
},
|
||||
/* Anthropic only */
|
||||
promptCache: {
|
||||
type: Boolean,
|
||||
},
|
||||
system: {
|
||||
type: String,
|
||||
},
|
||||
// files
|
||||
resendFiles: {
|
||||
type: Boolean,
|
||||
},
|
||||
imageDetail: {
|
||||
type: String,
|
||||
},
|
||||
/* agents */
|
||||
agent_id: {
|
||||
type: String,
|
||||
},
|
||||
/* assistants */
|
||||
assistant_id: {
|
||||
type: String,
|
||||
|
|
@ -84,6 +91,36 @@ const conversationPreset = {
|
|||
instructions: {
|
||||
type: String,
|
||||
},
|
||||
stop: { type: [{ type: String }], default: undefined },
|
||||
isArchived: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
/* UI Components */
|
||||
iconURL: {
|
||||
type: String,
|
||||
},
|
||||
greeting: {
|
||||
type: String,
|
||||
},
|
||||
spec: {
|
||||
type: String,
|
||||
},
|
||||
tags: {
|
||||
type: [String],
|
||||
default: [],
|
||||
},
|
||||
tools: { type: [{ type: String }], default: undefined },
|
||||
maxContextTokens: {
|
||||
type: Number,
|
||||
},
|
||||
max_tokens: {
|
||||
type: Number,
|
||||
},
|
||||
/** omni models only */
|
||||
reasoning_effort: {
|
||||
type: String,
|
||||
},
|
||||
};
|
||||
|
||||
const agentOptions = {
|
||||
|
|
@ -133,12 +170,6 @@ const agentOptions = {
|
|||
type: Number,
|
||||
required: false,
|
||||
},
|
||||
context: {
|
||||
type: String,
|
||||
},
|
||||
systemMessage: {
|
||||
type: String,
|
||||
},
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ const mongoose = require('mongoose');
|
|||
|
||||
/**
|
||||
* @typedef {Object} MongoFile
|
||||
* @property {mongoose.Schema.Types.ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {number} [__v] - MongoDB Version Key
|
||||
* @property {mongoose.Schema.Types.ObjectId} user - User ID
|
||||
* @property {ObjectId} user - User ID
|
||||
* @property {string} [conversationId] - Optional conversation ID
|
||||
* @property {string} file_id - File identifier
|
||||
* @property {string} [temp_file_id] - Temporary File identifier
|
||||
|
|
@ -14,14 +14,21 @@ const mongoose = require('mongoose');
|
|||
* @property {string} filepath - Location of the file
|
||||
* @property {'file'} object - Type of object, always 'file'
|
||||
* @property {string} type - Type of file
|
||||
* @property {number} usage - Number of uses of the file
|
||||
* @property {string} [source] - The source of the file
|
||||
* @property {number} [usage=0] - Number of uses of the file
|
||||
* @property {string} [context] - Context of the file origin
|
||||
* @property {boolean} [embedded=false] - Whether or not the file is embedded in vector db
|
||||
* @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting)
|
||||
* @property {string} [source] - The source of the file (e.g., from FileSources)
|
||||
* @property {number} [width] - Optional width of the file
|
||||
* @property {number} [height] - Optional height of the file
|
||||
* @property {Date} [expiresAt] - Optional height of the file
|
||||
* @property {Object} [metadata] - Metadata related to the file
|
||||
* @property {string} [metadata.fileIdentifier] - Unique identifier for the file in metadata
|
||||
* @property {Date} [expiresAt] - Optional expiration date of the file
|
||||
* @property {Date} [createdAt] - Date when the file was created
|
||||
* @property {Date} [updatedAt] - Date when the file was updated
|
||||
*/
|
||||
|
||||
/** @type {MongooseSchema<MongoFile>} */
|
||||
const fileSchema = mongoose.Schema(
|
||||
{
|
||||
user: {
|
||||
|
|
@ -61,6 +68,9 @@ const fileSchema = mongoose.Schema(
|
|||
required: true,
|
||||
default: 'file',
|
||||
},
|
||||
embedded: {
|
||||
type: Boolean,
|
||||
},
|
||||
type: {
|
||||
type: String,
|
||||
required: true,
|
||||
|
|
@ -78,11 +88,17 @@ const fileSchema = mongoose.Schema(
|
|||
type: String,
|
||||
default: FileSources.local,
|
||||
},
|
||||
model: {
|
||||
type: String,
|
||||
},
|
||||
width: Number,
|
||||
height: Number,
|
||||
metadata: {
|
||||
fileIdentifier: String,
|
||||
},
|
||||
expiresAt: {
|
||||
type: Date,
|
||||
expires: 3600,
|
||||
expires: 3600, // 1 hour in seconds
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -90,4 +106,6 @@ const fileSchema = mongoose.Schema(
|
|||
},
|
||||
);
|
||||
|
||||
fileSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
|
||||
module.exports = fileSchema;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ const keySchema = mongoose.Schema({
|
|||
},
|
||||
expiresAt: {
|
||||
type: Date,
|
||||
expires: 0,
|
||||
},
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ const messageSchema = mongoose.Schema(
|
|||
},
|
||||
conversationId: {
|
||||
type: String,
|
||||
index: true,
|
||||
required: true,
|
||||
meiliIndex: true,
|
||||
},
|
||||
|
|
@ -61,10 +62,6 @@ const messageSchema = mongoose.Schema(
|
|||
required: true,
|
||||
default: false,
|
||||
},
|
||||
isEdited: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
unfinished: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
|
|
@ -110,6 +107,36 @@ const messageSchema = mongoose.Schema(
|
|||
thread_id: {
|
||||
type: String,
|
||||
},
|
||||
/* frontend components */
|
||||
iconURL: {
|
||||
type: String,
|
||||
},
|
||||
attachments: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined },
|
||||
/*
|
||||
attachments: {
|
||||
type: [
|
||||
{
|
||||
file_id: String,
|
||||
filename: String,
|
||||
filepath: String,
|
||||
expiresAt: Date,
|
||||
width: Number,
|
||||
height: Number,
|
||||
type: String,
|
||||
conversationId: String,
|
||||
messageId: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
toolCallId: String,
|
||||
},
|
||||
],
|
||||
default: undefined,
|
||||
},
|
||||
*/
|
||||
expiredAt: {
|
||||
type: Date,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
|
@ -122,9 +149,11 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
|
|||
primaryKey: 'messageId',
|
||||
});
|
||||
}
|
||||
|
||||
messageSchema.index({ expiredAt: 1 }, { expireAfterSeconds: 0 });
|
||||
messageSchema.index({ createdAt: 1 });
|
||||
messageSchema.index({ messageId: 1, user: 1 }, { unique: true });
|
||||
|
||||
/** @type {mongoose.Model<TMessage>} */
|
||||
const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
|
||||
|
||||
module.exports = Message;
|
||||
|
|
|
|||
35
api/models/schema/projectSchema.js
Normal file
35
api/models/schema/projectSchema.js
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
const { Schema } = require('mongoose');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoProject
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the project
|
||||
* @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project
|
||||
* @property {Date} [createdAt] - Date when the project was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const projectSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
promptGroupIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'PromptGroup',
|
||||
default: [],
|
||||
},
|
||||
agentIds: {
|
||||
type: [String],
|
||||
ref: 'Agent',
|
||||
default: [],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = projectSchema;
|
||||
118
api/models/schema/promptSchema.js
Normal file
118
api/models/schema/promptSchema.js
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoPromptGroup
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the prompt group
|
||||
* @property {ObjectId} author - The author of the prompt group
|
||||
* @property {ObjectId} [projectId=null] - The project ID of the prompt group
|
||||
* @property {ObjectId} [productionId=null] - The project ID of the prompt group
|
||||
* @property {string} authorName - The name of the author of the prompt group
|
||||
* @property {number} [numberOfGenerations=0] - Number of generations the prompt group has
|
||||
* @property {string} [oneliner=''] - Oneliner description of the prompt group
|
||||
* @property {string} [category=''] - Category of the prompt group
|
||||
* @property {string} [command] - Command for the prompt group
|
||||
* @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const promptGroupSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
numberOfGenerations: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
oneliner: {
|
||||
type: String,
|
||||
default: '',
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
default: '',
|
||||
index: true,
|
||||
},
|
||||
projectIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'Project',
|
||||
index: true,
|
||||
},
|
||||
productionId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'Prompt',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
authorName: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
command: {
|
||||
type: String,
|
||||
index: true,
|
||||
validate: {
|
||||
validator: function (v) {
|
||||
return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v);
|
||||
},
|
||||
message: (props) =>
|
||||
`${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`,
|
||||
},
|
||||
maxlength: [
|
||||
Constants.COMMANDS_MAX_LENGTH,
|
||||
`Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`,
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
|
||||
|
||||
const promptSchema = new Schema(
|
||||
{
|
||||
groupId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'PromptGroup',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
type: {
|
||||
type: String,
|
||||
enum: ['text', 'chat'],
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const Prompt = mongoose.model('Prompt', promptSchema);
|
||||
|
||||
promptSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
promptGroupSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
|
||||
module.exports = { Prompt, PromptGroup };
|
||||
55
api/models/schema/roleSchema.js
Normal file
55
api/models/schema/roleSchema.js
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const roleSchema = new mongoose.Schema({
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
index: true,
|
||||
},
|
||||
[PermissionTypes.BOOKMARKS]: {
|
||||
[Permissions.USE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
[Permissions.USE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
[Permissions.CREATE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
[Permissions.USE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
[Permissions.CREATE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: {
|
||||
[Permissions.USE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const Role = mongoose.model('Role', roleSchema);
|
||||
|
||||
module.exports = Role;
|
||||
20
api/models/schema/session.js
Normal file
20
api/models/schema/session.js
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
const sessionSchema = mongoose.Schema({
|
||||
refreshTokenHash: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
expiration: {
|
||||
type: Date,
|
||||
required: true,
|
||||
expires: 0,
|
||||
},
|
||||
user: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
});
|
||||
|
||||
module.exports = sessionSchema;
|
||||
30
api/models/schema/shareSchema.js
Normal file
30
api/models/schema/shareSchema.js
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
const shareSchema = mongoose.Schema(
|
||||
{
|
||||
conversationId: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
title: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
user: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }],
|
||||
shareId: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
isPublic: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
module.exports = mongoose.model('SharedLink', shareSchema);
|
||||
|
|
@ -7,6 +7,13 @@ const tokenSchema = new Schema({
|
|||
required: true,
|
||||
ref: 'user',
|
||||
},
|
||||
email: {
|
||||
type: String,
|
||||
},
|
||||
type: String,
|
||||
identifier: {
|
||||
type: String,
|
||||
},
|
||||
token: {
|
||||
type: String,
|
||||
required: true,
|
||||
|
|
@ -15,8 +22,17 @@ const tokenSchema = new Schema({
|
|||
type: Date,
|
||||
required: true,
|
||||
default: Date.now,
|
||||
expires: 900,
|
||||
},
|
||||
expiresAt: {
|
||||
type: Date,
|
||||
required: true,
|
||||
},
|
||||
metadata: {
|
||||
type: Map,
|
||||
of: Schema.Types.Mixed,
|
||||
},
|
||||
});
|
||||
|
||||
module.exports = mongoose.model('Token', tokenSchema);
|
||||
tokenSchema.index({ expiresAt: 1 }, { expireAfterSeconds: 0 });
|
||||
|
||||
module.exports = tokenSchema;
|
||||
|
|
|
|||
54
api/models/schema/toolCallSchema.js
Normal file
54
api/models/schema/toolCallSchema.js
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ToolCallData
|
||||
* @property {string} conversationId - The ID of the conversation
|
||||
* @property {string} messageId - The ID of the message
|
||||
* @property {string} toolId - The ID of the tool
|
||||
* @property {string | ObjectId} user - The user's ObjectId
|
||||
* @property {unknown} [result] - Optional result data
|
||||
* @property {TAttachment[]} [attachments] - Optional attachments data
|
||||
* @property {number} [blockIndex] - Optional code block index
|
||||
* @property {number} [partIndex] - Optional part index
|
||||
*/
|
||||
|
||||
/** @type {MongooseSchema<ToolCallData>} */
|
||||
const toolCallSchema = mongoose.Schema(
|
||||
{
|
||||
conversationId: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
messageId: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
toolId: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
user: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
result: {
|
||||
type: mongoose.Schema.Types.Mixed,
|
||||
},
|
||||
attachments: {
|
||||
type: mongoose.Schema.Types.Mixed,
|
||||
},
|
||||
blockIndex: {
|
||||
type: Number,
|
||||
},
|
||||
partIndex: {
|
||||
type: Number,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
toolCallSchema.index({ messageId: 1, user: 1 });
|
||||
toolCallSchema.index({ conversationId: 1, user: 1 });
|
||||
|
||||
module.exports = mongoose.model('ToolCall', toolCallSchema);
|
||||
|
|
@ -30,6 +30,9 @@ const transactionSchema = mongoose.Schema(
|
|||
rate: Number,
|
||||
rawAmount: Number,
|
||||
tokenValue: Number,
|
||||
inputTokens: { type: Number },
|
||||
writeTokens: { type: Number },
|
||||
readTokens: { type: Number },
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,37 @@
|
|||
const mongoose = require('mongoose');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoSession
|
||||
* @property {string} [refreshToken] - The refresh token
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoUser
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} [name] - The user's name
|
||||
* @property {string} [username] - The user's username, in lowercase
|
||||
* @property {string} email - The user's email address
|
||||
* @property {boolean} emailVerified - Whether the user's email is verified
|
||||
* @property {string} [password] - The user's password, trimmed with 8-128 characters
|
||||
* @property {string} [avatar] - The URL of the user's avatar
|
||||
* @property {string} provider - The provider of the user's account (e.g., 'local', 'google')
|
||||
* @property {string} [role='USER'] - The role of the user
|
||||
* @property {string} [googleId] - Optional Google ID for the user
|
||||
* @property {string} [facebookId] - Optional Facebook ID for the user
|
||||
* @property {string} [openidId] - Optional OpenID ID for the user
|
||||
* @property {string} [ldapId] - Optional LDAP ID for the user
|
||||
* @property {string} [githubId] - Optional GitHub ID for the user
|
||||
* @property {string} [discordId] - Optional Discord ID for the user
|
||||
* @property {string} [appleId] - Optional Apple ID for the user
|
||||
* @property {Array} [plugins=[]] - List of plugins used by the user
|
||||
* @property {Array.<MongoSession>} [refreshToken] - List of sessions with refresh tokens
|
||||
* @property {Date} [expiresAt] - Optional expiration date of the file
|
||||
* @property {Date} [createdAt] - Date when the user was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the user was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
/** @type {MongooseSchema<MongoSession>} */
|
||||
const Session = mongoose.Schema({
|
||||
refreshToken: {
|
||||
type: String,
|
||||
|
|
@ -7,6 +39,7 @@ const Session = mongoose.Schema({
|
|||
},
|
||||
});
|
||||
|
||||
/** @type {MongooseSchema<MongoUser>} */
|
||||
const userSchema = mongoose.Schema(
|
||||
{
|
||||
name: {
|
||||
|
|
@ -47,7 +80,7 @@ const userSchema = mongoose.Schema(
|
|||
},
|
||||
role: {
|
||||
type: String,
|
||||
default: 'USER',
|
||||
default: SystemRoles.USER,
|
||||
},
|
||||
googleId: {
|
||||
type: String,
|
||||
|
|
@ -58,12 +91,22 @@ const userSchema = mongoose.Schema(
|
|||
openidId: {
|
||||
type: String,
|
||||
},
|
||||
ldapId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
githubId: {
|
||||
type: String,
|
||||
},
|
||||
discordId: {
|
||||
type: String,
|
||||
},
|
||||
appleId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
plugins: {
|
||||
type: Array,
|
||||
default: [],
|
||||
|
|
@ -71,7 +114,16 @@ const userSchema = mongoose.Schema(
|
|||
refreshToken: {
|
||||
type: [Session],
|
||||
},
|
||||
expiresAt: {
|
||||
type: Date,
|
||||
expires: 604800, // 7 days in seconds
|
||||
},
|
||||
termsAccepted: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
},
|
||||
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const Transaction = require('./Transaction');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
|
|
@ -11,7 +11,7 @@ const { logger } = require('~/config');
|
|||
* @param {String} txData.conversationId - The ID of the conversation.
|
||||
* @param {String} txData.model - The model name.
|
||||
* @param {String} txData.context - The context in which the transaction is made.
|
||||
* @param {String} [txData.endpointTokenConfig] - The current endpoint token config.
|
||||
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
|
||||
* @param {String} [txData.valueKey] - The value key (optional).
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
|
|
@ -21,44 +21,120 @@ const { logger } = require('~/config');
|
|||
*/
|
||||
const spendTokens = async (txData, tokenUsage) => {
|
||||
const { promptTokens, completionTokens } = tokenUsage;
|
||||
logger.debug(`[spendTokens] conversationId: ${txData.conversationId} | Token usage: `, {
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
});
|
||||
logger.debug(
|
||||
`[spendTokens] conversationId: ${txData.conversationId}${
|
||||
txData?.context ? ` | Context: ${txData?.context}` : ''
|
||||
} | Token usage: `,
|
||||
{
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
},
|
||||
);
|
||||
let prompt, completion;
|
||||
try {
|
||||
if (promptTokens >= 0) {
|
||||
if (promptTokens !== undefined) {
|
||||
prompt = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -promptTokens,
|
||||
rawAmount: -Math.max(promptTokens, 0),
|
||||
});
|
||||
}
|
||||
|
||||
if (!completionTokens) {
|
||||
logger.debug('[spendTokens] !completionTokens', { prompt, completion });
|
||||
return;
|
||||
if (completionTokens !== undefined) {
|
||||
completion = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -Math.max(completionTokens, 0),
|
||||
});
|
||||
}
|
||||
|
||||
completion = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -completionTokens,
|
||||
});
|
||||
|
||||
prompt &&
|
||||
completion &&
|
||||
if (prompt || completion) {
|
||||
logger.debug('[spendTokens] Transaction data record against balance:', {
|
||||
user: prompt.user,
|
||||
prompt: prompt.prompt,
|
||||
promptRate: prompt.rate,
|
||||
completion: completion.completion,
|
||||
completionRate: completion.rate,
|
||||
balance: completion.balance,
|
||||
user: txData.user,
|
||||
prompt: prompt?.prompt,
|
||||
promptRate: prompt?.rate,
|
||||
completion: completion?.completion,
|
||||
completionRate: completion?.rate,
|
||||
balance: completion?.balance ?? prompt?.balance,
|
||||
});
|
||||
} else {
|
||||
logger.debug('[spendTokens] No transactions incurred against balance');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[spendTokens]', err);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = spendTokens;
|
||||
/**
|
||||
* Creates transactions to record the spending of structured tokens.
|
||||
*
|
||||
* @function
|
||||
* @async
|
||||
* @param {Object} txData - Transaction data.
|
||||
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
|
||||
* @param {String} txData.conversationId - The ID of the conversation.
|
||||
* @param {String} txData.model - The model name.
|
||||
* @param {String} txData.context - The context in which the transaction is made.
|
||||
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
|
||||
* @param {String} [txData.valueKey] - The value key (optional).
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Object} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens.input - The number of input tokens.
|
||||
* @param {Number} tokenUsage.promptTokens.write - The number of write tokens.
|
||||
* @param {Number} tokenUsage.promptTokens.read - The number of read tokens.
|
||||
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
|
||||
* @returns {Promise<void>} - Returns nothing.
|
||||
* @throws {Error} - Throws an error if there's an issue creating the transactions.
|
||||
*/
|
||||
const spendStructuredTokens = async (txData, tokenUsage) => {
|
||||
const { promptTokens, completionTokens } = tokenUsage;
|
||||
logger.debug(
|
||||
`[spendStructuredTokens] conversationId: ${txData.conversationId}${
|
||||
txData?.context ? ` | Context: ${txData?.context}` : ''
|
||||
} | Token usage: `,
|
||||
{
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
},
|
||||
);
|
||||
let prompt, completion;
|
||||
try {
|
||||
if (promptTokens) {
|
||||
const { input = 0, write = 0, read = 0 } = promptTokens;
|
||||
prompt = await Transaction.createStructured({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -input,
|
||||
writeTokens: -write,
|
||||
readTokens: -read,
|
||||
});
|
||||
}
|
||||
|
||||
if (completionTokens) {
|
||||
completion = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -completionTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (prompt || completion) {
|
||||
logger.debug('[spendStructuredTokens] Transaction data record against balance:', {
|
||||
user: txData.user,
|
||||
prompt: prompt?.prompt,
|
||||
promptRate: prompt?.rate,
|
||||
completion: completion?.completion,
|
||||
completionRate: completion?.rate,
|
||||
balance: completion?.balance ?? prompt?.balance,
|
||||
});
|
||||
} else {
|
||||
logger.debug('[spendStructuredTokens] No transactions incurred against balance');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[spendStructuredTokens]', err);
|
||||
}
|
||||
|
||||
return { prompt, completion };
|
||||
};
|
||||
|
||||
module.exports = { spendTokens, spendStructuredTokens };
|
||||
|
|
|
|||
197
api/models/spendTokens.spec.js
Normal file
197
api/models/spendTokens.spec.js
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
const mongoose = require('mongoose');
|
||||
|
||||
jest.mock('./Transaction', () => ({
|
||||
Transaction: {
|
||||
create: jest.fn(),
|
||||
createStructured: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('./Balance', () => ({
|
||||
findOne: jest.fn(),
|
||||
findOneAndUpdate: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Import after mocking
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const Balance = require('./Balance');
|
||||
describe('spendTokens', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
process.env.CHECK_BALANCE = 'true';
|
||||
});
|
||||
|
||||
it('should create transactions for both prompt and completion tokens', async () => {
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 });
|
||||
Balance.findOne.mockResolvedValue({ tokenCredits: 10000 });
|
||||
Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 });
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).toHaveBeenCalledTimes(2);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -100,
|
||||
}),
|
||||
);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'completion',
|
||||
rawAmount: -50,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle zero completion tokens', async () => {
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 0,
|
||||
};
|
||||
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -0 });
|
||||
Balance.findOne.mockResolvedValue({ tokenCredits: 10000 });
|
||||
Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 });
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).toHaveBeenCalledTimes(2);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -100,
|
||||
}),
|
||||
);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'completion',
|
||||
rawAmount: -0, // Changed from 0 to -0
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle undefined token counts', async () => {
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
};
|
||||
const tokenUsage = {};
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not update balance when CHECK_BALANCE is false', async () => {
|
||||
process.env.CHECK_BALANCE = 'false';
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
|
||||
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 });
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.create).toHaveBeenCalledTimes(2);
|
||||
expect(Balance.findOne).not.toHaveBeenCalled();
|
||||
expect(Balance.findOneAndUpdate).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should create structured transactions for both prompt and completion tokens', async () => {
|
||||
const txData = {
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
conversationId: 'test-convo',
|
||||
model: 'claude-3-5-sonnet',
|
||||
context: 'test',
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
input: 10,
|
||||
write: 100,
|
||||
read: 5,
|
||||
},
|
||||
completionTokens: 50,
|
||||
};
|
||||
|
||||
Transaction.createStructured.mockResolvedValueOnce({
|
||||
rate: 3.75,
|
||||
user: txData.user.toString(),
|
||||
balance: 9570,
|
||||
prompt: -430,
|
||||
});
|
||||
Transaction.create.mockResolvedValueOnce({
|
||||
rate: 15,
|
||||
user: txData.user.toString(),
|
||||
balance: 8820,
|
||||
completion: -750,
|
||||
});
|
||||
|
||||
const result = await spendStructuredTokens(txData, tokenUsage);
|
||||
|
||||
expect(Transaction.createStructured).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
}),
|
||||
);
|
||||
expect(Transaction.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tokenType: 'completion',
|
||||
rawAmount: -50,
|
||||
}),
|
||||
);
|
||||
expect(result).toEqual({
|
||||
prompt: expect.objectContaining({
|
||||
rate: 3.75,
|
||||
user: txData.user.toString(),
|
||||
balance: 9570,
|
||||
prompt: -430,
|
||||
}),
|
||||
completion: expect.objectContaining({
|
||||
rate: 15,
|
||||
user: txData.user.toString(),
|
||||
balance: 8820,
|
||||
completion: -750,
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
196
api/models/tx.js
196
api/models/tx.js
|
|
@ -1,24 +1,131 @@
|
|||
const { matchModelName } = require('../utils');
|
||||
const defaultRate = 6;
|
||||
|
||||
/**
|
||||
* AWS Bedrock pricing
|
||||
* source: https://aws.amazon.com/bedrock/pricing/
|
||||
* */
|
||||
const bedrockValues = {
|
||||
// Basic llama2 patterns
|
||||
'llama2-13b': { prompt: 0.75, completion: 1.0 },
|
||||
'llama2:13b': { prompt: 0.75, completion: 1.0 },
|
||||
'llama2:70b': { prompt: 1.95, completion: 2.56 },
|
||||
'llama2-70b': { prompt: 1.95, completion: 2.56 },
|
||||
|
||||
// Basic llama3 patterns
|
||||
'llama3-8b': { prompt: 0.3, completion: 0.6 },
|
||||
'llama3:8b': { prompt: 0.3, completion: 0.6 },
|
||||
'llama3-70b': { prompt: 2.65, completion: 3.5 },
|
||||
'llama3:70b': { prompt: 2.65, completion: 3.5 },
|
||||
|
||||
// llama3-x-Nb pattern
|
||||
'llama3-1-8b': { prompt: 0.22, completion: 0.22 },
|
||||
'llama3-1-70b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama3-1-405b': { prompt: 2.4, completion: 2.4 },
|
||||
'llama3-2-1b': { prompt: 0.1, completion: 0.1 },
|
||||
'llama3-2-3b': { prompt: 0.15, completion: 0.15 },
|
||||
'llama3-2-11b': { prompt: 0.16, completion: 0.16 },
|
||||
'llama3-2-90b': { prompt: 0.72, completion: 0.72 },
|
||||
|
||||
// llama3.x:Nb pattern
|
||||
'llama3.1:8b': { prompt: 0.22, completion: 0.22 },
|
||||
'llama3.1:70b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama3.1:405b': { prompt: 2.4, completion: 2.4 },
|
||||
'llama3.2:1b': { prompt: 0.1, completion: 0.1 },
|
||||
'llama3.2:3b': { prompt: 0.15, completion: 0.15 },
|
||||
'llama3.2:11b': { prompt: 0.16, completion: 0.16 },
|
||||
'llama3.2:90b': { prompt: 0.72, completion: 0.72 },
|
||||
|
||||
// llama-3.x-Nb pattern
|
||||
'llama-3.1-8b': { prompt: 0.22, completion: 0.22 },
|
||||
'llama-3.1-70b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama-3.1-405b': { prompt: 2.4, completion: 2.4 },
|
||||
'llama-3.2-1b': { prompt: 0.1, completion: 0.1 },
|
||||
'llama-3.2-3b': { prompt: 0.15, completion: 0.15 },
|
||||
'llama-3.2-11b': { prompt: 0.16, completion: 0.16 },
|
||||
'llama-3.2-90b': { prompt: 0.72, completion: 0.72 },
|
||||
'llama-3.3-70b': { prompt: 2.65, completion: 3.5 },
|
||||
'mistral-7b': { prompt: 0.15, completion: 0.2 },
|
||||
'mistral-small': { prompt: 0.15, completion: 0.2 },
|
||||
'mixtral-8x7b': { prompt: 0.45, completion: 0.7 },
|
||||
'mistral-large-2402': { prompt: 4.0, completion: 12.0 },
|
||||
'mistral-large-2407': { prompt: 3.0, completion: 9.0 },
|
||||
'command-text': { prompt: 1.5, completion: 2.0 },
|
||||
'command-light': { prompt: 0.3, completion: 0.6 },
|
||||
'ai21.j2-mid-v1': { prompt: 12.5, completion: 12.5 },
|
||||
'ai21.j2-ultra-v1': { prompt: 18.8, completion: 18.8 },
|
||||
'ai21.jamba-instruct-v1:0': { prompt: 0.5, completion: 0.7 },
|
||||
'amazon.titan-text-lite-v1': { prompt: 0.15, completion: 0.2 },
|
||||
'amazon.titan-text-express-v1': { prompt: 0.2, completion: 0.6 },
|
||||
'amazon.titan-text-premier-v1:0': { prompt: 0.5, completion: 1.5 },
|
||||
'amazon.nova-micro-v1:0': { prompt: 0.035, completion: 0.14 },
|
||||
'amazon.nova-lite-v1:0': { prompt: 0.06, completion: 0.24 },
|
||||
'amazon.nova-pro-v1:0': { prompt: 0.8, completion: 3.2 },
|
||||
};
|
||||
|
||||
/**
|
||||
* Mapping of model token sizes to their respective multipliers for prompt and completion.
|
||||
* The rates are 1 USD per 1M tokens.
|
||||
* @type {Object.<string, {prompt: number, completion: number}>}
|
||||
*/
|
||||
const tokenValues = {
|
||||
'8k': { prompt: 30, completion: 60 },
|
||||
'32k': { prompt: 60, completion: 120 },
|
||||
'4k': { prompt: 1.5, completion: 2 },
|
||||
'16k': { prompt: 3, completion: 4 },
|
||||
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
||||
'claude-3-opus': { prompt: 15, completion: 75 },
|
||||
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
||||
'claude-2.1': { prompt: 8, completion: 24 },
|
||||
'claude-2': { prompt: 8, completion: 24 },
|
||||
'claude-': { prompt: 0.8, completion: 2.4 },
|
||||
const tokenValues = Object.assign(
|
||||
{
|
||||
'8k': { prompt: 30, completion: 60 },
|
||||
'32k': { prompt: 60, completion: 120 },
|
||||
'4k': { prompt: 1.5, completion: 2 },
|
||||
'16k': { prompt: 3, completion: 4 },
|
||||
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
|
||||
'o3-mini': { prompt: 1.1, completion: 4.4 },
|
||||
'o1-mini': { prompt: 1.1, completion: 4.4 },
|
||||
'o1-preview': { prompt: 15, completion: 60 },
|
||||
o1: { prompt: 15, completion: 60 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-4o': { prompt: 2.5, completion: 10 },
|
||||
'gpt-4o-2024-05-13': { prompt: 5, completion: 15 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
||||
'claude-3-opus': { prompt: 15, completion: 75 },
|
||||
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3.5-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-5-haiku': { prompt: 0.8, completion: 4 },
|
||||
'claude-3.5-haiku': { prompt: 0.8, completion: 4 },
|
||||
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
||||
'claude-2.1': { prompt: 8, completion: 24 },
|
||||
'claude-2': { prompt: 8, completion: 24 },
|
||||
'claude-instant': { prompt: 0.8, completion: 2.4 },
|
||||
'claude-': { prompt: 0.8, completion: 2.4 },
|
||||
'command-r-plus': { prompt: 3, completion: 15 },
|
||||
'command-r': { prompt: 0.5, completion: 1.5 },
|
||||
'deepseek-reasoner': { prompt: 0.55, completion: 2.19 },
|
||||
deepseek: { prompt: 0.14, completion: 0.28 },
|
||||
/* cohere doesn't have rates for the older command models,
|
||||
so this was from https://artificialanalysis.ai/models/command-light/providers */
|
||||
command: { prompt: 0.38, completion: 0.38 },
|
||||
'gemini-2.0-flash-lite': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-2.0-flash': { prompt: 0.1, completion: 0.7 },
|
||||
'gemini-2.0': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing
|
||||
'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 },
|
||||
'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 },
|
||||
'gemini-1.5': { prompt: 2.5, completion: 10 },
|
||||
'gemini-pro-vision': { prompt: 0.5, completion: 1.5 },
|
||||
gemini: { prompt: 0.5, completion: 1.5 },
|
||||
},
|
||||
bedrockValues,
|
||||
);
|
||||
|
||||
/**
|
||||
* Mapping of model token sizes to their respective multipliers for cached input, read and write.
|
||||
* See Anthropic's documentation on this: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#pricing
|
||||
* The rates are 1 USD per 1M tokens.
|
||||
* @type {Object.<string, {write: number, read: number }>}
|
||||
*/
|
||||
const cacheTokenValues = {
|
||||
'claude-3.5-sonnet': { write: 3.75, read: 0.3 },
|
||||
'claude-3-5-sonnet': { write: 3.75, read: 0.3 },
|
||||
'claude-3.5-haiku': { write: 1, read: 0.08 },
|
||||
'claude-3-5-haiku': { write: 1, read: 0.08 },
|
||||
'claude-3-haiku': { write: 0.3, read: 0.03 },
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -42,6 +149,20 @@ const getValueKey = (model, endpoint) => {
|
|||
return 'gpt-3.5-turbo-1106';
|
||||
} else if (modelName.includes('gpt-3.5')) {
|
||||
return '4k';
|
||||
} else if (modelName.includes('o1-preview')) {
|
||||
return 'o1-preview';
|
||||
} else if (modelName.includes('o1-mini')) {
|
||||
return 'o1-mini';
|
||||
} else if (modelName.includes('o1')) {
|
||||
return 'o1';
|
||||
} else if (modelName.includes('gpt-4o-2024-05-13')) {
|
||||
return 'gpt-4o-2024-05-13';
|
||||
} else if (modelName.includes('gpt-4o-mini')) {
|
||||
return 'gpt-4o-mini';
|
||||
} else if (modelName.includes('gpt-4o')) {
|
||||
return 'gpt-4o';
|
||||
} else if (modelName.includes('gpt-4-vision')) {
|
||||
return 'gpt-4-1106';
|
||||
} else if (modelName.includes('gpt-4-1106')) {
|
||||
return 'gpt-4-1106';
|
||||
} else if (modelName.includes('gpt-4-0125')) {
|
||||
|
|
@ -65,7 +186,7 @@ const getValueKey = (model, endpoint) => {
|
|||
*
|
||||
* @param {Object} params - The parameters for the function.
|
||||
* @param {string} [params.valueKey] - The key corresponding to the model name.
|
||||
* @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion').
|
||||
* @param {'prompt' | 'completion'} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion').
|
||||
* @param {string} [params.model] - The model name to derive the value key from if not provided.
|
||||
* @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided.
|
||||
* @param {EndpointTokenConfig} [params.endpointTokenConfig] - The token configuration for the endpoint.
|
||||
|
|
@ -90,7 +211,48 @@ const getMultiplier = ({ valueKey, tokenType, model, endpoint, endpointTokenConf
|
|||
}
|
||||
|
||||
// If we got this far, and values[tokenType] is undefined somehow, return a rough average of default multipliers
|
||||
return tokenValues[valueKey][tokenType] ?? defaultRate;
|
||||
return tokenValues[valueKey]?.[tokenType] ?? defaultRate;
|
||||
};
|
||||
|
||||
module.exports = { tokenValues, getValueKey, getMultiplier, defaultRate };
|
||||
/**
|
||||
* Retrieves the cache multiplier for a given value key and token type. If no value key is provided,
|
||||
* it attempts to derive it from the model name.
|
||||
*
|
||||
* @param {Object} params - The parameters for the function.
|
||||
* @param {string} [params.valueKey] - The key corresponding to the model name.
|
||||
* @param {'write' | 'read'} [params.cacheType] - The type of token (e.g., 'write' or 'read').
|
||||
* @param {string} [params.model] - The model name to derive the value key from if not provided.
|
||||
* @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided.
|
||||
* @param {EndpointTokenConfig} [params.endpointTokenConfig] - The token configuration for the endpoint.
|
||||
* @returns {number | null} The multiplier for the given parameters, or `null` if not found.
|
||||
*/
|
||||
const getCacheMultiplier = ({ valueKey, cacheType, model, endpoint, endpointTokenConfig }) => {
|
||||
if (endpointTokenConfig) {
|
||||
return endpointTokenConfig?.[model]?.[cacheType] ?? null;
|
||||
}
|
||||
|
||||
if (valueKey && cacheType) {
|
||||
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
|
||||
}
|
||||
|
||||
if (!cacheType || !model) {
|
||||
return null;
|
||||
}
|
||||
|
||||
valueKey = getValueKey(model, endpoint);
|
||||
if (!valueKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// If we got this far, and values[cacheType] is undefined somehow, return a rough average of default multipliers
|
||||
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
tokenValues,
|
||||
getValueKey,
|
||||
getMultiplier,
|
||||
getCacheMultiplier,
|
||||
defaultRate,
|
||||
cacheTokenValues,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,4 +1,12 @@
|
|||
const { getValueKey, getMultiplier, defaultRate, tokenValues } = require('./tx');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
defaultRate,
|
||||
tokenValues,
|
||||
getValueKey,
|
||||
getMultiplier,
|
||||
cacheTokenValues,
|
||||
getCacheMultiplier,
|
||||
} = require('./tx');
|
||||
|
||||
describe('getValueKey', () => {
|
||||
it('should return "16k" for model name containing "gpt-3.5-turbo-16k"', () => {
|
||||
|
|
@ -34,6 +42,71 @@ describe('getValueKey', () => {
|
|||
expect(getValueKey('openai/gpt-4-1106')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('gpt-4-1106/openai/')).toBe('gpt-4-1106');
|
||||
});
|
||||
|
||||
it('should return "gpt-4-1106" for model type of "gpt-4-1106"', () => {
|
||||
expect(getValueKey('gpt-4-vision-preview')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('openai/gpt-4-1106')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('gpt-4-turbo')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('gpt-4-0125')).toBe('gpt-4-1106');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o" for model type of "gpt-4o"', () => {
|
||||
expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o');
|
||||
expect(getValueKey('openai/gpt-4o')).toBe('gpt-4o');
|
||||
expect(getValueKey('openai/gpt-4o-2024-08-06')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-turbo')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o-mini" for model type of "gpt-4o-mini"', () => {
|
||||
expect(getValueKey('gpt-4o-mini-2024-07-18')).toBe('gpt-4o-mini');
|
||||
expect(getValueKey('openai/gpt-4o-mini')).toBe('gpt-4o-mini');
|
||||
expect(getValueKey('gpt-4o-mini-0718')).toBe('gpt-4o-mini');
|
||||
expect(getValueKey('gpt-4o-2024-08-06-0718')).not.toBe('gpt-4o-mini');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o-2024-05-13" for model type of "gpt-4o-2024-05-13"', () => {
|
||||
expect(getValueKey('gpt-4o-2024-05-13')).toBe('gpt-4o-2024-05-13');
|
||||
expect(getValueKey('openai/gpt-4o-2024-05-13')).toBe('gpt-4o-2024-05-13');
|
||||
expect(getValueKey('gpt-4o-2024-05-13-0718')).toBe('gpt-4o-2024-05-13');
|
||||
expect(getValueKey('gpt-4o-2024-05-13-0718')).not.toBe('gpt-4o');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o" for model type of "chatgpt-4o"', () => {
|
||||
expect(getValueKey('chatgpt-4o-latest')).toBe('gpt-4o');
|
||||
expect(getValueKey('openai/chatgpt-4o-latest')).toBe('gpt-4o');
|
||||
expect(getValueKey('chatgpt-4o-latest-0916')).toBe('gpt-4o');
|
||||
expect(getValueKey('chatgpt-4o-latest-0718')).toBe('gpt-4o');
|
||||
});
|
||||
|
||||
it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => {
|
||||
expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-turbo')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-0125')).toBe('claude-3-5-sonnet');
|
||||
});
|
||||
|
||||
it('should return "claude-3.5-sonnet" for model type of "claude-3.5-sonnet-"', () => {
|
||||
expect(getValueKey('claude-3.5-sonnet-20240620')).toBe('claude-3.5-sonnet');
|
||||
expect(getValueKey('anthropic/claude-3.5-sonnet')).toBe('claude-3.5-sonnet');
|
||||
expect(getValueKey('claude-3.5-sonnet-turbo')).toBe('claude-3.5-sonnet');
|
||||
expect(getValueKey('claude-3.5-sonnet-0125')).toBe('claude-3.5-sonnet');
|
||||
});
|
||||
|
||||
it('should return "claude-3-5-haiku" for model type of "claude-3-5-haiku-"', () => {
|
||||
expect(getValueKey('claude-3-5-haiku-20240620')).toBe('claude-3-5-haiku');
|
||||
expect(getValueKey('anthropic/claude-3-5-haiku')).toBe('claude-3-5-haiku');
|
||||
expect(getValueKey('claude-3-5-haiku-turbo')).toBe('claude-3-5-haiku');
|
||||
expect(getValueKey('claude-3-5-haiku-0125')).toBe('claude-3-5-haiku');
|
||||
});
|
||||
|
||||
it('should return "claude-3.5-haiku" for model type of "claude-3.5-haiku-"', () => {
|
||||
expect(getValueKey('claude-3.5-haiku-20240620')).toBe('claude-3.5-haiku');
|
||||
expect(getValueKey('anthropic/claude-3.5-haiku')).toBe('claude-3.5-haiku');
|
||||
expect(getValueKey('claude-3.5-haiku-turbo')).toBe('claude-3.5-haiku');
|
||||
expect(getValueKey('claude-3.5-haiku-0125')).toBe('claude-3.5-haiku');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMultiplier', () => {
|
||||
|
|
@ -77,6 +150,41 @@ describe('getMultiplier', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o', () => {
|
||||
const valueKey = getValueKey('gpt-4o-2024-08-06');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-4o'].completion,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).not.toBe(
|
||||
tokenValues['gpt-4-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o-mini', () => {
|
||||
const valueKey = getValueKey('gpt-4o-mini-2024-07-18');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-4o-mini'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-4o-mini'].completion,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).not.toBe(
|
||||
tokenValues['gpt-4-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for chatgpt-4o-latest', () => {
|
||||
const valueKey = getValueKey('chatgpt-4o-latest');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-4o'].completion,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).not.toBe(
|
||||
tokenValues['gpt-4o-mini'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should derive the valueKey from the model if not provided for new models', () => {
|
||||
expect(
|
||||
getMultiplier({ tokenType: 'prompt', model: 'gpt-3.5-turbo-1106-some-other-info' }),
|
||||
|
|
@ -101,3 +209,252 @@ describe('getMultiplier', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('AWS Bedrock Model Tests', () => {
|
||||
const awsModels = [
|
||||
'anthropic.claude-3-5-haiku-20241022-v1:0',
|
||||
'anthropic.claude-3-haiku-20240307-v1:0',
|
||||
'anthropic.claude-3-sonnet-20240229-v1:0',
|
||||
'anthropic.claude-3-opus-20240229-v1:0',
|
||||
'anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
'anthropic.claude-v2:1',
|
||||
'anthropic.claude-instant-v1',
|
||||
'meta.llama2-13b-chat-v1',
|
||||
'meta.llama2-70b-chat-v1',
|
||||
'meta.llama3-8b-instruct-v1:0',
|
||||
'meta.llama3-70b-instruct-v1:0',
|
||||
'meta.llama3-1-8b-instruct-v1:0',
|
||||
'meta.llama3-1-70b-instruct-v1:0',
|
||||
'meta.llama3-1-405b-instruct-v1:0',
|
||||
'mistral.mistral-7b-instruct-v0:2',
|
||||
'mistral.mistral-small-2402-v1:0',
|
||||
'mistral.mixtral-8x7b-instruct-v0:1',
|
||||
'mistral.mistral-large-2402-v1:0',
|
||||
'mistral.mistral-large-2407-v1:0',
|
||||
'cohere.command-text-v14',
|
||||
'cohere.command-light-text-v14',
|
||||
'cohere.command-r-v1:0',
|
||||
'cohere.command-r-plus-v1:0',
|
||||
'ai21.j2-mid-v1',
|
||||
'ai21.j2-ultra-v1',
|
||||
'amazon.titan-text-lite-v1',
|
||||
'amazon.titan-text-express-v1',
|
||||
'amazon.nova-micro-v1:0',
|
||||
'amazon.nova-lite-v1:0',
|
||||
'amazon.nova-pro-v1:0',
|
||||
];
|
||||
|
||||
it('should return the correct prompt multipliers for all models', () => {
|
||||
const results = awsModels.map((model) => {
|
||||
const valueKey = getValueKey(model, EModelEndpoint.bedrock);
|
||||
const multiplier = getMultiplier({ valueKey, tokenType: 'prompt' });
|
||||
return tokenValues[valueKey].prompt && multiplier === tokenValues[valueKey].prompt;
|
||||
});
|
||||
expect(results.every(Boolean)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return the correct completion multipliers for all models', () => {
|
||||
const results = awsModels.map((model) => {
|
||||
const valueKey = getValueKey(model, EModelEndpoint.bedrock);
|
||||
const multiplier = getMultiplier({ valueKey, tokenType: 'completion' });
|
||||
return tokenValues[valueKey].completion && multiplier === tokenValues[valueKey].completion;
|
||||
});
|
||||
expect(results.every(Boolean)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Deepseek Model Tests', () => {
|
||||
const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner'];
|
||||
|
||||
it('should return the correct prompt multipliers for all models', () => {
|
||||
const results = deepseekModels.map((model) => {
|
||||
const valueKey = getValueKey(model);
|
||||
const multiplier = getMultiplier({ valueKey, tokenType: 'prompt' });
|
||||
return tokenValues[valueKey].prompt && multiplier === tokenValues[valueKey].prompt;
|
||||
});
|
||||
expect(results.every(Boolean)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return the correct completion multipliers for all models', () => {
|
||||
const results = deepseekModels.map((model) => {
|
||||
const valueKey = getValueKey(model);
|
||||
const multiplier = getMultiplier({ valueKey, tokenType: 'completion' });
|
||||
return tokenValues[valueKey].completion && multiplier === tokenValues[valueKey].completion;
|
||||
});
|
||||
expect(results.every(Boolean)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return the correct prompt multipliers for reasoning model', () => {
|
||||
const model = 'deepseek-reasoner';
|
||||
const valueKey = getValueKey(model);
|
||||
expect(valueKey).toBe(model);
|
||||
const multiplier = getMultiplier({ valueKey, tokenType: 'prompt' });
|
||||
const result = tokenValues[valueKey].prompt && multiplier === tokenValues[valueKey].prompt;
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCacheMultiplier', () => {
|
||||
it('should return the correct cache multiplier for a given valueKey and cacheType', () => {
|
||||
expect(getCacheMultiplier({ valueKey: 'claude-3-5-sonnet', cacheType: 'write' })).toBe(
|
||||
cacheTokenValues['claude-3-5-sonnet'].write,
|
||||
);
|
||||
expect(getCacheMultiplier({ valueKey: 'claude-3-5-sonnet', cacheType: 'read' })).toBe(
|
||||
cacheTokenValues['claude-3-5-sonnet'].read,
|
||||
);
|
||||
expect(getCacheMultiplier({ valueKey: 'claude-3-5-haiku', cacheType: 'write' })).toBe(
|
||||
cacheTokenValues['claude-3-5-haiku'].write,
|
||||
);
|
||||
expect(getCacheMultiplier({ valueKey: 'claude-3-5-haiku', cacheType: 'read' })).toBe(
|
||||
cacheTokenValues['claude-3-5-haiku'].read,
|
||||
);
|
||||
expect(getCacheMultiplier({ valueKey: 'claude-3-haiku', cacheType: 'write' })).toBe(
|
||||
cacheTokenValues['claude-3-haiku'].write,
|
||||
);
|
||||
expect(getCacheMultiplier({ valueKey: 'claude-3-haiku', cacheType: 'read' })).toBe(
|
||||
cacheTokenValues['claude-3-haiku'].read,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null if cacheType is provided but not found in cacheTokenValues', () => {
|
||||
expect(
|
||||
getCacheMultiplier({ valueKey: 'claude-3-5-sonnet', cacheType: 'unknownType' }),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it('should derive the valueKey from the model if not provided', () => {
|
||||
expect(getCacheMultiplier({ cacheType: 'write', model: 'claude-3-5-sonnet-20240620' })).toBe(
|
||||
3.75,
|
||||
);
|
||||
expect(getCacheMultiplier({ cacheType: 'read', model: 'claude-3-haiku-20240307' })).toBe(0.03);
|
||||
});
|
||||
|
||||
it('should return null if only model or cacheType is missing', () => {
|
||||
expect(getCacheMultiplier({ cacheType: 'write' })).toBeNull();
|
||||
expect(getCacheMultiplier({ model: 'claude-3-5-sonnet' })).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null if derived valueKey does not match any known patterns', () => {
|
||||
expect(getCacheMultiplier({ cacheType: 'write', model: 'gpt-4-some-other-info' })).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle endpointTokenConfig if provided', () => {
|
||||
const endpointTokenConfig = {
|
||||
'custom-model': {
|
||||
write: 5,
|
||||
read: 1,
|
||||
},
|
||||
};
|
||||
expect(
|
||||
getCacheMultiplier({ model: 'custom-model', cacheType: 'write', endpointTokenConfig }),
|
||||
).toBe(5);
|
||||
expect(
|
||||
getCacheMultiplier({ model: 'custom-model', cacheType: 'read', endpointTokenConfig }),
|
||||
).toBe(1);
|
||||
});
|
||||
|
||||
it('should return null if model is not found in endpointTokenConfig', () => {
|
||||
const endpointTokenConfig = {
|
||||
'custom-model': {
|
||||
write: 5,
|
||||
read: 1,
|
||||
},
|
||||
};
|
||||
expect(
|
||||
getCacheMultiplier({ model: 'unknown-model', cacheType: 'write', endpointTokenConfig }),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle models with "bedrock/" prefix', () => {
|
||||
expect(
|
||||
getCacheMultiplier({
|
||||
model: 'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
cacheType: 'write',
|
||||
}),
|
||||
).toBe(3.75);
|
||||
expect(
|
||||
getCacheMultiplier({
|
||||
model: 'bedrock/anthropic.claude-3-haiku-20240307-v1:0',
|
||||
cacheType: 'read',
|
||||
}),
|
||||
).toBe(0.03);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Google Model Tests', () => {
|
||||
const googleModels = [
|
||||
'gemini-2.0-flash-lite-preview-02-05',
|
||||
'gemini-2.0-flash-001',
|
||||
'gemini-2.0-flash-exp',
|
||||
'gemini-2.0-pro-exp-02-05',
|
||||
'gemini-1.5-flash-8b',
|
||||
'gemini-1.5-flash-thinking',
|
||||
'gemini-1.5-pro-latest',
|
||||
'gemini-1.5-pro-preview-0409',
|
||||
'gemini-pro-vision',
|
||||
'gemini-1.0',
|
||||
'gemini-pro',
|
||||
];
|
||||
|
||||
it('should return the correct prompt and completion rates for all models', () => {
|
||||
const results = googleModels.map((model) => {
|
||||
const valueKey = getValueKey(model, EModelEndpoint.google);
|
||||
const promptRate = getMultiplier({
|
||||
model,
|
||||
tokenType: 'prompt',
|
||||
endpoint: EModelEndpoint.google,
|
||||
});
|
||||
const completionRate = getMultiplier({
|
||||
model,
|
||||
tokenType: 'completion',
|
||||
endpoint: EModelEndpoint.google,
|
||||
});
|
||||
return { model, valueKey, promptRate, completionRate };
|
||||
});
|
||||
|
||||
results.forEach(({ valueKey, promptRate, completionRate }) => {
|
||||
expect(promptRate).toBe(tokenValues[valueKey].prompt);
|
||||
expect(completionRate).toBe(tokenValues[valueKey].completion);
|
||||
});
|
||||
});
|
||||
|
||||
it('should map to the correct model keys', () => {
|
||||
const expected = {
|
||||
'gemini-2.0-flash-lite-preview-02-05': 'gemini-2.0-flash-lite',
|
||||
'gemini-2.0-flash-001': 'gemini-2.0-flash',
|
||||
'gemini-2.0-flash-exp': 'gemini-2.0-flash',
|
||||
'gemini-2.0-pro-exp-02-05': 'gemini-2.0',
|
||||
'gemini-1.5-flash-8b': 'gemini-1.5-flash-8b',
|
||||
'gemini-1.5-flash-thinking': 'gemini-1.5-flash',
|
||||
'gemini-1.5-pro-latest': 'gemini-1.5',
|
||||
'gemini-1.5-pro-preview-0409': 'gemini-1.5',
|
||||
'gemini-pro-vision': 'gemini-pro-vision',
|
||||
'gemini-1.0': 'gemini',
|
||||
'gemini-pro': 'gemini',
|
||||
};
|
||||
|
||||
Object.entries(expected).forEach(([model, expectedKey]) => {
|
||||
const valueKey = getValueKey(model, EModelEndpoint.google);
|
||||
expect(valueKey).toBe(expectedKey);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle model names with different formats', () => {
|
||||
const testCases = [
|
||||
{ input: 'google/gemini-pro', expected: 'gemini' },
|
||||
{ input: 'gemini-pro/google', expected: 'gemini' },
|
||||
{ input: 'google/gemini-2.0-flash-lite', expected: 'gemini-2.0-flash-lite' },
|
||||
];
|
||||
|
||||
testCases.forEach(({ input, expected }) => {
|
||||
const valueKey = getValueKey(input, EModelEndpoint.google);
|
||||
expect(valueKey).toBe(expected);
|
||||
expect(
|
||||
getMultiplier({ model: input, tokenType: 'prompt', endpoint: EModelEndpoint.google }),
|
||||
).toBe(tokenValues[expected].prompt);
|
||||
expect(
|
||||
getMultiplier({ model: input, tokenType: 'completion', endpoint: EModelEndpoint.google }),
|
||||
).toBe(tokenValues[expected].completion);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,28 +1,39 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
const signPayload = require('~/server/services/signPayload');
|
||||
const { isEnabled } = require('~/server/utils/handleText');
|
||||
const Balance = require('./Balance');
|
||||
const User = require('./User');
|
||||
|
||||
const hashPassword = async (password) => {
|
||||
const hashedPassword = await new Promise((resolve, reject) => {
|
||||
bcrypt.hash(password, 10, function (err, hash) {
|
||||
if (err) {
|
||||
reject(err);
|
||||
} else {
|
||||
resolve(hash);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return hashedPassword;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve a user by ID and convert the found user document to a plain object.
|
||||
*
|
||||
* @param {string} userId - The ID of the user to find and return as a plain object.
|
||||
* @returns {Promise<Object>} A plain object representing the user document, or `null` if no user is found.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||
*/
|
||||
const getUser = async function (userId) {
|
||||
return await User.findById(userId).lean();
|
||||
const getUserById = async function (userId, fieldsToSelect = null) {
|
||||
const query = User.findById(userId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for a single user based on partial data and return matching user document as plain object.
|
||||
* @param {Partial<MongoUser>} searchCriteria - The partial data to use for searching the user.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||
*/
|
||||
const findUser = async function (searchCriteria, fieldsToSelect = null) {
|
||||
const query = User.findOne(searchCriteria);
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -30,17 +41,137 @@ const getUser = async function (userId) {
|
|||
*
|
||||
* @param {string} userId - The ID of the user to update.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<Object>} The updated user document as a plain object, or `null` if no user is found.
|
||||
* @returns {Promise<MongoUser>} The updated user document as a plain object, or `null` if no user is found.
|
||||
*/
|
||||
const updateUser = async function (userId, updateData) {
|
||||
return await User.findByIdAndUpdate(userId, updateData, {
|
||||
const updateOperation = {
|
||||
$set: updateData,
|
||||
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
|
||||
};
|
||||
return await User.findByIdAndUpdate(userId, updateOperation, {
|
||||
new: true,
|
||||
runValidators: true,
|
||||
}).lean();
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
hashPassword,
|
||||
updateUser,
|
||||
getUser,
|
||||
/**
|
||||
* Creates a new user, optionally with a TTL of 1 week.
|
||||
* @param {MongoUser} data - The user data to be created, must contain user_id.
|
||||
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
|
||||
* @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`.
|
||||
* @returns {Promise<ObjectId>} A promise that resolves to the created user document ID.
|
||||
* @throws {Error} If a user with the same user_id already exists.
|
||||
*/
|
||||
const createUser = async (data, disableTTL = true, returnUser = false) => {
|
||||
const userData = {
|
||||
...data,
|
||||
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
|
||||
};
|
||||
|
||||
if (disableTTL) {
|
||||
delete userData.expiresAt;
|
||||
}
|
||||
|
||||
const user = await User.create(userData);
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE) && process.env.START_BALANCE) {
|
||||
let incrementValue = parseInt(process.env.START_BALANCE);
|
||||
await Balance.findOneAndUpdate(
|
||||
{ user: user._id },
|
||||
{ $inc: { tokenCredits: incrementValue } },
|
||||
{ upsert: true, new: true },
|
||||
).lean();
|
||||
}
|
||||
|
||||
if (returnUser) {
|
||||
return user.toObject();
|
||||
}
|
||||
return user._id;
|
||||
};
|
||||
|
||||
/**
|
||||
* Count the number of user documents in the collection based on the provided filter.
|
||||
*
|
||||
* @param {Object} [filter={}] - The filter to apply when counting the documents.
|
||||
* @returns {Promise<number>} The count of documents that match the filter.
|
||||
*/
|
||||
const countUsers = async function (filter = {}) {
|
||||
return await User.countDocuments(filter);
|
||||
};
|
||||
|
||||
/**
|
||||
* Delete a user by their unique ID.
|
||||
*
|
||||
* @param {string} userId - The ID of the user to delete.
|
||||
* @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents.
|
||||
*/
|
||||
const deleteUserById = async function (userId) {
|
||||
try {
|
||||
const result = await User.deleteOne({ _id: userId });
|
||||
if (result.deletedCount === 0) {
|
||||
return { deletedCount: 0, message: 'No user found with that ID.' };
|
||||
}
|
||||
return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' };
|
||||
} catch (error) {
|
||||
throw new Error('Error deleting user: ' + error.message);
|
||||
}
|
||||
};
|
||||
|
||||
const { SESSION_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
||||
|
||||
/**
|
||||
* Generates a JWT token for a given user.
|
||||
*
|
||||
* @param {MongoUser} user - ID of the user for whom the token is being generated.
|
||||
* @returns {Promise<string>} A promise that resolves to a JWT token.
|
||||
*/
|
||||
const generateToken = async (user) => {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
return await signPayload({
|
||||
payload: {
|
||||
id: user._id,
|
||||
username: user.username,
|
||||
provider: user.provider,
|
||||
email: user.email,
|
||||
},
|
||||
secret: process.env.JWT_SECRET,
|
||||
expirationTime: expires / 1000,
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Compares the provided password with the user's password.
|
||||
*
|
||||
* @param {MongoUser} user - the user to compare password for.
|
||||
* @param {string} candidatePassword - The password to test against the user's password.
|
||||
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the password matches.
|
||||
*/
|
||||
const comparePassword = async (user, candidatePassword) => {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
|
||||
if (err) {
|
||||
reject(err);
|
||||
}
|
||||
resolve(isMatch);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
countUsers,
|
||||
createUser,
|
||||
updateUser,
|
||||
findUser,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue