🤖 feat: Gemini 1.5 Support (+Vertex AI) (#2383)

* WIP: gemini-1.5 support

* feat: extended vertex ai support

* fix: handle possibly undefined modelName

* fix: gpt-4-turbo-preview invalid vision model

* feat: specify `fileConfig.imageOutputType` and make PNG default image conversion type

* feat: better truncation for errors including base64 strings

* fix: gemini inlineData formatting

* feat: RAG augmented prompt for gemini-1.5

* feat: gemini-1.5 rates and token window

* chore: adjust tokens, update docs, update vision Models

* chore: add back `ChatGoogleVertexAI` for chat models via vertex ai

* refactor: ask/edit controllers to not use `unfinished` field for google endpoint

* chore: remove comment

* chore(ci): fix AppService test

* chore: remove comment

* refactor(GoogleSearch): use `GOOGLE_SEARCH_API_KEY` instead, issue warning for old variable

* chore: bump data-provider to 0.5.4

* chore: update docs

* fix: condition for gemini-1.5 using generative ai lib

* chore: update docs

* ci: add additional AppService test for `imageOutputType`

* refactor: optimize new config value `imageOutputType`

* chore: bump CONFIG_VERSION

* fix(assistants): avatar upload
This commit is contained in:
Danny Avila 2024-04-16 08:32:40 -04:00 committed by GitHub
parent fce7246ac1
commit 9d854dac07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 1030 additions and 258 deletions

View file

@ -1,7 +1,9 @@
const { google } = require('googleapis');
const { Agent, ProxyAgent } = require('undici');
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { GoogleVertexAI } = require('@langchain/community/llms/googlevertexai');
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
@ -10,6 +12,7 @@ const {
getResponseSender,
endpointSettings,
EModelEndpoint,
VisionModes,
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
@ -126,7 +129,7 @@ class GoogleClient extends BaseClient {
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
/** @type {boolean} Whether using a "GenerativeAI" Model */
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
const { isGenerativeModel } = this;
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
@ -247,6 +250,40 @@ class GoogleClient extends BaseClient {
})).bind(this);
}
/**
* Formats messages for generative AI
* @param {TMessage[]} messages
* @returns
*/
async formatGenerativeMessages(messages) {
const formattedMessages = [];
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
const files = await this.addImageURLs(latestMessage, attachments, VisionModes.generative);
this.options.attachments = files;
messages[messages.length - 1] = latestMessage;
for (const _message of messages) {
const role = _message.isCreatedByUser ? this.userLabel : this.modelLabel;
const parts = [];
parts.push({ text: _message.text });
if (!_message.image_urls?.length) {
formattedMessages.push({ role, parts });
continue;
}
for (const images of _message.image_urls) {
if (images.inlineData) {
parts.push({ inlineData: images.inlineData });
}
}
formattedMessages.push({ role, parts });
}
return formattedMessages;
}
/**
*
* Adds image URLs to the message object and returns the files
@ -255,17 +292,23 @@ class GoogleClient extends BaseClient {
* @param {MongoFile[]} files
* @returns {Promise<MongoFile[]>}
*/
async addImageURLs(message, attachments) {
async addImageURLs(message, attachments, mode = '') {
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments,
EModelEndpoint.google,
mode,
);
message.image_urls = image_urls.length ? image_urls : undefined;
return files;
}
async buildVisionMessages(messages = [], parentMessageId) {
/**
* Builds the augmented prompt for attachments
* TODO: Add File API Support
* @param {TMessage[]} messages
*/
async buildAugmentedPrompt(messages = []) {
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
this.contextHandlers = createContextHandlers(this.options.req, latestMessage.text);
@ -281,6 +324,12 @@ class GoogleClient extends BaseClient {
this.augmentedPrompt = await this.contextHandlers.createContext();
this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix;
}
}
async buildVisionMessages(messages = [], parentMessageId) {
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
await this.buildAugmentedPrompt(messages);
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
@ -301,15 +350,26 @@ class GoogleClient extends BaseClient {
return { prompt: payload };
}
/** @param {TMessage[]} [messages=[]] */
async buildGenerativeMessages(messages = []) {
this.userLabel = 'user';
this.modelLabel = 'model';
const promises = [];
promises.push(await this.formatGenerativeMessages(messages));
promises.push(this.buildAugmentedPrompt(messages));
const [formattedMessages] = await Promise.all(promises);
return { prompt: formattedMessages };
}
async buildMessages(messages = [], parentMessageId) {
if (!this.isGenerativeModel && !this.project_id) {
throw new Error(
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
);
} else if (this.isGenerativeModel && (!this.apiKey || this.apiKey === 'user_provided')) {
throw new Error(
'[GoogleClient] an API Key is required for Gemini models (Generative Language API)',
);
}
if (!this.project_id && this.modelOptions.model.includes('1.5')) {
return await this.buildGenerativeMessages(messages);
}
if (this.options.attachments && this.isGenerativeModel) {
@ -526,13 +586,24 @@ class GoogleClient extends BaseClient {
}
createLLM(clientOptions) {
if (this.isGenerativeModel) {
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
const model = clientOptions.modelName ?? clientOptions.model;
if (this.project_id && this.isTextModel) {
return new GoogleVertexAI(clientOptions);
} else if (this.project_id && this.isChatModel) {
return new ChatGoogleVertexAI(clientOptions);
} else if (this.project_id) {
return new ChatVertexAI(clientOptions);
} else if (model.includes('1.5')) {
return new GenAI(this.apiKey).getGenerativeModel(
{
...clientOptions,
model,
},
{ apiVersion: 'v1beta' },
);
}
return this.isTextModel
? new GoogleVertexAI(clientOptions)
: new ChatGoogleVertexAI(clientOptions);
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}
async getCompletion(_payload, options = {}) {
@ -544,7 +615,7 @@ class GoogleClient extends BaseClient {
let clientOptions = { ...parameters, maxRetries: 2 };
if (!this.isGenerativeModel) {
if (this.project_id) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
@ -557,7 +628,7 @@ class GoogleClient extends BaseClient {
clientOptions = { ...clientOptions, ...this.modelOptions };
}
if (this.isGenerativeModel) {
if (this.isGenerativeModel && !this.project_id) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}
@ -588,16 +659,46 @@ class GoogleClient extends BaseClient {
messages.unshift(new SystemMessage(context));
}
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
contents: _payload,
};
if (this.options?.promptPrefix?.length) {
requestOptions.systemInstruction = {
parts: [
{
text: this.options.promptPrefix,
},
],
};
}
const result = await client.generateContentStream(requestOptions);
for await (const chunk of result.stream) {
const chunkText = chunk.text();
this.generateTextStream(chunkText, onProgress, {
delay: 12,
});
reply += chunkText;
}
return reply;
}
const stream = await model.stream(messages, {
signal: abortController.signal,
timeout: 7000,
});
for await (const chunk of stream) {
await this.generateTextStream(chunk?.content ?? chunk, onProgress, {
const chunkText = chunk?.content ?? chunk;
this.generateTextStream(chunkText, onProgress, {
delay: this.isGenerativeModel ? 12 : 8,
});
reply += chunk?.content ?? chunk;
reply += chunkText;
}
return reply;

View file

@ -13,7 +13,7 @@ module.exports = {
...handleInputs,
...instructions,
...titlePrompts,
truncateText,
...truncateText,
createVisionPrompt,
createContextHandlers,
};

View file

@ -1,10 +1,40 @@
const MAX_CHAR = 255;
function truncateText(text) {
if (text.length > MAX_CHAR) {
return `${text.slice(0, MAX_CHAR)}... [text truncated for brevity]`;
/**
* Truncates a given text to a specified maximum length, appending ellipsis and a notification
* if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text.
*/
function truncateText(text, maxLength = MAX_CHAR) {
if (text.length > maxLength) {
return `${text.slice(0, maxLength)}... [text truncated for brevity]`;
}
return text;
}
module.exports = truncateText;
/**
* Truncates a given text to a specified maximum length by showing the first half and the last half of the text,
* separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition
* of ellipsis and notification if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength.
*/
function smartTruncateText(text, maxLength = MAX_CHAR) {
const ellipsis = '...';
const notification = ' [text truncated for brevity]';
const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2);
if (text.length > maxLength) {
const startLastHalf = text.length - halfMaxLength;
return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`;
}
return text;
}
module.exports = { truncateText, smartTruncateText };

View file

@ -24,7 +24,7 @@
"description": "This is your Google Custom Search Engine ID. For instructions on how to obtain this, see <a href='https://github.com/danny-avila/LibreChat/blob/main/docs/features/plugins/google_search.md'>Our Docs</a>."
},
{
"authField": "GOOGLE_API_KEY",
"authField": "GOOGLE_SEARCH_API_KEY",
"label": "Google API Key",
"description": "This is your Google Custom Search API Key. For instructions on how to obtain this, see <a href='https://github.com/danny-avila/LibreChat/blob/main/docs/features/plugins/google_search.md'>Our Docs</a>."
}

View file

@ -9,7 +9,7 @@ class GoogleSearchResults extends Tool {
constructor(fields = {}) {
super(fields);
this.envVarApiKey = 'GOOGLE_API_KEY';
this.envVarApiKey = 'GOOGLE_SEARCH_API_KEY';
this.envVarSearchEngineId = 'GOOGLE_CSE_ID';
this.override = fields.override ?? false;
this.apiKey = fields.apiKey ?? getEnvironmentVariable(this.envVarApiKey);

View file

@ -25,6 +25,10 @@ const tokenValues = {
/* cohere doesn't have rates for the older command models,
so this was from https://artificialanalysis.ai/models/command-light/providers */
command: { prompt: 0.38, completion: 0.38 },
// 'gemini-1.5': { prompt: 7, completion: 21 }, // May 2nd, 2024 pricing
// 'gemini': { prompt: 0.5, completion: 1.5 }, // May 2nd, 2024 pricing
'gemini-1.5': { prompt: 0, completion: 0 }, // currently free
gemini: { prompt: 0, completion: 0 }, // currently free
};
/**

View file

@ -35,10 +35,12 @@
"dependencies": {
"@anthropic-ai/sdk": "^0.16.1",
"@azure/search-documents": "^12.0.0",
"@google/generative-ai": "^0.5.0",
"@keyv/mongo": "^2.1.8",
"@keyv/redis": "^2.8.1",
"@langchain/community": "^0.0.17",
"@langchain/google-genai": "^0.0.8",
"@langchain/community": "^0.0.46",
"@langchain/google-genai": "^0.0.11",
"@langchain/google-vertexai": "^0.0.5",
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",

View file

@ -1,5 +1,5 @@
const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
@ -48,7 +48,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
try {
const { client } = await initializeClient({ req, res, endpointOption });
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
@ -59,7 +59,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: client.modelOptions.model,
unfinished: true,
unfinished,
error: false,
user,
});

View file

@ -1,5 +1,5 @@
const throttle = require('lodash/throttle');
const { getResponseSender } = require('librechat-data-provider');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
@ -48,6 +48,7 @@ const EditController = async (req, res, next, initializeClient) => {
}
};
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: throttle(
@ -59,7 +60,7 @@ const EditController = async (req, res, next, initializeClient) => {
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
unfinished,
isEdited: true,
error: false,
user,

View file

@ -1,9 +1,9 @@
const { EModelEndpoint } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
const clearPendingReq = require('~/cache/clearPendingReq');
const abortControllers = require('./abortControllers');
const { redactMessage } = require('~/config/parsers');
const spendTokens = require('~/models/spendTokens');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
@ -100,7 +100,15 @@ const createAbortController = (req, res, getAbortData) => {
};
const handleAbortError = async (res, req, error, data) => {
logger.error('[handleAbortError] AI response error; aborting request:', error);
if (error?.message?.includes('base64')) {
logger.error('[handleAbortError] Error in base64 encoding', {
...error,
stack: smartTruncateText(error?.stack, 1000),
message: truncateText(error.message, 350),
});
} else {
logger.error('[handleAbortError] AI response error; aborting request:', error);
}
const { sender, conversationId, messageId, parentMessageId, partialText } = data;
if (error.stack && error.stack.includes('google')) {
@ -109,13 +117,15 @@ const handleAbortError = async (res, req, error, data) => {
);
}
const errorText = 'An error occurred while processing your request. Please contact the Admin.';
const respondWithError = async (partialText) => {
let options = {
sender,
messageId,
conversationId,
parentMessageId,
text: redactMessage(error.message),
text: errorText,
shouldSaveMessage: true,
user: req.user.id,
};

View file

@ -213,7 +213,13 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) =>
/** @type {{ openai: OpenAI }} */
const { openai } = await initializeClient({ req, res });
const image = await uploadImageBuffer({ req, context: FileContext.avatar });
const image = await uploadImageBuffer({
req,
context: FileContext.avatar,
metadata: {
buffer: req.file.buffer,
},
});
try {
_metadata = JSON.parse(_metadata);

View file

@ -18,13 +18,15 @@ router.post('/', upload.single('input'), async (req, res) => {
}
const fileStrategy = req.app.locals.fileStrategy;
const webPBuffer = await resizeAvatar({
const desiredFormat = req.app.locals.imageOutputType;
const resizedBuffer = await resizeAvatar({
userId,
input,
desiredFormat,
});
const { processAvatar } = getStrategyFunctions(fileStrategy);
const url = await processAvatar({ buffer: webPBuffer, userId, manual });
const url = await processAvatar({ buffer: resizedBuffer, userId, manual });
res.json({ url });
} catch (error) {

View file

@ -3,6 +3,7 @@ const {
FileSources,
Capabilities,
EModelEndpoint,
EImageOutputType,
defaultSocialLogins,
validateAzureGroups,
mapModelToAzureConfig,
@ -181,6 +182,7 @@ const AppService = async (app) => {
fileConfig: config?.fileConfig,
interface: config?.interface,
secureImageLinks: config?.secureImageLinks,
imageOutputType: config?.imageOutputType?.toLowerCase() ?? EImageOutputType.PNG,
paths,
...endpointLocals,
};
@ -204,6 +206,12 @@ const AppService = async (app) => {
`,
);
}
if (process.env.GOOGLE_API_KEY) {
logger.warn(
'The `GOOGLE_API_KEY` environment variable is deprecated.\nPlease use the `GOOGLE_SEARCH_API_KEY` environment variable instead.',
);
}
};
module.exports = AppService;

View file

@ -1,6 +1,7 @@
const {
FileSources,
EModelEndpoint,
EImageOutputType,
defaultSocialLogins,
validateAzureGroups,
deprecatedAzureVariables,
@ -107,6 +108,10 @@ describe('AppService', () => {
},
},
paths: expect.anything(),
imageOutputType: expect.any(String),
interface: undefined,
fileConfig: undefined,
secureImageLinks: undefined,
});
});
@ -125,6 +130,31 @@ describe('AppService', () => {
expect(logger.info).toHaveBeenCalledWith(expect.stringContaining('Outdated Config version'));
});
it('should change the `imageOutputType` based on config value', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({
version: '0.10.0',
imageOutputType: EImageOutputType.WEBP,
}),
);
await AppService(app);
expect(app.locals.imageOutputType).toEqual(EImageOutputType.WEBP);
});
it('should default to `PNG` `imageOutputType` with no provided type', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({
version: '0.10.0',
}),
);
await AppService(app);
expect(app.locals.imageOutputType).toEqual(EImageOutputType.PNG);
});
it('should initialize Firebase when fileStrategy is firebase', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({

View file

@ -1,5 +1,5 @@
const path = require('path');
const { CacheKeys, configSchema } = require('librechat-data-provider');
const { CacheKeys, configSchema, EImageOutputType } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const loadYaml = require('~/utils/loadYaml');
const { logger } = require('~/config');
@ -55,6 +55,20 @@ async function loadCustomConfig() {
}
const result = configSchema.strict().safeParse(customConfig);
if (result?.error?.errors?.some((err) => err?.path && err.path?.includes('imageOutputType'))) {
throw new Error(
`
Please specify a correct \`imageOutputType\` value (case-sensitive).
The available options are:
- ${EImageOutputType.JPEG}
- ${EImageOutputType.PNG}
- ${EImageOutputType.WEBP}
Refer to the latest config file guide for more information:
https://docs.librechat.ai/install/configuration/custom_config.html`,
);
}
if (!result.success) {
i === 0 && logger.error(`Invalid custom config file at ${configPath}`, result.error);
i === 0 && i++;

View file

@ -8,7 +8,7 @@ const { updateFile } = require('~/models/File');
const { logger } = require('~/config');
/**
* Converts an image file to the WebP format. The function first resizes the image based on the specified
* Converts an image file to the target format. The function first resizes the image based on the specified
* resolution.
*
* @param {Object} params - The params object.
@ -21,7 +21,7 @@ const { logger } = require('~/config');
*
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number}>}
* A promise that resolves to an object containing:
* - filepath: The path where the converted WebP image is saved.
* - filepath: The path where the converted image is saved.
* - bytes: The size of the converted image in bytes.
* - width: The width of the converted image.
* - height: The height of the converted image.
@ -39,15 +39,16 @@ async function uploadImageToFirebase({ req, file, file_id, endpoint, resolution
let webPBuffer;
let fileName = `${file_id}__${path.basename(inputFilePath)}`;
if (extension.toLowerCase() === '.webp') {
const targetExtension = `.${req.app.locals.imageOutputType}`;
if (extension.toLowerCase() === targetExtension) {
webPBuffer = resizedBuffer;
} else {
webPBuffer = await sharp(resizedBuffer).toFormat('webp').toBuffer();
webPBuffer = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer();
// Replace or append the correct extension
const extRegExp = new RegExp(path.extname(fileName) + '$');
fileName = fileName.replace(extRegExp, '.webp');
fileName = fileName.replace(extRegExp, targetExtension);
if (!path.extname(fileName)) {
fileName += '.webp';
fileName += targetExtension;
}
}
@ -79,7 +80,7 @@ async function prepareImageURL(req, file) {
* If the 'manual' flag is set to 'true', it also updates the user's avatar URL in the database.
*
* @param {object} params - The parameters object.
* @param {Buffer} params.buffer - The Buffer containing the avatar image in WebP format.
* @param {Buffer} params.buffer - The Buffer containing the avatar image.
* @param {string} params.userId - The user ID.
* @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false').
* @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar.

View file

@ -6,11 +6,11 @@ const { updateUser } = require('~/models/userMethods');
const { updateFile } = require('~/models/File');
/**
* Converts an image file to the WebP format. The function first resizes the image based on the specified
* Converts an image file to the target format. The function first resizes the image based on the specified
* resolution.
*
* If the original image is already in WebP format, it writes the resized image back. Otherwise,
* it converts the image to WebP format before saving.
* If the original image is already in target format, it writes the resized image back. Otherwise,
* it converts the image to target format before saving.
*
* The original image is deleted after conversion.
* @param {Object} params - The params object.
@ -24,7 +24,7 @@ const { updateFile } = require('~/models/File');
*
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number}>}
* A promise that resolves to an object containing:
* - filepath: The path where the converted WebP image is saved.
* - filepath: The path where the converted image is saved.
* - bytes: The size of the converted image in bytes.
* - width: The width of the converted image.
* - height: The height of the converted image.
@ -48,16 +48,17 @@ async function uploadLocalImage({ req, file, file_id, endpoint, resolution = 'hi
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const newPath = path.join(userPath, fileName);
const targetExtension = `.${req.app.locals.imageOutputType}`;
if (extension.toLowerCase() === '.webp') {
if (extension.toLowerCase() === targetExtension) {
const bytes = Buffer.byteLength(resizedBuffer);
await fs.promises.writeFile(newPath, resizedBuffer);
const filepath = path.posix.join('/', 'images', req.user.id, path.basename(newPath));
return { filepath, bytes, width, height };
}
const outputFilePath = newPath.replace(extension, '.webp');
const data = await sharp(resizedBuffer).toFormat('webp').toBuffer();
const outputFilePath = newPath.replace(extension, targetExtension);
const data = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer();
await fs.promises.writeFile(outputFilePath, data);
const bytes = Buffer.byteLength(data);
const filepath = path.posix.join('/', 'images', req.user.id, path.basename(outputFilePath));
@ -109,7 +110,7 @@ async function prepareImagesLocal(req, file) {
* If the 'manual' flag is set to 'true', it also updates the user's avatar URL in the database.
*
* @param {object} params - The parameters object.
* @param {Buffer} params.buffer - The Buffer containing the avatar image in WebP format.
* @param {Buffer} params.buffer - The Buffer containing the avatar image.
* @param {string} params.userId - The user ID.
* @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false').
* @returns {Promise<string>} - A promise that resolves with the URL of the uploaded avatar.

View file

@ -6,10 +6,11 @@ const { logger } = require('~/config');
/**
* Uploads an avatar image for a user. This function can handle various types of input (URL, Buffer, or File object),
* processes the image to a square format, converts it to WebP format, and returns the resized buffer.
* processes the image to a square format, converts it to target format, and returns the resized buffer.
*
* @param {Object} params - The parameters object.
* @param {string} params.userId - The unique identifier of the user for whom the avatar is being uploaded.
* @param {string} options.desiredFormat - The desired output format of the image.
* @param {(string|Buffer|File)} params.input - The input representing the avatar image. Can be a URL (string),
* a Buffer, or a File object.
*
@ -19,7 +20,7 @@ const { logger } = require('~/config');
* @throws {Error} Throws an error if the user ID is undefined, the input type is invalid, the image fetching fails,
* or any other error occurs during the processing.
*/
async function resizeAvatar({ userId, input }) {
async function resizeAvatar({ userId, input, desiredFormat }) {
try {
if (userId === undefined) {
throw new Error('User ID is undefined');
@ -53,7 +54,10 @@ async function resizeAvatar({ userId, input }) {
})
.toBuffer();
const { buffer } = await resizeAndConvert(squaredBuffer);
const { buffer } = await resizeAndConvert({
inputBuffer: squaredBuffer,
desiredFormat,
});
return buffer;
} catch (error) {
logger.error('Error uploading the avatar:', error);

View file

@ -6,7 +6,7 @@ const { getStrategyFunctions } = require('../strategies');
const { logger } = require('~/config');
/**
* Converts an image file or buffer to WebP format with specified resolution.
* Converts an image file or buffer to target output type with specified resolution.
*
* @param {Express.Request} req - The request object, containing user and app configuration data.
* @param {Buffer | Express.Multer.File} file - The file object, containing either a path or a buffer.
@ -15,7 +15,7 @@ const { logger } = require('~/config');
* @returns {Promise<{filepath: string, bytes: number, width: number, height: number}>} An object containing the path, size, and dimensions of the converted image.
* @throws Throws an error if there is an issue during the conversion process.
*/
async function convertToWebP(req, file, resolution = 'high', basename = '') {
async function convertImage(req, file, resolution = 'high', basename = '') {
try {
let inputBuffer;
let outputBuffer;
@ -38,13 +38,13 @@ async function convertToWebP(req, file, resolution = 'high', basename = '') {
height,
} = await resizeImageBuffer(inputBuffer, resolution);
// Check if the file is already in WebP format
// If it isn't, convert it:
if (extension === '.webp') {
// Check if the file is already in target format; if it isn't, convert it:
const targetExtension = `.${req.app.locals.imageOutputType}`;
if (extension === targetExtension) {
outputBuffer = resizedBuffer;
} else {
outputBuffer = await sharp(resizedBuffer).toFormat('webp').toBuffer();
extension = '.webp';
outputBuffer = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer();
extension = targetExtension;
}
// Generate a new filename for the output file
@ -67,4 +67,4 @@ async function convertToWebP(req, file, resolution = 'high', basename = '') {
}
}
module.exports = { convertToWebP };
module.exports = { convertImage };

View file

@ -1,5 +1,5 @@
const axios = require('axios');
const { EModelEndpoint, FileSources } = require('librechat-data-provider');
const { EModelEndpoint, FileSources, VisionModes } = require('librechat-data-provider');
const { getStrategyFunctions } = require('../strategies');
const { logger } = require('~/config');
@ -30,11 +30,20 @@ const base64Only = new Set([EModelEndpoint.google, EModelEndpoint.anthropic]);
* @param {Express.Request} req - The request object.
* @param {Array<MongoFile>} files - The array of files to encode and format.
* @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image.
* @param {string} [mode] - Optional: The endpoint mode for the image.
* @returns {Promise<Object>} - A promise that resolves to the result object containing the encoded images and file details.
*/
async function encodeAndFormat(req, files, endpoint) {
async function encodeAndFormat(req, files, endpoint, mode) {
const promises = [];
const encodingMethods = {};
const result = {
files: [],
image_urls: [],
};
if (!files || !files.length) {
return result;
}
for (let file of files) {
const source = file.source ?? FileSources.local;
@ -69,11 +78,6 @@ async function encodeAndFormat(req, files, endpoint) {
/** @type {Array<[MongoFile, string]>} */
const formattedImages = await Promise.all(promises);
const result = {
files: [],
image_urls: [],
};
for (const [file, imageContent] of formattedImages) {
const fileMetadata = {
type: file.type,
@ -98,12 +102,18 @@ async function encodeAndFormat(req, files, endpoint) {
image_url: {
url: imageContent.startsWith('http')
? imageContent
: `data:image/webp;base64,${imageContent}`,
: `data:${file.type};base64,${imageContent}`,
detail,
},
};
if (endpoint && endpoint === EModelEndpoint.google) {
if (endpoint && endpoint === EModelEndpoint.google && mode === VisionModes.generative) {
delete imagePart.image_url;
imagePart.inlineData = {
mimeType: file.type,
data: imageContent,
};
} else if (endpoint && endpoint === EModelEndpoint.google) {
imagePart.image_url = imagePart.image_url.url;
} else if (endpoint && endpoint === EModelEndpoint.anthropic) {
imagePart.type = 'image';

View file

@ -62,14 +62,20 @@ async function resizeImageBuffer(inputBuffer, resolution, endpoint) {
}
/**
* Resizes an image buffer to webp format as well as reduces by specified or default 150 px width.
* Resizes an image buffer to a specified format and width.
*
* @param {Buffer} inputBuffer - The buffer of the image to be resized.
* @returns {Promise<{ buffer: Buffer, width: number, height: number, bytes: number }>} An object containing the resized image buffer, its size and dimensions.
* @throws Will throw an error if the resolution parameter is invalid.
* @param {Object} options - The options for resizing and converting the image.
* @param {Buffer} options.inputBuffer - The buffer of the image to be resized.
* @param {string} options.desiredFormat - The desired output format of the image.
* @param {number} [options.width=150] - The desired width of the image. Defaults to 150 pixels.
* @returns {Promise<{ buffer: Buffer, width: number, height: number, bytes: number }>} An object containing the resized image buffer, its size, and dimensions.
* @throws Will throw an error if the resolution or format parameters are invalid.
*/
async function resizeAndConvert(inputBuffer, width = 150) {
const resizedBuffer = await sharp(inputBuffer).resize({ width }).toFormat('webp').toBuffer();
async function resizeAndConvert({ inputBuffer, desiredFormat, width = 150 }) {
const resizedBuffer = await sharp(inputBuffer)
.resize({ width })
.toFormat(desiredFormat)
.toBuffer();
const resizedMetadata = await sharp(resizedBuffer).metadata();
return {
buffer: resizedBuffer,

View file

@ -12,7 +12,7 @@ const {
hostImageIdSuffix,
hostImageNamePrefix,
} = require('librechat-data-provider');
const { convertToWebP, resizeAndConvert } = require('~/server/services/Files/images');
const { convertImage, resizeAndConvert } = require('~/server/services/Files/images');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { createFile, updateFileUsage, deleteFiles } = require('~/models/File');
const { LB_QueueAsyncCall } = require('~/server/utils/queue');
@ -207,7 +207,7 @@ const processImageFile = async ({ req, res, file, metadata }) => {
filename: file.originalname,
context: FileContext.message_attachment,
source,
type: 'image/webp',
type: `image/${req.app.locals.imageOutputType}`,
width,
height,
},
@ -223,9 +223,9 @@ const processImageFile = async ({ req, res, file, metadata }) => {
* @param {Object} params - The parameters object.
* @param {Express.Request} params.req - The Express request object.
* @param {FileContext} params.context - The context of the file (e.g., 'avatar', 'image_generation', etc.)
* @param {boolean} [params.resize=true] - Whether to resize and convert the image to WebP. Default is `true`.
* @param {boolean} [params.resize=true] - Whether to resize and convert the image to target format. Default is `true`.
* @param {{ buffer: Buffer, width: number, height: number, bytes: number, filename: string, type: string, file_id: string }} [params.metadata] - Required metadata for the file if resize is false.
* @returns {Promise<{ filepath: string, filename: string, source: string, type: 'image/webp'}>}
* @returns {Promise<{ filepath: string, filename: string, source: string, type: string}>}
*/
const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true }) => {
const source = req.app.locals.fileStrategy;
@ -233,9 +233,14 @@ const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true })
let { buffer, width, height, bytes, filename, file_id, type } = metadata;
if (resize) {
file_id = v4();
type = 'image/webp';
({ buffer, width, height, bytes } = await resizeAndConvert(req.file.buffer));
filename = path.basename(req.file.originalname, path.extname(req.file.originalname)) + '.webp';
type = `image/${req.app.locals.imageOutputType}`;
({ buffer, width, height, bytes } = await resizeAndConvert({
inputBuffer: buffer,
desiredFormat: req.app.locals.imageOutputType,
}));
filename = `${path.basename(req.file.originalname, path.extname(req.file.originalname))}.${
req.app.locals.imageOutputType
}`;
}
const filepath = await saveBuffer({ userId: req.user.id, fileName: filename, buffer });
@ -363,7 +368,7 @@ const processOpenAIFile = async ({
};
/**
* Process OpenAI image files, convert to webp, save and return file metadata.
* Process OpenAI image files, convert to target format, save and return file metadata.
* @param {object} params - The params object.
* @param {Express.Request} params.req - The Express request object.
* @param {Buffer} params.buffer - The image buffer.
@ -375,12 +380,12 @@ const processOpenAIFile = async ({
const processOpenAIImageOutput = async ({ req, buffer, file_id, filename, fileExt }) => {
const currentDate = new Date();
const formattedDate = currentDate.toISOString();
const _file = await convertToWebP(req, buffer, 'high', `${file_id}${fileExt}`);
const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`);
const file = {
..._file,
usage: 1,
user: req.user.id,
type: 'image/webp',
type: `image/${req.app.locals.imageOutputType}`,
createdAt: formattedDate,
updatedAt: formattedDate,
source: req.app.locals.fileStrategy,

View file

@ -25,12 +25,12 @@ const handleExistingUser = async (oldUser, avatarUrl) => {
await oldUser.save();
} else if (!isLocal && (oldUser.avatar === null || !oldUser.avatar.includes('?manual=true'))) {
const userId = oldUser._id;
const webPBuffer = await resizeAvatar({
const resizedBuffer = await resizeAvatar({
userId,
input: avatarUrl,
});
const { processAvatar } = getStrategyFunctions(fileStrategy);
oldUser.avatar = await processAvatar({ buffer: webPBuffer, userId });
oldUser.avatar = await processAvatar({ buffer: resizedBuffer, userId });
await oldUser.save();
}
};
@ -83,12 +83,12 @@ const createNewUser = async ({
if (!isLocal) {
const userId = newUser._id;
const webPBuffer = await resizeAvatar({
const resizedBuffer = await resizeAvatar({
userId,
input: avatarUrl,
});
const { processAvatar } = getStrategyFunctions(fileStrategy);
newUser.avatar = await processAvatar({ buffer: webPBuffer, userId });
newUser.avatar = await processAvatar({ buffer: resizedBuffer, userId });
await newUser.save();
}

View file

@ -14,6 +14,12 @@
* @memberof typedefs
*/
/**
* @exports GenerativeModel
* @typedef {import('@google/generative-ai').GenerativeModel} GenerativeModel
* @memberof typedefs
*/
/**
* @exports AssistantStreamEvent
* @typedef {import('openai').default.Beta.AssistantStreamEvent} AssistantStreamEvent
@ -295,6 +301,12 @@
* @memberof typedefs
*/
/**
* @exports EImageOutputType
* @typedef {import('librechat-data-provider').EImageOutputType} EImageOutputType
* @memberof typedefs
*/
/**
* @exports TCustomConfig
* @typedef {import('librechat-data-provider').TCustomConfig} TCustomConfig

View file

@ -65,12 +65,14 @@ const cohereModels = {
command: 4086, // -10 from max
'command-nightly': 8182, // -10 from max
'command-r': 127500, // -500 from max
'command-r-plus:': 127500, // -500 from max
'command-r-plus': 127500, // -500 from max
};
const googleModels = {
/* Max I/O is combined so we subtract the amount from max response tokens for actual total */
gemini: 32750, // -10 from max
gemini: 30720, // -2048 from max
'gemini-pro-vision': 12288, // -4096 from max
'gemini-1.5': 1048576, // -8192 from max
'text-bison-32k': 32758, // -10 from max
'chat-bison-32k': 32758, // -10 from max
'code-bison-32k': 32758, // -10 from max

View file

@ -131,6 +131,18 @@ describe('getModelMaxTokens', () => {
});
test('should return correct tokens for partial match - Google models', () => {
expect(getModelMaxTokens('gemini-1.5-pro-latest', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['gemini-1.5'],
);
expect(getModelMaxTokens('gemini-1.5-pro-preview-0409', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['gemini-1.5'],
);
expect(getModelMaxTokens('gemini-pro-vision', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['gemini-pro-vision'],
);
expect(getModelMaxTokens('gemini-1.0', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['gemini'],
);
expect(getModelMaxTokens('gemini-pro', EModelEndpoint.google)).toBe(
maxTokensMap[EModelEndpoint.google]['gemini'],
);
@ -142,6 +154,15 @@ describe('getModelMaxTokens', () => {
);
});
test('should return correct tokens for partial match - Cohere models', () => {
expect(getModelMaxTokens('command', EModelEndpoint.custom)).toBe(
maxTokensMap[EModelEndpoint.custom]['command'],
);
expect(getModelMaxTokens('command-r-plus', EModelEndpoint.custom)).toBe(
maxTokensMap[EModelEndpoint.custom]['command-r-plus'],
);
});
test('should return correct tokens when using a custom endpointTokenConfig', () => {
const customTokenConfig = {
'custom-model': 12345,