mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 17:00:15 +01:00
feat(Google): Support all Text/Chat Models, Response streaming, PaLM -> Google 🤖 (#1316)
* feat: update PaLM icons * feat: add additional google models * POC: formatting inputs for Vertex AI streaming * refactor: move endpoints services outside of /routes dir to /services/Endpoints * refactor: shorten schemas import * refactor: rename PALM to GOOGLE * feat: make Google editable endpoint * feat: reusable Ask and Edit controllers based off Anthropic * chore: organize imports/logic * fix(parseConvo): include examples in googleSchema * fix: google only allows odd number of messages to be sent * fix: pass proxy to AnthropicClient * refactor: change `google` altName to `Google` * refactor: update getModelMaxTokens and related functions to handle maxTokensMap with nested endpoint model key/values * refactor: google Icon and response sender changes (Codey and Google logo instead of PaLM in all cases) * feat: google support for maxTokensMap * feat: google updated endpoints with Ask/Edit controllers, buildOptions, and initializeClient * feat(GoogleClient): now builds prompt for text models and supports real streaming from Vertex AI through langchain * chore(GoogleClient): remove comments, left before for reference in git history * docs: update google instructions (WIP) * docs(apis_and_tokens.md): add images to google instructions * docs: remove typo apis_and_tokens.md * Update apis_and_tokens.md * feat(Google): use default settings map, fully support context for both text and chat models, fully support examples for chat models * chore: update more PaLM references to Google * chore: move playwright out of workflows to avoid failing tests
This commit is contained in:
parent
8a1968b2f8
commit
583e978a82
90 changed files with 1613 additions and 784 deletions
|
|
@ -1,6 +1,6 @@
|
|||
const Anthropic = require('@anthropic-ai/sdk');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
|
||||
const { getResponseSender, EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
|
||||
|
|
@ -46,7 +46,8 @@ class AnthropicClient extends BaseClient {
|
|||
stop: modelOptions.stop, // no stop method for now
|
||||
};
|
||||
|
||||
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 100000;
|
||||
this.maxContextTokens =
|
||||
getModelMaxTokens(this.modelOptions.model, EModelEndpoint.anthropic) ?? 100000;
|
||||
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500;
|
||||
this.maxPromptTokens =
|
||||
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
||||
|
|
|
|||
|
|
@ -445,6 +445,7 @@ class BaseClient {
|
|||
amount: promptTokens,
|
||||
debug: this.options.debug,
|
||||
model: this.modelOptions.model,
|
||||
endpoint: this.options.endpoint,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,23 +1,43 @@
|
|||
const BaseClient = require('./BaseClient');
|
||||
const { google } = require('googleapis');
|
||||
const { Agent, ProxyAgent } = require('undici');
|
||||
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
|
||||
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
|
||||
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
getResponseSender,
|
||||
EModelEndpoint,
|
||||
endpointSettings,
|
||||
} = require('~/server/services/Endpoints');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { formatMessage } = require('./prompts');
|
||||
const BaseClient = require('./BaseClient');
|
||||
|
||||
const loc = 'us-central1';
|
||||
const publisher = 'google';
|
||||
const endpointPrefix = `https://${loc}-aiplatform.googleapis.com`;
|
||||
// const apiEndpoint = loc + '-aiplatform.googleapis.com';
|
||||
const tokenizersCache = {};
|
||||
|
||||
const settings = endpointSettings[EModelEndpoint.google];
|
||||
|
||||
class GoogleClient extends BaseClient {
|
||||
constructor(credentials, options = {}) {
|
||||
super('apiKey', options);
|
||||
this.credentials = credentials;
|
||||
this.client_email = credentials.client_email;
|
||||
this.project_id = credentials.project_id;
|
||||
this.private_key = credentials.private_key;
|
||||
this.sender = 'PaLM2';
|
||||
this.access_token = null;
|
||||
if (options.skipSetOptions) {
|
||||
return;
|
||||
}
|
||||
this.setOptions(options);
|
||||
}
|
||||
|
||||
/* Google/PaLM2 specific methods */
|
||||
/* Google specific methods */
|
||||
constructUrl() {
|
||||
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`;
|
||||
return `${endpointPrefix}/v1/projects/${this.project_id}/locations/${loc}/publishers/${publisher}/models/${this.modelOptions.model}:serverStreamingPredict`;
|
||||
}
|
||||
|
||||
async getClient() {
|
||||
|
|
@ -35,6 +55,24 @@ class GoogleClient extends BaseClient {
|
|||
return jwtClient;
|
||||
}
|
||||
|
||||
async getAccessToken() {
|
||||
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
|
||||
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
jwtClient.authorize((err, tokens) => {
|
||||
if (err) {
|
||||
console.error('Error: jwtClient failed to authorize');
|
||||
console.error(err.message);
|
||||
reject(err);
|
||||
} else {
|
||||
console.log('Access Token:', tokens.access_token);
|
||||
resolve(tokens.access_token);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/* Required Client methods */
|
||||
setOptions(options) {
|
||||
if (this.options && !this.options.replaceOptions) {
|
||||
|
|
@ -53,30 +91,33 @@ class GoogleClient extends BaseClient {
|
|||
this.options = options;
|
||||
}
|
||||
|
||||
this.options.examples = this.options.examples.filter(
|
||||
(obj) => obj.input.content !== '' && obj.output.content !== '',
|
||||
);
|
||||
this.options.examples = this.options.examples
|
||||
.filter((ex) => ex)
|
||||
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');
|
||||
|
||||
const modelOptions = this.options.modelOptions || {};
|
||||
this.modelOptions = {
|
||||
...modelOptions,
|
||||
// set some good defaults (check for undefined in some cases because they may be 0)
|
||||
model: modelOptions.model || 'chat-bison',
|
||||
temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, // 0 - 1, 0.2 is recommended
|
||||
topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, // 0 - 1, default: 0.95
|
||||
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
|
||||
model: modelOptions.model || settings.model.default,
|
||||
temperature:
|
||||
typeof modelOptions.temperature === 'undefined'
|
||||
? settings.temperature.default
|
||||
: modelOptions.temperature,
|
||||
topP: typeof modelOptions.topP === 'undefined' ? settings.topP.default : modelOptions.topP,
|
||||
topK: typeof modelOptions.topK === 'undefined' ? settings.topK.default : modelOptions.topK,
|
||||
// stop: modelOptions.stop // no stop method for now
|
||||
};
|
||||
|
||||
this.isChatModel = this.modelOptions.model.startsWith('chat-');
|
||||
this.isChatModel = this.modelOptions.model.includes('chat');
|
||||
const { isChatModel } = this;
|
||||
this.isTextModel = this.modelOptions.model.startsWith('text-');
|
||||
this.isTextModel = !isChatModel && /code|text/.test(this.modelOptions.model);
|
||||
const { isTextModel } = this;
|
||||
|
||||
this.maxContextTokens = this.options.maxContextTokens || (isTextModel ? 8000 : 4096);
|
||||
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google);
|
||||
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
|
||||
// Earlier messages will be dropped until the prompt is within the limit.
|
||||
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1024;
|
||||
this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default;
|
||||
this.maxPromptTokens =
|
||||
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
||||
|
||||
|
|
@ -88,6 +129,14 @@ class GoogleClient extends BaseClient {
|
|||
);
|
||||
}
|
||||
|
||||
this.sender =
|
||||
this.options.sender ??
|
||||
getResponseSender({
|
||||
model: this.modelOptions.model,
|
||||
endpoint: EModelEndpoint.google,
|
||||
modelLabel: this.options.modelLabel,
|
||||
});
|
||||
|
||||
this.userLabel = this.options.userLabel || 'User';
|
||||
this.modelLabel = this.options.modelLabel || 'Assistant';
|
||||
|
||||
|
|
@ -99,8 +148,8 @@ class GoogleClient extends BaseClient {
|
|||
this.endToken = '';
|
||||
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
|
||||
} else if (isTextModel) {
|
||||
this.startToken = '<|im_start|>';
|
||||
this.endToken = '<|im_end|>';
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
|
|
@ -138,15 +187,18 @@ class GoogleClient extends BaseClient {
|
|||
return this;
|
||||
}
|
||||
|
||||
getMessageMapMethod() {
|
||||
formatMessages() {
|
||||
return ((message) => ({
|
||||
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel),
|
||||
content: message?.content ?? message.text,
|
||||
})).bind(this);
|
||||
}
|
||||
|
||||
buildMessages(messages = []) {
|
||||
const formattedMessages = messages.map(this.getMessageMapMethod());
|
||||
buildMessages(messages = [], parentMessageId) {
|
||||
if (this.isTextModel) {
|
||||
return this.buildMessagesPrompt(messages, parentMessageId);
|
||||
}
|
||||
const formattedMessages = messages.map(this.formatMessages());
|
||||
let payload = {
|
||||
instances: [
|
||||
{
|
||||
|
|
@ -164,15 +216,6 @@ class GoogleClient extends BaseClient {
|
|||
payload.instances[0].examples = this.options.examples;
|
||||
}
|
||||
|
||||
/* TO-DO: text model needs more context since it can't process an array of messages */
|
||||
if (this.isTextModel) {
|
||||
payload.instances = [
|
||||
{
|
||||
prompt: messages[messages.length - 1].content,
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('GoogleClient buildMessages');
|
||||
console.dir(payload, { depth: null });
|
||||
|
|
@ -181,7 +224,157 @@ class GoogleClient extends BaseClient {
|
|||
return { prompt: payload };
|
||||
}
|
||||
|
||||
async getCompletion(payload, abortController = null) {
|
||||
async buildMessagesPrompt(messages, parentMessageId) {
|
||||
const orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
});
|
||||
if (this.options.debug) {
|
||||
console.debug('GoogleClient: orderedMessages', orderedMessages, parentMessageId);
|
||||
}
|
||||
|
||||
const formattedMessages = orderedMessages.map((message) => ({
|
||||
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
|
||||
content: message?.content ?? message.text,
|
||||
}));
|
||||
|
||||
let lastAuthor = '';
|
||||
let groupedMessages = [];
|
||||
|
||||
for (let message of formattedMessages) {
|
||||
// If last author is not same as current author, add to new group
|
||||
if (lastAuthor !== message.author) {
|
||||
groupedMessages.push({
|
||||
author: message.author,
|
||||
content: [message.content],
|
||||
});
|
||||
lastAuthor = message.author;
|
||||
// If same author, append content to the last group
|
||||
} else {
|
||||
groupedMessages[groupedMessages.length - 1].content.push(message.content);
|
||||
}
|
||||
}
|
||||
|
||||
let identityPrefix = '';
|
||||
if (this.options.userLabel) {
|
||||
identityPrefix = `\nHuman's name: ${this.options.userLabel}`;
|
||||
}
|
||||
|
||||
if (this.options.modelLabel) {
|
||||
identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`;
|
||||
}
|
||||
|
||||
let promptPrefix = (this.options.promptPrefix || '').trim();
|
||||
if (promptPrefix) {
|
||||
// If the prompt prefix doesn't end with the end token, add it.
|
||||
if (!promptPrefix.endsWith(`${this.endToken}`)) {
|
||||
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
|
||||
}
|
||||
promptPrefix = `\nContext:\n${promptPrefix}`;
|
||||
}
|
||||
|
||||
if (identityPrefix) {
|
||||
promptPrefix = `${identityPrefix}${promptPrefix}`;
|
||||
}
|
||||
|
||||
// Prompt AI to respond, empty if last message was from AI
|
||||
let isEdited = lastAuthor === this.modelLabel;
|
||||
const promptSuffix = isEdited ? '' : `${promptPrefix}\n\n${this.modelLabel}:\n`;
|
||||
let currentTokenCount = isEdited
|
||||
? this.getTokenCount(promptPrefix)
|
||||
: this.getTokenCount(promptSuffix);
|
||||
|
||||
let promptBody = '';
|
||||
const maxTokenCount = this.maxPromptTokens;
|
||||
|
||||
const context = [];
|
||||
|
||||
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
|
||||
// Do this within a recursive async function so that it doesn't block the event loop for too long.
|
||||
// Also, remove the next message when the message that puts us over the token limit is created by the user.
|
||||
// Otherwise, remove only the exceeding message. This is due to Anthropic's strict payload rule to start with "Human:".
|
||||
const nextMessage = {
|
||||
remove: false,
|
||||
tokenCount: 0,
|
||||
messageString: '',
|
||||
};
|
||||
|
||||
const buildPromptBody = async () => {
|
||||
if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) {
|
||||
const message = groupedMessages.pop();
|
||||
const isCreatedByUser = message.author === this.userLabel;
|
||||
// Use promptPrefix if message is edited assistant'
|
||||
const messagePrefix =
|
||||
isCreatedByUser || !isEdited
|
||||
? `\n\n${message.author}:`
|
||||
: `${promptPrefix}\n\n${message.author}:`;
|
||||
const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`;
|
||||
let newPromptBody = `${messageString}${promptBody}`;
|
||||
|
||||
context.unshift(message);
|
||||
|
||||
const tokenCountForMessage = this.getTokenCount(messageString);
|
||||
const newTokenCount = currentTokenCount + tokenCountForMessage;
|
||||
|
||||
if (!isCreatedByUser) {
|
||||
nextMessage.messageString = messageString;
|
||||
nextMessage.tokenCount = tokenCountForMessage;
|
||||
}
|
||||
|
||||
if (newTokenCount > maxTokenCount) {
|
||||
if (!promptBody) {
|
||||
// This is the first message, so we can't add it. Just throw an error.
|
||||
throw new Error(
|
||||
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Otherwise, ths message would put us over the token limit, so don't add it.
|
||||
// if created by user, remove next message, otherwise remove only this message
|
||||
if (isCreatedByUser) {
|
||||
nextMessage.remove = true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
promptBody = newPromptBody;
|
||||
currentTokenCount = newTokenCount;
|
||||
|
||||
// Switch off isEdited after using it for the first time
|
||||
if (isEdited) {
|
||||
isEdited = false;
|
||||
}
|
||||
|
||||
// wait for next tick to avoid blocking the event loop
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
return buildPromptBody();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
await buildPromptBody();
|
||||
|
||||
if (nextMessage.remove) {
|
||||
promptBody = promptBody.replace(nextMessage.messageString, '');
|
||||
currentTokenCount -= nextMessage.tokenCount;
|
||||
context.shift();
|
||||
}
|
||||
|
||||
let prompt = `${promptBody}${promptSuffix}`;
|
||||
|
||||
// Add 2 tokens for metadata after all messages have been counted.
|
||||
currentTokenCount += 2;
|
||||
|
||||
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
||||
this.modelOptions.maxOutputTokens = Math.min(
|
||||
this.maxContextTokens - currentTokenCount,
|
||||
this.maxResponseTokens,
|
||||
);
|
||||
|
||||
return { prompt, context };
|
||||
}
|
||||
|
||||
async _getCompletion(payload, abortController = null) {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
|
|
@ -212,6 +405,72 @@ class GoogleClient extends BaseClient {
|
|||
return res.data;
|
||||
}
|
||||
|
||||
async getCompletion(_payload, options = {}) {
|
||||
const { onProgress, abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let examples;
|
||||
|
||||
let clientOptions = {
|
||||
authOptions: {
|
||||
credentials: {
|
||||
...this.credentials,
|
||||
},
|
||||
projectId: this.project_id,
|
||||
},
|
||||
...parameters,
|
||||
};
|
||||
|
||||
if (!parameters) {
|
||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||
}
|
||||
|
||||
if (_examples && _examples.length) {
|
||||
examples = _examples
|
||||
.map((ex) => {
|
||||
const { input, output } = ex;
|
||||
if (!input || !output) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
input: new HumanMessage(input.content),
|
||||
output: new AIMessage(output.content),
|
||||
};
|
||||
})
|
||||
.filter((ex) => ex);
|
||||
|
||||
clientOptions.examples = examples;
|
||||
}
|
||||
|
||||
const model = this.isTextModel
|
||||
? new GoogleVertexAI(clientOptions)
|
||||
: new ChatGoogleVertexAI(clientOptions);
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel
|
||||
? _payload.trim()
|
||||
: _messages
|
||||
.map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' }))
|
||||
.map((message) => formatMessage({ message, langChain: true }));
|
||||
|
||||
if (context && messages?.length > 0) {
|
||||
messages.unshift(new SystemMessage(context));
|
||||
}
|
||||
|
||||
const stream = await model.stream(messages, {
|
||||
signal: abortController.signal,
|
||||
timeout: 7000,
|
||||
});
|
||||
|
||||
for await (const chunk of stream) {
|
||||
await this.generateTextStream(chunk?.content ?? chunk, onProgress, { delay: 7 });
|
||||
reply += chunk?.content ?? chunk;
|
||||
}
|
||||
|
||||
return reply;
|
||||
}
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
|
|
@ -225,34 +484,18 @@ class GoogleClient extends BaseClient {
|
|||
}
|
||||
|
||||
async sendCompletion(payload, opts = {}) {
|
||||
console.log('GoogleClient: sendcompletion', payload, opts);
|
||||
let reply = '';
|
||||
let blocked = false;
|
||||
try {
|
||||
const result = await this.getCompletion(payload, opts.abortController);
|
||||
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked;
|
||||
reply =
|
||||
result?.predictions?.[0]?.candidates?.[0]?.content ||
|
||||
result?.predictions?.[0]?.content ||
|
||||
'';
|
||||
if (blocked === true) {
|
||||
reply = `Google blocked a proper response to your message:\n${JSON.stringify(
|
||||
result.predictions[0].safetyAttributes,
|
||||
)}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`;
|
||||
}
|
||||
reply = await this.getCompletion(payload, opts);
|
||||
if (this.options.debug) {
|
||||
console.debug('result');
|
||||
console.debug(result);
|
||||
console.debug(reply);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Error: failed to send completion to Google');
|
||||
console.error(err);
|
||||
console.error(err.message);
|
||||
}
|
||||
|
||||
if (!blocked) {
|
||||
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 });
|
||||
}
|
||||
|
||||
return reply.trim();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { getResponseSender, EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const { encodeAndFormat, validateVisionModel } = require('~/server/services/Files/images');
|
||||
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('~/utils');
|
||||
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
|
||||
const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { createLLM, RunManager } = require('./llm');
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ const { CallbackManager } = require('langchain/callbacks');
|
|||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
||||
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
||||
const checkBalance = require('../../models/checkBalance');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const { formatLangChainMessages } = require('./prompts');
|
||||
const { isEnabled } = require('../../server/utils');
|
||||
const { extractBaseURL } = require('../../utils');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { SelfReflectionTool } = require('./tools');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { loadTools } = require('./tools/util');
|
||||
|
||||
class PluginsClient extends OpenAIClient {
|
||||
|
|
@ -304,6 +305,7 @@ class PluginsClient extends OpenAIClient {
|
|||
amount: promptTokens,
|
||||
debug: this.options.debug,
|
||||
model: this.modelOptions.model,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
const { promptTokensEstimate } = require('openai-chat-tokens');
|
||||
const checkBalance = require('../../../models/checkBalance');
|
||||
const { isEnabled } = require('../../../server/utils');
|
||||
const { formatFromLangChain } = require('../prompts');
|
||||
const { EModelEndpoint } = require('~/server/services/Endpoints');
|
||||
const { formatFromLangChain } = require('~/app/clients/prompts');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
const createStartHandler = ({
|
||||
context,
|
||||
|
|
@ -55,6 +56,7 @@ const createStartHandler = ({
|
|||
debug: manager.debug,
|
||||
generations,
|
||||
model,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
|
|||
42
api/app/clients/prompts/formatGoogleInputs.js
Normal file
42
api/app/clients/prompts/formatGoogleInputs.js
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Formats an object to match the struct_val, list_val, string_val, float_val, and int_val format.
|
||||
*
|
||||
* @param {Object} obj - The object to be formatted.
|
||||
* @returns {Object} The formatted object.
|
||||
*
|
||||
* Handles different types:
|
||||
* - Arrays are wrapped in list_val and each element is processed.
|
||||
* - Objects are recursively processed.
|
||||
* - Strings are wrapped in string_val.
|
||||
* - Numbers are wrapped in float_val or int_val depending on whether they are floating-point or integers.
|
||||
*/
|
||||
function formatGoogleInputs(obj) {
|
||||
const formattedObj = {};
|
||||
|
||||
for (const key in obj) {
|
||||
if (Object.prototype.hasOwnProperty.call(obj, key)) {
|
||||
const value = obj[key];
|
||||
|
||||
// Handle arrays
|
||||
if (Array.isArray(value)) {
|
||||
formattedObj[key] = { list_val: value.map((item) => formatGoogleInputs(item)) };
|
||||
}
|
||||
// Handle objects
|
||||
else if (typeof value === 'object' && value !== null) {
|
||||
formattedObj[key] = formatGoogleInputs(value);
|
||||
}
|
||||
// Handle numbers
|
||||
else if (typeof value === 'number') {
|
||||
formattedObj[key] = Number.isInteger(value) ? { int_val: value } : { float_val: value };
|
||||
}
|
||||
// Handle other types (e.g., strings)
|
||||
else {
|
||||
formattedObj[key] = { string_val: [value] };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { struct_val: formattedObj };
|
||||
}
|
||||
|
||||
module.exports = formatGoogleInputs;
|
||||
274
api/app/clients/prompts/formatGoogleInputs.spec.js
Normal file
274
api/app/clients/prompts/formatGoogleInputs.spec.js
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
const formatGoogleInputs = require('./formatGoogleInputs');
|
||||
|
||||
describe('formatGoogleInputs', () => {
|
||||
it('formats message correctly', () => {
|
||||
const input = {
|
||||
messages: [
|
||||
{
|
||||
content: 'hi',
|
||||
author: 'user',
|
||||
},
|
||||
],
|
||||
context: 'context',
|
||||
examples: [
|
||||
{
|
||||
input: {
|
||||
author: 'user',
|
||||
content: 'user input',
|
||||
},
|
||||
output: {
|
||||
author: 'bot',
|
||||
content: 'bot output',
|
||||
},
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
temperature: 0.2,
|
||||
topP: 0.8,
|
||||
topK: 40,
|
||||
maxOutputTokens: 1024,
|
||||
},
|
||||
};
|
||||
|
||||
const expectedOutput = {
|
||||
struct_val: {
|
||||
messages: {
|
||||
list_val: [
|
||||
{
|
||||
struct_val: {
|
||||
content: {
|
||||
string_val: ['hi'],
|
||||
},
|
||||
author: {
|
||||
string_val: ['user'],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
context: {
|
||||
string_val: ['context'],
|
||||
},
|
||||
examples: {
|
||||
list_val: [
|
||||
{
|
||||
struct_val: {
|
||||
input: {
|
||||
struct_val: {
|
||||
author: {
|
||||
string_val: ['user'],
|
||||
},
|
||||
content: {
|
||||
string_val: ['user input'],
|
||||
},
|
||||
},
|
||||
},
|
||||
output: {
|
||||
struct_val: {
|
||||
author: {
|
||||
string_val: ['bot'],
|
||||
},
|
||||
content: {
|
||||
string_val: ['bot output'],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
parameters: {
|
||||
struct_val: {
|
||||
temperature: {
|
||||
float_val: 0.2,
|
||||
},
|
||||
topP: {
|
||||
float_val: 0.8,
|
||||
},
|
||||
topK: {
|
||||
int_val: 40,
|
||||
},
|
||||
maxOutputTokens: {
|
||||
int_val: 1024,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = formatGoogleInputs(input);
|
||||
expect(JSON.stringify(result)).toEqual(JSON.stringify(expectedOutput));
|
||||
});
|
||||
|
||||
it('formats real payload parts', () => {
|
||||
const input = {
|
||||
instances: [
|
||||
{
|
||||
context: 'context',
|
||||
examples: [
|
||||
{
|
||||
input: {
|
||||
author: 'user',
|
||||
content: 'user input',
|
||||
},
|
||||
output: {
|
||||
author: 'bot',
|
||||
content: 'user output',
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
author: 'user',
|
||||
content: 'hi',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
candidateCount: 1,
|
||||
maxOutputTokens: 1024,
|
||||
temperature: 0.2,
|
||||
topP: 0.8,
|
||||
topK: 40,
|
||||
},
|
||||
};
|
||||
const expectedOutput = {
|
||||
struct_val: {
|
||||
instances: {
|
||||
list_val: [
|
||||
{
|
||||
struct_val: {
|
||||
context: { string_val: ['context'] },
|
||||
examples: {
|
||||
list_val: [
|
||||
{
|
||||
struct_val: {
|
||||
input: {
|
||||
struct_val: {
|
||||
author: { string_val: ['user'] },
|
||||
content: { string_val: ['user input'] },
|
||||
},
|
||||
},
|
||||
output: {
|
||||
struct_val: {
|
||||
author: { string_val: ['bot'] },
|
||||
content: { string_val: ['user output'] },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
messages: {
|
||||
list_val: [
|
||||
{
|
||||
struct_val: {
|
||||
author: { string_val: ['user'] },
|
||||
content: { string_val: ['hi'] },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
parameters: {
|
||||
struct_val: {
|
||||
candidateCount: { int_val: 1 },
|
||||
maxOutputTokens: { int_val: 1024 },
|
||||
temperature: { float_val: 0.2 },
|
||||
topP: { float_val: 0.8 },
|
||||
topK: { int_val: 40 },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = formatGoogleInputs(input);
|
||||
expect(JSON.stringify(result)).toEqual(JSON.stringify(expectedOutput));
|
||||
});
|
||||
|
||||
it('helps create valid payload parts', () => {
|
||||
const instances = {
|
||||
context: 'context',
|
||||
examples: [
|
||||
{
|
||||
input: {
|
||||
author: 'user',
|
||||
content: 'user input',
|
||||
},
|
||||
output: {
|
||||
author: 'bot',
|
||||
content: 'user output',
|
||||
},
|
||||
},
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
author: 'user',
|
||||
content: 'hi',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const expectedInstances = {
|
||||
struct_val: {
|
||||
context: { string_val: ['context'] },
|
||||
examples: {
|
||||
list_val: [
|
||||
{
|
||||
struct_val: {
|
||||
input: {
|
||||
struct_val: {
|
||||
author: { string_val: ['user'] },
|
||||
content: { string_val: ['user input'] },
|
||||
},
|
||||
},
|
||||
output: {
|
||||
struct_val: {
|
||||
author: { string_val: ['bot'] },
|
||||
content: { string_val: ['user output'] },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
messages: {
|
||||
list_val: [
|
||||
{
|
||||
struct_val: {
|
||||
author: { string_val: ['user'] },
|
||||
content: { string_val: ['hi'] },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const parameters = {
|
||||
candidateCount: 1,
|
||||
maxOutputTokens: 1024,
|
||||
temperature: 0.2,
|
||||
topP: 0.8,
|
||||
topK: 40,
|
||||
};
|
||||
const expectedParameters = {
|
||||
struct_val: {
|
||||
candidateCount: { int_val: 1 },
|
||||
maxOutputTokens: { int_val: 1024 },
|
||||
temperature: { float_val: 0.2 },
|
||||
topP: { float_val: 0.8 },
|
||||
topK: { int_val: 40 },
|
||||
},
|
||||
};
|
||||
|
||||
const instancesResult = formatGoogleInputs(instances);
|
||||
const parametersResult = formatGoogleInputs(parameters);
|
||||
expect(JSON.stringify(instancesResult)).toEqual(JSON.stringify(expectedInstances));
|
||||
expect(JSON.stringify(parametersResult)).toEqual(JSON.stringify(expectedParameters));
|
||||
});
|
||||
});
|
||||
Loading…
Add table
Add a link
Reference in a new issue