📸 feat: Gemini vision, Improved Logs and Multi-modal Handling (#1368)

* feat: add GOOGLE_MODELS env var

* feat: add gemini vision support

* refactor(GoogleClient): adjust clientOptions handling depending on model

* fix(logger): fix redact logic and redact errors only

* fix(GoogleClient): do not allow non-multiModal messages when gemini-pro-vision is selected

* refactor(OpenAIClient): use `isVisionModel` client property to avoid calling validateVisionModel multiple times

* refactor: better debug logging by correctly traversing, redacting sensitive info, and logging condensed versions of long values

* refactor(GoogleClient): allow response errors to be thrown/caught above client handling so user receives meaningful error message
debug orderedMessages, parentMessageId, and buildMessages result

* refactor(AskController): use model from client.modelOptions.model when saving intermediate messages, which requires for the progress callback to be initialized after the client is initialized

* feat(useSSE): revert to previous model if the model was auto-switched by backend due to message attachments

* docs: update with google updates, notes about Gemini Pro Vision

* fix: redis should not be initialized without USE_REDIS and increase max listeners to 20
This commit is contained in:
Danny Avila 2023-12-16 20:45:27 -05:00 committed by GitHub
parent 676f133545
commit 0c326797dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 356 additions and 210 deletions

View file

@ -4,6 +4,7 @@ const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
getResponseSender,
@ -122,9 +123,18 @@ class GoogleClient extends BaseClient {
// stop: modelOptions.stop // no stop method for now
};
if (this.options.attachments) {
this.modelOptions.model = 'gemini-pro-vision';
}
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
this.isVisionModel = validateVisionModel(this.modelOptions.model);
const { isGenerativeModel } = this;
if (this.isVisionModel && !this.options.attachments) {
this.modelOptions.model = 'gemini-pro';
this.isVisionModel = false;
}
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
const { isChatModel } = this;
this.isTextModel =
@ -216,7 +226,34 @@ class GoogleClient extends BaseClient {
})).bind(this);
}
buildMessages(messages = [], parentMessageId) {
async buildVisionMessages(messages = [], parentMessageId) {
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
const attachments = await this.options.attachments;
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments.filter((file) => file.type.includes('image')),
EModelEndpoint.google,
);
const latestMessage = { ...messages[messages.length - 1] };
latestMessage.image_urls = image_urls;
this.options.attachments = files;
latestMessage.text = prompt;
const payload = {
instances: [
{
messages: [new HumanMessage(formatMessage({ message: latestMessage }))],
},
],
parameters: this.modelOptions,
};
return { prompt: payload };
}
async buildMessages(messages = [], parentMessageId) {
if (!this.isGenerativeModel && !this.project_id) {
throw new Error(
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
@ -227,17 +264,24 @@ class GoogleClient extends BaseClient {
);
}
if (this.options.attachments) {
return this.buildVisionMessages(messages, parentMessageId);
}
if (this.isTextModel) {
return this.buildMessagesPrompt(messages, parentMessageId);
}
const formattedMessages = messages.map(this.formatMessages());
let payload = {
instances: [
{
messages: formattedMessages,
messages: messages
.map(this.formatMessages())
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
.map((message) => formatMessage({ message, langChain: true })),
},
],
parameters: this.options.modelOptions,
parameters: this.modelOptions,
};
if (this.options.promptPrefix) {
@ -248,9 +292,7 @@ class GoogleClient extends BaseClient {
payload.instances[0].examples = this.options.examples;
}
if (this.options.debug) {
logger.debug('GoogleClient buildMessages', payload);
}
logger.debug('[GoogleClient] buildMessages', payload);
return { prompt: payload };
}
@ -260,12 +302,11 @@ class GoogleClient extends BaseClient {
messages,
parentMessageId,
});
if (this.options.debug) {
logger.debug('GoogleClient: orderedMessages, parentMessageId', {
orderedMessages,
parentMessageId,
});
}
logger.debug('[GoogleClient]', {
orderedMessages,
parentMessageId,
});
const formattedMessages = orderedMessages.map((message) => ({
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
@ -394,7 +435,7 @@ class GoogleClient extends BaseClient {
context.shift();
}
let prompt = `${promptBody}${promptSuffix}`;
let prompt = `${promptBody}${promptSuffix}`.trim();
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
@ -453,20 +494,26 @@ class GoogleClient extends BaseClient {
let examples;
let clientOptions = {
authOptions: {
let clientOptions = { ...parameters, maxRetries: 2 };
if (!this.isGenerativeModel) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
},
projectId: this.project_id,
},
...parameters,
};
};
}
if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions };
}
if (this.isGenerativeModel) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}
if (_examples && _examples.length) {
examples = _examples
.map((ex) => {
@ -487,13 +534,9 @@ class GoogleClient extends BaseClient {
const model = this.createLLM(clientOptions);
let reply = '';
const messages = this.isTextModel
? _payload.trim()
: _messages
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
.map((message) => formatMessage({ message, langChain: true }));
const messages = this.isTextModel ? _payload.trim() : _messages;
if (context && messages?.length > 0) {
if (!this.isVisionModel && context && messages?.length > 0) {
messages.unshift(new SystemMessage(context));
}
@ -526,14 +569,7 @@ class GoogleClient extends BaseClient {
async sendCompletion(payload, opts = {}) {
let reply = '';
try {
reply = await this.getCompletion(payload, opts);
if (this.options.debug) {
logger.debug('GoogleClient sendCompletion', { reply });
}
} catch (err) {
logger.error('failed to send completion to Google', err);
}
reply = await this.getCompletion(payload, opts);
return reply.trim();
}