🛠️fix: apply safety settings to Gemini API (#3533)

- Introduce getSafetySettings method for centralized safety settings
- Apply safety settings uniformly in sendCompletion for all models, including Gemini
- Ensure consistent safety settings application in getCompletion
- Update .env.example to clarify safety settings usage for both APIs
This commit is contained in:
Oliver Faust 2024-08-04 20:08:57 +02:00 committed by GitHub
parent 62854c91d3
commit 433d8f832a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 56 additions and 59 deletions

View file

@ -117,7 +117,7 @@ BINGAI_TOKEN=user_provided
GOOGLE_KEY=user_provided GOOGLE_KEY=user_provided
# GOOGLE_REVERSE_PROXY= # GOOGLE_REVERSE_PROXY=
# Gemini API # Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision # GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
# Vertex AI # Vertex AI
@ -125,20 +125,24 @@ GOOGLE_KEY=user_provided
# GOOGLE_TITLE_MODEL=gemini-pro # GOOGLE_TITLE_MODEL=gemini-pro
# Google Gemini Safety Settings # Google Safety Settings
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default. # NOTE: These settings apply to both Vertex AI and Gemini API (AI Studio)
# To use this restricted HarmBlockThreshold setting, you will need to either:
# #
# (a) Get access through an allowlist via your Google account team # For Vertex AI:
# (b) Switch your account type to monthly invoiced billing following this instruction: # To use the BLOCK_NONE setting, you need either:
# https://cloud.google.com/billing/docs/how-to/invoiced-billing # (a) Access through an allowlist via your Google account team, or
# (b) Switch to monthly invoiced billing: https://cloud.google.com/billing/docs/how-to/invoiced-billing
#
# For Gemini API (AI Studio):
# BLOCK_NONE is available by default, no special account requirements.
#
# Available options: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
# #
# GOOGLE_SAFETY_SEXUALLY_EXPLICIT=BLOCK_ONLY_HIGH # GOOGLE_SAFETY_SEXUALLY_EXPLICIT=BLOCK_ONLY_HIGH
# GOOGLE_SAFETY_HATE_SPEECH=BLOCK_ONLY_HIGH # GOOGLE_SAFETY_HATE_SPEECH=BLOCK_ONLY_HIGH
# GOOGLE_SAFETY_HARASSMENT=BLOCK_ONLY_HIGH # GOOGLE_SAFETY_HARASSMENT=BLOCK_ONLY_HIGH
# GOOGLE_SAFETY_DANGEROUS_CONTENT=BLOCK_ONLY_HIGH # GOOGLE_SAFETY_DANGEROUS_CONTENT=BLOCK_ONLY_HIGH
#============# #============#
# OpenAI # # OpenAI #
#============# #============#

View file

@ -626,11 +626,11 @@ class GoogleClient extends BaseClient {
const { onProgress, abortController } = options; const { onProgress, abortController } = options;
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {}; const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
let examples; let examples;
let clientOptions = { ...parameters, maxRetries: 2 }; let clientOptions = { ...parameters, maxRetries: 2 };
if (this.project_id) { if (this.project_id) {
clientOptions['authOptions'] = { clientOptions['authOptions'] = {
credentials: { credentials: {
@ -639,16 +639,16 @@ class GoogleClient extends BaseClient {
projectId: this.project_id, projectId: this.project_id,
}; };
} }
if (!parameters) { if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions }; clientOptions = { ...clientOptions, ...this.modelOptions };
} }
if (this.isGenerativeModel && !this.project_id) { if (this.isGenerativeModel && !this.project_id) {
clientOptions.modelName = clientOptions.model; clientOptions.modelName = clientOptions.model;
delete clientOptions.model; delete clientOptions.model;
} }
if (_examples && _examples.length) { if (_examples && _examples.length) {
examples = _examples examples = _examples
.map((ex) => { .map((ex) => {
@ -662,27 +662,26 @@ class GoogleClient extends BaseClient {
}; };
}) })
.filter((ex) => ex); .filter((ex) => ex);
clientOptions.examples = examples; clientOptions.examples = examples;
} }
const model = this.createLLM(clientOptions); const model = this.createLLM(clientOptions);
let reply = ''; let reply = '';
const messages = this.isTextModel ? _payload.trim() : _messages; const messages = this.isTextModel ? _payload.trim() : _messages;
if (!this.isVisionModel && context && messages?.length > 0) { if (!this.isVisionModel && context && messages?.length > 0) {
messages.unshift(new SystemMessage(context)); messages.unshift(new SystemMessage(context));
} }
const modelName = clientOptions.modelName ?? clientOptions.model ?? ''; const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) { if (modelName?.includes('1.5') && !this.project_id) {
/** @type {GenerativeModel} */
const client = model; const client = model;
const requestOptions = { const requestOptions = {
contents: _payload, contents: _payload,
}; };
if (this.options?.promptPrefix?.length) { if (this.options?.promptPrefix?.length) {
requestOptions.systemInstruction = { requestOptions.systemInstruction = {
parts: [ parts: [
@ -692,10 +691,9 @@ class GoogleClient extends BaseClient {
], ],
}; };
} }
const safetySettings = _payload.safetySettings; requestOptions.safetySettings = _payload.safetySettings;
requestOptions.safetySettings = safetySettings;
const delay = modelName.includes('flash') ? 8 : 14; const delay = modelName.includes('flash') ? 8 : 14;
const result = await client.generateContentStream(requestOptions); const result = await client.generateContentStream(requestOptions);
for await (const chunk of result.stream) { for await (const chunk of result.stream) {
@ -708,16 +706,15 @@ class GoogleClient extends BaseClient {
} }
return reply; return reply;
} }
const safetySettings = _payload.safetySettings;
const stream = await model.stream(messages, { const stream = await model.stream(messages, {
signal: abortController.signal, signal: abortController.signal,
timeout: 7000, timeout: 7000,
safetySettings: safetySettings, safetySettings: _payload.safetySettings,
}); });
let delay = this.options.streamRate || 8; let delay = this.options.streamRate || 8;
if (!this.options.streamRate) { if (!this.options.streamRate) {
if (this.isGenerativeModel) { if (this.isGenerativeModel) {
delay = 12; delay = 12;
@ -726,7 +723,7 @@ class GoogleClient extends BaseClient {
delay = 5; delay = 5;
} }
} }
for await (const chunk of stream) { for await (const chunk of stream) {
const chunkText = chunk?.content ?? chunk; const chunkText = chunk?.content ?? chunk;
await this.generateTextStream(chunkText, onProgress, { await this.generateTextStream(chunkText, onProgress, {
@ -734,7 +731,7 @@ class GoogleClient extends BaseClient {
}); });
reply += chunkText; reply += chunkText;
} }
return reply; return reply;
} }
@ -871,37 +868,33 @@ class GoogleClient extends BaseClient {
} }
async sendCompletion(payload, opts = {}) { async sendCompletion(payload, opts = {}) {
const modelName = payload.parameters?.model; payload.safetySettings = this.getSafetySettings();
if (modelName && modelName.toLowerCase().includes('gemini')) {
const safetySettings = [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold:
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold:
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
];
payload.safetySettings = safetySettings;
}
let reply = ''; let reply = '';
reply = await this.getCompletion(payload, opts); reply = await this.getCompletion(payload, opts);
return reply.trim(); return reply.trim();
} }
getSafetySettings() {
return [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold: process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold: process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
];
}
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */ /* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {