refactor(BaseClient, GoogleClient): make sendCompletion required, refactor Google to use Base sendMessage (#591)

This commit is contained in:
Danny Avila 2023-07-05 14:00:12 -04:00 committed by GitHub
parent 4e317c85fd
commit 77d5fb0c58
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 56 deletions

View file

@ -26,6 +26,10 @@ class BaseClient {
throw new Error("Method 'getCompletion' must be implemented.");
}
sendCompletion() {
throw new Error("Method 'sendCompletion' must be implemented.");
}
getSaveOptions() {
throw new Error('Subclasses must implement getSaveOptions');
}

View file

@ -18,10 +18,26 @@ class GoogleClient extends BaseClient {
this.setOptions(options);
}
/* Google/PaLM2 specific methods */
constructUrl() {
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
}
async getClient() {
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
jwtClient.authorize((err) => {
if (err) {
console.log(err);
throw err;
}
});
return jwtClient;
}
/* Required Client methods */
setOptions(options) {
if (this.options && !this.options.replaceOptions) {
// nested options aren't spread properly, so we need to do this manually
@ -124,25 +140,19 @@ class GoogleClient extends BaseClient {
return this;
}
async getClient() {
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
jwtClient.authorize((err) => {
if (err) {
console.log(err);
throw err;
}
});
return jwtClient;
getMessageMapMethod() {
return ((message) => ({
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
content: message?.content ?? message.text
})).bind(this);
}
buildMessages(input, { messages = [] }) {
buildMessages(messages = []) {
const formattedMessages = messages.map(this.getMessageMapMethod());
let payload = {
instances: [
{
messages: [...messages, { author: this.userLabel, content: input }]
messages: formattedMessages,
}
],
parameters: this.options.modelOptions
@ -156,23 +166,24 @@ class GoogleClient extends BaseClient {
payload.instances[0].examples = this.options.examples;
}
/* TO-DO: text model needs more context since it can't process an array of messages */
if (this.isTextModel) {
payload.instances = [
{
prompt: input
prompt: messages[messages.length -1].content
}
];
}
if (this.options.debug) {
console.debug('buildMessages');
console.debug('GoogleClient buildMessages');
console.dir(payload, { depth: null });
}
return payload;
return { prompt: payload };
}
async getCompletion(input, messages = [], abortController = null) {
async getCompletion(payload, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
@ -198,19 +209,11 @@ class GoogleClient extends BaseClient {
}
const client = await this.getClient();
const payload = this.buildMessages(input, { messages });
const res = await client.request({ url, method: 'POST', data: payload });
console.dir(res.data, { depth: null });
return res.data;
}
getMessageMapMethod() {
return ((message) => ({
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
content: message?.content ?? message.text
})).bind(this);
}
getSaveOptions() {
return {
...this.modelOptions
@ -218,24 +221,15 @@ class GoogleClient extends BaseClient {
}
getBuildMessagesOptions() {
console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
// console.log('GoogleClient doesn\'t use getBuildMessagesOptions');
}
async sendMessage(message, opts = {}) {
console.log('GoogleClient: sendMessage', message, opts);
const {
user,
conversationId,
responseMessageId,
saveOptions,
userMessage,
} = await this.handleStartMethods(message, opts);
await this.saveMessageToDatabase(userMessage, saveOptions, user);
async sendCompletion(payload, opts = {}) {
console.log('GoogleClient: sendcompletion', payload, opts);
let reply = '';
let blocked = false;
try {
const result = await this.getCompletion(message, this.currentMessages, opts.abortController);
const result = await this.getCompletion(payload, opts.abortController);
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
reply =
result?.predictions?.[0]?.candidates?.[0]?.content ||
@ -254,29 +248,14 @@ class GoogleClient extends BaseClient {
console.error(err);
}
if (this.options.debug) {
console.debug('options');
console.debug(this.options);
}
if (!blocked) {
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
}
const responseMessage = {
messageId: responseMessageId,
conversationId,
parentMessageId: userMessage.messageId,
sender: this.sender,
text: reply,
error: blocked,
isCreatedByUser: false
};
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
return responseMessage;
return reply.trim();
}
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];