Merge branch 'dev' into feat/context-window-ui

This commit is contained in:
Marco Beretta 2025-12-29 02:07:54 +01:00
commit cb8322ca85
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
407 changed files with 25479 additions and 19894 deletions

View file

@ -350,9 +350,6 @@ function disposeClient(client) {
if (client.agentConfigs) {
client.agentConfigs = null;
}
if (client.agentIdMap) {
client.agentIdMap = null;
}
if (client.artifactPromises) {
client.artifactPromises = null;
}

View file

@ -10,7 +10,13 @@ const {
setAuthTokens,
registerUser,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const {
deleteAllUserSessions,
getUserById,
findSession,
updateUser,
findUser,
} = require('~/models');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const { getOAuthReconnectionManager } = require('~/config');
const { getOpenIdConfig } = require('~/strategies');
@ -72,16 +78,38 @@ const refreshController = async (req, res) => {
const openIdConfig = getOpenIdConfig();
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
const claims = tokenset.claims();
const { user, error } = await findOpenIDUser({
const { user, error, migration } = await findOpenIDUser({
findUser,
email: claims.email,
openidId: claims.sub,
idOnTheSource: claims.oid,
strategyName: 'refreshController',
});
logger.debug(
`[refreshController] findOpenIDUser result: user=${user?.email ?? 'null'}, error=${error ?? 'null'}, migration=${migration}, userOpenidId=${user?.openidId ?? 'null'}, claimsSub=${claims.sub}`,
);
if (error || !user) {
logger.warn(
`[refreshController] Redirecting to /login: error=${error ?? 'null'}, user=${user ? 'exists' : 'null'}`,
);
return res.status(401).redirect('/login');
}
// Handle migration: update user with openidId if found by email without openidId
// Also handle case where user has mismatched openidId (e.g., after database switch)
if (migration || user.openidId !== claims.sub) {
const reason = migration ? 'migration' : 'openidId mismatch';
await updateUser(user._id.toString(), {
provider: 'openid',
openidId: claims.sub,
});
logger.info(
`[refreshController] Updated user ${user.email} openidId (${reason}): ${user.openidId ?? 'null'} -> ${claims.sub}`,
);
}
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString(), refreshToken);
user.federatedTokens = {

View file

@ -1,5 +1,5 @@
const { nanoid } = require('nanoid');
const { sendEvent } = require('@librechat/api');
const { sendEvent, GenerationJobManager } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider');
const {
@ -144,17 +144,38 @@ function checkIfLastAgent(last_agent_id, langgraph_node) {
return langgraph_node?.endsWith(last_agent_id);
}
/**
* Helper to emit events either to res (standard mode) or to job emitter (resumable mode).
* @param {ServerResponse} res - The server response object
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
* @param {Object} eventData - The event data to send
*/
function emitEvent(res, streamId, eventData) {
if (streamId) {
GenerationJobManager.emitChunk(streamId, eventData);
} else {
sendEvent(res, eventData);
}
}
/**
* Get default handlers for stream events.
* @param {Object} options - The options object.
* @param {ServerResponse} options.res - The options object.
* @param {ContentAggregator} options.aggregateContent - The options object.
* @param {ServerResponse} options.res - The server response object.
* @param {ContentAggregator} options.aggregateContent - Content aggregator function.
* @param {ToolEndCallback} options.toolEndCallback - Callback to use when tool ends.
* @param {Array<UsageMetadata>} options.collectedUsage - The list of collected usage metadata.
* @param {string | null} [options.streamId] - The stream ID for resumable mode, or null for standard mode.
* @returns {Record<string, t.EventHandler>} The default handlers.
* @throws {Error} If the request is not found.
*/
function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedUsage }) {
function getDefaultHandlers({
res,
aggregateContent,
toolEndCallback,
collectedUsage,
streamId = null,
}) {
if (!res || !aggregateContent) {
throw new Error(
`[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`,
@ -173,16 +194,16 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
*/
handle: (event, data, metadata) => {
if (data?.stepDetails.type === StepTypes.TOOL_CALLS) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (!metadata?.hide_sequential_outputs) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else {
const agentName = metadata?.name ?? 'Agent';
const isToolCall = data?.stepDetails.type === StepTypes.TOOL_CALLS;
const action = isToolCall ? 'performing a task...' : 'thinking...';
sendEvent(res, {
emitEvent(res, streamId, {
event: 'on_agent_update',
data: {
runId: metadata?.run_id,
@ -202,11 +223,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
*/
handle: (event, data, metadata) => {
if (data?.delta.type === StepTypes.TOOL_CALLS) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (!metadata?.hide_sequential_outputs) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
@ -220,11 +241,11 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
*/
handle: (event, data, metadata) => {
if (data?.result != null) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (!metadata?.hide_sequential_outputs) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
@ -238,9 +259,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
*/
handle: (event, data, metadata) => {
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (!metadata?.hide_sequential_outputs) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
@ -254,9 +275,9 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
*/
handle: (event, data, metadata) => {
if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
} else if (!metadata?.hide_sequential_outputs) {
sendEvent(res, { event, data });
emitEvent(res, streamId, { event, data });
}
aggregateContent({ event, data });
},
@ -266,15 +287,30 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU
return handlers;
}
/**
* Helper to write attachment events either to res or to job emitter.
* @param {ServerResponse} res - The server response object
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
* @param {Object} attachment - The attachment data
*/
function writeAttachment(res, streamId, attachment) {
if (streamId) {
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
} else {
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
}
}
/**
*
* @param {Object} params
* @param {ServerRequest} params.req
* @param {ServerResponse} params.res
* @param {Promise<MongoFile | { filename: string; filepath: string; expires: number;} | null>[]} params.artifactPromises
* @param {string | null} [params.streamId] - The stream ID for resumable mode, or null for standard mode.
* @returns {ToolEndCallback} The tool end callback.
*/
function createToolEndCallback({ req, res, artifactPromises }) {
function createToolEndCallback({ req, res, artifactPromises, streamId = null }) {
/**
* @type {ToolEndCallback}
*/
@ -302,10 +338,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
if (!attachment) {
return null;
}
if (!res.headersSent) {
if (!streamId && !res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
writeAttachment(res, streamId, attachment);
return attachment;
})().catch((error) => {
logger.error('Error processing file citations:', error);
@ -314,8 +350,6 @@ function createToolEndCallback({ req, res, artifactPromises }) {
);
}
// TODO: a lot of duplicated code in createToolEndCallback
// we should refactor this to use a helper function in a follow-up PR
if (output.artifact[Tools.ui_resources]) {
artifactPromises.push(
(async () => {
@ -326,10 +360,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
conversationId: metadata.thread_id,
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
};
if (!res.headersSent) {
if (!streamId && !res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
writeAttachment(res, streamId, attachment);
return attachment;
})().catch((error) => {
logger.error('Error processing artifact content:', error);
@ -348,10 +382,10 @@ function createToolEndCallback({ req, res, artifactPromises }) {
conversationId: metadata.thread_id,
[Tools.web_search]: { ...output.artifact[Tools.web_search] },
};
if (!res.headersSent) {
if (!streamId && !res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
writeAttachment(res, streamId, attachment);
return attachment;
})().catch((error) => {
logger.error('Error processing artifact content:', error);
@ -388,7 +422,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
toolCallId: output.tool_call_id,
conversationId: metadata.thread_id,
});
if (!res.headersSent) {
if (!streamId && !res.headersSent) {
return fileMetadata;
}
@ -396,7 +430,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
return null;
}
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
writeAttachment(res, streamId, fileMetadata);
return fileMetadata;
})().catch((error) => {
logger.error('Error processing artifact content:', error);
@ -435,7 +469,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
conversationId: metadata.thread_id,
session_id: output.artifact.session_id,
});
if (!res.headersSent) {
if (!streamId && !res.headersSent) {
return fileMetadata;
}
@ -443,7 +477,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
return null;
}
res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`);
writeAttachment(res, streamId, fileMetadata);
return fileMetadata;
})().catch((error) => {
logger.error('Error processing code output:', error);

View file

@ -14,6 +14,7 @@ const {
getBalanceConfig,
getProviderConfig,
memoryInstructions,
GenerationJobManager,
getTransactionsConfig,
createMemoryProcessor,
filterMalformedContentParts,
@ -36,14 +37,13 @@ const {
EModelEndpoint,
PermissionTypes,
isAgentsEndpoint,
AgentCapabilities,
isEphemeralAgentId,
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { createContextHandlers } = require('~/app/clients/prompts');
const { checkCapability } = require('~/server/services/Config');
const { getConvoFiles } = require('~/models/Conversation');
const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role');
@ -95,59 +95,101 @@ function logToolError(graph, error, toolId) {
});
}
/** Regex pattern to match agent ID suffix (____N) */
const AGENT_SUFFIX_PATTERN = /____(\d+)$/;
/**
* Applies agent labeling to conversation history when multi-agent patterns are detected.
* Labels content parts by their originating agent to prevent identity confusion.
* Creates a mapMethod for getMessagesForConversation that processes agent content.
* - Strips agentId/groupId metadata from all content
* - For multi-agent: filters to primary agent content only (no suffix or lowest suffix)
* - For multi-agent: applies agent labels to content
*
* @param {TMessage[]} orderedMessages - The ordered conversation messages
* @param {Agent} primaryAgent - The primary agent configuration
* @param {Map<string, Agent>} agentConfigs - Map of additional agent configurations
* @returns {TMessage[]} Messages with agent labels applied where appropriate
* @param {Agent} primaryAgent - Primary agent configuration
* @param {Map<string, Agent>} [agentConfigs] - Additional agent configurations
* @returns {(message: TMessage) => TMessage} Map method for processing messages
*/
function applyAgentLabelsToHistory(orderedMessages, primaryAgent, agentConfigs) {
const shouldLabelByAgent = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
if (!shouldLabelByAgent) {
return orderedMessages;
}
const processedMessages = [];
for (let i = 0; i < orderedMessages.length; i++) {
const message = orderedMessages[i];
/** @type {Record<string, string>} */
const agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
function createMultiAgentMapper(primaryAgent, agentConfigs) {
const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
/** @type {Record<string, string> | null} */
let agentNames = null;
if (hasMultipleAgents) {
agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
if (agentConfigs) {
for (const [agentId, agentConfig] of agentConfigs.entries()) {
agentNames[agentId] = agentConfig.name || agentConfig.id;
}
}
if (
!message.isCreatedByUser &&
message.metadata?.agentIdMap &&
Array.isArray(message.content)
) {
try {
const labeledContent = labelContentByAgent(
message.content,
message.metadata.agentIdMap,
agentNames,
);
processedMessages.push({ ...message, content: labeledContent });
} catch (error) {
logger.error('[AgentClient] Error applying agent labels to message:', error);
processedMessages.push(message);
}
} else {
processedMessages.push(message);
}
}
return processedMessages;
return (message) => {
if (message.isCreatedByUser || !Array.isArray(message.content)) {
return message;
}
// Find primary agent ID (no suffix, or lowest suffix number) - only needed for multi-agent
let primaryAgentId = null;
let hasAgentMetadata = false;
if (hasMultipleAgents) {
let lowestSuffixIndex = Infinity;
for (const part of message.content) {
const agentId = part?.agentId;
if (!agentId) {
continue;
}
hasAgentMetadata = true;
const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN);
if (!suffixMatch) {
primaryAgentId = agentId;
break;
}
const suffixIndex = parseInt(suffixMatch[1], 10);
if (suffixIndex < lowestSuffixIndex) {
lowestSuffixIndex = suffixIndex;
primaryAgentId = agentId;
}
}
} else {
// Single agent: just check if any metadata exists
hasAgentMetadata = message.content.some((part) => part?.agentId || part?.groupId);
}
if (!hasAgentMetadata) {
return message;
}
try {
/** @type {Array<TMessageContentParts>} */
const filteredContent = [];
/** @type {Record<number, string>} */
const agentIdMap = {};
for (const part of message.content) {
const agentId = part?.agentId;
// For single agent: include all parts; for multi-agent: filter to primary
if (!hasMultipleAgents || !agentId || agentId === primaryAgentId) {
const newIndex = filteredContent.length;
const { agentId: _a, groupId: _g, ...cleanPart } = part;
filteredContent.push(cleanPart);
if (agentId && hasMultipleAgents) {
agentIdMap[newIndex] = agentId;
}
}
}
const finalContent =
Object.keys(agentIdMap).length > 0 && agentNames
? labelContentByAgent(filteredContent, agentIdMap, agentNames)
: filteredContent;
return { ...message, content: finalContent };
} catch (error) {
logger.error('[AgentClient] Error processing multi-agent message:', error);
return message;
}
};
}
class AgentClient extends BaseClient {
@ -199,8 +241,6 @@ class AgentClient extends BaseClient {
this.indexTokenCountMap = {};
/** @type {(messages: BaseMessage[]) => Promise<void>} */
this.processMemory;
/** @type {Record<number, string> | null} */
this.agentIdMap = null;
}
/**
@ -289,18 +329,13 @@ class AgentClient extends BaseClient {
{ instructions = null, additional_instructions = null },
opts,
) {
let orderedMessages = this.constructor.getMessagesForConversation({
const orderedMessages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
summary: this.shouldSummarize,
mapMethod: createMultiAgentMapper(this.options.agent, this.agentConfigs),
});
orderedMessages = applyAgentLabelsToHistory(
orderedMessages,
this.options.agent,
this.agentConfigs,
);
let payload;
/** @type {number | undefined} */
let promptTokens;
@ -552,10 +587,9 @@ class AgentClient extends BaseClient {
agent: prelimAgent,
allowedProviders,
endpointOption: {
endpoint:
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
? EModelEndpoint.agents
: memoryConfig.agent?.provider,
endpoint: !isEphemeralAgentId(prelimAgent.id)
? EModelEndpoint.agents
: memoryConfig.agent?.provider,
},
},
{
@ -595,10 +629,12 @@ class AgentClient extends BaseClient {
const userId = this.options.req.user.id + '';
const messageId = this.responseMessageId + '';
const conversationId = this.conversationId + '';
const streamId = this.options.req?._resumableStreamId || null;
const [withoutKeys, processMemory] = await createMemoryProcessor({
userId,
config,
messageId,
streamId,
conversationId,
memoryMethods: {
setMemory: db.setMemory,
@ -692,9 +728,7 @@ class AgentClient extends BaseClient {
});
const completion = filterMalformedContentParts(this.contentParts);
const metadata = this.agentIdMap ? { agentIdMap: this.agentIdMap } : undefined;
return { completion, metadata };
return { completion };
}
/**
@ -890,12 +924,10 @@ class AgentClient extends BaseClient {
*/
const runAgents = async (messages) => {
const agents = [this.options.agent];
if (
this.agentConfigs &&
this.agentConfigs.size > 0 &&
((this.options.agent.edges?.length ?? 0) > 0 ||
(await checkCapability(this.options.req, AgentCapabilities.chain)))
) {
// Include additional agents when:
// - agentConfigs has agents (from addedConvo parallel execution or agent handoffs)
// - Agents without incoming edges become start nodes and run in parallel automatically
if (this.agentConfigs && this.agentConfigs.size > 0) {
agents.push(...this.agentConfigs.values());
}
@ -955,6 +987,12 @@ class AgentClient extends BaseClient {
}
this.run = run;
const streamId = this.options.req?._resumableStreamId;
if (streamId && run.Graph) {
GenerationJobManager.setGraph(streamId, run.Graph);
}
if (userMCPAuthMap != null) {
config.configurable.userMCPAuthMap = userMCPAuthMap;
}
@ -985,24 +1023,6 @@ class AgentClient extends BaseClient {
);
});
}
try {
/** Capture agent ID map if we have edges or multiple agents */
const shouldStoreAgentMap =
(this.options.agent.edges?.length ?? 0) > 0 || (this.agentConfigs?.size ?? 0) > 0;
if (shouldStoreAgentMap && run?.Graph) {
const contentPartAgentMap = run.Graph.getContentPartAgentMap();
if (contentPartAgentMap && contentPartAgentMap.size > 0) {
this.agentIdMap = Object.fromEntries(contentPartAgentMap);
logger.debug('[AgentClient] Captured agent ID map:', {
totalParts: this.contentParts.length,
mappedParts: Object.keys(this.agentIdMap).length,
});
}
}
} catch (error) {
logger.error('[AgentClient] Error capturing agent ID map:', error);
}
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',

View file

@ -2,14 +2,11 @@ const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const {
sendEvent,
GenerationJobManager,
sanitizeFileForTransmit,
sanitizeMessageForTransmit,
} = require('@librechat/api');
const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const { handleAbortError } = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { saveMessage } = require('~/models');
@ -31,12 +28,16 @@ function createCloseHandler(abortController) {
};
}
const AgentController = async (req, res, next, initializeClient, addTitle) => {
let {
/**
* Resumable Agent Controller - Generation runs independently of HTTP connection.
* Returns streamId immediately, client subscribes separately via SSE.
*/
const ResumableAgentController = async (req, res, next, initializeClient, addTitle) => {
const {
text,
isRegenerate,
endpointOption,
conversationId,
conversationId: reqConversationId,
isContinued = false,
editedContent = null,
parentMessageId = null,
@ -44,18 +45,354 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
responseMessageId: editedResponseMessageId = null,
} = req.body;
let sender;
let abortKey;
const userId = req.user.id;
// Generate conversationId upfront if not provided - streamId === conversationId always
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
const conversationId =
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
const streamId = conversationId;
let client = null;
try {
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
req._resumableStreamId = streamId;
// Send JSON response IMMEDIATELY so client can connect to SSE stream
// This is critical: tool loading (MCP OAuth) may emit events that the client needs to receive
res.json({ streamId, conversationId, status: 'started' });
// Note: We no longer use res.on('close') to abort since we send JSON immediately.
// The response closes normally after res.json(), which is not an abort condition.
// Abort handling is done through GenerationJobManager via the SSE stream connection.
// Track if partial response was already saved to avoid duplicates
let partialResponseSaved = false;
/**
* Listen for all subscribers leaving to save partial response.
* This ensures the response is saved to DB even if all clients disconnect
* while generation continues.
*
* Note: The messageId used here falls back to `${userMessage.messageId}_` if the
* actual response messageId isn't available yet. The final response save will
* overwrite this with the complete response using the same messageId pattern.
*/
job.emitter.on('allSubscribersLeft', async (aggregatedContent) => {
if (partialResponseSaved || !aggregatedContent || aggregatedContent.length === 0) {
return;
}
const resumeState = await GenerationJobManager.getResumeState(streamId);
if (!resumeState?.userMessage) {
logger.debug('[ResumableAgentController] No user message to save partial response for');
return;
}
partialResponseSaved = true;
const responseConversationId = resumeState.conversationId || conversationId;
try {
const partialMessage = {
messageId: resumeState.responseMessageId || `${resumeState.userMessage.messageId}_`,
conversationId: responseConversationId,
parentMessageId: resumeState.userMessage.messageId,
sender: client?.sender ?? 'AI',
content: aggregatedContent,
unfinished: true,
error: false,
isCreatedByUser: false,
user: userId,
endpoint: endpointOption.endpoint,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
};
if (req.body?.agent_id) {
partialMessage.agent_id = req.body.agent_id;
}
await saveMessage(req, partialMessage, {
context: 'api/server/controllers/agents/request.js - partial response on disconnect',
});
logger.debug(
`[ResumableAgentController] Saved partial response for ${streamId}, content parts: ${aggregatedContent.length}`,
);
} catch (error) {
logger.error('[ResumableAgentController] Error saving partial response:', error);
// Reset flag so we can try again if subscribers reconnect and leave again
partialResponseSaved = false;
}
});
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
const result = await initializeClient({
req,
res,
endpointOption,
// Use the job's abort controller signal - allows abort via GenerationJobManager.abortJob()
signal: job.abortController.signal,
});
if (job.abortController.signal.aborted) {
GenerationJobManager.completeJob(streamId, 'Request aborted during initialization');
return;
}
client = result.client;
if (client?.sender) {
GenerationJobManager.updateMetadata(streamId, { sender: client.sender });
}
// Store reference to client's contentParts - graph will be set when run is created
if (client?.contentParts) {
GenerationJobManager.setContentParts(streamId, client.contentParts);
}
let userMessage;
const getReqData = (data = {}) => {
if (data.userMessage) {
userMessage = data.userMessage;
}
// conversationId is pre-generated, no need to update from callback
};
// Start background generation - readyPromise resolves immediately now
// (sync mechanism handles late subscribers)
const startGeneration = async () => {
try {
// Short timeout as safety net - promise should already be resolved
await Promise.race([job.readyPromise, new Promise((resolve) => setTimeout(resolve, 100))]);
} catch (waitError) {
logger.warn(
`[ResumableAgentController] Error waiting for subscriber: ${waitError.message}`,
);
}
try {
const onStart = (userMsg, respMsgId, _isNewConvo) => {
userMessage = userMsg;
// Store userMessage and responseMessageId upfront for resume capability
GenerationJobManager.updateMetadata(streamId, {
responseMessageId: respMsgId,
userMessage: {
messageId: userMsg.messageId,
parentMessageId: userMsg.parentMessageId,
conversationId: userMsg.conversationId,
text: userMsg.text,
},
});
GenerationJobManager.emitChunk(streamId, {
created: true,
message: userMessage,
streamId,
});
};
const messageOptions = {
user: userId,
onStart,
getReqData,
isContinued,
isRegenerate,
editedContent,
conversationId,
parentMessageId,
abortController: job.abortController,
overrideParentMessageId,
isEdited: !!editedContent,
userMCPAuthMap: result.userMCPAuthMap,
responseMessageId: editedResponseMessageId,
progressOptions: {
res: {
write: () => true,
end: () => {},
headersSent: false,
writableEnded: false,
},
},
};
const response = await client.sendMessage(text, messageOptions);
const messageId = response.messageId;
const endpoint = endpointOption.endpoint;
response.endpoint = endpoint;
const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
if (req.body.files && client.options?.attachments) {
userMessage.files = [];
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
for (const attachment of client.options.attachments) {
if (messageFiles.has(attachment.file_id)) {
userMessage.files.push(sanitizeFileForTransmit(attachment));
}
}
delete userMessage.image_urls;
}
// Check abort state BEFORE calling completeJob (which triggers abort signal for cleanup)
const wasAbortedBeforeComplete = job.abortController.signal.aborted;
const isNewConvo = !reqConversationId || reqConversationId === 'new';
const shouldGenerateTitle =
addTitle &&
parentMessageId === Constants.NO_PARENT &&
isNewConvo &&
!wasAbortedBeforeComplete;
if (!wasAbortedBeforeComplete) {
const finalEvent = {
final: true,
conversation,
title: conversation.title,
requestMessage: sanitizeMessageForTransmit(userMessage),
responseMessage: { ...response },
};
GenerationJobManager.emitDone(streamId, finalEvent);
GenerationJobManager.completeJob(streamId);
if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) {
await saveMessage(
req,
{ ...response, user: userId },
{ context: 'api/server/controllers/agents/request.js - resumable response end' },
);
}
} else {
const finalEvent = {
final: true,
conversation,
title: conversation.title,
requestMessage: sanitizeMessageForTransmit(userMessage),
responseMessage: { ...response, error: true },
error: { message: 'Request was aborted' },
};
GenerationJobManager.emitDone(streamId, finalEvent);
GenerationJobManager.completeJob(streamId, 'Request aborted');
}
if (!client.skipSaveUserMessage && userMessage) {
await saveMessage(req, userMessage, {
context: 'api/server/controllers/agents/request.js - resumable user message',
});
}
if (shouldGenerateTitle) {
addTitle(req, {
text,
response: { ...response },
client,
})
.catch((err) => {
logger.error('[ResumableAgentController] Error in title generation', err);
})
.finally(() => {
if (client) {
disposeClient(client);
}
});
} else {
if (client) {
disposeClient(client);
}
}
} catch (error) {
// Check if this was an abort (not a real error)
const wasAborted = job.abortController.signal.aborted || error.message?.includes('abort');
if (wasAborted) {
logger.debug(`[ResumableAgentController] Generation aborted for ${streamId}`);
// abortJob already handled emitDone and completeJob
} else {
logger.error(`[ResumableAgentController] Generation error for ${streamId}:`, error);
GenerationJobManager.emitError(streamId, error.message || 'Generation failed');
GenerationJobManager.completeJob(streamId, error.message);
}
if (client) {
disposeClient(client);
}
// Don't continue to title generation after error/abort
return;
}
};
// Start generation and handle any unhandled errors
startGeneration().catch((err) => {
logger.error(
`[ResumableAgentController] Unhandled error in background generation: ${err.message}`,
);
GenerationJobManager.completeJob(streamId, err.message);
});
} catch (error) {
logger.error('[ResumableAgentController] Initialization error:', error);
if (!res.headersSent) {
res.status(500).json({ error: error.message || 'Failed to start generation' });
} else {
// JSON already sent, emit error to stream so client can receive it
GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation');
}
GenerationJobManager.completeJob(streamId, error.message);
if (client) {
disposeClient(client);
}
}
};
/**
* Agent Controller - Routes to ResumableAgentController for all requests.
* The legacy non-resumable path is kept below but no longer used by default.
*/
const AgentController = async (req, res, next, initializeClient, addTitle) => {
return ResumableAgentController(req, res, next, initializeClient, addTitle);
};
/**
* Legacy Non-resumable Agent Controller - Uses GenerationJobManager for abort handling.
* Response is streamed directly to client via res, but abort state is managed centrally.
* @deprecated Use ResumableAgentController instead
*/
const _LegacyAgentController = async (req, res, next, initializeClient, addTitle) => {
const {
text,
isRegenerate,
endpointOption,
conversationId: reqConversationId,
isContinued = false,
editedContent = null,
parentMessageId = null,
overrideParentMessageId = null,
responseMessageId: editedResponseMessageId = null,
} = req.body;
// Generate conversationId upfront if not provided - streamId === conversationId always
// Treat "new" as a placeholder that needs a real UUID (frontend may send "new" for new convos)
const conversationId =
!reqConversationId || reqConversationId === 'new' ? crypto.randomUUID() : reqConversationId;
const streamId = conversationId;
let userMessage;
let promptTokens;
let userMessageId;
let responseMessageId;
let userMessagePromise;
let getAbortData;
let client = null;
let cleanupHandlers = [];
const newConvo = !conversationId;
// Match the same logic used for conversationId generation above
const isNewConvo = !reqConversationId || reqConversationId === 'new';
const userId = req.user.id;
// Create handler to avoid capturing the entire parent scope
@ -64,24 +401,20 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
// Update job metadata with prompt tokens for abort handling
GenerationJobManager.updateMetadata(streamId, { promptTokens: data[key] });
} else if (key === 'sender') {
sender = data[key];
} else if (key === 'abortKey') {
abortKey = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
GenerationJobManager.updateMetadata(streamId, { sender: data[key] });
}
// conversationId is pre-generated, no need to update from callback
}
};
// Create a function to handle final cleanup
const performCleanup = () => {
const performCleanup = async () => {
logger.debug('[AgentController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
@ -95,10 +428,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}
}
// Clean up abort controller
if (abortKey) {
logger.debug('[AgentController] Cleaning up abort controller');
cleanupAbortController(abortKey);
// Complete the job in GenerationJobManager
if (streamId) {
logger.debug('[AgentController] Completing job in GenerationJobManager');
await GenerationJobManager.completeJob(streamId);
}
// Dispose client properly
@ -110,11 +443,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
client = null;
getReqData = null;
userMessage = null;
getAbortData = null;
endpointOption.agent = null;
endpointOption = null;
cleanupHandlers = null;
userMessagePromise = null;
// Clear request data map
if (requestDataMap.has(req)) {
@ -136,6 +465,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}
};
cleanupHandlers.push(removePrelimHandler);
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
const result = await initializeClient({
req,
@ -143,6 +473,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
endpointOption,
signal: prelimAbortController.signal,
});
if (prelimAbortController.signal?.aborted) {
prelimAbortController = null;
throw new Error('Request was aborted before initialization could complete');
@ -161,28 +492,24 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Store request data in WeakMap keyed by req object
requestDataMap.set(req, { client });
// Use WeakRef to allow GC but still access content if it exists
const contentRef = new WeakRef(client.contentParts || []);
// Create job in GenerationJobManager for abort handling
// streamId === conversationId (pre-generated above)
const job = await GenerationJobManager.createJob(streamId, userId, conversationId);
// Minimize closure scope - only capture small primitives and WeakRef
getAbortData = () => {
// Dereference WeakRef each time
const content = contentRef.deref();
// Store endpoint metadata for abort handling
GenerationJobManager.updateMetadata(streamId, {
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
sender: client?.sender,
});
return {
sender,
content: content || [],
userMessage,
promptTokens,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
};
};
// Store content parts reference for abort
if (client?.contentParts) {
GenerationJobManager.setContentParts(streamId, client.contentParts);
}
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
const closeHandler = createCloseHandler(abortController);
const closeHandler = createCloseHandler(job.abortController);
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
@ -192,6 +519,27 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}
});
/**
* onStart callback - stores user message and response ID for abort handling
*/
const onStart = (userMsg, respMsgId, _isNewConvo) => {
sendEvent(res, { message: userMsg, created: true });
userMessage = userMsg;
userMessageId = userMsg.messageId;
responseMessageId = respMsgId;
// Store metadata for abort handling (conversationId is pre-generated)
GenerationJobManager.updateMetadata(streamId, {
responseMessageId: respMsgId,
userMessage: {
messageId: userMsg.messageId,
parentMessageId: userMsg.parentMessageId,
conversationId,
text: userMsg.text,
},
});
};
const messageOptions = {
user: userId,
onStart,
@ -201,7 +549,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
editedContent,
conversationId,
parentMessageId,
abortController,
abortController: job.abortController,
overrideParentMessageId,
isEdited: !!editedContent,
userMCPAuthMap: result.userMCPAuthMap,
@ -241,7 +589,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}
// Only send if not aborted
if (!abortController.signal.aborted) {
if (!job.abortController.signal.aborted) {
// Create a new response object with minimal copies
const finalResponse = { ...response };
@ -292,7 +640,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}
// Add title if needed - extract minimal data
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
if (addTitle && parentMessageId === Constants.NO_PARENT && isNewConvo) {
addTitle(req, {
text,
response: { ...response },
@ -315,7 +663,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Handle error without capturing much scope
handleAbortError(res, req, error, {
conversationId,
sender,
sender: client?.sender,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
userMessageId,

View file

@ -6,10 +6,54 @@
* @import { MCPServerDocument } from 'librechat-data-provider'
*/
const { logger } = require('@librechat/data-schemas');
const {
isMCPDomainNotAllowedError,
isMCPInspectionFailedError,
MCPErrorCodes,
} = require('@librechat/api');
const { Constants, MCPServerUserInputSchema } = require('librechat-data-provider');
const { cacheMCPServerTools, getMCPServerTools } = require('~/server/services/Config');
const { getMCPManager, getMCPServersRegistry } = require('~/config');
/**
* Handles MCP-specific errors and sends appropriate HTTP responses.
* @param {Error} error - The error to handle
* @param {import('express').Response} res - Express response object
* @returns {import('express').Response | null} Response if handled, null if not an MCP error
*/
function handleMCPError(error, res) {
if (isMCPDomainNotAllowedError(error)) {
return res.status(error.statusCode).json({
error: error.code,
message: error.message,
});
}
if (isMCPInspectionFailedError(error)) {
return res.status(error.statusCode).json({
error: error.code,
message: error.message,
});
}
// Fallback for legacy string-based error handling (backwards compatibility)
if (error.message?.startsWith(MCPErrorCodes.DOMAIN_NOT_ALLOWED)) {
return res.status(403).json({
error: MCPErrorCodes.DOMAIN_NOT_ALLOWED,
message: error.message.replace(/^MCP_DOMAIN_NOT_ALLOWED\s*:\s*/i, ''),
});
}
if (error.message?.startsWith(MCPErrorCodes.INSPECTION_FAILED)) {
return res.status(400).json({
error: MCPErrorCodes.INSPECTION_FAILED,
message: error.message,
});
}
return null;
}
/**
* Get all MCP tools available to the user
*/
@ -175,11 +219,9 @@ const createMCPServerController = async (req, res) => {
});
} catch (error) {
logger.error('[createMCPServer]', error);
if (error.message?.startsWith('MCP_INSPECTION_FAILED')) {
return res.status(400).json({
error: 'MCP_INSPECTION_FAILED',
message: error.message,
});
const mcpErrorResponse = handleMCPError(error, res);
if (mcpErrorResponse) {
return mcpErrorResponse;
}
res.status(500).json({ message: error.message });
}
@ -235,11 +277,9 @@ const updateMCPServerController = async (req, res) => {
res.status(200).json(parsedConfig);
} catch (error) {
logger.error('[updateMCPServer]', error);
if (error.message?.startsWith('MCP_INSPECTION_FAILED:')) {
return res.status(400).json({
error: 'MCP_INSPECTION_FAILED',
message: error.message,
});
const mcpErrorResponse = handleMCPError(error, res);
if (mcpErrorResponse) {
return mcpErrorResponse;
}
res.status(500).json({ message: error.message });
}

View file

@ -16,6 +16,8 @@ const {
performStartupChecks,
handleJsonParseError,
initializeFileStorage,
GenerationJobManager,
createStreamServices,
} = require('@librechat/api');
const { connectDb, indexSync } = require('~/db');
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
@ -192,6 +194,11 @@ const startServer = async () => {
await initializeMCPs();
await initializeOAuthReconnectManager();
await checkMigrations();
// Configure stream services (auto-detects Redis from USE_REDIS env var)
const streamServices = createStreamServices();
GenerationJobManager.configure(streamServices);
GenerationJobManager.initialize();
});
};

View file

@ -1,2 +0,0 @@
// abortControllers.js
module.exports = new Map();

View file

@ -1,124 +1,102 @@
const { logger } = require('@librechat/data-schemas');
const { countTokens, isEnabled, sendEvent, sanitizeMessageForTransmit } = require('@librechat/api');
const { isAssistantsEndpoint, ErrorTypes, Constants } = require('librechat-data-provider');
const {
countTokens,
isEnabled,
sendEvent,
GenerationJobManager,
sanitizeMessageForTransmit,
} = require('@librechat/api');
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const clearPendingReq = require('~/cache/clearPendingReq');
const { sendError } = require('~/server/middleware/error');
const { spendTokens } = require('~/models/spendTokens');
const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const { abortRun } = require('./abortRun');
const abortDataMap = new WeakMap();
/**
* @param {string} abortKey
* @returns {boolean}
* Abort an active message generation.
* Uses GenerationJobManager for all agent requests.
* Since streamId === conversationId, we can directly abort by conversationId.
*/
function cleanupAbortController(abortKey) {
if (!abortControllers.has(abortKey)) {
return false;
}
const { abortController } = abortControllers.get(abortKey);
if (!abortController) {
abortControllers.delete(abortKey);
return true;
}
// 1. Check if this controller has any composed signals and clean them up
try {
// This creates a temporary composed signal to use for cleanup
const composedSignal = AbortSignal.any([abortController.signal]);
// Get all event types - in practice, AbortSignal typically only uses 'abort'
const eventTypes = ['abort'];
// First, execute a dummy listener removal to handle potential composed signals
for (const eventType of eventTypes) {
const dummyHandler = () => {};
composedSignal.addEventListener(eventType, dummyHandler);
composedSignal.removeEventListener(eventType, dummyHandler);
const listeners = composedSignal.listeners?.(eventType) || [];
for (const listener of listeners) {
composedSignal.removeEventListener(eventType, listener);
}
}
} catch (e) {
logger.debug(`Error cleaning up composed signals: ${e}`);
}
// 2. Abort the controller if not already aborted
if (!abortController.signal.aborted) {
abortController.abort();
}
// 3. Remove from registry
abortControllers.delete(abortKey);
// 4. Clean up any data stored in the WeakMap
if (abortDataMap.has(abortController)) {
abortDataMap.delete(abortController);
}
// 5. Clean up function references on the controller
if (abortController.getAbortData) {
abortController.getAbortData = null;
}
if (abortController.abortCompletion) {
abortController.abortCompletion = null;
}
return true;
}
/**
* @param {string} abortKey
* @returns {function(): void}
*/
function createCleanUpHandler(abortKey) {
return function () {
try {
cleanupAbortController(abortKey);
} catch {
// Ignore cleanup errors
}
};
}
async function abortMessage(req, res) {
let { abortKey, endpoint } = req.body;
const { abortKey, endpoint } = req.body;
if (isAssistantsEndpoint(endpoint)) {
return await abortRun(req, res);
}
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
const userId = req.user.id;
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
abortKey = conversationId;
// Use GenerationJobManager to abort the job (streamId === conversationId)
const abortResult = await GenerationJobManager.abortJob(conversationId);
if (!abortResult.success) {
if (!res.headersSent) {
return res.status(204).send({ message: 'Request not found' });
}
return;
}
if (!abortControllers.has(abortKey) && !res.headersSent) {
return res.status(204).send({ message: 'Request not found' });
}
const { jobData, content, text } = abortResult;
const { abortController } = abortControllers.get(abortKey) ?? {};
if (!abortController) {
return res.status(204).send({ message: 'Request not found' });
}
// Count tokens and spend them
const completionTokens = await countTokens(text);
const promptTokens = jobData?.promptTokens ?? 0;
const finalEvent = await abortController.abortCompletion?.();
logger.debug(
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
JSON.stringify({ abortKey }),
const responseMessage = {
messageId: jobData?.responseMessageId,
parentMessageId: jobData?.userMessage?.messageId,
conversationId: jobData?.conversationId,
content,
text,
sender: jobData?.sender ?? 'AI',
finish_reason: 'incomplete',
endpoint: jobData?.endpoint,
iconURL: jobData?.iconURL,
model: jobData?.model,
unfinished: false,
error: false,
isCreatedByUser: false,
tokenCount: completionTokens,
};
await spendTokens(
{ ...responseMessage, context: 'incomplete', user: userId },
{ promptTokens, completionTokens },
);
cleanupAbortController(abortKey);
if (res.headersSent && finalEvent) {
await saveMessage(
req,
{ ...responseMessage, user: userId },
{ context: 'api/server/middleware/abortMiddleware.js' },
);
// Get conversation for title
const conversation = await getConvo(userId, conversationId);
const finalEvent = {
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
final: true,
conversation,
requestMessage: jobData?.userMessage
? sanitizeMessageForTransmit({
messageId: jobData.userMessage.messageId,
parentMessageId: jobData.userMessage.parentMessageId,
conversationId: jobData.userMessage.conversationId,
text: jobData.userMessage.text,
isCreatedByUser: true,
})
: null,
responseMessage,
};
logger.debug(
`[abortMessage] ID: ${userId} | ${req.user.email} | Aborted request: ${conversationId}`,
);
if (res.headersSent) {
return sendEvent(res, finalEvent);
}
@ -139,171 +117,13 @@ const handleAbort = function () {
};
};
const createAbortController = (req, res, getAbortData, getReqData) => {
const abortController = new AbortController();
const { endpointOption } = req.body;
// Store minimal data in WeakMap to avoid circular references
abortDataMap.set(abortController, {
getAbortDataFn: getAbortData,
userId: req.user.id,
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
});
// Replace the direct function reference with a wrapper that uses WeakMap
abortController.getAbortData = function () {
const data = abortDataMap.get(this);
if (!data || typeof data.getAbortDataFn !== 'function') {
return {};
}
try {
const result = data.getAbortDataFn();
// Create a copy without circular references
const cleanResult = { ...result };
// If userMessagePromise exists, break its reference to client
if (
cleanResult.userMessagePromise &&
typeof cleanResult.userMessagePromise.then === 'function'
) {
// Create a new promise that fulfills with the same result but doesn't reference the original
const originalPromise = cleanResult.userMessagePromise;
cleanResult.userMessagePromise = new Promise((resolve, reject) => {
originalPromise.then(
(result) => resolve({ ...result }),
(error) => reject(error),
);
});
}
return cleanResult;
} catch (err) {
logger.error('[abortController.getAbortData] Error:', err);
return {};
}
};
/**
* @param {TMessage} userMessage
* @param {string} responseMessageId
* @param {boolean} [isNewConvo]
*/
const onStart = (userMessage, responseMessageId, isNewConvo) => {
sendEvent(res, { message: userMessage, created: true });
const prelimAbortKey = userMessage?.conversationId ?? req.user.id;
const abortKey = isNewConvo
? `${prelimAbortKey}${Constants.COMMON_DIVIDER}${Constants.NEW_CONVO}`
: prelimAbortKey;
getReqData({ abortKey });
const prevRequest = abortControllers.get(abortKey);
const { overrideUserMessageId } = req?.body ?? {};
if (overrideUserMessageId != null && prevRequest && prevRequest?.abortController) {
const data = prevRequest.abortController.getAbortData();
getReqData({ userMessage: data?.userMessage });
const addedAbortKey = `${abortKey}:${responseMessageId}`;
// Store minimal options
const minimalOptions = {
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
};
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
const cleanupHandler = createCleanUpHandler(addedAbortKey);
res.on('finish', cleanupHandler);
return;
}
// Store minimal options
const minimalOptions = {
endpoint: endpointOption.endpoint,
iconURL: endpointOption.iconURL,
model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model,
};
abortControllers.set(abortKey, { abortController, ...minimalOptions });
const cleanupHandler = createCleanUpHandler(abortKey);
res.on('finish', cleanupHandler);
};
// Define abortCompletion without capturing the entire parent scope
abortController.abortCompletion = async function () {
this.abort();
// Get data from WeakMap
const ctrlData = abortDataMap.get(this);
if (!ctrlData || !ctrlData.getAbortDataFn) {
return { final: true, conversation: {}, title: 'New Chat' };
}
// Get abort data using stored function
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
ctrlData.getAbortDataFn();
const completionTokens = await countTokens(responseData?.text ?? '');
const user = ctrlData.userId;
const responseMessage = {
...responseData,
conversationId,
finish_reason: 'incomplete',
endpoint: ctrlData.endpoint,
iconURL: ctrlData.iconURL,
model: ctrlData.modelOptions?.model ?? ctrlData.model_parameters?.model,
unfinished: false,
error: false,
isCreatedByUser: false,
tokenCount: completionTokens,
};
await spendTokens(
{ ...responseMessage, context: 'incomplete', user },
{ promptTokens, completionTokens },
);
await saveMessage(
req,
{ ...responseMessage, user },
{ context: 'api/server/middleware/abortMiddleware.js' },
);
let conversation;
if (userMessagePromise) {
const resolved = await userMessagePromise;
conversation = resolved?.conversation;
// Break reference to promise
resolved.conversation = null;
}
if (!conversation) {
conversation = await getConvo(user, conversationId);
}
return {
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
final: true,
conversation,
requestMessage: sanitizeMessageForTransmit(userMessage),
responseMessage: responseMessage,
};
};
return { abortController, onStart };
};
/**
* Handle abort errors during generation.
* @param {ServerResponse} res
* @param {ServerRequest} req
* @param {Error | unknown} error
* @param {Partial<TMessage> & { partialText?: string }} data
* @returns { Promise<void> }
* @returns {Promise<void>}
*/
const handleAbortError = async (res, req, error, data) => {
if (error?.message?.includes('base64')) {
@ -368,8 +188,7 @@ const handleAbortError = async (res, req, error, data) => {
};
}
const callback = createCleanUpHandler(conversationId);
await sendError(req, res, options, callback);
await sendError(req, res, options);
};
if (partialText && partialText.length > 5) {
@ -387,6 +206,4 @@ const handleAbortError = async (res, req, error, data) => {
module.exports = {
handleAbort,
handleAbortError,
createAbortController,
cleanupAbortController,
};

View file

@ -1,5 +1,10 @@
const { logger } = require('@librechat/data-schemas');
const { Constants, isAgentsEndpoint, ResourceType } = require('librechat-data-provider');
const {
Constants,
ResourceType,
isAgentsEndpoint,
isEphemeralAgentId,
} = require('librechat-data-provider');
const { canAccessResource } = require('./canAccessResource');
const { getAgent } = require('~/models/Agent');
@ -13,7 +18,8 @@ const { getAgent } = require('~/models/Agent');
*/
const resolveAgentIdFromBody = async (agentCustomId) => {
// Handle ephemeral agents - they don't need permission checks
if (agentCustomId === Constants.EPHEMERAL_AGENT_ID) {
// Real agent IDs always start with "agent_", so anything else is ephemeral
if (isEphemeralAgentId(agentCustomId)) {
return null; // No permission check needed for ephemeral agents
}
@ -62,7 +68,8 @@ const canAccessAgentFromBody = (options) => {
}
// Skip permission checks for ephemeral agents
if (agentId === Constants.EPHEMERAL_AGENT_ID) {
// Real agent IDs always start with "agent_", so anything else is ephemeral
if (isEphemeralAgentId(agentId)) {
return next();
}

View file

@ -23,9 +23,10 @@ async function buildEndpointOption(req, res, next) {
try {
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
} catch (error) {
logger.warn(
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
);
logger.error(`Error parsing compact conversation for endpoint ${endpoint}`, error);
logger.debug({
'Error parsing compact conversation': { endpoint, endpointType, conversation: req.body },
});
return handleError(res, { text: 'Error parsing conversation' });
}

View file

@ -6,6 +6,15 @@ const { logViolation, getLogStores } = require('~/cache');
const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {};
/**
* Helper function to get conversationId from different request body structures.
* @param {Object} body - The request body.
* @returns {string|undefined} The conversationId.
*/
const getConversationId = (body) => {
return body.conversationId ?? body.arg?.conversationId;
};
/**
* Middleware to validate user's authorization for a conversation.
*
@ -24,7 +33,7 @@ const validateConvoAccess = async (req, res, next) => {
const namespace = ViolationTypes.CONVO_ACCESS;
const cache = getLogStores(namespace);
const conversationId = req.body.conversationId;
const conversationId = getConversationId(req.body);
if (!conversationId || conversationId === Constants.NEW_CONVO) {
return next();

View file

@ -43,7 +43,6 @@ afterEach(() => {
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
// eslint-disable-next-line jest/no-disabled-tests
describe.skip('GET /', () => {
it('should return 200 and the correct body', async () => {
process.env.APP_TITLE = 'Test Title';

View file

@ -59,6 +59,7 @@ jest.mock('~/server/middleware', () => ({
forkUserLimiter: (req, res, next) => next(),
})),
configMiddleware: (req, res, next) => next(),
validateConvoAccess: (req, res, next) => next(),
}));
jest.mock('~/server/utils/import/fork', () => ({

View file

@ -2,6 +2,7 @@ const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { getBasePath } = require('@librechat/api');
const mockRegistryInstance = {
getServerConfig: jest.fn(),
@ -12,26 +13,36 @@ const mockRegistryInstance = {
removeServer: jest.fn(),
};
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
MCPOAuthHandler: {
initiateOAuthFlow: jest.fn(),
getFlowState: jest.fn(),
completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(),
},
MCPTokenStorage: {
storeTokens: jest.fn(),
getClientInfoAndMetadata: jest.fn(),
getTokens: jest.fn(),
deleteUserTokens: jest.fn(),
},
getUserMCPAuthMap: jest.fn(),
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
MCPServersRegistry: {
getInstance: () => mockRegistryInstance,
},
}));
jest.mock('@librechat/api', () => {
const actual = jest.requireActual('@librechat/api');
return {
...actual,
MCPOAuthHandler: {
initiateOAuthFlow: jest.fn(),
getFlowState: jest.fn(),
completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(),
},
MCPTokenStorage: {
storeTokens: jest.fn(),
getClientInfoAndMetadata: jest.fn(),
getTokens: jest.fn(),
deleteUserTokens: jest.fn(),
},
getUserMCPAuthMap: jest.fn(),
generateCheckAccess: jest.fn(() => (req, res, next) => next()),
MCPServersRegistry: {
getInstance: () => mockRegistryInstance,
},
// Error handling utilities (from @librechat/api mcp/errors)
isMCPDomainNotAllowedError: (error) => error?.code === 'MCP_DOMAIN_NOT_ALLOWED',
isMCPInspectionFailedError: (error) => error?.code === 'MCP_INSPECTION_FAILED',
MCPErrorCodes: {
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
},
};
});
jest.mock('@librechat/data-schemas', () => ({
logger: {
@ -271,27 +282,30 @@ describe('MCP Routes', () => {
error: 'access_denied',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=access_denied');
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=access_denied`);
});
it('should redirect to error page when code is missing', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=missing_code');
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_code`);
});
it('should redirect to error page when state is missing', async () => {
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=missing_state');
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`);
});
it('should redirect to error page when flow state is not found', async () => {
@ -301,9 +315,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'invalid-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=invalid_state');
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`);
});
it('should handle OAuth callback successfully', async () => {
@ -358,9 +373,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith(
'test-flow-id',
'test-auth-code',
@ -394,9 +410,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
});
it('should handle system-level OAuth completion', async () => {
@ -429,9 +446,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
@ -474,9 +492,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
@ -515,9 +534,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
expect(response.headers.location).toBe(`${basePath}/oauth/error?error=callback_failed`);
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
});
@ -573,9 +593,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
// Verify storeTokens was called with ORIGINAL flow state credentials
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
@ -614,9 +635,10 @@ describe('MCP Routes', () => {
code: 'test-auth-code',
state: 'test-flow-id',
});
const basePath = getBasePath();
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`);
// Verify completeOAuthFlow was NOT called (prevented duplicate)
expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled();
@ -1385,8 +1407,10 @@ describe('MCP Routes', () => {
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
.expect(302);
const basePath = getBasePath();
expect(mockFlowManager.completeFlow).not.toHaveBeenCalled();
expect(response.headers.location).toContain('/oauth/success');
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
});
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
@ -1433,7 +1457,9 @@ describe('MCP Routes', () => {
.get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id')
.expect(302);
expect(response.headers.location).toContain('/oauth/success');
const basePath = getBasePath();
expect(response.headers.location).toContain(`${basePath}/oauth/success`);
});
});

View file

@ -1,6 +1,6 @@
const express = require('express');
const jwt = require('jsonwebtoken');
const { getAccessToken } = require('@librechat/api');
const { getAccessToken, getBasePath } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { findToken, updateToken, createToken } = require('~/models');
@ -24,6 +24,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
const { code, state } = req.query;
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const basePath = getBasePath();
let identifier = action_id;
try {
let decodedState;
@ -32,17 +33,17 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
} catch (err) {
logger.error('Error verifying state parameter:', err);
await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter');
return res.redirect('/oauth/error?error=invalid_state');
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
if (decodedState.action_id !== action_id) {
await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter');
return res.redirect('/oauth/error?error=invalid_state');
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
if (!decodedState.user) {
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
return res.redirect('/oauth/error?error=invalid_state');
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
identifier = `${decodedState.user}:${action_id}`;
const flowState = await flowManager.getFlowState(identifier, 'oauth');
@ -72,12 +73,12 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
/** Redirect to React success page */
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
res.redirect(redirectUrl);
} catch (error) {
logger.error('Error in OAuth callback:', error);
await flowManager.failFlow(identifier, 'oauth', error);
res.redirect('/oauth/error?error=callback_failed');
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
}
});

View file

@ -2,7 +2,6 @@ const express = require('express');
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
const { PermissionTypes, Permissions, PermissionBits } = require('librechat-data-provider');
const {
setHeaders,
moderateText,
// validateModel,
validateConvoAccess,
@ -16,8 +15,6 @@ const { getRoleByName } = require('~/models/Role');
const router = express.Router();
router.use(moderateText);
const checkAgentAccess = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
@ -28,11 +25,11 @@ const checkAgentResourceAccess = canAccessAgentFromBody({
requiredPermission: PermissionBits.VIEW,
});
router.use(moderateText);
router.use(checkAgentAccess);
router.use(checkAgentResourceAccess);
router.use(validateConvoAccess);
router.use(buildEndpointOption);
router.use(setHeaders);
const controller = async (req, res, next) => {
await AgentController(req, res, next, initializeClient, addTitle);

View file

@ -1,5 +1,6 @@
const express = require('express');
const { isEnabled } = require('@librechat/api');
const { isEnabled, GenerationJobManager } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
uaParser,
checkBan,
@ -22,6 +23,188 @@ router.use(uaParser);
router.use('/', v1);
/**
* Stream endpoints - mounted before chatRouter to bypass rate limiters
* These are GET requests and don't need message body validation or rate limiting
*/
/**
* @route GET /chat/stream/:streamId
* @desc Subscribe to an ongoing generation job's SSE stream with replay support
* @access Private
* @description Sends sync event with resume state, replays missed chunks, then streams live
* @query resume=true - Indicates this is a reconnection (sends sync event)
*/
router.get('/chat/stream/:streamId', async (req, res) => {
const { streamId } = req.params;
const isResume = req.query.resume === 'true';
const job = await GenerationJobManager.getJob(streamId);
if (!job) {
return res.status(404).json({
error: 'Stream not found',
message: 'The generation job does not exist or has expired.',
});
}
res.setHeader('Content-Encoding', 'identity');
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache, no-transform');
res.setHeader('Connection', 'keep-alive');
res.setHeader('X-Accel-Buffering', 'no');
res.flushHeaders();
logger.debug(`[AgentStream] Client subscribed to ${streamId}, resume: ${isResume}`);
// Send sync event with resume state for ALL reconnecting clients
// This supports multi-tab scenarios where each tab needs run step data
if (isResume) {
const resumeState = await GenerationJobManager.getResumeState(streamId);
if (resumeState && !res.writableEnded) {
// Send sync event with run steps AND aggregatedContent
// Client will use aggregatedContent to initialize message state
res.write(`event: message\ndata: ${JSON.stringify({ sync: true, resumeState })}\n\n`);
if (typeof res.flush === 'function') {
res.flush();
}
logger.debug(
`[AgentStream] Sent sync event for ${streamId} with ${resumeState.runSteps.length} run steps`,
);
}
}
const result = await GenerationJobManager.subscribe(
streamId,
(event) => {
if (!res.writableEnded) {
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
if (typeof res.flush === 'function') {
res.flush();
}
}
},
(event) => {
if (!res.writableEnded) {
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
if (typeof res.flush === 'function') {
res.flush();
}
res.end();
}
},
(error) => {
if (!res.writableEnded) {
res.write(`event: error\ndata: ${JSON.stringify({ error })}\n\n`);
if (typeof res.flush === 'function') {
res.flush();
}
res.end();
}
},
);
if (!result) {
return res.status(404).json({ error: 'Failed to subscribe to stream' });
}
req.on('close', () => {
logger.debug(`[AgentStream] Client disconnected from ${streamId}`);
result.unsubscribe();
});
});
/**
* @route GET /chat/active
* @desc Get all active generation job IDs for the current user
* @access Private
* @returns { activeJobIds: string[] }
*/
router.get('/chat/active', async (req, res) => {
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(req.user.id);
res.json({ activeJobIds });
});
/**
* @route GET /chat/status/:conversationId
* @desc Check if there's an active generation job for a conversation
* @access Private
* @returns { active, streamId, status, aggregatedContent, createdAt, resumeState }
*/
router.get('/chat/status/:conversationId', async (req, res) => {
const { conversationId } = req.params;
// streamId === conversationId, so we can use getJob directly
const job = await GenerationJobManager.getJob(conversationId);
if (!job) {
return res.json({ active: false });
}
if (job.metadata.userId !== req.user.id) {
return res.status(403).json({ error: 'Unauthorized' });
}
// Get resume state which contains aggregatedContent
// Avoid calling both getStreamInfo and getResumeState (both fetch content)
const resumeState = await GenerationJobManager.getResumeState(conversationId);
const isActive = job.status === 'running';
res.json({
active: isActive,
streamId: conversationId,
status: job.status,
aggregatedContent: resumeState?.aggregatedContent ?? [],
createdAt: job.createdAt,
resumeState,
});
});
/**
* @route POST /chat/abort
* @desc Abort an ongoing generation job
* @access Private
* @description Mounted before chatRouter to bypass buildEndpointOption middleware
*/
router.post('/chat/abort', async (req, res) => {
logger.debug(`[AgentStream] ========== ABORT ENDPOINT HIT ==========`);
logger.debug(`[AgentStream] Method: ${req.method}, Path: ${req.path}`);
logger.debug(`[AgentStream] Body:`, req.body);
const { streamId, conversationId, abortKey } = req.body;
const userId = req.user?.id;
// streamId === conversationId, so try any of the provided IDs
// Skip "new" as it's a placeholder for new conversations, not an actual ID
let jobStreamId =
streamId || (conversationId !== 'new' ? conversationId : null) || abortKey?.split(':')[0];
let job = jobStreamId ? await GenerationJobManager.getJob(jobStreamId) : null;
// Fallback: if job not found and we have a userId, look up active jobs for user
// This handles the case where frontend sends "new" but job was created with a UUID
if (!job && userId) {
logger.debug(`[AgentStream] Job not found by ID, checking active jobs for user: ${userId}`);
const activeJobIds = await GenerationJobManager.getActiveJobIdsForUser(userId);
if (activeJobIds.length > 0) {
// Abort the most recent active job for this user
jobStreamId = activeJobIds[0];
job = await GenerationJobManager.getJob(jobStreamId);
logger.debug(`[AgentStream] Found active job for user: ${jobStreamId}`);
}
}
logger.debug(`[AgentStream] Computed jobStreamId: ${jobStreamId}`);
if (job && jobStreamId) {
logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`);
await GenerationJobManager.abortJob(jobStreamId);
logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`);
return res.json({ success: true, aborted: jobStreamId });
}
logger.warn(`[AgentStream] Job not found for streamId: ${jobStreamId}`);
return res.status(404).json({ error: 'Job not found', streamId: jobStreamId });
});
const chatRouter = express.Router();
chatRouter.use(configMiddleware);

View file

@ -6,6 +6,7 @@ const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const {
createImportLimiters,
validateConvoAccess,
createForkLimiters,
configMiddleware,
} = require('~/server/middleware');
@ -67,16 +68,17 @@ router.get('/:conversationId', async (req, res) => {
}
});
router.post('/gen_title', async (req, res) => {
const { conversationId } = req.body;
router.get('/gen_title/:conversationId', async (req, res) => {
const { conversationId } = req.params;
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${conversationId}`;
let title = await titleCache.get(key);
if (!title) {
// Retry every 1s for up to 20s
for (let i = 0; i < 20; i++) {
await sleep(1000);
// Exponential backoff: 500ms, 1s, 2s, 4s, 8s (total ~15.5s max wait)
const delays = [500, 1000, 2000, 4000, 8000];
for (const delay of delays) {
await sleep(delay);
title = await titleCache.get(key);
if (title) {
break;
@ -150,17 +152,39 @@ router.delete('/all', async (req, res) => {
}
});
router.post('/update', async (req, res) => {
const update = req.body.arg;
/** Maximum allowed length for conversation titles */
const MAX_CONVO_TITLE_LENGTH = 1024;
if (!update.conversationId) {
/**
* Updates a conversation's title.
* @route POST /update
* @param {string} req.body.arg.conversationId - The conversation ID to update.
* @param {string} req.body.arg.title - The new title for the conversation.
* @returns {object} 201 - The updated conversation object.
*/
router.post('/update', validateConvoAccess, async (req, res) => {
const { conversationId, title } = req.body.arg ?? {};
if (!conversationId) {
return res.status(400).json({ error: 'conversationId is required' });
}
if (title === undefined) {
return res.status(400).json({ error: 'title is required' });
}
if (typeof title !== 'string') {
return res.status(400).json({ error: 'title must be a string' });
}
const sanitizedTitle = title.trim().slice(0, MAX_CONVO_TITLE_LENGTH);
try {
const dbResponse = await saveConvo(req, update, {
context: `POST /api/convos/update ${update.conversationId}`,
});
const dbResponse = await saveConvo(
req,
{ conversationId, title: sanitizedTitle },
{ context: `POST /api/convos/update ${conversationId}` },
);
res.status(201).json(dbResponse);
} catch (error) {
logger.error('Error updating conversation', error);

View file

@ -11,6 +11,7 @@ const {
createSafeUser,
MCPOAuthHandler,
MCPTokenStorage,
getBasePath,
getUserMCPAuthMap,
generateCheckAccess,
} = require('@librechat/api');
@ -105,6 +106,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
* This handles the OAuth callback after the user has authorized the application
*/
router.get('/:serverName/oauth/callback', async (req, res) => {
const basePath = getBasePath();
try {
const { serverName } = req.params;
const { code, state, error: oauthError } = req.query;
@ -118,17 +120,19 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
if (oauthError) {
logger.error('[MCP OAuth] OAuth error received', { error: oauthError });
return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`);
return res.redirect(
`${basePath}/oauth/error?error=${encodeURIComponent(String(oauthError))}`,
);
}
if (!code || typeof code !== 'string') {
logger.error('[MCP OAuth] Missing or invalid code');
return res.redirect('/oauth/error?error=missing_code');
return res.redirect(`${basePath}/oauth/error?error=missing_code`);
}
if (!state || typeof state !== 'string') {
logger.error('[MCP OAuth] Missing or invalid state');
return res.redirect('/oauth/error?error=missing_state');
return res.redirect(`${basePath}/oauth/error?error=missing_state`);
}
const flowId = state;
@ -142,7 +146,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
if (!flowState) {
logger.error('[MCP OAuth] Flow state not found for flowId:', flowId);
return res.redirect('/oauth/error?error=invalid_state');
return res.redirect(`${basePath}/oauth/error?error=invalid_state`);
}
logger.debug('[MCP OAuth] Flow state details', {
@ -160,7 +164,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
flowId,
serverName,
});
return res.redirect(`/oauth/success?serverName=${encodeURIComponent(serverName)}`);
return res.redirect(`${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`);
}
logger.debug('[MCP OAuth] Completing OAuth flow');
@ -254,11 +258,11 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
}
/** Redirect to success page with flowId and serverName */
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`;
res.redirect(redirectUrl);
} catch (error) {
logger.error('[MCP OAuth] OAuth callback error', error);
res.redirect('/oauth/error?error=callback_failed');
res.redirect(`${basePath}/oauth/error?error=callback_failed`);
}
});
@ -588,7 +592,7 @@ async function getOAuthHeaders(serverName, userId) {
return serverConfig?.oauth_headers ?? {};
}
/**
/**
MCP Server CRUD Routes (User-Managed MCP Servers)
*/

View file

@ -1,4 +1,5 @@
const express = require('express');
const { v4: uuidv4 } = require('uuid');
const { logger } = require('@librechat/data-schemas');
const { ContentTypes } = require('librechat-data-provider');
const { unescapeLaTeX, countTokens } = require('@librechat/api');
@ -111,6 +112,91 @@ router.get('/', async (req, res) => {
}
});
/**
* Creates a new branch message from a specific agent's content within a parallel response message.
* Filters the original message's content to only include parts attributed to the specified agentId.
* Only available for non-user messages with content attributions.
*
* @route POST /branch
* @param {string} req.body.messageId - The ID of the source message
* @param {string} req.body.agentId - The agentId to filter content by
* @returns {TMessage} The newly created branch message
*/
router.post('/branch', async (req, res) => {
try {
const { messageId, agentId } = req.body;
const userId = req.user.id;
if (!messageId || !agentId) {
return res.status(400).json({ error: 'messageId and agentId are required' });
}
const sourceMessage = await getMessage({ user: userId, messageId });
if (!sourceMessage) {
return res.status(404).json({ error: 'Source message not found' });
}
if (sourceMessage.isCreatedByUser) {
return res.status(400).json({ error: 'Cannot branch from user messages' });
}
if (!Array.isArray(sourceMessage.content)) {
return res.status(400).json({ error: 'Message does not have content' });
}
const hasAgentMetadata = sourceMessage.content.some((part) => part?.agentId);
if (!hasAgentMetadata) {
return res
.status(400)
.json({ error: 'Message does not have parallel content with attributions' });
}
/** @type {Array<import('librechat-data-provider').TMessageContentParts>} */
const filteredContent = [];
for (const part of sourceMessage.content) {
if (part?.agentId === agentId) {
const { agentId: _a, groupId: _g, ...cleanPart } = part;
filteredContent.push(cleanPart);
}
}
if (filteredContent.length === 0) {
return res.status(400).json({ error: 'No content found for the specified agentId' });
}
const newMessageId = uuidv4();
/** @type {import('librechat-data-provider').TMessage} */
const newMessage = {
messageId: newMessageId,
conversationId: sourceMessage.conversationId,
parentMessageId: sourceMessage.parentMessageId,
attachments: sourceMessage.attachments,
isCreatedByUser: false,
model: sourceMessage.model,
endpoint: sourceMessage.endpoint,
sender: sourceMessage.sender,
iconURL: sourceMessage.iconURL,
content: filteredContent,
unfinished: false,
error: false,
user: userId,
};
const savedMessage = await saveMessage(req, newMessage, {
context: 'POST /api/messages/branch',
});
if (!savedMessage) {
return res.status(500).json({ error: 'Failed to save branch message' });
}
res.status(201).json(savedMessage);
} catch (error) {
logger.error('Error creating branch message:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
router.post('/artifact/:messageId', async (req, res) => {
try {
const { messageId } = req.params;

View file

@ -3,7 +3,12 @@ const { nanoid } = require('nanoid');
const { tool } = require('@langchain/core/tools');
const { GraphEvents, sleep } = require('@librechat/agents');
const { logger, encryptV2, decryptV2 } = require('@librechat/data-schemas');
const { sendEvent, logAxiosError, refreshAccessToken } = require('@librechat/api');
const {
sendEvent,
logAxiosError,
refreshAccessToken,
GenerationJobManager,
} = require('@librechat/api');
const {
Time,
CacheKeys,
@ -127,6 +132,7 @@ async function loadActionSets(searchParams) {
* @param {string | undefined} [params.description] - The description for the tool.
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
* @param {string | null} [params.streamId] - The stream ID for resumable streams.
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createActionTool({
@ -138,6 +144,7 @@ async function createActionTool({
name,
description,
encrypted,
streamId = null,
}) {
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolInput, config) => {
@ -192,7 +199,12 @@ async function createActionTool({
`${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`,
'oauth_login',
async () => {
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
if (streamId) {
GenerationJobManager.emitChunk(streamId, eventData);
} else {
sendEvent(res, eventData);
}
logger.debug('Sent OAuth login request to client', { action_id, identifier });
return true;
},
@ -217,7 +229,12 @@ async function createActionTool({
logger.debug('Received OAuth Authorization response', { action_id, identifier });
data.delta.auth = undefined;
data.delta.expires_at = undefined;
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
const successEventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
if (streamId) {
GenerationJobManager.emitChunk(streamId, successEventData);
} else {
sendEvent(res, successEventData);
}
await sleep(3000);
metadata.oauth_access_token = result.access_token;
metadata.oauth_refresh_token = result.refresh_token;

View file

@ -1,9 +1,13 @@
const bcrypt = require('bcryptjs');
const jwt = require('jsonwebtoken');
const { webcrypto } = require('node:crypto');
const { logger } = require('@librechat/data-schemas');
const { isEnabled, checkEmailConfig, isEmailDomainAllowed } = require('@librechat/api');
const {
logger,
DEFAULT_SESSION_EXPIRY,
DEFAULT_REFRESH_TOKEN_EXPIRY,
} = require('@librechat/data-schemas');
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
const { isEnabled, checkEmailConfig, isEmailDomainAllowed, math } = require('@librechat/api');
const {
findUser,
findToken,
@ -369,19 +373,21 @@ const setAuthTokens = async (userId, res, _session = null) => {
let session = _session;
let refreshToken;
let refreshTokenExpires;
const expiresIn = math(process.env.REFRESH_TOKEN_EXPIRY, DEFAULT_REFRESH_TOKEN_EXPIRY);
if (session && session._id && session.expiration != null) {
refreshTokenExpires = session.expiration.getTime();
refreshToken = await generateRefreshToken(session);
} else {
const result = await createSession(userId);
const result = await createSession(userId, { expiresIn });
session = result.session;
refreshToken = result.refreshToken;
refreshTokenExpires = session.expiration.getTime();
}
const user = await getUserById(userId);
const token = await generateToken(user);
const sessionExpiry = math(process.env.SESSION_EXPIRY, DEFAULT_SESSION_EXPIRY);
const token = await generateToken(user, sessionExpiry);
res.cookie('refreshToken', refreshToken, {
expires: new Date(refreshTokenExpires),
@ -418,10 +424,10 @@ const setOpenIDAuthTokens = (tokenset, res, userId, existingRefreshToken) => {
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
return;
}
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY
? eval(REFRESH_TOKEN_EXPIRY)
: 1000 * 60 * 60 * 24 * 7; // 7 days default
const expiryInMilliseconds = math(
process.env.REFRESH_TOKEN_EXPIRY,
DEFAULT_REFRESH_TOKEN_EXPIRY,
);
const expirationDate = new Date(Date.now() + expiryInMilliseconds);
if (tokenset == null) {
logger.error('[setOpenIDAuthTokens] No tokenset found in request');

View file

@ -1,4 +1,4 @@
const { CacheKeys } = require('librechat-data-provider');
const { CacheKeys, Time } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
/**
@ -39,12 +39,12 @@ async function getCachedTools(options = {}) {
* @param {Object} options - Options for caching tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name for server-specific tools
* @param {number} [options.ttl] - Time to live in milliseconds
* @param {number} [options.ttl] - Time to live in milliseconds (default: 12 hours)
* @returns {Promise<boolean>} Whether the operation was successful
*/
async function setCachedTools(tools, options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { userId, serverName, ttl } = options;
const { userId, serverName, ttl = Time.TWELVE_HOURS } = options;
// Cache by MCP server if specified (requires userId)
if (serverName && userId) {

View file

@ -19,7 +19,11 @@ async function getEndpointsConfig(req) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
if (cachedEndpointsConfig) {
return cachedEndpointsConfig;
if (cachedEndpointsConfig.gptPlugins) {
await cache.delete(CacheKeys.ENDPOINT_CONFIG);
} else {
return cachedEndpointsConfig;
}
}
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));

View file

@ -0,0 +1,136 @@
const { logger } = require('@librechat/data-schemas');
const { initializeAgent, validateAgentModel } = require('@librechat/api');
const { loadAddedAgent, setGetAgent, ADDED_AGENT_ID } = require('~/models/loadAddedAgent');
const { getConvoFiles } = require('~/models/Conversation');
const { getAgent } = require('~/models/Agent');
const db = require('~/models');
// Initialize the getAgent dependency
setGetAgent(getAgent);
/**
* Process addedConvo for parallel agent execution.
* Creates a parallel agent config from an added conversation.
*
* When an added agent has no incoming edges, it becomes a start node
* and runs in parallel with the primary agent automatically.
*
* Edge cases handled:
* - Primary agent has edges (handoffs): Added agent runs in parallel with primary,
* but doesn't participate in the primary's handoff graph
* - Primary agent has agent_ids (legacy chain): Added agent runs in parallel with primary,
* but doesn't participate in the chain
* - Primary agent has both: Added agent is independent, runs parallel from start
*
* @param {Object} params
* @param {import('express').Request} params.req
* @param {import('express').Response} params.res
* @param {Object} params.endpointOption - The endpoint option containing addedConvo
* @param {Object} params.modelsConfig - The models configuration
* @param {Function} params.logViolation - Function to log violations
* @param {Function} params.loadTools - Function to load agent tools
* @param {Array} params.requestFiles - Request files
* @param {string} params.conversationId - The conversation ID
* @param {Set} params.allowedProviders - Set of allowed providers
* @param {Map} params.agentConfigs - Map of agent configs to add to
* @param {string} params.primaryAgentId - The primary agent ID
* @param {Object|undefined} params.userMCPAuthMap - User MCP auth map to merge into
* @returns {Promise<{userMCPAuthMap: Object|undefined}>} The updated userMCPAuthMap
*/
const processAddedConvo = async ({
req,
res,
endpointOption,
modelsConfig,
logViolation,
loadTools,
requestFiles,
conversationId,
allowedProviders,
agentConfigs,
primaryAgentId,
primaryAgent,
userMCPAuthMap,
}) => {
const addedConvo = endpointOption.addedConvo;
logger.debug('[processAddedConvo] Called with addedConvo:', {
hasAddedConvo: addedConvo != null,
addedConvoEndpoint: addedConvo?.endpoint,
addedConvoModel: addedConvo?.model,
addedConvoAgentId: addedConvo?.agent_id,
});
if (addedConvo == null) {
return { userMCPAuthMap };
}
try {
const addedAgent = await loadAddedAgent({ req, conversation: addedConvo, primaryAgent });
if (!addedAgent) {
return { userMCPAuthMap };
}
const addedValidation = await validateAgentModel({
req,
res,
modelsConfig,
logViolation,
agent: addedAgent,
});
if (!addedValidation.isValid) {
logger.warn(
`[processAddedConvo] Added agent validation failed: ${addedValidation.error?.message}`,
);
return { userMCPAuthMap };
}
const addedConfig = await initializeAgent(
{
req,
res,
loadTools,
requestFiles,
conversationId,
agent: addedAgent,
endpointOption,
allowedProviders,
},
{
getConvoFiles,
getFiles: db.getFiles,
getUserKey: db.getUserKey,
updateFilesUsage: db.updateFilesUsage,
getUserKeyValues: db.getUserKeyValues,
getToolFilesByIds: db.getToolFilesByIds,
},
);
if (userMCPAuthMap != null) {
Object.assign(userMCPAuthMap, addedConfig.userMCPAuthMap ?? {});
} else {
userMCPAuthMap = addedConfig.userMCPAuthMap;
}
const addedAgentId = addedConfig.id || ADDED_AGENT_ID;
agentConfigs.set(addedAgentId, addedConfig);
// No edges needed - agent without incoming edges becomes a start node
// and runs in parallel with the primary agent automatically.
// This is independent of any edges/agent_ids the primary agent has.
logger.debug(
`[processAddedConvo] Added parallel agent: ${addedAgentId} (primary: ${primaryAgentId}, ` +
`primary has edges: ${!!endpointOption.edges}, primary has agent_ids: ${!!endpointOption.agent_ids})`,
);
return { userMCPAuthMap };
} catch (err) {
logger.error('[processAddedConvo] Error processing addedConvo for parallel agent', err);
return { userMCPAuthMap };
}
};
module.exports = {
processAddedConvo,
ADDED_AGENT_ID,
};

View file

@ -15,6 +15,9 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
return undefined;
});
/** @type {import('librechat-data-provider').TConversation | undefined} */
const addedConvo = req.body?.addedConvo;
return removeNullishValues({
spec,
iconURL,
@ -23,6 +26,7 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
endpointType,
model_parameters,
agent: agentPromise,
addedConvo,
});
};

View file

@ -7,10 +7,10 @@ const {
createSequentialChainEdges,
} = require('@librechat/api');
const {
Constants,
EModelEndpoint,
isAgentsEndpoint,
getResponseSender,
isEphemeralAgentId,
} = require('librechat-data-provider');
const {
createToolEndCallback,
@ -20,14 +20,17 @@ const { getModelsConfig } = require('~/server/controllers/ModelController');
const { loadAgentTools } = require('~/server/services/ToolService');
const AgentClient = require('~/server/controllers/agents/client');
const { getConvoFiles } = require('~/models/Conversation');
const { processAddedConvo } = require('./addedConvo');
const { getAgent } = require('~/models/Agent');
const { logViolation } = require('~/cache');
const db = require('~/models');
/**
* @param {AbortSignal} signal
* Creates a tool loader function for the agent.
* @param {AbortSignal} signal - The abort signal
* @param {string | null} [streamId] - The stream ID for resumable mode
*/
function createToolLoader(signal) {
function createToolLoader(signal, streamId = null) {
/**
* @param {object} params
* @param {ServerRequest} params.req
@ -52,6 +55,7 @@ function createToolLoader(signal) {
agent,
signal,
tool_resources,
streamId,
});
} catch (error) {
logger.error('Error loading tools for agent ' + agentId, error);
@ -65,18 +69,21 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
}
const appConfig = req.config;
// TODO: use endpointOption to determine options/modelOptions
/** @type {string | null} */
const streamId = req._resumableStreamId || null;
/** @type {Array<UsageMetadata>} */
const collectedUsage = [];
/** @type {ArtifactPromises} */
const artifactPromises = [];
const { contentParts, aggregateContent } = createContentAggregator();
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId });
const eventHandlers = getDefaultHandlers({
res,
aggregateContent,
toolEndCallback,
collectedUsage,
streamId,
});
if (!endpointOption.agent) {
@ -105,7 +112,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
const agentConfigs = new Map();
const allowedProviders = new Set(appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders);
const loadTools = createToolLoader(signal);
const loadTools = createToolLoader(signal, streamId);
/** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? [];
/** @type {string} */
@ -227,6 +234,33 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
edges = edges ? edges.concat(chain) : chain;
}
/** Multi-Convo: Process addedConvo for parallel agent execution */
const { userMCPAuthMap: updatedMCPAuthMap } = await processAddedConvo({
req,
res,
endpointOption,
modelsConfig,
logViolation,
loadTools,
requestFiles,
conversationId,
allowedProviders,
agentConfigs,
primaryAgentId: primaryConfig.id,
primaryAgent,
userMCPAuthMap,
});
if (updatedMCPAuthMap) {
userMCPAuthMap = updatedMCPAuthMap;
}
// Ensure edges is an array when we have multiple agents (multi-agent mode)
// MultiAgentGraph.categorizeEdges requires edges to be iterable
if (agentConfigs.size > 0 && !edges) {
edges = [];
}
primaryConfig.edges = edges;
let endpointConfig = appConfig.endpoints?.[primaryConfig.endpoint];
@ -270,10 +304,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => {
endpointType: endpointOption.endpointType,
resendFiles: primaryConfig.resendFiles ?? true,
maxContextTokens: primaryConfig.maxContextTokens,
endpoint:
primaryConfig.id === Constants.EPHEMERAL_AGENT_ID
? primaryConfig.endpoint
: EModelEndpoint.agents,
endpoint: isEphemeralAgentId(primaryConfig.id) ? primaryConfig.endpoint : EModelEndpoint.agents,
});
return { client, userMCPAuthMap };

View file

@ -3,7 +3,6 @@ const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getAssistant } = require('~/models/Assistant');
const buildOptions = async (endpoint, parsedBody) => {
const { promptPrefix, assistant_id, iconURL, greeting, spec, artifacts, ...modelOptions } =
parsedBody;
const endpointOption = removeNullishValues({

View file

@ -10,8 +10,10 @@ const {
const {
sendEvent,
MCPOAuthHandler,
isMCPDomainAllowed,
normalizeServerName,
convertWithResolvedRefs,
GenerationJobManager,
} = require('@librechat/api');
const {
Time,
@ -21,13 +23,14 @@ const {
isAssistantsEndpoint,
} = require('librechat-data-provider');
const {
getMCPManager,
getFlowStateManager,
getOAuthReconnectionManager,
getMCPServersRegistry,
getFlowStateManager,
getMCPManager,
} = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache');
/**
@ -35,8 +38,9 @@ const { getLogStores } = require('~/cache');
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
*/
function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
function createRunStepDeltaEmitter({ res, stepId, toolCall, streamId = null }) {
/**
* @param {string} authURL - The URL to redirect the user for OAuth authentication.
* @returns {void}
@ -52,7 +56,12 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
expires_at: Date.now() + Time.TWO_MINUTES,
},
};
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
if (streamId) {
GenerationJobManager.emitChunk(streamId, eventData);
} else {
sendEvent(res, eventData);
}
};
}
@ -63,8 +72,9 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall }) {
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {number} [params.index]
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
*/
function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
function createRunStepEmitter({ res, runId, stepId, toolCall, index, streamId = null }) {
return function () {
/** @type {import('@librechat/agents').RunStep} */
const data = {
@ -77,7 +87,12 @@ function createRunStepEmitter({ res, runId, stepId, toolCall, index }) {
tool_calls: [toolCall],
},
};
sendEvent(res, { event: GraphEvents.ON_RUN_STEP, data });
const eventData = { event: GraphEvents.ON_RUN_STEP, data };
if (streamId) {
GenerationJobManager.emitChunk(streamId, eventData);
} else {
sendEvent(res, eventData);
}
};
}
@ -108,10 +123,9 @@ function createOAuthStart({ flowId, flowManager, callback }) {
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.stepId - The ID of the step in the flow.
* @param {ToolCallChunk} params.toolCall - The tool call object containing tool information.
* @param {string} params.loginFlowId - The ID of the login flow.
* @param {FlowStateManager<any>} params.flowManager - The flow manager instance.
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
*/
function createOAuthEnd({ res, stepId, toolCall }) {
function createOAuthEnd({ res, stepId, toolCall, streamId = null }) {
return async function () {
/** @type {{ id: string; delta: AgentToolCallDelta }} */
const data = {
@ -121,7 +135,12 @@ function createOAuthEnd({ res, stepId, toolCall }) {
tool_calls: [{ ...toolCall }],
},
};
sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data });
const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data };
if (streamId) {
GenerationJobManager.emitChunk(streamId, eventData);
} else {
sendEvent(res, eventData);
}
logger.debug('Sent OAuth login success to client');
};
}
@ -137,7 +156,9 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
return function () {
logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`);
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
// Clean up both mcp_oauth and mcp_get_tokens flows
flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted'));
flowManager.failFlow(flowId, 'mcp_get_tokens', new Error('Tool call aborted'));
};
}
@ -162,10 +183,19 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
* @param {AbortSignal} params.signal
* @param {string} params.model
* @param {number} [params.index]
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) {
async function reconnectServer({
res,
user,
index,
signal,
serverName,
userMCPAuthMap,
streamId = null,
}) {
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${user.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
@ -176,36 +206,60 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
type: 'tool_call_chunk',
};
const runStepEmitter = createRunStepEmitter({
res,
index,
runId,
stepId,
toolCall,
});
const runStepDeltaEmitter = createRunStepDeltaEmitter({
res,
stepId,
toolCall,
});
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
const oauthStart = createOAuthStart({
res,
flowId,
callback,
flowManager,
});
return await reinitMCPServer({
user,
signal,
serverName,
oauthStart,
flowManager,
userMCPAuthMap,
forceNew: true,
returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES,
});
// Set up abort handler to clean up OAuth flows if request is aborted
const oauthFlowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
const abortHandler = () => {
logger.info(
`[MCP][User: ${user.id}][${serverName}] Tool loading aborted, cleaning up OAuth flows`,
);
// Clean up both mcp_oauth and mcp_get_tokens flows
flowManager.failFlow(oauthFlowId, 'mcp_oauth', new Error('Tool loading aborted'));
flowManager.failFlow(oauthFlowId, 'mcp_get_tokens', new Error('Tool loading aborted'));
};
if (signal) {
signal.addEventListener('abort', abortHandler, { once: true });
}
try {
const runStepEmitter = createRunStepEmitter({
res,
index,
runId,
stepId,
toolCall,
streamId,
});
const runStepDeltaEmitter = createRunStepDeltaEmitter({
res,
stepId,
toolCall,
streamId,
});
const callback = createOAuthCallback({ runStepEmitter, runStepDeltaEmitter });
const oauthStart = createOAuthStart({
res,
flowId,
callback,
flowManager,
});
return await reinitMCPServer({
user,
signal,
serverName,
oauthStart,
flowManager,
userMCPAuthMap,
forceNew: true,
returnOnOAuth: false,
connectionTimeout: Time.TWO_MINUTES,
});
} finally {
// Clean up abort handler to prevent memory leaks
if (signal) {
signal.removeEventListener('abort', abortHandler);
}
}
}
/**
@ -222,11 +276,45 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
async function createMCPTools({
res,
user,
index,
signal,
config,
provider,
serverName,
userMCPAuthMap,
streamId = null,
}) {
// Early domain validation before reconnecting server (avoid wasted work on disallowed domains)
// Use getAppConfig() to support per-user/role domain restrictions
const serverConfig =
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
if (serverConfig?.url) {
const appConfig = await getAppConfig({ role: user?.role });
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
if (!isDomainAllowed) {
logger.warn(`[MCP][${serverName}] Domain not allowed, skipping all tools`);
return [];
}
}
const result = await reconnectServer({
res,
user,
index,
signal,
serverName,
userMCPAuthMap,
streamId,
});
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
@ -239,8 +327,10 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
user,
provider,
userMCPAuthMap,
streamId,
availableTools: result.availableTools,
toolKey: `${tool.name}${Constants.mcp_delimiter}${serverName}`,
config: serverConfig,
});
if (toolInstance) {
serverTools.push(toolInstance);
@ -259,9 +349,11 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
* @param {string} params.model - The model for the tool.
* @param {number} [params.index]
* @param {AbortSignal} [params.signal]
* @param {string | null} [params.streamId] - The stream ID for resumable mode.
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
* @param {LCAvailableTools} [params.availableTools]
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @param {import('@librechat/api').ParsedServerConfig} [params.config]
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({
@ -273,9 +365,25 @@ async function createMCPTool({
provider,
userMCPAuthMap,
availableTools,
config,
streamId = null,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
// Runtime domain validation: check if the server's domain is still allowed
// Use getAppConfig() to support per-user/role domain restrictions
const serverConfig =
config ?? (await getMCPServersRegistry().getServerConfig(serverName, user?.id));
if (serverConfig?.url) {
const appConfig = await getAppConfig({ role: user?.role });
const allowedDomains = appConfig?.mcpSettings?.allowedDomains;
const isDomainAllowed = await isMCPDomainAllowed(serverConfig, allowedDomains);
if (!isDomainAllowed) {
logger.warn(`[MCP][${serverName}] Domain no longer allowed, skipping tool: ${toolName}`);
return undefined;
}
}
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {
@ -289,6 +397,7 @@ async function createMCPTool({
signal,
serverName,
userMCPAuthMap,
streamId,
});
toolDefinition = result?.availableTools?.[toolKey]?.function;
}
@ -304,10 +413,18 @@ async function createMCPTool({
toolName,
serverName,
toolDefinition,
streamId,
});
}
function createToolInstance({ res, toolName, serverName, toolDefinition, provider: _provider }) {
function createToolInstance({
res,
toolName,
serverName,
toolDefinition,
provider: _provider,
streamId = null,
}) {
/** @type {LCTool} */
const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
@ -343,6 +460,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
res,
stepId,
toolCall,
streamId,
});
const oauthStart = createOAuthStart({
flowId,
@ -353,6 +471,7 @@ function createToolInstance({ res, toolName, serverName, toolDefinition, provide
res,
stepId,
toolCall,
streamId,
});
if (derivedSignal) {
@ -448,7 +567,10 @@ async function getMCPSetupData(userId) {
/** @type {Map<string, import('@librechat/api').MCPConnection>} */
let appConnections = new Map();
try {
appConnections = (await mcpManager.appConnections?.getAll()) || new Map();
// Use getLoaded() instead of getAll() to avoid forcing connection creation
// getAll() creates connections for all servers, which is problematic for servers
// that require user context (e.g., those with {{LIBRECHAT_USER_ID}} placeholders)
appConnections = (await mcpManager.appConnections?.getLoaded()) || new Map();
} catch (error) {
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
}

View file

@ -1,14 +1,4 @@
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
} = require('./MCP');
// Mock all dependencies - define mocks before imports
// Mock all dependencies
jest.mock('@librechat/data-schemas', () => ({
logger: {
@ -43,22 +33,46 @@ jest.mock('@librechat/agents', () => ({
},
}));
// Create mock registry instance
const mockRegistryInstance = {
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
getAllServerConfigs: jest.fn(() => Promise.resolve({})),
getServerConfig: jest.fn(() => Promise.resolve(null)),
};
jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
convertWithResolvedRefs: jest.fn((params) => params),
MCPServersRegistry: {
getInstance: () => mockRegistryInstance,
},
}));
// Create isMCPDomainAllowed mock that can be configured per-test
const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true));
const mockGetAppConfig = jest.fn(() => Promise.resolve({}));
jest.mock('@librechat/api', () => {
// Access mock via getter to avoid hoisting issues
return {
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
convertWithResolvedRefs: jest.fn((params) => params),
get isMCPDomainAllowed() {
return mockIsMCPDomainAllowed;
},
MCPServersRegistry: {
getInstance: () => mockRegistryInstance,
},
};
});
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
} = require('./MCP');
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
@ -80,7 +94,9 @@ jest.mock('librechat-data-provider', () => ({
jest.mock('./Config', () => ({
loadCustomConfig: jest.fn(),
getAppConfig: jest.fn(),
get getAppConfig() {
return mockGetAppConfig;
},
}));
jest.mock('~/config', () => ({
@ -128,7 +144,7 @@ describe('tests for the new helper functions used by the MCP connection status e
beforeEach(() => {
mockGetMCPManager.mockReturnValue({
appConnections: { getAll: jest.fn(() => new Map()) },
appConnections: { getLoaded: jest.fn(() => new Map()) },
getUserConnections: jest.fn(() => new Map()),
});
mockRegistryInstance.getOAuthServers.mockResolvedValue(new Set());
@ -143,7 +159,7 @@ describe('tests for the new helper functions used by the MCP connection status e
const mockOAuthServers = new Set(['server2']);
const mockMCPManager = {
appConnections: { getAll: jest.fn(() => Promise.resolve(mockAppConnections)) },
appConnections: { getLoaded: jest.fn(() => Promise.resolve(mockAppConnections)) },
getUserConnections: jest.fn(() => mockUserConnections),
};
mockGetMCPManager.mockReturnValue(mockMCPManager);
@ -153,7 +169,7 @@ describe('tests for the new helper functions used by the MCP connection status e
expect(mockRegistryInstance.getAllServerConfigs).toHaveBeenCalledWith(mockUserId);
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
expect(mockMCPManager.appConnections.getLoaded).toHaveBeenCalled();
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
expect(mockRegistryInstance.getOAuthServers).toHaveBeenCalledWith(mockUserId);
@ -174,7 +190,7 @@ describe('tests for the new helper functions used by the MCP connection status e
mockRegistryInstance.getAllServerConfigs.mockResolvedValue(mockConfig);
const mockMCPManager = {
appConnections: { getAll: jest.fn(() => Promise.resolve(null)) },
appConnections: { getLoaded: jest.fn(() => Promise.resolve(null)) },
getUserConnections: jest.fn(() => null),
};
mockGetMCPManager.mockReturnValue(mockMCPManager);
@ -692,6 +708,18 @@ describe('User parameter passing tests', () => {
createFlowWithHandler: jest.fn(),
failFlow: jest.fn(),
});
// Reset domain validation mock to default (allow all)
mockIsMCPDomainAllowed.mockReset();
mockIsMCPDomainAllowed.mockResolvedValue(true);
// Reset registry mocks
mockRegistryInstance.getServerConfig.mockReset();
mockRegistryInstance.getServerConfig.mockResolvedValue(null);
// Reset getAppConfig mock to default (no restrictions)
mockGetAppConfig.mockReset();
mockGetAppConfig.mockResolvedValue({});
});
describe('createMCPTools', () => {
@ -887,6 +915,229 @@ describe('User parameter passing tests', () => {
});
});
describe('Runtime domain validation', () => {
it('should skip tool creation when domain is not allowed', async () => {
const mockUser = { id: 'domain-test-user', role: 'user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config with URL (remote server)
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://disallowed-domain.com/sse',
});
// Mock getAppConfig to return domain restrictions
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
});
// Mock domain validation to return false (domain not allowed)
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools: {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
},
});
// Should return undefined for disallowed domain
expect(result).toBeUndefined();
// Should not call reinitMCPServer since domain check failed
expect(mockReinitMCPServer).not.toHaveBeenCalled();
// Verify getAppConfig was called with user role
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
// Verify domain validation was called with correct parameters
expect(mockIsMCPDomainAllowed).toHaveBeenCalledWith(
{ url: 'https://disallowed-domain.com/sse' },
['allowed-domain.com'],
);
});
it('should allow tool creation when domain is allowed', async () => {
const mockUser = { id: 'domain-test-user', role: 'admin' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config with URL (remote server)
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://allowed-domain.com/sse',
});
// Mock getAppConfig to return domain restrictions
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
});
// Mock domain validation to return true (domain allowed)
mockIsMCPDomainAllowed.mockResolvedValueOnce(true);
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
};
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Should create tool successfully
expect(result).toBeDefined();
// Verify getAppConfig was called with user role
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'admin' });
});
it('should skip domain validation for stdio transports (no URL)', async () => {
const mockUser = { id: 'stdio-test-user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config without URL (stdio transport)
mockRegistryInstance.getServerConfig.mockResolvedValue({
command: 'npx',
args: ['@modelcontextprotocol/server'],
});
// Mock getAppConfig (should not be called for stdio)
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['restricted-domain.com'] },
});
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
};
const result = await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Should create tool successfully without domain check
expect(result).toBeDefined();
// Should not call getAppConfig or isMCPDomainAllowed for stdio transport (no URL)
expect(mockGetAppConfig).not.toHaveBeenCalled();
expect(mockIsMCPDomainAllowed).not.toHaveBeenCalled();
});
it('should return empty array from createMCPTools when domain is not allowed', async () => {
const mockUser = { id: 'domain-test-user', role: 'user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Mock server config with URL (remote server)
const serverConfig = { url: 'https://disallowed-domain.com/sse' };
mockRegistryInstance.getServerConfig.mockResolvedValue(serverConfig);
// Mock getAppConfig to return domain restrictions
mockGetAppConfig.mockResolvedValue({
mcpSettings: { allowedDomains: ['allowed-domain.com'] },
});
// Mock domain validation to return false (domain not allowed)
mockIsMCPDomainAllowed.mockResolvedValueOnce(false);
const result = await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'test-server',
provider: 'openai',
userMCPAuthMap: {},
config: serverConfig,
});
// Should return empty array for disallowed domain
expect(result).toEqual([]);
// Should not call reinitMCPServer since domain check failed early
expect(mockReinitMCPServer).not.toHaveBeenCalled();
// Verify getAppConfig was called with user role
expect(mockGetAppConfig).toHaveBeenCalledWith({ role: 'user' });
});
it('should use user role when fetching domain restrictions', async () => {
const adminUser = { id: 'admin-user', role: 'admin' };
const regularUser = { id: 'regular-user', role: 'user' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://some-domain.com/sse',
});
// Mock different responses based on role
mockGetAppConfig
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['admin-allowed.com'] } })
.mockResolvedValueOnce({ mcpSettings: { allowedDomains: ['user-allowed.com'] } });
mockIsMCPDomainAllowed.mockResolvedValue(true);
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
};
// Call with admin user
await createMCPTool({
res: mockRes,
user: adminUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Reset and call with regular user
mockRegistryInstance.getServerConfig.mockResolvedValue({
url: 'https://some-domain.com/sse',
});
await createMCPTool({
res: mockRes,
user: regularUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools,
});
// Verify getAppConfig was called with correct roles
expect(mockGetAppConfig).toHaveBeenNthCalledWith(1, { role: 'admin' });
expect(mockGetAppConfig).toHaveBeenNthCalledWith(2, { role: 'user' });
});
});
describe('User parameter integrity', () => {
it('should preserve user object properties through the call chain', async () => {
const complexUser = {

View file

@ -255,7 +255,7 @@ describe('processMessages', () => {
type: 'text',
text: {
value:
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
"The text you have uploaded is from the book \"Harry Potter and the Philosopher's Stone\" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.",
annotations: [
{
type: 'file_citation',
@ -424,7 +424,7 @@ These points highlight Harry's initial experiences in the magical world and set
type: 'text',
text: {
value:
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
"The text you have uploaded is from the book \"Harry Potter and the Philosopher's Stone\" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.",
annotations: [
{
type: 'file_citation',
@ -582,7 +582,7 @@ These points highlight Harry's initial experiences in the magical world and set
type: 'text',
text: {
value:
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation【11:2†source】.',
"This is a test ^1^ with pre-existing citation-like text. Here's a real citation【11:2†source】.",
annotations: [
{
type: 'file_citation',
@ -610,7 +610,7 @@ These points highlight Harry's initial experiences in the magical world and set
});
const expectedText =
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation^1^.\n\n^1.^ test.txt';
"This is a test ^1^ with pre-existing citation-like text. Here's a real citation^1^.\n\n^1.^ test.txt";
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);

View file

@ -9,7 +9,6 @@ const {
} = require('@librechat/api');
const {
Tools,
Constants,
ErrorTypes,
ContentTypes,
imageGenTools,
@ -18,6 +17,7 @@ const {
ImageVisionTool,
openapiToFunction,
AgentCapabilities,
isEphemeralAgentId,
validateActionDomain,
defaultAgentCapabilities,
validateAndParseOpenAPISpec,
@ -369,7 +369,15 @@ async function processRequiredActions(client, requiredActions) {
* @param {string | undefined} [params.openAIApiKey] - The OpenAI API key.
* @returns {Promise<{ tools?: StructuredTool[]; userMCPAuthMap?: Record<string, Record<string, string>> }>} The agent tools.
*/
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
async function loadAgentTools({
req,
res,
agent,
signal,
tool_resources,
openAIApiKey,
streamId = null,
}) {
if (!agent.tools || agent.tools.length === 0) {
return {};
} else if (
@ -385,7 +393,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
const endpointsConfig = await getEndpointsConfig(req);
let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []);
/** Edge case: use defined/fallback capabilities when the "agents" endpoint is not enabled */
if (enabledCapabilities.size === 0 && agent.id === Constants.EPHEMERAL_AGENT_ID) {
if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) {
enabledCapabilities = new Set(
appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities,
);
@ -422,7 +430,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
/** @type {ReturnType<typeof createOnSearchResults>} */
let webSearchCallbacks;
if (includesWebSearch) {
webSearchCallbacks = createOnSearchResults(res);
webSearchCallbacks = createOnSearchResults(res, streamId);
}
/** @type {Record<string, Record<string, string>>} */
@ -622,6 +630,7 @@ async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIA
encrypted,
name: toolName,
description: functionSig.description,
streamId,
});
if (!tool) {

View file

@ -1,13 +1,29 @@
const { nanoid } = require('nanoid');
const { Tools } = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas');
const { GenerationJobManager } = require('@librechat/api');
/**
* Helper to write attachment events either to res or to job emitter.
* @param {import('http').ServerResponse} res - The server response object
* @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode
* @param {Object} attachment - The attachment data
*/
function writeAttachment(res, streamId, attachment) {
if (streamId) {
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
} else {
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
}
}
/**
* Creates a function to handle search results and stream them as attachments
* @param {import('http').ServerResponse} res - The HTTP server response object
* @param {string | null} [streamId] - The stream ID for resumable mode, or null for standard mode
* @returns {{ onSearchResults: function(SearchResult, GraphRunnableConfig): void; onGetHighlights: function(string): void}} - Function that takes search results and returns or streams an attachment
*/
function createOnSearchResults(res) {
function createOnSearchResults(res, streamId = null) {
const context = {
sourceMap: new Map(),
searchResultData: undefined,
@ -70,7 +86,7 @@ function createOnSearchResults(res) {
if (!res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
writeAttachment(res, streamId, attachment);
}
/**
@ -92,7 +108,7 @@ function createOnSearchResults(res) {
}
const attachment = buildAttachment(context);
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
writeAttachment(res, streamId, attachment);
}
return {

View file

@ -14,8 +14,9 @@ async function initializeMCPs() {
}
// Initialize MCPServersRegistry first (required for MCPManager)
// Pass allowedDomains from mcpSettings for domain validation
try {
createMCPServersRegistry(mongoose);
createMCPServersRegistry(mongoose, appConfig?.mcpSettings?.allowedDomains);
} catch (error) {
logger.error('[MCP] Failed to initialize MCPServersRegistry:', error);
throw error;