From 3784c702aa83be6d948507684694881e5d4c5029 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 3 Sep 2025 23:10:33 -0400 Subject: [PATCH] feat: first pass, multi-agent handoffs --- api/server/controllers/agents/client.js | 248 ++++-------------- .../services/Endpoints/agents/initialize.js | 97 ++++--- packages/api/src/agents/run.ts | 136 +++++++--- 3 files changed, 210 insertions(+), 271 deletions(-) diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 1bf9e07bd8..fb606a36b0 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -3,20 +3,17 @@ const { logger } = require('@librechat/data-schemas'); const { DynamicStructuredTool } = require('@langchain/core/tools'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { - sendEvent, createRun, Tokenizer, checkAccess, resolveHeaders, getBalanceConfig, memoryInstructions, - formatContentStrings, createMemoryProcessor, } = require('@librechat/api'); const { Callback, Providers, - GraphEvents, TitleMethod, formatMessage, formatAgentMessages, @@ -35,12 +32,12 @@ const { bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); -const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getFormattedMemories, deleteMemory, setMemory } = require('~/models'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { getProviderConfig } = require('~/server/services/Endpoints'); +const { createContextHandlers } = require('~/app/clients/prompts'); const { checkCapability } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); const { getRoleByName } = require('~/models/Role'); @@ -77,8 +74,6 @@ const payloadParser = ({ req, agent, endpoint }) => { return req.body.endpointOption.model_parameters; }; -const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; - function createTokenCounter(encoding) { return function (message) { const countTokens = (text) => Tokenizer.getTokenCount(text, encoding); @@ -801,138 +796,78 @@ class AgentClient extends BaseClient { ); /** - * - * @param {Agent} agent * @param {BaseMessage[]} messages - * @param {number} [i] - * @param {TMessageContentParts[]} [contentData] - * @param {Record} [currentIndexCountMap] */ - const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => { - config.configurable.model = agent.model_parameters.model; - const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap; - if (i > 0) { - this.model = agent.model_parameters.model; + const runAgents = async (messages) => { + const agents = [this.options.agent]; + if ( + this.agentConfigs && + this.agentConfigs.size > 0 && + ((this.options.agent.edges?.length ?? 0) > 0 || + (await checkCapability(this.options.req, AgentCapabilities.chain))) + ) { + agents.push(...this.agentConfigs.values()); } - if (i > 0 && config.signal == null) { - config.signal = abortController.signal; - } - if (agent.recursion_limit && typeof agent.recursion_limit === 'number') { - config.recursionLimit = agent.recursion_limit; + + if (agents[0].recursion_limit && typeof agents[0].recursion_limit === 'number') { + config.recursionLimit = agents[0].recursion_limit; } + if ( agentsEConfig?.maxRecursionLimit && config.recursionLimit > agentsEConfig?.maxRecursionLimit ) { config.recursionLimit = agentsEConfig?.maxRecursionLimit; } - config.configurable.agent_id = agent.id; - config.configurable.name = agent.name; - config.configurable.agent_index = i; - const noSystemMessages = noSystemModelRegex.some((regex) => - agent.model_parameters.model.match(regex), - ); - const systemMessage = Object.values(agent.toolContextMap ?? {}) - .join('\n') - .trim(); + // TODO: needs to be added as part of AgentContext initialization + // const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; + // const noSystemMessages = noSystemModelRegex.some((regex) => + // agent.model_parameters.model.match(regex), + // ); + // if (noSystemMessages === true && systemContent?.length) { + // const latestMessageContent = _messages.pop().content; + // if (typeof latestMessageContent !== 'string') { + // latestMessageContent[0].text = [systemContent, latestMessageContent[0].text].join('\n'); + // _messages.push(new HumanMessage({ content: latestMessageContent })); + // } else { + // const text = [systemContent, latestMessageContent].join('\n'); + // _messages.push(new HumanMessage(text)); + // } + // } + // let messages = _messages; + // if (agent.useLegacyContent === true) { + // messages = formatContentStrings(messages); + // } + // if ( + // agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes( + // 'prompt-caching', + // ) + // ) { + // messages = addCacheControl(messages); + // } - let systemContent = [ - systemMessage, - agent.instructions ?? '', - i !== 0 ? (agent.additional_instructions ?? '') : '', - ] - .join('\n') - .trim(); - - if (noSystemMessages === true) { - agent.instructions = undefined; - agent.additional_instructions = undefined; - } else { - agent.instructions = systemContent; - agent.additional_instructions = undefined; - } - - if (noSystemMessages === true && systemContent?.length) { - const latestMessageContent = _messages.pop().content; - if (typeof latestMessageContent !== 'string') { - latestMessageContent[0].text = [systemContent, latestMessageContent[0].text].join('\n'); - _messages.push(new HumanMessage({ content: latestMessageContent })); - } else { - const text = [systemContent, latestMessageContent].join('\n'); - _messages.push(new HumanMessage(text)); - } - } - - let messages = _messages; - if (agent.useLegacyContent === true) { - messages = formatContentStrings(messages); - } - if ( - agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes( - 'prompt-caching', - ) - ) { - messages = addCacheControl(messages); - } - - if (i === 0) { - memoryPromise = this.runMemory(messages); - } - - /** Resolve request-based headers for Custom Endpoints. Note: if this is added to - * non-custom endpoints, needs consideration of varying provider header configs. - */ - if (agent.model_parameters?.configuration?.defaultHeaders != null) { - agent.model_parameters.configuration.defaultHeaders = resolveHeaders({ - headers: agent.model_parameters.configuration.defaultHeaders, - body: config.configurable.requestBody, - }); - } + memoryPromise = this.runMemory(messages); run = await createRun({ - agent, - req: this.options.req, + agents, + indexTokenCountMap, runId: this.responseMessageId, signal: abortController.signal, customHandlers: this.options.eventHandlers, + requestBody: config.configurable.requestBody, + tokenCounter: createTokenCounter(this.getEncoding()), }); if (!run) { throw new Error('Failed to create run'); } - if (i === 0) { - this.run = run; - } - - if (contentData.length) { - const agentUpdate = { - type: ContentTypes.AGENT_UPDATE, - [ContentTypes.AGENT_UPDATE]: { - index: contentData.length, - runId: this.responseMessageId, - agentId: agent.id, - }, - }; - const streamData = { - event: GraphEvents.ON_AGENT_UPDATE, - data: agentUpdate, - }; - this.options.aggregateContent(streamData); - sendEvent(this.options.res, streamData); - contentData.push(agentUpdate); - run.Graph.contentData = contentData; - } - + this.run = run; if (userMCPAuthMap != null) { config.configurable.userMCPAuthMap = userMCPAuthMap; } await run.processStream({ messages }, config, { - keepContent: i !== 0, - tokenCounter: createTokenCounter(this.getEncoding()), - indexTokenCountMap: currentIndexCountMap, - maxContextTokens: agent.maxContextTokens, callbacks: { [Callback.TOOL_ERROR]: logToolError, }, @@ -941,94 +876,8 @@ class AgentClient extends BaseClient { config.signal = null; }; - await runAgent(this.options.agent, initialMessages); - let finalContentStart = 0; - if ( - this.agentConfigs && - this.agentConfigs.size > 0 && - (await checkCapability(this.options.req, AgentCapabilities.chain)) - ) { - const windowSize = 5; - let latestMessage = initialMessages.pop().content; - if (typeof latestMessage !== 'string') { - latestMessage = latestMessage[0].text; - } - let i = 1; - let runMessages = []; - - const windowIndexCountMap = {}; - const windowMessages = initialMessages.slice(-windowSize); - let currentIndex = 4; - for (let i = initialMessages.length - 1; i >= 0; i--) { - windowIndexCountMap[currentIndex] = indexTokenCountMap[i]; - currentIndex--; - if (currentIndex < 0) { - break; - } - } - const encoding = this.getEncoding(); - const tokenCounter = createTokenCounter(encoding); - for (const [agentId, agent] of this.agentConfigs) { - if (abortController.signal.aborted === true) { - break; - } - const currentRun = await run; - - if ( - i === this.agentConfigs.size && - config.configurable.hide_sequential_outputs === true - ) { - const content = this.contentParts.filter( - (part) => part.type === ContentTypes.TOOL_CALL, - ); - - this.options.res.write( - `event: message\ndata: ${JSON.stringify({ - event: 'on_content_update', - data: { - runId: this.responseMessageId, - content, - }, - })}\n\n`, - ); - } - const _runMessages = currentRun.Graph.getRunMessages(); - finalContentStart = this.contentParts.length; - runMessages = runMessages.concat(_runMessages); - const contentData = currentRun.Graph.contentData.slice(); - const bufferString = getBufferString([new HumanMessage(latestMessage), ...runMessages]); - if (i === this.agentConfigs.size) { - logger.debug(`SEQUENTIAL AGENTS: Last buffer string:\n${bufferString}`); - } - try { - const contextMessages = []; - const runIndexCountMap = {}; - for (let i = 0; i < windowMessages.length; i++) { - const message = windowMessages[i]; - const messageType = message._getType(); - if ( - (!agent.tools || agent.tools.length === 0) && - (messageType === 'tool' || (message.tool_calls?.length ?? 0) > 0) - ) { - continue; - } - runIndexCountMap[contextMessages.length] = windowIndexCountMap[i]; - contextMessages.push(message); - } - const bufferMessage = new HumanMessage(bufferString); - runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage); - const currentMessages = [...contextMessages, bufferMessage]; - await runAgent(agent, currentMessages, i, contentData, runIndexCountMap); - } catch (err) { - logger.error( - `[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`, - err, - ); - } - i++; - } - } - + await runAgents(initialMessages); + let finalContentStart = this.contentParts.length; /** Note: not implemented */ if (config.configurable.hide_sequential_outputs !== true) { finalContentStart = 0; @@ -1043,7 +892,6 @@ class AgentClient extends BaseClient { index >= finalContentStart || part.type === ContentTypes.TOOL_CALL || part.tool_call_ids ); }); - try { const attachments = await this.awaitMemoryWithTimeout(memoryPromise); if (attachments && attachments.length > 0) { diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 7cc0a39fba..e0ee5b862b 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -119,41 +119,78 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const agent_ids = primaryConfig.agent_ids; let userMCPAuthMap = primaryConfig.userMCPAuthMap; + + async function processAgent(agentId) { + const agent = await getAgent({ id: agentId }); + if (!agent) { + throw new Error(`Agent ${agentId} not found`); + } + + const validationResult = await validateAgentModel({ + req, + res, + agent, + modelsConfig, + logViolation, + }); + + if (!validationResult.isValid) { + throw new Error(validationResult.error?.message); + } + + const config = await initializeAgent({ + req, + res, + agent, + loadTools, + requestFiles, + conversationId, + endpointOption, + allowedProviders, + }); + if (userMCPAuthMap != null) { + Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {}); + } else { + userMCPAuthMap = config.userMCPAuthMap; + } + agentConfigs.set(agentId, config); + } + if (agent_ids?.length) { for (const agentId of agent_ids) { - const agent = await getAgent({ id: agentId }); - if (!agent) { - throw new Error(`Agent ${agentId} not found`); + await processAgent(agentId); + } + } + + if ((primaryConfig.edges?.length ?? 0) > 0) { + const edges = primaryConfig.edges; + const checkAgentInit = (agentId) => agentId === primaryConfig.id || agentConfigs.has(agentId); + for (const edge of edges) { + if (Array.isArray(edge.to)) { + for (const to of edge.to) { + if (checkAgentInit(to)) { + continue; + } + await processAgent(to); + } + } else if (typeof edge.to === 'string' && checkAgentInit(edge.to)) { + continue; + } else if (typeof edge.to === 'string') { + await processAgent(edge.to); } - const validationResult = await validateAgentModel({ - req, - res, - agent, - modelsConfig, - logViolation, - }); - - if (!validationResult.isValid) { - throw new Error(validationResult.error?.message); + if (Array.isArray(edge.from)) { + for (const from of edge.from) { + if (checkAgentInit(from)) { + continue; + } + await processAgent(from); + } + } else if (typeof edge.from === 'string' && checkAgentInit(edge.from)) { + continue; + } else if (typeof edge.from === 'string') { + await processAgent(edge.from); } - - const config = await initializeAgent({ - req, - res, - agent, - loadTools, - requestFiles, - conversationId, - endpointOption, - allowedProviders, - }); - if (userMCPAuthMap != null) { - Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {}); - } else { - userMCPAuthMap = config.userMCPAuthMap; - } - agentConfigs.set(agentId, config); } } diff --git a/packages/api/src/agents/run.ts b/packages/api/src/agents/run.ts index 9936a75a77..52e8908ce9 100644 --- a/packages/api/src/agents/run.ts +++ b/packages/api/src/agents/run.ts @@ -1,15 +1,17 @@ import { Run, Providers } from '@librechat/agents'; import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider'; import type { + MultiAgentGraphConfig, OpenAIClientOptions, StandardGraphConfig, - EventHandler, + AgentInputs, GenericTool, - GraphEvents, + RunConfig, IState, } from '@librechat/agents'; import type { Agent } from 'librechat-data-provider'; import type * as t from '~/types'; +import { resolveHeaders } from '~/utils/env'; const customProviders = new Set([ Providers.XAI, @@ -40,13 +42,18 @@ export function getReasoningKey( return reasoningKey; } +type RunAgent = Omit & { + tools?: GenericTool[]; + maxContextTokens?: number; + toolContextMap?: Record; +}; + /** * Creates a new Run instance with custom handlers and configuration. * * @param options - The options for creating the Run instance. - * @param options.agent - The agent for this run. + * @param options.agents - The agents for this run. * @param options.signal - The signal for this run. - * @param options.req - The server request. * @param options.runId - Optional run ID; otherwise, a new run ID will be generated. * @param options.customHandlers - Custom event handlers. * @param options.streaming - Whether to use streaming. @@ -55,61 +62,108 @@ export function getReasoningKey( */ export async function createRun({ runId, - agent, signal, + agents, + requestBody, + tokenCounter, customHandlers, + indexTokenCountMap, streaming = true, streamUsage = true, }: { - agent: Omit & { tools?: GenericTool[] }; + agents: RunAgent[]; signal: AbortSignal; runId?: string; streaming?: boolean; streamUsage?: boolean; - customHandlers?: Record; -}): Promise> { - const provider = - (providerEndpointMap[ - agent.provider as keyof typeof providerEndpointMap - ] as unknown as Providers) ?? agent.provider; + requestBody?: t.RequestBody; +} & Pick): Promise< + Run +> { + const agentInputs: AgentInputs[] = []; + const buildAgentContext = (agent: RunAgent) => { + const provider = + (providerEndpointMap[ + agent.provider as keyof typeof providerEndpointMap + ] as unknown as Providers) ?? agent.provider; - const llmConfig: t.RunLLMConfig = Object.assign( - { + const llmConfig: t.RunLLMConfig = Object.assign( + { + provider, + streaming, + streamUsage, + }, + agent.model_parameters, + ); + + const systemMessage = Object.values(agent.toolContextMap ?? {}) + .join('\n') + .trim(); + + const systemContent = [ + systemMessage, + agent.instructions ?? '', + agent.additional_instructions ?? '', + ] + .join('\n') + .trim(); + + /** + * Resolve request-based headers for Custom Endpoints. Note: if this is added to + * non-custom endpoints, needs consideration of varying provider header configs. + * This is done at this step because the request body may contain dynamic values + * that need to be resolved after agent initialization. + */ + if (llmConfig?.configuration?.defaultHeaders != null) { + llmConfig.configuration.defaultHeaders = resolveHeaders({ + headers: llmConfig.configuration.defaultHeaders as Record, + body: requestBody, + }); + } + + /** Resolves issues with new OpenAI usage field */ + if ( + customProviders.has(agent.provider) || + (agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider) + ) { + llmConfig.streamUsage = false; + llmConfig.usage = true; + } + + const reasoningKey = getReasoningKey(provider, llmConfig, agent.endpoint); + const agentInput: AgentInputs = { provider, - streaming, - streamUsage, - }, - agent.model_parameters, - ); - - /** Resolves issues with new OpenAI usage field */ - if ( - customProviders.has(agent.provider) || - (agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider) - ) { - llmConfig.streamUsage = false; - llmConfig.usage = true; - } - - const reasoningKey = getReasoningKey(provider, llmConfig, agent.endpoint); - const graphConfig: StandardGraphConfig = { - signal, - llmConfig, - reasoningKey, - tools: agent.tools, - instructions: agent.instructions, - additional_instructions: agent.additional_instructions, - // toolEnd: agent.end_after_tools, + reasoningKey, + agentId: agent.id, + tools: agent.tools, + clientOptions: llmConfig, + instructions: systemContent, + maxContextTokens: agent.maxContextTokens, + }; + agentInputs.push(agentInput); }; - // TEMPORARY FOR TESTING - if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) { - graphConfig.streamBuffer = 2000; + for (const agent of agents) { + buildAgentContext(agent); + } + + const graphConfig: RunConfig['graphConfig'] = { + signal, + agents: agentInputs, + edges: agents[0].edges, + }; + + if (agentInputs.length > 1 || ('edges' in graphConfig && graphConfig.edges.length > 0)) { + (graphConfig as unknown as MultiAgentGraphConfig).type = 'multi-agent'; + } else { + (graphConfig as StandardGraphConfig).type = 'standard'; } return Run.create({ runId, graphConfig, + tokenCounter, customHandlers, + indexTokenCountMap, }); }