diff --git a/.env.example b/.env.example index ca48394983..9b157f0c30 100644 --- a/.env.example +++ b/.env.example @@ -117,7 +117,7 @@ BINGAI_TOKEN=user_provided GOOGLE_KEY=user_provided # GOOGLE_REVERSE_PROXY= -# Gemini API +# Gemini API (AI Studio) # GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision # Vertex AI @@ -125,20 +125,24 @@ GOOGLE_KEY=user_provided # GOOGLE_TITLE_MODEL=gemini-pro -# Google Gemini Safety Settings -# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default. -# To use this restricted HarmBlockThreshold setting, you will need to either: +# Google Safety Settings +# NOTE: These settings apply to both Vertex AI and Gemini API (AI Studio) # -# (a) Get access through an allowlist via your Google account team -# (b) Switch your account type to monthly invoiced billing following this instruction: -# https://cloud.google.com/billing/docs/how-to/invoiced-billing +# For Vertex AI: +# To use the BLOCK_NONE setting, you need either: +# (a) Access through an allowlist via your Google account team, or +# (b) Switch to monthly invoiced billing: https://cloud.google.com/billing/docs/how-to/invoiced-billing +# +# For Gemini API (AI Studio): +# BLOCK_NONE is available by default, no special account requirements. +# +# Available options: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE # # GOOGLE_SAFETY_SEXUALLY_EXPLICIT=BLOCK_ONLY_HIGH # GOOGLE_SAFETY_HATE_SPEECH=BLOCK_ONLY_HIGH # GOOGLE_SAFETY_HARASSMENT=BLOCK_ONLY_HIGH # GOOGLE_SAFETY_DANGEROUS_CONTENT=BLOCK_ONLY_HIGH - #============# # OpenAI # #============# diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index e115ab1db8..225d73d935 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -626,11 +626,11 @@ class GoogleClient extends BaseClient { const { onProgress, abortController } = options; const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {}; - + let examples; - + let clientOptions = { ...parameters, maxRetries: 2 }; - + if (this.project_id) { clientOptions['authOptions'] = { credentials: { @@ -639,16 +639,16 @@ class GoogleClient extends BaseClient { projectId: this.project_id, }; } - + if (!parameters) { clientOptions = { ...clientOptions, ...this.modelOptions }; } - + if (this.isGenerativeModel && !this.project_id) { clientOptions.modelName = clientOptions.model; delete clientOptions.model; } - + if (_examples && _examples.length) { examples = _examples .map((ex) => { @@ -662,27 +662,26 @@ class GoogleClient extends BaseClient { }; }) .filter((ex) => ex); - + clientOptions.examples = examples; } - + const model = this.createLLM(clientOptions); - + let reply = ''; const messages = this.isTextModel ? _payload.trim() : _messages; - + if (!this.isVisionModel && context && messages?.length > 0) { messages.unshift(new SystemMessage(context)); } - + const modelName = clientOptions.modelName ?? clientOptions.model ?? ''; if (modelName?.includes('1.5') && !this.project_id) { - /** @type {GenerativeModel} */ const client = model; const requestOptions = { contents: _payload, }; - + if (this.options?.promptPrefix?.length) { requestOptions.systemInstruction = { parts: [ @@ -692,10 +691,9 @@ class GoogleClient extends BaseClient { ], }; } - - const safetySettings = _payload.safetySettings; - requestOptions.safetySettings = safetySettings; - + + requestOptions.safetySettings = _payload.safetySettings; + const delay = modelName.includes('flash') ? 8 : 14; const result = await client.generateContentStream(requestOptions); for await (const chunk of result.stream) { @@ -708,16 +706,15 @@ class GoogleClient extends BaseClient { } return reply; } - - const safetySettings = _payload.safetySettings; + const stream = await model.stream(messages, { signal: abortController.signal, timeout: 7000, - safetySettings: safetySettings, + safetySettings: _payload.safetySettings, }); - + let delay = this.options.streamRate || 8; - + if (!this.options.streamRate) { if (this.isGenerativeModel) { delay = 12; @@ -726,7 +723,7 @@ class GoogleClient extends BaseClient { delay = 5; } } - + for await (const chunk of stream) { const chunkText = chunk?.content ?? chunk; await this.generateTextStream(chunkText, onProgress, { @@ -734,7 +731,7 @@ class GoogleClient extends BaseClient { }); reply += chunkText; } - + return reply; } @@ -871,37 +868,33 @@ class GoogleClient extends BaseClient { } async sendCompletion(payload, opts = {}) { - const modelName = payload.parameters?.model; - - if (modelName && modelName.toLowerCase().includes('gemini')) { - const safetySettings = [ - { - category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - threshold: - process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', - }, - { - category: 'HARM_CATEGORY_HATE_SPEECH', - threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', - }, - { - category: 'HARM_CATEGORY_HARASSMENT', - threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', - }, - { - category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - threshold: - process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', - }, - ]; - - payload.safetySettings = safetySettings; - } - + payload.safetySettings = this.getSafetySettings(); + let reply = ''; reply = await this.getCompletion(payload, opts); return reply.trim(); } + + getSafetySettings() { + return [ + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + threshold: process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + }, + { + category: 'HARM_CATEGORY_HATE_SPEECH', + threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + threshold: process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + }, + ]; + } /* TO-DO: Handle tokens with Google tokenization NOTE: these are required */ static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {