🚀 Feat: Streamline File Strategies & GPT-4-Vision Settings (#1535)

* chore: fix `endpoint` typescript issues and typo in console info message

* feat(api): files GET endpoint and save only file_id references to messages

* refactor(client): `useGetFiles` query hook, update file types, optimistic update of filesQuery on file upload

* refactor(buildTree): update to use params object and accept fileMap

* feat: map files to messages; refactor(ChatView): messages only available after files are fetched

* fix: fetch files only when authenticated

* feat(api): AppService
- rename app.locals.configs to app.locals.paths
- load custom config use fileStrategy from yaml config in app.locals

* refactor: separate Firebase and Local strategies, call based on config

* refactor: modularize file strategies and employ with use of DALL-E

* refactor(librechat.yaml): add fileStrategy field

* feat: add source to MongoFile schema, as well as BatchFile, and ExtendedFile types

* feat: employ file strategies for upload/delete files

* refactor(deleteFirebaseFile): add user id validation for firebase file deletion

* chore(deleteFirebaseFile): update jsdocs

* feat: employ strategies for vision requests

* fix(client): handle messages with deleted files

* fix(client): ensure `filesToDelete` always saves/sends `file.source`

* feat(openAI): configurable `resendImages` and `imageDetail`

* refactor(getTokenCountForMessage): recursive process only when array of Objects and only their values (not keys) aside from `image_url` types

* feat(OpenAIClient): calculateImageTokenCost

* chore: remove comment

* refactor(uploadAvatar): employ fileStrategy for avatars, from social logins or user upload

* docs: update docs on how to configure fileStrategy

* fix(ci): mock winston and winston related modules, update DALLE3.spec.js with changes made

* refactor(redis): change terminal message to reflect current development state

* fix(DALL-E-2): pass fileStrategy to dall-e
This commit is contained in:
Danny Avila 2024-01-11 11:37:54 -05:00 committed by GitHub
parent 28a6807176
commit d20970f5c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
81 changed files with 1729 additions and 855 deletions

View file

@ -46,6 +46,10 @@ class BaseClient {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', response);
}
async addPreviousAttachments(messages) {
return messages;
}
async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', {
promptTokens,
@ -484,20 +488,22 @@ class BaseClient {
mapMethod = this.getMessageMapMethod();
}
const orderedMessages = this.constructor.getMessagesForConversation({
let _messages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
mapMethod,
});
_messages = await this.addPreviousAttachments(_messages);
if (!this.shouldSummarize) {
return orderedMessages;
return _messages;
}
// Find the latest message with a 'summary' property
for (let i = orderedMessages.length - 1; i >= 0; i--) {
if (orderedMessages[i]?.summary) {
this.previous_summary = orderedMessages[i];
for (let i = _messages.length - 1; i >= 0; i--) {
if (_messages[i]?.summary) {
this.previous_summary = _messages[i];
break;
}
}
@ -512,7 +518,7 @@ class BaseClient {
});
}
return orderedMessages;
return _messages;
}
async saveMessageToDatabase(message, endpointOptions, user = null) {
@ -618,6 +624,11 @@ class BaseClient {
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
* In our implementation, this is accounted for in the getMessagesWithinTokenLimit method.
*
* The content parts example was adapted from the following example:
* https://github.com/openai/openai-cookbook/pull/881/files
*
* Note: image token calculation is to be done elsewhere where we have access to the image metadata
*
* @param {Object} message
*/
getTokenCountForMessage(message) {
@ -631,11 +642,18 @@ class BaseClient {
}
const processValue = (value) => {
if (typeof value === 'object' && value !== null) {
for (let [nestedKey, nestedValue] of Object.entries(value)) {
if (nestedKey === 'image_url' || nestedValue === 'image_url') {
if (Array.isArray(value)) {
for (let item of value) {
if (!item || !item.type || item.type === 'image_url') {
continue;
}
const nestedValue = item[item.type];
if (!nestedValue) {
continue;
}
processValue(nestedValue);
}
} else {

View file

@ -1,6 +1,6 @@
const OpenAI = require('openai');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { getResponseSender } = require('librechat-data-provider');
const { getResponseSender, ImageDetailCost, ImageDetail } = require('librechat-data-provider');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils');
@ -8,8 +8,9 @@ const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
const { handleOpenAIErrors } = require('./tools/util');
const spendTokens = require('~/models/spendTokens');
const { createLLM, RunManager } = require('./llm');
const { isEnabled } = require('~/server/utils');
const ChatGPTClient = require('./ChatGPTClient');
const { isEnabled } = require('~/server/utils');
const { getFiles } = require('~/models/File');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
@ -76,16 +77,7 @@ class OpenAIClient extends BaseClient {
};
}
this.isVisionModel = validateVisionModel(this.modelOptions.model);
if (this.options.attachments && !this.isVisionModel) {
this.modelOptions.model = 'gpt-4-vision-preview';
this.isVisionModel = true;
}
if (this.isVisionModel) {
delete this.modelOptions.stop;
}
this.checkVisionRequest(this.options.attachments);
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
if (OPENROUTER_API_KEY && !this.azure) {
@ -204,6 +196,27 @@ class OpenAIClient extends BaseClient {
return this;
}
/**
*
* Checks if the model is a vision model based on request attachments and sets the appropriate options:
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
* - Sets `this.isVisionModel` to `true` if vision request.
* - Deletes `this.modelOptions.stop` if vision request.
* @param {Array<Promise<MongoFile[]> | MongoFile[]> | Record<string, MongoFile[]>} attachments
*/
checkVisionRequest(attachments) {
this.isVisionModel = validateVisionModel(this.modelOptions.model);
if (attachments && !this.isVisionModel) {
this.modelOptions.model = 'gpt-4-vision-preview';
this.isVisionModel = true;
}
if (this.isVisionModel) {
delete this.modelOptions.stop;
}
}
setupTokens() {
if (this.isChatCompletion) {
this.startToken = '||>';
@ -288,7 +301,11 @@ class OpenAIClient extends BaseClient {
tokenizerCallsCount++;
}
// Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
/**
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
* @param {string} text - The text to get the token count for.
* @returns {number} The token count of the given text.
*/
getTokenCount(text) {
this.resetTokenizersIfNecessary();
try {
@ -301,10 +318,33 @@ class OpenAIClient extends BaseClient {
}
}
/**
* Calculate the token cost for an image based on its dimensions and detail level.
*
* @param {Object} image - The image object.
* @param {number} image.width - The width of the image.
* @param {number} image.height - The height of the image.
* @param {'low'|'high'|string|undefined} [image.detail] - The detail level ('low', 'high', or other).
* @returns {number} The calculated token cost.
*/
calculateImageTokenCost({ width, height, detail }) {
if (detail === 'low') {
return ImageDetailCost.LOW;
}
// Calculate the number of 512px squares
const numSquares = Math.ceil(width / 512) * Math.ceil(height / 512);
// Default to high detail cost calculation
return numSquares * ImageDetailCost.HIGH + ImageDetailCost.ADDITIONAL;
}
getSaveOptions() {
return {
chatGptLabel: this.options.chatGptLabel,
promptPrefix: this.options.promptPrefix,
resendImages: this.options.resendImages,
imageDetail: this.options.imageDetail,
...this.modelOptions,
};
}
@ -317,6 +357,69 @@ class OpenAIClient extends BaseClient {
};
}
/**
*
* @param {TMessage[]} _messages
* @returns {TMessage[]}
*/
async addPreviousAttachments(_messages) {
if (!this.options.resendImages) {
return _messages;
}
/**
*
* @param {TMessage} message
*/
const processMessage = async (message) => {
if (!this.message_file_map) {
/** @type {Record<string, MongoFile[]> */
this.message_file_map = {};
}
const fileIds = message.files.map((file) => file.file_id);
const files = await getFiles({
file_id: { $in: fileIds },
});
await this.addImageURLs(message, files);
this.message_file_map[message.messageId] = files;
return message;
};
const promises = [];
for (const message of _messages) {
if (!message.files) {
promises.push(message);
continue;
}
promises.push(processMessage(message));
}
const messages = await Promise.all(promises);
this.checkVisionRequest(this.message_file_map);
return messages;
}
/**
*
* Adds image URLs to the message object and returns the files
*
* @param {TMessage[]} messages
* @param {MongoFile[]} files
* @returns {Promise<MongoFile[]>}
*/
async addImageURLs(message, attachments) {
const { files, image_urls } = await encodeAndFormat(this.options.req, attachments);
message.image_urls = image_urls;
return files;
}
async buildMessages(
messages,
parentMessageId,
@ -355,13 +458,23 @@ class OpenAIClient extends BaseClient {
}
if (this.options.attachments) {
const attachments = await this.options.attachments;
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments.filter((file) => file.type.includes('image')),
const attachments = (await this.options.attachments).filter((file) =>
file.type.includes('image'),
);
if (this.message_file_map) {
this.message_file_map[orderedMessages[orderedMessages.length - 1].messageId] = attachments;
} else {
this.message_file_map = {
[orderedMessages[orderedMessages.length - 1].messageId]: attachments,
};
}
const files = await this.addImageURLs(
orderedMessages[orderedMessages.length - 1],
attachments,
);
orderedMessages[orderedMessages.length - 1].image_urls = image_urls;
this.options.attachments = files;
}
@ -372,10 +485,25 @@ class OpenAIClient extends BaseClient {
assistantName: this.options?.chatGptLabel,
});
if (this.contextStrategy && !orderedMessages[i].tokenCount) {
const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount;
/* If tokens were never counted, or, is a Vision request and the message has files, count again */
if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) {
orderedMessages[i].tokenCount = this.getTokenCountForMessage(formattedMessage);
}
/* If message has files, calculate image token cost */
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
orderedMessages[i].tokenCount += this.calculateImageTokenCost({
width: file.width,
height: file.height,
detail: this.options.imageDetail ?? ImageDetail.auto,
});
}
}
return formattedMessage;
});
@ -780,7 +908,6 @@ ${convo}
if (this.isChatCompletion) {
modelOptions.messages = payload;
} else {
// TODO: unreachable code. Need to implement completions call for non-chat models
modelOptions.prompt = payload;
}
@ -916,6 +1043,8 @@ ${convo}
clientOptions.addMetadata({ finish_reason });
}
logger.debug('[OpenAIClient] chatCompletion response', chatCompletion);
return message.content;
} catch (err) {
if (

View file

@ -112,7 +112,7 @@ class PluginsClient extends OpenAIClient {
signal: this.abortController.signal,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
debug: this.options?.debug,
fileStrategy: this.options.req.app.locals.fileStrategy,
message,
},
});

View file

@ -546,6 +546,39 @@ describe('OpenAIClient', () => {
expect(totalTokens).toBe(testCase.expected);
});
});
const vision_request = [
{
role: 'user',
content: [
{
type: 'text',
text: 'describe what is in this image?',
},
{
type: 'image_url',
image_url: {
url: 'https://venturebeat.com/wp-content/uploads/2019/03/openai-1.png',
detail: 'high',
},
},
],
},
];
const expectedTokens = 14;
const visionModel = 'gpt-4-vision-preview';
it(`should return ${expectedTokens} tokens for model ${visionModel} (Vision Request)`, () => {
client.modelOptions.model = visionModel;
client.selectTokenizer();
// 3 tokens for assistant label
let totalTokens = 3;
for (let message of vision_request) {
totalTokens += client.getTokenCountForMessage(message);
}
expect(totalTokens).toBe(expectedTokens);
});
});
describe('sendMessage/getCompletion/chatCompletion', () => {

View file

@ -1,20 +1,13 @@
// From https://platform.openai.com/docs/api-reference/images/create
// To use this tool, you must pass in a configured OpenAIApi object.
const fs = require('fs');
const path = require('path');
const OpenAI = require('openai');
// const { genAzureEndpoint } = require('~/utils/genAzureEndpoints');
const { v4: uuidv4 } = require('uuid');
const { Tool } = require('langchain/tools');
const { HttpsProxyAgent } = require('https-proxy-agent');
const {
saveImageToFirebaseStorage,
getFirebaseStorageImageUrl,
getFirebaseStorage,
} = require('~/server/services/Files/Firebase');
const { getImageBasename } = require('~/server/services/Files/images');
const { processFileURL } = require('~/server/services/Files/process');
const extractBaseURL = require('~/utils/extractBaseURL');
const saveImageFromUrl = require('./saveImageFromUrl');
const { logger } = require('~/config');
const { DALLE_REVERSE_PROXY, PROXY } = process.env;
@ -23,6 +16,7 @@ class OpenAICreateImage extends Tool {
super();
this.userId = fields.userId;
this.fileStrategy = fields.fileStrategy;
let apiKey = fields.DALLE_API_KEY || this.getApiKey();
const config = { apiKey };
@ -82,12 +76,8 @@ Guidelines:
.trim();
}
getMarkdownImageUrl(imageName) {
const imageUrl = path
.join(this.relativeImageUrl, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
return `![generated image](/${imageUrl})`;
wrapInMarkdown(imageUrl) {
return `![generated image](${imageUrl})`;
}
async _call(input) {
@ -118,45 +108,21 @@ Guidelines:
});
}
this.outputPath = path.resolve(
__dirname,
'..',
'..',
'..',
'..',
'client',
'public',
'images',
this.userId,
);
try {
const result = await processFileURL({
fileStrategy: this.fileStrategy,
userId: this.userId,
URL: theImageUrl,
fileName: imageName,
basePath: 'images',
});
const appRoot = path.resolve(__dirname, '..', '..', '..', '..', 'client');
this.relativeImageUrl = path.relative(appRoot, this.outputPath);
// Check if directory exists, if not create it
if (!fs.existsSync(this.outputPath)) {
fs.mkdirSync(this.outputPath, { recursive: true });
this.result = this.wrapInMarkdown(result);
} catch (error) {
logger.error('Error while saving the image:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}
const storage = getFirebaseStorage();
if (storage) {
try {
await saveImageToFirebaseStorage(this.userId, theImageUrl, imageName);
this.result = await getFirebaseStorageImageUrl(`${this.userId}/${imageName}`);
logger.debug('[DALL-E] result: ' + this.result);
} catch (error) {
logger.error('Error while saving the image to Firebase Storage:', error);
this.result = `Failed to save the image to Firebase Storage. ${error.message}`;
}
} else {
try {
await saveImageFromUrl(theImageUrl, this.outputPath, imageName);
this.result = this.getMarkdownImageUrl(imageName);
} catch (error) {
logger.error('Error while saving the image locally:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}
}
return this.result;
}
}

View file

@ -1,46 +0,0 @@
const fs = require('fs');
const path = require('path');
const axios = require('axios');
const { logger } = require('~/config');
async function saveImageFromUrl(url, outputPath, outputFilename) {
try {
// Fetch the image from the URL
const response = await axios({
url,
responseType: 'stream',
});
// Get the content type from the response headers
const contentType = response.headers['content-type'];
let extension = contentType.split('/').pop();
// Check if the output directory exists, if not, create it
if (!fs.existsSync(outputPath)) {
fs.mkdirSync(outputPath, { recursive: true });
}
// Replace or append the correct extension
const extRegExp = new RegExp(path.extname(outputFilename) + '$');
outputFilename = outputFilename.replace(extRegExp, `.${extension}`);
if (!path.extname(outputFilename)) {
outputFilename += `.${extension}`;
}
// Create a writable stream for the output path
const outputFilePath = path.join(outputPath, outputFilename);
const writer = fs.createWriteStream(outputFilePath);
// Pipe the response data to the output file
response.data.pipe(writer);
return new Promise((resolve, reject) => {
writer.on('finish', resolve);
writer.on('error', reject);
});
} catch (error) {
logger.error('[saveImageFromUrl] Error while saving the image:', error);
}
}
module.exports = saveImageFromUrl;

View file

@ -1,20 +1,13 @@
// From https://platform.openai.com/docs/guides/images/usage?context=node
// To use this tool, you must pass in a configured OpenAIApi object.
const fs = require('fs');
const path = require('path');
const { z } = require('zod');
const OpenAI = require('openai');
const { v4: uuidv4 } = require('uuid');
const { Tool } = require('langchain/tools');
const { HttpsProxyAgent } = require('https-proxy-agent');
const {
saveImageToFirebaseStorage,
getFirebaseStorageImageUrl,
getFirebaseStorage,
} = require('~/server/services/Files/Firebase');
const { getImageBasename } = require('~/server/services/Files/images');
const { processFileURL } = require('~/server/services/Files/process');
const extractBaseURL = require('~/utils/extractBaseURL');
const saveImageFromUrl = require('../saveImageFromUrl');
const { logger } = require('~/config');
const { DALLE3_SYSTEM_PROMPT, DALLE_REVERSE_PROXY, PROXY } = process.env;
@ -23,6 +16,7 @@ class DALLE3 extends Tool {
super();
this.userId = fields.userId;
this.fileStrategy = fields.fileStrategy;
let apiKey = fields.DALLE_API_KEY || this.getApiKey();
const config = { apiKey };
if (DALLE_REVERSE_PROXY) {
@ -91,12 +85,8 @@ class DALLE3 extends Tool {
.trim();
}
getMarkdownImageUrl(imageName) {
const imageUrl = path
.join(this.relativeImageUrl, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
return `![generated image](/${imageUrl})`;
wrapInMarkdown(imageUrl) {
return `![generated image](${imageUrl})`;
}
async _call(data) {
@ -143,43 +133,19 @@ Error Message: ${error.message}`;
});
}
this.outputPath = path.resolve(
__dirname,
'..',
'..',
'..',
'..',
'..',
'client',
'public',
'images',
this.userId,
);
const appRoot = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client');
this.relativeImageUrl = path.relative(appRoot, this.outputPath);
try {
const result = await processFileURL({
fileStrategy: this.fileStrategy,
userId: this.userId,
URL: theImageUrl,
fileName: imageName,
basePath: 'images',
});
// Check if directory exists, if not create it
if (!fs.existsSync(this.outputPath)) {
fs.mkdirSync(this.outputPath, { recursive: true });
}
const storage = getFirebaseStorage();
if (storage) {
try {
await saveImageToFirebaseStorage(this.userId, theImageUrl, imageName);
this.result = await getFirebaseStorageImageUrl(`${this.userId}/${imageName}`);
logger.debug('[DALL-E-3] result: ' + this.result);
} catch (error) {
logger.error('Error while saving the image to Firebase Storage:', error);
this.result = `Failed to save the image to Firebase Storage. ${error.message}`;
}
} else {
try {
await saveImageFromUrl(theImageUrl, this.outputPath, imageName);
this.result = this.getMarkdownImageUrl(imageName);
} catch (error) {
logger.error('Error while saving the image locally:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}
this.result = this.wrapInMarkdown(result);
} catch (error) {
logger.error('Error while saving the image:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}
return this.result;

View file

@ -1,20 +1,13 @@
const fs = require('fs');
const path = require('path');
const OpenAI = require('openai');
const DALLE3 = require('../DALLE3');
const {
getFirebaseStorage,
saveImageToFirebaseStorage,
} = require('~/server/services/Files/Firebase');
const saveImageFromUrl = require('../../saveImageFromUrl');
const { processFileURL } = require('~/server/services/Files/process');
const { logger } = require('~/config');
jest.mock('openai');
jest.mock('~/server/services/Files/Firebase', () => ({
getFirebaseStorage: jest.fn(),
saveImageToFirebaseStorage: jest.fn(),
getFirebaseStorageImageUrl: jest.fn(),
jest.mock('~/server/services/Files/process', () => ({
processFileURL: jest.fn(),
}));
jest.mock('~/server/services/Files/images', () => ({
@ -50,10 +43,6 @@ jest.mock('fs', () => {
};
});
jest.mock('../../saveImageFromUrl', () => {
return jest.fn();
});
jest.mock('path', () => {
return {
resolve: jest.fn(),
@ -99,10 +88,8 @@ describe('DALLE3', () => {
it('should generate markdown image URL correctly', () => {
const imageName = 'test.png';
path.join.mockReturnValue('images/test.png');
path.relative.mockReturnValue('images/test.png');
const markdownImage = dalle.getMarkdownImageUrl(imageName);
expect(markdownImage).toBe('![generated image](/images/test.png)');
const markdownImage = dalle.wrapInMarkdown(imageName);
expect(markdownImage).toBe('![generated image](test.png)');
});
it('should call OpenAI API with correct parameters', async () => {
@ -122,11 +109,7 @@ describe('DALLE3', () => {
};
generate.mockResolvedValue(mockResponse);
saveImageFromUrl.mockResolvedValue(true);
fs.existsSync.mockReturnValue(true);
path.resolve.mockReturnValue('/fakepath/images');
path.join.mockReturnValue('/fakepath/images/img-test.png');
path.relative.mockReturnValue('images/img-test.png');
processFileURL.mockResolvedValue('http://example.com/img-test.png');
const result = await dalle._call(mockData);
@ -138,6 +121,7 @@ describe('DALLE3', () => {
prompt: mockData.prompt,
n: 1,
});
expect(result).toContain('![generated image]');
});
@ -184,23 +168,6 @@ describe('DALLE3', () => {
});
});
it('should create the directory if it does not exist', async () => {
const mockData = {
prompt: 'A test prompt',
};
const mockResponse = {
data: [
{
url: 'http://example.com/img-test.png',
},
],
};
generate.mockResolvedValue(mockResponse);
fs.existsSync.mockReturnValue(false); // Simulate directory does not exist
await dalle._call(mockData);
expect(fs.mkdirSync).toHaveBeenCalledWith(expect.any(String), { recursive: true });
});
it('should log an error and return the image URL if there is an error saving the image', async () => {
const mockData = {
prompt: 'A test prompt',
@ -214,31 +181,12 @@ describe('DALLE3', () => {
};
const error = new Error('Error while saving the image');
generate.mockResolvedValue(mockResponse);
saveImageFromUrl.mockRejectedValue(error);
processFileURL.mockRejectedValue(error);
const result = await dalle._call(mockData);
expect(logger.error).toHaveBeenCalledWith('Error while saving the image locally:', error);
expect(logger.error).toHaveBeenCalledWith('Error while saving the image:', error);
expect(result).toBe('Failed to save the image locally. Error while saving the image');
});
it('should save image to Firebase Storage if Firebase is initialized', async () => {
const mockData = {
prompt: 'A test prompt',
};
const mockImageUrl = 'http://example.com/img-test.png';
const mockResponse = { data: [{ url: mockImageUrl }] };
generate.mockResolvedValue(mockResponse);
getFirebaseStorage.mockReturnValue({}); // Simulate Firebase being initialized
await dalle._call(mockData);
expect(getFirebaseStorage).toHaveBeenCalled();
expect(saveImageToFirebaseStorage).toHaveBeenCalledWith(
undefined,
mockImageUrl,
expect.any(String),
);
});
it('should handle error when saving image to Firebase Storage fails', async () => {
const mockData = {
prompt: 'A test prompt',
@ -247,17 +195,11 @@ describe('DALLE3', () => {
const mockResponse = { data: [{ url: mockImageUrl }] };
const error = new Error('Error while saving to Firebase');
generate.mockResolvedValue(mockResponse);
getFirebaseStorage.mockReturnValue({}); // Simulate Firebase being initialized
saveImageToFirebaseStorage.mockRejectedValue(error);
processFileURL.mockRejectedValue(error);
const result = await dalle._call(mockData);
expect(logger.error).toHaveBeenCalledWith(
'Error while saving the image to Firebase Storage:',
error,
);
expect(result).toBe(
'Failed to save the image to Firebase Storage. Error while saving to Firebase',
);
expect(logger.error).toHaveBeenCalledWith('Error while saving the image:', error);
expect(result).toContain('Failed to save the image');
});
});

View file

@ -170,6 +170,8 @@ const loadTools = async ({
const toolOptions = {
serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' },
dalle: { fileStrategy: options.fileStrategy },
'dall-e': { fileStrategy: options.fileStrategy },
};
const toolAuthFields = {};