mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-18 17:30:16 +01:00
refactor(BaseClient, GoogleClient): make sendCompletion required, refactor Google to use Base sendMessage (#591)
This commit is contained in:
parent
4e317c85fd
commit
77d5fb0c58
2 changed files with 39 additions and 56 deletions
|
|
@ -26,6 +26,10 @@ class BaseClient {
|
||||||
throw new Error("Method 'getCompletion' must be implemented.");
|
throw new Error("Method 'getCompletion' must be implemented.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sendCompletion() {
|
||||||
|
throw new Error("Method 'sendCompletion' must be implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
getSaveOptions() {
|
getSaveOptions() {
|
||||||
throw new Error('Subclasses must implement getSaveOptions');
|
throw new Error('Subclasses must implement getSaveOptions');
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,26 @@ class GoogleClient extends BaseClient {
|
||||||
this.setOptions(options);
|
this.setOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Google/PaLM2 specific methods */
|
||||||
constructUrl() {
|
constructUrl() {
|
||||||
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
|
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getClient() {
|
||||||
|
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
||||||
|
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
||||||
|
|
||||||
|
jwtClient.authorize((err) => {
|
||||||
|
if (err) {
|
||||||
|
console.log(err);
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return jwtClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Required Client methods */
|
||||||
setOptions(options) {
|
setOptions(options) {
|
||||||
if (this.options && !this.options.replaceOptions) {
|
if (this.options && !this.options.replaceOptions) {
|
||||||
// nested options aren't spread properly, so we need to do this manually
|
// nested options aren't spread properly, so we need to do this manually
|
||||||
|
|
@ -124,25 +140,19 @@ class GoogleClient extends BaseClient {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getClient() {
|
getMessageMapMethod() {
|
||||||
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
return ((message) => ({
|
||||||
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
|
||||||
|
content: message?.content ?? message.text
|
||||||
jwtClient.authorize((err) => {
|
})).bind(this);
|
||||||
if (err) {
|
|
||||||
console.log(err);
|
|
||||||
throw err;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return jwtClient;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
buildMessages(input, { messages = [] }) {
|
buildMessages(messages = []) {
|
||||||
|
const formattedMessages = messages.map(this.getMessageMapMethod());
|
||||||
let payload = {
|
let payload = {
|
||||||
instances: [
|
instances: [
|
||||||
{
|
{
|
||||||
messages: [...messages, { author: this.userLabel, content: input }]
|
messages: formattedMessages,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
parameters: this.options.modelOptions
|
parameters: this.options.modelOptions
|
||||||
|
|
@ -156,23 +166,24 @@ class GoogleClient extends BaseClient {
|
||||||
payload.instances[0].examples = this.options.examples;
|
payload.instances[0].examples = this.options.examples;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* TO-DO: text model needs more context since it can't process an array of messages */
|
||||||
if (this.isTextModel) {
|
if (this.isTextModel) {
|
||||||
payload.instances = [
|
payload.instances = [
|
||||||
{
|
{
|
||||||
prompt: input
|
prompt: messages[messages.length -1].content
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.options.debug) {
|
if (this.options.debug) {
|
||||||
console.debug('buildMessages');
|
console.debug('GoogleClient buildMessages');
|
||||||
console.dir(payload, { depth: null });
|
console.dir(payload, { depth: null });
|
||||||
}
|
}
|
||||||
|
|
||||||
return payload;
|
return { prompt: payload };
|
||||||
}
|
}
|
||||||
|
|
||||||
async getCompletion(input, messages = [], abortController = null) {
|
async getCompletion(payload, abortController = null) {
|
||||||
if (!abortController) {
|
if (!abortController) {
|
||||||
abortController = new AbortController();
|
abortController = new AbortController();
|
||||||
}
|
}
|
||||||
|
|
@ -198,19 +209,11 @@ class GoogleClient extends BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = await this.getClient();
|
const client = await this.getClient();
|
||||||
const payload = this.buildMessages(input, { messages });
|
|
||||||
const res = await client.request({ url, method: 'POST', data: payload });
|
const res = await client.request({ url, method: 'POST', data: payload });
|
||||||
console.dir(res.data, { depth: null });
|
console.dir(res.data, { depth: null });
|
||||||
return res.data;
|
return res.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
getMessageMapMethod() {
|
|
||||||
return ((message) => ({
|
|
||||||
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
|
|
||||||
content: message?.content ?? message.text
|
|
||||||
})).bind(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
getSaveOptions() {
|
getSaveOptions() {
|
||||||
return {
|
return {
|
||||||
...this.modelOptions
|
...this.modelOptions
|
||||||
|
|
@ -218,24 +221,15 @@ class GoogleClient extends BaseClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
getBuildMessagesOptions() {
|
getBuildMessagesOptions() {
|
||||||
console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
|
// console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
|
||||||
}
|
}
|
||||||
|
|
||||||
async sendMessage(message, opts = {}) {
|
async sendCompletion(payload, opts = {}) {
|
||||||
console.log('GoogleClient: sendMessage', message, opts);
|
console.log('GoogleClient: sendcompletion', payload, opts);
|
||||||
const {
|
|
||||||
user,
|
|
||||||
conversationId,
|
|
||||||
responseMessageId,
|
|
||||||
saveOptions,
|
|
||||||
userMessage,
|
|
||||||
} = await this.handleStartMethods(message, opts);
|
|
||||||
|
|
||||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
|
||||||
let reply = '';
|
let reply = '';
|
||||||
let blocked = false;
|
let blocked = false;
|
||||||
try {
|
try {
|
||||||
const result = await this.getCompletion(message, this.currentMessages, opts.abortController);
|
const result = await this.getCompletion(payload, opts.abortController);
|
||||||
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
|
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
|
||||||
reply =
|
reply =
|
||||||
result?.predictions?.[0]?.candidates?.[0]?.content ||
|
result?.predictions?.[0]?.candidates?.[0]?.content ||
|
||||||
|
|
@ -254,29 +248,14 @@ class GoogleClient extends BaseClient {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.options.debug) {
|
|
||||||
console.debug('options');
|
|
||||||
console.debug(this.options);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!blocked) {
|
if (!blocked) {
|
||||||
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
|
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
|
||||||
}
|
}
|
||||||
|
|
||||||
const responseMessage = {
|
return reply.trim();
|
||||||
messageId: responseMessageId,
|
|
||||||
conversationId,
|
|
||||||
parentMessageId: userMessage.messageId,
|
|
||||||
sender: this.sender,
|
|
||||||
text: reply,
|
|
||||||
error: blocked,
|
|
||||||
isCreatedByUser: false
|
|
||||||
};
|
|
||||||
|
|
||||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
|
||||||
return responseMessage;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
|
||||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||||
if (tokenizersCache[encoding]) {
|
if (tokenizersCache[encoding]) {
|
||||||
return tokenizersCache[encoding];
|
return tokenizersCache[encoding];
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue