mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 08:12:00 +02:00

* feat: working started for feedback implementation. TODO: - needs some refactoring. - needs some UI animations. * feat: working rate functionality * feat: works now as well to reader the already rated responses from the server. * feat: added the option to give feedback in text (optional) * feat: added Dismiss option `x` to the `FeedbackTagOptions` * ✨ feat: Add rating and ratingContent fields to message schema * 🔧 chore: Bump version to 0.0.3 in package.json * ✨ feat: Enhance feedback localization and update UI elements * 🚀 feat: Implement feedback tagging system with thumbs up/down options * 🚀 feat: Add data-provider package to unused i18n keys detection * 🎨 style: update HoverButtons' style * 🎨 style: Update HoverButtons and Fork components for improved styling and visibility * 🔧 feat: Implement feedback system with rating and content options * 🔧 feat: Enhance feedback handling with improved rating toggle and tag options * 🔧 feat: Integrate toast notifications for feedback submission and clean up unused state * 🔧 feat: Remove unused feedback tag options from translation file * ✨ refactor: clean up Feedback component and improve HoverButtons structure * ✨ refactor: remove unused settings switches for auto scroll, hide side panel, and user message markdown * refactor: reorganize import order * ✨ refactor: enhance HoverButtons and Fork components with improved styles and animations * ✨ refactor: update feedback response phrases for improved user engagement * ✨ refactor: add CheckboxOption component and streamline fork options rendering * Refactor feedback components and logic - Consolidated feedback handling into a single Feedback component, removing FeedbackButtons and FeedbackTagOptions. - Introduced new feedback tagging system with detailed tags for both thumbs up and thumbs down ratings. - Updated feedback schema to include new tags and improved type definitions. - Enhanced user interface for feedback collection, including a dialog for additional comments. - Removed obsolete files and adjusted imports accordingly. - Updated translations for new feedback tags and placeholders. * ✨ refactor: update feedback handling by replacing rating fields with feedback in message updates * fix: add missing validateMessageReq middleware to feedback route and refactor feedback system * 🗑️ chore: Remove redundant fork option explanations from translation file * 🔧 refactor: Remove unused dependency from feedback callback * 🔧 refactor: Simplify message update response structure and improve error logging * Chore: removed unused tests. --------- Co-authored-by: Marco Beretta <81851188+berry-13@users.noreply.github.com>
981 lines
31 KiB
JavaScript
981 lines
31 KiB
JavaScript
const { google } = require('googleapis');
|
|
const { concat } = require('@langchain/core/utils/stream');
|
|
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
|
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
|
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
|
|
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
|
const {
|
|
googleGenConfigSchema,
|
|
validateVisionModel,
|
|
getResponseSender,
|
|
endpointSettings,
|
|
parseTextParts,
|
|
EModelEndpoint,
|
|
ContentTypes,
|
|
VisionModes,
|
|
ErrorTypes,
|
|
Constants,
|
|
AuthKeys,
|
|
} = require('librechat-data-provider');
|
|
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
|
|
const { encodeAndFormat } = require('~/server/services/Files/images');
|
|
const Tokenizer = require('~/server/services/Tokenizer');
|
|
const { spendTokens } = require('~/models/spendTokens');
|
|
const { getModelMaxTokens } = require('~/utils');
|
|
const { sleep } = require('~/server/utils');
|
|
const { logger } = require('~/config');
|
|
const {
|
|
formatMessage,
|
|
createContextHandlers,
|
|
titleInstruction,
|
|
truncateText,
|
|
} = require('./prompts');
|
|
const BaseClient = require('./BaseClient');
|
|
|
|
const loc = process.env.GOOGLE_LOC || 'us-central1';
|
|
const publisher = 'google';
|
|
const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
|
|
|
|
const settings = endpointSettings[EModelEndpoint.google];
|
|
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
|
|
|
|
class GoogleClient extends BaseClient {
|
|
constructor(credentials, options = {}) {
|
|
super('apiKey', options);
|
|
let creds = {};
|
|
|
|
if (typeof credentials === 'string') {
|
|
creds = JSON.parse(credentials);
|
|
} else if (credentials) {
|
|
creds = credentials;
|
|
}
|
|
|
|
const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
|
|
this.serviceKey =
|
|
serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : (serviceKey ?? {});
|
|
/** @type {string | null | undefined} */
|
|
this.project_id = this.serviceKey.project_id;
|
|
this.client_email = this.serviceKey.client_email;
|
|
this.private_key = this.serviceKey.private_key;
|
|
this.access_token = null;
|
|
|
|
this.apiKey = creds[AuthKeys.GOOGLE_API_KEY];
|
|
|
|
this.reverseProxyUrl = options.reverseProxyUrl;
|
|
|
|
this.authHeader = options.authHeader;
|
|
|
|
/** @type {UsageMetadata | undefined} */
|
|
this.usage;
|
|
/** The key for the usage object's input tokens
|
|
* @type {string} */
|
|
this.inputTokensKey = 'input_tokens';
|
|
/** The key for the usage object's output tokens
|
|
* @type {string} */
|
|
this.outputTokensKey = 'output_tokens';
|
|
this.visionMode = VisionModes.generative;
|
|
/** @type {string} */
|
|
this.systemMessage;
|
|
if (options.skipSetOptions) {
|
|
return;
|
|
}
|
|
this.setOptions(options);
|
|
}
|
|
|
|
/* Google specific methods */
|
|
constructUrl() {
|
|
return `https://${endpointPrefix}/v1/projects/${this.project_id}/locations/${loc}/publishers/${publisher}/models/${this.modelOptions.model}:serverStreamingPredict`;
|
|
}
|
|
|
|
async getClient() {
|
|
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
|
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
|
|
|
jwtClient.authorize((err) => {
|
|
if (err) {
|
|
logger.error('jwtClient failed to authorize', err);
|
|
throw err;
|
|
}
|
|
});
|
|
|
|
return jwtClient;
|
|
}
|
|
|
|
async getAccessToken() {
|
|
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
|
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
|
|
|
return new Promise((resolve, reject) => {
|
|
jwtClient.authorize((err, tokens) => {
|
|
if (err) {
|
|
logger.error('jwtClient failed to authorize', err);
|
|
reject(err);
|
|
} else {
|
|
resolve(tokens.access_token);
|
|
}
|
|
});
|
|
});
|
|
}
|
|
|
|
/* Required Client methods */
|
|
setOptions(options) {
|
|
if (this.options && !this.options.replaceOptions) {
|
|
// nested options aren't spread properly, so we need to do this manually
|
|
this.options.modelOptions = {
|
|
...this.options.modelOptions,
|
|
...options.modelOptions,
|
|
};
|
|
delete options.modelOptions;
|
|
// now we can merge options
|
|
this.options = {
|
|
...this.options,
|
|
...options,
|
|
};
|
|
} else {
|
|
this.options = options;
|
|
}
|
|
|
|
this.modelOptions = this.options.modelOptions || {};
|
|
|
|
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
|
|
|
|
/** @type {boolean} Whether using a "GenerativeAI" Model */
|
|
this.isGenerativeModel = /gemini|learnlm|gemma/.test(this.modelOptions.model);
|
|
|
|
this.maxContextTokens =
|
|
this.options.maxContextTokens ??
|
|
getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google);
|
|
|
|
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
|
|
// Earlier messages will be dropped until the prompt is within the limit.
|
|
this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default;
|
|
|
|
if (this.maxContextTokens > 32000) {
|
|
this.maxContextTokens = this.maxContextTokens - this.maxResponseTokens;
|
|
}
|
|
|
|
this.maxPromptTokens =
|
|
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
|
|
|
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
|
|
throw new Error(
|
|
`maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
|
|
this.maxPromptTokens + this.maxResponseTokens
|
|
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
|
|
);
|
|
}
|
|
|
|
this.sender =
|
|
this.options.sender ??
|
|
getResponseSender({
|
|
model: this.modelOptions.model,
|
|
endpoint: EModelEndpoint.google,
|
|
modelLabel: this.options.modelLabel,
|
|
});
|
|
|
|
this.userLabel = this.options.userLabel || 'User';
|
|
this.modelLabel = this.options.modelLabel || 'Assistant';
|
|
|
|
if (this.options.reverseProxyUrl) {
|
|
this.completionsUrl = this.options.reverseProxyUrl;
|
|
} else {
|
|
this.completionsUrl = this.constructUrl();
|
|
}
|
|
|
|
let promptPrefix = (this.options.promptPrefix ?? '').trim();
|
|
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
|
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
|
}
|
|
this.systemMessage = promptPrefix;
|
|
this.initializeClient();
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
*
|
|
* Checks if the model is a vision model based on request attachments and sets the appropriate options:
|
|
* @param {MongoFile[]} attachments
|
|
*/
|
|
checkVisionRequest(attachments) {
|
|
/* Validation vision request */
|
|
this.defaultVisionModel =
|
|
this.options.visionModel ??
|
|
(!EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)
|
|
? this.modelOptions.model
|
|
: 'gemini-pro-vision');
|
|
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
|
|
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
|
|
|
|
if (
|
|
attachments &&
|
|
attachments.some((file) => file?.type && file?.type?.includes('image')) &&
|
|
availableModels?.includes(this.defaultVisionModel) &&
|
|
!this.isVisionModel
|
|
) {
|
|
this.modelOptions.model = this.defaultVisionModel;
|
|
this.isVisionModel = true;
|
|
}
|
|
|
|
if (this.isVisionModel && !attachments && this.modelOptions.model.includes('gemini-pro')) {
|
|
this.modelOptions.model = 'gemini-pro';
|
|
this.isVisionModel = false;
|
|
}
|
|
}
|
|
|
|
formatMessages() {
|
|
return ((message) => {
|
|
const msg = {
|
|
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
|
|
content: message?.content ?? message.text,
|
|
};
|
|
|
|
if (!message.image_urls?.length) {
|
|
return msg;
|
|
}
|
|
|
|
msg.content = (
|
|
!Array.isArray(msg.content)
|
|
? [
|
|
{
|
|
type: ContentTypes.TEXT,
|
|
[ContentTypes.TEXT]: msg.content,
|
|
},
|
|
]
|
|
: msg.content
|
|
).concat(message.image_urls);
|
|
|
|
return msg;
|
|
}).bind(this);
|
|
}
|
|
|
|
/**
|
|
* Formats messages for generative AI
|
|
* @param {TMessage[]} messages
|
|
* @returns
|
|
*/
|
|
async formatGenerativeMessages(messages) {
|
|
const formattedMessages = [];
|
|
const attachments = await this.options.attachments;
|
|
const latestMessage = { ...messages[messages.length - 1] };
|
|
const files = await this.addImageURLs(latestMessage, attachments, VisionModes.generative);
|
|
this.options.attachments = files;
|
|
messages[messages.length - 1] = latestMessage;
|
|
|
|
for (const _message of messages) {
|
|
const role = _message.isCreatedByUser ? this.userLabel : this.modelLabel;
|
|
const parts = [];
|
|
parts.push({ text: _message.text });
|
|
if (!_message.image_urls?.length) {
|
|
formattedMessages.push({ role, parts });
|
|
continue;
|
|
}
|
|
|
|
for (const images of _message.image_urls) {
|
|
if (images.inlineData) {
|
|
parts.push({ inlineData: images.inlineData });
|
|
}
|
|
}
|
|
|
|
formattedMessages.push({ role, parts });
|
|
}
|
|
|
|
return formattedMessages;
|
|
}
|
|
|
|
/**
|
|
*
|
|
* Adds image URLs to the message object and returns the files
|
|
*
|
|
* @param {TMessage[]} messages
|
|
* @param {MongoFile[]} files
|
|
* @returns {Promise<MongoFile[]>}
|
|
*/
|
|
async addImageURLs(message, attachments, mode = '') {
|
|
const { files, image_urls } = await encodeAndFormat(
|
|
this.options.req,
|
|
attachments,
|
|
EModelEndpoint.google,
|
|
mode,
|
|
);
|
|
message.image_urls = image_urls.length ? image_urls : undefined;
|
|
return files;
|
|
}
|
|
|
|
/**
|
|
* Builds the augmented prompt for attachments
|
|
* TODO: Add File API Support
|
|
* @param {TMessage[]} messages
|
|
*/
|
|
async buildAugmentedPrompt(messages = []) {
|
|
const attachments = await this.options.attachments;
|
|
const latestMessage = { ...messages[messages.length - 1] };
|
|
this.contextHandlers = createContextHandlers(this.options.req, latestMessage.text);
|
|
|
|
if (this.contextHandlers) {
|
|
for (const file of attachments) {
|
|
if (file.embedded) {
|
|
this.contextHandlers?.processFile(file);
|
|
continue;
|
|
}
|
|
if (file.metadata?.fileIdentifier) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
this.augmentedPrompt = await this.contextHandlers.createContext();
|
|
this.systemMessage = this.augmentedPrompt + this.systemMessage;
|
|
}
|
|
}
|
|
|
|
async buildVisionMessages(messages = [], parentMessageId) {
|
|
const attachments = await this.options.attachments;
|
|
const latestMessage = { ...messages[messages.length - 1] };
|
|
await this.buildAugmentedPrompt(messages);
|
|
|
|
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
|
|
|
|
const files = await this.addImageURLs(latestMessage, attachments);
|
|
|
|
this.options.attachments = files;
|
|
|
|
latestMessage.text = prompt;
|
|
|
|
const payload = {
|
|
instances: [
|
|
{
|
|
messages: [new HumanMessage(formatMessage({ message: latestMessage }))],
|
|
},
|
|
],
|
|
};
|
|
return { prompt: payload };
|
|
}
|
|
|
|
/** @param {TMessage[]} [messages=[]] */
|
|
async buildGenerativeMessages(messages = []) {
|
|
this.userLabel = 'user';
|
|
this.modelLabel = 'model';
|
|
const promises = [];
|
|
promises.push(await this.formatGenerativeMessages(messages));
|
|
promises.push(this.buildAugmentedPrompt(messages));
|
|
const [formattedMessages] = await Promise.all(promises);
|
|
return { prompt: formattedMessages };
|
|
}
|
|
|
|
/**
|
|
* @param {TMessage[]} [messages=[]]
|
|
* @param {string} [parentMessageId]
|
|
*/
|
|
async buildMessages(_messages = [], parentMessageId) {
|
|
if (!this.isGenerativeModel && !this.project_id) {
|
|
throw new Error('[GoogleClient] PaLM 2 and Codey models are no longer supported.');
|
|
}
|
|
|
|
if (this.systemMessage) {
|
|
const instructionsTokenCount = this.getTokenCount(this.systemMessage);
|
|
|
|
this.maxContextTokens = this.maxContextTokens - instructionsTokenCount;
|
|
if (this.maxContextTokens < 0) {
|
|
const info = `${instructionsTokenCount} / ${this.maxContextTokens}`;
|
|
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
|
|
logger.warn(`Instructions token count exceeds max context (${info}).`);
|
|
throw new Error(errorMessage);
|
|
}
|
|
}
|
|
|
|
for (let i = 0; i < _messages.length; i++) {
|
|
const message = _messages[i];
|
|
if (!message.tokenCount) {
|
|
_messages[i].tokenCount = this.getTokenCountForMessage({
|
|
role: message.isCreatedByUser ? 'user' : 'assistant',
|
|
content: message.content ?? message.text,
|
|
});
|
|
}
|
|
}
|
|
|
|
const {
|
|
payload: messages,
|
|
tokenCountMap,
|
|
promptTokens,
|
|
} = await this.handleContextStrategy({
|
|
orderedMessages: _messages,
|
|
formattedMessages: _messages,
|
|
});
|
|
|
|
if (!this.project_id && !EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)) {
|
|
const result = await this.buildGenerativeMessages(messages);
|
|
result.tokenCountMap = tokenCountMap;
|
|
result.promptTokens = promptTokens;
|
|
return result;
|
|
}
|
|
|
|
if (this.options.attachments && this.isGenerativeModel) {
|
|
const result = this.buildVisionMessages(messages, parentMessageId);
|
|
result.tokenCountMap = tokenCountMap;
|
|
result.promptTokens = promptTokens;
|
|
return result;
|
|
}
|
|
|
|
let payload = {
|
|
instances: [
|
|
{
|
|
messages: messages
|
|
.map(this.formatMessages())
|
|
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
|
|
.map((message) => formatMessage({ message, langChain: true })),
|
|
},
|
|
],
|
|
};
|
|
|
|
if (this.systemMessage) {
|
|
payload.instances[0].context = this.systemMessage;
|
|
}
|
|
|
|
logger.debug('[GoogleClient] buildMessages', payload);
|
|
return { prompt: payload, tokenCountMap, promptTokens };
|
|
}
|
|
|
|
async buildMessagesPrompt(messages, parentMessageId) {
|
|
const orderedMessages = this.constructor.getMessagesForConversation({
|
|
messages,
|
|
parentMessageId,
|
|
});
|
|
|
|
logger.debug('[GoogleClient]', {
|
|
orderedMessages,
|
|
parentMessageId,
|
|
});
|
|
|
|
const formattedMessages = orderedMessages.map(this.formatMessages());
|
|
|
|
let lastAuthor = '';
|
|
let groupedMessages = [];
|
|
|
|
for (let message of formattedMessages) {
|
|
// If last author is not same as current author, add to new group
|
|
if (lastAuthor !== message.author) {
|
|
groupedMessages.push({
|
|
author: message.author,
|
|
content: [message.content],
|
|
});
|
|
lastAuthor = message.author;
|
|
// If same author, append content to the last group
|
|
} else {
|
|
groupedMessages[groupedMessages.length - 1].content.push(message.content);
|
|
}
|
|
}
|
|
|
|
let identityPrefix = '';
|
|
if (this.options.userLabel) {
|
|
identityPrefix = `\nHuman's name: ${this.options.userLabel}`;
|
|
}
|
|
|
|
if (this.options.modelLabel) {
|
|
identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`;
|
|
}
|
|
|
|
let promptPrefix = (this.systemMessage ?? '').trim();
|
|
|
|
if (identityPrefix) {
|
|
promptPrefix = `${identityPrefix}${promptPrefix}`;
|
|
}
|
|
|
|
// Prompt AI to respond, empty if last message was from AI
|
|
let isEdited = lastAuthor === this.modelLabel;
|
|
const promptSuffix = isEdited ? '' : `${promptPrefix}\n\n${this.modelLabel}:\n`;
|
|
let currentTokenCount = isEdited
|
|
? this.getTokenCount(promptPrefix)
|
|
: this.getTokenCount(promptSuffix);
|
|
|
|
let promptBody = '';
|
|
const maxTokenCount = this.maxPromptTokens;
|
|
|
|
const context = [];
|
|
|
|
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
|
|
// Do this within a recursive async function so that it doesn't block the event loop for too long.
|
|
// Also, remove the next message when the message that puts us over the token limit is created by the user.
|
|
// Otherwise, remove only the exceeding message. This is due to Anthropic's strict payload rule to start with "Human:".
|
|
const nextMessage = {
|
|
remove: false,
|
|
tokenCount: 0,
|
|
messageString: '',
|
|
};
|
|
|
|
const buildPromptBody = async () => {
|
|
if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) {
|
|
const message = groupedMessages.pop();
|
|
const isCreatedByUser = message.author === this.userLabel;
|
|
// Use promptPrefix if message is edited assistant'
|
|
const messagePrefix =
|
|
isCreatedByUser || !isEdited
|
|
? `\n\n${message.author}:`
|
|
: `${promptPrefix}\n\n${message.author}:`;
|
|
const messageString = `${messagePrefix}\n${message.content}\n`;
|
|
let newPromptBody = `${messageString}${promptBody}`;
|
|
|
|
context.unshift(message);
|
|
|
|
const tokenCountForMessage = this.getTokenCount(messageString);
|
|
const newTokenCount = currentTokenCount + tokenCountForMessage;
|
|
|
|
if (!isCreatedByUser) {
|
|
nextMessage.messageString = messageString;
|
|
nextMessage.tokenCount = tokenCountForMessage;
|
|
}
|
|
|
|
if (newTokenCount > maxTokenCount) {
|
|
if (!promptBody) {
|
|
// This is the first message, so we can't add it. Just throw an error.
|
|
throw new Error(
|
|
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
|
|
);
|
|
}
|
|
|
|
// Otherwise, ths message would put us over the token limit, so don't add it.
|
|
// if created by user, remove next message, otherwise remove only this message
|
|
if (isCreatedByUser) {
|
|
nextMessage.remove = true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
promptBody = newPromptBody;
|
|
currentTokenCount = newTokenCount;
|
|
|
|
// Switch off isEdited after using it for the first time
|
|
if (isEdited) {
|
|
isEdited = false;
|
|
}
|
|
|
|
// wait for next tick to avoid blocking the event loop
|
|
await new Promise((resolve) => setImmediate(resolve));
|
|
return buildPromptBody();
|
|
}
|
|
return true;
|
|
};
|
|
|
|
await buildPromptBody();
|
|
|
|
if (nextMessage.remove) {
|
|
promptBody = promptBody.replace(nextMessage.messageString, '');
|
|
currentTokenCount -= nextMessage.tokenCount;
|
|
context.shift();
|
|
}
|
|
|
|
let prompt = `${promptBody}${promptSuffix}`.trim();
|
|
|
|
// Add 2 tokens for metadata after all messages have been counted.
|
|
currentTokenCount += 2;
|
|
|
|
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
|
this.modelOptions.maxOutputTokens = Math.min(
|
|
this.maxContextTokens - currentTokenCount,
|
|
this.maxResponseTokens,
|
|
);
|
|
|
|
return { prompt, context };
|
|
}
|
|
|
|
createLLM(clientOptions) {
|
|
const model = clientOptions.modelName ?? clientOptions.model;
|
|
clientOptions.location = loc;
|
|
clientOptions.endpoint = endpointPrefix;
|
|
|
|
let requestOptions = null;
|
|
if (this.reverseProxyUrl) {
|
|
requestOptions = {
|
|
baseUrl: this.reverseProxyUrl,
|
|
};
|
|
|
|
if (this.authHeader) {
|
|
requestOptions.customHeaders = {
|
|
Authorization: `Bearer ${this.apiKey}`,
|
|
};
|
|
}
|
|
}
|
|
|
|
if (this.project_id != null) {
|
|
logger.debug('Creating VertexAI client');
|
|
this.visionMode = undefined;
|
|
clientOptions.streaming = true;
|
|
const client = new ChatVertexAI(clientOptions);
|
|
client.temperature = clientOptions.temperature;
|
|
client.topP = clientOptions.topP;
|
|
client.topK = clientOptions.topK;
|
|
client.topLogprobs = clientOptions.topLogprobs;
|
|
client.frequencyPenalty = clientOptions.frequencyPenalty;
|
|
client.presencePenalty = clientOptions.presencePenalty;
|
|
client.maxOutputTokens = clientOptions.maxOutputTokens;
|
|
return client;
|
|
} else if (!EXCLUDED_GENAI_MODELS.test(model)) {
|
|
logger.debug('Creating GenAI client');
|
|
return new GenAI(this.apiKey).getGenerativeModel({ model }, requestOptions);
|
|
}
|
|
|
|
logger.debug('Creating Chat Google Generative AI client');
|
|
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
|
}
|
|
|
|
initializeClient() {
|
|
let clientOptions = { ...this.modelOptions };
|
|
|
|
if (this.project_id) {
|
|
clientOptions['authOptions'] = {
|
|
credentials: {
|
|
...this.serviceKey,
|
|
},
|
|
projectId: this.project_id,
|
|
};
|
|
}
|
|
|
|
if (this.isGenerativeModel && !this.project_id) {
|
|
clientOptions.modelName = clientOptions.model;
|
|
delete clientOptions.model;
|
|
}
|
|
|
|
this.client = this.createLLM(clientOptions);
|
|
return this.client;
|
|
}
|
|
|
|
async getCompletion(_payload, options = {}) {
|
|
const { onProgress, abortController } = options;
|
|
const safetySettings = getSafetySettings(this.modelOptions.model);
|
|
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
|
const modelName = this.modelOptions.modelName ?? this.modelOptions.model ?? '';
|
|
|
|
let reply = '';
|
|
/** @type {Error} */
|
|
let error;
|
|
try {
|
|
if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) {
|
|
/** @type {GenerativeModel} */
|
|
const client = this.client;
|
|
/** @type {GenerateContentRequest} */
|
|
const requestOptions = {
|
|
safetySettings,
|
|
contents: _payload,
|
|
generationConfig: googleGenConfigSchema.parse(this.modelOptions),
|
|
};
|
|
|
|
const promptPrefix = (this.systemMessage ?? '').trim();
|
|
if (promptPrefix.length) {
|
|
requestOptions.systemInstruction = {
|
|
parts: [
|
|
{
|
|
text: promptPrefix,
|
|
},
|
|
],
|
|
};
|
|
}
|
|
|
|
const delay = modelName.includes('flash') ? 8 : 15;
|
|
/** @type {GenAIUsageMetadata} */
|
|
let usageMetadata;
|
|
|
|
abortController.signal.addEventListener(
|
|
'abort',
|
|
() => {
|
|
logger.warn('[GoogleClient] Request was aborted', abortController.signal.reason);
|
|
},
|
|
{ once: true },
|
|
);
|
|
|
|
const result = await client.generateContentStream(requestOptions, {
|
|
signal: abortController.signal,
|
|
});
|
|
for await (const chunk of result.stream) {
|
|
usageMetadata = !usageMetadata
|
|
? chunk?.usageMetadata
|
|
: Object.assign(usageMetadata, chunk?.usageMetadata);
|
|
const chunkText = chunk.text();
|
|
await this.generateTextStream(chunkText, onProgress, {
|
|
delay,
|
|
});
|
|
reply += chunkText;
|
|
await sleep(streamRate);
|
|
}
|
|
|
|
if (usageMetadata) {
|
|
this.usage = {
|
|
input_tokens: usageMetadata.promptTokenCount,
|
|
output_tokens: usageMetadata.candidatesTokenCount,
|
|
};
|
|
}
|
|
|
|
return reply;
|
|
}
|
|
|
|
const { instances } = _payload;
|
|
const { messages: messages, context } = instances?.[0] ?? {};
|
|
|
|
if (!this.isVisionModel && context && messages?.length > 0) {
|
|
messages.unshift(new SystemMessage(context));
|
|
}
|
|
|
|
/** @type {import('@langchain/core/messages').AIMessageChunk['usage_metadata']} */
|
|
let usageMetadata;
|
|
/** @type {ChatVertexAI} */
|
|
const client = this.client;
|
|
const stream = await client.stream(messages, {
|
|
signal: abortController.signal,
|
|
streamUsage: true,
|
|
safetySettings,
|
|
});
|
|
|
|
let delay = this.options.streamRate || 8;
|
|
|
|
if (!this.options.streamRate) {
|
|
if (this.isGenerativeModel) {
|
|
delay = 15;
|
|
}
|
|
if (modelName.includes('flash')) {
|
|
delay = 5;
|
|
}
|
|
}
|
|
|
|
for await (const chunk of stream) {
|
|
if (chunk?.usage_metadata) {
|
|
const metadata = chunk.usage_metadata;
|
|
for (const key in metadata) {
|
|
if (Number.isNaN(metadata[key])) {
|
|
delete metadata[key];
|
|
}
|
|
}
|
|
|
|
usageMetadata = !usageMetadata ? metadata : concat(usageMetadata, metadata);
|
|
}
|
|
|
|
const chunkText = chunk?.content ?? '';
|
|
await this.generateTextStream(chunkText, onProgress, {
|
|
delay,
|
|
});
|
|
reply += chunkText;
|
|
}
|
|
|
|
if (usageMetadata) {
|
|
this.usage = usageMetadata;
|
|
}
|
|
} catch (e) {
|
|
error = e;
|
|
logger.error('[GoogleClient] There was an issue generating the completion', e);
|
|
}
|
|
|
|
if (error != null && reply === '') {
|
|
const errorMessage = `{ "type": "${ErrorTypes.GoogleError}", "info": "${
|
|
error.message ?? 'The Google provider failed to generate content, please contact the Admin.'
|
|
}" }`;
|
|
throw new Error(errorMessage);
|
|
}
|
|
return reply;
|
|
}
|
|
|
|
/**
|
|
* Get stream usage as returned by this client's API response.
|
|
* @returns {UsageMetadata} The stream usage object.
|
|
*/
|
|
getStreamUsage() {
|
|
return this.usage;
|
|
}
|
|
|
|
getMessageMapMethod() {
|
|
/**
|
|
* @param {TMessage} msg
|
|
*/
|
|
return (msg) => {
|
|
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
|
|
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
|
|
} else if (msg.content != null) {
|
|
msg.text = parseTextParts(msg.content, true);
|
|
delete msg.content;
|
|
}
|
|
|
|
return msg;
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Calculates the correct token count for the current user message based on the token count map and API usage.
|
|
* Edge case: If the calculation results in a negative value, it returns the original estimate.
|
|
* If revisiting a conversation with a chat history entirely composed of token estimates,
|
|
* the cumulative token count going forward should become more accurate as the conversation progresses.
|
|
* @param {Object} params - The parameters for the calculation.
|
|
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
|
|
* @param {string} params.currentMessageId - The ID of the current message to calculate.
|
|
* @param {UsageMetadata} params.usage - The usage object returned by the API.
|
|
* @returns {number} The correct token count for the current user message.
|
|
*/
|
|
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
|
|
const originalEstimate = tokenCountMap[currentMessageId] || 0;
|
|
|
|
if (!usage || typeof usage.input_tokens !== 'number') {
|
|
return originalEstimate;
|
|
}
|
|
|
|
tokenCountMap[currentMessageId] = 0;
|
|
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
|
|
const numCount = Number(count);
|
|
return sum + (isNaN(numCount) ? 0 : numCount);
|
|
}, 0);
|
|
const totalInputTokens = usage.input_tokens ?? 0;
|
|
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
|
|
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
|
}
|
|
|
|
/**
|
|
* @param {object} params
|
|
* @param {number} params.promptTokens
|
|
* @param {number} params.completionTokens
|
|
* @param {UsageMetadata} [params.usage]
|
|
* @param {string} [params.model]
|
|
* @param {string} [params.context='message']
|
|
* @returns {Promise<void>}
|
|
*/
|
|
async recordTokenUsage({ promptTokens, completionTokens, model, context = 'message' }) {
|
|
await spendTokens(
|
|
{
|
|
context,
|
|
user: this.user ?? this.options.req?.user?.id,
|
|
conversationId: this.conversationId,
|
|
model: model ?? this.modelOptions.model,
|
|
endpointTokenConfig: this.options.endpointTokenConfig,
|
|
},
|
|
{ promptTokens, completionTokens },
|
|
);
|
|
}
|
|
|
|
/**
|
|
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
|
|
*/
|
|
async titleChatCompletion(_payload, options = {}) {
|
|
let reply = '';
|
|
const { abortController } = options;
|
|
|
|
const model =
|
|
this.options.titleModel ?? this.modelOptions.modelName ?? this.modelOptions.model ?? '';
|
|
const safetySettings = getSafetySettings(model);
|
|
if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) {
|
|
logger.debug('Identified titling model as GenAI version');
|
|
/** @type {GenerativeModel} */
|
|
const client = this.client;
|
|
const requestOptions = {
|
|
contents: _payload,
|
|
safetySettings,
|
|
generationConfig: {
|
|
temperature: 0.5,
|
|
},
|
|
};
|
|
|
|
const result = await client.generateContent(requestOptions);
|
|
reply = result.response?.text();
|
|
return reply;
|
|
} else {
|
|
const { instances } = _payload;
|
|
const { messages } = instances?.[0] ?? {};
|
|
const titleResponse = await this.client.invoke(messages, {
|
|
signal: abortController.signal,
|
|
timeout: 7000,
|
|
safetySettings,
|
|
});
|
|
|
|
if (titleResponse.usage_metadata) {
|
|
await this.recordTokenUsage({
|
|
model,
|
|
promptTokens: titleResponse.usage_metadata.input_tokens,
|
|
completionTokens: titleResponse.usage_metadata.output_tokens,
|
|
context: 'title',
|
|
});
|
|
}
|
|
|
|
reply = titleResponse.content;
|
|
return reply;
|
|
}
|
|
}
|
|
|
|
async titleConvo({ text, responseText = '' }) {
|
|
let title = 'New Chat';
|
|
const convo = `||>User:
|
|
"${truncateText(text)}"
|
|
||>Response:
|
|
"${JSON.stringify(truncateText(responseText))}"`;
|
|
|
|
let { prompt: payload } = await this.buildMessages([
|
|
{
|
|
text: `Please generate ${titleInstruction}
|
|
|
|
${convo}
|
|
|
|
||>Title:`,
|
|
isCreatedByUser: true,
|
|
author: this.userLabel,
|
|
},
|
|
]);
|
|
|
|
try {
|
|
this.initializeClient();
|
|
title = await this.titleChatCompletion(payload, {
|
|
abortController: new AbortController(),
|
|
onProgress: () => {},
|
|
});
|
|
} catch (e) {
|
|
logger.error('[GoogleClient] There was an issue generating the title', e);
|
|
}
|
|
logger.debug(`Title response: ${title}`);
|
|
return title;
|
|
}
|
|
|
|
getSaveOptions() {
|
|
return {
|
|
endpointType: null,
|
|
artifacts: this.options.artifacts,
|
|
promptPrefix: this.options.promptPrefix,
|
|
maxContextTokens: this.options.maxContextTokens,
|
|
modelLabel: this.options.modelLabel,
|
|
iconURL: this.options.iconURL,
|
|
greeting: this.options.greeting,
|
|
spec: this.options.spec,
|
|
...this.modelOptions,
|
|
};
|
|
}
|
|
|
|
getBuildMessagesOptions() {
|
|
// logger.debug('GoogleClient doesn\'t use getBuildMessagesOptions');
|
|
}
|
|
|
|
async sendCompletion(payload, opts = {}) {
|
|
let reply = '';
|
|
reply = await this.getCompletion(payload, opts);
|
|
return reply.trim();
|
|
}
|
|
|
|
getEncoding() {
|
|
return 'cl100k_base';
|
|
}
|
|
|
|
async getVertexTokenCount(text) {
|
|
/** @type {ChatVertexAI} */
|
|
const client = this.client ?? this.initializeClient();
|
|
const connection = client.connection;
|
|
const gAuthClient = connection.client;
|
|
const tokenEndpoint = `https://${connection._endpoint}/${connection.apiVersion}/projects/${this.project_id}/locations/${connection._location}/publishers/google/models/${connection.model}/:countTokens`;
|
|
const result = await gAuthClient.request({
|
|
url: tokenEndpoint,
|
|
method: 'POST',
|
|
data: {
|
|
contents: [{ role: 'user', parts: [{ text }] }],
|
|
},
|
|
});
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
|
|
* @param {string} text - The text to get the token count for.
|
|
* @returns {number} The token count of the given text.
|
|
*/
|
|
getTokenCount(text) {
|
|
const encoding = this.getEncoding();
|
|
return Tokenizer.getTokenCount(text, encoding);
|
|
}
|
|
}
|
|
|
|
module.exports = GoogleClient;
|