diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index 26610f73ba..c48db1d764 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -51,6 +51,10 @@ class DALLE3 extends Tool { this.fileStrategy = fields.fileStrategy; /** @type {boolean} */ this.isAgent = fields.isAgent; + if (this.isAgent) { + /** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */ + this.responseFormat = 'content_and_artifact'; + } if (fields.processFileURL) { /** @type {processFileURL} Necessary for output to contain all image metadata. */ this.processFileURL = fields.processFileURL.bind(this); diff --git a/api/app/clients/tools/structured/FluxAPI.js b/api/app/clients/tools/structured/FluxAPI.js index 56f86a707d..f8341f7904 100644 --- a/api/app/clients/tools/structured/FluxAPI.js +++ b/api/app/clients/tools/structured/FluxAPI.js @@ -113,6 +113,10 @@ class FluxAPI extends Tool { /** @type {boolean} **/ this.isAgent = fields.isAgent; + if (this.isAgent) { + /** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */ + this.responseFormat = 'content_and_artifact'; + } this.returnMetadata = fields.returnMetadata ?? false; if (fields.processFileURL) { @@ -524,10 +528,40 @@ class FluxAPI extends Tool { return this.returnValue('No image data received from Flux API.'); } - // Try saving the image locally const imageUrl = resultData.sample; const imageName = `img-${uuidv4()}.png`; + if (this.isAgent) { + try { + const fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(imageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } catch (error) { + logger.error('[FluxAPI] Error processing finetuned image for agent:', error); + return this.returnValue(`Failed to process the finetuned image. ${error.message}`); + } + } + try { logger.debug('[FluxAPI] Saving finetuned image:', imageUrl); const result = await this.processFileURL({ @@ -541,12 +575,6 @@ class FluxAPI extends Tool { logger.debug('[FluxAPI] Finetuned image saved to path:', result.filepath); - // Calculate cost based on endpoint - const endpointKey = endpoint.includes('ultra') - ? 'FLUX_PRO_1_1_ULTRA_FINETUNED' - : 'FLUX_PRO_FINETUNED'; - const cost = FluxAPI.PRICING[endpointKey] || 0; - // Return the result based on returnMetadata flag this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); return this.returnValue(this.result); } catch (error) { diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index d7a7a4d96b..8cf4b141bb 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -43,6 +43,10 @@ class StableDiffusionAPI extends Tool { this.returnMetadata = fields.returnMetadata ?? false; /** @type {boolean} */ this.isAgent = fields.isAgent; + if (this.isAgent) { + /** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */ + this.responseFormat = 'content_and_artifact'; + } if (fields.uploadImageBuffer) { /** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */ this.uploadImageBuffer = fields.uploadImageBuffer.bind(this); @@ -115,7 +119,7 @@ class StableDiffusionAPI extends Tool { generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload); } catch (error) { logger.error('[StableDiffusion] Error while generating image:', error); - return 'Error making API request.'; + return this.returnValue('Error making API request.'); } const image = generationResponse.data.images[0]; diff --git a/api/app/clients/tools/structured/specs/imageTools-agent.spec.js b/api/app/clients/tools/structured/specs/imageTools-agent.spec.js new file mode 100644 index 0000000000..b82dd87b3f --- /dev/null +++ b/api/app/clients/tools/structured/specs/imageTools-agent.spec.js @@ -0,0 +1,294 @@ +/** + * Regression tests for image tool agent mode — verifies that invoke() returns + * a ToolMessage with base64 in artifact.content rather than serialized into content. + * + * Root cause: DALLE3/FluxAPI/StableDiffusion extend LangChain's Tool but did not + * set responseFormat = 'content_and_artifact'. LangChain's invoke() would then + * JSON.stringify the entire [content, artifact] tuple into ToolMessage.content, + * dumping base64 into token counting and causing context exhaustion. + */ + +const axios = require('axios'); +const OpenAI = require('openai'); +const undici = require('undici'); +const fetch = require('node-fetch'); +const { ToolMessage } = require('@langchain/core/messages'); +const { ContentTypes } = require('librechat-data-provider'); +const StableDiffusionAPI = require('../StableDiffusion'); +const FluxAPI = require('../FluxAPI'); +const DALLE3 = require('../DALLE3'); + +jest.mock('axios'); +jest.mock('openai'); +jest.mock('node-fetch'); +jest.mock('undici', () => ({ + ProxyAgent: jest.fn(), + fetch: jest.fn(), +})); +jest.mock('@librechat/data-schemas', () => ({ + logger: { info: jest.fn(), warn: jest.fn(), debug: jest.fn(), error: jest.fn() }, +})); +jest.mock('path', () => ({ + resolve: jest.fn(), + join: jest.fn().mockReturnValue('/mock/path'), + relative: jest.fn().mockReturnValue('relative/path'), + extname: jest.fn().mockReturnValue('.png'), +})); +jest.mock('fs', () => ({ + existsSync: jest.fn().mockReturnValue(true), + mkdirSync: jest.fn(), + promises: { writeFile: jest.fn(), readFile: jest.fn(), unlink: jest.fn() }, +})); + +const FAKE_BASE64 = 'aGVsbG8='; + +const makeToolCall = (name, args) => ({ + id: 'call_test_123', + name, + args, + type: 'tool_call', +}); + +describe('image tools - agent mode ToolMessage format', () => { + const ENV_KEYS = ['DALLE_API_KEY', 'FLUX_API_KEY', 'SD_WEBUI_URL', 'PROXY']; + let savedEnv = {}; + + beforeEach(() => { + jest.clearAllMocks(); + for (const key of ENV_KEYS) { + savedEnv[key] = process.env[key]; + } + process.env.DALLE_API_KEY = 'test-dalle-key'; + process.env.FLUX_API_KEY = 'test-flux-key'; + process.env.SD_WEBUI_URL = 'http://localhost:7860'; + delete process.env.PROXY; + }); + + afterEach(() => { + for (const key of ENV_KEYS) { + if (savedEnv[key] === undefined) { + delete process.env[key]; + } else { + process.env[key] = savedEnv[key]; + } + } + savedEnv = {}; + }); + + describe('DALLE3', () => { + beforeEach(() => { + OpenAI.mockImplementation(() => ({ + images: { + generate: jest.fn().mockResolvedValue({ + data: [{ url: 'https://example.com/image.png' }], + }), + }, + })); + undici.fetch.mockResolvedValue({ + arrayBuffer: () => Promise.resolve(Buffer.from(FAKE_BASE64, 'base64')), + }); + }); + + it('sets responseFormat to content_and_artifact when isAgent is true', () => { + const dalle = new DALLE3({ isAgent: true }); + expect(dalle.responseFormat).toBe('content_and_artifact'); + }); + + it('does not set responseFormat when isAgent is false', () => { + const dalle = new DALLE3({ isAgent: false, processFileURL: jest.fn() }); + expect(dalle.responseFormat).not.toBe('content_and_artifact'); + }); + + it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => { + const dalle = new DALLE3({ isAgent: true }); + const result = await dalle.invoke( + makeToolCall('dalle', { + prompt: 'a box', + quality: 'standard', + size: '1024x1024', + style: 'vivid', + }), + ); + + expect(result).toBeInstanceOf(ToolMessage); + + const contentStr = + typeof result.content === 'string' ? result.content : JSON.stringify(result.content); + expect(contentStr).not.toContain(FAKE_BASE64); + + expect(result.artifact).toBeDefined(); + const artifactContent = result.artifact?.content; + expect(Array.isArray(artifactContent)).toBe(true); + expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL); + expect(artifactContent[0].image_url.url).toContain('base64'); + }); + + it('invoke() returns ToolMessage with error string in content when API fails', async () => { + OpenAI.mockImplementation(() => ({ + images: { generate: jest.fn().mockRejectedValue(new Error('API error')) }, + })); + + const dalle = new DALLE3({ isAgent: true }); + const result = await dalle.invoke( + makeToolCall('dalle', { + prompt: 'a box', + quality: 'standard', + size: '1024x1024', + style: 'vivid', + }), + ); + + expect(result).toBeInstanceOf(ToolMessage); + const contentStr = + typeof result.content === 'string' ? result.content : JSON.stringify(result.content); + expect(contentStr).toContain('Something went wrong'); + expect(result.artifact).toBeDefined(); + }); + }); + + describe('FluxAPI', () => { + beforeEach(() => { + jest.useFakeTimers(); + axios.post.mockResolvedValue({ data: { id: 'task-123' } }); + axios.get.mockResolvedValue({ + data: { status: 'Ready', result: { sample: 'https://example.com/image.png' } }, + }); + fetch.mockResolvedValue({ + arrayBuffer: () => Promise.resolve(Buffer.from(FAKE_BASE64, 'base64')), + }); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('sets responseFormat to content_and_artifact when isAgent is true', () => { + const flux = new FluxAPI({ isAgent: true }); + expect(flux.responseFormat).toBe('content_and_artifact'); + }); + + it('does not set responseFormat when isAgent is false', () => { + const flux = new FluxAPI({ isAgent: false, processFileURL: jest.fn() }); + expect(flux.responseFormat).not.toBe('content_and_artifact'); + }); + + it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => { + const flux = new FluxAPI({ isAgent: true }); + const invokePromise = flux.invoke( + makeToolCall('flux', { prompt: 'a box', endpoint: '/v1/flux-dev' }), + ); + await jest.runAllTimersAsync(); + const result = await invokePromise; + + expect(result).toBeInstanceOf(ToolMessage); + const contentStr = + typeof result.content === 'string' ? result.content : JSON.stringify(result.content); + expect(contentStr).not.toContain(FAKE_BASE64); + + expect(result.artifact).toBeDefined(); + const artifactContent = result.artifact?.content; + expect(Array.isArray(artifactContent)).toBe(true); + expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL); + expect(artifactContent[0].image_url.url).toContain('base64'); + }); + + it('invoke() returns ToolMessage with base64 in artifact for generate_finetuned action', async () => { + const flux = new FluxAPI({ isAgent: true }); + const invokePromise = flux.invoke( + makeToolCall('flux', { + action: 'generate_finetuned', + prompt: 'a box', + finetune_id: 'ft-abc123', + endpoint: '/v1/flux-pro-finetuned', + }), + ); + await jest.runAllTimersAsync(); + const result = await invokePromise; + + expect(result).toBeInstanceOf(ToolMessage); + const contentStr = + typeof result.content === 'string' ? result.content : JSON.stringify(result.content); + expect(contentStr).not.toContain(FAKE_BASE64); + + expect(result.artifact).toBeDefined(); + const artifactContent = result.artifact?.content; + expect(Array.isArray(artifactContent)).toBe(true); + expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL); + expect(artifactContent[0].image_url.url).toContain('base64'); + }); + + it('invoke() returns ToolMessage with error string in content when task submission fails', async () => { + axios.post.mockRejectedValue(new Error('Network error')); + + const flux = new FluxAPI({ isAgent: true }); + const invokePromise = flux.invoke( + makeToolCall('flux', { prompt: 'a box', endpoint: '/v1/flux-dev' }), + ); + await jest.runAllTimersAsync(); + const result = await invokePromise; + + expect(result).toBeInstanceOf(ToolMessage); + const contentStr = + typeof result.content === 'string' ? result.content : JSON.stringify(result.content); + expect(contentStr).toContain('Something went wrong'); + expect(result.artifact).toBeDefined(); + }); + }); + + describe('StableDiffusion', () => { + beforeEach(() => { + axios.post.mockResolvedValue({ + data: { + images: [FAKE_BASE64], + info: JSON.stringify({ height: 1024, width: 1024, seed: 42, infotexts: [] }), + }, + }); + }); + + it('sets responseFormat to content_and_artifact when isAgent is true', () => { + const sd = new StableDiffusionAPI({ isAgent: true, override: true }); + expect(sd.responseFormat).toBe('content_and_artifact'); + }); + + it('does not set responseFormat when isAgent is false', () => { + const sd = new StableDiffusionAPI({ + isAgent: false, + override: true, + uploadImageBuffer: jest.fn(), + }); + expect(sd.responseFormat).not.toBe('content_and_artifact'); + }); + + it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => { + const sd = new StableDiffusionAPI({ isAgent: true, override: true, userId: 'user-1' }); + const result = await sd.invoke( + makeToolCall('stable-diffusion', { prompt: 'a box', negative_prompt: '' }), + ); + + expect(result).toBeInstanceOf(ToolMessage); + const contentStr = + typeof result.content === 'string' ? result.content : JSON.stringify(result.content); + expect(contentStr).not.toContain(FAKE_BASE64); + + expect(result.artifact).toBeDefined(); + const artifactContent = result.artifact?.content; + expect(Array.isArray(artifactContent)).toBe(true); + expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL); + expect(artifactContent[0].image_url.url).toContain('base64'); + }); + + it('invoke() returns ToolMessage with error string in content when API fails', async () => { + axios.post.mockRejectedValue(new Error('Connection refused')); + + const sd = new StableDiffusionAPI({ isAgent: true, override: true, userId: 'user-1' }); + const result = await sd.invoke( + makeToolCall('stable-diffusion', { prompt: 'a box', negative_prompt: '' }), + ); + + expect(result).toBeInstanceOf(ToolMessage); + const contentStr = + typeof result.content === 'string' ? result.content : JSON.stringify(result.content); + expect(contentStr).toContain('Error making API request'); + }); + }); +});