mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-16 16:30:15 +01:00
🌿 feat: Fork Messages/Conversations (#2617)
* typedef for ImportBatchBuilder * feat: first pass, fork conversations * feat: fork - getMessagesUpToTargetLevel * fix: additional tests and fix getAllMessagesUpToParent * chore: arrow function return * refactor: fork 3 options * chore: remove unused genbuttons * chore: remove unused hover buttons code * feat: fork first pass * wip: fork remember setting * style: user icon * chore: move clear chats to data tab * WIP: fork UI options * feat: data-provider fork types/services/vars and use generic MutationOptions * refactor: use single param for fork option, use enum, fix mongo errors, use Date.now(), add records flag for testing, use endpoint from original convo and messages, pass originalConvo to finishConversation * feat: add fork mutation hook and consolidate type imports * refactor: use enum * feat: first pass, fork mutation * chore: add enum for target level fork option * chore: add enum for target level fork option * show toast when checking remember selection * feat: splitAtTarget * feat: split at target option * feat: navigate to new fork, show toasts, set result query data * feat: hover info for all fork options * refactor: add Messages settings tab * fix(Fork): remember text info * ci: test for single message and is target edge case * feat: additional tests for getAllMessagesUpToParent * ci: additional tests and cycle detection for getMessagesUpToTargetLevel * feat: circular dependency checks for getAllMessagesUpToParent * fix: getMessagesUpToTargetLevel circular dep. check * ci: more tests for getMessagesForConversation * style: hover text for checkbox fork items * refactor: add statefulness to conversation import
This commit is contained in:
parent
c8baceac76
commit
25fceb78b7
37 changed files with 1831 additions and 523 deletions
314
api/server/utils/import/fork.js
Normal file
314
api/server/utils/import/fork.js
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint, Constants, ForkOptions } = require('librechat-data-provider');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { getMessages } = require('~/models/Message');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {object} params - The parameters for the importer.
|
||||
* @param {string} params.originalConvoId - The ID of the conversation to fork.
|
||||
* @param {string} params.targetMessageId - The ID of the message to fork from.
|
||||
* @param {string} params.requestUserId - The ID of the user making the request.
|
||||
* @param {string} [params.newTitle] - Optional new title for the forked conversation uses old title if not provided
|
||||
* @param {string} [params.option=''] - Optional flag for fork option
|
||||
* @param {boolean} [params.records=false] - Optional flag for returning actual database records or resulting conversation and messages.
|
||||
* @param {boolean} [params.splitAtTarget=false] - Optional flag for splitting the messages at the target message level.
|
||||
* @param {string} [params.latestMessageId] - latestMessageId - Required if splitAtTarget is true.
|
||||
* @param {(userId: string) => ImportBatchBuilder} [params.builderFactory] - Optional factory function for creating an ImportBatchBuilder instance.
|
||||
* @returns {Promise<TForkConvoResponse>} The response after forking the conversation.
|
||||
*/
|
||||
async function forkConversation({
|
||||
originalConvoId,
|
||||
targetMessageId: targetId,
|
||||
requestUserId,
|
||||
newTitle,
|
||||
option = ForkOptions.TARGET_LEVEL,
|
||||
records = false,
|
||||
splitAtTarget = false,
|
||||
latestMessageId,
|
||||
builderFactory = createImportBatchBuilder,
|
||||
}) {
|
||||
try {
|
||||
const originalConvo = await getConvo(requestUserId, originalConvoId);
|
||||
let originalMessages = await getMessages({
|
||||
user: requestUserId,
|
||||
conversationId: originalConvoId,
|
||||
});
|
||||
|
||||
let targetMessageId = targetId;
|
||||
if (splitAtTarget && !latestMessageId) {
|
||||
throw new Error('Latest `messageId` is required for forking from target message.');
|
||||
} else if (splitAtTarget) {
|
||||
originalMessages = splitAtTargetLevel(originalMessages, targetId);
|
||||
targetMessageId = latestMessageId;
|
||||
}
|
||||
|
||||
const importBatchBuilder = builderFactory(requestUserId);
|
||||
importBatchBuilder.startConversation(originalConvo.endpoint ?? EModelEndpoint.openAI);
|
||||
|
||||
let messagesToClone = [];
|
||||
|
||||
if (option === ForkOptions.DIRECT_PATH) {
|
||||
// Direct path only
|
||||
messagesToClone = BaseClient.getMessagesForConversation({
|
||||
messages: originalMessages,
|
||||
parentMessageId: targetMessageId,
|
||||
});
|
||||
} else if (option === ForkOptions.INCLUDE_BRANCHES) {
|
||||
// Direct path and siblings
|
||||
messagesToClone = getAllMessagesUpToParent(originalMessages, targetMessageId);
|
||||
} else if (option === ForkOptions.TARGET_LEVEL || !option) {
|
||||
// Direct path, siblings, and all descendants
|
||||
messagesToClone = getMessagesUpToTargetLevel(originalMessages, targetMessageId);
|
||||
}
|
||||
|
||||
const idMapping = new Map();
|
||||
|
||||
for (const message of messagesToClone) {
|
||||
const newMessageId = uuidv4();
|
||||
idMapping.set(message.messageId, newMessageId);
|
||||
|
||||
const clonedMessage = {
|
||||
...message,
|
||||
messageId: newMessageId,
|
||||
parentMessageId:
|
||||
message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT
|
||||
? idMapping.get(message.parentMessageId)
|
||||
: Constants.NO_PARENT,
|
||||
};
|
||||
|
||||
importBatchBuilder.saveMessage(clonedMessage);
|
||||
}
|
||||
|
||||
const result = importBatchBuilder.finishConversation(
|
||||
newTitle || originalConvo.title,
|
||||
new Date(),
|
||||
originalConvo,
|
||||
);
|
||||
await importBatchBuilder.saveBatch();
|
||||
logger.debug(
|
||||
`user: ${requestUserId} | New conversation "${
|
||||
newTitle || originalConvo.title
|
||||
}" forked from conversation ID ${originalConvoId}`,
|
||||
);
|
||||
|
||||
if (!records) {
|
||||
return result;
|
||||
}
|
||||
|
||||
const conversation = await getConvo(requestUserId, result.conversation.conversationId);
|
||||
const messages = await getMessages({
|
||||
user: requestUserId,
|
||||
conversationId: conversation.conversationId,
|
||||
});
|
||||
|
||||
return {
|
||||
conversation,
|
||||
messages,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`user: ${requestUserId} | Error forking conversation from original ID ${originalConvoId}`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves all messages up to the root from the target message.
|
||||
* @param {TMessage[]} messages - The list of messages to search.
|
||||
* @param {string} targetMessageId - The ID of the target message.
|
||||
* @returns {TMessage[]} The list of messages up to the root from the target message.
|
||||
*/
|
||||
function getAllMessagesUpToParent(messages, targetMessageId) {
|
||||
const targetMessage = messages.find((msg) => msg.messageId === targetMessageId);
|
||||
if (!targetMessage) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const pathToRoot = new Set();
|
||||
const visited = new Set();
|
||||
let current = targetMessage;
|
||||
|
||||
while (current) {
|
||||
if (visited.has(current.messageId)) {
|
||||
break;
|
||||
}
|
||||
|
||||
visited.add(current.messageId);
|
||||
pathToRoot.add(current.messageId);
|
||||
|
||||
const currentParentId = current.parentMessageId ?? Constants.NO_PARENT;
|
||||
if (currentParentId === Constants.NO_PARENT) {
|
||||
break;
|
||||
}
|
||||
|
||||
current = messages.find((msg) => msg.messageId === currentParentId);
|
||||
}
|
||||
|
||||
// Include all messages that are in the path or whose parent is in the path
|
||||
// Exclude children of the target message
|
||||
return messages.filter(
|
||||
(msg) =>
|
||||
(pathToRoot.has(msg.messageId) && msg.messageId !== targetMessageId) ||
|
||||
(pathToRoot.has(msg.parentMessageId) && msg.parentMessageId !== targetMessageId) ||
|
||||
msg.messageId === targetMessageId,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves all messages up to the root from the target message and its neighbors.
|
||||
* @param {TMessage[]} messages - The list of messages to search.
|
||||
* @param {string} targetMessageId - The ID of the target message.
|
||||
* @returns {TMessage[]} The list of inclusive messages up to the root from the target message.
|
||||
*/
|
||||
function getMessagesUpToTargetLevel(messages, targetMessageId) {
|
||||
if (messages.length === 1 && messages[0] && messages[0].messageId === targetMessageId) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Create a map of parentMessageId to children messages
|
||||
const parentToChildrenMap = new Map();
|
||||
for (const message of messages) {
|
||||
if (!parentToChildrenMap.has(message.parentMessageId)) {
|
||||
parentToChildrenMap.set(message.parentMessageId, []);
|
||||
}
|
||||
parentToChildrenMap.get(message.parentMessageId).push(message);
|
||||
}
|
||||
|
||||
// Retrieve the target message
|
||||
const targetMessage = messages.find((msg) => msg.messageId === targetMessageId);
|
||||
if (!targetMessage) {
|
||||
logger.error('Target message not found.');
|
||||
return [];
|
||||
}
|
||||
|
||||
const visited = new Set();
|
||||
|
||||
const rootMessages = parentToChildrenMap.get(Constants.NO_PARENT) || [];
|
||||
let currentLevel = rootMessages.length > 0 ? [...rootMessages] : [targetMessage];
|
||||
const results = new Set(currentLevel);
|
||||
|
||||
// Check if the target message is at the root level
|
||||
if (
|
||||
currentLevel.some((msg) => msg.messageId === targetMessageId) &&
|
||||
targetMessage.parentMessageId === Constants.NO_PARENT
|
||||
) {
|
||||
return Array.from(results);
|
||||
}
|
||||
|
||||
// Iterate level by level until the target is found
|
||||
let targetFound = false;
|
||||
while (!targetFound && currentLevel.length > 0) {
|
||||
const nextLevel = [];
|
||||
for (const node of currentLevel) {
|
||||
if (visited.has(node.messageId)) {
|
||||
logger.warn('Cycle detected in message tree');
|
||||
continue;
|
||||
}
|
||||
visited.add(node.messageId);
|
||||
const children = parentToChildrenMap.get(node.messageId) || [];
|
||||
for (const child of children) {
|
||||
if (visited.has(child.messageId)) {
|
||||
logger.warn('Cycle detected in message tree');
|
||||
continue;
|
||||
}
|
||||
nextLevel.push(child);
|
||||
results.add(child);
|
||||
if (child.messageId === targetMessageId) {
|
||||
targetFound = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
currentLevel = nextLevel;
|
||||
}
|
||||
|
||||
return Array.from(results);
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits the conversation at the targeted message level, including the target, its siblings, and all descendant messages.
|
||||
* All target level messages have their parentMessageId set to the root.
|
||||
* @param {TMessage[]} messages - The list of messages to analyze.
|
||||
* @param {string} targetMessageId - The ID of the message to start the split from.
|
||||
* @returns {TMessage[]} The list of messages at and below the target level.
|
||||
*/
|
||||
function splitAtTargetLevel(messages, targetMessageId) {
|
||||
// Create a map of parentMessageId to children messages
|
||||
const parentToChildrenMap = new Map();
|
||||
for (const message of messages) {
|
||||
if (!parentToChildrenMap.has(message.parentMessageId)) {
|
||||
parentToChildrenMap.set(message.parentMessageId, []);
|
||||
}
|
||||
parentToChildrenMap.get(message.parentMessageId).push(message);
|
||||
}
|
||||
|
||||
// Retrieve the target message
|
||||
const targetMessage = messages.find((msg) => msg.messageId === targetMessageId);
|
||||
if (!targetMessage) {
|
||||
logger.error('Target message not found.');
|
||||
return [];
|
||||
}
|
||||
|
||||
// Initialize the search with root messages
|
||||
const rootMessages = parentToChildrenMap.get(Constants.NO_PARENT) || [];
|
||||
let currentLevel = [...rootMessages];
|
||||
let currentLevelIndex = 0;
|
||||
const levelMap = {};
|
||||
|
||||
// Map messages to their levels
|
||||
rootMessages.forEach((msg) => {
|
||||
levelMap[msg.messageId] = 0;
|
||||
});
|
||||
|
||||
// Search for the target level
|
||||
while (currentLevel.length > 0) {
|
||||
const nextLevel = [];
|
||||
for (const node of currentLevel) {
|
||||
const children = parentToChildrenMap.get(node.messageId) || [];
|
||||
for (const child of children) {
|
||||
nextLevel.push(child);
|
||||
levelMap[child.messageId] = currentLevelIndex + 1;
|
||||
}
|
||||
}
|
||||
currentLevel = nextLevel;
|
||||
currentLevelIndex++;
|
||||
}
|
||||
|
||||
// Determine the target level
|
||||
const targetLevel = levelMap[targetMessageId];
|
||||
if (targetLevel === undefined) {
|
||||
logger.error('Target level not found.');
|
||||
return [];
|
||||
}
|
||||
|
||||
// Filter messages at or below the target level
|
||||
const filteredMessages = messages
|
||||
.map((msg) => {
|
||||
const messageLevel = levelMap[msg.messageId];
|
||||
if (messageLevel < targetLevel) {
|
||||
return null;
|
||||
} else if (messageLevel === targetLevel) {
|
||||
return {
|
||||
...msg,
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
};
|
||||
}
|
||||
|
||||
return msg;
|
||||
})
|
||||
.filter((msg) => msg !== null);
|
||||
|
||||
return filteredMessages;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
forkConversation,
|
||||
splitAtTargetLevel,
|
||||
getAllMessagesUpToParent,
|
||||
getMessagesUpToTargetLevel,
|
||||
};
|
||||
Loading…
Add table
Add a link
Reference in a new issue