mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 00:40:14 +01:00
🛠️fix: apply safety settings to Gemini API (#3533)
- Introduce getSafetySettings method for centralized safety settings - Apply safety settings uniformly in sendCompletion for all models, including Gemini - Ensure consistent safety settings application in getCompletion - Update .env.example to clarify safety settings usage for both APIs
This commit is contained in:
parent
62854c91d3
commit
433d8f832a
2 changed files with 56 additions and 59 deletions
20
.env.example
20
.env.example
|
|
@ -117,7 +117,7 @@ BINGAI_TOKEN=user_provided
|
||||||
GOOGLE_KEY=user_provided
|
GOOGLE_KEY=user_provided
|
||||||
# GOOGLE_REVERSE_PROXY=
|
# GOOGLE_REVERSE_PROXY=
|
||||||
|
|
||||||
# Gemini API
|
# Gemini API (AI Studio)
|
||||||
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||||
|
|
||||||
# Vertex AI
|
# Vertex AI
|
||||||
|
|
@ -125,20 +125,24 @@ GOOGLE_KEY=user_provided
|
||||||
|
|
||||||
# GOOGLE_TITLE_MODEL=gemini-pro
|
# GOOGLE_TITLE_MODEL=gemini-pro
|
||||||
|
|
||||||
# Google Gemini Safety Settings
|
# Google Safety Settings
|
||||||
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
# NOTE: These settings apply to both Vertex AI and Gemini API (AI Studio)
|
||||||
# To use this restricted HarmBlockThreshold setting, you will need to either:
|
|
||||||
#
|
#
|
||||||
# (a) Get access through an allowlist via your Google account team
|
# For Vertex AI:
|
||||||
# (b) Switch your account type to monthly invoiced billing following this instruction:
|
# To use the BLOCK_NONE setting, you need either:
|
||||||
# https://cloud.google.com/billing/docs/how-to/invoiced-billing
|
# (a) Access through an allowlist via your Google account team, or
|
||||||
|
# (b) Switch to monthly invoiced billing: https://cloud.google.com/billing/docs/how-to/invoiced-billing
|
||||||
|
#
|
||||||
|
# For Gemini API (AI Studio):
|
||||||
|
# BLOCK_NONE is available by default, no special account requirements.
|
||||||
|
#
|
||||||
|
# Available options: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
|
||||||
#
|
#
|
||||||
# GOOGLE_SAFETY_SEXUALLY_EXPLICIT=BLOCK_ONLY_HIGH
|
# GOOGLE_SAFETY_SEXUALLY_EXPLICIT=BLOCK_ONLY_HIGH
|
||||||
# GOOGLE_SAFETY_HATE_SPEECH=BLOCK_ONLY_HIGH
|
# GOOGLE_SAFETY_HATE_SPEECH=BLOCK_ONLY_HIGH
|
||||||
# GOOGLE_SAFETY_HARASSMENT=BLOCK_ONLY_HIGH
|
# GOOGLE_SAFETY_HARASSMENT=BLOCK_ONLY_HIGH
|
||||||
# GOOGLE_SAFETY_DANGEROUS_CONTENT=BLOCK_ONLY_HIGH
|
# GOOGLE_SAFETY_DANGEROUS_CONTENT=BLOCK_ONLY_HIGH
|
||||||
|
|
||||||
|
|
||||||
#============#
|
#============#
|
||||||
# OpenAI #
|
# OpenAI #
|
||||||
#============#
|
#============#
|
||||||
|
|
|
||||||
|
|
@ -626,11 +626,11 @@ class GoogleClient extends BaseClient {
|
||||||
const { onProgress, abortController } = options;
|
const { onProgress, abortController } = options;
|
||||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||||
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
||||||
|
|
||||||
let examples;
|
let examples;
|
||||||
|
|
||||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||||
|
|
||||||
if (this.project_id) {
|
if (this.project_id) {
|
||||||
clientOptions['authOptions'] = {
|
clientOptions['authOptions'] = {
|
||||||
credentials: {
|
credentials: {
|
||||||
|
|
@ -639,16 +639,16 @@ class GoogleClient extends BaseClient {
|
||||||
projectId: this.project_id,
|
projectId: this.project_id,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!parameters) {
|
if (!parameters) {
|
||||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.isGenerativeModel && !this.project_id) {
|
if (this.isGenerativeModel && !this.project_id) {
|
||||||
clientOptions.modelName = clientOptions.model;
|
clientOptions.modelName = clientOptions.model;
|
||||||
delete clientOptions.model;
|
delete clientOptions.model;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_examples && _examples.length) {
|
if (_examples && _examples.length) {
|
||||||
examples = _examples
|
examples = _examples
|
||||||
.map((ex) => {
|
.map((ex) => {
|
||||||
|
|
@ -662,27 +662,26 @@ class GoogleClient extends BaseClient {
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
.filter((ex) => ex);
|
.filter((ex) => ex);
|
||||||
|
|
||||||
clientOptions.examples = examples;
|
clientOptions.examples = examples;
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = this.createLLM(clientOptions);
|
const model = this.createLLM(clientOptions);
|
||||||
|
|
||||||
let reply = '';
|
let reply = '';
|
||||||
const messages = this.isTextModel ? _payload.trim() : _messages;
|
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||||
|
|
||||||
if (!this.isVisionModel && context && messages?.length > 0) {
|
if (!this.isVisionModel && context && messages?.length > 0) {
|
||||||
messages.unshift(new SystemMessage(context));
|
messages.unshift(new SystemMessage(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||||
if (modelName?.includes('1.5') && !this.project_id) {
|
if (modelName?.includes('1.5') && !this.project_id) {
|
||||||
/** @type {GenerativeModel} */
|
|
||||||
const client = model;
|
const client = model;
|
||||||
const requestOptions = {
|
const requestOptions = {
|
||||||
contents: _payload,
|
contents: _payload,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (this.options?.promptPrefix?.length) {
|
if (this.options?.promptPrefix?.length) {
|
||||||
requestOptions.systemInstruction = {
|
requestOptions.systemInstruction = {
|
||||||
parts: [
|
parts: [
|
||||||
|
|
@ -692,10 +691,9 @@ class GoogleClient extends BaseClient {
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const safetySettings = _payload.safetySettings;
|
requestOptions.safetySettings = _payload.safetySettings;
|
||||||
requestOptions.safetySettings = safetySettings;
|
|
||||||
|
|
||||||
const delay = modelName.includes('flash') ? 8 : 14;
|
const delay = modelName.includes('flash') ? 8 : 14;
|
||||||
const result = await client.generateContentStream(requestOptions);
|
const result = await client.generateContentStream(requestOptions);
|
||||||
for await (const chunk of result.stream) {
|
for await (const chunk of result.stream) {
|
||||||
|
|
@ -708,16 +706,15 @@ class GoogleClient extends BaseClient {
|
||||||
}
|
}
|
||||||
return reply;
|
return reply;
|
||||||
}
|
}
|
||||||
|
|
||||||
const safetySettings = _payload.safetySettings;
|
|
||||||
const stream = await model.stream(messages, {
|
const stream = await model.stream(messages, {
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
timeout: 7000,
|
timeout: 7000,
|
||||||
safetySettings: safetySettings,
|
safetySettings: _payload.safetySettings,
|
||||||
});
|
});
|
||||||
|
|
||||||
let delay = this.options.streamRate || 8;
|
let delay = this.options.streamRate || 8;
|
||||||
|
|
||||||
if (!this.options.streamRate) {
|
if (!this.options.streamRate) {
|
||||||
if (this.isGenerativeModel) {
|
if (this.isGenerativeModel) {
|
||||||
delay = 12;
|
delay = 12;
|
||||||
|
|
@ -726,7 +723,7 @@ class GoogleClient extends BaseClient {
|
||||||
delay = 5;
|
delay = 5;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
const chunkText = chunk?.content ?? chunk;
|
const chunkText = chunk?.content ?? chunk;
|
||||||
await this.generateTextStream(chunkText, onProgress, {
|
await this.generateTextStream(chunkText, onProgress, {
|
||||||
|
|
@ -734,7 +731,7 @@ class GoogleClient extends BaseClient {
|
||||||
});
|
});
|
||||||
reply += chunkText;
|
reply += chunkText;
|
||||||
}
|
}
|
||||||
|
|
||||||
return reply;
|
return reply;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -871,37 +868,33 @@ class GoogleClient extends BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
async sendCompletion(payload, opts = {}) {
|
async sendCompletion(payload, opts = {}) {
|
||||||
const modelName = payload.parameters?.model;
|
payload.safetySettings = this.getSafetySettings();
|
||||||
|
|
||||||
if (modelName && modelName.toLowerCase().includes('gemini')) {
|
|
||||||
const safetySettings = [
|
|
||||||
{
|
|
||||||
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
|
||||||
threshold:
|
|
||||||
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
category: 'HARM_CATEGORY_HATE_SPEECH',
|
|
||||||
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
category: 'HARM_CATEGORY_HARASSMENT',
|
|
||||||
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
|
|
||||||
threshold:
|
|
||||||
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
payload.safetySettings = safetySettings;
|
|
||||||
}
|
|
||||||
|
|
||||||
let reply = '';
|
let reply = '';
|
||||||
reply = await this.getCompletion(payload, opts);
|
reply = await this.getCompletion(payload, opts);
|
||||||
return reply.trim();
|
return reply.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getSafetySettings() {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
||||||
|
threshold: process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
category: 'HARM_CATEGORY_HATE_SPEECH',
|
||||||
|
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
category: 'HARM_CATEGORY_HARASSMENT',
|
||||||
|
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
|
||||||
|
threshold: process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
|
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
|
||||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue