mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 08:12:00 +02:00
refactor(BaseClient, GoogleClient): make sendCompletion required, refactor Google to use Base sendMessage (#591)
This commit is contained in:
parent
4e317c85fd
commit
77d5fb0c58
2 changed files with 39 additions and 56 deletions
|
@ -26,6 +26,10 @@ class BaseClient {
|
|||
throw new Error("Method 'getCompletion' must be implemented.");
|
||||
}
|
||||
|
||||
sendCompletion() {
|
||||
throw new Error("Method 'sendCompletion' must be implemented.");
|
||||
}
|
||||
|
||||
getSaveOptions() {
|
||||
throw new Error('Subclasses must implement getSaveOptions');
|
||||
}
|
||||
|
|
|
@ -18,10 +18,26 @@ class GoogleClient extends BaseClient {
|
|||
this.setOptions(options);
|
||||
}
|
||||
|
||||
/* Google/PaLM2 specific methods */
|
||||
constructUrl() {
|
||||
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
|
||||
}
|
||||
|
||||
async getClient() {
|
||||
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
||||
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
||||
|
||||
jwtClient.authorize((err) => {
|
||||
if (err) {
|
||||
console.log(err);
|
||||
throw err;
|
||||
}
|
||||
});
|
||||
|
||||
return jwtClient;
|
||||
}
|
||||
|
||||
/* Required Client methods */
|
||||
setOptions(options) {
|
||||
if (this.options && !this.options.replaceOptions) {
|
||||
// nested options aren't spread properly, so we need to do this manually
|
||||
|
@ -124,25 +140,19 @@ class GoogleClient extends BaseClient {
|
|||
return this;
|
||||
}
|
||||
|
||||
async getClient() {
|
||||
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
||||
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
||||
|
||||
jwtClient.authorize((err) => {
|
||||
if (err) {
|
||||
console.log(err);
|
||||
throw err;
|
||||
}
|
||||
});
|
||||
|
||||
return jwtClient;
|
||||
getMessageMapMethod() {
|
||||
return ((message) => ({
|
||||
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
|
||||
content: message?.content ?? message.text
|
||||
})).bind(this);
|
||||
}
|
||||
|
||||
buildMessages(input, { messages = [] }) {
|
||||
buildMessages(messages = []) {
|
||||
const formattedMessages = messages.map(this.getMessageMapMethod());
|
||||
let payload = {
|
||||
instances: [
|
||||
{
|
||||
messages: [...messages, { author: this.userLabel, content: input }]
|
||||
messages: formattedMessages,
|
||||
}
|
||||
],
|
||||
parameters: this.options.modelOptions
|
||||
|
@ -156,23 +166,24 @@ class GoogleClient extends BaseClient {
|
|||
payload.instances[0].examples = this.options.examples;
|
||||
}
|
||||
|
||||
/* TO-DO: text model needs more context since it can't process an array of messages */
|
||||
if (this.isTextModel) {
|
||||
payload.instances = [
|
||||
{
|
||||
prompt: input
|
||||
prompt: messages[messages.length -1].content
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('buildMessages');
|
||||
console.debug('GoogleClient buildMessages');
|
||||
console.dir(payload, { depth: null });
|
||||
}
|
||||
|
||||
return payload;
|
||||
return { prompt: payload };
|
||||
}
|
||||
|
||||
async getCompletion(input, messages = [], abortController = null) {
|
||||
async getCompletion(payload, abortController = null) {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
|
@ -198,19 +209,11 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
|
||||
const client = await this.getClient();
|
||||
const payload = this.buildMessages(input, { messages });
|
||||
const res = await client.request({ url, method: 'POST', data: payload });
|
||||
console.dir(res.data, { depth: null });
|
||||
return res.data;
|
||||
}
|
||||
|
||||
getMessageMapMethod() {
|
||||
return ((message) => ({
|
||||
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
|
||||
content: message?.content ?? message.text
|
||||
})).bind(this);
|
||||
}
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
...this.modelOptions
|
||||
|
@ -218,24 +221,15 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
|
||||
getBuildMessagesOptions() {
|
||||
console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
|
||||
// console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
console.log('GoogleClient: sendMessage', message, opts);
|
||||
const {
|
||||
user,
|
||||
conversationId,
|
||||
responseMessageId,
|
||||
saveOptions,
|
||||
userMessage,
|
||||
} = await this.handleStartMethods(message, opts);
|
||||
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
async sendCompletion(payload, opts = {}) {
|
||||
console.log('GoogleClient: sendcompletion', payload, opts);
|
||||
let reply = '';
|
||||
let blocked = false;
|
||||
try {
|
||||
const result = await this.getCompletion(message, this.currentMessages, opts.abortController);
|
||||
const result = await this.getCompletion(payload, opts.abortController);
|
||||
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
|
||||
reply =
|
||||
result?.predictions?.[0]?.candidates?.[0]?.content ||
|
||||
|
@ -254,29 +248,14 @@ class GoogleClient extends BaseClient {
|
|||
console.error(err);
|
||||
}
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('options');
|
||||
console.debug(this.options);
|
||||
}
|
||||
|
||||
if (!blocked) {
|
||||
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
|
||||
}
|
||||
|
||||
const responseMessage = {
|
||||
messageId: responseMessageId,
|
||||
conversationId,
|
||||
parentMessageId: userMessage.messageId,
|
||||
sender: this.sender,
|
||||
text: reply,
|
||||
error: blocked,
|
||||
isCreatedByUser: false
|
||||
};
|
||||
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
return responseMessage;
|
||||
return reply.trim();
|
||||
}
|
||||
|
||||
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
|
||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||
if (tokenizersCache[encoding]) {
|
||||
return tokenizersCache[encoding];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue