📸 feat: Gemini vision, Improved Logs and Multi-modal Handling (#1368)

* feat: add GOOGLE_MODELS env var

* feat: add gemini vision support

* refactor(GoogleClient): adjust clientOptions handling depending on model

* fix(logger): fix redact logic and redact errors only

* fix(GoogleClient): do not allow non-multiModal messages when gemini-pro-vision is selected

* refactor(OpenAIClient): use `isVisionModel` client property to avoid calling validateVisionModel multiple times

* refactor: better debug logging by correctly traversing, redacting sensitive info, and logging condensed versions of long values

* refactor(GoogleClient): allow response errors to be thrown/caught above client handling so user receives meaningful error message
debug orderedMessages, parentMessageId, and buildMessages result

* refactor(AskController): use model from client.modelOptions.model when saving intermediate messages, which requires for the progress callback to be initialized after the client is initialized

* feat(useSSE): revert to previous model if the model was auto-switched by backend due to message attachments

* docs: update with google updates, notes about Gemini Pro Vision

* fix: redis should not be initialized without USE_REDIS and increase max listeners to 20
This commit is contained in:
Danny Avila 2023-12-16 20:45:27 -05:00 committed by GitHub
parent 676f133545
commit 0c326797dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 356 additions and 210 deletions

View file

@ -29,7 +29,7 @@
# Features
- 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and 11-2023 updates
- 💬 Multimodal Chat:
- Upload and analyze images with GPT-4-Vision 📸
- Upload and analyze images with GPT-4 and Gemini Vision 📸
- More filetypes and Assistants API integration in Active Development 🚧
- 🌎 Multilingual UI:
- English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro,

View file

@ -357,11 +357,11 @@ class BaseClient {
const promptTokens = this.maxContextTokens - remainingContextTokens;
logger.debug('[BaseClient] Payload size:', payload.length);
logger.debug('[BaseClient] tokenCountMap:', tokenCountMap);
logger.debug('[BaseClient]', {
promptTokens,
remainingContextTokens,
payloadSize: payload.length,
maxContextTokens: this.maxContextTokens,
});
@ -414,7 +414,6 @@ class BaseClient {
logger.debug('[BaseClient] tokenCountMap', tokenCountMap);
if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
logger.debug('[BaseClient] userMessage.tokenCount', userMessage.tokenCount);
logger.debug('[BaseClient] userMessage', userMessage);
}

View file

@ -4,6 +4,7 @@ const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
getResponseSender,
@ -122,9 +123,18 @@ class GoogleClient extends BaseClient {
// stop: modelOptions.stop // no stop method for now
};
if (this.options.attachments) {
this.modelOptions.model = 'gemini-pro-vision';
}
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
this.isVisionModel = validateVisionModel(this.modelOptions.model);
const { isGenerativeModel } = this;
if (this.isVisionModel && !this.options.attachments) {
this.modelOptions.model = 'gemini-pro';
this.isVisionModel = false;
}
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
const { isChatModel } = this;
this.isTextModel =
@ -216,7 +226,34 @@ class GoogleClient extends BaseClient {
})).bind(this);
}
buildMessages(messages = [], parentMessageId) {
async buildVisionMessages(messages = [], parentMessageId) {
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
const attachments = await this.options.attachments;
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments.filter((file) => file.type.includes('image')),
EModelEndpoint.google,
);
const latestMessage = { ...messages[messages.length - 1] };
latestMessage.image_urls = image_urls;
this.options.attachments = files;
latestMessage.text = prompt;
const payload = {
instances: [
{
messages: [new HumanMessage(formatMessage({ message: latestMessage }))],
},
],
parameters: this.modelOptions,
};
return { prompt: payload };
}
async buildMessages(messages = [], parentMessageId) {
if (!this.isGenerativeModel && !this.project_id) {
throw new Error(
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
@ -227,17 +264,24 @@ class GoogleClient extends BaseClient {
);
}
if (this.options.attachments) {
return this.buildVisionMessages(messages, parentMessageId);
}
if (this.isTextModel) {
return this.buildMessagesPrompt(messages, parentMessageId);
}
const formattedMessages = messages.map(this.formatMessages());
let payload = {
instances: [
{
messages: formattedMessages,
messages: messages
.map(this.formatMessages())
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
.map((message) => formatMessage({ message, langChain: true })),
},
],
parameters: this.options.modelOptions,
parameters: this.modelOptions,
};
if (this.options.promptPrefix) {
@ -248,9 +292,7 @@ class GoogleClient extends BaseClient {
payload.instances[0].examples = this.options.examples;
}
if (this.options.debug) {
logger.debug('GoogleClient buildMessages', payload);
}
logger.debug('[GoogleClient] buildMessages', payload);
return { prompt: payload };
}
@ -260,12 +302,11 @@ class GoogleClient extends BaseClient {
messages,
parentMessageId,
});
if (this.options.debug) {
logger.debug('GoogleClient: orderedMessages, parentMessageId', {
orderedMessages,
parentMessageId,
});
}
logger.debug('[GoogleClient]', {
orderedMessages,
parentMessageId,
});
const formattedMessages = orderedMessages.map((message) => ({
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
@ -394,7 +435,7 @@ class GoogleClient extends BaseClient {
context.shift();
}
let prompt = `${promptBody}${promptSuffix}`;
let prompt = `${promptBody}${promptSuffix}`.trim();
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
@ -453,20 +494,26 @@ class GoogleClient extends BaseClient {
let examples;
let clientOptions = {
authOptions: {
let clientOptions = { ...parameters, maxRetries: 2 };
if (!this.isGenerativeModel) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
},
projectId: this.project_id,
},
...parameters,
};
};
}
if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions };
}
if (this.isGenerativeModel) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}
if (_examples && _examples.length) {
examples = _examples
.map((ex) => {
@ -487,13 +534,9 @@ class GoogleClient extends BaseClient {
const model = this.createLLM(clientOptions);
let reply = '';
const messages = this.isTextModel
? _payload.trim()
: _messages
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
.map((message) => formatMessage({ message, langChain: true }));
const messages = this.isTextModel ? _payload.trim() : _messages;
if (context && messages?.length > 0) {
if (!this.isVisionModel && context && messages?.length > 0) {
messages.unshift(new SystemMessage(context));
}
@ -526,14 +569,7 @@ class GoogleClient extends BaseClient {
async sendCompletion(payload, opts = {}) {
let reply = '';
try {
reply = await this.getCompletion(payload, opts);
if (this.options.debug) {
logger.debug('GoogleClient sendCompletion', { reply });
}
} catch (err) {
logger.error('failed to send completion to Google', err);
}
reply = await this.getCompletion(payload, opts);
return reply.trim();
}

View file

@ -1,7 +1,7 @@
const OpenAI = require('openai');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils');
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
@ -76,11 +76,14 @@ class OpenAIClient extends BaseClient {
};
}
if (this.options.attachments && !validateVisionModel(this.modelOptions.model)) {
this.isVisionModel = validateVisionModel(this.modelOptions.model);
if (this.options.attachments && !this.isVisionModel) {
this.modelOptions.model = 'gpt-4-vision-preview';
this.isVisionModel = true;
}
if (validateVisionModel(this.modelOptions.model)) {
if (this.isVisionModel) {
delete this.modelOptions.stop;
}
@ -152,7 +155,7 @@ class OpenAIClient extends BaseClient {
this.setupTokens();
if (!this.modelOptions.stop && !validateVisionModel(this.modelOptions.model)) {
if (!this.modelOptions.stop && !this.isVisionModel) {
const stopTokens = [this.startToken];
if (this.endToken && this.endToken !== this.startToken) {
stopTokens.push(this.endToken);
@ -689,7 +692,7 @@ ${convo}
}
async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('[OpenAIClient]', { promptTokens, completionTokens });
logger.debug('[OpenAIClient] recordTokenUsage:', { promptTokens, completionTokens });
await spendTokens(
{
user: this.user,
@ -757,7 +760,7 @@ ${convo}
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
if (validateVisionModel(modelOptions.model)) {
if (this.isVisionModel) {
modelOptions.max_tokens = 4000;
}

View file

@ -180,7 +180,7 @@ class PluginsClient extends OpenAIClient {
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
if (errorMessage.length > 0) {
logger.debug('[PluginsClient] Caught error, input:', input);
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
}
try {

View file

@ -1,15 +1,19 @@
const KeyvRedis = require('@keyv/redis');
const { logger } = require('~/config');
const { isEnabled } = require('~/server/utils');
const { REDIS_URI } = process.env;
const { REDIS_URI, USE_REDIS } = process.env;
let keyvRedis;
if (REDIS_URI) {
if (REDIS_URI && isEnabled(USE_REDIS)) {
keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false });
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(20);
} else {
logger.info('REDIS_URI not provided. Redis module will not be initialized.');
logger.info(
'`REDIS_URI` not provided, or `USE_REDIS` not set. Redis module will not be initialized.',
);
}
module.exports = keyvRedis;

View file

@ -1,128 +1,160 @@
const util = require('util');
const winston = require('winston');
const traverse = require('traverse');
const { klona } = require('klona/full');
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/, /api-key: \w+/];
const SPLAT_SYMBOL = Symbol.for('splat');
const MESSAGE_SYMBOL = Symbol.for('message');
const sensitiveKeys = [/^(sk-)[^\s]+/, /(Bearer )[^\s]+/, /(api-key:? )[^\s]+/, /(key=)[^\s]+/];
/**
* Determines if a given key string is sensitive.
* Determines if a given value string is sensitive and returns matching regex patterns.
*
* @param {string} keyStr - The key string to check.
* @returns {boolean} True if the key string matches known sensitive key patterns.
* @param {string} valueStr - The value string to check.
* @returns {Array<RegExp>} An array of regex patterns that match the value string.
*/
function isSensitiveKey(keyStr) {
if (keyStr) {
return sensitiveKeys.some((regex) => regex.test(keyStr));
function getMatchingSensitivePatterns(valueStr) {
if (valueStr) {
// Filter and return all regex patterns that match the value string
return sensitiveKeys.filter((regex) => regex.test(valueStr));
}
return false;
return [];
}
/**
* Recursively redacts sensitive information from an object.
* Redacts sensitive information from a console message.
*
* @param {object} obj - The object to traverse and redact.
* @param {string} str - The console message to be redacted.
* @returns {string} - The redacted console message.
*/
function redactObject(obj) {
traverse(obj).forEach(function redactor() {
if (isSensitiveKey(this.key)) {
this.update('[REDACTED]');
}
function redactMessage(str) {
const patterns = getMatchingSensitivePatterns(str);
if (patterns.length === 0) {
return str;
}
patterns.forEach((pattern) => {
str = str.replace(pattern, '$1[REDACTED]');
});
return str;
}
/**
* Deep copies and redacts sensitive information from an object.
*
* @param {object} obj - The object to copy and redact.
* @returns {object} The redacted copy of the original object.
* Redacts sensitive information from log messages if the log level is 'error'.
* Note: Intentionally mutates the object.
* @param {Object} info - The log information object.
* @returns {Object} - The modified log information object.
*/
function redact(obj) {
const copy = klona(obj); // Making a deep copy to prevent side effects
redactObject(copy);
const splat = copy[Symbol.for('splat')];
redactObject(splat); // Specifically redact splat Symbol
return copy;
}
const redactFormat = winston.format((info) => {
if (info.level === 'error') {
info.message = redactMessage(info.message);
if (info[MESSAGE_SYMBOL]) {
info[MESSAGE_SYMBOL] = redactMessage(info[MESSAGE_SYMBOL]);
}
}
return info;
});
/**
* Truncates long strings, especially base64 image data, within log messages.
*
* @param {any} value - The value to be inspected and potentially truncated.
* @param {number} [length] - The length at which to truncate the value. Default: 100.
* @returns {any} - The truncated or original value.
*/
const truncateLongStrings = (value) => {
const truncateLongStrings = (value, length = 100) => {
if (typeof value === 'string') {
return value.length > 100 ? value.substring(0, 100) + '... [truncated]' : value;
return value.length > length ? value.substring(0, length) + '... [truncated]' : value;
}
return value;
};
// /**
// * Processes each message in the messages array, specifically looking for and truncating
// * base64 image URLs in the content. If a base64 image URL is found, it replaces the URL
// * with a truncated message.
// *
// * @param {PayloadMessage} message - The payload message object to format.
// * @returns {PayloadMessage} - The processed message object with base64 image URLs truncated.
// */
// const truncateBase64ImageURLs = (message) => {
// // Create a deep copy of the message
// const messageCopy = JSON.parse(JSON.stringify(message));
// if (messageCopy.content && Array.isArray(messageCopy.content)) {
// messageCopy.content = messageCopy.content.map(contentItem => {
// if (contentItem.type === 'image_url' && contentItem.image_url && isBase64String(contentItem.image_url.url)) {
// return { ...contentItem, image_url: { ...contentItem.image_url, url: 'Base64 Image Data... [truncated]' } };
// }
// return contentItem;
// });
// }
// return messageCopy;
// };
// /**
// * Checks if a string is a base64 image data string.
// *
// * @param {string} str - The string to be checked.
// * @returns {boolean} - True if the string is base64 image data, otherwise false.
// */
// const isBase64String = (str) => /^data:image\/[a-zA-Z]+;base64,/.test(str);
/**
* An array mapping function that truncates long strings (objects converted to JSON strings).
* @param {any} item - The item to be condensed.
* @returns {any} - The condensed item.
*/
const condenseArray = (item) => {
if (typeof item === 'string') {
return truncateLongStrings(JSON.stringify(item));
} else if (typeof item === 'object') {
return truncateLongStrings(JSON.stringify(item));
}
return item;
};
/**
* Custom log format for Winston that handles deep object inspection.
* It specifically truncates long strings and handles nested structures within metadata.
* Formats log messages for debugging purposes.
* - Truncates long strings within log messages.
* - Condenses arrays by truncating long strings and objects as strings within array items.
* - Redacts sensitive information from log messages if the log level is 'error'.
* - Converts log information object to a formatted string.
*
* @param {Object} info - Information about the log entry.
* @param {Object} options - The options for formatting log messages.
* @param {string} options.level - The log level.
* @param {string} options.message - The log message.
* @param {string} options.timestamp - The timestamp of the log message.
* @param {Object} options.metadata - Additional metadata associated with the log message.
* @returns {string} - The formatted log message.
*/
const deepObjectFormat = winston.format.printf(({ level, message, timestamp, ...metadata }) => {
let msg = `${timestamp} ${level}: ${message}`;
const debugTraverse = winston.format.printf(({ level, message, timestamp, ...metadata }) => {
let msg = `${timestamp} ${level}: ${truncateLongStrings(message?.trim(), 150)}`;
if (Object.keys(metadata).length) {
Object.entries(metadata).forEach(([key, value]) => {
let val = value;
if (key === 'modelOptions' && value && Array.isArray(value.messages)) {
// Create a shallow copy of the messages array
// val = { ...value, messages: value.messages.map(truncateBase64ImageURLs) };
val = { ...value, messages: `${value.messages.length} message(s) in payload` };
}
// Inspects each metadata value; applies special handling for 'messages'
const inspectedValue =
typeof val === 'string'
? truncateLongStrings(val)
: util.inspect(val, { depth: null, colors: false }); // Use 'val' here
msg += ` ${key}: ${inspectedValue}`;
});
if (level !== 'debug') {
return msg;
}
if (!metadata) {
return msg;
}
const debugValue = metadata[SPLAT_SYMBOL]?.[0];
if (!debugValue) {
return msg;
}
if (debugValue && Array.isArray(debugValue)) {
msg += `\n${JSON.stringify(debugValue.map(condenseArray))}`;
return msg;
}
if (typeof debugValue !== 'object') {
return (msg += ` ${debugValue}`);
}
msg += '\n{';
const copy = klona(metadata);
traverse(copy).forEach(function (value) {
const parent = this.parent;
const parentKey = `${parent && parent.notRoot ? parent.key + '.' : ''}`;
const tabs = `${parent && parent.notRoot ? '\t\t' : '\t'}`;
if (this.isLeaf && typeof value === 'string') {
const truncatedText = truncateLongStrings(value);
msg += `\n${tabs}${parentKey}${this.key}: ${JSON.stringify(truncatedText)},`;
} else if (this.notLeaf && Array.isArray(value) && value.length > 0) {
const currentMessage = `\n${tabs}// ${value.length} ${this.key.replace(/s$/, '')}(s)`;
this.update(currentMessage, true);
msg += currentMessage;
const stringifiedArray = value.map(condenseArray);
msg += `\n${tabs}${parentKey}${this.key}: [${stringifiedArray}],`;
} else if (this.isLeaf && typeof value === 'function') {
msg += `\n${tabs}${parentKey}${this.key}: function,`;
} else if (this.isLeaf) {
msg += `\n${tabs}${parentKey}${this.key}: ${value},`;
}
});
msg += '\n}';
return msg;
});
module.exports = {
redact,
deepObjectFormat,
redactFormat,
redactMessage,
debugTraverse,
};

View file

@ -1,7 +1,7 @@
const path = require('path');
const winston = require('winston');
require('winston-daily-rotate-file');
const { redact, deepObjectFormat } = require('./parsers');
const { redactFormat, redactMessage, debugTraverse } = require('./parsers');
const logDir = path.join(__dirname, '..', 'logs');
@ -32,10 +32,11 @@ const level = () => {
};
const fileFormat = winston.format.combine(
redactFormat(),
winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }),
winston.format.errors({ stack: true }),
winston.format.splat(),
winston.format((info) => redact(info))(),
// redactErrors(),
);
const transports = [
@ -78,16 +79,24 @@ if (
zippedArchive: true,
maxSize: '20m',
maxFiles: '14d',
format: winston.format.combine(fileFormat, deepObjectFormat),
format: winston.format.combine(fileFormat, debugTraverse),
}),
);
}
const consoleFormat = winston.format.combine(
redactFormat(),
winston.format.colorize({ all: true }),
winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }),
winston.format((info) => redact(info))(),
winston.format.printf((info) => `${info.timestamp} ${info.level}: ${info.message}`),
// redactErrors(),
winston.format.printf((info) => {
const message = `${info.timestamp} ${info.level}: ${info.message}`;
if (info.level.includes('error')) {
return redactMessage(message);
}
return message;
}),
);
if (
@ -97,7 +106,7 @@ if (
transports.push(
new winston.transports.Console({
level: 'debug',
format: winston.format.combine(consoleFormat, deepObjectFormat),
format: winston.format.combine(consoleFormat, debugTraverse),
}),
);
} else {

View file

@ -24,7 +24,7 @@ const getFiles = async (filter) => {
/**
* Creates a new file with a TTL of 1 hour.
* @param {Object} data - The file data to be created, must contain file_id.
* @param {MongoFile} data - The file data to be created, must contain file_id.
* @returns {Promise<MongoFile>} A promise that resolves to the created file document.
*/
const createFile = async (data) => {
@ -40,7 +40,7 @@ const createFile = async (data) => {
/**
* Updates a file identified by file_id with new data and removes the TTL.
* @param {Object} data - The data to update, must contain file_id.
* @param {MongoFile} data - The data to update, must contain file_id.
* @returns {Promise<MongoFile>} A promise that resolves to the updated file document.
*/
const updateFile = async (data) => {
@ -54,7 +54,7 @@ const updateFile = async (data) => {
/**
* Increments the usage of a file identified by file_id.
* @param {Object} data - The data to update, must contain file_id and the increment value for usage.
* @param {MongoFile} data - The data to update, must contain file_id and the increment value for usage.
* @returns {Promise<MongoFile>} A promise that resolves to the updated file document.
*/
const updateFileUsage = async (data) => {

View file

@ -39,7 +39,7 @@ transactionSchema.statics.create = async function (transactionData) {
{ user: transaction.user },
{ $inc: { tokenCredits: transaction.tokenValue } },
{ upsert: true, new: true },
);
).lean();
};
module.exports = mongoose.model('Transaction', transactionSchema);

View file

@ -43,46 +43,51 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
}
};
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
const getAbortData = () => ({
sender,
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
let getText;
try {
const { client } = await initializeClient({ req, res, endpointOption });
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text: partialText }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > saveDelay) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: client.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
user,
});
}
if (saveDelay < 500) {
saveDelay = 500;
}
},
});
getText = getPartialText;
const getAbortData = () => ({
sender,
conversationId,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const messageOptions = {
user,
parentMessageId,
@ -134,7 +139,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
});
}
} catch (error) {
const partialText = getPartialText();
const partialText = getText && getText();
handleAbortError(res, req, error, {
partialText,
conversationId,

View file

@ -2,6 +2,7 @@ const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/uti
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
const clearPendingReq = require('~/cache/clearPendingReq');
const abortControllers = require('./abortControllers');
const { redactMessage } = require('~/config/parsers');
const spendTokens = require('~/models/spendTokens');
const { logger } = require('~/config');
@ -92,7 +93,7 @@ const handleAbortError = async (res, req, error, data) => {
messageId,
conversationId,
parentMessageId,
text: error.message,
text: redactMessage(error.message),
shouldSaveMessage: true,
user: req.user.id,
};

View file

@ -1,9 +1,10 @@
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
const { EModelEndpoint } = require('librechat-data-provider');
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
const {
getOpenAIModels,
getChatGPTBrowserModels,
getGoogleModels,
getAnthropicModels,
getChatGPTBrowserModels,
} = require('~/server/services/ModelService');
const fitlerAssistantModels = (str) => {
@ -11,6 +12,7 @@ const fitlerAssistantModels = (str) => {
};
async function loadDefaultModels() {
const google = getGoogleModels();
const openAI = await getOpenAIModels();
const anthropic = getAnthropicModels();
const chatGPTBrowser = getChatGPTBrowserModels();
@ -19,13 +21,13 @@ async function loadDefaultModels() {
return {
[EModelEndpoint.openAI]: openAI,
[EModelEndpoint.google]: google,
[EModelEndpoint.anthropic]: anthropic,
[EModelEndpoint.gptPlugins]: gptPlugins,
[EModelEndpoint.azureOpenAI]: azureOpenAI,
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
[EModelEndpoint.google]: defaultModels[EModelEndpoint.google],
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
[EModelEndpoint.gptPlugins]: gptPlugins,
[EModelEndpoint.anthropic]: anthropic,
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
};
}

View file

@ -1,7 +1,13 @@
const fs = require('fs');
const path = require('path');
const { EModelEndpoint } = require('librechat-data-provider');
const { updateFile } = require('~/models');
/**
* Encodes an image file to base64.
* @param {string} imagePath - The path to the image file.
* @returns {Promise<string>} A promise that resolves with the base64 encoded image data.
*/
function encodeImage(imagePath) {
return new Promise((resolve, reject) => {
fs.readFile(imagePath, (err, data) => {
@ -14,6 +20,12 @@ function encodeImage(imagePath) {
});
}
/**
* Updates the file and encodes the image.
* @param {Object} req - The request object.
* @param {Object} file - The file object.
* @returns {Promise<[MongoFile, string]>} - A promise that resolves to an array of results from updateFile and encodeImage.
*/
async function updateAndEncode(req, file) {
const { publicPath, imageOutput } = req.app.locals.config;
const userPath = path.join(imageOutput, req.user.id);
@ -29,7 +41,14 @@ async function updateAndEncode(req, file) {
return await Promise.all(promises);
}
async function encodeAndFormat(req, files) {
/**
* Encodes and formats the given files.
* @param {Express.Request} req - The request object.
* @param {Array<MongoFile>} files - The array of files to encode and format.
* @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image.
* @returns {Promise<Object>} - A promise that resolves to the result object containing the encoded images and file details.
*/
async function encodeAndFormat(req, files, endpoint) {
const promises = [];
for (let file of files) {
promises.push(updateAndEncode(req, file));
@ -46,13 +65,19 @@ async function encodeAndFormat(req, files) {
};
for (const [file, base64] of encodedImages) {
result.image_urls.push({
const imagePart = {
type: 'image_url',
image_url: {
url: `data:image/webp;base64,${base64}`,
detail,
},
});
};
if (endpoint && endpoint === EModelEndpoint.google) {
imagePart.image_url = imagePart.image_url.url;
}
result.image_urls.push(imagePart);
result.files.push({
file_id: file.file_id,

View file

@ -15,8 +15,14 @@ const modelsCache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'models' });
const { OPENROUTER_API_KEY, OPENAI_REVERSE_PROXY, CHATGPT_MODELS, ANTHROPIC_MODELS, PROXY } =
process.env ?? {};
const {
OPENROUTER_API_KEY,
OPENAI_REVERSE_PROXY,
CHATGPT_MODELS,
ANTHROPIC_MODELS,
GOOGLE_MODELS,
PROXY,
} = process.env ?? {};
const fetchOpenAIModels = async (opts = { azure: false, plugins: false }, _models = []) => {
let models = _models.slice() ?? [];
@ -126,8 +132,18 @@ const getAnthropicModels = () => {
return models;
};
const getGoogleModels = () => {
let models = defaultModels[EModelEndpoint.google];
if (GOOGLE_MODELS) {
models = String(GOOGLE_MODELS).split(',');
}
return models;
};
module.exports = {
getOpenAIModels,
getChatGPTBrowserModels,
getAnthropicModels,
getGoogleModels,
};

View file

@ -172,7 +172,7 @@ export default function useSSE(submission: TSubmission | null, index = 0) {
const finalHandler = (data: TResData, submission: TSubmission) => {
const { requestMessage, responseMessage, conversation } = data;
const { messages, isRegenerate = false } = submission;
const { messages, conversation: submissionConvo, isRegenerate = false } = submission;
// update the messages
if (isRegenerate) {
@ -199,6 +199,11 @@ export default function useSSE(submission: TSubmission | null, index = 0) {
...conversation,
};
// Revert to previous model if the model was auto-switched by backend due to message attachments
if (conversation.model?.includes('vision') && !submissionConvo.model?.includes('vision')) {
update.model = submissionConvo?.model;
}
setStorage(update);
return update;
});

View file

@ -31,7 +31,7 @@
# Features
- 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and 11-2023 updates
- 💬 Multimodal Chat:
- Upload and analyze images with GPT-4-Vision 📸
- Upload and analyze images with GPT-4 and Gemini Vision 📸
- More filetypes and Assistants API integration in Active Development 🚧
- 🌎 Multilingual UI:
- English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro, Русский

View file

@ -70,10 +70,6 @@ For Vertex AI, you need a Service Account JSON key file, with appropriate access
Instructions for both are given below.
Setting `GOOGLE_KEY=user_provided` in your .env file will configure both values to be provided from the client (or frontend) like so:
![image](https://github.com/danny-avila/LibreChat/assets/110412045/728cbc04-4180-45a8-848c-ae5de2b02996)
### Generative Language API (Gemini)
**60 Gemini requests/minute are currently free until early next year when it enters general availability.**
@ -85,21 +81,22 @@ To use Gemini models, you'll need an API key. If you don't already have one, cre
<p><a class="button button-primary" href="https://makersuite.google.com/app/apikey" target="_blank" rel="noopener noreferrer">Get an API key here</a></p>
Once you have your key, you can either provide it from the frontend by setting the following:
```bash
GOOGLE_KEY=user_provided
```
Or, provide the key in your .env file, which allows all users of your instance to use it.
Once you have your key, provide the key in your .env file, which allows all users of your instance to use it.
```bash
GOOGLE_KEY=mY_SeCreT_w9347w8_kEY
```
> Notes:
> - As of 12/15/23, Gemini Pro Vision is not yet supported but is planned.
> - PaLM2 and Codey models cannot be accessed through the Generative Language API.
Or, you can make users provide it from the frontend by setting the following:
```bash
GOOGLE_KEY=user_provided
```
Note: PaLM2 and Codey models cannot be accessed through the Generative Language API, only through Vertex AI.
Setting `GOOGLE_KEY=user_provided` in your .env file will configure both the Vertex AI Service Account JSON key file and the Generative Language API key to be provided from the frontend like so:
![image](https://github.com/danny-avila/LibreChat/assets/110412045/728cbc04-4180-45a8-848c-ae5de2b02996)
### Vertex AI (PaLM 2 & Codey)
@ -132,14 +129,15 @@ You can usually get **$300 starting credit**, which makes this option free for 9
**Saving your JSON key file in the project directory which allows all users of your LibreChat instance to use it.**
Alternatively, Once you have your JSON key file, you can also provide it from the frontend on a user-basis by setting the following:
Alternatively, you can make users provide it from the frontend by setting the following:
```bash
# Note: this configures both the Vertex AI Service Account JSON key file
# and the Generative Language API key to be provided from the frontend.
GOOGLE_KEY=user_provided
```
> Notes:
> - As of 12/15/23, Gemini and Gemini Pro Vision are not yet supported through Vertex AI but are planned.
Note: Using Gemini models through Vertex AI is possible but not yet supported.
## Azure OpenAI

View file

@ -199,6 +199,15 @@ GOOGLE_KEY=user_provided
GOOGLE_REVERSE_PROXY=
```
- Customize the available models, separated by commas, **without spaces**.
- The first will be default.
- Leave it blank or commented out to use internal settings (default: all listed below).
```bash
# all available models as of 12/16/23
GOOGLE_MODELS=gemini-pro,gemini-pro-vision,chat-bison,chat-bison-32k,codechat-bison,codechat-bison-32k,text-bison,text-bison-32k,text-unicorn,code-gecko,code-bison,code-bison-32k
```
### OpenAI
- To get your OpenAI API key, you need to:

2
package-lock.json generated
View file

@ -25558,7 +25558,7 @@
},
"packages/data-provider": {
"name": "librechat-data-provider",
"version": "0.3.1",
"version": "0.3.2",
"license": "ISC",
"dependencies": {
"axios": "^1.3.4",

View file

@ -25,6 +25,7 @@ export const defaultEndpoints: EModelEndpoint[] = [
export const defaultModels = {
[EModelEndpoint.google]: [
'gemini-pro',
'gemini-pro-vision',
'chat-bison',
'chat-bison-32k',
'codechat-bison',
@ -135,6 +136,7 @@ export const modularEndpoints = new Set<EModelEndpoint | string>([
export const supportsFiles = {
[EModelEndpoint.openAI]: true,
[EModelEndpoint.google]: true,
[EModelEndpoint.assistant]: true,
};
@ -144,7 +146,7 @@ export const supportsBalanceCheck = {
[EModelEndpoint.gptPlugins]: true,
};
export const visionModels = ['gpt-4-vision', 'llava-13b'];
export const visionModels = ['gpt-4-vision', 'llava-13b', 'gemini-pro-vision'];
export const eModelEndpointSchema = z.nativeEnum(EModelEndpoint);