diff --git a/.env.example b/.env.example index dfde0428d7..f6930b8564 100644 --- a/.env.example +++ b/.env.example @@ -248,6 +248,7 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= # IMAGE_GEN_OAI_API_KEY= # Create or reuse OpenAI API key for image generation tool # IMAGE_GEN_OAI_BASEURL= # Custom OpenAI base URL for image generation tool # IMAGE_GEN_OAI_AZURE_API_VERSION= # Custom Azure OpenAI deployments +# IMAGE_GEN_OAI_MODEL=gpt-image-1 # OpenAI image model (e.g., gpt-image-1, gpt-image-1.5) # IMAGE_GEN_OAI_DESCRIPTION= # IMAGE_GEN_OAI_DESCRIPTION_WITH_FILES=Custom description for image generation tool when files are present # IMAGE_GEN_OAI_DESCRIPTION_NO_FILES=Custom description for image generation tool when no files are present diff --git a/api/app/clients/tools/structured/OpenAIImageTools.js b/api/app/clients/tools/structured/OpenAIImageTools.js index 3771167c51..e27a01786e 100644 --- a/api/app/clients/tools/structured/OpenAIImageTools.js +++ b/api/app/clients/tools/structured/OpenAIImageTools.js @@ -78,6 +78,8 @@ function createOpenAIImageTools(fields = {}) { let apiKey = fields.IMAGE_GEN_OAI_API_KEY ?? getApiKey(); const closureConfig = { apiKey }; + const imageModel = process.env.IMAGE_GEN_OAI_MODEL || 'gpt-image-1'; + let baseURL = 'https://api.openai.com/v1/'; if (!override && process.env.IMAGE_GEN_OAI_BASEURL) { baseURL = extractBaseURL(process.env.IMAGE_GEN_OAI_BASEURL); @@ -157,7 +159,7 @@ function createOpenAIImageTools(fields = {}) { resp = await openai.images.generate( { - model: 'gpt-image-1', + model: imageModel, prompt: replaceUnwantedChars(prompt), n: Math.min(Math.max(1, n), 10), background, @@ -239,7 +241,7 @@ Error Message: ${error.message}`); } const formData = new FormData(); - formData.append('model', 'gpt-image-1'); + formData.append('model', imageModel); formData.append('prompt', replaceUnwantedChars(prompt)); // TODO: `mask` support // TODO: more than 1 image support diff --git a/api/test/app/clients/tools/structured/OpenAIImageTools.test.js b/api/test/app/clients/tools/structured/OpenAIImageTools.test.js new file mode 100644 index 0000000000..aa0726b916 --- /dev/null +++ b/api/test/app/clients/tools/structured/OpenAIImageTools.test.js @@ -0,0 +1,162 @@ +const OpenAI = require('openai'); +const createOpenAIImageTools = require('~/app/clients/tools/structured/OpenAIImageTools'); + +jest.mock('openai'); +jest.mock('@librechat/data-schemas', () => ({ + logger: { + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + logAxiosError: jest.fn(), + oaiToolkit: { + image_gen_oai: { + name: 'image_gen_oai', + description: 'Generate an image', + schema: {}, + }, + image_edit_oai: { + name: 'image_edit_oai', + description: 'Edit an image', + schema: {}, + }, + }, + extractBaseURL: jest.fn((url) => url), +})); + +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +jest.mock('~/models', () => ({ + getFiles: jest.fn().mockResolvedValue([]), +})); + +describe('OpenAIImageTools - IMAGE_GEN_OAI_MODEL environment variable', () => { + let originalEnv; + + beforeEach(() => { + jest.clearAllMocks(); + originalEnv = { ...process.env }; + + process.env.IMAGE_GEN_OAI_API_KEY = 'test-api-key'; + + OpenAI.mockImplementation(() => ({ + images: { + generate: jest.fn().mockResolvedValue({ + data: [ + { + b64_json: 'base64-encoded-image-data', + }, + ], + }), + }, + })); + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it('should use default model "gpt-image-1" when IMAGE_GEN_OAI_MODEL is not set', async () => { + delete process.env.IMAGE_GEN_OAI_MODEL; + + const [imageGenTool] = createOpenAIImageTools({ + isAgent: true, + override: false, + req: { user: { id: 'test-user' } }, + }); + + const mockGenerate = jest.fn().mockResolvedValue({ + data: [ + { + b64_json: 'base64-encoded-image-data', + }, + ], + }); + + OpenAI.mockImplementation(() => ({ + images: { + generate: mockGenerate, + }, + })); + + await imageGenTool.func({ prompt: 'test prompt' }); + + expect(mockGenerate).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gpt-image-1', + }), + expect.any(Object), + ); + }); + + it('should use "gpt-image-1.5" when IMAGE_GEN_OAI_MODEL is set to "gpt-image-1.5"', async () => { + process.env.IMAGE_GEN_OAI_MODEL = 'gpt-image-1.5'; + + const mockGenerate = jest.fn().mockResolvedValue({ + data: [ + { + b64_json: 'base64-encoded-image-data', + }, + ], + }); + + OpenAI.mockImplementation(() => ({ + images: { + generate: mockGenerate, + }, + })); + + const [imageGenTool] = createOpenAIImageTools({ + isAgent: true, + override: false, + req: { user: { id: 'test-user' } }, + }); + + await imageGenTool.func({ prompt: 'test prompt' }); + + expect(mockGenerate).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gpt-image-1.5', + }), + expect.any(Object), + ); + }); + + it('should use custom model name from IMAGE_GEN_OAI_MODEL environment variable', async () => { + process.env.IMAGE_GEN_OAI_MODEL = 'custom-image-model'; + + const mockGenerate = jest.fn().mockResolvedValue({ + data: [ + { + b64_json: 'base64-encoded-image-data', + }, + ], + }); + + OpenAI.mockImplementation(() => ({ + images: { + generate: mockGenerate, + }, + })); + + const [imageGenTool] = createOpenAIImageTools({ + isAgent: true, + override: false, + req: { user: { id: 'test-user' } }, + }); + + await imageGenTool.func({ prompt: 'test prompt' }); + + expect(mockGenerate).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'custom-image-model', + }), + expect.any(Object), + ); + }); +});