feat: Google Gemini ❇️ (#1355)

* refactor: add gemini-pro to google Models list; use defaultModels for central model listing

* refactor(SetKeyDialog): create useMultipleKeys hook to use for Azure, export `isJson` from utils, use EModelEndpoint

* refactor(useUserKey): change variable names to make keyName setting more clear

* refactor(FileUpload): allow passing container className string

* feat(GoogleClient): Gemini support

* refactor(GoogleClient): alternate stream speed for Gemini models

* feat(Gemini): styling/settings configuration for Gemini

* refactor(GoogleClient): substract max response tokens from max context tokens if context is above 32k (I/O max is combined between the two)

* refactor(tokens): correct google max token counts and subtract max response tokens when input/output count are combined towards max context count

* feat(google/initializeClient): handle both local and user_provided credentials and write tests

* fix(GoogleClient): catch if credentials are undefined, handle if serviceKey is string or object correctly, handle no examples passed, throw error if not a Generative Language model and no service account JSON key is provided, throw error if it is a Generative m
odel, but not google API key was provided

* refactor(loadAsyncEndpoints/google): activate Google endpoint if either the service key JSON file is provided in /api/data, or a GOOGLE_KEY is defined.

* docs: updated Google configuration

* fix(ci): Mock import of Service Account Key JSON file (auth.json)

* Update apis_and_tokens.md

* feat: increase max output tokens slider for gemini pro

* refactor(GoogleSettings): handle max and default maxOutputTokens on model change

* chore: add sensitive redact regex

* docs: add warning about data privacy

* Update apis_and_tokens.md
This commit is contained in:
Danny Avila 2023-12-15 02:18:07 -05:00 committed by GitHub
parent d259431316
commit 561ce8e86a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 702 additions and 219 deletions

View file

@ -8,9 +8,9 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
*/
async function loadAsyncEndpoints() {
let i = 0;
let key, googleUserProvides;
let serviceKey, googleUserProvides;
try {
key = require('~/data/auth.json');
serviceKey = require('~/data/auth.json');
} catch (e) {
if (i === 0) {
i++;
@ -33,7 +33,7 @@ async function loadAsyncEndpoints() {
}
const plugins = transformToolsToMap(tools);
const google = key || googleUserProvides ? { userProvide: googleUserProvides } : false;
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
const gptPlugins =
openAIApiKey || azureOpenAIApiKey

View file

@ -1,10 +1,10 @@
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
const {
getOpenAIModels,
getChatGPTBrowserModels,
getAnthropicModels,
} = require('~/server/services/ModelService');
const { EModelEndpoint } = require('librechat-data-provider');
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
const fitlerAssistantModels = (str) => {
return /gpt-4|gpt-3\\.5/i.test(str) && !/vision|instruct/i.test(str);
@ -21,18 +21,7 @@ async function loadDefaultModels() {
[EModelEndpoint.openAI]: openAI,
[EModelEndpoint.azureOpenAI]: azureOpenAI,
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
[EModelEndpoint.google]: [
'chat-bison',
'chat-bison-32k',
'codechat-bison',
'codechat-bison-32k',
'text-bison',
'text-bison-32k',
'text-unicorn',
'code-gecko',
'code-bison',
'code-bison-32k',
],
[EModelEndpoint.google]: defaultModels[EModelEndpoint.google],
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
[EModelEndpoint.gptPlugins]: gptPlugins,

View file

@ -1,5 +1,5 @@
const { GoogleClient } = require('~/app');
const { EModelEndpoint } = require('librechat-data-provider');
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const initializeClient = async ({ req, res, endpointOption }) => {
@ -11,14 +11,26 @@ const initializeClient = async ({ req, res, endpointOption }) => {
if (expiresAt && isUserProvided) {
checkUserKeyExpiry(
expiresAt,
'Your Google key has expired. Please provide your JSON credentials again.',
'Your Google Credentials have expired. Please provide your Service Account JSON Key or Generative Language API Key again.',
);
userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.google });
}
const apiKey = isUserProvided ? userKey : require('~/data/auth.json');
let serviceKey = {};
try {
serviceKey = require('~/data/auth.json');
} catch (e) {
// Do nothing
}
const client = new GoogleClient(apiKey, {
const credentials = isUserProvided
? userKey
: {
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
};
const client = new GoogleClient(credentials, {
req,
res,
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
@ -28,7 +40,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
return {
client,
apiKey,
credentials,
};
};

View file

@ -0,0 +1,84 @@
const initializeClient = require('./initializeClient');
const { GoogleClient } = require('~/app');
const { checkUserKeyExpiry, getUserKey } = require('../../UserService');
jest.mock('../../UserService', () => ({
checkUserKeyExpiry: jest.fn().mockImplementation((expiresAt, errorMessage) => {
if (new Date(expiresAt) < new Date()) {
throw new Error(errorMessage);
}
}),
getUserKey: jest.fn().mockImplementation(() => ({})),
}));
describe('google/initializeClient', () => {
afterEach(() => {
jest.clearAllMocks();
});
test('should initialize GoogleClient with user-provided credentials', async () => {
process.env.GOOGLE_KEY = 'user_provided';
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
process.env.PROXY = 'http://proxy';
const expiresAt = new Date(Date.now() + 60000).toISOString();
const req = {
body: { key: expiresAt },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client, credentials } = await initializeClient({ req, res, endpointOption });
expect(getUserKey).toHaveBeenCalledWith({ userId: '123', name: 'google' });
expect(client).toBeInstanceOf(GoogleClient);
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
expect(client.options.proxy).toBe('http://proxy');
expect(credentials).toEqual({});
});
test('should initialize GoogleClient with service key credentials', async () => {
process.env.GOOGLE_KEY = 'service_key';
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
process.env.PROXY = 'http://proxy';
const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client, credentials } = await initializeClient({ req, res, endpointOption });
expect(client).toBeInstanceOf(GoogleClient);
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
expect(client.options.proxy).toBe('http://proxy');
expect(credentials).toEqual({
GOOGLE_SERVICE_KEY: {},
GOOGLE_API_KEY: 'service_key',
});
});
test('should handle expired user-provided key', async () => {
process.env.GOOGLE_KEY = 'user_provided';
const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired
const req = {
body: { key: expiresAt },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
checkUserKeyExpiry.mockImplementation((expiresAt, errorMessage) => {
throw new Error(errorMessage);
});
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/Your Google Credentials have expired/,
);
});
});

View file

@ -1,6 +1,7 @@
const Keyv = require('keyv');
const axios = require('axios');
const HttpsProxyAgent = require('https-proxy-agent');
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { extractBaseURL } = require('~/utils');
@ -117,15 +118,7 @@ const getChatGPTBrowserModels = () => {
};
const getAnthropicModels = () => {
let models = [
'claude-2.1',
'claude-2',
'claude-1.2',
'claude-1',
'claude-1-100k',
'claude-instant-1',
'claude-instant-1-100k',
];
let models = defaultModels[EModelEndpoint.anthropic];
if (ANTHROPIC_MODELS) {
models = String(ANTHROPIC_MODELS).split(',');
}