feat: Google Gemini ❇️ (#1355)

* refactor: add gemini-pro to google Models list; use defaultModels for central model listing

* refactor(SetKeyDialog): create useMultipleKeys hook to use for Azure, export `isJson` from utils, use EModelEndpoint

* refactor(useUserKey): change variable names to make keyName setting more clear

* refactor(FileUpload): allow passing container className string

* feat(GoogleClient): Gemini support

* refactor(GoogleClient): alternate stream speed for Gemini models

* feat(Gemini): styling/settings configuration for Gemini

* refactor(GoogleClient): substract max response tokens from max context tokens if context is above 32k (I/O max is combined between the two)

* refactor(tokens): correct google max token counts and subtract max response tokens when input/output count are combined towards max context count

* feat(google/initializeClient): handle both local and user_provided credentials and write tests

* fix(GoogleClient): catch if credentials are undefined, handle if serviceKey is string or object correctly, handle no examples passed, throw error if not a Generative Language model and no service account JSON key is provided, throw error if it is a Generative m
odel, but not google API key was provided

* refactor(loadAsyncEndpoints/google): activate Google endpoint if either the service key JSON file is provided in /api/data, or a GOOGLE_KEY is defined.

* docs: updated Google configuration

* fix(ci): Mock import of Service Account Key JSON file (auth.json)

* Update apis_and_tokens.md

* feat: increase max output tokens slider for gemini pro

* refactor(GoogleSettings): handle max and default maxOutputTokens on model change

* chore: add sensitive redact regex

* docs: add warning about data privacy

* Update apis_and_tokens.md
This commit is contained in:
Danny Avila 2023-12-15 02:18:07 -05:00 committed by GitHub
parent d259431316
commit 561ce8e86a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 702 additions and 219 deletions

View file

@ -1,10 +1,16 @@
const { google } = require('googleapis');
const { Agent, ProxyAgent } = require('undici');
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getResponseSender, EModelEndpoint, endpointSettings } = require('librechat-data-provider');
const {
getResponseSender,
EModelEndpoint,
endpointSettings,
AuthKeys,
} = require('librechat-data-provider');
const { getModelMaxTokens } = require('~/utils');
const { formatMessage } = require('./prompts');
const BaseClient = require('./BaseClient');
@ -21,11 +27,24 @@ const settings = endpointSettings[EModelEndpoint.google];
class GoogleClient extends BaseClient {
constructor(credentials, options = {}) {
super('apiKey', options);
this.credentials = credentials;
this.client_email = credentials.client_email;
this.project_id = credentials.project_id;
this.private_key = credentials.private_key;
let creds = {};
if (typeof credentials === 'string') {
creds = JSON.parse(credentials);
} else if (credentials) {
creds = credentials;
}
const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
this.serviceKey =
serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {};
this.client_email = this.serviceKey.client_email;
this.private_key = this.serviceKey.private_key;
this.project_id = this.serviceKey.project_id;
this.access_token = null;
this.apiKey = creds[AuthKeys.GOOGLE_API_KEY];
if (options.skipSetOptions) {
return;
}
@ -85,7 +104,7 @@ class GoogleClient extends BaseClient {
this.options = options;
}
this.options.examples = this.options.examples
this.options.examples = (this.options.examples ?? [])
.filter((ex) => ex)
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');
@ -103,15 +122,24 @@ class GoogleClient extends BaseClient {
// stop: modelOptions.stop // no stop method for now
};
this.isChatModel = this.modelOptions.model.includes('chat');
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
const { isGenerativeModel } = this;
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
const { isChatModel } = this;
this.isTextModel = !isChatModel && /code|text/.test(this.modelOptions.model);
this.isTextModel =
!isGenerativeModel && !isChatModel && /code|text/.test(this.modelOptions.model);
const { isTextModel } = this;
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google);
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default;
if (this.maxContextTokens > 32000) {
this.maxContextTokens = this.maxContextTokens - this.maxResponseTokens;
}
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
@ -134,7 +162,7 @@ class GoogleClient extends BaseClient {
this.userLabel = this.options.userLabel || 'User';
this.modelLabel = this.options.modelLabel || 'Assistant';
if (isChatModel) {
if (isChatModel || isGenerativeModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "||>" instead.
@ -189,6 +217,16 @@ class GoogleClient extends BaseClient {
}
buildMessages(messages = [], parentMessageId) {
if (!this.isGenerativeModel && !this.project_id) {
throw new Error(
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
);
} else if (this.isGenerativeModel && (!this.apiKey || this.apiKey === 'user_provided')) {
throw new Error(
'[GoogleClient] an API Key is required for Gemini models (Generative Language API)',
);
}
if (this.isTextModel) {
return this.buildMessagesPrompt(messages, parentMessageId);
}
@ -398,6 +436,16 @@ class GoogleClient extends BaseClient {
return res.data;
}
createLLM(clientOptions) {
if (this.isGenerativeModel) {
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}
return this.isTextModel
? new GoogleVertexAI(clientOptions)
: new ChatGoogleVertexAI(clientOptions);
}
async getCompletion(_payload, options = {}) {
const { onProgress, abortController } = options;
const { parameters, instances } = _payload;
@ -408,7 +456,7 @@ class GoogleClient extends BaseClient {
let clientOptions = {
authOptions: {
credentials: {
...this.credentials,
...this.serviceKey,
},
projectId: this.project_id,
},
@ -436,9 +484,7 @@ class GoogleClient extends BaseClient {
clientOptions.examples = examples;
}
const model = this.isTextModel
? new GoogleVertexAI(clientOptions)
: new ChatGoogleVertexAI(clientOptions);
const model = this.createLLM(clientOptions);
let reply = '';
const messages = this.isTextModel
@ -457,7 +503,9 @@ class GoogleClient extends BaseClient {
});
for await (const chunk of stream) {
await this.generateTextStream(chunk?.content ?? chunk, onProgress, { delay: 7 });
await this.generateTextStream(chunk?.content ?? chunk, onProgress, {
delay: this.isGenerativeModel ? 12 : 8,
});
reply += chunk?.content ?? chunk;
}