mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-30 23:28:52 +01:00
Merge branch 'dev' into feat/context-window-ui
This commit is contained in:
commit
cb8322ca85
407 changed files with 25479 additions and 19894 deletions
|
|
@ -350,9 +350,6 @@ function disposeClient(client) {
|
|||
if (client.agentConfigs) {
|
||||
client.agentConfigs = null;
|
||||
}
|
||||
if (client.agentIdMap) {
|
||||
client.agentIdMap = null;
|
||||
}
|
||||
if (client.artifactPromises) {
|
||||
client.artifactPromises = null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ const {
|
|||
setAuthTokens,
|
||||
registerUser,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const {
|
||||
deleteAllUserSessions,
|
||||
getUserById,
|
||||
findSession,
|
||||
updateUser,
|
||||
findUser,
|
||||
} = require('~/models');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
const { getOAuthReconnectionManager } = require('~/config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
|
@ -72,16 +78,38 @@ const refreshController = async (req, res) => {
|
|||
const openIdConfig = getOpenIdConfig();
|
||||
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
|
||||
const claims = tokenset.claims();
|
||||
const { user, error } = await findOpenIDUser({
|
||||
const { user, error, migration } = await findOpenIDUser({
|
||||
findUser,
|
||||
email: claims.email,
|
||||
openidId: claims.sub,
|
||||
idOnTheSource: claims.oid,
|
||||
strategyName: 'refreshController',
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`[refreshController] findOpenIDUser result: user=${user?.email ?? 'null'}, error=${error ?? 'null'}, migration=${migration}, userOpenidId=${user?.openidId ?? 'null'}, claimsSub=${claims.sub}`,
|
||||
);
|
||||
|
||||
if (error || !user) {
|
||||
logger.warn(
|
||||
`[refreshController] Redirecting to /login: error=${error ?? 'null'}, user=${user ? 'exists' : 'null'}`,
|
||||
);
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
|
||||
// Handle migration: update user with openidId if found by email without openidId
|
||||
// Also handle case where user has mismatched openidId (e.g., after database switch)
|
||||
if (migration || user.openidId !== claims.sub) {
|
||||
const reason = migration ? 'migration' : 'openidId mismatch';
|
||||
await updateUser(user._id.toString(), {
|
||||
provider: 'openid',
|
||||
openidId: claims.sub,
|
||||
});
|
||||
logger.info(
|
||||
`[refreshController] Updated user ${user.email} openidId (${reason}): ${user.openidId ?? 'null'} -> ${claims.sub}`,
|
||||
);
|
||||
}
|
||||
|
||||
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString(), refreshToken);
|
||||
|
||||
user.federatedTokens = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { sendEvent, GenerationJobManager } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
|
|
@ -144,17 +144,38 @@ function checkIfLastAgent(last_agent_id, langgraph_node) {
|
|||
return langgraph_node?.endsWith(last_agent_id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to emit events either to res (standard mode) or to job emitter (resumable mode).
|
||||
* @param {ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} eventData - The event data to send
|
||||
*/
|
||||
function emitEvent(res, streamId, eventData) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default handlers for stream events.
|
||||
* @param {Object} options - The options object.
|
||||
* @param {ServerResponse} options.res - The options object.
|
||||
* @param {ContentAggregator} options.aggregateContent - The options object.
|
||||
* @param {ServerResponse} options.res - The server response object.
|
||||
* @param {ContentAggregator} options.aggregateContent - Content aggregator function.
|
||||
* @param {ToolEndCallback} options.toolEndCallback - Callback to use when tool ends.
|
||||
* @param {Array<UsageMetadata>} options.collectedUsage - The list of collected usage metadata.
|
||||
* @param {string | null} [options.streamId] - The stream ID for resumable mode, or null for standard mode.
|
||||
* @returns {Record<string, t.EventHandler>} The default handlers.
|
||||
* @throws {Error} If the request is not found.
|
||||
*/
|
||||
function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedUsage }) {
|
||||
function getDefaultHandlers({
|
||||
res,
|
||||
aggregateContent,
|
||||
toolEndCallback,
|
||||
collectedUsage,
|
||||
streamId = null,
|
||||
}) {
|
||||
if (!res || !aggregateContent) {
|
||||
throw new Error(
|
||||
`[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`,
|
||||
|
|
@ -173,16 +194,16 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.stepDetails.type === StepTypes.TOOL_CALLS) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else {
|
||||
const agentName = metadata?.name ?? 'Agent';
|
||||
const isToolCall = data?.stepDetails.type === StepTypes.TOOL_CALLS;
|
||||
const action = isToolCall ? 'performing a task...' : 'thinking...';
|
||||
sendEvent(res, {
|
||||
emitEvent(res, streamId, {
|
||||
event: 'on_agent_update',
|
||||
data: {
|
||||
runId: metadata?.run_id,
|
||||
|
|
@ -202,11 +223,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.delta.type === StepTypes.TOOL_CALLS) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -220,11 +241,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (data?.result != null) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -238,9 +259,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -254,9 +275,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
*/
|
||||
handle: (event, data, metadata) => {
|
||||
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
} else if (!metadata?.hide_sequential_outputs) {
|
||||
sendEvent(res, { event, data });
|
||||
emitEvent(res, streamId, { event, data });
|
||||
}
|
||||
aggregateContent({ event, data });
|
||||
},
|
||||
|
|
@ -266,15 +287,30 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
|
|||
return handlers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to write attachment events either to res or to job emitter.
|
||||
* @param {ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} attachment - The attachment data
|
||||
*/
|
||||
function writeAttachment(res, streamId, attachment) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
|
||||
} else {
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {ServerResponse} params.res
|
||||
* @param {Promise<MongoFile | { filename: string; filepath: string; expires: number;} | null>[]} params.artifactPromises
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode, or null for standard mode.
|
||||
* @returns {ToolEndCallback} The tool end callback.
|
||||
*/
|
||||
function createToolEndCallback({ req, res, artifactPromises }) {
|
||||
function createToolEndCallback({ req, res, artifactPromises, streamId = null }) {
|
||||
/**
|
||||
* @type {ToolEndCallback}
|
||||
*/
|
||||
|
|
@ -302,10 +338,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
if (!attachment) {
|
||||
return null;
|
||||
}
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing file citations:', error);
|
||||
|
|
@ -314,8 +350,6 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
);
|
||||
}
|
||||
|
||||
// TODO: a lot of duplicated code in createToolEndCallback
|
||||
// we should refactor this to use a helper function in a follow-up PR
|
||||
if (output.artifact[Tools.ui_resources]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
|
|
@ -326,10 +360,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -348,10 +382,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
[Tools.web_search]: { ...output.artifact[Tools.web_search] },
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -388,7 +422,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
toolCallId: output.tool_call_id,
|
||||
conversationId: metadata.thread_id,
|
||||
});
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return fileMetadata;
|
||||
}
|
||||
|
||||
|
|
@ -396,7 +430,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
return null;
|
||||
}
|
||||
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
|
||||
writeAttachment(res, streamId, fileMetadata);
|
||||
return fileMetadata;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
|
|
@ -435,7 +469,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
conversationId: metadata.thread_id,
|
||||
session_id: output.artifact.session_id,
|
||||
});
|
||||
if (!res.headersSent) {
|
||||
if (!streamId && !res.headersSent) {
|
||||
return fileMetadata;
|
||||
}
|
||||
|
||||
|
|
@ -443,7 +477,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
|||
return null;
|
||||
}
|
||||
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
|
||||
writeAttachment(res, streamId, fileMetadata);
|
||||
return fileMetadata;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing code output:', error);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const {
|
|||
getBalanceConfig,
|
||||
getProviderConfig,
|
||||
memoryInstructions,
|
||||
GenerationJobManager,
|
||||
getTransactionsConfig,
|
||||
createMemoryProcessor,
|
||||
filterMalformedContentParts,
|
||||
|
|
@ -36,14 +37,13 @@ const {
|
|||
EModelEndpoint,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
isEphemeralAgentId,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { checkCapability } = require('~/server/services/Config');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
|
@ -95,59 +95,101 @@ function logToolError(graph, error, toolId) {
|
|||
});
|
||||
}
|
||||
|
||||
/** Regex pattern to match agent ID suffix (____N) */
|
||||
const AGENT_SUFFIX_PATTERN = /____(\d+)$/;
|
||||
|
||||
/**
|
||||
* Applies agent labeling to conversation history when multi-agent patterns are detected.
|
||||
* Labels content parts by their originating agent to prevent identity confusion.
|
||||
* Creates a mapMethod for getMessagesForConversation that processes agent content.
|
||||
* - Strips agentId/groupId metadata from all content
|
||||
* - For multi-agent: filters to primary agent content only (no suffix or lowest suffix)
|
||||
* - For multi-agent: applies agent labels to content
|
||||
*
|
||||
* @param {TMessage[]} orderedMessages - The ordered conversation messages
|
||||
* @param {Agent} primaryAgent - The primary agent configuration
|
||||
* @param {Map<string, Agent>} agentConfigs - Map of additional agent configurations
|
||||
* @returns {TMessage[]} Messages with agent labels applied where appropriate
|
||||
* @param {Agent} primaryAgent - Primary agent configuration
|
||||
* @param {Map<string, Agent>} [agentConfigs] - Additional agent configurations
|
||||
* @returns {(message: TMessage) => TMessage} Map method for processing messages
|
||||
*/
|
||||
function applyAgentLabelsToHistory(orderedMessages, primaryAgent, agentConfigs) {
|
||||
const shouldLabelByAgent = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
|
||||
|
||||
if (!shouldLabelByAgent) {
|
||||
return orderedMessages;
|
||||
}
|
||||
|
||||
const processedMessages = [];
|
||||
|
||||
for (let i = 0; i < orderedMessages.length; i++) {
|
||||
const message = orderedMessages[i];
|
||||
|
||||
/** @type {Record<string, string>} */
|
||||
const agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
|
||||
function createMultiAgentMapper(primaryAgent, agentConfigs) {
|
||||
const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
|
||||
|
||||
/** @type {Record<string, string> | null} */
|
||||
let agentNames = null;
|
||||
if (hasMultipleAgents) {
|
||||
agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
|
||||
if (agentConfigs) {
|
||||
for (const [agentId, agentConfig] of agentConfigs.entries()) {
|
||||
agentNames[agentId] = agentConfig.name || agentConfig.id;
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
!message.isCreatedByUser &&
|
||||
message.metadata?.agentIdMap &&
|
||||
Array.isArray(message.content)
|
||||
) {
|
||||
try {
|
||||
const labeledContent = labelContentByAgent(
|
||||
message.content,
|
||||
message.metadata.agentIdMap,
|
||||
agentNames,
|
||||
);
|
||||
|
||||
processedMessages.push({ ...message, content: labeledContent });
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error applying agent labels to message:', error);
|
||||
processedMessages.push(message);
|
||||
}
|
||||
} else {
|
||||
processedMessages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
return processedMessages;
|
||||
return (message) => {
|
||||
if (message.isCreatedByUser || !Array.isArray(message.content)) {
|
||||
return message;
|
||||
}
|
||||
|
||||
// Find primary agent ID (no suffix, or lowest suffix number) - only needed for multi-agent
|
||||
let primaryAgentId = null;
|
||||
let hasAgentMetadata = false;
|
||||
|
||||
if (hasMultipleAgents) {
|
||||
let lowestSuffixIndex = Infinity;
|
||||
for (const part of message.content) {
|
||||
const agentId = part?.agentId;
|
||||
if (!agentId) {
|
||||
continue;
|
||||
}
|
||||
hasAgentMetadata = true;
|
||||
|
||||
const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN);
|
||||
if (!suffixMatch) {
|
||||
primaryAgentId = agentId;
|
||||
break;
|
||||
}
|
||||
const suffixIndex = parseInt(suffixMatch[1], 10);
|
||||
if (suffixIndex < lowestSuffixIndex) {
|
||||
lowestSuffixIndex = suffixIndex;
|
||||
primaryAgentId = agentId;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single agent: just check if any metadata exists
|
||||
hasAgentMetadata = message.content.some((part) => part?.agentId || part?.groupId);
|
||||
}
|
||||
|
||||
if (!hasAgentMetadata) {
|
||||
return message;
|
||||
}
|
||||
|
||||
try {
|
||||
/** @type {Array<TMessageContentParts>} */
|
||||
const filteredContent = [];
|
||||
/** @type {Record<number, string>} */
|
||||
const agentIdMap = {};
|
||||
|
||||
for (const part of message.content) {
|
||||
const agentId = part?.agentId;
|
||||
// For single agent: include all parts; for multi-agent: filter to primary
|
||||
if (!hasMultipleAgents || !agentId || agentId === primaryAgentId) {
|
||||
const newIndex = filteredContent.length;
|
||||
const { agentId: _a, groupId: _g, ...cleanPart } = part;
|
||||
filteredContent.push(cleanPart);
|
||||
if (agentId && hasMultipleAgents) {
|
||||
agentIdMap[newIndex] = agentId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const finalContent =
|
||||
Object.keys(agentIdMap).length > 0 && agentNames
|
||||
? labelContentByAgent(filteredContent, agentIdMap, agentNames)
|
||||
: filteredContent;
|
||||
|
||||
return { ...message, content: finalContent };
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error processing multi-agent message:', error);
|
||||
return message;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
class AgentClient extends BaseClient {
|
||||
|
|
@ -199,8 +241,6 @@ class AgentClient extends BaseClient {
|
|||
this.indexTokenCountMap = {};
|
||||
/** @type {(messages: BaseMessage[]) => Promise<void>} */
|
||||
this.processMemory;
|
||||
/** @type {Record<number, string> | null} */
|
||||
this.agentIdMap = null;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -289,18 +329,13 @@ class AgentClient extends BaseClient {
|
|||
{ instructions = null, additional_instructions = null },
|
||||
opts,
|
||||
) {
|
||||
let orderedMessages = this.constructor.getMessagesForConversation({
|
||||
const orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
summary: this.shouldSummarize,
|
||||
mapMethod: createMultiAgentMapper(this.options.agent, this.agentConfigs),
|
||||
});
|
||||
|
||||
orderedMessages = applyAgentLabelsToHistory(
|
||||
orderedMessages,
|
||||
this.options.agent,
|
||||
this.agentConfigs,
|
||||
);
|
||||
|
||||
let payload;
|
||||
/** @type {number | undefined} */
|
||||
let promptTokens;
|
||||
|
|
@ -552,10 +587,9 @@ class AgentClient extends BaseClient {
|
|||
agent: prelimAgent,
|
||||
allowedProviders,
|
||||
endpointOption: {
|
||||
endpoint:
|
||||
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
endpoint: !isEphemeralAgentId(prelimAgent.id)
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -595,10 +629,12 @@ class AgentClient extends BaseClient {
|
|||
const userId = this.options.req.user.id + '';
|
||||
const messageId = this.responseMessageId + '';
|
||||
const conversationId = this.conversationId + '';
|
||||
const streamId = this.options.req?._resumableStreamId || null;
|
||||
const [withoutKeys, processMemory] = await createMemoryProcessor({
|
||||
userId,
|
||||
config,
|
||||
messageId,
|
||||
streamId,
|
||||
conversationId,
|
||||
memoryMethods: {
|
||||
setMemory: db.setMemory,
|
||||
|
|
@ -692,9 +728,7 @@ class AgentClient extends BaseClient {
|
|||
});
|
||||
|
||||
const completion = filterMalformedContentParts(this.contentParts);
|
||||
const metadata = this.agentIdMap ? { agentIdMap: this.agentIdMap } : undefined;
|
||||
|
||||
return { completion, metadata };
|
||||
return { completion };
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -890,12 +924,10 @@ class AgentClient extends BaseClient {
|
|||
*/
|
||||
const runAgents = async (messages) => {
|
||||
const agents = [this.options.agent];
|
||||
if (
|
||||
this.agentConfigs &&
|
||||
this.agentConfigs.size > 0 &&
|
||||
((this.options.agent.edges?.length ?? 0) > 0 ||
|
||||
(await checkCapability(this.options.req, AgentCapabilities.chain)))
|
||||
) {
|
||||
// Include additional agents when:
|
||||
// - agentConfigs has agents (from addedConvo parallel execution or agent handoffs)
|
||||
// - Agents without incoming edges become start nodes and run in parallel automatically
|
||||
if (this.agentConfigs && this.agentConfigs.size > 0) {
|
||||
agents.push(...this.agentConfigs.values());
|
||||
}
|
||||
|
||||
|
|
@ -955,6 +987,12 @@ class AgentClient extends BaseClient {
|
|||
}
|
||||
|
||||
this.run = run;
|
||||
|
||||
const streamId = this.options.req?._resumableStreamId;
|
||||
if (streamId && run.Graph) {
|
||||
GenerationJobManager.setGraph(streamId, run.Graph);
|
||||
}
|
||||
|
||||
if (userMCPAuthMap != null) {
|
||||
config.configurable.userMCPAuthMap = userMCPAuthMap;
|
||||
}
|
||||
|
|
@ -985,24 +1023,6 @@ class AgentClient extends BaseClient {
|
|||
);
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
/** Capture agent ID map if we have edges or multiple agents */
|
||||
const shouldStoreAgentMap =
|
||||
(this.options.agent.edges?.length ?? 0) > 0 || (this.agentConfigs?.size ?? 0) > 0;
|
||||
if (shouldStoreAgentMap && run?.Graph) {
|
||||
const contentPartAgentMap = run.Graph.getContentPartAgentMap();
|
||||
if (contentPartAgentMap && contentPartAgentMap.size > 0) {
|
||||
this.agentIdMap = Object.fromEntries(contentPartAgentMap);
|
||||
logger.debug('[AgentClient] Captured agent ID map:', {
|
||||
totalParts: this.contentParts.length,
|
||||
mappedParts: Object.keys(this.agentIdMap).length,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error capturing agent ID map:', error);
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
|
|
|
|||
|
|
@ -2,14 +2,11 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
sendEvent,
|
||||
GenerationJobManager,
|
||||
sanitizeFileForTransmit,
|
||||
sanitizeMessageForTransmit,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const { handleAbortError } = require('~/server/middleware');
|
||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||
const { saveMessage } = require('~/models');
|
||||
|
||||
|
|
@ -31,12 +28,16 @@ function createCloseHandler(abortController) {
|
|||
};
|
||||
}
|
||||
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
/**
|
||||
* Resumable Agent Controller - Generation runs independently of HTTP connection.
|
||||
* Returns streamId immediately, client subscribes separately via SSE.
|
||||
*/
|
||||
const ResumableAgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
conversationId: reqConversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
|
|
@ -44,18 +45,354 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let sender;
|
||||
let abortKey;
|
||||
const userId = req.user.id;
|
||||
|
||||
// Generate conversationId upfront if not provided - streamId === conversationId always
|
||||
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
||||
const conversationId =
|
||||
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
|
||||
const streamId = conversationId;
|
||||
|
||||
let client = null;
|
||||
|
||||
try {
|
||||
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
|
||||
req._resumableStreamId = streamId;
|
||||
|
||||
// Send JSON response IMMEDIATELY so client can connect to SSE stream
|
||||
// This is critical: tool loading (MCP OAuth) may emit events that the client needs to receive
|
||||
res.json({ streamId, conversationId, status: 'started' });
|
||||
|
||||
// Note: We no longer use res.on('close') to abort since we send JSON immediately.
|
||||
// The response closes normally after res.json(), which is not an abort condition.
|
||||
// Abort handling is done through GenerationJobManager via the SSE stream connection.
|
||||
|
||||
// Track if partial response was already saved to avoid duplicates
|
||||
let partialResponseSaved = false;
|
||||
|
||||
/**
|
||||
* Listen for all subscribers leaving to save partial response.
|
||||
* This ensures the response is saved to DB even if all clients disconnect
|
||||
* while generation continues.
|
||||
*
|
||||
* Note: The messageId used here falls back to `${userMessage.messageId}_` if the
|
||||
* actual response messageId isn't available yet. The final response save will
|
||||
* overwrite this with the complete response using the same messageId pattern.
|
||||
*/
|
||||
job.emitter.on('allSubscribersLeft', async (aggregatedContent) => {
|
||||
if (partialResponseSaved || !aggregatedContent || aggregatedContent.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resumeState = await GenerationJobManager.getResumeState(streamId);
|
||||
if (!resumeState?.userMessage) {
|
||||
logger.debug('[ResumableAgentController] No user message to save partial response for');
|
||||
return;
|
||||
}
|
||||
|
||||
partialResponseSaved = true;
|
||||
const responseConversationId = resumeState.conversationId || conversationId;
|
||||
|
||||
try {
|
||||
const partialMessage = {
|
||||
messageId: resumeState.responseMessageId || `${resumeState.userMessage.messageId}_`,
|
||||
conversationId: responseConversationId,
|
||||
parentMessageId: resumeState.userMessage.messageId,
|
||||
sender: client?.sender ?? 'AI',
|
||||
content: aggregatedContent,
|
||||
unfinished: true,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
user: userId,
|
||||
endpoint: endpointOption.endpoint,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
if (req.body?.agent_id) {
|
||||
partialMessage.agent_id = req.body.agent_id;
|
||||
}
|
||||
|
||||
await saveMessage(req, partialMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - partial response on disconnect',
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`[ResumableAgentController] Saved partial response for ${streamId}, content parts: ${aggregatedContent.length}`,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[ResumableAgentController] Error saving partial response:', error);
|
||||
// Reset flag so we can try again if subscribers reconnect and leave again
|
||||
partialResponseSaved = false;
|
||||
}
|
||||
});
|
||||
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
// Use the job's abort controller signal - allows abort via GenerationJobManager.abortJob()
|
||||
signal: job.abortController.signal,
|
||||
});
|
||||
|
||||
if (job.abortController.signal.aborted) {
|
||||
GenerationJobManager.completeJob(streamId, 'Request aborted during initialization');
|
||||
return;
|
||||
}
|
||||
|
||||
client = result.client;
|
||||
|
||||
if (client?.sender) {
|
||||
GenerationJobManager.updateMetadata(streamId, { sender: client.sender });
|
||||
}
|
||||
|
||||
// Store reference to client's contentParts - graph will be set when run is created
|
||||
if (client?.contentParts) {
|
||||
GenerationJobManager.setContentParts(streamId, client.contentParts);
|
||||
}
|
||||
|
||||
let userMessage;
|
||||
|
||||
const getReqData = (data = {}) => {
|
||||
if (data.userMessage) {
|
||||
userMessage = data.userMessage;
|
||||
}
|
||||
// conversationId is pre-generated, no need to update from callback
|
||||
};
|
||||
|
||||
// Start background generation - readyPromise resolves immediately now
|
||||
// (sync mechanism handles late subscribers)
|
||||
const startGeneration = async () => {
|
||||
try {
|
||||
// Short timeout as safety net - promise should already be resolved
|
||||
await Promise.race([job.readyPromise, new Promise((resolve) => setTimeout(resolve, 100))]);
|
||||
} catch (waitError) {
|
||||
logger.warn(
|
||||
`[ResumableAgentController] Error waiting for subscriber: ${waitError.message}`,
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const onStart = (userMsg, respMsgId, _isNewConvo) => {
|
||||
userMessage = userMsg;
|
||||
|
||||
// Store userMessage and responseMessageId upfront for resume capability
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
responseMessageId: respMsgId,
|
||||
userMessage: {
|
||||
messageId: userMsg.messageId,
|
||||
parentMessageId: userMsg.parentMessageId,
|
||||
conversationId: userMsg.conversationId,
|
||||
text: userMsg.text,
|
||||
},
|
||||
});
|
||||
|
||||
GenerationJobManager.emitChunk(streamId, {
|
||||
created: true,
|
||||
message: userMessage,
|
||||
streamId,
|
||||
});
|
||||
};
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
onStart,
|
||||
getReqData,
|
||||
isContinued,
|
||||
isRegenerate,
|
||||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController: job.abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res: {
|
||||
write: () => true,
|
||||
end: () => {},
|
||||
headersSent: false,
|
||||
writableEnded: false,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const response = await client.sendMessage(text, messageOptions);
|
||||
|
||||
const messageId = response.messageId;
|
||||
const endpoint = endpointOption.endpoint;
|
||||
response.endpoint = endpoint;
|
||||
|
||||
const databasePromise = response.databasePromise;
|
||||
delete response.databasePromise;
|
||||
|
||||
const { conversation: convoData = {} } = await databasePromise;
|
||||
const conversation = { ...convoData };
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
if (req.body.files && client.options?.attachments) {
|
||||
userMessage.files = [];
|
||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
||||
for (const attachment of client.options.attachments) {
|
||||
if (messageFiles.has(attachment.file_id)) {
|
||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
||||
}
|
||||
}
|
||||
delete userMessage.image_urls;
|
||||
}
|
||||
|
||||
// Check abort state BEFORE calling completeJob (which triggers abort signal for cleanup)
|
||||
const wasAbortedBeforeComplete = job.abortController.signal.aborted;
|
||||
const isNewConvo = !reqConversationId || reqConversationId === 'new';
|
||||
const shouldGenerateTitle =
|
||||
addTitle &&
|
||||
parentMessageId === Constants.NO_PARENT &&
|
||||
isNewConvo &&
|
||||
!wasAbortedBeforeComplete;
|
||||
|
||||
if (!wasAbortedBeforeComplete) {
|
||||
const finalEvent = {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: { ...response },
|
||||
};
|
||||
|
||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||
GenerationJobManager.completeJob(streamId);
|
||||
|
||||
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...response, user: userId },
|
||||
{ context: 'api/server/controllers/agents/request.js - resumable response end' },
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const finalEvent = {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: { ...response, error: true },
|
||||
error: { message: 'Request was aborted' },
|
||||
};
|
||||
GenerationJobManager.emitDone(streamId, finalEvent);
|
||||
GenerationJobManager.completeJob(streamId, 'Request aborted');
|
||||
}
|
||||
|
||||
if (!client.skipSaveUserMessage && userMessage) {
|
||||
await saveMessage(req, userMessage, {
|
||||
context: 'api/server/controllers/agents/request.js - resumable user message',
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldGenerateTitle) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
client,
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error('[ResumableAgentController] Error in title generation', err);
|
||||
})
|
||||
.finally(() => {
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Check if this was an abort (not a real error)
|
||||
const wasAborted = job.abortController.signal.aborted || error.message?.includes('abort');
|
||||
|
||||
if (wasAborted) {
|
||||
logger.debug(`[ResumableAgentController] Generation aborted for ${streamId}`);
|
||||
// abortJob already handled emitDone and completeJob
|
||||
} else {
|
||||
logger.error(`[ResumableAgentController] Generation error for ${streamId}:`, error);
|
||||
GenerationJobManager.emitError(streamId, error.message || 'Generation failed');
|
||||
GenerationJobManager.completeJob(streamId, error.message);
|
||||
}
|
||||
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
|
||||
// Don't continue to title generation after error/abort
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Start generation and handle any unhandled errors
|
||||
startGeneration().catch((err) => {
|
||||
logger.error(
|
||||
`[ResumableAgentController] Unhandled error in background generation: ${err.message}`,
|
||||
);
|
||||
GenerationJobManager.completeJob(streamId, err.message);
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[ResumableAgentController] Initialization error:', error);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({ error: error.message || 'Failed to start generation' });
|
||||
} else {
|
||||
// JSON already sent, emit error to stream so client can receive it
|
||||
GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation');
|
||||
}
|
||||
GenerationJobManager.completeJob(streamId, error.message);
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Agent Controller - Routes to ResumableAgentController for all requests.
|
||||
* The legacy non-resumable path is kept below but no longer used by default.
|
||||
*/
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
return ResumableAgentController(req, res, next, initializeClient, addTitle);
|
||||
};
|
||||
|
||||
/**
|
||||
* Legacy Non-resumable Agent Controller - Uses GenerationJobManager for abort handling.
|
||||
* Response is streamed directly to client via res, but abort state is managed centrally.
|
||||
* @deprecated Use ResumableAgentController instead
|
||||
*/
|
||||
const _LegacyAgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId: reqConversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
// Generate conversationId upfront if not provided - streamId === conversationId always
|
||||
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
|
||||
const conversationId =
|
||||
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
|
||||
const streamId = conversationId;
|
||||
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let userMessagePromise;
|
||||
let getAbortData;
|
||||
let client = null;
|
||||
let cleanupHandlers = [];
|
||||
|
||||
const newConvo = !conversationId;
|
||||
// Match the same logic used for conversationId generation above
|
||||
const isNewConvo = !reqConversationId || reqConversationId === 'new';
|
||||
const userId = req.user.id;
|
||||
|
||||
// Create handler to avoid capturing the entire parent scope
|
||||
|
|
@ -64,24 +401,20 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
// Update job metadata with prompt tokens for abort handling
|
||||
GenerationJobManager.updateMetadata(streamId, { promptTokens: data[key] });
|
||||
} else if (key === 'sender') {
|
||||
sender = data[key];
|
||||
} else if (key === 'abortKey') {
|
||||
abortKey = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
GenerationJobManager.updateMetadata(streamId, { sender: data[key] });
|
||||
}
|
||||
// conversationId is pre-generated, no need to update from callback
|
||||
}
|
||||
};
|
||||
|
||||
// Create a function to handle final cleanup
|
||||
const performCleanup = () => {
|
||||
const performCleanup = async () => {
|
||||
logger.debug('[AgentController] Performing cleanup');
|
||||
if (Array.isArray(cleanupHandlers)) {
|
||||
for (const handler of cleanupHandlers) {
|
||||
|
|
@ -95,10 +428,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
}
|
||||
|
||||
// Clean up abort controller
|
||||
if (abortKey) {
|
||||
logger.debug('[AgentController] Cleaning up abort controller');
|
||||
cleanupAbortController(abortKey);
|
||||
// Complete the job in GenerationJobManager
|
||||
if (streamId) {
|
||||
logger.debug('[AgentController] Completing job in GenerationJobManager');
|
||||
await GenerationJobManager.completeJob(streamId);
|
||||
}
|
||||
|
||||
// Dispose client properly
|
||||
|
|
@ -110,11 +443,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
client = null;
|
||||
getReqData = null;
|
||||
userMessage = null;
|
||||
getAbortData = null;
|
||||
endpointOption.agent = null;
|
||||
endpointOption = null;
|
||||
cleanupHandlers = null;
|
||||
userMessagePromise = null;
|
||||
|
||||
// Clear request data map
|
||||
if (requestDataMap.has(req)) {
|
||||
|
|
@ -136,6 +465,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
};
|
||||
cleanupHandlers.push(removePrelimHandler);
|
||||
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
|
|
@ -143,6 +473,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
endpointOption,
|
||||
signal: prelimAbortController.signal,
|
||||
});
|
||||
|
||||
if (prelimAbortController.signal?.aborted) {
|
||||
prelimAbortController = null;
|
||||
throw new Error('Request was aborted before initialization could complete');
|
||||
|
|
@ -161,28 +492,24 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
// Store request data in WeakMap keyed by req object
|
||||
requestDataMap.set(req, { client });
|
||||
|
||||
// Use WeakRef to allow GC but still access content if it exists
|
||||
const contentRef = new WeakRef(client.contentParts || []);
|
||||
// Create job in GenerationJobManager for abort handling
|
||||
// streamId === conversationId (pre-generated above)
|
||||
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
|
||||
|
||||
// Minimize closure scope - only capture small primitives and WeakRef
|
||||
getAbortData = () => {
|
||||
// Dereference WeakRef each time
|
||||
const content = contentRef.deref();
|
||||
// Store endpoint metadata for abort handling
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
sender: client?.sender,
|
||||
});
|
||||
|
||||
return {
|
||||
sender,
|
||||
content: content || [],
|
||||
userMessage,
|
||||
promptTokens,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
};
|
||||
};
|
||||
// Store content parts reference for abort
|
||||
if (client?.contentParts) {
|
||||
GenerationJobManager.setContentParts(streamId, client.contentParts);
|
||||
}
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
const closeHandler = createCloseHandler(abortController);
|
||||
const closeHandler = createCloseHandler(job.abortController);
|
||||
res.on('close', closeHandler);
|
||||
cleanupHandlers.push(() => {
|
||||
try {
|
||||
|
|
@ -192,6 +519,27 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* onStart callback - stores user message and response ID for abort handling
|
||||
*/
|
||||
const onStart = (userMsg, respMsgId, _isNewConvo) => {
|
||||
sendEvent(res, { message: userMsg, created: true });
|
||||
userMessage = userMsg;
|
||||
userMessageId = userMsg.messageId;
|
||||
responseMessageId = respMsgId;
|
||||
|
||||
// Store metadata for abort handling (conversationId is pre-generated)
|
||||
GenerationJobManager.updateMetadata(streamId, {
|
||||
responseMessageId: respMsgId,
|
||||
userMessage: {
|
||||
messageId: userMsg.messageId,
|
||||
parentMessageId: userMsg.parentMessageId,
|
||||
conversationId,
|
||||
text: userMsg.text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
onStart,
|
||||
|
|
@ -201,7 +549,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController,
|
||||
abortController: job.abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
|
|
@ -241,7 +589,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
|
||||
// Only send if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
if (!job.abortController.signal.aborted) {
|
||||
// Create a new response object with minimal copies
|
||||
const finalResponse = { ...response };
|
||||
|
||||
|
|
@ -292,7 +640,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
}
|
||||
|
||||
// Add title if needed - extract minimal data
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && isNewConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
|
|
@ -315,7 +663,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
|||
// Handle error without capturing much scope
|
||||
handleAbortError(res, req, error, {
|
||||
conversationId,
|
||||
sender,
|
||||
sender: client?.sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
|
||||
userMessageId,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,54 @@
|
|||
* @import { MCPServerDocument } from 'librechat-data-provider'
|
||||
*/
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
isMCPDomainNotAllowedError,
|
||||
isMCPInspectionFailedError,
|
||||
MCPErrorCodes,
|
||||
} = require('@librechat/api');
|
||||
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
|
||||
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getMCPManager, getMCPServersRegistry } = require('~/config');
|
||||
|
||||
/**
|
||||
* Handles MCP-specific errors and sends appropriate HTTP responses.
|
||||
* @param {Error} error - The error to handle
|
||||
* @param {import('express').Response} res - Express response object
|
||||
* @returns {import('express').Response | null} Response if handled, null if not an MCP error
|
||||
*/
|
||||
function handleMCPError(error, res) {
|
||||
if (isMCPDomainNotAllowedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
if (isMCPInspectionFailedError(error)) {
|
||||
return res.status(error.statusCode).json({
|
||||
error: error.code,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
// Fallback for legacy string-based error handling (backwards compatibility)
|
||||
if (error.message?.startsWith(MCPErrorCodes.DOMAIN_NOT_ALLOWED)) {
|
||||
return res.status(403).json({
|
||||
error: MCPErrorCodes.DOMAIN_NOT_ALLOWED,
|
||||
message: error.message.replace(/^MCP_DOMAIN_NOT_ALLOWED\s*:\s*/i, ''),
|
||||
});
|
||||
}
|
||||
|
||||
if (error.message?.startsWith(MCPErrorCodes.INSPECTION_FAILED)) {
|
||||
return res.status(400).json({
|
||||
error: MCPErrorCodes.INSPECTION_FAILED,
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
*/
|
||||
|
|
@ -175,11 +219,9 @@ const createMCPServerController = async (req, res) => {
|
|||
});
|
||||
} catch (error) {
|
||||
logger.error('[createMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
@ -235,11 +277,9 @@ const updateMCPServerController = async (req, res) => {
|
|||
res.status(200).json(parsedConfig);
|
||||
} catch (error) {
|
||||
logger.error('[updateMCPServer]', error);
|
||||
if (error.message?.startsWith('MCP_INSPECTION_FAILED:')) {
|
||||
return res.status(400).json({
|
||||
error: 'MCP_INSPECTION_FAILED',
|
||||
message: error.message,
|
||||
});
|
||||
const mcpErrorResponse = handleMCPError(error, res);
|
||||
if (mcpErrorResponse) {
|
||||
return mcpErrorResponse;
|
||||
}
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ const {
|
|||
performStartupChecks,
|
||||
handleJsonParseError,
|
||||
initializeFileStorage,
|
||||
GenerationJobManager,
|
||||
createStreamServices,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
|
|
@ -192,6 +194,11 @@ const startServer = async () => {
|
|||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await checkMigrations();
|
||||
|
||||
// Configure stream services (auto-detects Redis from USE_REDIS env var)
|
||||
const streamServices = createStreamServices();
|
||||
GenerationJobManager.configure(streamServices);
|
||||
GenerationJobManager.initialize();
|
||||
});
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
// abortControllers.js
|
||||
module.exports = new Map();
|
||||
|
|
@ -1,124 +1,102 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent, sanitizeMessageForTransmit } = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
countTokens,
|
||||
isEnabled,
|
||||
sendEvent,
|
||||
GenerationJobManager,
|
||||
sanitizeMessageForTransmit,
|
||||
} = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { abortRun } = require('./abortRun');
|
||||
|
||||
const abortDataMap = new WeakMap();
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {boolean}
|
||||
* Abort an active message generation.
|
||||
* Uses GenerationJobManager for all agent requests.
|
||||
* Since streamId === conversationId, we can directly abort by conversationId.
|
||||
*/
|
||||
function cleanupAbortController(abortKey) {
|
||||
if (!abortControllers.has(abortKey)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey);
|
||||
|
||||
if (!abortController) {
|
||||
abortControllers.delete(abortKey);
|
||||
return true;
|
||||
}
|
||||
|
||||
// 1. Check if this controller has any composed signals and clean them up
|
||||
try {
|
||||
// This creates a temporary composed signal to use for cleanup
|
||||
const composedSignal = AbortSignal.any([abortController.signal]);
|
||||
|
||||
// Get all event types - in practice, AbortSignal typically only uses 'abort'
|
||||
const eventTypes = ['abort'];
|
||||
|
||||
// First, execute a dummy listener removal to handle potential composed signals
|
||||
for (const eventType of eventTypes) {
|
||||
const dummyHandler = () => {};
|
||||
composedSignal.addEventListener(eventType, dummyHandler);
|
||||
composedSignal.removeEventListener(eventType, dummyHandler);
|
||||
|
||||
const listeners = composedSignal.listeners?.(eventType) || [];
|
||||
for (const listener of listeners) {
|
||||
composedSignal.removeEventListener(eventType, listener);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug(`Error cleaning up composed signals: ${e}`);
|
||||
}
|
||||
|
||||
// 2. Abort the controller if not already aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
|
||||
// 3. Remove from registry
|
||||
abortControllers.delete(abortKey);
|
||||
|
||||
// 4. Clean up any data stored in the WeakMap
|
||||
if (abortDataMap.has(abortController)) {
|
||||
abortDataMap.delete(abortController);
|
||||
}
|
||||
|
||||
// 5. Clean up function references on the controller
|
||||
if (abortController.getAbortData) {
|
||||
abortController.getAbortData = null;
|
||||
}
|
||||
|
||||
if (abortController.abortCompletion) {
|
||||
abortController.abortCompletion = null;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {function(): void}
|
||||
*/
|
||||
function createCleanUpHandler(abortKey) {
|
||||
return function () {
|
||||
try {
|
||||
cleanupAbortController(abortKey);
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, endpoint } = req.body;
|
||||
const { abortKey, endpoint } = req.body;
|
||||
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||
abortKey = conversationId;
|
||||
// Use GenerationJobManager to abort the job (streamId === conversationId)
|
||||
const abortResult = await GenerationJobManager.abortJob(conversationId);
|
||||
|
||||
if (!abortResult.success) {
|
||||
if (!res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
const { jobData, content, text } = abortResult;
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||
if (!abortController) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
// Count tokens and spend them
|
||||
const completionTokens = await countTokens(text);
|
||||
const promptTokens = jobData?.promptTokens ?? 0;
|
||||
|
||||
const finalEvent = await abortController.abortCompletion?.();
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
|
||||
JSON.stringify({ abortKey }),
|
||||
const responseMessage = {
|
||||
messageId: jobData?.responseMessageId,
|
||||
parentMessageId: jobData?.userMessage?.messageId,
|
||||
conversationId: jobData?.conversationId,
|
||||
content,
|
||||
text,
|
||||
sender: jobData?.sender ?? 'AI',
|
||||
finish_reason: 'incomplete',
|
||||
endpoint: jobData?.endpoint,
|
||||
iconURL: jobData?.iconURL,
|
||||
model: jobData?.model,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: completionTokens,
|
||||
};
|
||||
|
||||
await spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user: userId },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
cleanupAbortController(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user: userId },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
// Get conversation for title
|
||||
const conversation = await getConvo(userId, conversationId);
|
||||
|
||||
const finalEvent = {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: jobData?.userMessage
|
||||
? sanitizeMessageForTransmit({
|
||||
messageId: jobData.userMessage.messageId,
|
||||
parentMessageId: jobData.userMessage.parentMessageId,
|
||||
conversationId: jobData.userMessage.conversationId,
|
||||
text: jobData.userMessage.text,
|
||||
isCreatedByUser: true,
|
||||
})
|
||||
: null,
|
||||
responseMessage,
|
||||
};
|
||||
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${userId} | ${req.user.email} | Aborted request: ${conversationId}`,
|
||||
);
|
||||
|
||||
if (res.headersSent) {
|
||||
return sendEvent(res, finalEvent);
|
||||
}
|
||||
|
||||
|
|
@ -139,171 +117,13 @@ const handleAbort = function () {
|
|||
};
|
||||
};
|
||||
|
||||
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
const abortController = new AbortController();
|
||||
const { endpointOption } = req.body;
|
||||
|
||||
// Store minimal data in WeakMap to avoid circular references
|
||||
abortDataMap.set(abortController, {
|
||||
getAbortDataFn: getAbortData,
|
||||
userId: req.user.id,
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
});
|
||||
|
||||
// Replace the direct function reference with a wrapper that uses WeakMap
|
||||
abortController.getAbortData = function () {
|
||||
const data = abortDataMap.get(this);
|
||||
if (!data || typeof data.getAbortDataFn !== 'function') {
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
const result = data.getAbortDataFn();
|
||||
|
||||
// Create a copy without circular references
|
||||
const cleanResult = { ...result };
|
||||
|
||||
// If userMessagePromise exists, break its reference to client
|
||||
if (
|
||||
cleanResult.userMessagePromise &&
|
||||
typeof cleanResult.userMessagePromise.then === 'function'
|
||||
) {
|
||||
// Create a new promise that fulfills with the same result but doesn't reference the original
|
||||
const originalPromise = cleanResult.userMessagePromise;
|
||||
cleanResult.userMessagePromise = new Promise((resolve, reject) => {
|
||||
originalPromise.then(
|
||||
(result) => resolve({ ...result }),
|
||||
(error) => reject(error),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return cleanResult;
|
||||
} catch (err) {
|
||||
logger.error('[abortController.getAbortData] Error:', err);
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
* @param {boolean} [isNewConvo]
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId, isNewConvo) => {
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
|
||||
const prelimAbortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const abortKey = isNewConvo
|
||||
? `${prelimAbortKey}${Constants.COMMON_DIVIDER}${Constants.NEW_CONVO}`
|
||||
: prelimAbortKey;
|
||||
getReqData({ abortKey });
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
const { overrideUserMessageId } = req?.body ?? {};
|
||||
|
||||
if (overrideUserMessageId != null && prevRequest && prevRequest?.abortController) {
|
||||
const data = prevRequest.abortController.getAbortData();
|
||||
getReqData({ userMessage: data?.userMessage });
|
||||
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||
|
||||
// Store minimal options
|
||||
const minimalOptions = {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
|
||||
const cleanupHandler = createCleanUpHandler(addedAbortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store minimal options
|
||||
const minimalOptions = {
|
||||
endpoint: endpointOption.endpoint,
|
||||
iconURL: endpointOption.iconURL,
|
||||
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
|
||||
};
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...minimalOptions });
|
||||
const cleanupHandler = createCleanUpHandler(abortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
};
|
||||
|
||||
// Define abortCompletion without capturing the entire parent scope
|
||||
abortController.abortCompletion = async function () {
|
||||
this.abort();
|
||||
|
||||
// Get data from WeakMap
|
||||
const ctrlData = abortDataMap.get(this);
|
||||
if (!ctrlData || !ctrlData.getAbortDataFn) {
|
||||
return { final: true, conversation: {}, title: 'New Chat' };
|
||||
}
|
||||
|
||||
// Get abort data using stored function
|
||||
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
|
||||
ctrlData.getAbortDataFn();
|
||||
|
||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||
const user = ctrlData.userId;
|
||||
|
||||
const responseMessage = {
|
||||
...responseData,
|
||||
conversationId,
|
||||
finish_reason: 'incomplete',
|
||||
endpoint: ctrlData.endpoint,
|
||||
iconURL: ctrlData.iconURL,
|
||||
model: ctrlData.modelOptions?.model ?? ctrlData.model_parameters?.model,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: completionTokens,
|
||||
};
|
||||
|
||||
await spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
let conversation;
|
||||
if (userMessagePromise) {
|
||||
const resolved = await userMessagePromise;
|
||||
conversation = resolved?.conversation;
|
||||
// Break reference to promise
|
||||
resolved.conversation = null;
|
||||
}
|
||||
|
||||
if (!conversation) {
|
||||
conversation = await getConvo(user, conversationId);
|
||||
}
|
||||
|
||||
return {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
};
|
||||
|
||||
return { abortController, onStart };
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle abort errors during generation.
|
||||
* @param {ServerResponse} res
|
||||
* @param {ServerRequest} req
|
||||
* @param {Error | unknown} error
|
||||
* @param {Partial<TMessage> & { partialText?: string }} data
|
||||
* @returns { Promise<void> }
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const handleAbortError = async (res, req, error, data) => {
|
||||
if (error?.message?.includes('base64')) {
|
||||
|
|
@ -368,8 +188,7 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
};
|
||||
}
|
||||
|
||||
const callback = createCleanUpHandler(conversationId);
|
||||
await sendError(req, res, options, callback);
|
||||
await sendError(req, res, options);
|
||||
};
|
||||
|
||||
if (partialText && partialText.length > 5) {
|
||||
|
|
@ -387,6 +206,4 @@ const handleAbortError = async (res, req, error, data) => {
|
|||
module.exports = {
|
||||
handleAbort,
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants, isAgentsEndpoint, ResourceType } = require('librechat-data-provider');
|
||||
const {
|
||||
Constants,
|
||||
ResourceType,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
|
||||
|
|
@ -13,7 +18,8 @@ const { getAgent } = require('~/models/Agent');
|
|||
*/
|
||||
const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||
// Handle ephemeral agents - they don't need permission checks
|
||||
if (agentCustomId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
if (isEphemeralAgentId(agentCustomId)) {
|
||||
return null; // No permission check needed for ephemeral agents
|
||||
}
|
||||
|
||||
|
|
@ -62,7 +68,8 @@ const canAccessAgentFromBody = (options) => {
|
|||
}
|
||||
|
||||
// Skip permission checks for ephemeral agents
|
||||
if (agentId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
// Real agent IDs always start with "agent_", so anything else is ephemeral
|
||||
if (isEphemeralAgentId(agentId)) {
|
||||
return next();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,10 @@ async function buildEndpointOption(req, res, next) {
|
|||
try {
|
||||
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
|
||||
);
|
||||
logger.error(`Error parsing compact conversation for endpoint ${endpoint}`, error);
|
||||
logger.debug({
|
||||
'Error parsing compact conversation': { endpoint, endpointType, conversation: req.body },
|
||||
});
|
||||
return handleError(res, { text: 'Error parsing conversation' });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,15 @@ const { logViolation, getLogStores } = require('~/cache');
|
|||
|
||||
const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {};
|
||||
|
||||
/**
|
||||
* Helper function to get conversationId from different request body structures.
|
||||
* @param {Object} body - The request body.
|
||||
* @returns {string|undefined} The conversationId.
|
||||
*/
|
||||
const getConversationId = (body) => {
|
||||
return body.conversationId ?? body.arg?.conversationId;
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to validate user's authorization for a conversation.
|
||||
*
|
||||
|
|
@ -24,7 +33,7 @@ const validateConvoAccess = async (req, res, next) => {
|
|||
const namespace = ViolationTypes.CONVO_ACCESS;
|
||||
const cache = getLogStores(namespace);
|
||||
|
||||
const conversationId = req.body.conversationId;
|
||||
const conversationId = getConversationId(req.body);
|
||||
|
||||
if (!conversationId || conversationId === Constants.NEW_CONVO) {
|
||||
return next();
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ afterEach(() => {
|
|||
|
||||
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
|
||||
|
||||
// eslint-disable-next-line jest/no-disabled-tests
|
||||
describe.skip('GET /', () => {
|
||||
it('should return 200 and the correct body', async () => {
|
||||
process.env.APP_TITLE = 'Test Title';
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ jest.mock('~/server/middleware', () => ({
|
|||
forkUserLimiter: (req, res, next) => next(),
|
||||
})),
|
||||
configMiddleware: (req, res, next) => next(),
|
||||
validateConvoAccess: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/utils/import/fork', () => ({
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ const express = require('express');
|
|||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { getBasePath } = require('@librechat/api');
|
||||
|
||||
const mockRegistryInstance = {
|
||||
getServerConfig: jest.fn(),
|
||||
|
|
@ -12,26 +13,36 @@ const mockRegistryInstance = {
|
|||
removeServer: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
jest.mock('@librechat/api', () => {
|
||||
const actual = jest.requireActual('@librechat/api');
|
||||
return {
|
||||
...actual,
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
getClientInfoAndMetadata: jest.fn(),
|
||||
getTokens: jest.fn(),
|
||||
deleteUserTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
// Error handling utilities (from @librechat/api mcp/errors)
|
||||
isMCPDomainNotAllowedError: (error) => error?.code === 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
isMCPInspectionFailedError: (error) => error?.code === 'MCP_INSPECTION_FAILED',
|
||||
MCPErrorCodes: {
|
||||
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
|
||||
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -271,27 +282,30 @@ describe('MCP Routes', () => {
|
|||
error: 'access_denied',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=access_denied');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when code is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=missing_code');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_code`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when state is missing', async () => {
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=missing_state');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`);
|
||||
});
|
||||
|
||||
it('should redirect to error page when flow state is not found', async () => {
|
||||
|
|
@ -301,9 +315,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'invalid-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=invalid_state');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
|
||||
});
|
||||
|
||||
it('should handle OAuth callback successfully', async () => {
|
||||
|
|
@ -358,9 +373,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
|
||||
'test-flow-id',
|
||||
'test-auth-code',
|
||||
|
|
@ -394,9 +410,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
|
||||
});
|
||||
|
||||
it('should handle system-level OAuth completion', async () => {
|
||||
|
|
@ -429,9 +446,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
|
|
@ -474,9 +492,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
|
@ -515,9 +534,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
|
||||
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
|
|
@ -573,9 +593,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify storeTokens was called with ORIGINAL flow state credentials
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
||||
|
|
@ -614,9 +635,10 @@ describe('MCP Routes', () => {
|
|||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
|
||||
|
||||
// Verify completeOAuthFlow was NOT called (prevented duplicate)
|
||||
expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled();
|
||||
|
|
@ -1385,8 +1407,10 @@ describe('MCP Routes', () => {
|
|||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.expect(302);
|
||||
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(mockFlowManager.completeFlow).not.toHaveBeenCalled();
|
||||
expect(response.headers.location).toContain('/oauth/success');
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
|
||||
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
|
||||
|
|
@ -1433,7 +1457,9 @@ describe('MCP Routes', () => {
|
|||
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
|
||||
.expect(302);
|
||||
|
||||
expect(response.headers.location).toContain('/oauth/success');
|
||||
const basePath = getBasePath();
|
||||
|
||||
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
const express = require('express');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { getAccessToken } = require('@librechat/api');
|
||||
const { getAccessToken, getBasePath } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { findToken, updateToken, createToken } = require('~/models');
|
||||
|
|
@ -24,6 +24,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
const { code, state } = req.query;
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const basePath = getBasePath();
|
||||
let identifier = action_id;
|
||||
try {
|
||||
let decodedState;
|
||||
|
|
@ -32,17 +33,17 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
} catch (err) {
|
||||
logger.error('Error verifying state parameter:', err);
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
if (decodedState.action_id !== action_id) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
if (!decodedState.user) {
|
||||
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
identifier = `${decodedState.user}:${action_id}`;
|
||||
const flowState = await flowManager.getFlowState(identifier, 'oauth');
|
||||
|
|
@ -72,12 +73,12 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
|
|||
|
||||
/** Redirect to React success page */
|
||||
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('Error in OAuth callback:', error);
|
||||
await flowManager.failFlow(identifier, 'oauth', error);
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ const express = require('express');
|
|||
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions, PermissionBits } = require('librechat-data-provider');
|
||||
const {
|
||||
setHeaders,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
validateConvoAccess,
|
||||
|
|
@ -16,8 +15,6 @@ const { getRoleByName } = require('~/models/Role');
|
|||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(moderateText);
|
||||
|
||||
const checkAgentAccess = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
|
|
@ -28,11 +25,11 @@ const checkAgentResourceAccess = canAccessAgentFromBody({
|
|||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
router.use(moderateText);
|
||||
router.use(checkAgentAccess);
|
||||
router.use(checkAgentResourceAccess);
|
||||
router.use(validateConvoAccess);
|
||||
router.use(buildEndpointOption);
|
||||
router.use(setHeaders);
|
||||
|
||||
const controller = async (req, res, next) => {
|
||||
await AgentController(req, res, next, initializeClient, addTitle);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
const express = require('express');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { isEnabled, GenerationJobManager } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
|
|
@ -22,6 +23,188 @@ router.use(uaParser);
|
|||
|
||||
router.use('/', v1);
|
||||
|
||||
/**
|
||||
* Stream endpoints - mounted before chatRouter to bypass rate limiters
|
||||
* These are GET requests and don't need message body validation or rate limiting
|
||||
*/
|
||||
|
||||
/**
|
||||
* @route GET /chat/stream/:streamId
|
||||
* @desc Subscribe to an ongoing generation job's SSE stream with replay support
|
||||
* @access Private
|
||||
* @description Sends sync event with resume state, replays missed chunks, then streams live
|
||||
* @query resume=true - Indicates this is a reconnection (sends sync event)
|
||||
*/
|
||||
router.get('/chat/stream/:streamId', async (req, res) => {
|
||||
const { streamId } = req.params;
|
||||
const isResume = req.query.resume === 'true';
|
||||
|
||||
const job = await GenerationJobManager.getJob(streamId);
|
||||
if (!job) {
|
||||
return res.status(404).json({
|
||||
error: 'Stream not found',
|
||||
message: 'The generation job does not exist or has expired.',
|
||||
});
|
||||
}
|
||||
|
||||
res.setHeader('Content-Encoding', 'identity');
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.flushHeaders();
|
||||
|
||||
logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`);
|
||||
|
||||
// Send sync event with resume state for ALL reconnecting clients
|
||||
// This supports multi-tab scenarios where each tab needs run step data
|
||||
if (isResume) {
|
||||
const resumeState = await GenerationJobManager.getResumeState(streamId);
|
||||
if (resumeState && !res.writableEnded) {
|
||||
// Send sync event with run steps AND aggregatedContent
|
||||
// Client will use aggregatedContent to initialize message state
|
||||
res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
logger.debug(
|
||||
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const result = await GenerationJobManager.subscribe(
|
||||
streamId,
|
||||
(event) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
}
|
||||
},
|
||||
(event) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
res.end();
|
||||
}
|
||||
},
|
||||
(error) => {
|
||||
if (!res.writableEnded) {
|
||||
res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
res.end();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
if (!result) {
|
||||
return res.status(404).json({ error: 'Failed to subscribe to stream' });
|
||||
}
|
||||
|
||||
req.on('close', () => {
|
||||
logger.debug(`[AgentStream] Client disconnected from ${streamId}`);
|
||||
result.unsubscribe();
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @route GET /chat/active
|
||||
* @desc Get all active generation job IDs for the current user
|
||||
* @access Private
|
||||
* @returns { activeJobIds: string[] }
|
||||
*/
|
||||
router.get('/chat/active', async (req, res) => {
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id);
|
||||
res.json({ activeJobIds });
|
||||
});
|
||||
|
||||
/**
|
||||
* @route GET /chat/status/:conversationId
|
||||
* @desc Check if there's an active generation job for a conversation
|
||||
* @access Private
|
||||
* @returns { active, streamId, status, aggregatedContent, createdAt, resumeState }
|
||||
*/
|
||||
router.get('/chat/status/:conversationId', async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
|
||||
// streamId === conversationId, so we can use getJob directly
|
||||
const job = await GenerationJobManager.getJob(conversationId);
|
||||
|
||||
if (!job) {
|
||||
return res.json({ active: false });
|
||||
}
|
||||
|
||||
if (job.metadata.userId !== req.user.id) {
|
||||
return res.status(403).json({ error: 'Unauthorized' });
|
||||
}
|
||||
|
||||
// Get resume state which contains aggregatedContent
|
||||
// Avoid calling both getStreamInfo and getResumeState (both fetch content)
|
||||
const resumeState = await GenerationJobManager.getResumeState(conversationId);
|
||||
const isActive = job.status === 'running';
|
||||
|
||||
res.json({
|
||||
active: isActive,
|
||||
streamId: conversationId,
|
||||
status: job.status,
|
||||
aggregatedContent: resumeState?.aggregatedContent ?? [],
|
||||
createdAt: job.createdAt,
|
||||
resumeState,
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* @route POST /chat/abort
|
||||
* @desc Abort an ongoing generation job
|
||||
* @access Private
|
||||
* @description Mounted before chatRouter to bypass buildEndpointOption middleware
|
||||
*/
|
||||
router.post('/chat/abort', async (req, res) => {
|
||||
logger.debug(`[AgentStream] ========== ABORT ENDPOINT HIT ==========`);
|
||||
logger.debug(`[AgentStream] Method: ${req.method}, Path: ${req.path}`);
|
||||
logger.debug(`[AgentStream] Body:`, req.body);
|
||||
|
||||
const { streamId, conversationId, abortKey } = req.body;
|
||||
const userId = req.user?.id;
|
||||
|
||||
// streamId === conversationId, so try any of the provided IDs
|
||||
// Skip "new" as it's a placeholder for new conversations, not an actual ID
|
||||
let jobStreamId =
|
||||
streamId || (conversationId !== 'new' ? conversationId : null) || abortKey?.split(':')[0];
|
||||
let job = jobStreamId ? await GenerationJobManager.getJob(jobStreamId) : null;
|
||||
|
||||
// Fallback: if job not found and we have a userId, look up active jobs for user
|
||||
// This handles the case where frontend sends "new" but job was created with a UUID
|
||||
if (!job && userId) {
|
||||
logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`);
|
||||
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId);
|
||||
if (activeJobIds.length > 0) {
|
||||
// Abort the most recent active job for this user
|
||||
jobStreamId = activeJobIds[0];
|
||||
job = await GenerationJobManager.getJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Found active job for user: ${jobStreamId}`);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(`[AgentStream] Computed jobStreamId: ${jobStreamId}`);
|
||||
|
||||
if (job && jobStreamId) {
|
||||
logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`);
|
||||
await GenerationJobManager.abortJob(jobStreamId);
|
||||
logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`);
|
||||
return res.json({ success: true, aborted: jobStreamId });
|
||||
}
|
||||
|
||||
logger.warn(`[AgentStream] Job not found for streamId: ${jobStreamId}`);
|
||||
return res.status(404).json({ error: 'Job not found', streamId: jobStreamId });
|
||||
});
|
||||
|
||||
const chatRouter = express.Router();
|
||||
chatRouter.use(configMiddleware);
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const { logger } = require('@librechat/data-schemas');
|
|||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
createImportLimiters,
|
||||
validateConvoAccess,
|
||||
createForkLimiters,
|
||||
configMiddleware,
|
||||
} = require('~/server/middleware');
|
||||
|
|
@ -67,16 +68,17 @@ router.get('/:conversationId', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
router.post('/gen_title', async (req, res) => {
|
||||
const { conversationId } = req.body;
|
||||
router.get('/gen_title/:conversationId', async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${conversationId}`;
|
||||
let title = await titleCache.get(key);
|
||||
|
||||
if (!title) {
|
||||
// Retry every 1s for up to 20s
|
||||
for (let i = 0; i < 20; i++) {
|
||||
await sleep(1000);
|
||||
// Exponential backoff: 500ms, 1s, 2s, 4s, 8s (total ~15.5s max wait)
|
||||
const delays = [500, 1000, 2000, 4000, 8000];
|
||||
for (const delay of delays) {
|
||||
await sleep(delay);
|
||||
title = await titleCache.get(key);
|
||||
if (title) {
|
||||
break;
|
||||
|
|
@ -150,17 +152,39 @@ router.delete('/all', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
router.post('/update', async (req, res) => {
|
||||
const update = req.body.arg;
|
||||
/** Maximum allowed length for conversation titles */
|
||||
const MAX_CONVO_TITLE_LENGTH = 1024;
|
||||
|
||||
if (!update.conversationId) {
|
||||
/**
|
||||
* Updates a conversation's title.
|
||||
* @route POST /update
|
||||
* @param {string} req.body.arg.conversationId - The conversation ID to update.
|
||||
* @param {string} req.body.arg.title - The new title for the conversation.
|
||||
* @returns {object} 201 - The updated conversation object.
|
||||
*/
|
||||
router.post('/update', validateConvoAccess, async (req, res) => {
|
||||
const { conversationId, title } = req.body.arg ?? {};
|
||||
|
||||
if (!conversationId) {
|
||||
return res.status(400).json({ error: 'conversationId is required' });
|
||||
}
|
||||
|
||||
if (title === undefined) {
|
||||
return res.status(400).json({ error: 'title is required' });
|
||||
}
|
||||
|
||||
if (typeof title !== 'string') {
|
||||
return res.status(400).json({ error: 'title must be a string' });
|
||||
}
|
||||
|
||||
const sanitizedTitle = title.trim().slice(0, MAX_CONVO_TITLE_LENGTH);
|
||||
|
||||
try {
|
||||
const dbResponse = await saveConvo(req, update, {
|
||||
context: `POST /api/convos/update ${update.conversationId}`,
|
||||
});
|
||||
const dbResponse = await saveConvo(
|
||||
req,
|
||||
{ conversationId, title: sanitizedTitle },
|
||||
{ context: `POST /api/convos/update ${conversationId}` },
|
||||
);
|
||||
res.status(201).json(dbResponse);
|
||||
} catch (error) {
|
||||
logger.error('Error updating conversation', error);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ const {
|
|||
createSafeUser,
|
||||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
getBasePath,
|
||||
getUserMCPAuthMap,
|
||||
generateCheckAccess,
|
||||
} = require('@librechat/api');
|
||||
|
|
@ -105,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
|||
* This handles the OAuth callback after the user has authorized the application
|
||||
*/
|
||||
router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
const basePath = getBasePath();
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const { code, state, error: oauthError } = req.query;
|
||||
|
|
@ -118,17 +120,19 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
|
||||
if (oauthError) {
|
||||
logger.error('[MCP OAuth] OAuth error received', { error: oauthError });
|
||||
return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`);
|
||||
return res.redirect(
|
||||
`${basePath}/oauth/error?error=${encodeURIComponent(String(oauthError))}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!code || typeof code !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid code');
|
||||
return res.redirect('/oauth/error?error=missing_code');
|
||||
return res.redirect(`${basePath}/oauth/error?error=missing_code`);
|
||||
}
|
||||
|
||||
if (!state || typeof state !== 'string') {
|
||||
logger.error('[MCP OAuth] Missing or invalid state');
|
||||
return res.redirect('/oauth/error?error=missing_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
|
||||
}
|
||||
|
||||
const flowId = state;
|
||||
|
|
@ -142,7 +146,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
|
||||
if (!flowState) {
|
||||
logger.error('[MCP OAuth] Flow state not found for flowId:', flowId);
|
||||
return res.redirect('/oauth/error?error=invalid_state');
|
||||
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Flow state details', {
|
||||
|
|
@ -160,7 +164,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
flowId,
|
||||
serverName,
|
||||
});
|
||||
return res.redirect(`/oauth/success?serverName=${encodeURIComponent(serverName)}`);
|
||||
return res.redirect(`${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`);
|
||||
}
|
||||
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
|
|
@ -254,11 +258,11 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
|||
}
|
||||
|
||||
/** Redirect to success page with flowId and serverName */
|
||||
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
|
||||
res.redirect(redirectUrl);
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] OAuth callback error', error);
|
||||
res.redirect('/oauth/error?error=callback_failed');
|
||||
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -588,7 +592,7 @@ async function getOAuthHeaders(serverName, userId) {
|
|||
return serverConfig?.oauth_headers ?? {};
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
MCP Server CRUD Routes (User-Managed MCP Servers)
|
||||
*/
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const express = require('express');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ContentTypes } = require('librechat-data-provider');
|
||||
const { unescapeLaTeX, countTokens } = require('@librechat/api');
|
||||
|
|
@ -111,6 +112,91 @@ router.get('/', async (req, res) => {
|
|||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Creates a new branch message from a specific agent's content within a parallel response message.
|
||||
* Filters the original message's content to only include parts attributed to the specified agentId.
|
||||
* Only available for non-user messages with content attributions.
|
||||
*
|
||||
* @route POST /branch
|
||||
* @param {string} req.body.messageId - The ID of the source message
|
||||
* @param {string} req.body.agentId - The agentId to filter content by
|
||||
* @returns {TMessage} The newly created branch message
|
||||
*/
|
||||
router.post('/branch', async (req, res) => {
|
||||
try {
|
||||
const { messageId, agentId } = req.body;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!messageId || !agentId) {
|
||||
return res.status(400).json({ error: 'messageId and agentId are required' });
|
||||
}
|
||||
|
||||
const sourceMessage = await getMessage({ user: userId, messageId });
|
||||
if (!sourceMessage) {
|
||||
return res.status(404).json({ error: 'Source message not found' });
|
||||
}
|
||||
|
||||
if (sourceMessage.isCreatedByUser) {
|
||||
return res.status(400).json({ error: 'Cannot branch from user messages' });
|
||||
}
|
||||
|
||||
if (!Array.isArray(sourceMessage.content)) {
|
||||
return res.status(400).json({ error: 'Message does not have content' });
|
||||
}
|
||||
|
||||
const hasAgentMetadata = sourceMessage.content.some((part) => part?.agentId);
|
||||
if (!hasAgentMetadata) {
|
||||
return res
|
||||
.status(400)
|
||||
.json({ error: 'Message does not have parallel content with attributions' });
|
||||
}
|
||||
|
||||
/** @type {Array<import('librechat-data-provider').TMessageContentParts>} */
|
||||
const filteredContent = [];
|
||||
for (const part of sourceMessage.content) {
|
||||
if (part?.agentId === agentId) {
|
||||
const { agentId: _a, groupId: _g, ...cleanPart } = part;
|
||||
filteredContent.push(cleanPart);
|
||||
}
|
||||
}
|
||||
|
||||
if (filteredContent.length === 0) {
|
||||
return res.status(400).json({ error: 'No content found for the specified agentId' });
|
||||
}
|
||||
|
||||
const newMessageId = uuidv4();
|
||||
/** @type {import('librechat-data-provider').TMessage} */
|
||||
const newMessage = {
|
||||
messageId: newMessageId,
|
||||
conversationId: sourceMessage.conversationId,
|
||||
parentMessageId: sourceMessage.parentMessageId,
|
||||
attachments: sourceMessage.attachments,
|
||||
isCreatedByUser: false,
|
||||
model: sourceMessage.model,
|
||||
endpoint: sourceMessage.endpoint,
|
||||
sender: sourceMessage.sender,
|
||||
iconURL: sourceMessage.iconURL,
|
||||
content: filteredContent,
|
||||
unfinished: false,
|
||||
error: false,
|
||||
user: userId,
|
||||
};
|
||||
|
||||
const savedMessage = await saveMessage(req, newMessage, {
|
||||
context: 'POST /api/messages/branch',
|
||||
});
|
||||
|
||||
if (!savedMessage) {
|
||||
return res.status(500).json({ error: 'Failed to save branch message' });
|
||||
}
|
||||
|
||||
res.status(201).json(savedMessage);
|
||||
} catch (error) {
|
||||
logger.error('Error creating branch message:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/artifact/:messageId', async (req, res) => {
|
||||
try {
|
||||
const { messageId } = req.params;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@ const { nanoid } = require('nanoid');
|
|||
const { tool } = require('@langchain/core/tools');
|
||||
const { GraphEvents, sleep } = require('@librechat/agents');
|
||||
const { logger, encryptV2, decryptV2 } = require('@librechat/data-schemas');
|
||||
const { sendEvent, logAxiosError, refreshAccessToken } = require('@librechat/api');
|
||||
const {
|
||||
sendEvent,
|
||||
logAxiosError,
|
||||
refreshAccessToken,
|
||||
GenerationJobManager,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
CacheKeys,
|
||||
|
|
@ -127,6 +132,7 @@ async function loadActionSets(searchParams) {
|
|||
* @param {string | undefined} [params.description] - The description for the tool.
|
||||
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
|
||||
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable streams.
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createActionTool({
|
||||
|
|
@ -138,6 +144,7 @@ async function createActionTool({
|
|||
name,
|
||||
description,
|
||||
encrypted,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
|
||||
const _call = async (toolInput, config) => {
|
||||
|
|
@ -192,7 +199,12 @@ async function createActionTool({
|
|||
`${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`,
|
||||
'oauth_login',
|
||||
async () => {
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
logger.debug('Sent OAuth login request to client', { action_id, identifier });
|
||||
return true;
|
||||
},
|
||||
|
|
@ -217,7 +229,12 @@ async function createActionTool({
|
|||
logger.debug('Received OAuth Authorization response', { action_id, identifier });
|
||||
data.delta.auth = undefined;
|
||||
data.delta.expires_at = undefined;
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const successEventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, successEventData);
|
||||
} else {
|
||||
sendEvent(res, successEventData);
|
||||
}
|
||||
await sleep(3000);
|
||||
metadata.oauth_access_token = result.access_token;
|
||||
metadata.oauth_refresh_token = result.refresh_token;
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
const bcrypt = require('bcryptjs');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { webcrypto } = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, checkEmailConfig, isEmailDomainAllowed } = require('@librechat/api');
|
||||
const {
|
||||
logger,
|
||||
DEFAULT_SESSION_EXPIRY,
|
||||
DEFAULT_REFRESH_TOKEN_EXPIRY,
|
||||
} = require('@librechat/data-schemas');
|
||||
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const { isEnabled, checkEmailConfig, isEmailDomainAllowed, math } = require('@librechat/api');
|
||||
const {
|
||||
findUser,
|
||||
findToken,
|
||||
|
|
@ -369,19 +373,21 @@ const setAuthTokens = async (userId, res, _session = null) => {
|
|||
let session = _session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
const expiresIn = math(process.env.REFRESH_TOKEN_EXPIRY, DEFAULT_REFRESH_TOKEN_EXPIRY);
|
||||
|
||||
if (session && session._id && session.expiration != null) {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
const result = await createSession(userId);
|
||||
const result = await createSession(userId, { expiresIn });
|
||||
session = result.session;
|
||||
refreshToken = result.refreshToken;
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
const sessionExpiry = math(process.env.SESSION_EXPIRY, DEFAULT_SESSION_EXPIRY);
|
||||
const token = await generateToken(user, sessionExpiry);
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
|
|
@ -418,10 +424,10 @@ const setOpenIDAuthTokens = (tokenset, res, userId, existingRefreshToken) => {
|
|||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
return;
|
||||
}
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY
|
||||
? eval(REFRESH_TOKEN_EXPIRY)
|
||||
: 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
const expiryInMilliseconds = math(
|
||||
process.env.REFRESH_TOKEN_EXPIRY,
|
||||
DEFAULT_REFRESH_TOKEN_EXPIRY,
|
||||
);
|
||||
const expirationDate = new Date(Date.now() + expiryInMilliseconds);
|
||||
if (tokenset == null) {
|
||||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { CacheKeys, Time } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
|
|
@ -39,12 +39,12 @@ async function getCachedTools(options = {}) {
|
|||
* @param {Object} options - Options for caching tools
|
||||
* @param {string} [options.userId] - User ID for user-specific MCP tools
|
||||
* @param {string} [options.serverName] - MCP server name for server-specific tools
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds (default: 12 hours)
|
||||
* @returns {Promise<boolean>} Whether the operation was successful
|
||||
*/
|
||||
async function setCachedTools(tools, options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { userId, serverName, ttl } = options;
|
||||
const { userId, serverName, ttl = Time.TWELVE_HOURS } = options;
|
||||
|
||||
// Cache by MCP server if specified (requires userId)
|
||||
if (serverName && userId) {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@ async function getEndpointsConfig(req) {
|
|||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
if (cachedEndpointsConfig) {
|
||||
return cachedEndpointsConfig;
|
||||
if (cachedEndpointsConfig.gptPlugins) {
|
||||
await cache.delete(CacheKeys.ENDPOINT_CONFIG);
|
||||
} else {
|
||||
return cachedEndpointsConfig;
|
||||
}
|
||||
}
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
|
|
|
|||
136
api/server/services/Endpoints/agents/addedConvo.js
Normal file
136
api/server/services/Endpoints/agents/addedConvo.js
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { initializeAgent, validateAgentModel } = require('@librechat/api');
|
||||
const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const db = require('~/models');
|
||||
|
||||
// Initialize the getAgent dependency
|
||||
setGetAgent(getAgent);
|
||||
|
||||
/**
|
||||
* Process addedConvo for parallel agent execution.
|
||||
* Creates a parallel agent config from an added conversation.
|
||||
*
|
||||
* When an added agent has no incoming edges, it becomes a start node
|
||||
* and runs in parallel with the primary agent automatically.
|
||||
*
|
||||
* Edge cases handled:
|
||||
* - Primary agent has edges (handoffs): Added agent runs in parallel with primary,
|
||||
* but doesn't participate in the primary's handoff graph
|
||||
* - Primary agent has agent_ids (legacy chain): Added agent runs in parallel with primary,
|
||||
* but doesn't participate in the chain
|
||||
* - Primary agent has both: Added agent is independent, runs parallel from start
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {import('express').Request} params.req
|
||||
* @param {import('express').Response} params.res
|
||||
* @param {Object} params.endpointOption - The endpoint option containing addedConvo
|
||||
* @param {Object} params.modelsConfig - The models configuration
|
||||
* @param {Function} params.logViolation - Function to log violations
|
||||
* @param {Function} params.loadTools - Function to load agent tools
|
||||
* @param {Array} params.requestFiles - Request files
|
||||
* @param {string} params.conversationId - The conversation ID
|
||||
* @param {Set} params.allowedProviders - Set of allowed providers
|
||||
* @param {Map} params.agentConfigs - Map of agent configs to add to
|
||||
* @param {string} params.primaryAgentId - The primary agent ID
|
||||
* @param {Object|undefined} params.userMCPAuthMap - User MCP auth map to merge into
|
||||
* @returns {Promise<{userMCPAuthMap: Object|undefined}>} The updated userMCPAuthMap
|
||||
*/
|
||||
const processAddedConvo = async ({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
allowedProviders,
|
||||
agentConfigs,
|
||||
primaryAgentId,
|
||||
primaryAgent,
|
||||
userMCPAuthMap,
|
||||
}) => {
|
||||
const addedConvo = endpointOption.addedConvo;
|
||||
logger.debug('[processAddedConvo] Called with addedConvo:', {
|
||||
hasAddedConvo: addedConvo != null,
|
||||
addedConvoEndpoint: addedConvo?.endpoint,
|
||||
addedConvoModel: addedConvo?.model,
|
||||
addedConvoAgentId: addedConvo?.agent_id,
|
||||
});
|
||||
if (addedConvo == null) {
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
try {
|
||||
const addedAgent = await loadAddedAgent({ req, conversation: addedConvo, primaryAgent });
|
||||
if (!addedAgent) {
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
const addedValidation = await validateAgentModel({
|
||||
req,
|
||||
res,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
agent: addedAgent,
|
||||
});
|
||||
|
||||
if (!addedValidation.isValid) {
|
||||
logger.warn(
|
||||
`[processAddedConvo] Added agent validation failed: ${addedValidation.error?.message}`,
|
||||
);
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
|
||||
const addedConfig = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
agent: addedAgent,
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
},
|
||||
{
|
||||
getConvoFiles,
|
||||
getFiles: db.getFiles,
|
||||
getUserKey: db.getUserKey,
|
||||
updateFilesUsage: db.updateFilesUsage,
|
||||
getUserKeyValues: db.getUserKeyValues,
|
||||
getToolFilesByIds: db.getToolFilesByIds,
|
||||
},
|
||||
);
|
||||
|
||||
if (userMCPAuthMap != null) {
|
||||
Object.assign(userMCPAuthMap, addedConfig.userMCPAuthMap ?? {});
|
||||
} else {
|
||||
userMCPAuthMap = addedConfig.userMCPAuthMap;
|
||||
}
|
||||
|
||||
const addedAgentId = addedConfig.id || ADDED_AGENT_ID;
|
||||
agentConfigs.set(addedAgentId, addedConfig);
|
||||
|
||||
// No edges needed - agent without incoming edges becomes a start node
|
||||
// and runs in parallel with the primary agent automatically.
|
||||
// This is independent of any edges/agent_ids the primary agent has.
|
||||
|
||||
logger.debug(
|
||||
`[processAddedConvo] Added parallel agent: ${addedAgentId} (primary: ${primaryAgentId}, ` +
|
||||
`primary has edges: ${!!endpointOption.edges}, primary has agent_ids: ${!!endpointOption.agent_ids})`,
|
||||
);
|
||||
|
||||
return { userMCPAuthMap };
|
||||
} catch (err) {
|
||||
logger.error('[processAddedConvo] Error processing addedConvo for parallel agent', err);
|
||||
return { userMCPAuthMap };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
processAddedConvo,
|
||||
ADDED_AGENT_ID,
|
||||
};
|
||||
|
|
@ -15,6 +15,9 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
|||
return undefined;
|
||||
});
|
||||
|
||||
/** @type {import('librechat-data-provider').TConversation | undefined} */
|
||||
const addedConvo = req.body?.addedConvo;
|
||||
|
||||
return removeNullishValues({
|
||||
spec,
|
||||
iconURL,
|
||||
|
|
@ -23,6 +26,7 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
|||
endpointType,
|
||||
model_parameters,
|
||||
agent: agentPromise,
|
||||
addedConvo,
|
||||
});
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ const {
|
|||
createSequentialChainEdges,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
getResponseSender,
|
||||
isEphemeralAgentId,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
createToolEndCallback,
|
||||
|
|
@ -20,14 +20,17 @@ const { getModelsConfig } = require('~/server/controllers/ModelController');
|
|||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { processAddedConvo } = require('./addedConvo');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { logViolation } = require('~/cache');
|
||||
const db = require('~/models');
|
||||
|
||||
/**
|
||||
* @param {AbortSignal} signal
|
||||
* Creates a tool loader function for the agent.
|
||||
* @param {AbortSignal} signal - The abort signal
|
||||
* @param {string | null} [streamId] - The stream ID for resumable mode
|
||||
*/
|
||||
function createToolLoader(signal) {
|
||||
function createToolLoader(signal, streamId = null) {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
|
|
@ -52,6 +55,7 @@ function createToolLoader(signal) {
|
|||
agent,
|
||||
signal,
|
||||
tool_resources,
|
||||
streamId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error loading tools for agent ' + agentId, error);
|
||||
|
|
@ -65,18 +69,21 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
}
|
||||
const appConfig = req.config;
|
||||
|
||||
// TODO: use endpointOption to determine options/modelOptions
|
||||
/** @type {string | null} */
|
||||
const streamId = req._resumableStreamId || null;
|
||||
|
||||
/** @type {Array<UsageMetadata>} */
|
||||
const collectedUsage = [];
|
||||
/** @type {ArtifactPromises} */
|
||||
const artifactPromises = [];
|
||||
const { contentParts, aggregateContent } = createContentAggregator();
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId });
|
||||
const eventHandlers = getDefaultHandlers({
|
||||
res,
|
||||
aggregateContent,
|
||||
toolEndCallback,
|
||||
collectedUsage,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (!endpointOption.agent) {
|
||||
|
|
@ -105,7 +112,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
const agentConfigs = new Map();
|
||||
const allowedProviders = new Set(appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
const loadTools = createToolLoader(signal);
|
||||
const loadTools = createToolLoader(signal, streamId);
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
/** @type {string} */
|
||||
|
|
@ -227,6 +234,33 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
edges = edges ? edges.concat(chain) : chain;
|
||||
}
|
||||
|
||||
/** Multi-Convo: Process addedConvo for parallel agent execution */
|
||||
const { userMCPAuthMap: updatedMCPAuthMap } = await processAddedConvo({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
modelsConfig,
|
||||
logViolation,
|
||||
loadTools,
|
||||
requestFiles,
|
||||
conversationId,
|
||||
allowedProviders,
|
||||
agentConfigs,
|
||||
primaryAgentId: primaryConfig.id,
|
||||
primaryAgent,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
if (updatedMCPAuthMap) {
|
||||
userMCPAuthMap = updatedMCPAuthMap;
|
||||
}
|
||||
|
||||
// Ensure edges is an array when we have multiple agents (multi-agent mode)
|
||||
// MultiAgentGraph.categorizeEdges requires edges to be iterable
|
||||
if (agentConfigs.size > 0 && !edges) {
|
||||
edges = [];
|
||||
}
|
||||
|
||||
primaryConfig.edges = edges;
|
||||
|
||||
let endpointConfig = appConfig.endpoints?.[primaryConfig.endpoint];
|
||||
|
|
@ -270,10 +304,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
|||
endpointType: endpointOption.endpointType,
|
||||
resendFiles: primaryConfig.resendFiles ?? true,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
endpoint:
|
||||
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
|
||||
? primaryConfig.endpoint
|
||||
: EModelEndpoint.agents,
|
||||
endpoint: isEphemeralAgentId(primaryConfig.id) ? primaryConfig.endpoint : EModelEndpoint.agents,
|
||||
});
|
||||
|
||||
return { client, userMCPAuthMap };
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
|||
const { getAssistant } = require('~/models/Assistant');
|
||||
|
||||
const buildOptions = async (endpoint, parsedBody) => {
|
||||
|
||||
const { promptPrefix, assistant_id, iconURL, greeting, spec, artifacts, ...modelOptions } =
|
||||
parsedBody;
|
||||
const endpointOption = removeNullishValues({
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ const {
|
|||
const {
|
||||
sendEvent,
|
||||
MCPOAuthHandler,
|
||||
isMCPDomainAllowed,
|
||||
normalizeServerName,
|
||||
convertWithResolvedRefs,
|
||||
GenerationJobManager,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
|
|
@ -21,13 +23,14 @@ const {
|
|||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getMCPManager,
|
||||
getFlowStateManager,
|
||||
getOAuthReconnectionManager,
|
||||
getMCPServersRegistry,
|
||||
getFlowStateManager,
|
||||
getMCPManager,
|
||||
} = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
|
|
@ -35,8 +38,9 @@ const { getLogStores } = require('~/cache');
|
|||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
||||
function createRunStepDeltaEmitter({ res, stepId, toolCall, streamId = null }) {
|
||||
/**
|
||||
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
|
||||
* @returns {void}
|
||||
|
|
@ -52,7 +56,12 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
|||
expires_at: Date.now() + Time.TWO_MINUTES,
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -63,8 +72,9 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
|
|||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {number} [params.index]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
|
||||
function createRunStepEmitter({ res, runId, stepId, toolCall, index, streamId = null }) {
|
||||
return function () {
|
||||
/** @type {import('@librechat/agents').RunStep} */
|
||||
const data = {
|
||||
|
|
@ -77,7 +87,12 @@ function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
|
|||
tool_calls: [toolCall],
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -108,10 +123,9 @@ function createOAuthStart({ flowId, flowManager, callback }) {
|
|||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.stepId - The ID of the step in the flow.
|
||||
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
|
||||
* @param {string} params.loginFlowId - The ID of the login flow.
|
||||
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
*/
|
||||
function createOAuthEnd({ res, stepId, toolCall }) {
|
||||
function createOAuthEnd({ res, stepId, toolCall, streamId = null }) {
|
||||
return async function () {
|
||||
/** @type {{ id: string; delta: AgentToolCallDelta }} */
|
||||
const data = {
|
||||
|
|
@ -121,7 +135,12 @@ function createOAuthEnd({ res, stepId, toolCall }) {
|
|||
tool_calls: [{ ...toolCall }],
|
||||
},
|
||||
};
|
||||
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
|
||||
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, eventData);
|
||||
} else {
|
||||
sendEvent(res, eventData);
|
||||
}
|
||||
logger.debug('Sent OAuth login success to client');
|
||||
};
|
||||
}
|
||||
|
|
@ -137,7 +156,9 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
|
|||
return function () {
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
// Clean up both mcp_oauth and mcp_get_tokens flows
|
||||
flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted'));
|
||||
flowManager.failFlow(flowId, 'mcp_get_tokens', new Error('Tool call aborted'));
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -162,10 +183,19 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
|||
* @param {AbortSignal} params.signal
|
||||
* @param {string} params.model
|
||||
* @param {number} [params.index]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) {
|
||||
async function reconnectServer({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||
const flowId = `${user.id}:${serverName}:${Date.now()}`;
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
|
|
@ -176,36 +206,60 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
|||
type: 'tool_call_chunk',
|
||||
};
|
||||
|
||||
const runStepEmitter = createRunStepEmitter({
|
||||
res,
|
||||
index,
|
||||
runId,
|
||||
stepId,
|
||||
toolCall,
|
||||
});
|
||||
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
});
|
||||
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
|
||||
const oauthStart = createOAuthStart({
|
||||
res,
|
||||
flowId,
|
||||
callback,
|
||||
flowManager,
|
||||
});
|
||||
return await reinitMCPServer({
|
||||
user,
|
||||
signal,
|
||||
serverName,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
// Set up abort handler to clean up OAuth flows if request is aborted
|
||||
const oauthFlowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
const abortHandler = () => {
|
||||
logger.info(
|
||||
`[MCP][User: ${user.id}][${serverName}] Tool loading aborted, cleaning up OAuth flows`,
|
||||
);
|
||||
// Clean up both mcp_oauth and mcp_get_tokens flows
|
||||
flowManager.failFlow(oauthFlowId, 'mcp_oauth', new Error('Tool loading aborted'));
|
||||
flowManager.failFlow(oauthFlowId, 'mcp_get_tokens', new Error('Tool loading aborted'));
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener('abort', abortHandler, { once: true });
|
||||
}
|
||||
|
||||
try {
|
||||
const runStepEmitter = createRunStepEmitter({
|
||||
res,
|
||||
index,
|
||||
runId,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const runStepDeltaEmitter = createRunStepDeltaEmitter({
|
||||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
|
||||
const oauthStart = createOAuthStart({
|
||||
res,
|
||||
flowId,
|
||||
callback,
|
||||
flowManager,
|
||||
});
|
||||
return await reinitMCPServer({
|
||||
user,
|
||||
signal,
|
||||
serverName,
|
||||
oauthStart,
|
||||
flowManager,
|
||||
userMCPAuthMap,
|
||||
forceNew: true,
|
||||
returnOnOAuth: false,
|
||||
connectionTimeout: Time.TWO_MINUTES,
|
||||
});
|
||||
} finally {
|
||||
// Clean up abort handler to prevent memory leaks
|
||||
if (signal) {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -222,11 +276,45 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
|
|||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {number} [params.index]
|
||||
* @param {AbortSignal} [params.signal]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
|
||||
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
|
||||
async function createMCPTools({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
config,
|
||||
provider,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId = null,
|
||||
}) {
|
||||
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
const result = await reconnectServer({
|
||||
res,
|
||||
user,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
if (!result || !result.tools) {
|
||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||
return;
|
||||
|
|
@ -239,8 +327,10 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
user,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
availableTools: result.availableTools,
|
||||
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
|
||||
config: serverConfig,
|
||||
});
|
||||
if (toolInstance) {
|
||||
serverTools.push(toolInstance);
|
||||
|
|
@ -259,9 +349,11 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
|
|||
* @param {string} params.model - The model for the tool.
|
||||
* @param {number} [params.index]
|
||||
* @param {AbortSignal} [params.signal]
|
||||
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
|
||||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
* @param {LCAvailableTools} [params.availableTools]
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTool({
|
||||
|
|
@ -273,9 +365,25 @@ async function createMCPTool({
|
|||
provider,
|
||||
userMCPAuthMap,
|
||||
availableTools,
|
||||
config,
|
||||
streamId = null,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
// Runtime domain validation: check if the server's domain is still allowed
|
||||
// Use getAppConfig() to support per-user/role domain restrictions
|
||||
const serverConfig =
|
||||
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
|
||||
if (serverConfig?.url) {
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
|
||||
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
|
||||
if (!isDomainAllowed) {
|
||||
logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {LCTool | undefined} */
|
||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
|
|
@ -289,6 +397,7 @@ async function createMCPTool({
|
|||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
streamId,
|
||||
});
|
||||
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||
}
|
||||
|
|
@ -304,10 +413,18 @@ async function createMCPTool({
|
|||
toolName,
|
||||
serverName,
|
||||
toolDefinition,
|
||||
streamId,
|
||||
});
|
||||
}
|
||||
|
||||
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
|
||||
function createToolInstance({
|
||||
res,
|
||||
toolName,
|
||||
serverName,
|
||||
toolDefinition,
|
||||
provider: _provider,
|
||||
streamId = null,
|
||||
}) {
|
||||
/** @type {LCTool} */
|
||||
const { description, parameters } = toolDefinition;
|
||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||
|
|
@ -343,6 +460,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
|
|||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
const oauthStart = createOAuthStart({
|
||||
flowId,
|
||||
|
|
@ -353,6 +471,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
|
|||
res,
|
||||
stepId,
|
||||
toolCall,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (derivedSignal) {
|
||||
|
|
@ -448,7 +567,10 @@ async function getMCPSetupData(userId) {
|
|||
/** @type {Map<string, import('@librechat/api').MCPConnection>} */
|
||||
let appConnections = new Map();
|
||||
try {
|
||||
appConnections = (await mcpManager.appConnections?.getAll()) || new Map();
|
||||
// Use getLoaded() instead of getAll() to avoid forcing connection creation
|
||||
// getAll() creates connections for all servers, which is problematic for servers
|
||||
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders)
|
||||
appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map();
|
||||
} catch (error) {
|
||||
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,4 @@
|
|||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
// Mock all dependencies - define mocks before imports
|
||||
// Mock all dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
|
|
@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({
|
|||
},
|
||||
}));
|
||||
|
||||
// Create mock registry instance
|
||||
const mockRegistryInstance = {
|
||||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
|
||||
getServerConfig: jest.fn(() => Promise.resolve(null)),
|
||||
};
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
}));
|
||||
// Create isMCPDomainAllowed mock that can be configured per-test
|
||||
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
|
||||
|
||||
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
|
||||
|
||||
jest.mock('@librechat/api', () => {
|
||||
// Access mock via getter to avoid hoisting issues
|
||||
return {
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
get isMCPDomainAllowed() {
|
||||
return mockIsMCPDomainAllowed;
|
||||
},
|
||||
MCPServersRegistry: {
|
||||
getInstance: () => mockRegistryInstance,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const {
|
||||
createMCPTool,
|
||||
createMCPTools,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
} = require('./MCP');
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
CacheKeys: {
|
||||
|
|
@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({
|
|||
|
||||
jest.mock('./Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
getAppConfig: jest.fn(),
|
||||
get getAppConfig() {
|
||||
return mockGetAppConfig;
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
|
|
@ -128,7 +144,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
|
||||
beforeEach(() => {
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
appConnections: { getAll: jest.fn(() => new Map()) },
|
||||
appConnections: { getLoaded: jest.fn(() => new Map()) },
|
||||
getUserConnections: jest.fn(() => new Map()),
|
||||
});
|
||||
mockRegistryInstance.getOAuthServers.mockResolvedValue(new Set());
|
||||
|
|
@ -143,7 +159,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
const mockOAuthServers = new Set(['server2']);
|
||||
|
||||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => Promise.resolve(mockAppConnections)) },
|
||||
appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) },
|
||||
getUserConnections: jest.fn(() => mockUserConnections),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
|
@ -153,7 +169,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
|
||||
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
|
||||
expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId);
|
||||
|
||||
|
|
@ -174,7 +190,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
|||
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
|
||||
|
||||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => Promise.resolve(null)) },
|
||||
appConnections: { getLoaded: jest.fn(() => Promise.resolve(null)) },
|
||||
getUserConnections: jest.fn(() => null),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
|
@ -692,6 +708,18 @@ describe('User parameter passing tests', () => {
|
|||
createFlowWithHandler: jest.fn(),
|
||||
failFlow: jest.fn(),
|
||||
});
|
||||
|
||||
// Reset domain validation mock to default (allow all)
|
||||
mockIsMCPDomainAllowed.mockReset();
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
// Reset registry mocks
|
||||
mockRegistryInstance.getServerConfig.mockReset();
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
|
||||
|
||||
// Reset getAppConfig mock to default (no restrictions)
|
||||
mockGetAppConfig.mockReset();
|
||||
mockGetAppConfig.mockResolvedValue({});
|
||||
});
|
||||
|
||||
describe('createMCPTools', () => {
|
||||
|
|
@ -887,6 +915,229 @@ describe('User parameter passing tests', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('Runtime domain validation', () => {
|
||||
it('should skip tool creation when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://disallowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools: {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Should return undefined for disallowed domain
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
|
||||
// Verify domain validation was called with correct parameters
|
||||
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
|
||||
{ url: 'https://disallowed-domain.com/sse' },
|
||||
['allowed-domain.com'],
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow tool creation when domain is allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'admin' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://allowed-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return true (domain allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
|
||||
});
|
||||
|
||||
it('should skip domain validation for stdio transports (no URL)', async () => {
|
||||
const mockUser = { id: 'stdio-test-user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config without URL (stdio transport)
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
command: 'npx',
|
||||
args: ['@modelcontextprotocol/server'],
|
||||
});
|
||||
|
||||
// Mock getAppConfig (should not be called for stdio)
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
|
||||
});
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = await createMCPTool({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Should create tool successfully without domain check
|
||||
expect(result).toBeDefined();
|
||||
|
||||
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
|
||||
expect(mockGetAppConfig).not.toHaveBeenCalled();
|
||||
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return empty array from createMCPTools when domain is not allowed', async () => {
|
||||
const mockUser = { id: 'domain-test-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
// Mock server config with URL (remote server)
|
||||
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
|
||||
|
||||
// Mock getAppConfig to return domain restrictions
|
||||
mockGetAppConfig.mockResolvedValue({
|
||||
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
|
||||
});
|
||||
|
||||
// Mock domain validation to return false (domain not allowed)
|
||||
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
|
||||
|
||||
const result = await createMCPTools({
|
||||
res: mockRes,
|
||||
user: mockUser,
|
||||
serverName: 'test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
config: serverConfig,
|
||||
});
|
||||
|
||||
// Should return empty array for disallowed domain
|
||||
expect(result).toEqual([]);
|
||||
|
||||
// Should not call reinitMCPServer since domain check failed early
|
||||
expect(mockReinitMCPServer).not.toHaveBeenCalled();
|
||||
|
||||
// Verify getAppConfig was called with user role
|
||||
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
|
||||
});
|
||||
|
||||
it('should use user role when fetching domain restrictions', async () => {
|
||||
const adminUser = { id: 'admin-user', role: 'admin' };
|
||||
const regularUser = { id: 'regular-user', role: 'user' };
|
||||
const mockRes = { write: jest.fn(), flush: jest.fn() };
|
||||
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
// Mock different responses based on role
|
||||
mockGetAppConfig
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
|
||||
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
|
||||
|
||||
mockIsMCPDomainAllowed.mockResolvedValue(true);
|
||||
|
||||
const availableTools = {
|
||||
'test-tool::test-server': {
|
||||
function: {
|
||||
description: 'Test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Call with admin user
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: adminUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Reset and call with regular user
|
||||
mockRegistryInstance.getServerConfig.mockResolvedValue({
|
||||
url: 'https://some-domain.com/sse',
|
||||
});
|
||||
|
||||
await createMCPTool({
|
||||
res: mockRes,
|
||||
user: regularUser,
|
||||
toolKey: 'test-tool::test-server',
|
||||
provider: 'openai',
|
||||
userMCPAuthMap: {},
|
||||
availableTools,
|
||||
});
|
||||
|
||||
// Verify getAppConfig was called with correct roles
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
|
||||
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('User parameter integrity', () => {
|
||||
it('should preserve user object properties through the call chain', async () => {
|
||||
const complexUser = {
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ describe('processMessages', () => {
|
|||
type: 'text',
|
||||
text: {
|
||||
value:
|
||||
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
|
||||
"The text you have uploaded is from the book \"Harry Potter and the Philosopher's Stone\" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.",
|
||||
annotations: [
|
||||
{
|
||||
type: 'file_citation',
|
||||
|
|
@ -424,7 +424,7 @@ These points highlight Harry's initial experiences in the magical world and set
|
|||
type: 'text',
|
||||
text: {
|
||||
value:
|
||||
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
|
||||
"The text you have uploaded is from the book \"Harry Potter and the Philosopher's Stone\" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.",
|
||||
annotations: [
|
||||
{
|
||||
type: 'file_citation',
|
||||
|
|
@ -582,7 +582,7 @@ These points highlight Harry's initial experiences in the magical world and set
|
|||
type: 'text',
|
||||
text: {
|
||||
value:
|
||||
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation【11:2†source】.',
|
||||
"This is a test ^1^ with pre-existing citation-like text. Here's a real citation【11:2†source】.",
|
||||
annotations: [
|
||||
{
|
||||
type: 'file_citation',
|
||||
|
|
@ -610,7 +610,7 @@ These points highlight Harry's initial experiences in the magical world and set
|
|||
});
|
||||
|
||||
const expectedText =
|
||||
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation^1^.\n\n^1.^ test.txt';
|
||||
"This is a test ^1^ with pre-existing citation-like text. Here's a real citation^1^.\n\n^1.^ test.txt";
|
||||
|
||||
expect(result.text).toBe(expectedText);
|
||||
expect(result.edited).toBe(true);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ const {
|
|||
} = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
ErrorTypes,
|
||||
ContentTypes,
|
||||
imageGenTools,
|
||||
|
|
@ -18,6 +17,7 @@ const {
|
|||
ImageVisionTool,
|
||||
openapiToFunction,
|
||||
AgentCapabilities,
|
||||
isEphemeralAgentId,
|
||||
validateActionDomain,
|
||||
defaultAgentCapabilities,
|
||||
validateAndParseOpenAPISpec,
|
||||
|
|
@ -369,7 +369,15 @@ async function processRequiredActions(client, requiredActions) {
|
|||
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
|
||||
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
|
||||
*/
|
||||
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
|
||||
async function loadAgentTools({
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
signal,
|
||||
tool_resources,
|
||||
openAIApiKey,
|
||||
streamId = null,
|
||||
}) {
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
return {};
|
||||
} else if (
|
||||
|
|
@ -385,7 +393,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
|
||||
/** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */
|
||||
if (enabledCapabilities.size === 0 && agent.id === Constants.EPHEMERAL_AGENT_ID) {
|
||||
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
|
||||
enabledCapabilities = new Set(
|
||||
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
|
||||
);
|
||||
|
|
@ -422,7 +430,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
/** @type {ReturnType<typeof createOnSearchResults>} */
|
||||
let webSearchCallbacks;
|
||||
if (includesWebSearch) {
|
||||
webSearchCallbacks = createOnSearchResults(res);
|
||||
webSearchCallbacks = createOnSearchResults(res, streamId);
|
||||
}
|
||||
|
||||
/** @type {Record<string, Record<string, string>>} */
|
||||
|
|
@ -622,6 +630,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
|
|||
encrypted,
|
||||
name: toolName,
|
||||
description: functionSig.description,
|
||||
streamId,
|
||||
});
|
||||
|
||||
if (!tool) {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,29 @@
|
|||
const { nanoid } = require('nanoid');
|
||||
const { Tools } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { GenerationJobManager } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Helper to write attachment events either to res or to job emitter.
|
||||
* @param {import('http').ServerResponse} res - The server response object
|
||||
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
|
||||
* @param {Object} attachment - The attachment data
|
||||
*/
|
||||
function writeAttachment(res, streamId, attachment) {
|
||||
if (streamId) {
|
||||
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
|
||||
} else {
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a function to handle search results and stream them as attachments
|
||||
* @param {import('http').ServerResponse} res - The HTTP server response object
|
||||
* @param {string | null} [streamId] - The stream ID for resumable mode, or null for standard mode
|
||||
* @returns {{ onSearchResults: function(SearchResult, GraphRunnableConfig): void; onGetHighlights: function(string): void}} - Function that takes search results and returns or streams an attachment
|
||||
*/
|
||||
function createOnSearchResults(res) {
|
||||
function createOnSearchResults(res, streamId = null) {
|
||||
const context = {
|
||||
sourceMap: new Map(),
|
||||
searchResultData: undefined,
|
||||
|
|
@ -70,7 +86,7 @@ function createOnSearchResults(res) {
|
|||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -92,7 +108,7 @@ function createOnSearchResults(res) {
|
|||
}
|
||||
|
||||
const attachment = buildAttachment(context);
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
writeAttachment(res, streamId, attachment);
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ async function initializeMCPs() {
|
|||
}
|
||||
|
||||
// Initialize MCPServersRegistry first (required for MCPManager)
|
||||
// Pass allowedDomains from mcpSettings for domain validation
|
||||
try {
|
||||
createMCPServersRegistry(mongoose);
|
||||
createMCPServersRegistry(mongoose, appConfig?.mcpSettings?.allowedDomains);
|
||||
} catch (error) {
|
||||
logger.error('[MCP] Failed to initialize MCPServersRegistry:', error);
|
||||
throw error;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue