diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index 4643f096aa..f3970bff22 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -12,7 +12,11 @@ const { validateActionDomain, validateAndParseOpenAPISpec, } = require('librechat-data-provider'); -const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { + legacyDomainEncode, + encryptMetadata, + domainParser, +} = require('~/server/services/ActionService'); const { findAccessibleResources } = require('~/server/services/PermissionService'); const { getAgent, updateAgent, getListAgentsByAccess } = require('~/models/Agent'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); @@ -119,13 +123,14 @@ router.post( return res.status(400).json({ message: 'Domain not allowed' }); } - let { domain } = metadata; - domain = await domainParser(domain, true); + const encodedDomain = await domainParser(metadata.domain, true); - if (!domain) { + if (!encodedDomain) { return res.status(400).json({ message: 'No domain provided' }); } + const legacyDomain = legacyDomainEncode(metadata.domain); + const action_id = _action_id ?? nanoid(); const initialPromises = []; @@ -160,14 +165,23 @@ router.post( actions.push(action); } - actions.push(`${domain}${actionDelimiter}${action_id}`); + actions.push(`${encodedDomain}${actionDelimiter}${action_id}`); /** @type {string[]}} */ const { tools: _tools = [] } = agent; + const shouldRemoveAgentTool = (tool) => { + if (!tool) { + return false; + } + return ( + tool.includes(encodedDomain) || tool.includes(legacyDomain) || tool.includes(action_id) + ); + }; + const tools = _tools - .filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id)))) - .concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`)); + .filter((tool) => !shouldRemoveAgentTool(tool)) + .concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${encodedDomain}`)); // Force version update since actions are changing const updatedAgent = await updateAgent( @@ -231,22 +245,22 @@ router.delete( const { tools = [], actions = [] } = agent; - let domain = ''; + let storedDomain = ''; const updatedActions = actions.filter((action) => { if (action.includes(action_id)) { - [domain] = action.split(actionDelimiter); + [storedDomain] = action.split(actionDelimiter); return false; } return true; }); - domain = await domainParser(domain, true); - - if (!domain) { + if (!storedDomain) { return res.status(400).json({ message: 'No domain provided' }); } - const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain))); + const updatedTools = tools.filter( + (tool) => !(tool && (tool.includes(storedDomain) || tool.includes(action_id))), + ); // Force version update since actions are being removed await updateAgent( diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index b085fbd36a..75ab879e2b 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -3,7 +3,11 @@ const { nanoid } = require('nanoid'); const { logger } = require('@librechat/data-schemas'); const { isActionDomainAllowed } = require('@librechat/api'); const { actionDelimiter, EModelEndpoint, removeNullishValues } = require('librechat-data-provider'); -const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { + legacyDomainEncode, + encryptMetadata, + domainParser, +} = require('~/server/services/ActionService'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAssistantDoc, getAssistant } = require('~/models/Assistant'); @@ -39,13 +43,14 @@ router.post('/:assistant_id', async (req, res) => { return res.status(400).json({ message: 'Domain not allowed' }); } - let { domain } = metadata; - domain = await domainParser(domain, true); + const encodedDomain = await domainParser(metadata.domain, true); - if (!domain) { + if (!encodedDomain) { return res.status(400).json({ message: 'No domain provided' }); } + const legacyDomain = legacyDomainEncode(metadata.domain); + const action_id = _action_id ?? nanoid(); const initialPromises = []; @@ -81,25 +86,29 @@ router.post('/:assistant_id', async (req, res) => { actions.push(action); } - actions.push(`${domain}${actionDelimiter}${action_id}`); + actions.push(`${encodedDomain}${actionDelimiter}${action_id}`); /** @type {{ tools: FunctionTool[] | { type: 'code_interpreter'|'retrieval'}[]}} */ const { tools: _tools = [] } = assistant; + const shouldRemoveAssistantTool = (tool) => { + if (!tool.function) { + return false; + } + const name = tool.function.name; + return ( + name.includes(encodedDomain) || name.includes(legacyDomain) || name.includes(action_id) + ); + }; + const tools = _tools - .filter( - (tool) => - !( - tool.function && - (tool.function.name.includes(domain) || tool.function.name.includes(action_id)) - ), - ) + .filter((tool) => !shouldRemoveAssistantTool(tool)) .concat( functions.map((tool) => ({ ...tool, function: { ...tool.function, - name: `${tool.function.name}${actionDelimiter}${domain}`, + name: `${tool.function.name}${actionDelimiter}${encodedDomain}`, }, })), ); @@ -171,23 +180,25 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { const { actions = [] } = assistant_data ?? {}; const { tools = [] } = assistant ?? {}; - let domain = ''; + let storedDomain = ''; const updatedActions = actions.filter((action) => { if (action.includes(action_id)) { - [domain] = action.split(actionDelimiter); + [storedDomain] = action.split(actionDelimiter); return false; } return true; }); - domain = await domainParser(domain, true); - - if (!domain) { + if (!storedDomain) { return res.status(400).json({ message: 'No domain provided' }); } const updatedTools = tools.filter( - (tool) => !(tool.function && tool.function.name.includes(domain)), + (tool) => + !( + tool.function && + (tool.function.name.includes(storedDomain) || tool.function.name.includes(action_id)) + ), ); await openai.beta.assistants.update(assistant_id, { tools: updatedTools }); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 5e96726a46..bde052bba4 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -28,6 +28,7 @@ const { getLogStores } = require('~/cache'); const JWT_SECRET = process.env.JWT_SECRET; const toolNameRegex = /^[a-zA-Z0-9_-]+$/; +const protocolRegex = /^https?:\/\//; const replaceSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); /** @@ -48,7 +49,11 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { actions = await getActions({ assistant_id, user: req.user.id }, true); const matchingActions = actions.filter((action) => { const metadata = action.metadata; - return metadata && metadata.domain === domain; + if (!metadata) { + return false; + } + const strippedMetaDomain = stripProtocol(metadata.domain); + return strippedMetaDomain === domain || metadata.domain === domain; }); const action = matchingActions[0]; if (!action) { @@ -66,10 +71,36 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { return tool; }; +/** @param {string} domain */ +function stripProtocol(domain) { + const stripped = domain.replace(protocolRegex, ''); + const pathIdx = stripped.indexOf('/'); + return pathIdx === -1 ? stripped : stripped.substring(0, pathIdx); +} + +/** + * Encodes a domain using the legacy scheme (full URL including protocol). + * Used for backward-compatible matching against agents saved before the collision fix. + * @param {string} domain + * @returns {string} + */ +function legacyDomainEncode(domain) { + if (!domain) { + return ''; + } + if (domain.length <= Constants.ENCODED_DOMAIN_LENGTH) { + return domain.replace(/\./g, actionDomainSeparator); + } + const modifiedDomain = Buffer.from(domain).toString('base64'); + return modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH); +} + /** * Encodes or decodes a domain name to/from base64, or replacing periods with a custom separator. * * Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. + * Strips protocol prefix before encoding to prevent base64 collisions + * (all `https://` URLs share the same 10-char base64 prefix). * * @param {string} domain - The domain name to encode/decode. * @param {boolean} inverse - False to decode from base64, true to encode to base64. @@ -79,23 +110,27 @@ async function domainParser(domain, inverse = false) { if (!domain) { return; } - const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); - const cachedDomain = await domainsCache.get(domain); - if (inverse && cachedDomain) { - return domain; - } - if (inverse && domain.length <= Constants.ENCODED_DOMAIN_LENGTH) { - return domain.replace(/\./g, actionDomainSeparator); - } + const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); if (inverse) { - const modifiedDomain = Buffer.from(domain).toString('base64'); + const hostname = stripProtocol(domain); + const cachedDomain = await domainsCache.get(hostname); + if (cachedDomain) { + return hostname; + } + + if (hostname.length <= Constants.ENCODED_DOMAIN_LENGTH) { + return hostname.replace(/\./g, actionDomainSeparator); + } + + const modifiedDomain = Buffer.from(hostname).toString('base64'); const key = modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH); await domainsCache.set(key, modifiedDomain); return key; } + const cachedDomain = await domainsCache.get(domain); if (!cachedDomain) { return domain.replace(replaceSeparatorRegex, '.'); } @@ -456,6 +491,7 @@ const deleteAssistantActions = async ({ req, assistant_id }) => { module.exports = { deleteAssistantActions, validateAndUpdateTool, + legacyDomainEncode, createActionTool, encryptMetadata, decryptMetadata, diff --git a/api/server/services/ActionService.spec.js b/api/server/services/ActionService.spec.js index c60aef7ad1..42def44b4f 100644 --- a/api/server/services/ActionService.spec.js +++ b/api/server/services/ActionService.spec.js @@ -1,175 +1,539 @@ -const { Constants, actionDomainSeparator } = require('librechat-data-provider'); -const { domainParser } = require('./ActionService'); +const { Constants, actionDelimiter, actionDomainSeparator } = require('librechat-data-provider'); +const { domainParser, legacyDomainEncode, validateAndUpdateTool } = require('./ActionService'); jest.mock('keyv'); -const globalCache = {}; +jest.mock('~/models/Action', () => ({ + getActions: jest.fn(), + deleteActions: jest.fn(), +})); + +const { getActions } = require('~/models/Action'); + +let mockDomainCache = {}; jest.mock('~/cache/getLogStores', () => { - return jest.fn().mockImplementation(() => { - const EventEmitter = require('events'); - const { CacheKeys } = require('librechat-data-provider'); + return jest.fn().mockImplementation(() => ({ + get: async (key) => mockDomainCache[key] ?? null, + set: async (key, value) => { + mockDomainCache[key] = value; + return true; + }, + })); +}); - class KeyvMongo extends EventEmitter { - constructor(url = 'mongodb://127.0.0.1:27017', options) { - super(); - this.ttlSupport = false; - url = url ?? {}; - if (typeof url === 'string') { - url = { url }; - } - if (url.uri) { - url = { url: url.uri, ...url }; - } - this.opts = { - url, - collection: 'keyv', - ...url, - ...options, - }; - } +beforeEach(() => { + mockDomainCache = {}; + getActions.mockReset(); +}); - get = async (key) => { - return new Promise((resolve) => { - resolve(globalCache[key] || null); - }); - }; +const SEP = actionDomainSeparator; +const DELIM = actionDelimiter; +const MAX = Constants.ENCODED_DOMAIN_LENGTH; +const domainSepRegex = new RegExp(SEP, 'g'); - set = async (key, value) => { - return new Promise((resolve) => { - globalCache[key] = value; - resolve(true); - }); - }; - } +describe('domainParser', () => { + describe('nullish input', () => { + it.each([null, undefined, ''])('returns undefined for %j', async (input) => { + expect(await domainParser(input, true)).toBeUndefined(); + expect(await domainParser(input, false)).toBeUndefined(); + }); + }); - return new KeyvMongo('', { - namespace: CacheKeys.ENCODED_DOMAINS, - ttl: 0, + describe('short-path encoding (hostname ≤ threshold)', () => { + it.each([ + ['examp.com', `examp${SEP}com`], + ['swapi.tech', `swapi${SEP}tech`], + ['a.b', `a${SEP}b`], + ])('replaces dots in %s → %s', async (domain, expected) => { + expect(await domainParser(domain, true)).toBe(expected); + }); + + it('handles domain exactly at threshold length', async () => { + const domain = 'a'.repeat(MAX - 4) + '.com'; + expect(domain).toHaveLength(MAX); + const result = await domainParser(domain, true); + expect(result).toBe(domain.replace(/\./g, SEP)); + }); + }); + + describe('base64-path encoding (hostname > threshold)', () => { + it('produces a key of exactly ENCODED_DOMAIN_LENGTH chars', async () => { + const result = await domainParser('api.example.com', true); + expect(result).toHaveLength(MAX); + }); + + it('encodes hostname, not full URL', async () => { + const hostname = 'api.example.com'; + const expectedKey = Buffer.from(hostname).toString('base64').substring(0, MAX); + expect(await domainParser(hostname, true)).toBe(expectedKey); + }); + + it('populates decode cache for round-trip', async () => { + const hostname = 'longdomainname.com'; + const key = await domainParser(hostname, true); + + expect(mockDomainCache[key]).toBe(Buffer.from(hostname).toString('base64')); + expect(await domainParser(key, false)).toBe(hostname); + }); + }); + + describe('protocol stripping', () => { + it('https:// URL and bare hostname produce identical encoding', async () => { + const encoded = await domainParser('https://swapi.tech', true); + expect(encoded).toBe(await domainParser('swapi.tech', true)); + expect(encoded).toBe(`swapi${SEP}tech`); + }); + + it('http:// URL and bare hostname produce identical encoding', async () => { + const encoded = await domainParser('http://api.example.com', true); + expect(encoded).toBe(await domainParser('api.example.com', true)); + }); + + it('different https:// domains produce unique keys', async () => { + const keys = await Promise.all([ + domainParser('https://api.example.com', true), + domainParser('https://api.weather.com', true), + domainParser('https://data.github.com', true), + ]); + const unique = new Set(keys); + expect(unique.size).toBe(keys.length); + }); + + it('long hostname after stripping still uses base64 path', async () => { + const result = await domainParser('https://api.example.com', true); + expect(result).toHaveLength(MAX); + expect(result).not.toContain(SEP); + }); + + it('short hostname after stripping uses dot-replacement path', async () => { + const result = await domainParser('https://a.b.c', true); + expect(result).toBe(`a${SEP}b${SEP}c`); + }); + + it('strips path and query from full URL before encoding', async () => { + const result = await domainParser('https://api.example.com/v1/endpoint?foo=bar', true); + expect(result).toBe(await domainParser('api.example.com', true)); + }); + }); + + describe('unicode domains', () => { + it('encodes unicode hostname via base64 path', async () => { + const domain = 'täst.example.com'; + const result = await domainParser(domain, true); + expect(result).toHaveLength(MAX); + expect(result).toBe(Buffer.from(domain).toString('base64').substring(0, MAX)); + }); + + it('round-trips unicode hostname through encode then decode', async () => { + const domain = 'täst.example.com'; + const key = await domainParser(domain, true); + expect(await domainParser(key, false)).toBe(domain); + }); + + it('strips protocol before encoding unicode hostname', async () => { + const withProto = 'https://täst.example.com'; + const bare = 'täst.example.com'; + expect(await domainParser(withProto, true)).toBe(await domainParser(bare, true)); + }); + }); + + describe('decode path', () => { + it('short-path encoded domain decodes via separator replacement', async () => { + expect(await domainParser(`examp${SEP}com`, false)).toBe('examp.com'); + }); + + it('base64-path encoded domain decodes via cache lookup', async () => { + const hostname = 'api.example.com'; + const key = await domainParser(hostname, true); + expect(await domainParser(key, false)).toBe(hostname); + }); + + it('returns input unchanged for unknown non-separator strings', async () => { + expect(await domainParser('not_base64_encoded', false)).toBe('not_base64_encoded'); + }); + + it('returns a string without throwing for corrupt cache entries', async () => { + mockDomainCache['corrupt_key'] = '!!!'; + const result = await domainParser('corrupt_key', false); + expect(typeof result).toBe('string'); }); }); }); -describe('domainParser', () => { - const TLD = '.com'; - - // Non-azure request - it('does not return domain as is if not azure', async () => { - const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`; - const result1 = await domainParser(domain, false); - const result2 = await domainParser(domain, true); - expect(result1).not.toEqual(domain); - expect(result2).not.toEqual(domain); +describe('legacyDomainEncode', () => { + it.each(['', null, undefined])('returns empty string for %j', (input) => { + expect(legacyDomainEncode(input)).toBe(''); }); - // Test for Empty or Null Inputs - it('returns undefined for null domain input', async () => { - const result = await domainParser(null, true); - expect(result).toBeUndefined(); + it('is synchronous (returns a string, not a Promise)', () => { + const result = legacyDomainEncode('examp.com'); + expect(result).toBe(`examp${SEP}com`); + expect(result).not.toBeInstanceOf(Promise); }); - it('returns undefined for empty domain input', async () => { - const result = await domainParser('', true); - expect(result).toBeUndefined(); + it('uses dot-replacement for short domains', () => { + expect(legacyDomainEncode('examp.com')).toBe(`examp${SEP}com`); }); - // Verify Correct Caching Behavior - it('caches encoded domain correctly', async () => { - const domain = 'longdomainname.com'; - const encodedDomain = Buffer.from(domain) - .toString('base64') - .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - - await domainParser(domain, true); - - const cachedValue = await globalCache[encodedDomain]; - expect(cachedValue).toEqual(Buffer.from(domain).toString('base64')); + it('uses base64 prefix of full input for long domains', () => { + const domain = 'https://swapi.tech'; + const expected = Buffer.from(domain).toString('base64').substring(0, MAX); + expect(legacyDomainEncode(domain)).toBe(expected); }); - // Test for Edge Cases Around Length Threshold - it('encodes domain exactly at threshold without modification', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); + it('all https:// URLs collide to the same key', () => { + const results = [ + legacyDomainEncode('https://api.example.com'), + legacyDomainEncode('https://api.weather.com'), + legacyDomainEncode('https://totally.different.host'), + ]; + expect(new Set(results).size).toBe(1); }); - it('encodes domain just below threshold without modification', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); + it('matches what old domainParser would have produced', () => { + const domain = 'https://api.example.com'; + const legacy = legacyDomainEncode(domain); + expect(legacy).toBe(Buffer.from(domain).toString('base64').substring(0, MAX)); }); - // Test for Unicode Domain Names - it('handles unicode characters in domain names correctly when encoding', async () => { - const unicodeDomain = 'täst.example.com'; - const encodedDomain = Buffer.from(unicodeDomain) - .toString('base64') - .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - const result = await domainParser(unicodeDomain, true); - expect(result).toEqual(encodedDomain); - }); - - it('decodes unicode domain names correctly', async () => { - const unicodeDomain = 'täst.example.com'; - const encodedDomain = Buffer.from(unicodeDomain).toString('base64'); - globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching - - const result = await domainParser( - encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH), - false, - ); - expect(result).toEqual(unicodeDomain); - }); - - // Core Functionality Tests - it('returns domain with replaced separators if no cached domain exists', async () => { - const domain = 'example.com'; - const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(withSeparator, false); - expect(result).toEqual(domain); - }); - - it('returns domain with replaced separators when inverse is false and under encoding length', async () => { - const domain = 'examp.com'; - const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(withSeparator, false); - expect(result).toEqual(domain); - }); - - it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => { - const domain = 'examp.com'; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); - }); - - it('encodes domain when length is above threshold and inverse is true', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com'); - const result = await domainParser(domain, true); - expect(result).not.toEqual(domain); - expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH); - }); - - it('returns encoded value if no encoded value is cached, and inverse is false', async () => { - const originalDomain = 'example.com'; - const encodedDomain = Buffer.from( - originalDomain.replace(/\./g, actionDomainSeparator), - ).toString('base64'); - const result = await domainParser(encodedDomain, false); - expect(result).toEqual(encodedDomain); - }); - - it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => { - const originalDomain = 'example.com'; - const encodedDomain = await domainParser(originalDomain, true); - const result = await domainParser(encodedDomain, false); - expect(result).toEqual(originalDomain); - }); - - it('handles invalid base64 encoded values gracefully', async () => { - const invalidBase64Domain = 'not_base64_encoded'; - const result = await domainParser(invalidBase64Domain, false); - expect(result).toEqual(invalidBase64Domain); + it('produces same result as new domainParser for short bare hostnames', async () => { + const domain = 'swapi.tech'; + expect(legacyDomainEncode(domain)).toBe(await domainParser(domain, true)); + }); +}); + +describe('validateAndUpdateTool', () => { + const mockReq = { user: { id: 'user123' } }; + + it('returns tool unchanged when name passes tool-name regex', async () => { + const tool = { function: { name: 'getPeople_action_swapi---tech' } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + expect(result).toEqual(tool); + expect(getActions).not.toHaveBeenCalled(); + }); + + it('matches action when metadata.domain has https:// prefix and tool domain is bare hostname', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'https://api.example.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).not.toBeNull(); + expect(result.function.name).toMatch(/^getPeople_action_/); + expect(result.function.name).not.toContain('.'); + }); + + it('matches action when metadata.domain has no protocol', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'api.example.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).not.toBeNull(); + expect(result.function.name).toMatch(/^getPeople_action_/); + }); + + it('returns null when no action matches the domain', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'https://other.domain.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).toBeNull(); + }); + + it('returns null when action has no metadata', async () => { + getActions.mockResolvedValue([{ metadata: null }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).toBeNull(); + }); +}); + +describe('backward-compatible tool name matching', () => { + function normalizeToolName(name) { + return name.replace(domainSepRegex, '_'); + } + + function buildToolName(functionName, encodedDomain) { + return `${functionName}${DELIM}${encodedDomain}`; + } + + describe('definition-phase matching', () => { + it('new encoding matches agent tools stored with new encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const encoded = await domainParser(metadataDomain, true); + const normalized = normalizeToolName(encoded); + + const storedTool = buildToolName('getPeople', encoded); + const defToolName = `getPeople${DELIM}${normalized}`; + + expect(normalizeToolName(storedTool)).toBe(defToolName); + }); + + it('legacy encoding matches agent tools stored with legacy encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const legacy = legacyDomainEncode(metadataDomain); + const legacyNormalized = normalizeToolName(legacy); + + const storedTool = buildToolName('getPeople', legacy); + const legacyDefName = `getPeople${DELIM}${legacyNormalized}`; + + expect(normalizeToolName(storedTool)).toBe(legacyDefName); + }); + + it('new definition matches old stored tools via legacy fallback', async () => { + const metadataDomain = 'https://swapi.tech'; + const newDomain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const newNorm = normalizeToolName(newDomain); + const legacyNorm = normalizeToolName(legacyDomain); + + const oldStoredTool = buildToolName('getPeople', legacyDomain); + const newToolName = `getPeople${DELIM}${newNorm}`; + const legacyToolName = `getPeople${DELIM}${legacyNorm}`; + + const storedNormalized = normalizeToolName(oldStoredTool); + const hasMatch = storedNormalized === newToolName || storedNormalized === legacyToolName; + expect(hasMatch).toBe(true); + }); + + it('pre-normalized Set eliminates per-tool normalization', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const normalizedDomain = normalizeToolName(domain); + const legacyNormalized = normalizeToolName(legacyDomain); + + const storedTools = [ + buildToolName('getWeather', legacyDomain), + buildToolName('getForecast', domain), + ]; + + const preNormalized = new Set(storedTools.map((t) => normalizeToolName(t))); + + const toolName = `getWeather${DELIM}${normalizedDomain}`; + const legacyToolName = `getWeather${DELIM}${legacyNormalized}`; + expect(preNormalized.has(toolName) || preNormalized.has(legacyToolName)).toBe(true); + }); + }); + + describe('execution-phase tool lookup', () => { + it('model-called tool name resolves via normalizedToDomain map (new encoding)', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const normalized = normalizeToolName(domain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(normalized, domain); + + const modelToolName = `getWeather${DELIM}${normalized}`; + + let matched = ''; + for (const [norm, canonical] of normalizedToDomain.entries()) { + if (modelToolName.includes(norm)) { + matched = canonical; + break; + } + } + + expect(matched).toBe(domain); + + const functionName = modelToolName.replace(`${DELIM}${normalizeToolName(matched)}`, ''); + expect(functionName).toBe('getWeather'); + }); + + it('model-called tool name resolves via legacy entry in normalizedToDomain map', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const legacyNorm = normalizeToolName(legacyDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(normalizeToolName(domain), domain); + normalizedToDomain.set(legacyNorm, domain); + + const legacyModelToolName = `getWeather${DELIM}${legacyNorm}`; + + let matched = ''; + for (const [norm, canonical] of normalizedToDomain.entries()) { + if (legacyModelToolName.includes(norm)) { + matched = canonical; + break; + } + } + + expect(matched).toBe(domain); + }); + + it('legacy guard skips duplicate map entry for short bare hostnames', async () => { + const domain = 'swapi.tech'; + const newEncoding = await domainParser(domain, true); + const legacyEncoding = legacyDomainEncode(domain); + + expect(newEncoding).toBe(legacyEncoding); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(newEncoding, newEncoding); + if (legacyEncoding !== newEncoding) { + normalizedToDomain.set(legacyEncoding, newEncoding); + } + expect(normalizedToDomain.size).toBe(1); + }); + }); + + describe('processRequiredActions matching (assistants path)', () => { + it('legacy tool from OpenAI matches via normalizedToDomain with both encodings', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(domain, domain); + if (legacyDomain !== domain) { + normalizedToDomain.set(legacyDomain, domain); + } + + const legacyToolName = buildToolName('getPeople', legacyDomain); + + let currentDomain = ''; + let matchedKey = ''; + for (const [key, canonical] of normalizedToDomain.entries()) { + if (legacyToolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; + break; + } + } + + expect(currentDomain).toBe(domain); + expect(matchedKey).toBe(legacyDomain); + + const functionName = legacyToolName.replace(`${DELIM}${matchedKey}`, ''); + expect(functionName).toBe('getPeople'); + }); + + it('new tool name matches via the canonical domain key', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(domain, domain); + if (legacyDomain !== domain) { + normalizedToDomain.set(legacyDomain, domain); + } + + const newToolName = buildToolName('getPeople', domain); + + let currentDomain = ''; + let matchedKey = ''; + for (const [key, canonical] of normalizedToDomain.entries()) { + if (newToolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; + break; + } + } + + expect(currentDomain).toBe(domain); + expect(matchedKey).toBe(domain); + + const functionName = newToolName.replace(`${DELIM}${matchedKey}`, ''); + expect(functionName).toBe('getPeople'); + }); + }); + + describe('save-route cleanup', () => { + it('tool filter removes tools matching new encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const tools = [ + buildToolName('getPeople', domain), + buildToolName('unrelated', 'other---domain'), + ]; + + const filtered = tools.filter((t) => !t.includes(domain) && !t.includes(legacyDomain)); + + expect(filtered).toEqual([buildToolName('unrelated', 'other---domain')]); + }); + + it('tool filter removes tools matching legacy encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const tools = [ + buildToolName('getPeople', legacyDomain), + buildToolName('unrelated', 'other---domain'), + ]; + + const filtered = tools.filter((t) => !t.includes(domain) && !t.includes(legacyDomain)); + + expect(filtered).toEqual([buildToolName('unrelated', 'other---domain')]); + }); + }); + + describe('delete-route domain extraction', () => { + it('domain extracted from actions array is usable as-is for tool filtering', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const actionId = 'abc123'; + const actionEntry = `${domain}${DELIM}${actionId}`; + + const [storedDomain] = actionEntry.split(DELIM); + expect(storedDomain).toBe(domain); + + const tools = [buildToolName('getWeather', domain), buildToolName('getPeople', 'other')]; + + const filtered = tools.filter((t) => !t.includes(storedDomain)); + expect(filtered).toEqual([buildToolName('getPeople', 'other')]); + }); + }); + + describe('multi-action agents (collision scenario)', () => { + it('two https:// actions now produce distinct tool names', async () => { + const domain1 = await domainParser('https://api.weather.com', true); + const domain2 = await domainParser('https://api.spacex.com', true); + + const tool1 = buildToolName('getData', domain1); + const tool2 = buildToolName('getData', domain2); + + expect(tool1).not.toBe(tool2); + }); + + it('two https:// actions used to collide in legacy encoding', () => { + const legacy1 = legacyDomainEncode('https://api.weather.com'); + const legacy2 = legacyDomainEncode('https://api.spacex.com'); + + const tool1 = buildToolName('getData', legacy1); + const tool2 = buildToolName('getData', legacy2); + + expect(tool1).toBe(tool2); + }); }); }); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 5fc95e748d..ca75e7eb4f 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -42,6 +42,7 @@ const { } = require('librechat-data-provider'); const { createActionTool, + legacyDomainEncode, decryptMetadata, loadActionSets, domainParser, @@ -65,6 +66,8 @@ const { findPluginAuthsByKeys } = require('~/models'); const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); +const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + /** * Resolves the set of enabled agent capabilities from endpoints config, * falling back to app-level or default capabilities for ephemeral agents. @@ -172,8 +175,7 @@ async function processRequiredActions(client, requiredActions) { const promises = []; - /** @type {Action[]} */ - let actionSets = []; + let actionSetsData = null; let isActionTool = false; const ActionToolMap = {}; const ActionBuildersMap = {}; @@ -259,9 +261,9 @@ async function processRequiredActions(client, requiredActions) { if (!tool) { // throw new Error(`Tool ${currentAction.tool} not found.`); - // Load all action sets once if not already loaded - if (!actionSets.length) { - actionSets = + if (!actionSetsData) { + /** @type {Action[]} */ + const actionSets = (await loadActionSets({ assistant_id: client.req.body.assistant_id, })) ?? []; @@ -269,11 +271,16 @@ async function processRequiredActions(client, requiredActions) { // Process all action sets once // Map domains to their processed action sets const processedDomains = new Map(); - const domainMap = new Map(); + const domainLookupMap = new Map(); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + domainLookupMap.set(domain, domain); + + const legacyDomain = legacyDomainEncode(action.metadata.domain); + if (legacyDomain !== domain) { + domainLookupMap.set(legacyDomain, domain); + } const isDomainAllowed = await isActionDomainAllowed( action.metadata.domain, @@ -328,27 +335,26 @@ async function processRequiredActions(client, requiredActions) { ActionBuildersMap[action.metadata.domain] = requestBuilders; } - // Update actionSets reference to use the domain map - actionSets = { domainMap, processedDomains }; + actionSetsData = { domainLookupMap, processedDomains }; } - // Find the matching domain for this tool let currentDomain = ''; - for (const domain of actionSets.domainMap.keys()) { - if (currentAction.tool.includes(domain)) { - currentDomain = domain; + let matchedKey = ''; + for (const [key, canonical] of actionSetsData.domainLookupMap.entries()) { + if (currentAction.tool.includes(key)) { + currentDomain = canonical; + matchedKey = key; break; } } - if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) { - // TODO: try `function` if no action set is found - // throw new Error(`Tool ${currentAction.tool} not found.`); + if (!currentDomain || !actionSetsData.processedDomains.has(currentDomain)) { continue; } - const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain); - const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, ''); + const { action, requestBuilders, encrypted } = + actionSetsData.processedDomains.get(currentDomain); + const functionName = currentAction.tool.replace(`${actionDelimiter}${matchedKey}`, ''); const requestBuilder = requestBuilders[functionName]; if (!requestBuilder) { @@ -586,12 +592,17 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const definitions = []; const allowedDomains = appConfig?.actions?.allowedDomains; - const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + const normalizedToolNames = new Set( + actionToolNames.map((n) => n.replace(domainSeparatorRegex, '_')), + ); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + const legacyDomain = legacyDomainEncode(action.metadata.domain); + const legacyNormalized = legacyDomain.replace(domainSeparatorRegex, '_'); + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); if (!isDomainAllowed) { logger.warn( @@ -611,7 +622,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to for (const sig of functionSignatures) { const toolName = `${sig.name}${actionDelimiter}${normalizedDomain}`; - if (!actionToolNames.some((name) => name.replace(domainSeparatorRegex, '_') === toolName)) { + const legacyToolName = `${sig.name}${actionDelimiter}${legacyNormalized}`; + if (!normalizedToolNames.has(toolName) && !normalizedToolNames.has(legacyToolName)) { continue; } @@ -990,15 +1002,17 @@ async function loadAgentTools({ }; } - // Process each action set once (validate spec, decrypt metadata) const processedActionSets = new Map(); - const domainMap = new Map(); + const domainLookupMap = new Map(); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + domainLookupMap.set(domain, domain); - // Check if domain is allowed (do this once per action set) + const legacyDomain = legacyDomainEncode(action.metadata.domain); + if (legacyDomain !== domain) { + domainLookupMap.set(legacyDomain, domain); + } const isDomainAllowed = await isActionDomainAllowed( action.metadata.domain, appConfig?.actions?.allowedDomains, @@ -1060,11 +1074,12 @@ async function loadAgentTools({ continue; } - // Find the matching domain for this tool let currentDomain = ''; - for (const domain of domainMap.keys()) { - if (toolName.includes(domain)) { - currentDomain = domain; + let matchedKey = ''; + for (const [key, canonical] of domainLookupMap.entries()) { + if (toolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; break; } } @@ -1075,7 +1090,7 @@ async function loadAgentTools({ const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = processedActionSets.get(currentDomain); - const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, ''); + const functionName = toolName.replace(`${actionDelimiter}${matchedKey}`, ''); const functionSig = functionSignatures.find((sig) => sig.name === functionName); const requestBuilder = requestBuilders[functionName]; const zodSchema = zodSchemas[functionName]; @@ -1310,12 +1325,20 @@ async function loadActionToolsForExecution({ } const processedActionSets = new Map(); - const domainMap = new Map(); + /** Maps both new and legacy normalized domains to their canonical (new) domain key */ + const normalizedToDomain = new Map(); const allowedDomains = appConfig?.actions?.allowedDomains; for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + normalizedToDomain.set(normalizedDomain, domain); + + const legacyDomain = legacyDomainEncode(action.metadata.domain); + const legacyNormalized = legacyDomain.replace(domainSeparatorRegex, '_'); + if (legacyNormalized !== normalizedDomain) { + normalizedToDomain.set(legacyNormalized, domain); + } const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); if (!isDomainAllowed) { @@ -1364,16 +1387,15 @@ async function loadActionToolsForExecution({ functionSignatures, zodSchemas, encrypted, + legacyNormalized, }); } - const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); for (const toolName of actionToolNames) { let currentDomain = ''; - for (const domain of domainMap.keys()) { - const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + for (const [normalizedDomain, canonicalDomain] of normalizedToDomain.entries()) { if (toolName.includes(normalizedDomain)) { - currentDomain = domain; + currentDomain = canonicalDomain; break; } } @@ -1382,7 +1404,7 @@ async function loadActionToolsForExecution({ continue; } - const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = + const { action, encrypted, zodSchemas, requestBuilders, functionSignatures, legacyNormalized } = processedActionSets.get(currentDomain); const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_'); const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, ''); @@ -1391,6 +1413,25 @@ async function loadActionToolsForExecution({ const zodSchema = zodSchemas[functionName]; if (!requestBuilder) { + const legacyFnName = toolName.replace(`${actionDelimiter}${legacyNormalized}`, ''); + if (legacyFnName !== toolName && requestBuilders[legacyFnName]) { + const legacyTool = await createActionTool({ + userId: req.user.id, + res, + action, + streamId, + encrypted, + requestBuilder: requestBuilders[legacyFnName], + zodSchema: zodSchemas[legacyFnName], + name: toolName, + description: + functionSignatures.find((sig) => sig.name === legacyFnName)?.description ?? '', + useSSRFProtection: !Array.isArray(allowedDomains) || allowedDomains.length === 0, + }); + if (legacyTool) { + loadedActionTools.push(legacyTool); + } + } continue; }