mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-03-21 15:16:33 +01:00
🖼️ fix: Correct ToolMessage Response Format for Agent-Mode Image Tools (#12310)
* fix: Set response format for agent tools in DALLE3, FluxAPI, and StableDiffusion classes - Added logic to set `responseFormat` to 'content_and_artifact' when `isAgent` is true in DALLE3.js, FluxAPI.js, and StableDiffusion.js. * test: Add regression tests for image tool agent mode in imageTools-agent.spec.js - Introduced a new test suite for DALLE3, FluxAPI, and StableDiffusion classes to verify that the invoke() method returns a ToolMessage with base64 in artifact.content, ensuring it is not serialized into content. - Validated that responseFormat is set to 'content_and_artifact' when isAgent is true, and confirmed the correct handling of base64 data in the response. * fix: handle agent error paths and generateFinetunedImage in image tools - StableDiffusion._call() was returning a raw string on API error, bypassing returnValue() and breaking the content_and_artifact contract when isAgent is true - FluxAPI.generateFinetunedImage() had no isAgent branch; it would call processFileURL (unset in agent context) instead of fetching and returning the base64 image as an artifact tuple - Add JSDoc to all three responseFormat assignments clarifying why LangChain requires this property for correct ToolMessage construction * test: expand image tool agent mode regression suite - Add env var save/restore in beforeEach/afterEach to prevent test pollution - Add error path tests for all three tools verifying ToolMessage content and artifact are correctly populated when the upstream API fails - Add generate_finetuned action test for FluxAPI covering the new agent branch in generateFinetunedImage * chore: fix lint errors in FluxAPI and imageTools-agent spec * chore: fix import ordering in imageTools-agent spec
This commit is contained in:
parent
93952f06b4
commit
a88bfae4dd
4 changed files with 338 additions and 8 deletions
|
|
@ -51,6 +51,10 @@ class DALLE3 extends Tool {
|
|||
this.fileStrategy = fields.fileStrategy;
|
||||
/** @type {boolean} */
|
||||
this.isAgent = fields.isAgent;
|
||||
if (this.isAgent) {
|
||||
/** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */
|
||||
this.responseFormat = 'content_and_artifact';
|
||||
}
|
||||
if (fields.processFileURL) {
|
||||
/** @type {processFileURL} Necessary for output to contain all image metadata. */
|
||||
this.processFileURL = fields.processFileURL.bind(this);
|
||||
|
|
|
|||
|
|
@ -113,6 +113,10 @@ class FluxAPI extends Tool {
|
|||
|
||||
/** @type {boolean} **/
|
||||
this.isAgent = fields.isAgent;
|
||||
if (this.isAgent) {
|
||||
/** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */
|
||||
this.responseFormat = 'content_and_artifact';
|
||||
}
|
||||
this.returnMetadata = fields.returnMetadata ?? false;
|
||||
|
||||
if (fields.processFileURL) {
|
||||
|
|
@ -524,10 +528,40 @@ class FluxAPI extends Tool {
|
|||
return this.returnValue('No image data received from Flux API.');
|
||||
}
|
||||
|
||||
// Try saving the image locally
|
||||
const imageUrl = resultData.sample;
|
||||
const imageName = `img-${uuidv4()}.png`;
|
||||
|
||||
if (this.isAgent) {
|
||||
try {
|
||||
const fetchOptions = {};
|
||||
if (process.env.PROXY) {
|
||||
fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
const imageResponse = await fetch(imageUrl, fetchOptions);
|
||||
const arrayBuffer = await imageResponse.arrayBuffer();
|
||||
const base64 = Buffer.from(arrayBuffer).toString('base64');
|
||||
const content = [
|
||||
{
|
||||
type: ContentTypes.IMAGE_URL,
|
||||
image_url: {
|
||||
url: `data:image/png;base64,${base64}`,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const response = [
|
||||
{
|
||||
type: ContentTypes.TEXT,
|
||||
text: displayMessage,
|
||||
},
|
||||
];
|
||||
return [response, { content }];
|
||||
} catch (error) {
|
||||
logger.error('[FluxAPI] Error processing finetuned image for agent:', error);
|
||||
return this.returnValue(`Failed to process the finetuned image. ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
logger.debug('[FluxAPI] Saving finetuned image:', imageUrl);
|
||||
const result = await this.processFileURL({
|
||||
|
|
@ -541,12 +575,6 @@ class FluxAPI extends Tool {
|
|||
|
||||
logger.debug('[FluxAPI] Finetuned image saved to path:', result.filepath);
|
||||
|
||||
// Calculate cost based on endpoint
|
||||
const endpointKey = endpoint.includes('ultra')
|
||||
? 'FLUX_PRO_1_1_ULTRA_FINETUNED'
|
||||
: 'FLUX_PRO_FINETUNED';
|
||||
const cost = FluxAPI.PRICING[endpointKey] || 0;
|
||||
// Return the result based on returnMetadata flag
|
||||
this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath);
|
||||
return this.returnValue(this.result);
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ class StableDiffusionAPI extends Tool {
|
|||
this.returnMetadata = fields.returnMetadata ?? false;
|
||||
/** @type {boolean} */
|
||||
this.isAgent = fields.isAgent;
|
||||
if (this.isAgent) {
|
||||
/** Ensures LangChain maps [content, artifact] tuple to ToolMessage fields instead of serializing it into content. */
|
||||
this.responseFormat = 'content_and_artifact';
|
||||
}
|
||||
if (fields.uploadImageBuffer) {
|
||||
/** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */
|
||||
this.uploadImageBuffer = fields.uploadImageBuffer.bind(this);
|
||||
|
|
@ -115,7 +119,7 @@ class StableDiffusionAPI extends Tool {
|
|||
generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
} catch (error) {
|
||||
logger.error('[StableDiffusion] Error while generating image:', error);
|
||||
return 'Error making API request.';
|
||||
return this.returnValue('Error making API request.');
|
||||
}
|
||||
const image = generationResponse.data.images[0];
|
||||
|
||||
|
|
|
|||
294
api/app/clients/tools/structured/specs/imageTools-agent.spec.js
Normal file
294
api/app/clients/tools/structured/specs/imageTools-agent.spec.js
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
/**
|
||||
* Regression tests for image tool agent mode — verifies that invoke() returns
|
||||
* a ToolMessage with base64 in artifact.content rather than serialized into content.
|
||||
*
|
||||
* Root cause: DALLE3/FluxAPI/StableDiffusion extend LangChain's Tool but did not
|
||||
* set responseFormat = 'content_and_artifact'. LangChain's invoke() would then
|
||||
* JSON.stringify the entire [content, artifact] tuple into ToolMessage.content,
|
||||
* dumping base64 into token counting and causing context exhaustion.
|
||||
*/
|
||||
|
||||
const axios = require('axios');
|
||||
const OpenAI = require('openai');
|
||||
const undici = require('undici');
|
||||
const fetch = require('node-fetch');
|
||||
const { ToolMessage } = require('@langchain/core/messages');
|
||||
const { ContentTypes } = require('librechat-data-provider');
|
||||
const StableDiffusionAPI = require('../StableDiffusion');
|
||||
const FluxAPI = require('../FluxAPI');
|
||||
const DALLE3 = require('../DALLE3');
|
||||
|
||||
jest.mock('axios');
|
||||
jest.mock('openai');
|
||||
jest.mock('node-fetch');
|
||||
jest.mock('undici', () => ({
|
||||
ProxyAgent: jest.fn(),
|
||||
fetch: jest.fn(),
|
||||
}));
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: { info: jest.fn(), warn: jest.fn(), debug: jest.fn(), error: jest.fn() },
|
||||
}));
|
||||
jest.mock('path', () => ({
|
||||
resolve: jest.fn(),
|
||||
join: jest.fn().mockReturnValue('/mock/path'),
|
||||
relative: jest.fn().mockReturnValue('relative/path'),
|
||||
extname: jest.fn().mockReturnValue('.png'),
|
||||
}));
|
||||
jest.mock('fs', () => ({
|
||||
existsSync: jest.fn().mockReturnValue(true),
|
||||
mkdirSync: jest.fn(),
|
||||
promises: { writeFile: jest.fn(), readFile: jest.fn(), unlink: jest.fn() },
|
||||
}));
|
||||
|
||||
const FAKE_BASE64 = 'aGVsbG8=';
|
||||
|
||||
const makeToolCall = (name, args) => ({
|
||||
id: 'call_test_123',
|
||||
name,
|
||||
args,
|
||||
type: 'tool_call',
|
||||
});
|
||||
|
||||
describe('image tools - agent mode ToolMessage format', () => {
|
||||
const ENV_KEYS = ['DALLE_API_KEY', 'FLUX_API_KEY', 'SD_WEBUI_URL', 'PROXY'];
|
||||
let savedEnv = {};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
for (const key of ENV_KEYS) {
|
||||
savedEnv[key] = process.env[key];
|
||||
}
|
||||
process.env.DALLE_API_KEY = 'test-dalle-key';
|
||||
process.env.FLUX_API_KEY = 'test-flux-key';
|
||||
process.env.SD_WEBUI_URL = 'http://localhost:7860';
|
||||
delete process.env.PROXY;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
for (const key of ENV_KEYS) {
|
||||
if (savedEnv[key] === undefined) {
|
||||
delete process.env[key];
|
||||
} else {
|
||||
process.env[key] = savedEnv[key];
|
||||
}
|
||||
}
|
||||
savedEnv = {};
|
||||
});
|
||||
|
||||
describe('DALLE3', () => {
|
||||
beforeEach(() => {
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate: jest.fn().mockResolvedValue({
|
||||
data: [{ url: 'https://example.com/image.png' }],
|
||||
}),
|
||||
},
|
||||
}));
|
||||
undici.fetch.mockResolvedValue({
|
||||
arrayBuffer: () => Promise.resolve(Buffer.from(FAKE_BASE64, 'base64')),
|
||||
});
|
||||
});
|
||||
|
||||
it('sets responseFormat to content_and_artifact when isAgent is true', () => {
|
||||
const dalle = new DALLE3({ isAgent: true });
|
||||
expect(dalle.responseFormat).toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('does not set responseFormat when isAgent is false', () => {
|
||||
const dalle = new DALLE3({ isAgent: false, processFileURL: jest.fn() });
|
||||
expect(dalle.responseFormat).not.toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => {
|
||||
const dalle = new DALLE3({ isAgent: true });
|
||||
const result = await dalle.invoke(
|
||||
makeToolCall('dalle', {
|
||||
prompt: 'a box',
|
||||
quality: 'standard',
|
||||
size: '1024x1024',
|
||||
style: 'vivid',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with error string in content when API fails', async () => {
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: { generate: jest.fn().mockRejectedValue(new Error('API error')) },
|
||||
}));
|
||||
|
||||
const dalle = new DALLE3({ isAgent: true });
|
||||
const result = await dalle.invoke(
|
||||
makeToolCall('dalle', {
|
||||
prompt: 'a box',
|
||||
quality: 'standard',
|
||||
size: '1024x1024',
|
||||
style: 'vivid',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).toContain('Something went wrong');
|
||||
expect(result.artifact).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('FluxAPI', () => {
|
||||
beforeEach(() => {
|
||||
jest.useFakeTimers();
|
||||
axios.post.mockResolvedValue({ data: { id: 'task-123' } });
|
||||
axios.get.mockResolvedValue({
|
||||
data: { status: 'Ready', result: { sample: 'https://example.com/image.png' } },
|
||||
});
|
||||
fetch.mockResolvedValue({
|
||||
arrayBuffer: () => Promise.resolve(Buffer.from(FAKE_BASE64, 'base64')),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('sets responseFormat to content_and_artifact when isAgent is true', () => {
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
expect(flux.responseFormat).toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('does not set responseFormat when isAgent is false', () => {
|
||||
const flux = new FluxAPI({ isAgent: false, processFileURL: jest.fn() });
|
||||
expect(flux.responseFormat).not.toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => {
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
const invokePromise = flux.invoke(
|
||||
makeToolCall('flux', { prompt: 'a box', endpoint: '/v1/flux-dev' }),
|
||||
);
|
||||
await jest.runAllTimersAsync();
|
||||
const result = await invokePromise;
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact for generate_finetuned action', async () => {
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
const invokePromise = flux.invoke(
|
||||
makeToolCall('flux', {
|
||||
action: 'generate_finetuned',
|
||||
prompt: 'a box',
|
||||
finetune_id: 'ft-abc123',
|
||||
endpoint: '/v1/flux-pro-finetuned',
|
||||
}),
|
||||
);
|
||||
await jest.runAllTimersAsync();
|
||||
const result = await invokePromise;
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with error string in content when task submission fails', async () => {
|
||||
axios.post.mockRejectedValue(new Error('Network error'));
|
||||
|
||||
const flux = new FluxAPI({ isAgent: true });
|
||||
const invokePromise = flux.invoke(
|
||||
makeToolCall('flux', { prompt: 'a box', endpoint: '/v1/flux-dev' }),
|
||||
);
|
||||
await jest.runAllTimersAsync();
|
||||
const result = await invokePromise;
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).toContain('Something went wrong');
|
||||
expect(result.artifact).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('StableDiffusion', () => {
|
||||
beforeEach(() => {
|
||||
axios.post.mockResolvedValue({
|
||||
data: {
|
||||
images: [FAKE_BASE64],
|
||||
info: JSON.stringify({ height: 1024, width: 1024, seed: 42, infotexts: [] }),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('sets responseFormat to content_and_artifact when isAgent is true', () => {
|
||||
const sd = new StableDiffusionAPI({ isAgent: true, override: true });
|
||||
expect(sd.responseFormat).toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('does not set responseFormat when isAgent is false', () => {
|
||||
const sd = new StableDiffusionAPI({
|
||||
isAgent: false,
|
||||
override: true,
|
||||
uploadImageBuffer: jest.fn(),
|
||||
});
|
||||
expect(sd.responseFormat).not.toBe('content_and_artifact');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with base64 in artifact, not serialized in content', async () => {
|
||||
const sd = new StableDiffusionAPI({ isAgent: true, override: true, userId: 'user-1' });
|
||||
const result = await sd.invoke(
|
||||
makeToolCall('stable-diffusion', { prompt: 'a box', negative_prompt: '' }),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).not.toContain(FAKE_BASE64);
|
||||
|
||||
expect(result.artifact).toBeDefined();
|
||||
const artifactContent = result.artifact?.content;
|
||||
expect(Array.isArray(artifactContent)).toBe(true);
|
||||
expect(artifactContent[0].type).toBe(ContentTypes.IMAGE_URL);
|
||||
expect(artifactContent[0].image_url.url).toContain('base64');
|
||||
});
|
||||
|
||||
it('invoke() returns ToolMessage with error string in content when API fails', async () => {
|
||||
axios.post.mockRejectedValue(new Error('Connection refused'));
|
||||
|
||||
const sd = new StableDiffusionAPI({ isAgent: true, override: true, userId: 'user-1' });
|
||||
const result = await sd.invoke(
|
||||
makeToolCall('stable-diffusion', { prompt: 'a box', negative_prompt: '' }),
|
||||
);
|
||||
|
||||
expect(result).toBeInstanceOf(ToolMessage);
|
||||
const contentStr =
|
||||
typeof result.content === 'string' ? result.content : JSON.stringify(result.content);
|
||||
expect(contentStr).toContain('Error making API request');
|
||||
});
|
||||
});
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue