mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-25 03:36:12 +01:00
📸 feat: Gemini vision, Improved Logs and Multi-modal Handling (#1368)
* feat: add GOOGLE_MODELS env var * feat: add gemini vision support * refactor(GoogleClient): adjust clientOptions handling depending on model * fix(logger): fix redact logic and redact errors only * fix(GoogleClient): do not allow non-multiModal messages when gemini-pro-vision is selected * refactor(OpenAIClient): use `isVisionModel` client property to avoid calling validateVisionModel multiple times * refactor: better debug logging by correctly traversing, redacting sensitive info, and logging condensed versions of long values * refactor(GoogleClient): allow response errors to be thrown/caught above client handling so user receives meaningful error message debug orderedMessages, parentMessageId, and buildMessages result * refactor(AskController): use model from client.modelOptions.model when saving intermediate messages, which requires for the progress callback to be initialized after the client is initialized * feat(useSSE): revert to previous model if the model was auto-switched by backend due to message attachments * docs: update with google updates, notes about Gemini Pro Vision * fix: redis should not be initialized without USE_REDIS and increase max listeners to 20
This commit is contained in:
parent
676f133545
commit
0c326797dd
21 changed files with 356 additions and 210 deletions
|
|
@ -357,11 +357,11 @@ class BaseClient {
|
|||
|
||||
const promptTokens = this.maxContextTokens - remainingContextTokens;
|
||||
|
||||
logger.debug('[BaseClient] Payload size:', payload.length);
|
||||
logger.debug('[BaseClient] tokenCountMap:', tokenCountMap);
|
||||
logger.debug('[BaseClient]', {
|
||||
promptTokens,
|
||||
remainingContextTokens,
|
||||
payloadSize: payload.length,
|
||||
maxContextTokens: this.maxContextTokens,
|
||||
});
|
||||
|
||||
|
|
@ -414,7 +414,6 @@ class BaseClient {
|
|||
logger.debug('[BaseClient] tokenCountMap', tokenCountMap);
|
||||
if (tokenCountMap[userMessage.messageId]) {
|
||||
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
||||
logger.debug('[BaseClient] userMessage.tokenCount', userMessage.tokenCount);
|
||||
logger.debug('[BaseClient] userMessage', userMessage);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
|
|||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
|
||||
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
|
||||
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
getResponseSender,
|
||||
|
|
@ -122,9 +123,18 @@ class GoogleClient extends BaseClient {
|
|||
// stop: modelOptions.stop // no stop method for now
|
||||
};
|
||||
|
||||
if (this.options.attachments) {
|
||||
this.modelOptions.model = 'gemini-pro-vision';
|
||||
}
|
||||
|
||||
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
|
||||
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
|
||||
this.isVisionModel = validateVisionModel(this.modelOptions.model);
|
||||
const { isGenerativeModel } = this;
|
||||
if (this.isVisionModel && !this.options.attachments) {
|
||||
this.modelOptions.model = 'gemini-pro';
|
||||
this.isVisionModel = false;
|
||||
}
|
||||
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
|
||||
const { isChatModel } = this;
|
||||
this.isTextModel =
|
||||
|
|
@ -216,7 +226,34 @@ class GoogleClient extends BaseClient {
|
|||
})).bind(this);
|
||||
}
|
||||
|
||||
buildMessages(messages = [], parentMessageId) {
|
||||
async buildVisionMessages(messages = [], parentMessageId) {
|
||||
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
|
||||
const attachments = await this.options.attachments;
|
||||
const { files, image_urls } = await encodeAndFormat(
|
||||
this.options.req,
|
||||
attachments.filter((file) => file.type.includes('image')),
|
||||
EModelEndpoint.google,
|
||||
);
|
||||
|
||||
const latestMessage = { ...messages[messages.length - 1] };
|
||||
|
||||
latestMessage.image_urls = image_urls;
|
||||
this.options.attachments = files;
|
||||
|
||||
latestMessage.text = prompt;
|
||||
|
||||
const payload = {
|
||||
instances: [
|
||||
{
|
||||
messages: [new HumanMessage(formatMessage({ message: latestMessage }))],
|
||||
},
|
||||
],
|
||||
parameters: this.modelOptions,
|
||||
};
|
||||
return { prompt: payload };
|
||||
}
|
||||
|
||||
async buildMessages(messages = [], parentMessageId) {
|
||||
if (!this.isGenerativeModel && !this.project_id) {
|
||||
throw new Error(
|
||||
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
|
||||
|
|
@ -227,17 +264,24 @@ class GoogleClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
|
||||
if (this.options.attachments) {
|
||||
return this.buildVisionMessages(messages, parentMessageId);
|
||||
}
|
||||
|
||||
if (this.isTextModel) {
|
||||
return this.buildMessagesPrompt(messages, parentMessageId);
|
||||
}
|
||||
const formattedMessages = messages.map(this.formatMessages());
|
||||
|
||||
let payload = {
|
||||
instances: [
|
||||
{
|
||||
messages: formattedMessages,
|
||||
messages: messages
|
||||
.map(this.formatMessages())
|
||||
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
|
||||
.map((message) => formatMessage({ message, langChain: true })),
|
||||
},
|
||||
],
|
||||
parameters: this.options.modelOptions,
|
||||
parameters: this.modelOptions,
|
||||
};
|
||||
|
||||
if (this.options.promptPrefix) {
|
||||
|
|
@ -248,9 +292,7 @@ class GoogleClient extends BaseClient {
|
|||
payload.instances[0].examples = this.options.examples;
|
||||
}
|
||||
|
||||
if (this.options.debug) {
|
||||
logger.debug('GoogleClient buildMessages', payload);
|
||||
}
|
||||
logger.debug('[GoogleClient] buildMessages', payload);
|
||||
|
||||
return { prompt: payload };
|
||||
}
|
||||
|
|
@ -260,12 +302,11 @@ class GoogleClient extends BaseClient {
|
|||
messages,
|
||||
parentMessageId,
|
||||
});
|
||||
if (this.options.debug) {
|
||||
logger.debug('GoogleClient: orderedMessages, parentMessageId', {
|
||||
orderedMessages,
|
||||
parentMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
logger.debug('[GoogleClient]', {
|
||||
orderedMessages,
|
||||
parentMessageId,
|
||||
});
|
||||
|
||||
const formattedMessages = orderedMessages.map((message) => ({
|
||||
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
|
||||
|
|
@ -394,7 +435,7 @@ class GoogleClient extends BaseClient {
|
|||
context.shift();
|
||||
}
|
||||
|
||||
let prompt = `${promptBody}${promptSuffix}`;
|
||||
let prompt = `${promptBody}${promptSuffix}`.trim();
|
||||
|
||||
// Add 2 tokens for metadata after all messages have been counted.
|
||||
currentTokenCount += 2;
|
||||
|
|
@ -453,20 +494,26 @@ class GoogleClient extends BaseClient {
|
|||
|
||||
let examples;
|
||||
|
||||
let clientOptions = {
|
||||
authOptions: {
|
||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||
|
||||
if (!this.isGenerativeModel) {
|
||||
clientOptions['authOptions'] = {
|
||||
credentials: {
|
||||
...this.serviceKey,
|
||||
},
|
||||
projectId: this.project_id,
|
||||
},
|
||||
...parameters,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
if (!parameters) {
|
||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||
}
|
||||
|
||||
if (this.isGenerativeModel) {
|
||||
clientOptions.modelName = clientOptions.model;
|
||||
delete clientOptions.model;
|
||||
}
|
||||
|
||||
if (_examples && _examples.length) {
|
||||
examples = _examples
|
||||
.map((ex) => {
|
||||
|
|
@ -487,13 +534,9 @@ class GoogleClient extends BaseClient {
|
|||
const model = this.createLLM(clientOptions);
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel
|
||||
? _payload.trim()
|
||||
: _messages
|
||||
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
|
||||
.map((message) => formatMessage({ message, langChain: true }));
|
||||
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||
|
||||
if (context && messages?.length > 0) {
|
||||
if (!this.isVisionModel && context && messages?.length > 0) {
|
||||
messages.unshift(new SystemMessage(context));
|
||||
}
|
||||
|
||||
|
|
@ -526,14 +569,7 @@ class GoogleClient extends BaseClient {
|
|||
|
||||
async sendCompletion(payload, opts = {}) {
|
||||
let reply = '';
|
||||
try {
|
||||
reply = await this.getCompletion(payload, opts);
|
||||
if (this.options.debug) {
|
||||
logger.debug('GoogleClient sendCompletion', { reply });
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('failed to send completion to Google', err);
|
||||
}
|
||||
reply = await this.getCompletion(payload, opts);
|
||||
return reply.trim();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
|
||||
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils');
|
||||
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
|
||||
|
|
@ -76,11 +76,14 @@ class OpenAIClient extends BaseClient {
|
|||
};
|
||||
}
|
||||
|
||||
if (this.options.attachments && !validateVisionModel(this.modelOptions.model)) {
|
||||
this.isVisionModel = validateVisionModel(this.modelOptions.model);
|
||||
|
||||
if (this.options.attachments && !this.isVisionModel) {
|
||||
this.modelOptions.model = 'gpt-4-vision-preview';
|
||||
this.isVisionModel = true;
|
||||
}
|
||||
|
||||
if (validateVisionModel(this.modelOptions.model)) {
|
||||
if (this.isVisionModel) {
|
||||
delete this.modelOptions.stop;
|
||||
}
|
||||
|
||||
|
|
@ -152,7 +155,7 @@ class OpenAIClient extends BaseClient {
|
|||
|
||||
this.setupTokens();
|
||||
|
||||
if (!this.modelOptions.stop && !validateVisionModel(this.modelOptions.model)) {
|
||||
if (!this.modelOptions.stop && !this.isVisionModel) {
|
||||
const stopTokens = [this.startToken];
|
||||
if (this.endToken && this.endToken !== this.startToken) {
|
||||
stopTokens.push(this.endToken);
|
||||
|
|
@ -689,7 +692,7 @@ ${convo}
|
|||
}
|
||||
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
logger.debug('[OpenAIClient]', { promptTokens, completionTokens });
|
||||
logger.debug('[OpenAIClient] recordTokenUsage:', { promptTokens, completionTokens });
|
||||
await spendTokens(
|
||||
{
|
||||
user: this.user,
|
||||
|
|
@ -757,7 +760,7 @@ ${convo}
|
|||
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
if (validateVisionModel(modelOptions.model)) {
|
||||
if (this.isVisionModel) {
|
||||
modelOptions.max_tokens = 4000;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class PluginsClient extends OpenAIClient {
|
|||
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
|
||||
|
||||
if (errorMessage.length > 0) {
|
||||
logger.debug('[PluginsClient] Caught error, input:', input);
|
||||
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue