refactor(BaseClient, GoogleClient): make sendCompletion required, refactor Google to use Base sendMessage (#591)

This commit is contained in:
Danny Avila 2023-07-05 14:00:12 -04:00 committed by GitHub
parent 4e317c85fd
commit 77d5fb0c58
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 56 deletions

View file

@ -26,6 +26,10 @@ class BaseClient {
throw new Error("Method 'getCompletion' must be implemented."); throw new Error("Method 'getCompletion' must be implemented.");
} }
sendCompletion() {
throw new Error("Method 'sendCompletion' must be implemented.");
}
getSaveOptions() { getSaveOptions() {
throw new Error('Subclasses must implement getSaveOptions'); throw new Error('Subclasses must implement getSaveOptions');
} }

View file

@ -18,10 +18,26 @@ class GoogleClient extends BaseClient {
this.setOptions(options); this.setOptions(options);
} }
/* Google/PaLM2 specific methods */
constructUrl() { constructUrl() {
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`; return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
} }
async getClient() {
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
jwtClient.authorize((err) => {
if (err) {
console.log(err);
throw err;
}
});
return jwtClient;
}
/* Required Client methods */
setOptions(options) { setOptions(options) {
if (this.options && !this.options.replaceOptions) { if (this.options && !this.options.replaceOptions) {
// nested options aren't spread properly, so we need to do this manually // nested options aren't spread properly, so we need to do this manually
@ -124,25 +140,19 @@ class GoogleClient extends BaseClient {
return this; return this;
} }
async getClient() { getMessageMapMethod() {
const scopes = ['https://www.googleapis.com/auth/cloud-platform']; return ((message) => ({
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
content: message?.content ?? message.text
jwtClient.authorize((err) => { })).bind(this);
if (err) {
console.log(err);
throw err;
}
});
return jwtClient;
} }
buildMessages(input, { messages = [] }) { buildMessages(messages = []) {
const formattedMessages = messages.map(this.getMessageMapMethod());
let payload = { let payload = {
instances: [ instances: [
{ {
messages: [...messages, { author: this.userLabel, content: input }] messages: formattedMessages,
} }
], ],
parameters: this.options.modelOptions parameters: this.options.modelOptions
@ -156,23 +166,24 @@ class GoogleClient extends BaseClient {
payload.instances[0].examples = this.options.examples; payload.instances[0].examples = this.options.examples;
} }
/* TO-DO: text model needs more context since it can't process an array of messages */
if (this.isTextModel) { if (this.isTextModel) {
payload.instances = [ payload.instances = [
{ {
prompt: input prompt: messages[messages.length -1].content
} }
]; ];
} }
if (this.options.debug) { if (this.options.debug) {
console.debug('buildMessages'); console.debug('GoogleClient buildMessages');
console.dir(payload, { depth: null }); console.dir(payload, { depth: null });
} }
return payload; return { prompt: payload };
} }
async getCompletion(input, messages = [], abortController = null) { async getCompletion(payload, abortController = null) {
if (!abortController) { if (!abortController) {
abortController = new AbortController(); abortController = new AbortController();
} }
@ -198,19 +209,11 @@ class GoogleClient extends BaseClient {
} }
const client = await this.getClient(); const client = await this.getClient();
const payload = this.buildMessages(input, { messages });
const res = await client.request({ url, method: 'POST', data: payload }); const res = await client.request({ url, method: 'POST', data: payload });
console.dir(res.data, { depth: null }); console.dir(res.data, { depth: null });
return res.data; return res.data;
} }
getMessageMapMethod() {
return ((message) => ({
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
content: message?.content ?? message.text
})).bind(this);
}
getSaveOptions() { getSaveOptions() {
return { return {
...this.modelOptions ...this.modelOptions
@ -218,24 +221,15 @@ class GoogleClient extends BaseClient {
} }
getBuildMessagesOptions() { getBuildMessagesOptions() {
console.log('GoogleClient doesn\'t use getBuildMessagesOptions'); // console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
} }
async sendMessage(message, opts = {}) { async sendCompletion(payload, opts = {}) {
console.log('GoogleClient: sendMessage', message, opts); console.log('GoogleClient: sendcompletion', payload, opts);
const {
user,
conversationId,
responseMessageId,
saveOptions,
userMessage,
} = await this.handleStartMethods(message, opts);
await this.saveMessageToDatabase(userMessage, saveOptions, user);
let reply = ''; let reply = '';
let blocked = false; let blocked = false;
try { try {
const result = await this.getCompletion(message, this.currentMessages, opts.abortController); const result = await this.getCompletion(payload, opts.abortController);
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked; blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
reply = reply =
result?.predictions?.[0]?.candidates?.[0]?.content || result?.predictions?.[0]?.candidates?.[0]?.content ||
@ -254,29 +248,14 @@ class GoogleClient extends BaseClient {
console.error(err); console.error(err);
} }
if (this.options.debug) {
console.debug('options');
console.debug(this.options);
}
if (!blocked) { if (!blocked) {
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 }); await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
} }
const responseMessage = { return reply.trim();
messageId: responseMessageId,
conversationId,
parentMessageId: userMessage.messageId,
sender: this.sender,
text: reply,
error: blocked,
isCreatedByUser: false
};
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
return responseMessage;
} }
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) { if (tokenizersCache[encoding]) {
return tokenizersCache[encoding]; return tokenizersCache[encoding];