mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 08:12:00 +02:00

* ci(backend-review.yml): add linter step to the backend review workflow * chore(backend-review.yml): remove prettier from lint-action configuration * chore: apply new linting workflow * chore(lint-staged.config.js): reorder lint-staged tasks for JavaScript and TypeScript files * chore(eslint): update ignorePatterns in .eslintrc.js chore(lint-action): remove prettier option in backend-review.yml chore(package.json): add lint and lint:fix scripts * chore(lint-staged.config.js): remove prettier --write command for js, jsx, ts, tsx files * chore(titleConvo.js): remove unnecessary console.log statement chore(titleConvo.js): add missing comma in options object * chore: apply linting to all files * chore(lint-staged.config.js): update lint-staged configuration to include prettier formatting
280 lines
9 KiB
JavaScript
280 lines
9 KiB
JavaScript
const BaseClient = require('./BaseClient');
|
|
const { google } = require('googleapis');
|
|
const { Agent, ProxyAgent } = require('undici');
|
|
const {
|
|
encoding_for_model: encodingForModel,
|
|
get_encoding: getEncoding,
|
|
} = require('@dqbd/tiktoken');
|
|
|
|
const tokenizersCache = {};
|
|
|
|
class GoogleClient extends BaseClient {
|
|
constructor(credentials, options = {}) {
|
|
super('apiKey', options);
|
|
this.client_email = credentials.client_email;
|
|
this.project_id = credentials.project_id;
|
|
this.private_key = credentials.private_key;
|
|
this.sender = 'PaLM2';
|
|
this.setOptions(options);
|
|
}
|
|
|
|
/* Google/PaLM2 specific methods */
|
|
constructUrl() {
|
|
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
|
|
}
|
|
|
|
async getClient() {
|
|
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
|
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
|
|
|
jwtClient.authorize((err) => {
|
|
if (err) {
|
|
console.log(err);
|
|
throw err;
|
|
}
|
|
});
|
|
|
|
return jwtClient;
|
|
}
|
|
|
|
/* Required Client methods */
|
|
setOptions(options) {
|
|
if (this.options && !this.options.replaceOptions) {
|
|
// nested options aren't spread properly, so we need to do this manually
|
|
this.options.modelOptions = {
|
|
...this.options.modelOptions,
|
|
...options.modelOptions,
|
|
};
|
|
delete options.modelOptions;
|
|
// now we can merge options
|
|
this.options = {
|
|
...this.options,
|
|
...options,
|
|
};
|
|
} else {
|
|
this.options = options;
|
|
}
|
|
|
|
this.options.examples = this.options.examples.filter(
|
|
(obj) => obj.input.content !== '' && obj.output.content !== '',
|
|
);
|
|
|
|
const modelOptions = this.options.modelOptions || {};
|
|
this.modelOptions = {
|
|
...modelOptions,
|
|
// set some good defaults (check for undefined in some cases because they may be 0)
|
|
model: modelOptions.model || 'chat-bison',
|
|
temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, // 0 - 1, 0.2 is recommended
|
|
topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, // 0 - 1, default: 0.95
|
|
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
|
|
// stop: modelOptions.stop // no stop method for now
|
|
};
|
|
|
|
this.isChatModel = this.modelOptions.model.startsWith('chat-');
|
|
const { isChatModel } = this;
|
|
this.isTextModel = this.modelOptions.model.startsWith('text-');
|
|
const { isTextModel } = this;
|
|
|
|
this.maxContextTokens = this.options.maxContextTokens || (isTextModel ? 8000 : 4096);
|
|
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
|
|
// Earlier messages will be dropped until the prompt is within the limit.
|
|
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1024;
|
|
this.maxPromptTokens =
|
|
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
|
|
|
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
|
|
throw new Error(
|
|
`maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
|
|
this.maxPromptTokens + this.maxResponseTokens
|
|
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
|
|
);
|
|
}
|
|
|
|
this.userLabel = this.options.userLabel || 'User';
|
|
this.modelLabel = this.options.modelLabel || 'Assistant';
|
|
|
|
if (isChatModel) {
|
|
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
|
|
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
|
|
// without tripping the stop sequences, so I'm using "||>" instead.
|
|
this.startToken = '||>';
|
|
this.endToken = '';
|
|
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
|
|
} else if (isTextModel) {
|
|
this.startToken = '<|im_start|>';
|
|
this.endToken = '<|im_end|>';
|
|
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
|
|
'<|im_start|>': 100264,
|
|
'<|im_end|>': 100265,
|
|
});
|
|
} else {
|
|
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
|
|
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
|
|
// as a single token. So we're using this instead.
|
|
this.startToken = '||>';
|
|
this.endToken = '';
|
|
try {
|
|
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
|
|
} catch {
|
|
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
|
|
}
|
|
}
|
|
|
|
if (!this.modelOptions.stop) {
|
|
const stopTokens = [this.startToken];
|
|
if (this.endToken && this.endToken !== this.startToken) {
|
|
stopTokens.push(this.endToken);
|
|
}
|
|
stopTokens.push(`\n${this.userLabel}:`);
|
|
stopTokens.push('<|diff_marker|>');
|
|
// I chose not to do one for `modelLabel` because I've never seen it happen
|
|
this.modelOptions.stop = stopTokens;
|
|
}
|
|
|
|
if (this.options.reverseProxyUrl) {
|
|
this.completionsUrl = this.options.reverseProxyUrl;
|
|
} else {
|
|
this.completionsUrl = this.constructUrl();
|
|
}
|
|
|
|
return this;
|
|
}
|
|
|
|
getMessageMapMethod() {
|
|
return ((message) => ({
|
|
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
|
|
content: message?.content ?? message.text,
|
|
})).bind(this);
|
|
}
|
|
|
|
buildMessages(messages = []) {
|
|
const formattedMessages = messages.map(this.getMessageMapMethod());
|
|
let payload = {
|
|
instances: [
|
|
{
|
|
messages: formattedMessages,
|
|
},
|
|
],
|
|
parameters: this.options.modelOptions,
|
|
};
|
|
|
|
if (this.options.promptPrefix) {
|
|
payload.instances[0].context = this.options.promptPrefix;
|
|
}
|
|
|
|
if (this.options.examples.length > 0) {
|
|
payload.instances[0].examples = this.options.examples;
|
|
}
|
|
|
|
/* TO-DO: text model needs more context since it can't process an array of messages */
|
|
if (this.isTextModel) {
|
|
payload.instances = [
|
|
{
|
|
prompt: messages[messages.length - 1].content,
|
|
},
|
|
];
|
|
}
|
|
|
|
if (this.options.debug) {
|
|
console.debug('GoogleClient buildMessages');
|
|
console.dir(payload, { depth: null });
|
|
}
|
|
|
|
return { prompt: payload };
|
|
}
|
|
|
|
async getCompletion(payload, abortController = null) {
|
|
if (!abortController) {
|
|
abortController = new AbortController();
|
|
}
|
|
const { debug } = this.options;
|
|
const url = this.completionsUrl;
|
|
if (debug) {
|
|
console.debug();
|
|
console.debug(url);
|
|
console.debug(this.modelOptions);
|
|
console.debug();
|
|
}
|
|
const opts = {
|
|
method: 'POST',
|
|
agent: new Agent({
|
|
bodyTimeout: 0,
|
|
headersTimeout: 0,
|
|
}),
|
|
signal: abortController.signal,
|
|
};
|
|
|
|
if (this.options.proxy) {
|
|
opts.agent = new ProxyAgent(this.options.proxy);
|
|
}
|
|
|
|
const client = await this.getClient();
|
|
const res = await client.request({ url, method: 'POST', data: payload });
|
|
console.dir(res.data, { depth: null });
|
|
return res.data;
|
|
}
|
|
|
|
getSaveOptions() {
|
|
return {
|
|
promptPrefix: this.options.promptPrefix,
|
|
modelLabel: this.options.modelLabel,
|
|
...this.modelOptions,
|
|
};
|
|
}
|
|
|
|
getBuildMessagesOptions() {
|
|
// console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
|
|
}
|
|
|
|
async sendCompletion(payload, opts = {}) {
|
|
console.log('GoogleClient: sendcompletion', payload, opts);
|
|
let reply = '';
|
|
let blocked = false;
|
|
try {
|
|
const result = await this.getCompletion(payload, opts.abortController);
|
|
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
|
|
reply =
|
|
result?.predictions?.[0]?.candidates?.[0]?.content ||
|
|
result?.predictions?.[0]?.content ||
|
|
'';
|
|
if (blocked === true) {
|
|
reply = `Google blocked a proper response to your message:\n${JSON.stringify(
|
|
result.predictions[0].safetyAttributes,
|
|
)}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`;
|
|
}
|
|
if (this.options.debug) {
|
|
console.debug('result');
|
|
console.debug(result);
|
|
}
|
|
} catch (err) {
|
|
console.error(err);
|
|
}
|
|
|
|
if (!blocked) {
|
|
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
|
|
}
|
|
|
|
return reply.trim();
|
|
}
|
|
|
|
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
|
|
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
|
if (tokenizersCache[encoding]) {
|
|
return tokenizersCache[encoding];
|
|
}
|
|
let tokenizer;
|
|
if (isModelName) {
|
|
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
|
} else {
|
|
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
|
}
|
|
tokenizersCache[encoding] = tokenizer;
|
|
return tokenizer;
|
|
}
|
|
|
|
getTokenCount(text) {
|
|
return this.gptEncoder.encode(text, 'all').length;
|
|
}
|
|
}
|
|
|
|
module.exports = GoogleClient;
|