diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 1abba8b2c8..dbb97df24b 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -6,6 +6,7 @@ const { agentCreateSchema, agentUpdateSchema, refreshListAvatars, + collectEdgeAgentIds, mergeAgentOcrConversion, MAX_AVATAR_REFRESH_AGENTS, convertOcrToContextInPlace, @@ -35,6 +36,7 @@ const { } = require('~/models/Agent'); const { findPubliclyAccessibleResources, + getResourcePermissionsMap, findAccessibleResources, hasPublicPermission, grantPermission, @@ -58,6 +60,44 @@ const systemTools = { const MAX_SEARCH_LEN = 100; const escapeRegex = (str = '') => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); +/** + * Validates that the requesting user has VIEW access to every agent referenced in edges. + * Agents that do not exist in the database are skipped — at create time, the `from` field + * often references the agent being built, which has no DB record yet. + * @param {import('librechat-data-provider').GraphEdge[]} edges + * @param {string} userId + * @param {string} userRole - Used for group/role principal resolution + * @returns {Promise} Agent IDs the user cannot VIEW (empty if all accessible) + */ +const validateEdgeAgentAccess = async (edges, userId, userRole) => { + const edgeAgentIds = collectEdgeAgentIds(edges); + if (edgeAgentIds.size === 0) { + return []; + } + + const agents = (await Promise.all([...edgeAgentIds].map((id) => getAgent({ id })))).filter( + Boolean, + ); + + if (agents.length === 0) { + return []; + } + + const permissionsMap = await getResourcePermissionsMap({ + userId, + role: userRole, + resourceType: ResourceType.AGENT, + resourceIds: agents.map((a) => a._id), + }); + + return agents + .filter((a) => { + const bits = permissionsMap.get(a._id.toString()) ?? 0; + return (bits & PermissionBits.VIEW) === 0; + }) + .map((a) => a.id); +}; + /** * Creates an Agent. * @route POST /Agents @@ -75,7 +115,17 @@ const createAgentHandler = async (req, res) => { agentData.model_parameters = removeNullishValues(agentData.model_parameters, true); } - const { id: userId } = req.user; + const { id: userId, role: userRole } = req.user; + + if (agentData.edges?.length) { + const unauthorized = await validateEdgeAgentAccess(agentData.edges, userId, userRole); + if (unauthorized.length > 0) { + return res.status(403).json({ + error: 'You do not have access to one or more agents referenced in edges', + agent_ids: unauthorized, + }); + } + } agentData.id = `agent_${nanoid()}`; agentData.author = userId; @@ -243,6 +293,17 @@ const updateAgentHandler = async (req, res) => { updateData.avatar = avatarField; } + if (updateData.edges?.length) { + const { id: userId, role: userRole } = req.user; + const unauthorized = await validateEdgeAgentAccess(updateData.edges, userId, userRole); + if (unauthorized.length > 0) { + return res.status(403).json({ + error: 'You do not have access to one or more agents referenced in edges', + agent_ids: unauthorized, + }); + } + } + // Convert OCR to context in incoming updateData convertOcrToContextInPlace(updateData); diff --git a/api/server/controllers/agents/v1.spec.js b/api/server/controllers/agents/v1.spec.js index ce68cc241f..ede4ea416a 100644 --- a/api/server/controllers/agents/v1.spec.js +++ b/api/server/controllers/agents/v1.spec.js @@ -2,7 +2,7 @@ const mongoose = require('mongoose'); const { nanoid } = require('nanoid'); const { v4: uuidv4 } = require('uuid'); const { agentSchema } = require('@librechat/data-schemas'); -const { FileSources } = require('librechat-data-provider'); +const { FileSources, PermissionBits } = require('librechat-data-provider'); const { MongoMemoryServer } = require('mongodb-memory-server'); // Only mock the dependencies that are not database-related @@ -46,9 +46,9 @@ jest.mock('~/models/File', () => ({ jest.mock('~/server/services/PermissionService', () => ({ findAccessibleResources: jest.fn().mockResolvedValue([]), findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]), + getResourcePermissionsMap: jest.fn().mockResolvedValue(new Map()), grantPermission: jest.fn(), hasPublicPermission: jest.fn().mockResolvedValue(false), - checkPermission: jest.fn().mockResolvedValue(true), })); jest.mock('~/models', () => ({ @@ -74,6 +74,7 @@ const { const { findAccessibleResources, findPubliclyAccessibleResources, + getResourcePermissionsMap, } = require('~/server/services/PermissionService'); const { refreshS3Url } = require('~/server/services/Files/S3/crud'); @@ -1647,4 +1648,112 @@ describe('Agent Controllers - Mass Assignment Protection', () => { expect(agent.avatar.filepath).toBe('old-s3-path.jpg'); }); }); + + describe('Edge ACL validation', () => { + let targetAgent; + + beforeEach(async () => { + targetAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: new mongoose.Types.ObjectId().toString(), + name: 'Target Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + }); + + test('createAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => { + const permMap = new Map(); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.body = { + name: 'Attacker Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.agent_ids).toContain(targetAgent.id); + }); + + test('createAgentHandler should succeed when user has VIEW on all edge-referenced agents', async () => { + const permMap = new Map([[targetAgent._id.toString(), 1]]); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.body = { + name: 'Legit Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'self_placeholder', to: targetAgent.id, edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + }); + + test('createAgentHandler should allow edges referencing non-existent agents (self-reference at create time)', async () => { + mockReq.body = { + name: 'Self-Ref Agent', + provider: 'openai', + model: 'gpt-4', + edges: [{ from: 'agent_does_not_exist_yet', to: 'agent_also_new', edgeType: 'handoff' }], + }; + + await createAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(201); + }); + + test('updateAgentHandler should return 403 when user lacks VIEW on an edge-referenced agent', async () => { + const ownedAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: mockReq.user.id, + name: 'Owned Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + + const permMap = new Map([[ownedAgent._id.toString(), PermissionBits.VIEW]]); + getResourcePermissionsMap.mockResolvedValueOnce(permMap); + + mockReq.params = { id: ownedAgent.id }; + mockReq.body = { + edges: [{ from: ownedAgent.id, to: targetAgent.id, edgeType: 'handoff' }], + }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.agent_ids).toContain(targetAgent.id); + expect(response.agent_ids).not.toContain(ownedAgent.id); + }); + + test('updateAgentHandler should succeed when edges field is absent from payload', async () => { + const ownedAgent = await Agent.create({ + id: `agent_${nanoid()}`, + author: mockReq.user.id, + name: 'Owned Agent', + provider: 'openai', + model: 'gpt-4', + tools: [], + }); + + mockReq.params = { id: ownedAgent.id }; + mockReq.body = { name: 'Renamed Agent' }; + + await updateAgentHandler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(403); + const response = mockRes.json.mock.calls[0][0]; + expect(response.name).toBe('Renamed Agent'); + }); + }); }); diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index e71270ef85..44583e6dbc 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -10,6 +10,8 @@ const { createSequentialChainEdges, } = require('@librechat/api'); const { + ResourceType, + PermissionBits, EModelEndpoint, isAgentsEndpoint, getResponseSender, @@ -21,6 +23,7 @@ const { } = require('~/server/controllers/agents/callbacks'); const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { getModelsConfig } = require('~/server/controllers/ModelController'); +const { checkPermission } = require('~/server/services/PermissionService'); const AgentClient = require('~/server/controllers/agents/client'); const { getConvoFiles } = require('~/models/Conversation'); const { processAddedConvo } = require('./addedConvo'); @@ -229,6 +232,22 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { return null; } + const hasAccess = await checkPermission({ + userId: req.user.id, + role: req.user.role, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + requiredPermission: PermissionBits.VIEW, + }); + + if (!hasAccess) { + logger.warn( + `[processAgent] User ${req.user.id} lacks VIEW access to handoff agent ${agentId}, skipping`, + ); + skippedAgentIds.add(agentId); + return null; + } + const validationResult = await validateAgentModel({ req, res, diff --git a/api/server/services/Endpoints/agents/initialize.spec.js b/api/server/services/Endpoints/agents/initialize.spec.js new file mode 100644 index 0000000000..16b41aca65 --- /dev/null +++ b/api/server/services/Endpoints/agents/initialize.spec.js @@ -0,0 +1,201 @@ +const mongoose = require('mongoose'); +const { + ResourceType, + PermissionBits, + PrincipalType, + PrincipalModel, +} = require('librechat-data-provider'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +const mockInitializeAgent = jest.fn(); +const mockValidateAgentModel = jest.fn(); + +jest.mock('@librechat/agents', () => ({ + ...jest.requireActual('@librechat/agents'), + createContentAggregator: jest.fn(() => ({ + contentParts: [], + aggregateContent: jest.fn(), + })), +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + initializeAgent: (...args) => mockInitializeAgent(...args), + validateAgentModel: (...args) => mockValidateAgentModel(...args), + GenerationJobManager: { setCollectedUsage: jest.fn() }, + getCustomEndpointConfig: jest.fn(), + createSequentialChainEdges: jest.fn(), +})); + +jest.mock('~/server/controllers/agents/callbacks', () => ({ + createToolEndCallback: jest.fn(() => jest.fn()), + getDefaultHandlers: jest.fn(() => ({})), +})); + +jest.mock('~/server/services/ToolService', () => ({ + loadAgentTools: jest.fn(), + loadToolsForExecution: jest.fn(), +})); + +jest.mock('~/server/controllers/ModelController', () => ({ + getModelsConfig: jest.fn().mockResolvedValue({}), +})); + +let agentClientArgs; +jest.mock('~/server/controllers/agents/client', () => { + return jest.fn().mockImplementation((args) => { + agentClientArgs = args; + return {}; + }); +}); + +jest.mock('./addedConvo', () => ({ + processAddedConvo: jest.fn().mockResolvedValue({ userMCPAuthMap: undefined }), +})); + +jest.mock('~/cache', () => ({ + logViolation: jest.fn(), +})); + +const { initializeClient } = require('./initialize'); +const { createAgent } = require('~/models/Agent'); +const { User, AclEntry } = require('~/db/models'); + +const PRIMARY_ID = 'agent_primary'; +const TARGET_ID = 'agent_target'; +const AUTHORIZED_ID = 'agent_authorized'; + +describe('initializeClient — processAgent ACL gate', () => { + let mongoServer; + let testUser; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + await mongoose.connect(mongoServer.getUri()); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await mongoose.connection.dropDatabase(); + jest.clearAllMocks(); + agentClientArgs = undefined; + + testUser = await User.create({ + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + role: 'USER', + }); + + mockValidateAgentModel.mockResolvedValue({ isValid: true }); + }); + + const makeReq = () => ({ + user: { id: testUser._id.toString(), role: 'USER' }, + body: { conversationId: 'conv_1', files: [] }, + config: { endpoints: {} }, + _resumableStreamId: null, + }); + + const makeEndpointOption = () => ({ + agent: Promise.resolve({ + id: PRIMARY_ID, + name: 'Primary', + provider: 'openai', + model: 'gpt-4', + tools: [], + }), + model_parameters: { model: 'gpt-4' }, + endpoint: 'agents', + }); + + const makePrimaryConfig = (edges) => ({ + id: PRIMARY_ID, + endpoint: 'agents', + edges, + toolDefinitions: [], + toolRegistry: new Map(), + userMCPAuthMap: null, + tool_resources: {}, + resendFiles: true, + maxContextTokens: 4096, + }); + + it('should skip handoff agent and filter its edge when user lacks VIEW access', async () => { + await createAgent({ + id: TARGET_ID, + name: 'Target Agent', + provider: 'openai', + model: 'gpt-4', + author: new mongoose.Types.ObjectId(), + tools: [], + }); + + const edges = [{ from: PRIMARY_ID, to: TARGET_ID, edgeType: 'handoff' }]; + mockInitializeAgent.mockResolvedValue(makePrimaryConfig(edges)); + + await initializeClient({ + req: makeReq(), + res: {}, + signal: new AbortController().signal, + endpointOption: makeEndpointOption(), + }); + + expect(mockInitializeAgent).toHaveBeenCalledTimes(1); + expect(agentClientArgs.agent.edges).toEqual([]); + }); + + it('should initialize handoff agent and keep its edge when user has VIEW access', async () => { + const authorizedAgent = await createAgent({ + id: AUTHORIZED_ID, + name: 'Authorized Agent', + provider: 'openai', + model: 'gpt-4', + author: new mongoose.Types.ObjectId(), + tools: [], + }); + + await AclEntry.create({ + principalType: PrincipalType.USER, + principalId: testUser._id, + principalModel: PrincipalModel.USER, + resourceType: ResourceType.AGENT, + resourceId: authorizedAgent._id, + permBits: PermissionBits.VIEW, + grantedBy: testUser._id, + }); + + const edges = [{ from: PRIMARY_ID, to: AUTHORIZED_ID, edgeType: 'handoff' }]; + const handoffConfig = { + id: AUTHORIZED_ID, + edges: [], + toolDefinitions: [], + toolRegistry: new Map(), + userMCPAuthMap: null, + tool_resources: {}, + }; + + let callCount = 0; + mockInitializeAgent.mockImplementation(() => { + callCount++; + return callCount === 1 + ? Promise.resolve(makePrimaryConfig(edges)) + : Promise.resolve(handoffConfig); + }); + + await initializeClient({ + req: makeReq(), + res: {}, + signal: new AbortController().signal, + endpointOption: makeEndpointOption(), + }); + + expect(mockInitializeAgent).toHaveBeenCalledTimes(2); + expect(agentClientArgs.agent.edges).toHaveLength(1); + expect(agentClientArgs.agent.edges[0].to).toBe(AUTHORIZED_ID); + }); +}); diff --git a/packages/api/src/agents/edges.spec.ts b/packages/api/src/agents/edges.spec.ts index 1b30a202d0..b23f00f63f 100644 --- a/packages/api/src/agents/edges.spec.ts +++ b/packages/api/src/agents/edges.spec.ts @@ -1,5 +1,11 @@ import type { GraphEdge } from 'librechat-data-provider'; -import { getEdgeKey, getEdgeParticipants, filterOrphanedEdges, createEdgeCollector } from './edges'; +import { + getEdgeKey, + getEdgeParticipants, + collectEdgeAgentIds, + filterOrphanedEdges, + createEdgeCollector, +} from './edges'; describe('edges utilities', () => { describe('getEdgeKey', () => { @@ -70,6 +76,49 @@ describe('edges utilities', () => { }); }); + describe('collectEdgeAgentIds', () => { + it('should return empty set for undefined input', () => { + expect(collectEdgeAgentIds(undefined)).toEqual(new Set()); + }); + + it('should return empty set for empty array', () => { + expect(collectEdgeAgentIds([])).toEqual(new Set()); + }); + + it('should collect IDs from simple string from/to', () => { + const edges: GraphEdge[] = [{ from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }]; + expect(collectEdgeAgentIds(edges)).toEqual(new Set(['agent_a', 'agent_b'])); + }); + + it('should collect IDs from array from/to values', () => { + const edges: GraphEdge[] = [ + { from: ['agent_a', 'agent_b'], to: ['agent_c', 'agent_d'], edgeType: 'handoff' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual( + new Set(['agent_a', 'agent_b', 'agent_c', 'agent_d']), + ); + }); + + it('should deduplicate IDs across edges', () => { + const edges: GraphEdge[] = [ + { from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }, + { from: 'agent_b', to: 'agent_c', edgeType: 'handoff' }, + { from: 'agent_a', to: 'agent_c', edgeType: 'direct' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual(new Set(['agent_a', 'agent_b', 'agent_c'])); + }); + + it('should handle mixed scalar and array edges', () => { + const edges: GraphEdge[] = [ + { from: 'agent_a', to: ['agent_b', 'agent_c'], edgeType: 'handoff' }, + { from: ['agent_c', 'agent_d'], to: 'agent_e', edgeType: 'direct' }, + ]; + expect(collectEdgeAgentIds(edges)).toEqual( + new Set(['agent_a', 'agent_b', 'agent_c', 'agent_d', 'agent_e']), + ); + }); + }); + describe('filterOrphanedEdges', () => { const edges: GraphEdge[] = [ { from: 'agent_a', to: 'agent_b', edgeType: 'handoff' }, diff --git a/packages/api/src/agents/edges.ts b/packages/api/src/agents/edges.ts index 4d2883d165..9a36105b74 100644 --- a/packages/api/src/agents/edges.ts +++ b/packages/api/src/agents/edges.ts @@ -43,6 +43,20 @@ export function filterOrphanedEdges(edges: GraphEdge[], skippedAgentIds: Set { + const ids = new Set(); + if (!edges || edges.length === 0) { + return ids; + } + for (const edge of edges) { + for (const id of getEdgeParticipants(edge)) { + ids.add(id); + } + } + return ids; +} + /** * Result of discovering and aggregating edges from connected agents. */