mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
feat: Google Gemini ❇️ (#1355)
* refactor: add gemini-pro to google Models list; use defaultModels for central model listing * refactor(SetKeyDialog): create useMultipleKeys hook to use for Azure, export `isJson` from utils, use EModelEndpoint * refactor(useUserKey): change variable names to make keyName setting more clear * refactor(FileUpload): allow passing container className string * feat(GoogleClient): Gemini support * refactor(GoogleClient): alternate stream speed for Gemini models * feat(Gemini): styling/settings configuration for Gemini * refactor(GoogleClient): substract max response tokens from max context tokens if context is above 32k (I/O max is combined between the two) * refactor(tokens): correct google max token counts and subtract max response tokens when input/output count are combined towards max context count * feat(google/initializeClient): handle both local and user_provided credentials and write tests * fix(GoogleClient): catch if credentials are undefined, handle if serviceKey is string or object correctly, handle no examples passed, throw error if not a Generative Language model and no service account JSON key is provided, throw error if it is a Generative m odel, but not google API key was provided * refactor(loadAsyncEndpoints/google): activate Google endpoint if either the service key JSON file is provided in /api/data, or a GOOGLE_KEY is defined. * docs: updated Google configuration * fix(ci): Mock import of Service Account Key JSON file (auth.json) * Update apis_and_tokens.md * feat: increase max output tokens slider for gemini pro * refactor(GoogleSettings): handle max and default maxOutputTokens on model change * chore: add sensitive redact regex * docs: add warning about data privacy * Update apis_and_tokens.md
This commit is contained in:
parent
d259431316
commit
561ce8e86a
37 changed files with 702 additions and 219 deletions
|
|
@ -1,10 +1,16 @@
|
|||
const { google } = require('googleapis');
|
||||
const { Agent, ProxyAgent } = require('undici');
|
||||
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
|
||||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
|
||||
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { getResponseSender, EModelEndpoint, endpointSettings } = require('librechat-data-provider');
|
||||
const {
|
||||
getResponseSender,
|
||||
EModelEndpoint,
|
||||
endpointSettings,
|
||||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { formatMessage } = require('./prompts');
|
||||
const BaseClient = require('./BaseClient');
|
||||
|
|
@ -21,11 +27,24 @@ const settings = endpointSettings[EModelEndpoint.google];
|
|||
class GoogleClient extends BaseClient {
|
||||
constructor(credentials, options = {}) {
|
||||
super('apiKey', options);
|
||||
this.credentials = credentials;
|
||||
this.client_email = credentials.client_email;
|
||||
this.project_id = credentials.project_id;
|
||||
this.private_key = credentials.private_key;
|
||||
let creds = {};
|
||||
|
||||
if (typeof credentials === 'string') {
|
||||
creds = JSON.parse(credentials);
|
||||
} else if (credentials) {
|
||||
creds = credentials;
|
||||
}
|
||||
|
||||
const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
|
||||
this.serviceKey =
|
||||
serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {};
|
||||
this.client_email = this.serviceKey.client_email;
|
||||
this.private_key = this.serviceKey.private_key;
|
||||
this.project_id = this.serviceKey.project_id;
|
||||
this.access_token = null;
|
||||
|
||||
this.apiKey = creds[AuthKeys.GOOGLE_API_KEY];
|
||||
|
||||
if (options.skipSetOptions) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -85,7 +104,7 @@ class GoogleClient extends BaseClient {
|
|||
this.options = options;
|
||||
}
|
||||
|
||||
this.options.examples = this.options.examples
|
||||
this.options.examples = (this.options.examples ?? [])
|
||||
.filter((ex) => ex)
|
||||
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');
|
||||
|
||||
|
|
@ -103,15 +122,24 @@ class GoogleClient extends BaseClient {
|
|||
// stop: modelOptions.stop // no stop method for now
|
||||
};
|
||||
|
||||
this.isChatModel = this.modelOptions.model.includes('chat');
|
||||
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
|
||||
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
|
||||
const { isGenerativeModel } = this;
|
||||
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
|
||||
const { isChatModel } = this;
|
||||
this.isTextModel = !isChatModel && /code|text/.test(this.modelOptions.model);
|
||||
this.isTextModel =
|
||||
!isGenerativeModel && !isChatModel && /code|text/.test(this.modelOptions.model);
|
||||
const { isTextModel } = this;
|
||||
|
||||
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google);
|
||||
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
|
||||
// Earlier messages will be dropped until the prompt is within the limit.
|
||||
this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default;
|
||||
|
||||
if (this.maxContextTokens > 32000) {
|
||||
this.maxContextTokens = this.maxContextTokens - this.maxResponseTokens;
|
||||
}
|
||||
|
||||
this.maxPromptTokens =
|
||||
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
||||
|
||||
|
|
@ -134,7 +162,7 @@ class GoogleClient extends BaseClient {
|
|||
this.userLabel = this.options.userLabel || 'User';
|
||||
this.modelLabel = this.options.modelLabel || 'Assistant';
|
||||
|
||||
if (isChatModel) {
|
||||
if (isChatModel || isGenerativeModel) {
|
||||
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
|
||||
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
|
||||
// without tripping the stop sequences, so I'm using "||>" instead.
|
||||
|
|
@ -189,6 +217,16 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
|
||||
buildMessages(messages = [], parentMessageId) {
|
||||
if (!this.isGenerativeModel && !this.project_id) {
|
||||
throw new Error(
|
||||
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
|
||||
);
|
||||
} else if (this.isGenerativeModel && (!this.apiKey || this.apiKey === 'user_provided')) {
|
||||
throw new Error(
|
||||
'[GoogleClient] an API Key is required for Gemini models (Generative Language API)',
|
||||
);
|
||||
}
|
||||
|
||||
if (this.isTextModel) {
|
||||
return this.buildMessagesPrompt(messages, parentMessageId);
|
||||
}
|
||||
|
|
@ -398,6 +436,16 @@ class GoogleClient extends BaseClient {
|
|||
return res.data;
|
||||
}
|
||||
|
||||
createLLM(clientOptions) {
|
||||
if (this.isGenerativeModel) {
|
||||
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
||||
}
|
||||
|
||||
return this.isTextModel
|
||||
? new GoogleVertexAI(clientOptions)
|
||||
: new ChatGoogleVertexAI(clientOptions);
|
||||
}
|
||||
|
||||
async getCompletion(_payload, options = {}) {
|
||||
const { onProgress, abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
|
|
@ -408,7 +456,7 @@ class GoogleClient extends BaseClient {
|
|||
let clientOptions = {
|
||||
authOptions: {
|
||||
credentials: {
|
||||
...this.credentials,
|
||||
...this.serviceKey,
|
||||
},
|
||||
projectId: this.project_id,
|
||||
},
|
||||
|
|
@ -436,9 +484,7 @@ class GoogleClient extends BaseClient {
|
|||
clientOptions.examples = examples;
|
||||
}
|
||||
|
||||
const model = this.isTextModel
|
||||
? new GoogleVertexAI(clientOptions)
|
||||
: new ChatGoogleVertexAI(clientOptions);
|
||||
const model = this.createLLM(clientOptions);
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel
|
||||
|
|
@ -457,7 +503,9 @@ class GoogleClient extends BaseClient {
|
|||
});
|
||||
|
||||
for await (const chunk of stream) {
|
||||
await this.generateTextStream(chunk?.content ?? chunk, onProgress, { delay: 7 });
|
||||
await this.generateTextStream(chunk?.content ?? chunk, onProgress, {
|
||||
delay: this.isGenerativeModel ? 12 : 8,
|
||||
});
|
||||
reply += chunk?.content ?? chunk;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ const winston = require('winston');
|
|||
const traverse = require('traverse');
|
||||
const { klona } = require('klona/full');
|
||||
|
||||
const sensitiveKeys = [/^sk-\w+$/];
|
||||
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/];
|
||||
|
||||
/**
|
||||
* Determines if a given key string is sensitive.
|
||||
|
|
|
|||
|
|
@ -10,5 +10,6 @@ module.exports = {
|
|||
],
|
||||
moduleNameMapper: {
|
||||
'~/(.*)': '<rootDir>/$1',
|
||||
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@
|
|||
"@azure/search-documents": "^12.0.0",
|
||||
"@keyv/mongo": "^2.1.8",
|
||||
"@keyv/redis": "^2.8.0",
|
||||
"@langchain/google-genai": "^0.0.2",
|
||||
"axios": "^1.3.4",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"cheerio": "^1.0.0-rc.12",
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
|
|||
*/
|
||||
async function loadAsyncEndpoints() {
|
||||
let i = 0;
|
||||
let key, googleUserProvides;
|
||||
let serviceKey, googleUserProvides;
|
||||
try {
|
||||
key = require('~/data/auth.json');
|
||||
serviceKey = require('~/data/auth.json');
|
||||
} catch (e) {
|
||||
if (i === 0) {
|
||||
i++;
|
||||
|
|
@ -33,7 +33,7 @@ async function loadAsyncEndpoints() {
|
|||
}
|
||||
const plugins = transformToolsToMap(tools);
|
||||
|
||||
const google = key || googleUserProvides ? { userProvide: googleUserProvides } : false;
|
||||
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
|
||||
|
||||
const gptPlugins =
|
||||
openAIApiKey || azureOpenAIApiKey
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
|
||||
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
|
||||
const {
|
||||
getOpenAIModels,
|
||||
getChatGPTBrowserModels,
|
||||
getAnthropicModels,
|
||||
} = require('~/server/services/ModelService');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
|
||||
|
||||
const fitlerAssistantModels = (str) => {
|
||||
return /gpt-4|gpt-3\\.5/i.test(str) && !/vision|instruct/i.test(str);
|
||||
|
|
@ -21,18 +21,7 @@ async function loadDefaultModels() {
|
|||
[EModelEndpoint.openAI]: openAI,
|
||||
[EModelEndpoint.azureOpenAI]: azureOpenAI,
|
||||
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
|
||||
[EModelEndpoint.google]: [
|
||||
'chat-bison',
|
||||
'chat-bison-32k',
|
||||
'codechat-bison',
|
||||
'codechat-bison-32k',
|
||||
'text-bison',
|
||||
'text-bison-32k',
|
||||
'text-unicorn',
|
||||
'code-gecko',
|
||||
'code-bison',
|
||||
'code-bison-32k',
|
||||
],
|
||||
[EModelEndpoint.google]: defaultModels[EModelEndpoint.google],
|
||||
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
|
||||
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
const { GoogleClient } = require('~/app');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
|
|
@ -11,14 +11,26 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
if (expiresAt && isUserProvided) {
|
||||
checkUserKeyExpiry(
|
||||
expiresAt,
|
||||
'Your Google key has expired. Please provide your JSON credentials again.',
|
||||
'Your Google Credentials have expired. Please provide your Service Account JSON Key or Generative Language API Key again.',
|
||||
);
|
||||
userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.google });
|
||||
}
|
||||
|
||||
const apiKey = isUserProvided ? userKey : require('~/data/auth.json');
|
||||
let serviceKey = {};
|
||||
try {
|
||||
serviceKey = require('~/data/auth.json');
|
||||
} catch (e) {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
const client = new GoogleClient(apiKey, {
|
||||
const credentials = isUserProvided
|
||||
? userKey
|
||||
: {
|
||||
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
};
|
||||
|
||||
const client = new GoogleClient(credentials, {
|
||||
req,
|
||||
res,
|
||||
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
|
||||
|
|
@ -28,7 +40,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
|||
|
||||
return {
|
||||
client,
|
||||
apiKey,
|
||||
credentials,
|
||||
};
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
const initializeClient = require('./initializeClient');
|
||||
const { GoogleClient } = require('~/app');
|
||||
const { checkUserKeyExpiry, getUserKey } = require('../../UserService');
|
||||
|
||||
jest.mock('../../UserService', () => ({
|
||||
checkUserKeyExpiry: jest.fn().mockImplementation((expiresAt, errorMessage) => {
|
||||
if (new Date(expiresAt) < new Date()) {
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}),
|
||||
getUserKey: jest.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
describe('google/initializeClient', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test('should initialize GoogleClient with user-provided credentials', async () => {
|
||||
process.env.GOOGLE_KEY = 'user_provided';
|
||||
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
|
||||
process.env.PROXY = 'http://proxy';
|
||||
|
||||
const expiresAt = new Date(Date.now() + 60000).toISOString();
|
||||
|
||||
const req = {
|
||||
body: { key: expiresAt },
|
||||
user: { id: '123' },
|
||||
};
|
||||
const res = {};
|
||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||
|
||||
const { client, credentials } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
expect(getUserKey).toHaveBeenCalledWith({ userId: '123', name: 'google' });
|
||||
expect(client).toBeInstanceOf(GoogleClient);
|
||||
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
|
||||
expect(client.options.proxy).toBe('http://proxy');
|
||||
expect(credentials).toEqual({});
|
||||
});
|
||||
|
||||
test('should initialize GoogleClient with service key credentials', async () => {
|
||||
process.env.GOOGLE_KEY = 'service_key';
|
||||
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
|
||||
process.env.PROXY = 'http://proxy';
|
||||
|
||||
const req = {
|
||||
body: { key: null },
|
||||
user: { id: '123' },
|
||||
};
|
||||
const res = {};
|
||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||
|
||||
const { client, credentials } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
expect(client).toBeInstanceOf(GoogleClient);
|
||||
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
|
||||
expect(client.options.proxy).toBe('http://proxy');
|
||||
expect(credentials).toEqual({
|
||||
GOOGLE_SERVICE_KEY: {},
|
||||
GOOGLE_API_KEY: 'service_key',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle expired user-provided key', async () => {
|
||||
process.env.GOOGLE_KEY = 'user_provided';
|
||||
|
||||
const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired
|
||||
const req = {
|
||||
body: { key: expiresAt },
|
||||
user: { id: '123' },
|
||||
};
|
||||
const res = {};
|
||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||
|
||||
checkUserKeyExpiry.mockImplementation((expiresAt, errorMessage) => {
|
||||
throw new Error(errorMessage);
|
||||
});
|
||||
|
||||
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
|
||||
/Your Google Credentials have expired/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
const Keyv = require('keyv');
|
||||
const axios = require('axios');
|
||||
const HttpsProxyAgent = require('https-proxy-agent');
|
||||
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
|
|
@ -117,15 +118,7 @@ const getChatGPTBrowserModels = () => {
|
|||
};
|
||||
|
||||
const getAnthropicModels = () => {
|
||||
let models = [
|
||||
'claude-2.1',
|
||||
'claude-2',
|
||||
'claude-1.2',
|
||||
'claude-1',
|
||||
'claude-1-100k',
|
||||
'claude-instant-1',
|
||||
'claude-instant-1-100k',
|
||||
];
|
||||
let models = defaultModels[EModelEndpoint.anthropic];
|
||||
if (ANTHROPIC_MODELS) {
|
||||
models = String(ANTHROPIC_MODELS).split(',');
|
||||
}
|
||||
|
|
|
|||
13
api/test/__mocks__/auth.mock.json
Normal file
13
api/test/__mocks__/auth.mock.json
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"type": "service_account",
|
||||
"project_id": "",
|
||||
"private_key_id": "",
|
||||
"private_key": "",
|
||||
"client_email": "",
|
||||
"client_id": "",
|
||||
"auth_uri": "",
|
||||
"token_uri": "",
|
||||
"auth_provider_x509_cert_url": "",
|
||||
"client_x509_cert_url": "",
|
||||
"universe_domain": ""
|
||||
}
|
||||
|
|
@ -56,11 +56,12 @@ const maxTokensMap = {
|
|||
'gpt-4-1106': 127995, // -5 from max
|
||||
},
|
||||
[EModelEndpoint.google]: {
|
||||
/* Max I/O is 32k combined, so -1000 to leave room for response */
|
||||
'text-bison-32k': 31000,
|
||||
'chat-bison-32k': 31000,
|
||||
'code-bison-32k': 31000,
|
||||
'codechat-bison-32k': 31000,
|
||||
/* Max I/O is combined so we subtract the amount from max response tokens for actual total */
|
||||
gemini: 32750, // -10 from max
|
||||
'text-bison-32k': 32758, // -10 from max
|
||||
'chat-bison-32k': 32758, // -10 from max
|
||||
'code-bison-32k': 32758, // -10 from max
|
||||
'codechat-bison-32k': 32758,
|
||||
/* Codey, -5 from max: 6144 */
|
||||
'code-': 6139,
|
||||
'codechat-': 6139,
|
||||
|
|
|
|||
|
|
@ -114,6 +114,9 @@ describe('getModelMaxTokens', () => {
|
|||
});
|
||||
|
||||
test('should return correct tokens for partial match - Google models', () => {
|
||||
expect(getModelMaxTokens('gemini-pro', EModelEndpoint.google)).toBe(
|
||||
maxTokensMap[EModelEndpoint.google]['gemini'],
|
||||
);
|
||||
expect(getModelMaxTokens('code-', EModelEndpoint.google)).toBe(
|
||||
maxTokensMap[EModelEndpoint.google]['code-'],
|
||||
);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue