From 9a64791e3ec8670bbc67d14d6eb5173db64557ef Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 17 Mar 2026 01:38:51 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=A2=20fix:=20Action=20Domain=20Encodin?= =?UTF-8?q?g=20Collision=20for=20HTTPS=20URLs=20(#12271)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: strip protocol from domain before encoding in `domainParser` All https:// (and http://) domains produced the same 10-char base64 prefix due to ENCODED_DOMAIN_LENGTH truncation, causing tool name collisions for agents with multiple actions. Strip the protocol before encoding so the base64 key is derived from the hostname. Add `legacyDomainEncode` to preserve the old encoding logic for backward-compatible matching of existing stored actions. * fix: backward-compatible tool matching in ToolService Update `getActionToolDefinitions` to match stored tools against both new and legacy domain encodings. Update `loadActionToolsForExecution` to resolve model-called tool names via a `normalizedToDomain` map that includes both encoding variants, with legacy fallback for request builder lookup. * fix: action route save/delete domain encoding issues Save routes now remove old tools matching either new or legacy domain encoding, preventing stale entries when an action's encoding changes on update. Delete routes no longer re-encode the already-encoded domain extracted from the stored actions array, which was producing incorrect keys and leaving orphaned tools. * test: comprehensive coverage for action domain encoding Rewrite ActionService tests to cover real matching patterns used by ToolService and action routes. Tests verify encode/decode round-trips, protocol stripping, backward-compatible tool name matching at both definition and execution phases, save-route cleanup of old/new encodings, delete-route domain extraction, and the collision fix for multi-action agents. * fix: add legacy domain compat to all execution paths, make legacyDomainEncode sync CRITICAL: processRequiredActions (assistants path) was not updated with legacy domain matching — existing assistants with https:// domain actions would silently fail post-deployment because domainMap only had new encoding. MAJOR: loadAgentTools definitionsOnly=false path had the same issue. Both now use a normalizedToDomain map with legacy+new entries and extract function names via the matched key (not the canonical domain). Also: make legacyDomainEncode synchronous (no async operations), store legacyNormalized in processedActionSets to eliminate recomputation in the per-tool fallback, and hoist domainSeparatorRegex to module level. * refactor: clarify domain variable naming and tool-filter helpers in action routes Rename shadowed 'domain' to 'encodedDomain' to separate raw URL from encoded key in both agent and assistant save routes. Rename shouldRemoveTool to shouldRemoveAgentTool / shouldRemoveAssistantTool to make the distinct data-shape guards explicit. Remove await on now-synchronous legacyDomainEncode. * test: expand coverage for all review findings - Add validateAndUpdateTool tests (protocol-stripping match logic) - Restore unicode domain encode/decode/round-trip tests - Add processRequiredActions matching pattern tests (assistants path) - Add legacy guard skip test for short bare hostnames - Add pre-normalized Set test for definition-phase optimization - Fix corrupt-cache test to assert typeof instead of toBeDefined - Verify legacyDomainEncode is synchronous (not a Promise) - Remove all await on legacyDomainEncode (now sync) 58 tests, up from 44. * fix: address follow-up review findings A-E A: Fix stale JSDoc @returns {Promise} on now-synchronous legacyDomainEncode — changed to @returns {string}. B: Rename normalizedToDomain to domainLookupMap in processRequiredActions and loadAgentTools where keys are raw encoded domains (not normalized), avoiding confusion with loadActionToolsForExecution where keys ARE normalized. C: Pre-normalize actionToolNames into a Set in getActionToolDefinitions, replacing O(signatures × tools) per-check .some() + .replace() with O(1) Set.has() lookups. D: Remove stripProtocol from ActionService exports — it is a one-line internal helper. Spec tests for it removed; behavior is fully covered by domainParser protocol-stripping tests. E: Fix pre-existing bug where processRequiredActions re-loaded action sets on every missing-tool iteration. The guard !actionSets.length always re-triggered because actionSets was reassigned to a plain object (whose .length is undefined). Replaced with a null-check on a dedicated actionSetsData variable. * fix: strip path and query from domain URLs in stripProtocol URLs like 'https://api.example.com/v1/endpoint?foo=bar' previously retained the path after protocol stripping, contaminating the encoded domain key. Now strips everything after the first '/' following the host, using string indexing instead of URL parsing to avoid punycode normalization of unicode hostnames. Closes Copilot review comments 1, 2, and 5. --- api/server/routes/agents/actions.js | 40 +- api/server/routes/assistants/actions.js | 49 +- api/server/services/ActionService.js | 56 +- api/server/services/ActionService.spec.js | 662 +++++++++++++++++----- api/server/services/ToolService.js | 113 ++-- 5 files changed, 693 insertions(+), 227 deletions(-) diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index 4643f096aa..f3970bff22 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -12,7 +12,11 @@ const { validateActionDomain, validateAndParseOpenAPISpec, } = require('librechat-data-provider'); -const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { + legacyDomainEncode, + encryptMetadata, + domainParser, +} = require('~/server/services/ActionService'); const { findAccessibleResources } = require('~/server/services/PermissionService'); const { getAgent, updateAgent, getListAgentsByAccess } = require('~/models/Agent'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); @@ -119,13 +123,14 @@ router.post( return res.status(400).json({ message: 'Domain not allowed' }); } - let { domain } = metadata; - domain = await domainParser(domain, true); + const encodedDomain = await domainParser(metadata.domain, true); - if (!domain) { + if (!encodedDomain) { return res.status(400).json({ message: 'No domain provided' }); } + const legacyDomain = legacyDomainEncode(metadata.domain); + const action_id = _action_id ?? nanoid(); const initialPromises = []; @@ -160,14 +165,23 @@ router.post( actions.push(action); } - actions.push(`${domain}${actionDelimiter}${action_id}`); + actions.push(`${encodedDomain}${actionDelimiter}${action_id}`); /** @type {string[]}} */ const { tools: _tools = [] } = agent; + const shouldRemoveAgentTool = (tool) => { + if (!tool) { + return false; + } + return ( + tool.includes(encodedDomain) || tool.includes(legacyDomain) || tool.includes(action_id) + ); + }; + const tools = _tools - .filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id)))) - .concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`)); + .filter((tool) => !shouldRemoveAgentTool(tool)) + .concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${encodedDomain}`)); // Force version update since actions are changing const updatedAgent = await updateAgent( @@ -231,22 +245,22 @@ router.delete( const { tools = [], actions = [] } = agent; - let domain = ''; + let storedDomain = ''; const updatedActions = actions.filter((action) => { if (action.includes(action_id)) { - [domain] = action.split(actionDelimiter); + [storedDomain] = action.split(actionDelimiter); return false; } return true; }); - domain = await domainParser(domain, true); - - if (!domain) { + if (!storedDomain) { return res.status(400).json({ message: 'No domain provided' }); } - const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain))); + const updatedTools = tools.filter( + (tool) => !(tool && (tool.includes(storedDomain) || tool.includes(action_id))), + ); // Force version update since actions are being removed await updateAgent( diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index b085fbd36a..75ab879e2b 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -3,7 +3,11 @@ const { nanoid } = require('nanoid'); const { logger } = require('@librechat/data-schemas'); const { isActionDomainAllowed } = require('@librechat/api'); const { actionDelimiter, EModelEndpoint, removeNullishValues } = require('librechat-data-provider'); -const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); +const { + legacyDomainEncode, + encryptMetadata, + domainParser, +} = require('~/server/services/ActionService'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAssistantDoc, getAssistant } = require('~/models/Assistant'); @@ -39,13 +43,14 @@ router.post('/:assistant_id', async (req, res) => { return res.status(400).json({ message: 'Domain not allowed' }); } - let { domain } = metadata; - domain = await domainParser(domain, true); + const encodedDomain = await domainParser(metadata.domain, true); - if (!domain) { + if (!encodedDomain) { return res.status(400).json({ message: 'No domain provided' }); } + const legacyDomain = legacyDomainEncode(metadata.domain); + const action_id = _action_id ?? nanoid(); const initialPromises = []; @@ -81,25 +86,29 @@ router.post('/:assistant_id', async (req, res) => { actions.push(action); } - actions.push(`${domain}${actionDelimiter}${action_id}`); + actions.push(`${encodedDomain}${actionDelimiter}${action_id}`); /** @type {{ tools: FunctionTool[] | { type: 'code_interpreter'|'retrieval'}[]}} */ const { tools: _tools = [] } = assistant; + const shouldRemoveAssistantTool = (tool) => { + if (!tool.function) { + return false; + } + const name = tool.function.name; + return ( + name.includes(encodedDomain) || name.includes(legacyDomain) || name.includes(action_id) + ); + }; + const tools = _tools - .filter( - (tool) => - !( - tool.function && - (tool.function.name.includes(domain) || tool.function.name.includes(action_id)) - ), - ) + .filter((tool) => !shouldRemoveAssistantTool(tool)) .concat( functions.map((tool) => ({ ...tool, function: { ...tool.function, - name: `${tool.function.name}${actionDelimiter}${domain}`, + name: `${tool.function.name}${actionDelimiter}${encodedDomain}`, }, })), ); @@ -171,23 +180,25 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { const { actions = [] } = assistant_data ?? {}; const { tools = [] } = assistant ?? {}; - let domain = ''; + let storedDomain = ''; const updatedActions = actions.filter((action) => { if (action.includes(action_id)) { - [domain] = action.split(actionDelimiter); + [storedDomain] = action.split(actionDelimiter); return false; } return true; }); - domain = await domainParser(domain, true); - - if (!domain) { + if (!storedDomain) { return res.status(400).json({ message: 'No domain provided' }); } const updatedTools = tools.filter( - (tool) => !(tool.function && tool.function.name.includes(domain)), + (tool) => + !( + tool.function && + (tool.function.name.includes(storedDomain) || tool.function.name.includes(action_id)) + ), ); await openai.beta.assistants.update(assistant_id, { tools: updatedTools }); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 5e96726a46..bde052bba4 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -28,6 +28,7 @@ const { getLogStores } = require('~/cache'); const JWT_SECRET = process.env.JWT_SECRET; const toolNameRegex = /^[a-zA-Z0-9_-]+$/; +const protocolRegex = /^https?:\/\//; const replaceSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); /** @@ -48,7 +49,11 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { actions = await getActions({ assistant_id, user: req.user.id }, true); const matchingActions = actions.filter((action) => { const metadata = action.metadata; - return metadata && metadata.domain === domain; + if (!metadata) { + return false; + } + const strippedMetaDomain = stripProtocol(metadata.domain); + return strippedMetaDomain === domain || metadata.domain === domain; }); const action = matchingActions[0]; if (!action) { @@ -66,10 +71,36 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { return tool; }; +/** @param {string} domain */ +function stripProtocol(domain) { + const stripped = domain.replace(protocolRegex, ''); + const pathIdx = stripped.indexOf('/'); + return pathIdx === -1 ? stripped : stripped.substring(0, pathIdx); +} + +/** + * Encodes a domain using the legacy scheme (full URL including protocol). + * Used for backward-compatible matching against agents saved before the collision fix. + * @param {string} domain + * @returns {string} + */ +function legacyDomainEncode(domain) { + if (!domain) { + return ''; + } + if (domain.length <= Constants.ENCODED_DOMAIN_LENGTH) { + return domain.replace(/\./g, actionDomainSeparator); + } + const modifiedDomain = Buffer.from(domain).toString('base64'); + return modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH); +} + /** * Encodes or decodes a domain name to/from base64, or replacing periods with a custom separator. * * Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. + * Strips protocol prefix before encoding to prevent base64 collisions + * (all `https://` URLs share the same 10-char base64 prefix). * * @param {string} domain - The domain name to encode/decode. * @param {boolean} inverse - False to decode from base64, true to encode to base64. @@ -79,23 +110,27 @@ async function domainParser(domain, inverse = false) { if (!domain) { return; } - const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); - const cachedDomain = await domainsCache.get(domain); - if (inverse && cachedDomain) { - return domain; - } - if (inverse && domain.length <= Constants.ENCODED_DOMAIN_LENGTH) { - return domain.replace(/\./g, actionDomainSeparator); - } + const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); if (inverse) { - const modifiedDomain = Buffer.from(domain).toString('base64'); + const hostname = stripProtocol(domain); + const cachedDomain = await domainsCache.get(hostname); + if (cachedDomain) { + return hostname; + } + + if (hostname.length <= Constants.ENCODED_DOMAIN_LENGTH) { + return hostname.replace(/\./g, actionDomainSeparator); + } + + const modifiedDomain = Buffer.from(hostname).toString('base64'); const key = modifiedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH); await domainsCache.set(key, modifiedDomain); return key; } + const cachedDomain = await domainsCache.get(domain); if (!cachedDomain) { return domain.replace(replaceSeparatorRegex, '.'); } @@ -456,6 +491,7 @@ const deleteAssistantActions = async ({ req, assistant_id }) => { module.exports = { deleteAssistantActions, validateAndUpdateTool, + legacyDomainEncode, createActionTool, encryptMetadata, decryptMetadata, diff --git a/api/server/services/ActionService.spec.js b/api/server/services/ActionService.spec.js index c60aef7ad1..42def44b4f 100644 --- a/api/server/services/ActionService.spec.js +++ b/api/server/services/ActionService.spec.js @@ -1,175 +1,539 @@ -const { Constants, actionDomainSeparator } = require('librechat-data-provider'); -const { domainParser } = require('./ActionService'); +const { Constants, actionDelimiter, actionDomainSeparator } = require('librechat-data-provider'); +const { domainParser, legacyDomainEncode, validateAndUpdateTool } = require('./ActionService'); jest.mock('keyv'); -const globalCache = {}; +jest.mock('~/models/Action', () => ({ + getActions: jest.fn(), + deleteActions: jest.fn(), +})); + +const { getActions } = require('~/models/Action'); + +let mockDomainCache = {}; jest.mock('~/cache/getLogStores', () => { - return jest.fn().mockImplementation(() => { - const EventEmitter = require('events'); - const { CacheKeys } = require('librechat-data-provider'); + return jest.fn().mockImplementation(() => ({ + get: async (key) => mockDomainCache[key] ?? null, + set: async (key, value) => { + mockDomainCache[key] = value; + return true; + }, + })); +}); - class KeyvMongo extends EventEmitter { - constructor(url = 'mongodb://127.0.0.1:27017', options) { - super(); - this.ttlSupport = false; - url = url ?? {}; - if (typeof url === 'string') { - url = { url }; - } - if (url.uri) { - url = { url: url.uri, ...url }; - } - this.opts = { - url, - collection: 'keyv', - ...url, - ...options, - }; - } +beforeEach(() => { + mockDomainCache = {}; + getActions.mockReset(); +}); - get = async (key) => { - return new Promise((resolve) => { - resolve(globalCache[key] || null); - }); - }; +const SEP = actionDomainSeparator; +const DELIM = actionDelimiter; +const MAX = Constants.ENCODED_DOMAIN_LENGTH; +const domainSepRegex = new RegExp(SEP, 'g'); - set = async (key, value) => { - return new Promise((resolve) => { - globalCache[key] = value; - resolve(true); - }); - }; - } +describe('domainParser', () => { + describe('nullish input', () => { + it.each([null, undefined, ''])('returns undefined for %j', async (input) => { + expect(await domainParser(input, true)).toBeUndefined(); + expect(await domainParser(input, false)).toBeUndefined(); + }); + }); - return new KeyvMongo('', { - namespace: CacheKeys.ENCODED_DOMAINS, - ttl: 0, + describe('short-path encoding (hostname ≤ threshold)', () => { + it.each([ + ['examp.com', `examp${SEP}com`], + ['swapi.tech', `swapi${SEP}tech`], + ['a.b', `a${SEP}b`], + ])('replaces dots in %s → %s', async (domain, expected) => { + expect(await domainParser(domain, true)).toBe(expected); + }); + + it('handles domain exactly at threshold length', async () => { + const domain = 'a'.repeat(MAX - 4) + '.com'; + expect(domain).toHaveLength(MAX); + const result = await domainParser(domain, true); + expect(result).toBe(domain.replace(/\./g, SEP)); + }); + }); + + describe('base64-path encoding (hostname > threshold)', () => { + it('produces a key of exactly ENCODED_DOMAIN_LENGTH chars', async () => { + const result = await domainParser('api.example.com', true); + expect(result).toHaveLength(MAX); + }); + + it('encodes hostname, not full URL', async () => { + const hostname = 'api.example.com'; + const expectedKey = Buffer.from(hostname).toString('base64').substring(0, MAX); + expect(await domainParser(hostname, true)).toBe(expectedKey); + }); + + it('populates decode cache for round-trip', async () => { + const hostname = 'longdomainname.com'; + const key = await domainParser(hostname, true); + + expect(mockDomainCache[key]).toBe(Buffer.from(hostname).toString('base64')); + expect(await domainParser(key, false)).toBe(hostname); + }); + }); + + describe('protocol stripping', () => { + it('https:// URL and bare hostname produce identical encoding', async () => { + const encoded = await domainParser('https://swapi.tech', true); + expect(encoded).toBe(await domainParser('swapi.tech', true)); + expect(encoded).toBe(`swapi${SEP}tech`); + }); + + it('http:// URL and bare hostname produce identical encoding', async () => { + const encoded = await domainParser('http://api.example.com', true); + expect(encoded).toBe(await domainParser('api.example.com', true)); + }); + + it('different https:// domains produce unique keys', async () => { + const keys = await Promise.all([ + domainParser('https://api.example.com', true), + domainParser('https://api.weather.com', true), + domainParser('https://data.github.com', true), + ]); + const unique = new Set(keys); + expect(unique.size).toBe(keys.length); + }); + + it('long hostname after stripping still uses base64 path', async () => { + const result = await domainParser('https://api.example.com', true); + expect(result).toHaveLength(MAX); + expect(result).not.toContain(SEP); + }); + + it('short hostname after stripping uses dot-replacement path', async () => { + const result = await domainParser('https://a.b.c', true); + expect(result).toBe(`a${SEP}b${SEP}c`); + }); + + it('strips path and query from full URL before encoding', async () => { + const result = await domainParser('https://api.example.com/v1/endpoint?foo=bar', true); + expect(result).toBe(await domainParser('api.example.com', true)); + }); + }); + + describe('unicode domains', () => { + it('encodes unicode hostname via base64 path', async () => { + const domain = 'täst.example.com'; + const result = await domainParser(domain, true); + expect(result).toHaveLength(MAX); + expect(result).toBe(Buffer.from(domain).toString('base64').substring(0, MAX)); + }); + + it('round-trips unicode hostname through encode then decode', async () => { + const domain = 'täst.example.com'; + const key = await domainParser(domain, true); + expect(await domainParser(key, false)).toBe(domain); + }); + + it('strips protocol before encoding unicode hostname', async () => { + const withProto = 'https://täst.example.com'; + const bare = 'täst.example.com'; + expect(await domainParser(withProto, true)).toBe(await domainParser(bare, true)); + }); + }); + + describe('decode path', () => { + it('short-path encoded domain decodes via separator replacement', async () => { + expect(await domainParser(`examp${SEP}com`, false)).toBe('examp.com'); + }); + + it('base64-path encoded domain decodes via cache lookup', async () => { + const hostname = 'api.example.com'; + const key = await domainParser(hostname, true); + expect(await domainParser(key, false)).toBe(hostname); + }); + + it('returns input unchanged for unknown non-separator strings', async () => { + expect(await domainParser('not_base64_encoded', false)).toBe('not_base64_encoded'); + }); + + it('returns a string without throwing for corrupt cache entries', async () => { + mockDomainCache['corrupt_key'] = '!!!'; + const result = await domainParser('corrupt_key', false); + expect(typeof result).toBe('string'); }); }); }); -describe('domainParser', () => { - const TLD = '.com'; - - // Non-azure request - it('does not return domain as is if not azure', async () => { - const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`; - const result1 = await domainParser(domain, false); - const result2 = await domainParser(domain, true); - expect(result1).not.toEqual(domain); - expect(result2).not.toEqual(domain); +describe('legacyDomainEncode', () => { + it.each(['', null, undefined])('returns empty string for %j', (input) => { + expect(legacyDomainEncode(input)).toBe(''); }); - // Test for Empty or Null Inputs - it('returns undefined for null domain input', async () => { - const result = await domainParser(null, true); - expect(result).toBeUndefined(); + it('is synchronous (returns a string, not a Promise)', () => { + const result = legacyDomainEncode('examp.com'); + expect(result).toBe(`examp${SEP}com`); + expect(result).not.toBeInstanceOf(Promise); }); - it('returns undefined for empty domain input', async () => { - const result = await domainParser('', true); - expect(result).toBeUndefined(); + it('uses dot-replacement for short domains', () => { + expect(legacyDomainEncode('examp.com')).toBe(`examp${SEP}com`); }); - // Verify Correct Caching Behavior - it('caches encoded domain correctly', async () => { - const domain = 'longdomainname.com'; - const encodedDomain = Buffer.from(domain) - .toString('base64') - .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - - await domainParser(domain, true); - - const cachedValue = await globalCache[encodedDomain]; - expect(cachedValue).toEqual(Buffer.from(domain).toString('base64')); + it('uses base64 prefix of full input for long domains', () => { + const domain = 'https://swapi.tech'; + const expected = Buffer.from(domain).toString('base64').substring(0, MAX); + expect(legacyDomainEncode(domain)).toBe(expected); }); - // Test for Edge Cases Around Length Threshold - it('encodes domain exactly at threshold without modification', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); + it('all https:// URLs collide to the same key', () => { + const results = [ + legacyDomainEncode('https://api.example.com'), + legacyDomainEncode('https://api.weather.com'), + legacyDomainEncode('https://totally.different.host'), + ]; + expect(new Set(results).size).toBe(1); }); - it('encodes domain just below threshold without modification', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); + it('matches what old domainParser would have produced', () => { + const domain = 'https://api.example.com'; + const legacy = legacyDomainEncode(domain); + expect(legacy).toBe(Buffer.from(domain).toString('base64').substring(0, MAX)); }); - // Test for Unicode Domain Names - it('handles unicode characters in domain names correctly when encoding', async () => { - const unicodeDomain = 'täst.example.com'; - const encodedDomain = Buffer.from(unicodeDomain) - .toString('base64') - .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - const result = await domainParser(unicodeDomain, true); - expect(result).toEqual(encodedDomain); - }); - - it('decodes unicode domain names correctly', async () => { - const unicodeDomain = 'täst.example.com'; - const encodedDomain = Buffer.from(unicodeDomain).toString('base64'); - globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching - - const result = await domainParser( - encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH), - false, - ); - expect(result).toEqual(unicodeDomain); - }); - - // Core Functionality Tests - it('returns domain with replaced separators if no cached domain exists', async () => { - const domain = 'example.com'; - const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(withSeparator, false); - expect(result).toEqual(domain); - }); - - it('returns domain with replaced separators when inverse is false and under encoding length', async () => { - const domain = 'examp.com'; - const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(withSeparator, false); - expect(result).toEqual(domain); - }); - - it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => { - const domain = 'examp.com'; - const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(domain, true); - expect(result).toEqual(expected); - }); - - it('encodes domain when length is above threshold and inverse is true', async () => { - const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com'); - const result = await domainParser(domain, true); - expect(result).not.toEqual(domain); - expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH); - }); - - it('returns encoded value if no encoded value is cached, and inverse is false', async () => { - const originalDomain = 'example.com'; - const encodedDomain = Buffer.from( - originalDomain.replace(/\./g, actionDomainSeparator), - ).toString('base64'); - const result = await domainParser(encodedDomain, false); - expect(result).toEqual(encodedDomain); - }); - - it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => { - const originalDomain = 'example.com'; - const encodedDomain = await domainParser(originalDomain, true); - const result = await domainParser(encodedDomain, false); - expect(result).toEqual(originalDomain); - }); - - it('handles invalid base64 encoded values gracefully', async () => { - const invalidBase64Domain = 'not_base64_encoded'; - const result = await domainParser(invalidBase64Domain, false); - expect(result).toEqual(invalidBase64Domain); + it('produces same result as new domainParser for short bare hostnames', async () => { + const domain = 'swapi.tech'; + expect(legacyDomainEncode(domain)).toBe(await domainParser(domain, true)); + }); +}); + +describe('validateAndUpdateTool', () => { + const mockReq = { user: { id: 'user123' } }; + + it('returns tool unchanged when name passes tool-name regex', async () => { + const tool = { function: { name: 'getPeople_action_swapi---tech' } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + expect(result).toEqual(tool); + expect(getActions).not.toHaveBeenCalled(); + }); + + it('matches action when metadata.domain has https:// prefix and tool domain is bare hostname', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'https://api.example.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).not.toBeNull(); + expect(result.function.name).toMatch(/^getPeople_action_/); + expect(result.function.name).not.toContain('.'); + }); + + it('matches action when metadata.domain has no protocol', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'api.example.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).not.toBeNull(); + expect(result.function.name).toMatch(/^getPeople_action_/); + }); + + it('returns null when no action matches the domain', async () => { + getActions.mockResolvedValue([{ metadata: { domain: 'https://other.domain.com' } }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).toBeNull(); + }); + + it('returns null when action has no metadata', async () => { + getActions.mockResolvedValue([{ metadata: null }]); + + const tool = { function: { name: `getPeople${DELIM}api.example.com` } }; + const result = await validateAndUpdateTool({ + req: mockReq, + tool, + assistant_id: 'asst_1', + }); + + expect(result).toBeNull(); + }); +}); + +describe('backward-compatible tool name matching', () => { + function normalizeToolName(name) { + return name.replace(domainSepRegex, '_'); + } + + function buildToolName(functionName, encodedDomain) { + return `${functionName}${DELIM}${encodedDomain}`; + } + + describe('definition-phase matching', () => { + it('new encoding matches agent tools stored with new encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const encoded = await domainParser(metadataDomain, true); + const normalized = normalizeToolName(encoded); + + const storedTool = buildToolName('getPeople', encoded); + const defToolName = `getPeople${DELIM}${normalized}`; + + expect(normalizeToolName(storedTool)).toBe(defToolName); + }); + + it('legacy encoding matches agent tools stored with legacy encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const legacy = legacyDomainEncode(metadataDomain); + const legacyNormalized = normalizeToolName(legacy); + + const storedTool = buildToolName('getPeople', legacy); + const legacyDefName = `getPeople${DELIM}${legacyNormalized}`; + + expect(normalizeToolName(storedTool)).toBe(legacyDefName); + }); + + it('new definition matches old stored tools via legacy fallback', async () => { + const metadataDomain = 'https://swapi.tech'; + const newDomain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const newNorm = normalizeToolName(newDomain); + const legacyNorm = normalizeToolName(legacyDomain); + + const oldStoredTool = buildToolName('getPeople', legacyDomain); + const newToolName = `getPeople${DELIM}${newNorm}`; + const legacyToolName = `getPeople${DELIM}${legacyNorm}`; + + const storedNormalized = normalizeToolName(oldStoredTool); + const hasMatch = storedNormalized === newToolName || storedNormalized === legacyToolName; + expect(hasMatch).toBe(true); + }); + + it('pre-normalized Set eliminates per-tool normalization', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const normalizedDomain = normalizeToolName(domain); + const legacyNormalized = normalizeToolName(legacyDomain); + + const storedTools = [ + buildToolName('getWeather', legacyDomain), + buildToolName('getForecast', domain), + ]; + + const preNormalized = new Set(storedTools.map((t) => normalizeToolName(t))); + + const toolName = `getWeather${DELIM}${normalizedDomain}`; + const legacyToolName = `getWeather${DELIM}${legacyNormalized}`; + expect(preNormalized.has(toolName) || preNormalized.has(legacyToolName)).toBe(true); + }); + }); + + describe('execution-phase tool lookup', () => { + it('model-called tool name resolves via normalizedToDomain map (new encoding)', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const normalized = normalizeToolName(domain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(normalized, domain); + + const modelToolName = `getWeather${DELIM}${normalized}`; + + let matched = ''; + for (const [norm, canonical] of normalizedToDomain.entries()) { + if (modelToolName.includes(norm)) { + matched = canonical; + break; + } + } + + expect(matched).toBe(domain); + + const functionName = modelToolName.replace(`${DELIM}${normalizeToolName(matched)}`, ''); + expect(functionName).toBe('getWeather'); + }); + + it('model-called tool name resolves via legacy entry in normalizedToDomain map', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + const legacyNorm = normalizeToolName(legacyDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(normalizeToolName(domain), domain); + normalizedToDomain.set(legacyNorm, domain); + + const legacyModelToolName = `getWeather${DELIM}${legacyNorm}`; + + let matched = ''; + for (const [norm, canonical] of normalizedToDomain.entries()) { + if (legacyModelToolName.includes(norm)) { + matched = canonical; + break; + } + } + + expect(matched).toBe(domain); + }); + + it('legacy guard skips duplicate map entry for short bare hostnames', async () => { + const domain = 'swapi.tech'; + const newEncoding = await domainParser(domain, true); + const legacyEncoding = legacyDomainEncode(domain); + + expect(newEncoding).toBe(legacyEncoding); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(newEncoding, newEncoding); + if (legacyEncoding !== newEncoding) { + normalizedToDomain.set(legacyEncoding, newEncoding); + } + expect(normalizedToDomain.size).toBe(1); + }); + }); + + describe('processRequiredActions matching (assistants path)', () => { + it('legacy tool from OpenAI matches via normalizedToDomain with both encodings', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(domain, domain); + if (legacyDomain !== domain) { + normalizedToDomain.set(legacyDomain, domain); + } + + const legacyToolName = buildToolName('getPeople', legacyDomain); + + let currentDomain = ''; + let matchedKey = ''; + for (const [key, canonical] of normalizedToDomain.entries()) { + if (legacyToolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; + break; + } + } + + expect(currentDomain).toBe(domain); + expect(matchedKey).toBe(legacyDomain); + + const functionName = legacyToolName.replace(`${DELIM}${matchedKey}`, ''); + expect(functionName).toBe('getPeople'); + }); + + it('new tool name matches via the canonical domain key', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const normalizedToDomain = new Map(); + normalizedToDomain.set(domain, domain); + if (legacyDomain !== domain) { + normalizedToDomain.set(legacyDomain, domain); + } + + const newToolName = buildToolName('getPeople', domain); + + let currentDomain = ''; + let matchedKey = ''; + for (const [key, canonical] of normalizedToDomain.entries()) { + if (newToolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; + break; + } + } + + expect(currentDomain).toBe(domain); + expect(matchedKey).toBe(domain); + + const functionName = newToolName.replace(`${DELIM}${matchedKey}`, ''); + expect(functionName).toBe('getPeople'); + }); + }); + + describe('save-route cleanup', () => { + it('tool filter removes tools matching new encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const tools = [ + buildToolName('getPeople', domain), + buildToolName('unrelated', 'other---domain'), + ]; + + const filtered = tools.filter((t) => !t.includes(domain) && !t.includes(legacyDomain)); + + expect(filtered).toEqual([buildToolName('unrelated', 'other---domain')]); + }); + + it('tool filter removes tools matching legacy encoding', async () => { + const metadataDomain = 'https://swapi.tech'; + const domain = await domainParser(metadataDomain, true); + const legacyDomain = legacyDomainEncode(metadataDomain); + + const tools = [ + buildToolName('getPeople', legacyDomain), + buildToolName('unrelated', 'other---domain'), + ]; + + const filtered = tools.filter((t) => !t.includes(domain) && !t.includes(legacyDomain)); + + expect(filtered).toEqual([buildToolName('unrelated', 'other---domain')]); + }); + }); + + describe('delete-route domain extraction', () => { + it('domain extracted from actions array is usable as-is for tool filtering', async () => { + const metadataDomain = 'https://api.example.com'; + const domain = await domainParser(metadataDomain, true); + const actionId = 'abc123'; + const actionEntry = `${domain}${DELIM}${actionId}`; + + const [storedDomain] = actionEntry.split(DELIM); + expect(storedDomain).toBe(domain); + + const tools = [buildToolName('getWeather', domain), buildToolName('getPeople', 'other')]; + + const filtered = tools.filter((t) => !t.includes(storedDomain)); + expect(filtered).toEqual([buildToolName('getPeople', 'other')]); + }); + }); + + describe('multi-action agents (collision scenario)', () => { + it('two https:// actions now produce distinct tool names', async () => { + const domain1 = await domainParser('https://api.weather.com', true); + const domain2 = await domainParser('https://api.spacex.com', true); + + const tool1 = buildToolName('getData', domain1); + const tool2 = buildToolName('getData', domain2); + + expect(tool1).not.toBe(tool2); + }); + + it('two https:// actions used to collide in legacy encoding', () => { + const legacy1 = legacyDomainEncode('https://api.weather.com'); + const legacy2 = legacyDomainEncode('https://api.spacex.com'); + + const tool1 = buildToolName('getData', legacy1); + const tool2 = buildToolName('getData', legacy2); + + expect(tool1).toBe(tool2); + }); }); }); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 5fc95e748d..ca75e7eb4f 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -42,6 +42,7 @@ const { } = require('librechat-data-provider'); const { createActionTool, + legacyDomainEncode, decryptMetadata, loadActionSets, domainParser, @@ -65,6 +66,8 @@ const { findPluginAuthsByKeys } = require('~/models'); const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); +const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + /** * Resolves the set of enabled agent capabilities from endpoints config, * falling back to app-level or default capabilities for ephemeral agents. @@ -172,8 +175,7 @@ async function processRequiredActions(client, requiredActions) { const promises = []; - /** @type {Action[]} */ - let actionSets = []; + let actionSetsData = null; let isActionTool = false; const ActionToolMap = {}; const ActionBuildersMap = {}; @@ -259,9 +261,9 @@ async function processRequiredActions(client, requiredActions) { if (!tool) { // throw new Error(`Tool ${currentAction.tool} not found.`); - // Load all action sets once if not already loaded - if (!actionSets.length) { - actionSets = + if (!actionSetsData) { + /** @type {Action[]} */ + const actionSets = (await loadActionSets({ assistant_id: client.req.body.assistant_id, })) ?? []; @@ -269,11 +271,16 @@ async function processRequiredActions(client, requiredActions) { // Process all action sets once // Map domains to their processed action sets const processedDomains = new Map(); - const domainMap = new Map(); + const domainLookupMap = new Map(); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + domainLookupMap.set(domain, domain); + + const legacyDomain = legacyDomainEncode(action.metadata.domain); + if (legacyDomain !== domain) { + domainLookupMap.set(legacyDomain, domain); + } const isDomainAllowed = await isActionDomainAllowed( action.metadata.domain, @@ -328,27 +335,26 @@ async function processRequiredActions(client, requiredActions) { ActionBuildersMap[action.metadata.domain] = requestBuilders; } - // Update actionSets reference to use the domain map - actionSets = { domainMap, processedDomains }; + actionSetsData = { domainLookupMap, processedDomains }; } - // Find the matching domain for this tool let currentDomain = ''; - for (const domain of actionSets.domainMap.keys()) { - if (currentAction.tool.includes(domain)) { - currentDomain = domain; + let matchedKey = ''; + for (const [key, canonical] of actionSetsData.domainLookupMap.entries()) { + if (currentAction.tool.includes(key)) { + currentDomain = canonical; + matchedKey = key; break; } } - if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) { - // TODO: try `function` if no action set is found - // throw new Error(`Tool ${currentAction.tool} not found.`); + if (!currentDomain || !actionSetsData.processedDomains.has(currentDomain)) { continue; } - const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain); - const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, ''); + const { action, requestBuilders, encrypted } = + actionSetsData.processedDomains.get(currentDomain); + const functionName = currentAction.tool.replace(`${actionDelimiter}${matchedKey}`, ''); const requestBuilder = requestBuilders[functionName]; if (!requestBuilder) { @@ -586,12 +592,17 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to const definitions = []; const allowedDomains = appConfig?.actions?.allowedDomains; - const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + const normalizedToolNames = new Set( + actionToolNames.map((n) => n.replace(domainSeparatorRegex, '_')), + ); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + const legacyDomain = legacyDomainEncode(action.metadata.domain); + const legacyNormalized = legacyDomain.replace(domainSeparatorRegex, '_'); + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); if (!isDomainAllowed) { logger.warn( @@ -611,7 +622,8 @@ async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, to for (const sig of functionSignatures) { const toolName = `${sig.name}${actionDelimiter}${normalizedDomain}`; - if (!actionToolNames.some((name) => name.replace(domainSeparatorRegex, '_') === toolName)) { + const legacyToolName = `${sig.name}${actionDelimiter}${legacyNormalized}`; + if (!normalizedToolNames.has(toolName) && !normalizedToolNames.has(legacyToolName)) { continue; } @@ -990,15 +1002,17 @@ async function loadAgentTools({ }; } - // Process each action set once (validate spec, decrypt metadata) const processedActionSets = new Map(); - const domainMap = new Map(); + const domainLookupMap = new Map(); for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + domainLookupMap.set(domain, domain); - // Check if domain is allowed (do this once per action set) + const legacyDomain = legacyDomainEncode(action.metadata.domain); + if (legacyDomain !== domain) { + domainLookupMap.set(legacyDomain, domain); + } const isDomainAllowed = await isActionDomainAllowed( action.metadata.domain, appConfig?.actions?.allowedDomains, @@ -1060,11 +1074,12 @@ async function loadAgentTools({ continue; } - // Find the matching domain for this tool let currentDomain = ''; - for (const domain of domainMap.keys()) { - if (toolName.includes(domain)) { - currentDomain = domain; + let matchedKey = ''; + for (const [key, canonical] of domainLookupMap.entries()) { + if (toolName.includes(key)) { + currentDomain = canonical; + matchedKey = key; break; } } @@ -1075,7 +1090,7 @@ async function loadAgentTools({ const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = processedActionSets.get(currentDomain); - const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, ''); + const functionName = toolName.replace(`${actionDelimiter}${matchedKey}`, ''); const functionSig = functionSignatures.find((sig) => sig.name === functionName); const requestBuilder = requestBuilders[functionName]; const zodSchema = zodSchemas[functionName]; @@ -1310,12 +1325,20 @@ async function loadActionToolsForExecution({ } const processedActionSets = new Map(); - const domainMap = new Map(); + /** Maps both new and legacy normalized domains to their canonical (new) domain key */ + const normalizedToDomain = new Map(); const allowedDomains = appConfig?.actions?.allowedDomains; for (const action of actionSets) { const domain = await domainParser(action.metadata.domain, true); - domainMap.set(domain, action); + const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + normalizedToDomain.set(normalizedDomain, domain); + + const legacyDomain = legacyDomainEncode(action.metadata.domain); + const legacyNormalized = legacyDomain.replace(domainSeparatorRegex, '_'); + if (legacyNormalized !== normalizedDomain) { + normalizedToDomain.set(legacyNormalized, domain); + } const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); if (!isDomainAllowed) { @@ -1364,16 +1387,15 @@ async function loadActionToolsForExecution({ functionSignatures, zodSchemas, encrypted, + legacyNormalized, }); } - const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); for (const toolName of actionToolNames) { let currentDomain = ''; - for (const domain of domainMap.keys()) { - const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + for (const [normalizedDomain, canonicalDomain] of normalizedToDomain.entries()) { if (toolName.includes(normalizedDomain)) { - currentDomain = domain; + currentDomain = canonicalDomain; break; } } @@ -1382,7 +1404,7 @@ async function loadActionToolsForExecution({ continue; } - const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = + const { action, encrypted, zodSchemas, requestBuilders, functionSignatures, legacyNormalized } = processedActionSets.get(currentDomain); const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_'); const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, ''); @@ -1391,6 +1413,25 @@ async function loadActionToolsForExecution({ const zodSchema = zodSchemas[functionName]; if (!requestBuilder) { + const legacyFnName = toolName.replace(`${actionDelimiter}${legacyNormalized}`, ''); + if (legacyFnName !== toolName && requestBuilders[legacyFnName]) { + const legacyTool = await createActionTool({ + userId: req.user.id, + res, + action, + streamId, + encrypted, + requestBuilder: requestBuilders[legacyFnName], + zodSchema: zodSchemas[legacyFnName], + name: toolName, + description: + functionSignatures.find((sig) => sig.name === legacyFnName)?.description ?? '', + useSSRFProtection: !Array.isArray(allowedDomains) || allowedDomains.length === 0, + }); + if (legacyTool) { + loadedActionTools.push(legacyTool); + } + } continue; }