diff --git a/api/app/clients/tools/structured/TavilySearch.js b/api/app/clients/tools/structured/TavilySearch.js index b5478d0fc8..55f3b6e1c8 100644 --- a/api/app/clients/tools/structured/TavilySearch.js +++ b/api/app/clients/tools/structured/TavilySearch.js @@ -1,4 +1,5 @@ const { z } = require('zod'); +const { ProxyAgent, fetch } = require('undici'); const { tool } = require('@langchain/core/tools'); const { getApiKey } = require('./credentials'); @@ -19,13 +20,19 @@ function createTavilySearchTool(fields = {}) { ...kwargs, }; - const response = await fetch('https://api.tavily.com/search', { + const fetchOptions = { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify(requestBody), - }); + }; + + if (process.env.PROXY) { + fetchOptions.dispatcher = new ProxyAgent(process.env.PROXY); + } + + const response = await fetch('https://api.tavily.com/search', fetchOptions); const json = await response.json(); if (!response.ok) { diff --git a/api/app/clients/tools/structured/TavilySearchResults.js b/api/app/clients/tools/structured/TavilySearchResults.js index 9461293371..796f31dcca 100644 --- a/api/app/clients/tools/structured/TavilySearchResults.js +++ b/api/app/clients/tools/structured/TavilySearchResults.js @@ -1,4 +1,5 @@ const { z } = require('zod'); +const { ProxyAgent, fetch } = require('undici'); const { Tool } = require('@langchain/core/tools'); const { getEnvironmentVariable } = require('@langchain/core/utils/env'); @@ -102,13 +103,19 @@ class TavilySearchResults extends Tool { ...this.kwargs, }; - const response = await fetch('https://api.tavily.com/search', { + const fetchOptions = { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify(requestBody), - }); + }; + + if (process.env.PROXY) { + fetchOptions.dispatcher = new ProxyAgent(process.env.PROXY); + } + + const response = await fetch('https://api.tavily.com/search', fetchOptions); const json = await response.json(); if (!response.ok) { diff --git a/api/app/clients/tools/structured/specs/TavilySearchResults.spec.js b/api/app/clients/tools/structured/specs/TavilySearchResults.spec.js index 5ea00140c7..f37c83e30e 100644 --- a/api/app/clients/tools/structured/specs/TavilySearchResults.spec.js +++ b/api/app/clients/tools/structured/specs/TavilySearchResults.spec.js @@ -1,6 +1,7 @@ +const { fetch, ProxyAgent } = require('undici'); const TavilySearchResults = require('../TavilySearchResults'); -jest.mock('node-fetch'); +jest.mock('undici'); jest.mock('@langchain/core/utils/env'); describe('TavilySearchResults', () => { @@ -13,6 +14,7 @@ describe('TavilySearchResults', () => { beforeEach(() => { jest.resetModules(); + jest.clearAllMocks(); process.env = { ...originalEnv, TAVILY_API_KEY: mockApiKey, @@ -20,7 +22,6 @@ describe('TavilySearchResults', () => { }); afterEach(() => { - jest.clearAllMocks(); process.env = originalEnv; }); @@ -35,4 +36,49 @@ describe('TavilySearchResults', () => { }); expect(instance.apiKey).toBe(mockApiKey); }); + + describe('proxy support', () => { + const mockResponse = { + ok: true, + json: jest.fn().mockResolvedValue({ results: [] }), + }; + + beforeEach(() => { + fetch.mockResolvedValue(mockResponse); + }); + + it('should use ProxyAgent when PROXY env var is set', async () => { + const proxyUrl = 'http://proxy.example.com:8080'; + process.env.PROXY = proxyUrl; + + const mockProxyAgent = { type: 'proxy-agent' }; + ProxyAgent.mockImplementation(() => mockProxyAgent); + + const instance = new TavilySearchResults({ TAVILY_API_KEY: mockApiKey }); + await instance._call({ query: 'test query' }); + + expect(ProxyAgent).toHaveBeenCalledWith(proxyUrl); + expect(fetch).toHaveBeenCalledWith( + 'https://api.tavily.com/search', + expect.objectContaining({ + dispatcher: mockProxyAgent, + }), + ); + }); + + it('should not use ProxyAgent when PROXY env var is not set', async () => { + delete process.env.PROXY; + + const instance = new TavilySearchResults({ TAVILY_API_KEY: mockApiKey }); + await instance._call({ query: 'test query' }); + + expect(ProxyAgent).not.toHaveBeenCalled(); + expect(fetch).toHaveBeenCalledWith( + 'https://api.tavily.com/search', + expect.not.objectContaining({ + dispatcher: expect.anything(), + }), + ); + }); + }); });