🤖 feat: Add titling to Google client (#2983)

* feat: Add titling to Google client

* feat: Add titling to Google client

* PR feedback changes
This commit is contained in:
Matthew Unrath 2024-06-22 08:42:51 -07:00 committed by GitHub
parent aac01df80c
commit b5081bfe86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 193 additions and 5 deletions

View file

@ -123,6 +123,8 @@ GOOGLE_KEY=user_provided
# Vertex AI
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
# GOOGLE_TITLE_MODEL=gemini-pro
# Google Gemini Safety Settings
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
# To use this restricted HarmBlockThreshold setting, you will need to either:

View file

@ -16,10 +16,15 @@ const {
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { formatMessage, createContextHandlers } = require('./prompts');
const { getModelMaxTokens } = require('~/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const {
formatMessage,
createContextHandlers,
titleInstruction,
truncateText,
} = require('./prompts');
const BaseClient = require('./BaseClient');
const loc = 'us-central1';
const publisher = 'google';
@ -591,12 +596,16 @@ class GoogleClient extends BaseClient {
createLLM(clientOptions) {
const model = clientOptions.modelName ?? clientOptions.model;
if (this.project_id && this.isTextModel) {
logger.debug('Creating Google VertexAI client');
return new GoogleVertexAI(clientOptions);
} else if (this.project_id && this.isChatModel) {
logger.debug('Creating Chat Google VertexAI client');
return new ChatGoogleVertexAI(clientOptions);
} else if (this.project_id) {
logger.debug('Creating VertexAI client');
return new ChatVertexAI(clientOptions);
} else if (model.includes('1.5')) {
logger.debug('Creating GenAI client');
return new GenAI(this.apiKey).getGenerativeModel(
{
...clientOptions,
@ -606,6 +615,7 @@ class GoogleClient extends BaseClient {
);
}
logger.debug('Creating Chat Google Generative AI client');
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}
@ -717,6 +727,123 @@ class GoogleClient extends BaseClient {
return reply;
}
/**
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
*/
async titleChatCompletion(_payload, options = {}) {
const { abortController } = options;
const { parameters, instances } = _payload;
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};
let clientOptions = { ...parameters, maxRetries: 2 };
logger.info('Initialized title client options');
if (this.project_id) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
},
projectId: this.project_id,
};
}
if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions };
}
if (this.isGenerativeModel && !this.project_id) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}
const model = this.createLLM(clientOptions);
let reply = '';
const messages = this.isTextModel ? _payload.trim() : _messages;
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
logger.info('Identified titling model as 1.5 version');
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
contents: _payload,
};
if (this.options?.promptPrefix?.length) {
requestOptions.systemInstruction = {
parts: [
{
text: this.options.promptPrefix,
},
],
};
}
const safetySettings = _payload.safetySettings;
requestOptions.safetySettings = safetySettings;
const result = await client.generateContent(requestOptions);
reply = result.response?.text();
return reply;
} else {
logger.info('Beginning titling');
const safetySettings = _payload.safetySettings;
const titleResponse = await model.invoke(messages, {
signal: abortController.signal,
timeout: 7000,
safetySettings: safetySettings,
});
reply = titleResponse.content;
return reply;
}
}
async titleConvo({ text, responseText = '' }) {
let title = 'New Chat';
const convo = `||>User:
"${truncateText(text)}"
||>Response:
"${JSON.stringify(truncateText(responseText))}"`;
let { prompt: payload } = await this.buildMessages([
{
text: `Please generate ${titleInstruction}
${convo}
||>Title:`,
isCreatedByUser: true,
author: this.userLabel,
},
]);
if (this.isVisionModel) {
logger.warn(
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
);
payload.parameters = { ...payload.parameters, model: settings.model.default };
}
try {
title = await this.titleChatCompletion(payload, {
abortController: new AbortController(),
onProgress: () => {},
});
} catch (e) {
logger.error('[GoogleClient] There was an issue generating the title', e);
}
logger.info(`Title response: ${title}`);
return title;
}
getSaveOptions() {
return {
promptPrefix: this.options.promptPrefix,

View file

@ -1,6 +1,6 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
@ -20,7 +20,7 @@ router.post(
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient);
await AskController(req, res, next, initializeClient, addTitle);
},
);

View file

@ -0,0 +1,58 @@
const { CacheKeys, Constants } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils');
const { saveConvo } = require('~/models');
const { logger } = require('~/config');
const initializeClient = require('./initializeClient');
const addTitle = async (req, { text, response, client }) => {
const { TITLE_CONVO = 'true' } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) {
return;
}
if (client.options.titleConvo === false) {
return;
}
const DEFAULT_TITLE_MODEL = 'gemini-pro';
const { GOOGLE_TITLE_MODEL } = process.env ?? {};
let model = GOOGLE_TITLE_MODEL ?? DEFAULT_TITLE_MODEL;
if (GOOGLE_TITLE_MODEL === Constants.CURRENT_MODEL) {
model = client.options?.modelOptions.model;
if (client.isVisionModel) {
logger.warn(
`current_model was specified for Google title request, but the model ${model} cannot process a text-only conversation. Falling back to ${DEFAULT_TITLE_MODEL}`,
);
model = DEFAULT_TITLE_MODEL;
}
}
const titleEndpointOptions = {
...client.options,
modelOptions: { ...client.options?.modelOptions, model: model },
attachments: undefined, // After a response, this is set to an empty array which results in an error during setOptions
};
const { client: titleClient } = await initializeClient({
req,
res: response,
endpointOption: titleEndpointOptions,
});
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;
const title = await titleClient.titleConvo({ text, responseText: response?.text });
await titleCache.set(key, title, 120000);
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
};
module.exports = addTitle;

View file

@ -1,8 +1,9 @@
const addTitle = require('./addTitle');
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');
module.exports = {
// addTitle, // todo
addTitle,
buildOptions,
initializeClient,
};